Python torch.set_printoptions() Examples

The following are 19 code examples of torch.set_printoptions(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: attn2d_dynamic_ll.py    From attn2d with MIT License 6 votes vote down vote up
def decide(self, prev_output_tokens, encoder_out, context_size):
        torch.set_printoptions(precision=2)
        # source embeddings
        src_emb = encoder_out['encoder_out'][:, :context_size]  # B, Ts, ds 
        # target embeddings:
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=None,
        ) if self.embed_positions is not None else None
        # Build the full grid
        tgt_emb = self.embed_scale * self.embed_tokens(prev_output_tokens)
        if positions is not None:
            tgt_emb += positions
        tgt_emb = self.embedding_dropout(tgt_emb)
        src_length = src_emb.size(1)
        tgt_length = tgt_emb.size(1)
        # build 2d "image" of embeddings
        src_emb = _expand(src_emb, 1, tgt_length)  # B, Tt, Ts, ds
        tgt_emb = _expand(tgt_emb, 2, src_length)  # B, Tt, Ts, dt
        x = torch.cat((src_emb, tgt_emb), dim=3)   # B, Tt, Ts, C=ds+dt
        obs = self.controller_feat(x)
        controls = self.controller.predict_read_write(obs) 
        pwrite = torch.exp(controls[:,-1,-1,1])
        return pwrite 
Example #2
Source File: double_attn2d_dynamic_ll.py    From attn2d with MIT License 6 votes vote down vote up
def decide(self, prev_output_tokens, encoder_out, context_size):
        torch.set_printoptions(precision=1)
        # source embeddings
        src_emb = encoder_out['ctrl_encoder_out'][:, :context_size]  # B, Ts, ds 
        # target embeddings:
        positions = self.ctrl_embed_positions(
            prev_output_tokens,
            incremental_state=None,
        ) if self.ctrl_embed_positions is not None else None
        # Build the full grid
        tgt_emb = self.embed_scale * self.ctrl_embed_tokens(prev_output_tokens)
        if positions is not None:
            tgt_emb += positions
        tgt_emb = self.embedding_dropout(tgt_emb)
        src_length = src_emb.size(1)
        tgt_length = tgt_emb.size(1)
        # build 2d "image" of embeddings
        src_emb = _expand(src_emb, 1, tgt_length)  # B, Tt, Ts, ds
        tgt_emb = _expand(tgt_emb, 2, src_length)  # B, Tt, Ts, dt
        x = torch.cat((src_emb, tgt_emb), dim=3)   # B, Tt, Ts, C=ds+dt
        obs = self.controller_feat(x)
        controls = self.controller.predict_read_write(obs) 
        pwrite = torch.exp(controls[:,-1,-1,1])
        return pwrite 
Example #3
Source File: visual.py    From GraphIE with GNU General Public License v3.0 6 votes vote down vote up
def plot_att_change(batch_doc, network, record, save_img_path, uid='temp',
                    epoch=0, device=torch.device('cpu'), word_alphabet=None, show_net=False, graph_types=['coref']):
    char, word, posi, labels, feats, adjs = [batch_doc[i].to(device) for i in
                                             ["chars", "word_ids", "posi", "ner_ids", "feat_ids", "adjs"]]
    word_txt = []
    if word_alphabet:
        doc = word[0][word[0] != PAD_ID_WORD]
        word_txt = [word_alphabet.get_instance(w) for w in doc]

    adjs_cp = adjs.clone()

    # save adj to file
    print_thres = adjs.size(-1) * adjs.size(-2) + 1000
    torch.set_printoptions(threshold=print_thres)

    # check adj_old, adj_new
    # select = plot_att(adjs_cp, word_txt, record, epoch=epoch)

    network.loss(None, word, char, adjs_cp, labels, show_net=show_net, graph_types=graph_types)
    # plot_att(adjs_cp, word_txt, record, epoch=epoch, select=select) 
Example #4
Source File: dropblock.py    From Parsing-R-CNN with MIT License 5 votes vote down vote up
def forward(self, input):
        if not self.training or self.keep_prob == 1:
            return input
        gamma = (1. - self.keep_prob) / self.block_size ** 2
        for sh in input.shape[2:]:
            gamma *= sh / (sh - self.block_size + 1)
        M = torch.bernoulli(torch.ones_like(input) * gamma)
        Msum = F.conv2d(M,
                        torch.ones((input.shape[1], 1, self.block_size, self.block_size)).to(device=input.device,
                                                                                             dtype=input.dtype),
                        padding=self.block_size // 2,
                        groups=input.shape[1])
        torch.set_printoptions(threshold=5000)
        mask = (Msum < 1).to(device=input.device, dtype=input.dtype)
        return input * mask * mask.numel() /mask.sum() #TODO input * mask * self.keep_prob ? 
Example #5
Source File: dropout.py    From LightNetPlusPlus with MIT License 5 votes vote down vote up
def forward(self, input):
        if not self.training or self.keep_prob == 1:
            return input
        gamma = (1. - self.keep_prob) / self.block_size ** 2
        for sh in input.shape[2:]:
            gamma *= sh / (sh - self.block_size + 1)
        M = torch.bernoulli(torch.ones_like(input) * gamma)
        Msum = F.conv2d(M,
                        torch.ones((input.shape[1], 1, self.block_size, self.block_size)).to(device=input.device,
                                                                                             dtype=input.dtype),
                        padding=self.block_size // 2,
                        groups=input.shape[1])
        torch.set_printoptions(threshold=5000)
        mask = (Msum < 1).to(device=input.device, dtype=input.dtype)
        return input * mask * mask.numel() / mask.sum()  # TODO input * mask * self.keep_prob ? 
Example #6
Source File: test_label_smoothing.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def setUp(self, seed=1234):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        # Set pytorch print precision
        torch.set_printoptions(precision=10) 
Example #7
Source File: shallow_controller.py    From attn2d with MIT License 5 votes vote down vote up
def decide(self, x):
        torch.set_printoptions(precision=2)
        # Final LN
        if self.final_ln is not None:
            x = self.final_ln(x)
        # Aggregate
        x, _ = self.aggregator(x)
        x = x[:, -1, -1]
        # A stack of linear layers
        x =  self.net(x)
        # The R/W decisions:
        x = torch.sigmoid(self.gate(x)).squeeze(-1)  # p(read)
        return  1-x 
Example #8
Source File: pa_controller.py    From attn2d with MIT License 5 votes vote down vote up
def decide(self, src_tokens, prev_output_tokens, writing_grid):
        # torch.set_printoptions(precision=2)
        if not self.share_embeddings:
            x = self.observation_grid(src_tokens,
                                      prev_output_tokens) 
        else:
            x = writing_grid

        # Cumulative ResNet:
        x =  self.net(x)
        # Cell aggreegation
        x = x[:,-1, -1]
        # The R/W decisions:
        x = torch.sigmoid(self.gate(x)).squeeze(-1)  # p(read)
        return  1-x 
Example #9
Source File: DropBlock.py    From DropBlock-pytorch with MIT License 5 votes vote down vote up
def forward(self, input):
        if not self.training or self.keep_prob == 1:
            return input
        gamma = (1. - self.keep_prob) / self.block_size ** 2
        for sh in input.shape[2:]:
            gamma *= sh / (sh - self.block_size + 1)
        M = torch.bernoulli(torch.ones_like(input) * gamma)
        Msum = F.conv2d(M,
                        torch.ones((input.shape[1], 1, self.block_size, self.block_size)).to(device=input.device,
                                                                                             dtype=input.dtype),
                        padding=self.block_size // 2,
                        groups=input.shape[1])
        torch.set_printoptions(threshold=5000)
        mask = (Msum < 1).to(device=input.device, dtype=input.dtype)
        return input * mask * mask.numel() /mask.sum() #TODO input * mask * self.keep_prob ? 
Example #10
Source File: trainer.py    From pytorch-asr with GNU General Public License v3.0 5 votes vote down vote up
def unit_train(self, data):
        xs, ys, frame_lens, label_lens, filenames, _ = data
        try:
            batch_size = xs.size(0)
            if self.use_cuda:
                xs = xs.cuda(non_blocking=True)
            ys_hat, frame_lens = self.model(xs, frame_lens)
            if self.fp16:
                ys_hat = ys_hat.float()
            ys_hat = ys_hat.transpose(0, 1).contiguous()  # TxNxH
            #torch.set_printoptions(threshold=5000000)
            #print(ys_hat.shape, frame_lens, ys.shape, label_lens)
            #print(onehot2int(ys_hat).squeeze(), ys)
            loss = self.loss(ys_hat, ys, frame_lens, label_lens)
            if torch.isnan(loss) or loss.item() == float("inf") or loss.item() == -float("inf"):
                logger.warning("received an nan/inf loss: probably frame_lens < label_lens or the learning rate is too high")
                #raise RuntimeError
                return None
            if frame_lens.cpu().lt(2*label_lens).nonzero().numel():
                logger.debug("the batch includes a data with frame_lens < 2*label_lens: set loss to zero")
                loss.mul_(0)
            loss_value = loss.item()
            self.optimizer.zero_grad()
            if self.fp16:
                #self.optimizer.backward(loss)
                #self.optimizer.clip_master_grads(self.max_norm)
                with self.optimizer.scale_loss(loss) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
            self.optimizer.step()
            if self.use_cuda:
                torch.cuda.synchronize()
            del loss
            return loss_value
        except Exception as e:
            print(e)
            print(filenames, frame_lens, label_lens)
            raise 
Example #11
Source File: printing.py    From heat with MIT License 5 votes vote down vote up
def set_printoptions(
    precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None
):
    """
    Configures the printing options. List of items shamelessly taken from NumPy and PyTorch (thanks guys!).

    Parameters
    ----------
    precision: int
        Number of digits of precision for floating point output (default=4).
    threshold: int
        Total number of array elements which trigger summarization rather than full `repr` string (default=1000).
    edgeitems: int
        Number of array items in summary at beginning and end of each dimension (default=3).
    linewidth: int
        The number of characters per line for the purpose of inserting line breaks (default = 80).
    profile: str
        Sane defaults for pretty printing. Can override with any of the above options. Can be any one of `default`,
        `short`, `full`.
    sci_mode: bool
        Enable (True) or disable (False) scientific notation. If None (default) is specified, the value is automatically
        inferred by HeAT.
    """
    torch.set_printoptions(precision, threshold, edgeitems, linewidth, profile, sci_mode)

    # HeAT profiles will print a bit wider than PyTorch does
    if profile == "default" and linewidth is None:
        torch._tensor_str.PRINT_OPTS.linewidth = _DEFAULT_LINEWIDTH
    elif profile == "short" and linewidth is None:
        torch._tensor_str.PRINT_OPTS.linewidth = _DEFAULT_LINEWIDTH
    elif profile == "full" and linewidth is None:
        torch._tensor_str.PRINT_OPTS.linewidth = _DEFAULT_LINEWIDTH 
Example #12
Source File: transformer.py    From GraphIE with GNU General Public License v3.0 5 votes vote down vote up
def get_attn_adj_mask(adjs):
    adjs_mask = adjs.ne(0)  # batch*n_node*n_node
    # torch.set_printoptions(precision=None, threshold=float('inf'))
    # pdb.set_trace()

    n_neig = adjs_mask.sum(dim=2)
    adjs_mask[:, :, 0] += n_neig.eq(0)  # this is for making PAD not all zeros

    return adjs_mask.eq(0) 
Example #13
Source File: test_sum_product.py    From smooth-topk with MIT License 5 votes vote down vote up
def setUp(self):

        torch.set_printoptions(linewidth=160, threshold=1e3)

        seed = 7
        np.random.seed(1234)
        seed = np.random.randint(1e5)
        torch.manual_seed(seed)

        self.eps = 1e-4 
Example #14
Source File: off.py    From pytorch_geometric with MIT License 5 votes vote down vote up
def write_off(data, path):
    r"""Writes a :class:`torch_geometric.data.Data` object to an OFF (Object
    File Format) file.

    Args:
        data (:class:`torch_geometric.data.Data`): The data object.
        path (str): The path to the file.
    """
    num_nodes, num_faces = data.pos.size(0), data.face.size(1)

    pos = data.pos.to(torch.float)
    face = data.face.t()
    num_vertices = torch.full((num_faces, 1), face.size(1), dtype=torch.long)
    face = torch.cat([num_vertices, face], dim=-1)

    threshold = PRINT_OPTS.threshold
    torch.set_printoptions(threshold=float('inf'))

    pos_repr = re.sub(',', '', _tensor_str(pos, indent=0))
    pos_repr = '\n'.join([x[2:-1] for x in pos_repr.split('\n')])[:-1]

    face_repr = re.sub(',', '', _tensor_str(face, indent=0))
    face_repr = '\n'.join([x[2:-1] for x in face_repr.split('\n')])[:-1]

    with open(path, 'w') as f:
        f.write('OFF\n{} {} 0\n'.format(num_nodes, num_faces))
        f.write(pos_repr)
        f.write('\n')
        f.write(face_repr)
        f.write('\n')
    torch.set_printoptions(threshold=threshold) 
Example #15
Source File: generate_fbank_data.py    From audio with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def decode(fn, sound_path, exe_path, scp_path, out_dir):
    """
    Takes a filepath and prints out the corresponding shell command to run that specific
    kaldi configuration. It also calls compliance.kaldi and prints the two outputs.

    Example:
        >> fn = 'fbank-1.1009-2.5985-1.1875-0.8750-5723-true-918-4-0.31-true-false-true-true-' \
            'false-false-false-true-4595-4281-1.0000-hamming.ark'
        >> decode(fn)
    """
    out_fn = out_dir + fn
    fn = fn[len('fbank-'):-len('.ark')]
    arr = [
        'blackman_coeff', 'energy_floor', 'frame_length', 'frame_shift', 'high_freq', 'htk_compat',
        'low_freq', 'num_mel_bins', 'preemphasis_coefficient', 'raw_energy', 'remove_dc_offset',
        'round_to_power_of_two', 'snip_edges', 'subtract_mean', 'use_energy', 'use_log_fbank',
        'use_power', 'vtln_high', 'vtln_low', 'vtln_warp', 'window_type']
    fn_split = fn.split('-')
    assert len(fn_split) == len(arr), ('Len mismatch: %d and %d' % (len(fn_split), len(arr)))
    inputs = {arr[i]: utils.parse(fn_split[i]) for i in range(len(arr))}

    # print flags for C++
    s = ' '.join(['--' + arr[i].replace('_', '-') + '=' + fn_split[i] for i in range(len(arr))])
    logging.info(exe_path + ' --dither=0.0 --debug-mel=true ' + s + ' ' + scp_path + ' ' + out_fn)
    logging.info()
    # print args for python
    inputs['dither'] = 0.0
    logging.info(inputs)
    sound, sample_rate = torchaudio.load_wav(sound_path)
    kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(out_fn)}
    res = torchaudio.compliance.kaldi.fbank(sound, **inputs)
    torch.set_printoptions(precision=10, sci_mode=False)
    logging.info(res)
    logging.info(kaldi_output_dict['my_id']) 
Example #16
Source File: models.py    From KG-A2C with MIT License 4 votes vote down vote up
def forward(self, input, input_hidden, vocab, vocab_rev, decode_steps_t, graphs):
        all_outputs, all_words = [], []

        decoder_input = torch.tensor([vocab_rev['<s>']] * input.size(0)).cuda()
        decoder_hidden = input_hidden.unsqueeze(0)
        torch.set_printoptions(profile="full")

        for di in range(self.max_decode_steps):
            ret_decoder_output, decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, input, graphs)

            if self.k == 1:
                all_outputs.append(ret_decoder_output)

                dec_objs = []
                for i in range(decoder_output.shape[0]):
                    dec_probs = F.softmax(ret_decoder_output[i][graphs[i]], dim=0)
                    idx = dec_probs.multinomial(1)
                    graph_list = graphs[i].nonzero().cpu().numpy().flatten().tolist()
                    assert len(graph_list) == dec_probs.numel()
                    dec_objs.append(graph_list[idx])
                topi = torch.LongTensor(dec_objs).cuda()

                # dec_probs = self.softmax(decoder_output)
                # topi = dec_probs.multinomial(num_samples=1)
                # topi = self.softmax(decoder_output).topk(1)[1]

                decoder_input = topi.squeeze().detach()

                all_words.append(topi)
            else:
                topv, topi = decoder_output.topk(self.k)
                topv = self.softmax(topv)
                topv = topv.cpu().numpy()
                topi = topi.cpu().numpy()
                cur_objs = []

                for i in range(graphs.size(0)):
                    cur_obj = np.random.choice(topi[i].reshape(-1), p=topv[i].reshape(-1))
                    cur_objs.append(cur_obj)

                decoder_input = torch.LongTensor(cur_objs).cuda()
                all_words.append(decoder_input)
                all_outputs.append(decoder_output)

        return torch.stack(all_outputs), torch.stack(all_words) 
Example #17
Source File: attn2d_dynamic_v2.py    From attn2d with MIT License 4 votes vote down vote up
def forward_train(self, prev_output_tokens, encoder_out, target, **kwargs):
        torch.set_printoptions(precision=2)
        # source embeddings
        src_emb = encoder_out['encoder_out']  # B, Ts, ds 
        # target embeddings:
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=None,
        ) if self.embed_positions is not None else None

        decoder_mask = prev_output_tokens.eq(self.padding_idx)
        if not decoder_mask.any():
            decoder_mask = None

        # Build the full grid
        tgt_emb = self.embed_scale * self.embed_tokens(prev_output_tokens)
        if positions is not None:
            tgt_emb += positions
        tgt_emb = self.embedding_dropout(tgt_emb)
        batch_size = src_emb.size(0)
        src_length = src_emb.size(1)
        tgt_length = tgt_emb.size(1)

        # build 2d "image" of embeddings
        src_emb = _expand(src_emb, 1, tgt_length)  # B, Tt, Ts, ds
        tgt_emb = _expand(tgt_emb, 2, src_length)  # B, Tt, Ts, dt
        x = torch.cat((src_emb, tgt_emb), dim=3)   # B, Tt, Ts, C=ds+dt
        x = self.input_dropout(x)

        observations = self.controller_feat(x)
        # pass through dense convolutional layers
        encoder_mask = encoder_out['encoder_padding_mask']
        x = self.net(
            x, 
            decoder_mask=decoder_mask,
            encoder_mask=encoder_mask,
            incremental_state=None,
        )  # B, Tt, Ts, C
        x, _ = self.aggregator(x)  # B, Tt, Ts, C
        x = self.projection(x) if self.projection is not None else x  # B, Tt, C

        # Predict
        x = self.prediction_dropout(x)
        x = self.prediction(x)  # B, Tt, Ts, V
        x = utils.log_softmax(x, dim=-1)
        x = x.view(-1, x.size(-1)).gather(
            dim=-1,
            index=target.unsqueeze(-1).expand(-1, -1, src_length).contiguous().view(-1, 1)
        ).view(batch_size, tgt_length, src_length).permute(1,0,2)  # Tt, B, Ts
        controls, gamma, read_labels, write_labels = self.controller(observations, x)
        return x, observations, controls, gamma, read_labels, write_labels 
Example #18
Source File: rule_miner.py    From CPL with MIT License 4 votes vote down vote up
def rollout(self, e_s, q, e_t, num_steps, visualize_action_probs=False):
        # 改变:现场计算reward。

        assert (num_steps > 0)
        kg, pn = self.kg, self.mdl

        def reward_fun_binary(e1, r, e2, pred_e2, reward_binary):
            reward = (pred_e2 == e2).float()
            for i in range(e1.size()[0]):
                if reward_binary[i] and pred_e2[i] == kg.dummy_end_e: reward[i] = 1
            return reward

        # Initialization
        log_action_probs = []
        action_entropy = []
        r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)
        seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1)
        path_components = []
        reward = torch.zeros(e_s.size()).cuda()

        path_trace = [(r_s, e_s)]
        pn.initialize_path((r_s, e_s), kg)

        logr = open("traces.txt","a")
        for t in range(num_steps):
            last_r, e = path_trace[-1]
            obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes]
            db_outcomes, inv_offset, policy_entropy = pn.transit(
                e, obs, kg, use_action_space_bucketing=self.use_action_space_bucketing)
            sample_outcome = self.sample_action(db_outcomes, inv_offset)
            action = sample_outcome['action_sample']
            reward = reward + reward_fun_binary(e_s, q, e_t, action[1], reward)   #现场计算reward
            torch.set_printoptions(threshold=5000)
            pn.update_path(action, kg)
            action_prob = sample_outcome['action_prob']
            log_action_probs.append(ops.safe_log(action_prob))
            action_entropy.append(policy_entropy)
            seen_nodes = torch.cat([seen_nodes, e.unsqueeze(1)], dim=1)
            path_trace.append(action)
            #print(action[0], file=logr)

            if visualize_action_probs:
                top_k_action = sample_outcome['top_actions']
                top_k_action_prob = sample_outcome['top_action_probs']
                path_components.append((e, top_k_action, top_k_action_prob))

        pred_e2 = path_trace[-1][1] #理论来讲需要改,但是实际上好像没用而且耽误backprop……
        reward = self.reward_fun(e_s, q, e_t, pred_e2, reward)
        #print(reward, file=logr)
        self.record_path_trace(path_trace)

        return {
            'pred_e2': pred_e2,
            'log_action_probs': log_action_probs,
            'action_entropy': action_entropy,
            'path_trace': path_trace,
            'path_components': path_components,
            'reward': reward
        } 
Example #19
Source File: loss.py    From batch-dropblock-network with MIT License 4 votes vote down vote up
def hard_example_mining(dist_mat, labels, margin, return_inds=False):
    """For each anchor, find the hardest positive and negative sample.
    Args:
      dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
      labels: pytorch LongTensor, with shape [N]
      return_inds: whether to return the indices. Save time if `False`(?)
    Returns:
      dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
      dist_an: pytorch Variable, distance(anchor, negative); shape [N]
      p_inds: pytorch LongTensor, with shape [N];
        indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
      n_inds: pytorch LongTensor, with shape [N];
        indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
    NOTE: Only consider the case in which all labels have same num of samples,
      thus we can cope with all anchors in parallel.
    """

    torch.set_printoptions(threshold=5000) 
    assert len(dist_mat.size()) == 2
    assert dist_mat.size(0) == dist_mat.size(1)
    N = dist_mat.size(0)

    # shape [N, N]
    is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
    is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
    # `dist_ap` means distance(anchor, positive)
    # both `dist_ap` and `relative_p_inds` with shape [N, 1]
    dist_ap, relative_p_inds = torch.max(
        dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
    # `dist_an` means distance(anchor, negative)
    # both `dist_an` and `relative_n_inds` with shape [N, 1]
    dist_an, relative_n_inds = torch.min(
        dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
    # shape [N]
    dist_ap = dist_ap.squeeze(1)
    dist_an = dist_an.squeeze(1)

    if return_inds:
        # shape [N, N]
        ind = (labels.new().resize_as_(labels)
               .copy_(torch.arange(0, N).long())
               .unsqueeze(0).expand(N, N))
        # shape [N, 1]
        p_inds = torch.gather(
            ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
        n_inds = torch.gather(
            ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
        # shape [N]
        p_inds = p_inds.squeeze(1)
        n_inds = n_inds.squeeze(1)
        return dist_ap, dist_an, p_inds, n_inds

    return dist_ap, dist_an