Python torch_geometric.nn.Set2Set() Examples

The following are 13 code examples of torch_geometric.nn.Set2Set(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch_geometric.nn , or try the search function .
Example #1
Source File: mpnn.py    From GLN with MIT License 6 votes vote down vote up
def __init__(self, latent_dim, output_dim, num_node_feats, num_edge_feats, max_lv=3, 
                act_func='elu', msg_aggregate_type='mean', dropout=None):
        if output_dim > 0:
            embed_dim = output_dim
        else:
            embed_dim = latent_dim            
        super(MPNN, self).__init__(embed_dim, dropout)
        if msg_aggregate_type == 'sum':
            msg_aggregate_type = 'add'
        self.max_lv = max_lv
        self.readout = nn.Linear(2 * latent_dim, self.embed_dim)
        self.lin0 = torch.nn.Linear(num_node_feats, latent_dim)
        net = MLP(input_dim=num_edge_feats, 
                  hidden_dims=[128, latent_dim * latent_dim],
                  nonlinearity=act_func)
        self.conv = NNConv(latent_dim, latent_dim, net, aggr=msg_aggregate_type, root_weight=False)

        self.act_func = NONLINEARITIES[act_func]
        self.gru = nn.GRU(latent_dim, latent_dim)
        self.set2set = Set2Set(latent_dim, processing_steps=3) 
Example #2
Source File: gin.py    From Alchemy with MIT License 6 votes vote down vote up
def __init__(self,
                 node_input_dim=15,
                 output_dim=12,
                 node_hidden_dim=64,
                 num_step_prop=6,
                 num_step_set2set=6):
        super(GIN, self).__init__()
        self.num_step_prop = num_step_prop
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        self.mlps = torch.nn.ModuleList()
        self.convs = torch.nn.ModuleList()
        for i in range(num_step_prop):
            self.mlps.append(nn.Sequential(nn.Linear(node_hidden_dim, node_hidden_dim), nn.BatchNorm1d(node_hidden_dim), nn.ReLU(),
                                           nn.Linear(node_hidden_dim, node_hidden_dim), nn.BatchNorm1d(node_hidden_dim), nn.ReLU()))
            self.convs.append(GINConv(self.mlps[i], eps=0, train_eps=False))
        self.set2set = Set2Set(node_hidden_dim, processing_steps=num_step_set2set)
        self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        self.lin2 = nn.Linear(node_hidden_dim, output_dim) 
Example #3
Source File: mpnn.py    From Alchemy with MIT License 6 votes vote down vote up
def __init__(self,
                 node_input_dim=15,
                 edge_input_dim=5,
                 output_dim=1,
                 node_hidden_dim=64,
                 edge_hidden_dim=128,
                 num_step_message_passing=6,
                 num_step_set2set=6):
        super(MPNN, self).__init__()
        self.num_step_message_passing = num_step_message_passing
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        edge_network = nn.Sequential(
            nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
            nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
        self.conv = NNConv(node_hidden_dim,
                           node_hidden_dim,
                           edge_network,
                           aggr='mean',
                           root_weight=False)
        self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)

        self.set2set = Set2Set(node_hidden_dim,
                               processing_steps=num_step_set2set)
        self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        self.lin2 = nn.Linear(node_hidden_dim, output_dim) 
Example #4
Source File: set2set.py    From pytorch_geometric with MIT License 5 votes vote down vote up
def __init__(self, dataset, num_layers, hidden):
        super(Set2SetNet, self).__init__()
        self.conv1 = SAGEConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden, hidden))
        self.set2set = Set2Set(hidden, processing_steps=4)
        self.lin1 = Linear(2 * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes) 
Example #5
Source File: qm9_nn_conv.py    From pytorch_geometric with MIT License 5 votes vote down vote up
def __init__(self):
        super(Net, self).__init__()
        self.lin0 = torch.nn.Linear(dataset.num_features, dim)

        nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim))
        self.conv = NNConv(dim, dim, nn, aggr='mean')
        self.gru = GRU(dim, dim)

        self.set2set = Set2Set(dim, processing_steps=3)
        self.lin1 = torch.nn.Linear(2 * dim, dim)
        self.lin2 = torch.nn.Linear(dim, 1) 
Example #6
Source File: test_set2set.py    From pytorch_geometric with MIT License 5 votes vote down vote up
def test_set2set():
    set2set = Set2Set(in_channels=2, processing_steps=1)
    assert set2set.__repr__() == 'Set2Set(2, 4)'

    N = 4
    x_1, batch_1 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long)
    out_1 = set2set(x_1, batch_1).view(-1)

    N = 6
    x_2, batch_2 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long)
    out_2 = set2set(x_2, batch_2).view(-1)

    x, batch = torch.cat([x_1, x_2]), torch.cat([batch_1, batch_2 + 1])
    out = set2set(x, batch)
    assert out.size() == (2, 4)
    assert out_1.tolist() == out[0].tolist()
    assert out_2.tolist() == out[1].tolist()

    x, batch = torch.cat([x_2, x_1]), torch.cat([batch_2, batch_1 + 1])
    out = set2set(x, batch)
    assert out.size() == (2, 4)
    assert out_1.tolist() == out[1].tolist()
    assert out_2.tolist() == out[0].tolist() 
Example #7
Source File: gcn.py    From Alchemy with MIT License 5 votes vote down vote up
def __init__(self,
                 node_input_dim=15,
                 output_dim=12,
                 node_hidden_dim=64,
                 num_step_prop=6,
                 num_step_set2set=6):
        super(GCN, self).__init__()
        self.num_step_prop = num_step_prop
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        self.conv = GCNConv(node_hidden_dim, node_hidden_dim, cached=False)        
        self.set2set = Set2Set(node_hidden_dim, processing_steps=num_step_set2set)
        self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        self.lin2 = nn.Linear(node_hidden_dim, output_dim) 
Example #8
Source File: chebynet.py    From Alchemy with MIT License 5 votes vote down vote up
def __init__(self,
                 node_input_dim=15,
                 output_dim=12,
                 node_hidden_dim=64,
                 polynomial_order=5,
                 num_step_prop=6,
                 num_step_set2set=6):
        super(ChebyNet, self).__init__()
        self.num_step_prop = num_step_prop
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        self.conv = ChebConv(node_hidden_dim, node_hidden_dim, K=polynomial_order)
        
        self.set2set = Set2Set(node_hidden_dim, processing_steps=num_step_set2set)
        self.lin1 = torch.nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        self.lin2 = torch.nn.Linear(node_hidden_dim, output_dim) 
Example #9
Source File: gat.py    From Alchemy with MIT License 5 votes vote down vote up
def __init__(self,
                 node_input_dim=15,
                 output_dim=12,
                 node_hidden_dim=64,
                 num_step_prop=6,
                 num_step_set2set=6):
        super(GAT, self).__init__()
        self.num_step_prop = num_step_prop
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        self.conv = GATConv(node_hidden_dim, node_hidden_dim)
        
        self.set2set = Set2Set(node_hidden_dim, processing_steps=num_step_set2set)
        self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        self.lin2 = nn.Linear(node_hidden_dim, output_dim) 
Example #10
Source File: infograph.py    From cogdl with MIT License 5 votes vote down vote up
def __init__(self, num_features, dim, num_layers=1):
        super(SUPEncoder, self).__init__()
        self.lin0 = torch.nn.Linear(num_features, dim)

        nnu = nn.Sequential(nn.Linear(5, 128), nn.ReLU(), nn.Linear(128, dim * dim))
        self.conv = NNConv(dim, dim, nnu, aggr='mean', root_weight=False)
        self.gru = nn.GRU(dim, dim)

        self.set2set = Set2Set(dim, processing_steps=3)
        # self.lin1 = torch.nn.Linear(2 * dim, dim)
        # self.lin2 = torch.nn.Linear(dim, 1) 
Example #11
Source File: gnn.py    From ogb with MIT License 4 votes vote down vote up
def __init__(self, num_tasks, num_layer = 5, emb_dim = 300, 
                    gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
        else:
            self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)


        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks) 
Example #12
Source File: gnn.py    From ogb with MIT License 4 votes vote down vote up
def __init__(self, num_vocab, max_seq_len, node_encoder, num_layer = 5, emb_dim = 300, 
                    gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_vocab = num_vocab
        self.max_seq_len = max_seq_len
        self.graph_pooling = graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
        else:
            self.gnn_node = GNN_node(num_layer, emb_dim, node_encoder, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)


        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_linear_list = torch.nn.ModuleList()

        if graph_pooling == "set2set":
            for i in range(max_seq_len):
                 self.graph_pred_linear_list.append(torch.nn.Linear(2*emb_dim, self.num_vocab))

        else:
            for i in range(max_seq_len):
                 self.graph_pred_linear_list.append(torch.nn.Linear(emb_dim, self.num_vocab)) 
Example #13
Source File: gnn.py    From ogb with MIT License 4 votes vote down vote up
def __init__(self, num_class, num_layer = 5, emb_dim = 300, 
                    gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_class = num_class
        self.graph_pooling = graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
        else:
            self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)


        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_class)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_class)