Python torch_geometric.nn.global_add_pool() Examples
The following are 10
code examples of torch_geometric.nn.global_add_pool().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch_geometric.nn
, or try the search function
.
Example #1
Source File: models.py From IGMC with MIT License | 6 votes |
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if self.adj_dropout > 0: edge_index, edge_type = dropout_adj( edge_index, edge_type, p=self.adj_dropout, force_undirected=self.force_undirected, num_nodes=len(x), training=self.training ) concat_states = [] for conv in self.convs: x = torch.tanh(conv(x, edge_index)) concat_states.append(x) concat_states = torch.cat(concat_states, 1) x = global_add_pool(concat_states, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) if self.regression: return x[:, 0] else: return F.log_softmax(x, dim=-1)
Example #2
Source File: mutag_gin.py From pytorch_geometric with MIT License | 6 votes |
def forward(self, x, edge_index, batch): x = F.relu(self.conv1(x, edge_index)) x = self.bn1(x) x = F.relu(self.conv2(x, edge_index)) x = self.bn2(x) x = F.relu(self.conv3(x, edge_index)) x = self.bn3(x) x = F.relu(self.conv4(x, edge_index)) x = self.bn4(x) x = F.relu(self.conv5(x, edge_index)) x = self.bn5(x) x = global_add_pool(x, batch) x = F.relu(self.fc1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=-1)
Example #3
Source File: MolecularFingerprint.py From gnn-comparison with GNU General Public License v3.0 | 5 votes |
def forward(self, data): return self.mlp(global_add_pool(data.x, data.batch))
Example #4
Source File: DeepMultisets.py From gnn-comparison with GNU General Public License v3.0 | 5 votes |
def forward(self, data): x, batch = data.x, data.batch x = F.relu(self.fc_vertex(x)) x = global_add_pool(x, batch) # sums all vertex embeddings belonging to the same graph! x = F.relu(self.fc_global1(x)) x = self.fc_global2(x) return x
Example #5
Source File: colors_topk_pool.py From pytorch_geometric with MIT License | 5 votes |
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch out = F.relu(self.conv1(x, edge_index)) out, edge_index, _, batch, perm, score = self.pool1( out, edge_index, None, batch, attn=x) ratio = out.size(0) / x.size(0) out = F.relu(self.conv2(out, edge_index)) out = global_add_pool(out, batch) out = self.lin(out).view(-1) attn_loss = F.kl_div(torch.log(score + 1e-14), data.attn[perm], reduction='none') attn_loss = scatter_mean(attn_loss, batch) return out, attn_loss, ratio
Example #6
Source File: gin.py From pytorch_geometric with MIT License | 5 votes |
def forward(self, x, edge_index, batch): for conv, batch_norm in zip(self.convs, self.batch_norms): x = F.relu(batch_norm(conv(x, edge_index))) x = global_add_pool(x, batch) x = F.relu(self.batch_norm1(self.lin1(x))) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
Example #7
Source File: pna.py From pytorch_geometric with MIT License | 5 votes |
def forward(self, x, edge_index, edge_attr, batch): x = self.node_emb(x.squeeze()) edge_attr = self.edge_emb(edge_attr) for conv, batch_norm in zip(self.convs, self.batch_norms): x = F.relu(batch_norm(conv(x, edge_index, edge_attr))) x = global_add_pool(x, batch) return self.mlp(x)
Example #8
Source File: conv.py From ogb with MIT License | 4 votes |
def forward(self, batched_data): x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch ### virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) h_list = [self.atom_encoder(x)] for layer in range(self.num_layer): ### add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch] ### Message passing among graph nodes h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layer - 1: #remove relu for the last layer h = F.dropout(h, self.drop_ratio, training = self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) if self.residual: h = h + h_list[layer] h_list.append(h) ### update the virtual nodes if layer < self.num_layer - 1: ### add message from graph nodes to virtual nodes virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding ### transform virtual nodes using MLP if self.residual: virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) else: virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) ### Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layer): node_representation += h_list[layer] return node_representation
Example #9
Source File: conv.py From ogb with MIT License | 4 votes |
def forward(self, batched_data): x, edge_index, edge_attr, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.node_depth, batched_data.batch ### virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) h_list = [self.node_encoder(x, node_depth.view(-1,))] for layer in range(self.num_layer): ### add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch] ### Message passing among graph nodes h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layer - 1: #remove relu for the last layer h = F.dropout(h, self.drop_ratio, training = self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) if self.residual: h = h + h_list[layer] h_list.append(h) ### update the virtual nodes if layer < self.num_layer - 1: ### add message from graph nodes to virtual nodes virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding ### transform virtual nodes using MLP if self.residual: virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) else: virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) ### Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layer): node_representation += h_list[layer] return node_representation
Example #10
Source File: conv.py From ogb with MIT License | 4 votes |
def forward(self, batched_data): x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch ### virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) h_list = [self.node_encoder(x)] for layer in range(self.num_layer): ### add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch] ### Message passing among graph nodes h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layer - 1: #remove relu for the last layer h = F.dropout(h, self.drop_ratio, training = self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) if self.residual: h = h + h_list[layer] h_list.append(h) ### update the virtual nodes if layer < self.num_layer - 1: ### add message from graph nodes to virtual nodes virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding ### transform virtual nodes using MLP if self.residual: virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) else: virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) ### Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layer): node_representation += h_list[layer] return node_representation