Python torch.nn.functional.gumbel_softmax() Examples

The following are 19 code examples of torch.nn.functional.gumbel_softmax(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn.functional , or try the search function .
Example #1
Source File: synthesizer.py    From CTGAN with MIT License 6 votes vote down vote up
def _apply_activate(self, data):
        data_t = []
        st = 0
        for item in self.transformer.output_info:
            if item[1] == 'tanh':
                ed = st + item[0]
                data_t.append(torch.tanh(data[:, st:ed]))
                st = ed
            elif item[1] == 'softmax':
                ed = st + item[0]
                data_t.append(functional.gumbel_softmax(data[:, st:ed], tau=0.2))
                st = ed
            else:
                assert 0

        return torch.cat(data_t, dim=1) 
Example #2
Source File: qv.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        # points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        print('Distances:', distances[:3])
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #3
Source File: qv.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #4
Source File: qv.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        # points = points.data  # the only diff from 16
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #5
Source File: qv_v1.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #6
Source File: qv_v1.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #7
Source File: model_search.py    From nasbench-1shot1 with Apache License 2.0 5 votes vote down vote up
def forward(self, input, discrete=False, normalize=False):
        # NASBench only has one input to each cell
        s0 = self.stem(input)
        for i, cell in enumerate(self.cells):
            if i in [self._layers // 3, 2 * self._layers // 3]:
                # Perform down-sampling by factor 1/2
                # Equivalent to https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L68
                s0 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(s0)
            # If using discrete architecture from random_ws search with weight sharing then pass through architecture
            # weights directly.
            # For GDAS use gumbel softmax hard, therefore per mixed block only a single operation is evaluated
            preprocess_op_mixed_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=True, dim=-1)
            # Don't use hard for the rest, because it very quickly gave exploding gradients
            preprocess_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=False, dim=-1)

            # Normalize mixed_op weights for the choice blocks in the graph
            mixed_op_weights = preprocess_op_mixed_op(self._arch_parameters[0])
            # Normalize the output weights
            output_weights = preprocess_op(self._arch_parameters[1]) if self._output_weights else None
            # Normalize the input weights for the nodes in the cell
            input_weights = [preprocess_op(alpha) for alpha in self._arch_parameters[2:]]
            s0 = cell(s0, mixed_op_weights, output_weights, input_weights)

        # Include one more preprocessing step here
        s0 = self.postprocess(s0)  # [N, C_max * (steps + 1), w, h] -> [N, C_max, w, h]

        # Global Average Pooling by averaging over last two remaining spatial dimensions
        # Like in nasbench: https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L92
        out = s0.view(*s0.shape[:2], -1).mean(-1)
        logits = self.classifier(out.view(out.size(0), -1))
        return logits 
Example #8
Source File: test_pyprof_nvtx.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_gumbel_softmax(self):
        inp = torch.randn(16, 1024, device='cuda', dtype=self.dtype)
        output = F.gumbel_softmax(inp, tau=1, hard=False, eps=1e-10, dim=-1) 
Example #9
Source File: multi_categorical.py    From multi-categorical-gans with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def forward(self, logits, training=True, temperature=None):
        # gumbel-softmax (training and evaluation)
        if temperature is not None:
            return F.gumbel_softmax(logits, hard=not training, tau=temperature)
        # softmax training
        elif training:
            return F.softmax(logits, dim=1)
        # softmax evaluation
        else:
            return OneHotCategorical(logits=logits).sample() 
Example #10
Source File: attention.py    From seq2seq.pytorch with MIT License 5 votes vote down vote up
def forward(self, q, k, v):
        b_q, t_q, dim_q = list(q.size())
        b_k, t_k, dim_k = list(k.size())
        b_v, t_v, dim_v = list(v.size())
        assert(b_q == b_k and b_k == b_v)  # batch size should be equal
        assert(dim_q == dim_k)  # dims should be equal
        assert(t_k == t_v)  # times should be equal
        b = b_q
        qk = torch.bmm(q, k.transpose(1, 2))  # b x t_q x t_k
        qk = qk / (dim_k ** 0.5)
        mask = None
        with torch.no_grad():
            if self.causal and t_q > 1:
                causal_mask = q.data.new(t_q, t_k).byte().fill_(1).triu_(1)
                mask = causal_mask.unsqueeze(0).expand(b, t_q, t_k)
            if self.mask_k is not None:
                mask_k = self.mask_k.unsqueeze(1).expand(b, t_q, t_k)
                mask = mask_k if mask is None else mask | mask_k
            if self.mask_q is not None:
                mask_q = self.mask_q.unsqueeze(2).expand(b, t_q, t_k)
                mask = mask_q if mask is None else mask | mask_q
        if mask is not None:
            qk.masked_fill_(mask, -1e12)

        if self.gumbel:
            sm_qk = F.gumbel_softmax(qk, dim=2, hard=True)
        else:
            sm_qk = F.softmax(qk, dim=2,
                              dtype=torch.float32 if qk.dtype == torch.float16 else qk.dtype)
        sm_qk = self.dropout(sm_qk)
        return torch.bmm(sm_qk, v), sm_qk  # b x t_q x dim_v 
Example #11
Source File: dynamic_halters.py    From attn2d with MIT License 5 votes vote down vote up
def step(self, x, n, cumul=None, total_computes=None):
        """
        n is the index of the upcoming block, 
        Given the current activation decide whether to go in or skip/exit.
        returns the binary decision and the log-(p, 1-p)
        """
        T, B, C = x.size()
        if self.detach_before_classifier:
            x = x.detach()
        x = self.halting_predictors[n if self.separate_halting_predictors else 0](x)
        halt_logits = F.logsigmoid(x)  # the log-p of halting
        # Apply the gumbel trick
        halt = halt_logits.view(-1, 2)
        halt = F.gumbel_softmax(halt, tau=self.gumbel_tau).view(T, B, 2)
        return halt 
Example #12
Source File: qv.py    From attn2d with MIT License 5 votes vote down vote up
def forward(self, x, key):
        T, B, C = x.size()
        loss = torch.zeros(1).type_as(x).to(x.device)
        if key is not None:
            Tr = 1
        else:
            key = x
            Tr = T

        if self.tau:
            resp = F.gumbel_softmax(
                self.assign(key.contiguous().view(Tr*B, self.key_dim)),
                tau=self.tau,
                hard=self.hard
            )  # T*B, ne
        else:
            resp = torch.softmax(
                self.assign(key.contiguous().view(Tr*B, self.key_dim)),
                dim=-1
            )  # T*B, ne

        importance = resp.sum(dim=0)
        loss = self.loss_scale * torch.std(importance) / torch.mean(importance)
        print('importance', importance.data.round())
        # w = torch.matmul(resp, self.pw_w1)  # T*B, C_out * C_in
        # w = w.view(T, B, self.output_dim, self.input_dim)
        # x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
        # if self.pw_bias is not None:
            # x = x + self.pw_bias(x0)
        # First evaluate each expert output
        resp = resp.view(Tr, B, self.ne, 1)
        residual = x
        x = torch.matmul(self.pw_w1, x.unsqueeze(2).unsqueeze(-1)).squeeze(-1)  # T, B, ne, out
        x = F.relu(x)
        x = torch.sum(resp * x, dim=2)
        if self.pw_bias is not None:
            x = x + self.pw_bias(key)
        return x + residual, loss 
Example #13
Source File: ctgan.py    From SDGym with MIT License 5 votes vote down vote up
def apply_activate(data, output_info):
    data_t = []
    st = 0
    for item in output_info:
        if item[1] == 'tanh':
            ed = st + item[0]
            data_t.append(torch.tanh(data[:, st:ed]))
            st = ed
        elif item[1] == 'softmax':
            ed = st + item[0]
            data_t.append(F.gumbel_softmax(data[:, st:ed], tau=0.2))
            st = ed
        else:
            assert 0
    return torch.cat(data_t, dim=1) 
Example #14
Source File: adv_masker.py    From bert_on_stilts with Apache License 2.0 5 votes vote down vote up
def forward(self, x, attention_mask, gumbel_softmax=True, tau=None):
        extended_attention_mask = self.convert_mask(attention_mask)
        h = self.bert_layer(x, extended_attention_mask)
        h = self.linear_layer(h)
        log_probs = self.log_sigmoid(h).squeeze(dim=2)

        if gumbel_softmax:
            tau = self.tau if tau is None else tau
            return F.gumbel_softmax(log_probs, tau=tau)
        else:
            return log_probs 
Example #15
Source File: qv.py    From attn2d with MIT License 4 votes vote down vote up
def forward(self, x, key):
        T, B, C = x.size()
        loss = torch.zeros(1).type_as(x).to(x.device)
        if key is not None:
            # outsider influence:
            scales = torch.softmax(self.pw_scales(key), dim=-1).unsqueeze(-1).unsqueeze(-1)  # 1, B, ne, 1, 1
            w = self.pw_w1 * scales  # 1, B, ne, C_out, C_in
            w = torch.sum(w, dim=2)
            x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
            if self.pw_bias is not None:
                x = x + self.pw_bias(key)
            return x, loss
        else:
            residual = x
            if self.training:
                if self.tau:
                    resp = F.gumbel_softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        tau=self.tau,
                        hard=self.hard
                    )  # T*B, ne
                else:
                    resp = torch.softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        dim=-1
                    )  # T*B, ne

            else:
                if self.tau:
                    resp = F.gumbel_softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        tau=self.tau,
                        hard=self.hard
                    )  # T*B, ne
                else:
                    resp = torch.softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        dim=-1
                    )  # T*B, ne

                # For the new exp with ne600 eval is soft as well
                # logits = self.assign(x.contiguous().view(T*B, C))
                # indices = torch.argmax(logits, dim=-1)
                # resp = torch.zeros(logits.size(0), self.ne).type_as(logits)
                # resp.scatter_(1, indices.unsqueeze(1), 1)
            importance = resp.sum(dim=0)
            loss = self.loss_scale * torch.std(importance) / torch.mean(importance)

            print('importance', importance.data.round())
            w = torch.matmul(resp, self.pw_w1)  # T*B, C_out * C_in
            w = w.view(T, B, self.output_dim, self.input_dim)
            x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
            x = x + residual  # v1 sigmoid on x before residual
            if self.pw_bias is not None:
                x = x + self.pw_bias(residual)
            return x, loss 
Example #16
Source File: qv.py    From attn2d with MIT License 4 votes vote down vote up
def forward(self, x, key):
        T, B, C = x.size()
        loss = torch.zeros(1).type_as(x).to(x.device)
        if key is not None:
            # outsider influence:
            scales = torch.softmax(self.pw_scales(key), dim=-1).unsqueeze(-1).unsqueeze(-1)  # 1, B, ne, 1, 1
            w = self.pw_w1 * scales  # 1, B, ne, C_out, C_in
            w = torch.sum(w, dim=2)
            x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
            if self.pw_bias is not None:
                x = x + self.pw_bias(key)
            return x, loss
        else:
            residual = x
            if self.training:
                if self.tau:
                    resp = F.gumbel_softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        tau=self.tau,
                        hard=self.hard
                    )  # T*B, ne
                else:
                    resp = torch.softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        dim=-1
                    )  # T*B, ne

            else:
                if self.tau:
                    resp = F.gumbel_softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        tau=self.tau,
                        hard=self.hard
                    )  # T*B, ne
                else:
                    resp = torch.softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        dim=-1
                    )  # T*B, ne

                # For the new exp with ne600 eval is soft as well
                # logits = self.assign(x.contiguous().view(T*B, C))
                # indices = torch.argmax(logits, dim=-1)
                # resp = torch.zeros(logits.size(0), self.ne).type_as(logits)
                # resp.scatter_(1, indices.unsqueeze(1), 1)
            importance = resp.sum(dim=0)
            loss = self.loss_scale * torch.std(importance) / torch.mean(importance)

            print('importance', importance.data.round())
            w = torch.matmul(resp, self.pw_w1)  # T*B, C_out * C_in
            w = w.view(T, B, self.output_dim, self.input_dim)
            x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
            if self.pw_bias is not None:
                x = x + self.pw_bias(residual)
            x = torch.sigmoid(x) + residual
            return x, loss 
Example #17
Source File: qv.py    From attn2d with MIT License 4 votes vote down vote up
def forward(self, x, key):
        T, B, C = x.size()
        loss = torch.zeros(1).type_as(x).to(x.device)
        if key is not None:
            # outsider influence:
            scales = torch.softmax(self.pw_scales(key), dim=-1).unsqueeze(-1).unsqueeze(-1)  # 1, B, ne, 1, 1
            w = self.pw_w1 * scales  # 1, B, ne, C_out, C_in
            w = torch.sum(w, dim=2)
            x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
            if self.pw_bias is not None:
                x = x + self.pw_bias(key)
            return x, loss
        else:
            x0 = x
            key = x.contiguous().view(T*B, C)
            energies = self.assign(key)
            
            if self.training:
                noise = F.softplus(self.noise(key)) * torch.randn_like(energies)
                energies = keeptopk_masked(energies + noise, self.topk)
                if self.tau:
                    resp = F.gumbel_softmax(
                        energies,
                        tau=self.tau,
                        hard=self.hard
                    )  # T*B, ne
                else:
                    resp = F.softmax(
                        energies,
                        dim=-1
                        )  # T*B, ne

            else:
                energies = keeptopk_masked(energies, self.topk)
                if self.tau:
                    resp = F.gumbel_softmax(
                        energies,
                        tau=self.tau,
                        hard=self.hard
                    )  # T*B, ne
                else:
                    resp = F.softmax(
                        energies,
                        dim=-1
                        )  # T*B, ne

                # indices = torch.argmax(energies, dim=-1)
                # resp = torch.zeros(energies.size(0), self.ne).type_as(energies)
                # resp.scatter_(1, indices.unsqueeze(1), 1)
            importance = resp.sum(dim=0)
            loss = self.loss_scale * torch.std(importance) / torch.mean(importance)

            print('importance', importance.data.round())
            w = torch.matmul(resp, self.pw_w1)  # T*B, C_out * C_in
            w = w.view(T, B, self.output_dim, self.input_dim)
            x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
            if self.pw_bias is not None:
                x = x + self.pw_bias(x0)
            return x, loss 
Example #18
Source File: qv.py    From attn2d with MIT License 4 votes vote down vote up
def forward(self, x, key):
        T, B, C = x.size()
        loss = torch.zeros(1).type_as(x).to(x.device)
        if key is not None:
            # outsider influence:
            scales = torch.softmax(self.pw_scales(key), dim=-1).unsqueeze(-1).unsqueeze(-1)  # 1, B, ne, 1, 1
            w = self.pw_w1 * scales  # 1, B, ne, C_out, C_in
            w = torch.sum(w, dim=2)
            x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
            if self.pw_bias is not None:
                x = x + self.pw_bias(key)
            return x, loss
        else:
            x0 = x
            if self.training:
                if self.tau:
                    resp = F.gumbel_softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        tau=self.tau,
                        hard=self.hard
                    )  # T*B, ne
                else:
                    resp = torch.softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        dim=-1
                    )  # T*B, ne

            else:
                if self.tau:
                    resp = F.gumbel_softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        tau=self.tau,
                        hard=self.hard
                    )  # T*B, ne
                else:
                    resp = torch.softmax(
                        self.assign(x.contiguous().view(T*B, C)),
                        dim=-1
                    )  # T*B, ne

                # For the new exp with ne600 eval is soft as well
                # logits = self.assign(x.contiguous().view(T*B, C))
                # indices = torch.argmax(logits, dim=-1)
                # resp = torch.zeros(logits.size(0), self.ne).type_as(logits)
                # resp.scatter_(1, indices.unsqueeze(1), 1)
            importance = resp.sum(dim=0) 
            loss = self.loss_scale * torch.std(importance) / torch.mean(importance)

            print('importance', importance.data.round())
            w = torch.matmul(resp, self.pw_w1)  # T*B, C_out * C_in
            w = w.view(T, B, self.output_dim, self.input_dim)
            x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1)
            if self.pw_bias is not None:
                x = x + self.pw_bias(x0)
            return x, loss 
Example #19
Source File: genotypes.py    From nni with MIT License 4 votes vote down vote up
def parse_gumbel(alpha, beta, k):
    """
    parse continuous alpha to discrete gene.
    alpha is ParameterList:
    ParameterList [
        Parameter(n_edges1, n_ops),
        Parameter(n_edges2, n_ops),
        ...
    ]

    beta is ParameterList:
    ParameterList [
        Parameter(n_edges1),
        Parameter(n_edges2),
        ...
    ]

    gene is list:
    [
        [('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)],
        [('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)],
        ...
    ]
    each node has two edges (k=2) in CNN.
    """

    gene = []
    assert PRIMITIVES[-1] == 'none'  # assume last PRIMITIVE is 'none'

    # 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge
    # 2) Choose top-k edges per node by edge score (top-1 weight in edge)
    # output the connect idx[(node_idx, connect_idx, op_idx).... () ()]
    connect_idx = []
    for edges, w in zip(alpha, beta):
        # edges: Tensor(n_edges, n_ops)
        discrete_a = F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True)
        for i in range(k-1):
            discrete_a = discrete_a + F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True)
        discrete_a = discrete_a.reshape(-1, len(PRIMITIVES)-1)
        reserved_edge = (discrete_a > 0).nonzero()

        node_gene = []
        node_idx = []
        for i in range(reserved_edge.shape[0]):
            edge_idx = reserved_edge[i][0].item()
            prim_idx = reserved_edge[i][1].item()
            prim = PRIMITIVES[prim_idx]
            node_gene.append((prim, edge_idx))
            node_idx.append((edge_idx, prim_idx))

        gene.append(node_gene)
        connect_idx.append(node_idx)

    return gene, connect_idx