Python torch_geometric.nn.Set2Set() Examples
The following are 13
code examples of torch_geometric.nn.Set2Set().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch_geometric.nn
, or try the search function
.
Example #1
Source File: mpnn.py From GLN with MIT License | 6 votes |
def __init__(self, latent_dim, output_dim, num_node_feats, num_edge_feats, max_lv=3, act_func='elu', msg_aggregate_type='mean', dropout=None): if output_dim > 0: embed_dim = output_dim else: embed_dim = latent_dim super(MPNN, self).__init__(embed_dim, dropout) if msg_aggregate_type == 'sum': msg_aggregate_type = 'add' self.max_lv = max_lv self.readout = nn.Linear(2 * latent_dim, self.embed_dim) self.lin0 = torch.nn.Linear(num_node_feats, latent_dim) net = MLP(input_dim=num_edge_feats, hidden_dims=[128, latent_dim * latent_dim], nonlinearity=act_func) self.conv = NNConv(latent_dim, latent_dim, net, aggr=msg_aggregate_type, root_weight=False) self.act_func = NONLINEARITIES[act_func] self.gru = nn.GRU(latent_dim, latent_dim) self.set2set = Set2Set(latent_dim, processing_steps=3)
Example #2
Source File: gin.py From Alchemy with MIT License | 6 votes |
def __init__(self, node_input_dim=15, output_dim=12, node_hidden_dim=64, num_step_prop=6, num_step_set2set=6): super(GIN, self).__init__() self.num_step_prop = num_step_prop self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) self.mlps = torch.nn.ModuleList() self.convs = torch.nn.ModuleList() for i in range(num_step_prop): self.mlps.append(nn.Sequential(nn.Linear(node_hidden_dim, node_hidden_dim), nn.BatchNorm1d(node_hidden_dim), nn.ReLU(), nn.Linear(node_hidden_dim, node_hidden_dim), nn.BatchNorm1d(node_hidden_dim), nn.ReLU())) self.convs.append(GINConv(self.mlps[i], eps=0, train_eps=False)) self.set2set = Set2Set(node_hidden_dim, processing_steps=num_step_set2set) self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim) self.lin2 = nn.Linear(node_hidden_dim, output_dim)
Example #3
Source File: mpnn.py From Alchemy with MIT License | 6 votes |
def __init__(self, node_input_dim=15, edge_input_dim=5, output_dim=1, node_hidden_dim=64, edge_hidden_dim=128, num_step_message_passing=6, num_step_set2set=6): super(MPNN, self).__init__() self.num_step_message_passing = num_step_message_passing self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) edge_network = nn.Sequential( nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim)) self.conv = NNConv(node_hidden_dim, node_hidden_dim, edge_network, aggr='mean', root_weight=False) self.gru = nn.GRU(node_hidden_dim, node_hidden_dim) self.set2set = Set2Set(node_hidden_dim, processing_steps=num_step_set2set) self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim) self.lin2 = nn.Linear(node_hidden_dim, output_dim)
Example #4
Source File: set2set.py From pytorch_geometric with MIT License | 5 votes |
def __init__(self, dataset, num_layers, hidden): super(Set2SetNet, self).__init__() self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.set2set = Set2Set(hidden, processing_steps=4) self.lin1 = Linear(2 * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes)
Example #5
Source File: qm9_nn_conv.py From pytorch_geometric with MIT License | 5 votes |
def __init__(self): super(Net, self).__init__() self.lin0 = torch.nn.Linear(dataset.num_features, dim) nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim)) self.conv = NNConv(dim, dim, nn, aggr='mean') self.gru = GRU(dim, dim) self.set2set = Set2Set(dim, processing_steps=3) self.lin1 = torch.nn.Linear(2 * dim, dim) self.lin2 = torch.nn.Linear(dim, 1)
Example #6
Source File: test_set2set.py From pytorch_geometric with MIT License | 5 votes |
def test_set2set(): set2set = Set2Set(in_channels=2, processing_steps=1) assert set2set.__repr__() == 'Set2Set(2, 4)' N = 4 x_1, batch_1 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long) out_1 = set2set(x_1, batch_1).view(-1) N = 6 x_2, batch_2 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long) out_2 = set2set(x_2, batch_2).view(-1) x, batch = torch.cat([x_1, x_2]), torch.cat([batch_1, batch_2 + 1]) out = set2set(x, batch) assert out.size() == (2, 4) assert out_1.tolist() == out[0].tolist() assert out_2.tolist() == out[1].tolist() x, batch = torch.cat([x_2, x_1]), torch.cat([batch_2, batch_1 + 1]) out = set2set(x, batch) assert out.size() == (2, 4) assert out_1.tolist() == out[1].tolist() assert out_2.tolist() == out[0].tolist()
Example #7
Source File: gcn.py From Alchemy with MIT License | 5 votes |
def __init__(self, node_input_dim=15, output_dim=12, node_hidden_dim=64, num_step_prop=6, num_step_set2set=6): super(GCN, self).__init__() self.num_step_prop = num_step_prop self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) self.conv = GCNConv(node_hidden_dim, node_hidden_dim, cached=False) self.set2set = Set2Set(node_hidden_dim, processing_steps=num_step_set2set) self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim) self.lin2 = nn.Linear(node_hidden_dim, output_dim)
Example #8
Source File: chebynet.py From Alchemy with MIT License | 5 votes |
def __init__(self, node_input_dim=15, output_dim=12, node_hidden_dim=64, polynomial_order=5, num_step_prop=6, num_step_set2set=6): super(ChebyNet, self).__init__() self.num_step_prop = num_step_prop self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) self.conv = ChebConv(node_hidden_dim, node_hidden_dim, K=polynomial_order) self.set2set = Set2Set(node_hidden_dim, processing_steps=num_step_set2set) self.lin1 = torch.nn.Linear(2 * node_hidden_dim, node_hidden_dim) self.lin2 = torch.nn.Linear(node_hidden_dim, output_dim)
Example #9
Source File: gat.py From Alchemy with MIT License | 5 votes |
def __init__(self, node_input_dim=15, output_dim=12, node_hidden_dim=64, num_step_prop=6, num_step_set2set=6): super(GAT, self).__init__() self.num_step_prop = num_step_prop self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) self.conv = GATConv(node_hidden_dim, node_hidden_dim) self.set2set = Set2Set(node_hidden_dim, processing_steps=num_step_set2set) self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim) self.lin2 = nn.Linear(node_hidden_dim, output_dim)
Example #10
Source File: infograph.py From cogdl with MIT License | 5 votes |
def __init__(self, num_features, dim, num_layers=1): super(SUPEncoder, self).__init__() self.lin0 = torch.nn.Linear(num_features, dim) nnu = nn.Sequential(nn.Linear(5, 128), nn.ReLU(), nn.Linear(128, dim * dim)) self.conv = NNConv(dim, dim, nnu, aggr='mean', root_weight=False) self.gru = nn.GRU(dim, dim) self.set2set = Set2Set(dim, processing_steps=3) # self.lin1 = torch.nn.Linear(2 * dim, dim) # self.lin2 = torch.nn.Linear(dim, 1)
Example #11
Source File: gnn.py From ogb with MIT License | 4 votes |
def __init__(self, num_tasks, num_layer = 5, emb_dim = 300, gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"): ''' num_tasks (int): number of labels to be predicted virtual_node (bool): whether to add virtual node or not ''' super(GNN, self).__init__() self.num_layer = num_layer self.drop_ratio = drop_ratio self.JK = JK self.emb_dim = emb_dim self.num_tasks = num_tasks self.graph_pooling = graph_pooling if self.num_layer < 2: raise ValueError("Number of GNN layers must be greater than 1.") ### GNN to generate node embeddings if virtual_node: self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) else: self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) ### Pooling function to generate whole-graph embeddings if self.graph_pooling == "sum": self.pool = global_add_pool elif self.graph_pooling == "mean": self.pool = global_mean_pool elif self.graph_pooling == "max": self.pool = global_max_pool elif self.graph_pooling == "attention": self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) elif self.graph_pooling == "set2set": self.pool = Set2Set(emb_dim, processing_steps = 2) else: raise ValueError("Invalid graph pooling type.") if graph_pooling == "set2set": self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks) else: self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
Example #12
Source File: gnn.py From ogb with MIT License | 4 votes |
def __init__(self, num_vocab, max_seq_len, node_encoder, num_layer = 5, emb_dim = 300, gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"): ''' num_tasks (int): number of labels to be predicted virtual_node (bool): whether to add virtual node or not ''' super(GNN, self).__init__() self.num_layer = num_layer self.drop_ratio = drop_ratio self.JK = JK self.emb_dim = emb_dim self.num_vocab = num_vocab self.max_seq_len = max_seq_len self.graph_pooling = graph_pooling if self.num_layer < 2: raise ValueError("Number of GNN layers must be greater than 1.") ### GNN to generate node embeddings if virtual_node: self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) else: self.gnn_node = GNN_node(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) ### Pooling function to generate whole-graph embeddings if self.graph_pooling == "sum": self.pool = global_add_pool elif self.graph_pooling == "mean": self.pool = global_mean_pool elif self.graph_pooling == "max": self.pool = global_max_pool elif self.graph_pooling == "attention": self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) elif self.graph_pooling == "set2set": self.pool = Set2Set(emb_dim, processing_steps = 2) else: raise ValueError("Invalid graph pooling type.") self.graph_pred_linear_list = torch.nn.ModuleList() if graph_pooling == "set2set": for i in range(max_seq_len): self.graph_pred_linear_list.append(torch.nn.Linear(2*emb_dim, self.num_vocab)) else: for i in range(max_seq_len): self.graph_pred_linear_list.append(torch.nn.Linear(emb_dim, self.num_vocab))
Example #13
Source File: gnn.py From ogb with MIT License | 4 votes |
def __init__(self, num_class, num_layer = 5, emb_dim = 300, gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"): ''' num_tasks (int): number of labels to be predicted virtual_node (bool): whether to add virtual node or not ''' super(GNN, self).__init__() self.num_layer = num_layer self.drop_ratio = drop_ratio self.JK = JK self.emb_dim = emb_dim self.num_class = num_class self.graph_pooling = graph_pooling if self.num_layer < 2: raise ValueError("Number of GNN layers must be greater than 1.") ### GNN to generate node embeddings if virtual_node: self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) else: self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) ### Pooling function to generate whole-graph embeddings if self.graph_pooling == "sum": self.pool = global_add_pool elif self.graph_pooling == "mean": self.pool = global_mean_pool elif self.graph_pooling == "max": self.pool = global_max_pool elif self.graph_pooling == "attention": self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) elif self.graph_pooling == "set2set": self.pool = Set2Set(emb_dim, processing_steps = 2) else: raise ValueError("Invalid graph pooling type.") if graph_pooling == "set2set": self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_class) else: self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_class)