Python torch.repeat_interleave() Examples

The following are 30 code examples of torch.repeat_interleave(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: cuda_tensor.py    From CrypTen with MIT License 6 votes vote down vote up
def __patched_conv_ops(op, x, y, *args, **kwargs):
        x_encoded = CUDALongTensor.__encode_as_fp64(x).data
        y_encoded = CUDALongTensor.__encode_as_fp64(y).data

        repeat_idx = [1] * (x_encoded.dim() - 1)
        x_enc_span = x_encoded.repeat(3, *repeat_idx)
        y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0)

        bs, c, *img = x.size()
        c_out, c_in, *ks = y.size()

        x_enc_span = x_enc_span.transpose_(0, 1).reshape(bs, 9 * c, *img)
        y_enc_span = y_enc_span.reshape(9 * c_out, c_in, *ks)

        c_z = c_out if op in ["conv1d", "conv2d"] else c_in

        z_encoded = getattr(torch, op)(
            x_enc_span, y_enc_span, *args, **kwargs, groups=9
        )
        z_encoded = z_encoded.reshape(bs, 9, c_z, *z_encoded.size()[2:]).transpose_(
            0, 1
        )

        return CUDALongTensor.__decode_as_int64(z_encoded) 
Example #2
Source File: fastspeech.py    From NeMo with Apache License 2.0 6 votes vote down vote up
def get_output(encoder_output, duration_predictor_output, alpha, mel_max_length=None):
        output = list()
        dec_pos = list()

        for i in range(encoder_output.size(0)):
            repeats = duration_predictor_output[i].float() * alpha
            repeats = torch.round(repeats).long()
            output.append(torch.repeat_interleave(encoder_output[i], repeats, dim=0))
            dec_pos.append(torch.from_numpy(np.indices((output[i].shape[0],))[0] + 1))

        output = torch.nn.utils.rnn.pad_sequence(output, batch_first=True)
        dec_pos = torch.nn.utils.rnn.pad_sequence(dec_pos, batch_first=True)

        dec_pos = dec_pos.to(output.device, non_blocking=True)

        if mel_max_length:
            output = output[:, :mel_max_length]
            dec_pos = dec_pos[:, :mel_max_length]

        return output, dec_pos 
Example #3
Source File: types.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_tiled_batch(self, num_tiles: int):
        assert (
            self.has_float_features_only
        ), f"only works for float features now: {self}"
        """
        tiled_feature should be (batch_size * num_tiles, feature_dim)
        forall i in [batch_size],
        tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i]
        """
        feat = self.float_features
        assert (
            len(feat.shape) == 2
        ), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}."
        batch_size, _ = feat.shape
        # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`.
        tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0)
        return FeatureData(float_features=tiled_feat) 
Example #4
Source File: scatter.py    From torchsupport with MIT License 6 votes vote down vote up
def pairwise_no_pad(op, data, indices):
  unique, counts = indices.unique(return_counts=True)
  expansion = torch.cumsum(counts, dim=0)
  expansion = torch.repeat_interleave(expansion, counts)
  offset = torch.arange(0, counts.sum(), device=data.device)
  expansion = expansion - offset - 1
  expanded = torch.repeat_interleave(data, expansion.to(data.device), dim=0)

  expansion_offset = counts.roll(1)
  expansion_offset[0] = 0
  expansion_offset = expansion_offset.cumsum(dim=0)
  expansion_offset = torch.repeat_interleave(expansion_offset, counts)
  expansion_offset = torch.repeat_interleave(expansion_offset, expansion)
  off_start = torch.repeat_interleave(torch.repeat_interleave(counts, counts) - expansion, expansion)
  access = torch.arange(expansion.sum(), device=data.device)
  access = access - torch.repeat_interleave(expansion.roll(1).cumsum(dim=0), expansion) + off_start + expansion_offset

  result = op(expanded, data[access.to(data.device)])
  return result, torch.repeat_interleave(indices, expansion, dim=0) 
Example #5
Source File: gradient.py    From torchsupport with MIT License 6 votes vote down vote up
def hard_k_hot(logits, k, temperature=0.1):
  r"""Returns a hard k-hot sample given a categorical
  distribution defined by a tensor of unnormalized
  log-likelihoods.

  This is useful for example to sample a set of pixels in an
  image to move from a grid-structured data representation to a
  set- or graph-structured representation within a network.

  Args:
    logits (torch.Tensor): unnormalized log-likelihood tensor.
    k (int): number of items to sample without replacement.
    temperature (float): temparature of the soft distribution.

  Returns:
    Hard k-hot vector from the relaxed k-hot distribution
    defined by logits and temperature.
  """
  soft = soft_k_hot(logits, k, temperature=temperature)
  hard = torch.zeros_like(soft)
  _, top_k = torch.topk(logits, k)
  index = torch.repeat_interleave(torch.arange(0, hard.size(0)), k)
  hard[index, top_k.view(-1)] = 1.0
  return replace_gradient(hard, soft) 
Example #6
Source File: mslr_slate.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def sample_weights(self) -> Tensor:
        if self._sample_weights is None:
            samples = self.queries[:, 2]
            self._sample_weights = torch.repeat_interleave(
                samples.to(dtype=torch.float).reciprocal(), samples.to(dtype=torch.long)
            )
        return self._sample_weights 
Example #7
Source File: cuda_tensor.py    From CrypTen with MIT License 5 votes vote down vote up
def matmul(x, y, *args, **kwargs):
        # Prepend 1 to the dimension of x or y if it is 1-dimensional
        remove_x, remove_y = False, False
        if x.dim() == 1:
            x = x.view(1, x.shape[0])
            remove_x = True
        if y.dim() == 1:
            y = y.view(y.shape[0], 1)
            remove_y = True

        x_encoded = CUDALongTensor.__encode_as_fp64(x).data
        y_encoded = CUDALongTensor.__encode_as_fp64(y).data

        # Span x and y for cross multiplication
        repeat_idx = [1] * (x_encoded.dim() - 1)
        x_enc_span = x_encoded.repeat(3, *repeat_idx)
        y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0)

        # Broadcasting
        for _ in range(abs(x_enc_span.ndim - y_enc_span.ndim)):
            if x_enc_span.ndim > y_enc_span.ndim:
                y_enc_span.unsqueeze_(1)
            else:
                x_enc_span.unsqueeze_(1)

        z_encoded = torch.matmul(x_enc_span, y_enc_span, *args, **kwargs)

        if remove_x:
            z_encoded.squeeze_(-2)
        if remove_y:
            z_encoded.squeeze_(-1)

        return CUDALongTensor.__decode_as_int64(z_encoded) 
Example #8
Source File: auxiliary_stillimages.py    From ZeroShotVideoClassification with Apache License 2.0 5 votes vote down vote up
def extract_video(self, img):
        buffer = self.transform(img)
        buffer = torch.repeat_interleave(torch.unsqueeze(buffer, 1), self.clip_len, 1)
        buffer = torch.repeat_interleave(torch.unsqueeze(buffer, 0), self.n_clips, 0)
        return buffer 
Example #9
Source File: adhoc_metric.py    From pt-ranking.github.io with MIT License 5 votes vote down vote up
def torch_batch_ideal_err(batch_sorted_labels, k=10, gpu=False, point=True):
	assert batch_sorted_labels.size(1) > k

	batch_max = torch.max(batch_sorted_labels, dim=1)

	batch_labels = batch_sorted_labels[:, 0:k]
	batch_satis_pros = (torch.pow(2.0, batch_labels) - 1.0) / torch.pow(2.0, batch_max)

	batch_unsatis_pros = torch.ones(batch_labels) - batch_satis_pros
	batch_cum_unsatis_pros = torch.cumprod(batch_unsatis_pros, dim=1)

	positions = torch.arange(k) + 1.0
	positions = positions.view(1, -1)
	positions = torch.repeat_interleave(positions, batch_sorted_labels.size(0), dim=0)

	batch_expt_ranks = 1.0 / positions

	cascad_unsatis_pros = positions
	cascad_unsatis_pros[:, 1:k] = batch_cum_unsatis_pros[:, 0:k-1]

	expt_satis_ranks = batch_expt_ranks * batch_satis_pros * cascad_unsatis_pros  # w.r.t. all rank positions

	if point:
		batch_errs = torch.sum(expt_satis_ranks, dim=1)
		return batch_errs
	else:
		batch_err_at_ks = torch.cumsum(expt_satis_ranks, dim=1)
		return batch_err_at_ks 
Example #10
Source File: datasets.py    From PySNN with MIT License 5 votes vote down vote up
def __init__(
        self, data_encoder=None, data_transform=None, lbl_transform=None, repeats=1
    ):
        self.data_encoder = data_encoder
        self.data_transform = data_transform
        self.lbl_transform = lbl_transform
        self.data = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.float)
        self.data = torch.repeat_interleave(self.data, int(repeats), dim=1) 
Example #11
Source File: aev.py    From torchani with MIT License 5 votes vote down vote up
def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
    """Input: indices for pairs of atoms that are close to each other.
    each pair only appear once, i.e. only one of the pairs (1, 2) and
    (2, 1) exists.

    Output: indices for all central atoms and it pairs of neighbors. For
    example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2),
    (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have
    central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
    are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
    """
    # convert representation from pair to central-others
    ai1 = atom_index12.view(-1)
    sorted_ai1, rev_indices = ai1.sort()

    # sort and compute unique key
    uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_inverse=False, return_counts=True)

    # compute central_atom_index
    pair_sizes = counts * (counts - 1) // 2
    pair_indices = torch.repeat_interleave(pair_sizes)
    central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)

    # do local combinations within unique key, assuming sorted
    m = counts.max().item() if counts.numel() > 0 else 0
    n = pair_sizes.shape[0]
    intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1)
    mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten()
    sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask]
    sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices)

    # unsort result from last part
    local_index12 = rev_indices[sorted_local_index12]

    # compute mapping between representation of central-other to pair
    n = atom_index12.shape[1]
    sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1
    return central_atom_index, local_index12 % n, sign12 
Example #12
Source File: test_graph_size_norm.py    From pytorch_geometric with MIT License 5 votes vote down vote up
def test_graph_size_norm():
    batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))
    norm = GraphSizeNorm()
    out = norm(torch.randn(100, 16), batch)
    assert out.size() == (100, 16) 
Example #13
Source File: types.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def select_slate(self, action: torch.Tensor):
        row_idx = torch.repeat_interleave(
            torch.arange(action.shape[0]).unsqueeze(1), action.shape[1], dim=1
        )
        mask = self.mask[row_idx, action]
        # Make sure the indices are in the right range
        assert mask.to(torch.bool).all()
        float_features = self.float_features[row_idx, action]
        value = self.value[row_idx, action]
        return DocList(float_features, mask, value) 
Example #14
Source File: talknet_modules.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def _generate_text_rep(text, dur):
        text_rep = []
        for t, d in zip(text, dur):
            text_rep.append(torch.repeat_interleave(t, d))

        text_rep = Ops.merge(text_rep)

        return text_rep 
Example #15
Source File: architecture.py    From deep_gcns_torch with MIT License 5 votes vote down vote up
def forward(self, inputs):
        feats = [self.head(inputs, self.knn(inputs[:, 0:3]))]
        for i in range(self.n_blocks-1):
            feats.append(self.backbone[i](feats[-1]))
        feats = torch.cat(feats, 1)
        fusion = self.fusion_block(feats)

        x1 = F.adaptive_max_pool2d(fusion, 1)
        x2 = F.adaptive_avg_pool2d(fusion, 1)
        feat_global_pool = torch.cat((x1, x2), dim=1)
        feat_global_pool = torch.repeat_interleave(feat_global_pool, repeats=fusion.shape[2], dim=2)
        cat_pooled = torch.cat((feat_global_pool, fusion), dim=1)
        out = self.prediction(cat_pooled).squeeze(-1)
        return F.log_softmax(out, dim=1) 
Example #16
Source File: architecture.py    From deep_gcns_torch with MIT License 5 votes vote down vote up
def forward(self, inputs):
        feats = [self.head(inputs, self.knn(inputs[:, 0:3]))]
        for i in range(self.n_blocks-1):
            feats.append(self.backbone[i](feats[-1]))
        feats = torch.cat(feats, dim=1)

        fusion = torch.max_pool2d(self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]])
        fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2)
        return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1) 
Example #17
Source File: architecture.py    From deep_gcns_torch with MIT License 5 votes vote down vote up
def forward(self, data):
        corr, color, batch = data.pos, data.x, data.batch
        x = torch.cat((corr, color), dim=1)
        feats = [self.head(x, self.knn(x[:, 0:3], batch))]
        for i in range(self.n_blocks-1):
            feats.append(self.backbone[i](feats[-1], batch)[0])
        feats = torch.cat(feats, dim=1)

        fusion = tg.utils.scatter_('max', self.fusion_block(feats), batch)
        fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0)
        return self.prediction(torch.cat((fusion, feats), dim=1)) 
Example #18
Source File: basic.py    From torchsupport with MIT License 5 votes vote down vote up
def __init__(self, batch, width):
    # prepare offsets of connections:
    offset_factors = torch.arange(width * batch) * width
    offset = torch.repeat_interleave(offset_factors, width)

    structure_connections = torch.arange(width * width * batch) - offset
    structure_connections = structure_connections.reshape(batch * width, width)
    super(FullyConnectedConstant, self).__init__(
      0, 0,
      structure_connections
    ) 
Example #19
Source File: basic.py    From torchsupport with MIT License 5 votes vote down vote up
def __init__(self, indices):
    unique, counts = indices.unique(return_counts=True)
    structure_indices = torch.arange(counts.sum(), device=indices.device)
    structure_indices = torch.repeat_interleave(
      structure_indices, torch.repeat_interleave(
        counts, counts
      )
    )

    # prepare offsets of connections:
    repeated_counts = torch.repeat_interleave(counts, counts)
    other_counts = repeated_counts.roll(1)
    other_counts[0] = 0
    other_counts = other_counts.cumsum(dim=0)
    offset_factors = other_counts
    offset = torch.repeat_interleave(offset_factors, repeated_counts)
    base = counts.roll(1)
    base[0] = 0
    base = base.cumsum(dim=0)
    base = torch.repeat_interleave(torch.repeat_interleave(base, counts), repeated_counts)

    structure_connections = torch.arange((counts * counts).sum(), device=indices.device)
    structure_connections = structure_connections - offset + base

    super(FullyConnectedScatter, self).__init__(
      0, 0,
      structure_indices,
      structure_connections
    ) 
Example #20
Source File: scatter.py    From torchsupport with MIT License 5 votes vote down vote up
def pairwise(op, data, indices, padding_value=0):
  padded, _, _, counts = pad(data, indices, value=padding_value)
  padded = padded.transpose(1, 2)
  reference = padded.unsqueeze(-1)
  padded = padded.unsqueeze(-2)
  op_result = op(padded, reference)

  # batch indices into pairwise tensor:
  batch_indices = torch.arange(counts.size(0))
  batch_indices = torch.repeat_interleave(batch_indices, counts ** 2)

  # first dimension indices:
  first_offset = counts.roll(1)
  first_offset[0] = 0
  first_offset = torch.cumsum(first_offset, dim=0)
  first_offset = torch.repeat_interleave(first_offset, counts)
  first_indices = torch.arange(counts.sum()) - first_offset
  first_indices = torch.repeat_interleave(
    first_indices,
    torch.repeat_interleave(counts, counts)
  )

  # second dimension indices:
  second_offset = torch.repeat_interleave(counts, counts).roll(1)
  second_offset[0] = 0
  second_offset = torch.cumsum(second_offset, dim=0)
  second_offset = torch.repeat_interleave(second_offset, torch.repeat_interleave(counts, counts))
  second_indices = torch.arange((counts ** 2).sum()) - second_offset

  # extract tensor from padded result using indices:
  result = op_result[batch_indices, first_indices, second_indices]

  # access: cumsum(counts ** 2)[idx] + counts[idx] * idy + idz
  access_batch = (counts ** 2).roll(1)
  access_batch[0] = 0
  access_batch = torch.cumsum(access_batch, dim=0)
  access_first = counts

  access = (access_batch, access_first)

  return result, batch_indices, first_indices, second_indices, access 
Example #21
Source File: scatter.py    From torchsupport with MIT License 5 votes vote down vote up
def repack(data, indices, target_indices):
  out = torch.zeros(
    target_indices.size(0), *data.shape[1:],
    dtype=data.dtype, device=data.device
  )
  unique, lengths = indices.unique(return_counts=True)
  unique, target_lengths = target_indices.unique(return_counts=True)
  offset = target_lengths - lengths
  offset = offset.roll(1, 0)
  offset[0] = 0
  offset = torch.repeat_interleave(offset.cumsum(dim=0), lengths, dim=0)
  index = offset + torch.arange(len(indices)).to(data.device)

  out[index] = data
  return data, target_indices 
Example #22
Source File: set_mnist_ebm.py    From torchsupport with MIT License 5 votes vote down vote up
def forward(self, image, condition):
    image = image.view(-1, 28 * 28)
    out = self.input_process(self.input(image))
    mean, logvar = self.condition(condition)
    #distribution = Normal(mean, torch.exp(0.5 * logvar))
    sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
    cond = self.postprocess(sample)
    cond = torch.repeat_interleave(cond, 5, dim=0)
    result = self.combine(torch.cat((out, cond), dim=1))
    return result, (mean, logvar) 
Example #23
Source File: set_yeast_ebm.py    From torchsupport with MIT License 5 votes vote down vote up
def forward(self, image, condition):
    image = image.view(-1, 3, 64, 64)
    out = self.input_process(self.input(image))
    mean, logvar = self.condition(condition)
    #distribution = Normal(mean, torch.exp(0.5 * logvar))
    sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
    cond = self.postprocess(sample)
    cond = torch.repeat_interleave(cond, 5, dim=0)
    result = self.combine(torch.cat((out, cond), dim=1))
    return result, (mean, logvar) 
Example #24
Source File: set_mnist_gan.py    From torchsupport with MIT License 5 votes vote down vote up
def forward(self, data):
    support, values = data
    mean, logvar = self.encoder(support)
    distribution = Normal(mean, torch.exp(0.5 * logvar))
    latent_sample = distribution.rsample()
    latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
    combined = torch.cat((values.view(-1, 28 * 28), latent_sample), dim=1)
    return self.verdict(combined) 
Example #25
Source File: set_mnist_gan.py    From torchsupport with MIT License 5 votes vote down vote up
def sample(self, data):
    support, values = data
    mean, logvar = self.condition(support)
    distribution = Normal(mean, torch.exp(0.5 * logvar))
    latent_sample = distribution.rsample()
    latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
    local_samples = torch.randn(support.size(0) * self.size, 16)
    sample = torch.cat((latent_sample, local_samples), dim=1)
    return (support, sample), (mean, logvar) 
Example #26
Source File: tensor.py    From dgl with Apache License 2.0 5 votes vote down vote up
def repeat(input, repeats, dim):
    # return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
    if dim < 0:
        dim += input.dim()
    return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1) 
Example #27
Source File: sequence.py    From DeepCTR-Torch with Apache License 2.0 5 votes vote down vote up
def forward(self, seq_value_len_list):
        if self.supports_masking:
            uiseq_embed_list, mask = seq_value_len_list  # [B, T, E], [B, 1]
            mask = mask.float()
            user_behavior_length = torch.sum(mask, dim=-1, keepdim=True)
            mask = mask.unsqueeze(2)
        else:
            uiseq_embed_list, user_behavior_length = seq_value_len_list  # [B, T, E], [B, 1]
            mask = self._sequence_mask(user_behavior_length, maxlen=uiseq_embed_list.shape[1],
                                       dtype=torch.float32)  # [B, 1, maxlen]
            mask = torch.transpose(mask, 1, 2)  # [B, maxlen, 1]

        embedding_size = uiseq_embed_list.shape[-1]

        mask = torch.repeat_interleave(mask, embedding_size, dim=2)  # [B, maxlen, E]

        if self.mode == 'max':
            hist = uiseq_embed_list - (1 - mask) * 1e9
            hist = torch.max(hist, dim=1, keepdim=True)[0]
            return hist
        hist = uiseq_embed_list * mask.float()
        hist = torch.sum(hist, dim=1, keepdim=False)

        if self.mode == 'mean':
            hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps)

        hist = torch.unsqueeze(hist, dim=1)
        return hist 
Example #28
Source File: test_instance_norm.py    From pytorch_geometric with MIT License 5 votes vote down vote up
def test_instance_norm():
    batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))

    norm = InstanceNorm(16)
    assert norm.__repr__() == (
        'InstanceNorm(16, eps=1e-05, momentum=0.1, affine=False, '
        'track_running_stats=False)')
    out = norm(torch.randn(100, 16), batch)
    assert out.size() == (100, 16)

    norm = InstanceNorm(16, affine=True, track_running_stats=True)
    out = norm(torch.randn(100, 16), batch)
    assert out.size() == (100, 16)

    # Should behave equally to `BatchNorm` for mini-batches of size 1.
    x = torch.randn(100, 16)
    norm1 = InstanceNorm(16, affine=False, track_running_stats=False)
    norm2 = BatchNorm(16, affine=False, track_running_stats=False)
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6)

    norm1 = InstanceNorm(16, affine=False, track_running_stats=True)
    norm2 = BatchNorm(16, affine=False, track_running_stats=True)
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
    assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
    assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
    assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
    assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
    norm1.eval()
    norm2.eval()
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6) 
Example #29
Source File: array.py    From MONAI with Apache License 2.0 4 votes vote down vote up
def __call__(self, img):
        """
        Args:
            img: shape must be (batch_size, C, spatial_dim1[, spatial_dim2, ...]).

        Returns:
            A PyTorch Tensor with shape (batch_size, C, spatial_dim1[, spatial_dim2, ...]).
        """
        channel_dim = 1
        if img.shape[channel_dim] == 1:

            img = torch.squeeze(img, dim=channel_dim)

            if self.independent:
                for i in self.applied_labels:
                    foreground = (img == i).type(torch.uint8)
                    mask = get_largest_connected_component_mask(foreground, self.connectivity)
                    img[foreground != mask] = 0
            else:
                foreground = torch.zeros_like(img)
                for i in self.applied_labels:
                    foreground += (img == i).type(torch.uint8)
                mask = get_largest_connected_component_mask(foreground, self.connectivity)
                img[foreground != mask] = 0
            output = torch.unsqueeze(img, dim=channel_dim)
        else:
            # one-hot data is assumed to have binary value in each channel
            if self.independent:
                for i in self.applied_labels:
                    foreground = img[:, i, ...].type(torch.uint8)
                    mask = get_largest_connected_component_mask(foreground, self.connectivity)
                    img[:, i, ...][foreground != mask] = 0
            else:
                applied_img = img[:, self.applied_labels, ...].type(torch.uint8)
                foreground = torch.any(applied_img, dim=channel_dim)
                mask = get_largest_connected_component_mask(foreground, self.connectivity)
                background_mask = torch.unsqueeze(foreground != mask, dim=channel_dim)
                background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=channel_dim)
                applied_img[background_mask] = 0
                img[:, self.applied_labels, ...] = applied_img.type(img.type())
            output = img

        return output 
Example #30
Source File: densepose_head.py    From detectron2 with Apache License 2.0 4 votes vote down vote up
def _forward_confidence_estimation_layers(
        self, confidence_model_cfg, head_outputs, interp2d, ann_index, index_uv
    ):
        sigma_1, sigma_2, kappa_u, kappa_v = None, None, None, None
        sigma_1_lowres, sigma_2_lowres, kappa_u_lowres, kappa_v_lowres = None, None, None, None
        fine_segm_confidence_lowres, fine_segm_confidence = None, None
        coarse_segm_confidence_lowres, coarse_segm_confidence = None, None
        if confidence_model_cfg.uv_confidence.enabled:
            if confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.IID_ISO:
                sigma_2_lowres = self.sigma_2_lowres(head_outputs)
                sigma_2 = interp2d(sigma_2_lowres)
            elif confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.INDEP_ANISO:
                sigma_2_lowres = self.sigma_2_lowres(head_outputs)
                kappa_u_lowres = self.kappa_u_lowres(head_outputs)
                kappa_v_lowres = self.kappa_v_lowres(head_outputs)
                sigma_2 = interp2d(sigma_2_lowres)
                kappa_u = interp2d(kappa_u_lowres)
                kappa_v = interp2d(kappa_v_lowres)
            else:
                raise ValueError(
                    f"Unknown confidence model type: {confidence_model_cfg.confidence_model_type}"
                )
        if confidence_model_cfg.segm_confidence.enabled:
            fine_segm_confidence_lowres = self.fine_segm_confidence_lowres(head_outputs)
            fine_segm_confidence = interp2d(fine_segm_confidence_lowres)
            fine_segm_confidence = (
                F.softplus(fine_segm_confidence) + confidence_model_cfg.segm_confidence.epsilon
            )
            index_uv = index_uv * torch.repeat_interleave(
                fine_segm_confidence, index_uv.shape[1], dim=1
            )
            coarse_segm_confidence_lowres = self.coarse_segm_confidence_lowres(head_outputs)
            coarse_segm_confidence = interp2d(coarse_segm_confidence_lowres)
            coarse_segm_confidence = (
                F.softplus(coarse_segm_confidence) + confidence_model_cfg.segm_confidence.epsilon
            )
            ann_index = ann_index * torch.repeat_interleave(
                coarse_segm_confidence, ann_index.shape[1], dim=1
            )
        return (
            (sigma_1, sigma_2, kappa_u, kappa_v, fine_segm_confidence, coarse_segm_confidence),
            (
                sigma_1_lowres,
                sigma_2_lowres,
                kappa_u_lowres,
                kappa_v_lowres,
                fine_segm_confidence_lowres,
                coarse_segm_confidence_lowres,
            ),
            (ann_index, index_uv),
        )