Python torch.unbind() Examples

The following are 30 code examples of torch.unbind(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: gradient.py    From captum with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _neuron_gradients(
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    saved_layer: Dict[device, Tuple[Tensor, ...]],
    key_list: List[device],
    gradient_neuron_index: Union[int, Tuple[int, ...]],
) -> Tuple[Tensor, ...]:
    with torch.autograd.set_grad_enabled(True):
        gradient_tensors = []
        for key in key_list:
            assert (
                len(saved_layer[key]) == 1
            ), "Cannot compute neuron gradients for layer with multiple tensors."
            current_out_tensor = saved_layer[key][0]
            gradient_tensors.append(
                torch.autograd.grad(
                    torch.unbind(
                        _verify_select_column(current_out_tensor, gradient_neuron_index)
                    ),
                    inputs,
                )
            )
        _total_gradients = _reduce_list(gradient_tensors, sum)
    return _total_gradients 
Example #2
Source File: context_conditioned_policy.py    From garage with MIT License 6 votes vote down vote up
def compute_kl_div(self):
        r"""Compute :math:`KL(q(z|c) \| p(z))`.

        Returns:
            float: :math:`KL(q(z|c) \| p(z))`.

        """
        prior = torch.distributions.Normal(
            torch.zeros(self._latent_dim).to(global_device()),
            torch.ones(self._latent_dim).to(global_device()))
        posteriors = [
            torch.distributions.Normal(mu, torch.sqrt(var)) for mu, var in zip(
                torch.unbind(self.z_means), torch.unbind(self.z_vars))
        ]
        kl_divs = [
            torch.distributions.kl.kl_divergence(post, prior)
            for post in posteriors
        ]
        kl_div_sum = torch.sum(torch.stack(kl_divs))
        return kl_div_sum 
Example #3
Source File: __init__.py    From yolo2-pytorch with GNU Lesser General Public License v3.0 6 votes vote down vote up
def fit_positive(rows, cols, yx_min, yx_max, anchors):
    device_id = anchors.get_device() if torch.cuda.is_available() else None
    batch_size, num, _ = yx_min.size()
    num_anchors, _ = anchors.size()
    valid = torch.prod(yx_min < yx_max, -1)
    center = (yx_min + yx_max) / 2
    ij = torch.floor(center)
    i, j = torch.unbind(ij.long(), -1)
    index = i * cols + j
    anchors2 = anchors / 2
    iou_matrix = utils.iou.torch.iou_matrix((yx_min - center).view(-1, 2), (yx_max - center).view(-1, 2), -anchors2, anchors2).view(batch_size, -1, num_anchors)
    iou, index_anchor = iou_matrix.max(-1)
    _positive = []
    cells = rows * cols
    for valid, index, index_anchor in zip(torch.unbind(valid), torch.unbind(index), torch.unbind(index_anchor)):
        index, index_anchor = (t[valid] for t in (index, index_anchor))
        t = utils.ensure_device(torch.ByteTensor(cells, num_anchors).zero_(), device_id)
        t[index, index_anchor] = 1
        _positive.append(t)
    return torch.stack(_positive) 
Example #4
Source File: vqa_net.py    From block.bootstrap.pytorch with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def forward(self, q, v):
        alpha = self.process_attention(q, v)

        if self.mlp_glimpses > 0:
            alpha = self.linear0(alpha)
            alpha = F.relu(alpha)
            alpha = self.linear1(alpha)

        alpha = F.softmax(alpha, dim=1)

        if alpha.size(2) > 1: # nb_glimpses > 1
            alphas = torch.unbind(alpha, dim=2)
            v_outs = []
            for alpha in alphas:
                alpha = alpha.unsqueeze(2).expand_as(v)
                v_out = alpha*v
                v_out = v_out.sum(1)
                v_outs.append(v_out)
            v_out = torch.cat(v_outs, dim=1)
        else:
            alpha = alpha.expand_as(v)
            v_out = alpha*v
            v_out = v_out.sum(1)
        return v_out 
Example #5
Source File: recurrent.py    From Tagger with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def forward(self, x, state):
        c, h = state

        gates = self.gates(torch.cat([x, h], 1))

        if self.layer_norm is not None:
            combined = self.layer_norm(
                torch.reshape(gates, [-1, 4, self.output_size]))
        else:
            combined = torch.reshape(gates, [-1, 4, self.output_size])

        i, j, f, o = torch.unbind(combined, 1)
        i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)

        new_c = f * c + i * torch.tanh(j)

        if self.activation is None:
            # Do not use tanh activation
            new_h = o * new_c
        else:
            new_h = o * self.activation(new_c)

        return new_h, (new_c, new_h) 
Example #6
Source File: pose.py    From photometric-mesh-optim with MIT License 6 votes vote down vote up
def rotation_matrix_to_quaternion(R): # [B,3,3]
	row0,row1,row2 = torch.unbind(R,dim=-2)
	R00,R01,R02 = torch.unbind(row0,dim=-1)
	R10,R11,R12 = torch.unbind(row1,dim=-1)
	R20,R21,R22 = torch.unbind(row2,dim=-1)
	t = R[...,0,0]+R[...,1,1]+R[...,2,2]
	r = (1+t).sqrt()
	qa = 0.5*r
	qb = (R21-R12).sign()*0.5*(1+R00-R11-R22).sqrt()
	qc = (R02-R20).sign()*0.5*(1-R00+R11-R22).sqrt()
	qd = (R10-R01).sign()*0.5*(1-R00-R11+R22).sqrt()
	q = torch.stack([qa,qb,qc,qd],dim=-1)
	for i,qi in enumerate(q):
		if torch.isnan(qi).any():
			print(i)
			K = torch.stack([torch.stack([R00-R11-R22,R10+R01,R20+R02,R12-R21],dim=-1),
							 torch.stack([R10+R01,R11-R00-R22,R21+R12,R20-R20],dim=-1),
							 torch.stack([R20+R02,R21+R12,R22-R00-R11,R01-R10],dim=-1),
							 torch.stack([R12-R21,R20-R02,R01-R10,R00+R11+R22],dim=-1)],dim=-2)/3.0
			K = K[i]
			eigval,eigvec = K.eig(eigenvectors=True)
			idx = eigval[:,0].argmax()
			V = eigvec[:,idx]
			q[i] = torch.stack([V[3],V[0],V[1],V[2]])
	return q 
Example #7
Source File: evaluator.py    From Vanilla_NER with Apache License 2.0 6 votes vote down vote up
def calc_f1_batch(self, decoded_data, target_data):
        """
        update statics for f1 score.

        Parameters
        ----------
        decoded_data: ``torch.LongTensor``, required.
            the decoded best label index pathes.
        target_data:  ``torch.LongTensor``, required.
            the golden label index pathes.
        """
        batch_decoded = torch.unbind(decoded_data, 1)

        for decoded, target in zip(batch_decoded, target_data):
            length = len(target)
            best_path = decoded[:length]

            correct_labels_i, total_labels_i, gold_count_i, guess_count_i, overlap_count_i = self.eval_instance(best_path.numpy(), target)
            self.correct_labels += correct_labels_i
            self.total_labels += total_labels_i
            self.gold_count += gold_count_i
            self.guess_count += guess_count_i
            self.overlap_count += overlap_count_i 
Example #8
Source File: evaluator.py    From Vanilla_NER with Apache License 2.0 6 votes vote down vote up
def calc_acc_batch(self, decoded_data, target_data):
        """
        update statics for accuracy score.

        Parameters
        ----------
        decoded_data: ``torch.LongTensor``, required.
            the decoded best label index pathes.
        target_data:  ``torch.LongTensor``, required.
            the golden label index pathes.
        """
        batch_decoded = torch.unbind(decoded_data, 1)

        for decoded, target in zip(batch_decoded, target_data):
            
            # remove padding
            length = len(target)
            best_path = decoded[:length].numpy()

            self.total_labels += length
            self.correct_labels += np.sum(np.equal(best_path, gold)) 
Example #9
Source File: utils.py    From RPNet-Pytorch with MIT License 6 votes vote down vote up
def batch_transform(batch, transform):
    """Applies a transform to a batch of samples.

    Keyword arguments:
    - batch (): a batch os samples
    - transform (callable): A function/transform to apply to ``batch``

    """

    # Convert the single channel label to RGB in tensor form
    # 1. F.unbind removes the 0-dimension of "labels" and returns a tuple of
    # all slices along that dimension
    # 2. the transform is applied to each slice
    transf_slices = [transform(tensor) for tensor in F.unbind(batch)]

    return F.stack(transf_slices) 
Example #10
Source File: fusion.py    From VQA_ReGAT with MIT License 6 votes vote down vote up
def forward(self, q, v):
        alpha = self.process_attention(q, v)

        if self.mlp_glimpses > 0:
            alpha = self.linear0(alpha)
            alpha = F.relu(alpha)
            alpha = self.linear1(alpha)

        alpha = F.softmax(alpha, dim=1)

        if alpha.size(2) > 1:  # nb_glimpses > 1
            alphas = torch.unbind(alpha, dim=2)
            v_outs = []
            for alpha in alphas:
                alpha = alpha.unsqueeze(2).expand_as(v)
                v_out = alpha*v
                v_out = v_out.sum(1)
                v_outs.append(v_out)
            v_out = torch.cat(v_outs, dim=1)
        else:
            alpha = alpha.expand_as(v)
            v_out = alpha*v
            v_out = v_out.sum(1)
        return v_out 
Example #11
Source File: warp.py    From inverse-compositional-STN with MIT License 6 votes vote down vote up
def transformImage(opt,image,pMtrx):
	refMtrx = torch.from_numpy(opt.refMtrx).cuda()
	refMtrx = refMtrx.repeat(opt.batchSize,1,1)
	transMtrx = refMtrx.matmul(pMtrx)
	# warp the canonical coordinates
	X,Y = np.meshgrid(np.linspace(-1,1,opt.W),np.linspace(-1,1,opt.H))
	X,Y = X.flatten(),Y.flatten()
	XYhom = np.stack([X,Y,np.ones_like(X)],axis=1).T
	XYhom = np.tile(XYhom,[opt.batchSize,1,1]).astype(np.float32)
	XYhom = torch.from_numpy(XYhom).cuda()
	XYwarpHom = transMtrx.matmul(XYhom)
	XwarpHom,YwarpHom,ZwarpHom = torch.unbind(XYwarpHom,dim=1)
	Xwarp = (XwarpHom/(ZwarpHom+1e-8)).reshape(opt.batchSize,opt.H,opt.W)
	Ywarp = (YwarpHom/(ZwarpHom+1e-8)).reshape(opt.batchSize,opt.H,opt.W)
	grid = torch.stack([Xwarp,Ywarp],dim=-1)
	# sampling with bilinear interpolation
	imageWarp = torch.nn.functional.grid_sample(image,grid,mode="bilinear")
	return imageWarp 
Example #12
Source File: util.py    From multee with Apache License 2.0 6 votes vote down vote up
def sentencewise_scores2paragraph_tokenwise_scores(sentences_scores, sentences_mask):
    """
    # Input:
    # sentences_mask: (batch_size X num_sentences X sent_seq_len)
    # sentences_scores: (batch_size X num_sentences)

    # Output:
    # paragraph_tokenwise_scores: (batch_size X max_para_seq_len)
    """
    paragraph_tokenwise_scores = []
    for instance_sentences_scores, instance_sentences_mask in zip(torch.unbind(sentences_scores, dim=0),
                                                                  torch.unbind(sentences_mask, dim=0)):
        instance_paragraph_tokenwise_scores = torch.masked_select(instance_sentences_scores.unsqueeze(-1),
                                                                    instance_sentences_mask.byte())
        paragraph_tokenwise_scores.append(instance_paragraph_tokenwise_scores)
    paragraph_tokenwise_scores = torch.nn.utils.rnn.pad_sequence(paragraph_tokenwise_scores, batch_first=True)
    return paragraph_tokenwise_scores 
Example #13
Source File: run_pplm_discrim_train.py    From PPLM with Apache License 2.0 6 votes vote down vote up
def get_cached_data_loader(dataset, batch_size, discriminator,
                           shuffle=False, device='cpu'):
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              collate_fn=collate_fn)

    xs = []
    ys = []
    for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
        with torch.no_grad():
            x = x.to(device)
            avg_rep = discriminator.avg_representation(x).cpu().detach()
            avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
            xs += avg_rep_list
            ys += y.cpu().numpy().tolist()

    data_loader = torch.utils.data.DataLoader(
        dataset=Dataset(xs, ys),
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=cached_collate_fn)

    return data_loader 
Example #14
Source File: test_neuron_gradient_shap.py    From captum with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _assert_attributions(
        self,
        model: Module,
        layer: Module,
        inputs: Tensor,
        baselines: Union[Tensor, Callable[..., Tensor]],
        neuron_ind: Union[int, tuple],
        n_samples: int = 5,
    ) -> None:
        ngs = NeuronGradientShap(model, layer)
        nig = NeuronIntegratedGradients(model, layer)
        attrs_gs = ngs.attribute(
            inputs, neuron_ind, baselines=baselines, n_samples=n_samples, stdevs=0.09
        )

        if callable(baselines):
            baselines = baselines(inputs)

        attrs_ig = []
        for baseline in torch.unbind(baselines):
            attrs_ig.append(
                nig.attribute(inputs, neuron_ind, baselines=baseline.unsqueeze(0))
            )
        combined_attrs_ig = torch.stack(attrs_ig, dim=0).mean(dim=0)
        assertTensorAlmostEqual(self, attrs_gs, combined_attrs_ig, 0.5) 
Example #15
Source File: run_pplm_discrim_train.py    From exbert with Apache License 2.0 6 votes vote down vote up
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False, device="cpu"):
    data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)

    xs = []
    ys = []
    for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
        with torch.no_grad():
            x = x.to(device)
            avg_rep = discriminator.avg_representation(x).cpu().detach()
            avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
            xs += avg_rep_list
            ys += y.cpu().numpy().tolist()

    data_loader = torch.utils.data.DataLoader(
        dataset=Dataset(xs, ys), batch_size=batch_size, shuffle=shuffle, collate_fn=cached_collate_fn
    )

    return data_loader 
Example #16
Source File: utils.py    From PyTorch-ENet with MIT License 6 votes vote down vote up
def batch_transform(batch, transform):
    """Applies a transform to a batch of samples.

    Keyword arguments:
    - batch (): a batch os samples
    - transform (callable): A function/transform to apply to ``batch``

    """

    # Convert the single channel label to RGB in tensor form
    # 1. torch.unbind removes the 0-dimension of "labels" and returns a tuple of
    # all slices along that dimension
    # 2. the transform is applied to each slice
    transf_slices = [transform(tensor) for tensor in torch.unbind(batch)]

    return torch.stack(transf_slices) 
Example #17
Source File: vqa_net.py    From block.bootstrap.pytorch with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def process_question(self, q, l):
        q_emb = self.txt_enc.embedding(q)
        q, _ = self.txt_enc.rnn(q_emb)

        if self.self_q_att:
            q_att = self.q_att_linear0(q)
            q_att = F.relu(q_att)
            q_att = self.q_att_linear1(q_att)
            q_att = mask_softmax(q_att, l)
            #self.q_att_coeffs = q_att
            if q_att.size(2) > 1:
                q_atts = torch.unbind(q_att, dim=2)
                q_outs = []
                for q_att in q_atts:
                    q_att = q_att.unsqueeze(2)
                    q_att = q_att.expand_as(q)
                    q_out = q_att*q
                    q_out = q_out.sum(1)
                    q_outs.append(q_out)
                q = torch.cat(q_outs, dim=1)
            else:
                q_att = q_att.expand_as(q)
                q = q_att * q
                q = q.sum(1)
        else:
            # l contains the number of words for each question
            # in case of multi-gpus it must be a Tensor
            # thus we convert it into a list during the forward pass
            l = list(l.data[:,0])
            q = self.txt_enc._select_last(q, l)

        return q 
Example #18
Source File: util.py    From multee with Apache License 2.0 5 votes vote down vote up
def paragraph2sentences_tensor(paragraphwise_tensor, sentence_lengths):
    """
    # # Input:
    # paragraphwise_tensor: (batch_size, paragraph_max_seq_len, ...)
    # sentence_lengths: (batch_size, premises_count)

    # Output:
    # sentencewise_tensor: (batch_size, num_sentences, sentence_max_seq_len, ...)

    # rough eg. for one instance of batch:
    # paragraphwise_tensor = torch.tensor(45, 10)
    # sentence_lengths = torch.tensor([10, 10, 10, 15])
    # cumulated_sentence_lengths = (10, 20, 30, 45)
    # shifted_cumulated_sentence_lengths = (0, 10, 20, 30)
    # range_indices = zip(shifted_cumulated_sentence_lengths, cumulated_sentence_lengths)
    # sentencewise_tensor = ([paragraphwise_tensor[start:end] for start, end in range_indices])
    # sentencewise_tensor = torch.nn.utils.rnn.pad_sequence(sentencewise_tensor, batch_first=True)
    # return sentencewise_tensor
    """
    sentencewise_tensors = []
    # max_sentence_length across all paragraphs (ie. any batch instance)
    max_sentence_length = sentence_lengths.max()
    for instance_paragraphwise_tensor, instance_sentence_lengths in zip(torch.unbind(paragraphwise_tensor, dim=0),
                                                                        torch.unbind(sentence_lengths, dim=0)):

        instance_cumulated_sentence_lengths = instance_sentence_lengths.cumsum(dim=0)
        instance_shifted_cumulated_sentence_lengths = instance_cumulated_sentence_lengths - instance_sentence_lengths
        range_indices = zip(instance_shifted_cumulated_sentence_lengths, instance_cumulated_sentence_lengths)

        sentencewise_tensor = [instance_paragraphwise_tensor[start.int():end.int()] for start, end in range_indices]
        sentencewise_tensor = torch.nn.utils.rnn.pad_sequence(sentencewise_tensor, batch_first=True)

        # sentencewise_tensor: (num_sentences, sentence_max_seq_len, ...)
        # adjust first dim by max sentence length across all the batch instances.
        padding = max_sentence_length - sentencewise_tensor.shape[1]
        padding_tuple = ([0]*(len(sentencewise_tensor.shape)-2)*2) + [0, padding.int(), 0, 0]
        sentencewise_tensor = F.pad(sentencewise_tensor, pad=padding_tuple)
        sentencewise_tensors.append(sentencewise_tensor)

    sentencewise_tensors = torch.nn.utils.rnn.pad_sequence(sentencewise_tensors, batch_first=True)
    return sentencewise_tensors 
Example #19
Source File: sgd_decoder_nm.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def _get_noncategorical_slot_goals(self, encoded_utterance, utterance_mask, noncat_slot_emb, token_embeddings):
        """
        Obtain logits for status and slot spans for non-categorical slots.
        Slot status values: none, dontcare, active
        """
        # Predict the status of all non-categorical slots.
        max_num_slots = noncat_slot_emb.size()[1]
        status_logits = self.noncat_slot_layer(
            encoded_utterance=encoded_utterance,
            token_embeddings=token_embeddings,
            element_embeddings=noncat_slot_emb,
            utterance_mask=utterance_mask,
        )

        # Predict the distribution for span indices.
        max_num_tokens = token_embeddings.size()[1]

        repeated_token_embeddings = token_embeddings.unsqueeze(1).repeat(1, max_num_slots, 1, 1)
        repeated_slot_embeddings = noncat_slot_emb.unsqueeze(2).repeat(1, 1, max_num_tokens, 1)

        # Shape: (batch_size, max_num_slots, max_num_tokens, 2 * embedding_dim).
        slot_token_embeddings = torch.cat([repeated_slot_embeddings, repeated_token_embeddings], axis=3)

        # Project the combined embeddings to obtain logits, Shape: (batch_size, max_num_slots, max_num_tokens, 2)
        span_logits = self.noncat_layer1(slot_token_embeddings)
        span_logits = self.noncat_activation(span_logits)
        span_logits = self.noncat_layer2(span_logits)

        # Mask out invalid logits for padded tokens.
        utterance_mask = utterance_mask.to(bool)  # Shape: (batch_size, max_num_tokens).
        repeated_utterance_mask = utterance_mask.unsqueeze(1).unsqueeze(3).repeat(1, max_num_slots, 1, 2)
        negative_logits = (torch.finfo(span_logits.dtype).max * -0.7) * torch.ones(
            span_logits.size(), device=self._device, dtype=span_logits.dtype
        )

        span_logits = torch.where(repeated_utterance_mask, span_logits, negative_logits)

        # Shape of both tensors: (batch_size, max_num_slots, max_num_tokens).
        span_start_logits, span_end_logits = torch.unbind(span_logits, dim=3)
        return status_logits, span_start_logits, span_end_logits 
Example #20
Source File: ParsingNetwork.py    From PRPN with MIT License 5 votes vote down vote up
def forward(self, emb, parser_state):
        emb_last, cum_gate = parser_state
        ntimestep = emb.size(0)

        emb_last = torch.cat([emb_last, emb], dim=0)
        emb = emb_last.transpose(0, 1).transpose(1, 2)  # bsz, ninp, ntimestep + nlookback

        gates = self.gate(emb)  # bsz, 2, ntimestep
        gate = gates[:, 0, :]
        gate_next = gates[:, 1, :]
        cum_gate = torch.cat([cum_gate, gate], dim=1)
        gate_hat = torch.stack([cum_gate[:, i:i + ntimestep] for i in range(self.nslots, 0, -1)],
                               dim=2)  # bsz, ntimestep, nslots

        if self.hard:
            memory_gate = (F.hardtanh((gate[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
        else:
            memory_gate = F.sigmoid(
                (gate[:, :, None] - gate_hat) / self.resolution * 10 + 5)  # bsz, ntimestep, nslots
        memory_gate = torch.cumprod(memory_gate, dim=2)  # bsz, ntimestep, nlookback+1
        memory_gate = torch.unbind(memory_gate, dim=1)

        if self.hard:
            memory_gate_next = (F.hardtanh((gate_next[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
        else:
            memory_gate_next = F.sigmoid(
                (gate_next[:, :, None] - gate_hat) / self.resolution * 10 + 5)  # bsz, ntimestep, nslots
        memory_gate_next = torch.cumprod(memory_gate_next, dim=2)  # bsz, ntimestep, nlookback+1
        memory_gate_next = torch.unbind(memory_gate_next, dim=1)

        return (memory_gate, memory_gate_next), gate, (emb_last[-self.nlookback:], cum_gate[:, -self.nslots:]) 
Example #21
Source File: ParsingNetwork.py    From PRPN-Analysis with MIT License 5 votes vote down vote up
def forward(self, emb, parser_state):
        emb_last, cum_gate = parser_state
        ntimestep = emb.size(0)

        emb_last = torch.cat([emb_last, emb], dim=0)
        emb = emb_last.transpose(0, 1).transpose(1, 2)  # bsz, ninp, ntimestep + nlookback

        gates = self.gate(emb)  # bsz, 2, ntimestep
        gate = gates[:, 0, :]
        gate_next = gates[:, 1, :]
        cum_gate = torch.cat([cum_gate, gate], dim=1)
        gate_hat = torch.stack([cum_gate[:, i:i + ntimestep] for i in range(self.nslots, 0, -1)],
                               dim=2)  # bsz, ntimestep, nslots

        if self.hard:
            memory_gate = (F.hardtanh((gate[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
        else:
            memory_gate = F.sigmoid(
                (gate[:, :, None] - gate_hat) / self.resolution * 10 + 5)  # bsz, ntimestep, nslots
        memory_gate = torch.cumprod(memory_gate, dim=2)  # bsz, ntimestep, nlookback+1
        memory_gate = torch.unbind(memory_gate, dim=1)

        if self.hard:
            memory_gate_next = (F.hardtanh((gate_next[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
        else:
            memory_gate_next = F.sigmoid(
                (gate_next[:, :, None] - gate_hat) / self.resolution * 10 + 5)  # bsz, ntimestep, nslots
        memory_gate_next = torch.cumprod(memory_gate_next, dim=2)  # bsz, ntimestep, nlookback+1
        memory_gate_next = torch.unbind(memory_gate_next, dim=1)

        return (memory_gate, memory_gate_next), gate, (emb_last[-self.nlookback:], cum_gate[:, -self.nslots:]) 
Example #22
Source File: recurrent.py    From Tagger with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def forward(self, x, state):
        c, h = state

        gates = self.gates(torch.cat([x, h], 1))
        combined = torch.reshape(gates, [-1, 5, self.output_size])
        i, j, f, o, t = torch.unbind(combined, 1)
        i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
        t = torch.sigmoid(t)

        new_c = f * c + i * torch.tanh(j)
        tmp_h = o * torch.tanh(new_c)
        new_h = t * tmp_h + (1.0 - t) * self.trans(x)

        return new_h, (new_c, new_h) 
Example #23
Source File: util.py    From multee with Apache License 2.0 5 votes vote down vote up
def sentences2paragraph_tensor(sentencewise_tensor, sentences_mask):
    """
    # Input:
    # sentencewise_tensor: (batch_size, num_sentences, sentence_max_seq_len, ...)
    # sentences_mask: (batch_size, num_sentences, sentence_max_seq_len)

    # Output:
    # paragraphwise_tensor: (batch_size, paragraph_max_seq_len, ...)
    """
    num_sentences = sentencewise_tensor.shape[1]
    trailing_shape = list(sentencewise_tensor.shape[3:])

    sentences_mask = sentences_mask.byte()
    # keep unsqueezing instance_sentences_mask at -1 to make it same shape as sentencewise_tensor and then .byte()
    while len(sentences_mask.shape) < len(sentencewise_tensor.shape):
        sentences_mask = sentences_mask.unsqueeze(-1)

    paragraphwise_tensor = []
    for instance_sentencewise_tensor, instance_sentences_mask in zip(torch.unbind(sentencewise_tensor, dim=0),
                                                                     torch.unbind(sentences_mask, dim=0)):
        instance_paragraphwise_tensor = instance_sentencewise_tensor.masked_select(instance_sentences_mask)
        instance_paragraphwise_tensor = instance_paragraphwise_tensor.reshape([-1]+trailing_shape)
        paragraphwise_tensor.append(instance_paragraphwise_tensor)

    paragraphwise_tensor = torch.nn.utils.rnn.pad_sequence(paragraphwise_tensor, batch_first=True)
    return paragraphwise_tensor 
Example #24
Source File: box_regression.py    From detectron2 with Apache License 2.0 5 votes vote down vote up
def get_deltas(self, src_boxes, target_boxes):
        """
        Get box regression transformation deltas (dx, dy, dw, dh, da) that can be used
        to transform the `src_boxes` into the `target_boxes`. That is, the relation
        ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless
        any delta is too large and is clamped).

        Args:
            src_boxes (Tensor): Nx5 source boxes, e.g., object proposals
            target_boxes (Tensor): Nx5 target of the transformation, e.g., ground-truth
                boxes.
        """
        assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
        assert isinstance(target_boxes, torch.Tensor), type(target_boxes)

        src_ctr_x, src_ctr_y, src_widths, src_heights, src_angles = torch.unbind(src_boxes, dim=1)

        target_ctr_x, target_ctr_y, target_widths, target_heights, target_angles = torch.unbind(
            target_boxes, dim=1
        )

        wx, wy, ww, wh, wa = self.weights
        dx = wx * (target_ctr_x - src_ctr_x) / src_widths
        dy = wy * (target_ctr_y - src_ctr_y) / src_heights
        dw = ww * torch.log(target_widths / src_widths)
        dh = wh * torch.log(target_heights / src_heights)
        # Angles of deltas are in radians while angles of boxes are in degrees.
        # the conversion to radians serve as a way to normalize the values
        da = target_angles - src_angles
        da = (da + 180.0) % 360.0 - 180.0  # make it in [-180, 180)
        da *= wa * math.pi / 180.0

        deltas = torch.stack((dx, dy, dw, dh, da), dim=1)
        assert (
            (src_widths > 0).all().item()
        ), "Input boxes to Box2BoxTransformRotated are not valid!"
        return deltas 
Example #25
Source File: gradient.py    From captum with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def compute_gradients(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    target_ind: TargetType = None,
    additional_forward_args: Any = None,
) -> Tuple[Tensor, ...]:
    r"""
        Computes gradients of the output with respect to inputs for an
        arbitrary forward function.

        Args:

            forward_fn: forward function. This can be for example model's
                        forward function.
            input:      Input at which gradients are evaluated,
                        will be passed to forward_fn.
            target_ind: Index of the target class for which gradients
                        must be computed (classification only).
            additional_forward_args: Additional input arguments that forward
                        function requires. It takes an empty tuple (no additional
                        arguments) if no additional arguments are required
    """
    with torch.autograd.set_grad_enabled(True):
        # runs forward pass
        outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
        assert outputs[0].numel() == 1, (
            "Target not provided when necessary, cannot"
            " take gradient with respect to multiple outputs."
        )
        # torch.unbind(forward_out) is a list of scalar tensor tuples and
        # contains batch_size * #steps elements
        grads = torch.autograd.grad(torch.unbind(outputs), inputs)
    return grads 
Example #26
Source File: modules.py    From kur with Apache License 2.0 5 votes vote down vote up
def parallel(layer):
	""" Creates a parallel operation (i.e., map/distributed operation).
	"""
	def func(module, x):
		""" The actual wrapped operation.
		"""
		return torch.stack(
			tuple(Layer.resolve(layer)(module, X) for X in torch.unbind(x, 0)),
			0
		)
	func.pure = True
	return func

############################################################################### 
Example #27
Source File: policy.py    From midlevel-reps with MIT License 5 votes vote down vote up
def apply(func, tensor):
    tList = [func(m) for m in torch.unbind(tensor, dim=0) ]
    res = torch.stack(tList, dim=0)
    return res 
Example #28
Source File: utils_dionysus.py    From TopologyLayer with MIT License 5 votes vote down vote up
def top_batch_cost(gen_imgs, diagramlayer, filtration):
    start_time = time.time()
    axis=0
    costs = torch.stack([
        top_cost(x_i.view(-1), diagramlayer, filtration) for i, x_i in enumerate(torch.unbind(gen_imgs, dim=axis), 0)
    ], dim=axis)
    avg = torch.mean(costs.view(-1))
    print("top_batch_cost", "time: ", time.time() - start_time, "cost: ", avg)
    return avg
    ''' *** End Topology *** ''' 
Example #29
Source File: utils_dionysus.py    From TopologyLayer with MIT License 5 votes vote down vote up
def top_batch_features(input, diagramlayer, filtration, dim=1):
    #print(gen_imgs.shape)
    start_time = time.time()
    axis=0
    #print("input",input)
    feats = torch.stack([
        top_features(x_i.view(-1), diagramlayer, filtration, dim) for i, x_i in enumerate(torch.unbind(input, dim=axis), 0)
    ], dim=axis)
    #avg = torch.mean(costs.view(-1))
    print("feats", "time: ", time.time() - start_time)
    #print feats.shape
    return feats 
Example #30
Source File: util.py    From multee with Apache License 2.0 5 votes vote down vote up
def unbind_tensor_dict(dict_tensors, dim):
    """
    Unbinds each tensor dict as returned by text_field.as_tensor in forward method
    on a certain dimension and returns a list of such tensor dicts
    """
    intermediate_dict = {}
    for key, tensor in dict_tensors.items():
        intermediate_dict[key] = torch.unbind(tensor, dim=dim)
        items_count = len(intermediate_dict[key])
    dict_tensor_list = [{} for _ in range(items_count)]
    for key, tensors in intermediate_dict.items():
        for index, tensor in enumerate(tensors):
            dict_tensor_list[index][key] = tensor
    return dict_tensor_list