Python torch.nn.functional.one_hot() Examples

The following are 30 code examples of torch.nn.functional.one_hot(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn.functional , or try the search function .
Example #1
Source File: mutator.py    From nni with MIT License 6 votes vote down vote up
def _sample_layer_choice(self, mutable):
        self._lstm_next_step()
        logit = self.soft(self._h[-1])
        if self.temperature is not None:
            logit /= self.temperature
        if self.tanh_constant is not None:
            logit = self.tanh_constant * torch.tanh(logit)
        if mutable.key in self.bias_dict:
            logit += self.bias_dict[mutable.key]
        branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
        log_prob = self.cross_entropy_loss(logit, branch_id)
        self.sample_log_prob += self.entropy_reduction(log_prob)
        entropy = (log_prob * torch.exp(-log_prob)).detach()  # pylint: disable=invalid-unary-operand-type
        self.sample_entropy += self.entropy_reduction(entropy)
        self._inputs = self.embedding(branch_id)
        return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) 
Example #2
Source File: criterion.py    From MusicTransformer-pytorch with MIT License 6 votes vote down vote up
def forward(self, input, target):
        """
        Args:
            input: [B * T, V]
            target: [B * T]
        Returns:
            cross entropy: [1]
        """
        mask = (target == self.ignore_index).unsqueeze(-1)
        q = F.one_hot(target.long(), self.vocab_size).type(torch.float32)
        u = 1.0 / self.vocab_size
        q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u
        q_prime = q_prime.masked_fill(mask, 0)

        ce = self.cross_entropy_with_logits(q_prime, input)
        if self.reduction == 'mean':
            lengths = torch.sum(target != self.ignore_index)
            return ce.sum() / lengths
        elif self.reduction == 'sum':
            return ce.sum()
        else:
            raise NotImplementedError 
Example #3
Source File: datasets.py    From tape with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def collate_fn(self, batch):
        msa, dist_bins, omega_bins, theta_bins, phi_bins = tuple(zip(*batch))
        # features = pad_sequences([self.featurize(msa_) for msa_ in msa], 0)
        msa1hot = pad_sequences(
            [F.one_hot(torch.LongTensor(msa_), 21) for msa_ in msa], 0, torch.float)
        # input_mask = torch.FloatTensor(pad_sequences(input_mask, 0))
        dist_bins = torch.LongTensor(pad_sequences(dist_bins, -1))
        omega_bins = torch.LongTensor(pad_sequences(omega_bins, 0))
        theta_bins = torch.LongTensor(pad_sequences(theta_bins, 0))
        phi_bins = torch.LongTensor(pad_sequences(phi_bins, 0))

        return {'msa1hot': msa1hot,
                # 'input_mask': input_mask,
                'dist': dist_bins,
                'omega': omega_bins,
                'theta': theta_bins,
                'phi': phi_bins} 
Example #4
Source File: discrete_sampler.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        assert scores.dim() == 2, (
            "scores dim is %d" % scores.dim()
        )  # batch_size x num_actions
        batch_size, num_actions = scores.shape

        # pyre-fixme[16]: `Tensor` has no attribute `argmax`.
        argmax = F.one_hot(scores.argmax(dim=1), num_actions).bool()

        rand_prob = self.epsilon / num_actions
        p = torch.full_like(rand_prob, scores)

        greedy_prob = 1 - self.epsilon + rand_prob
        p[argmax] = greedy_prob

        m = torch.distributions.Categorical(probs=p)
        raw_action = m.sample()
        action = F.one_hot(raw_action, num_actions)
        assert action.shape == (batch_size, num_actions)
        log_prob = m.log_prob(raw_action)
        assert log_prob.shape == (batch_size,)
        return rlt.ActorOutput(action=action, log_prob=log_prob) 
Example #5
Source File: losses.py    From EfficientDet-PyTorch with Apache License 2.0 6 votes vote down vote up
def focal_loss(self, x, y):
        '''Focal loss.
        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].
        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25
        gamma = 2

        t = F.one_hot(y.data, 1+self.num_classes)  # [N,21]
        t = t[:,1:]  # exclude background
        t = Variable(t)
        
        p = x.sigmoid()
        pt = p*t + (1-p)*(1-t)         # pt = p if t > 0 else 1-p
        w = alpha*t + (1-alpha)*(1-t)  # w = alpha if t > 0 else 1-alpha
        w = w * (1-pt).pow(gamma)
        return F.binary_cross_entropy_with_logits(x, t, w, reduction='sum') 
Example #6
Source File: datasets.py    From tape with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def featurize(self, msa):
        msa = torch.LongTensor(msa)
        msa1hot = F.one_hot(msa, 21).float()

        seqlen = msa1hot.size(1)

        weights = self.reweight(msa1hot)
        features_1d = self.extract_features_1d(msa1hot, weights)
        features_2d = self.extract_features_2d(msa1hot, weights)

        features = torch.cat((
            features_1d.unsqueeze(1).repeat(1, seqlen, 1),
            features_1d.unsqueeze(0).repeat(seqlen, 1, 1),
            features_2d), -1)

        features = features.permute(2, 0, 1)

        return features 
Example #7
Source File: crossentropyloss.py    From backpack with MIT License 6 votes vote down vote up
def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
        self._check_2nd_order_parameters(module)

        M = mc_samples
        C = module.input0.shape[1]

        probs = self._get_probs(module)
        V_dim = 0
        probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1)

        multi = multinomial(probs, M, replacement=True)
        classes = one_hot(multi, num_classes=C)
        classes = einsum("nvc->vnc", classes).float()

        sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_mc_h /= sqrt(N)

        return sqrt_mc_h 
Example #8
Source File: sdf.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def parse_sdf(src):
    src = src.split('\n')[3:]
    num_atoms, num_bonds = [int(item) for item in src[0].split()[:2]]

    atom_block = src[1:num_atoms + 1]
    pos = parse_txt_array(atom_block, end=3)
    x = torch.tensor([elems[item.split()[3]] for item in atom_block])
    x = F.one_hot(x, num_classes=len(elems))

    bond_block = src[1 + num_atoms:1 + num_atoms + num_bonds]
    row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1
    row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
    edge_index = torch.stack([row, col], dim=0)
    edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1
    edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
    edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms,
                                     num_atoms)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)
    return data 
Example #9
Source File: metric.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def intersection_and_union(pred, target, num_classes, batch=None):
    r"""Computes intersection and union of predictions.

    Args:
        pred (LongTensor): The predictions.
        target (LongTensor): The targets.
        num_classes (int): The number of classes.
        batch (LongTensor): The assignment vector which maps each pred-target
            pair to an example.

    :rtype: (:class:`LongTensor`, :class:`LongTensor`)
    """
    pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes)

    if batch is None:
        i = (pred & target).sum(dim=0)
        u = (pred | target).sum(dim=0)
    else:
        i = scatter_add(pred & target, batch, dim=0)
        u = scatter_add(pred | target, batch, dim=0)

    return i, u 
Example #10
Source File: losses.py    From EfficientDet-PyTorch with Apache License 2.0 6 votes vote down vote up
def focal_loss_alt(self, x, y, alpha=0.25, gamma=1.5):
        '''Focal loss alternative.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        '''
        t = F.one_hot(y, self.num_classes+1)
        t = t[:,1:]

        xt = x*(2*t-1)  # xt = x if t > 0 else -x
        pt = (2*xt+1).sigmoid()
        pt = pt.clamp(1e-7, 1.0)
        w = (0+alpha)*(0+t) + (1-alpha)*(1-t)
        loss = -w*pt.log() / gamma
        return loss.sum() 
Example #11
Source File: normal_loss.py    From RMI with MIT License 6 votes vote down vote up
def forward(self, logits_4D, labels_4D):
		"""
		Args:
			logits_4D 	:	[N, C, H, W], dtype=float32
			labels_4D 	:	[N, H, W], dtype=long
		"""
		label_flat = labels_4D.view(-1).requires_grad_(False)
		label_mask_flat = label_flat < self.num_classes
		onehot_label_flat = F.one_hot(label_flat * label_mask_flat.long(), num_classes=self.num_classes).float()
		onehot_label_flat = onehot_label_flat.requires_grad_(False)
		logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes])

		# binary loss, multiplied by the not_ignore_mask
		label_mask_flat = label_mask_flat.float()
		valid_pixels = torch.sum(label_mask_flat)
		binary_loss = F.binary_cross_entropy_with_logits(logits_flat,
															target=onehot_label_flat,
															weight=label_mask_flat.unsqueeze(dim=1),
															reduction='sum')
		bce_loss = torch.div(binary_loss, valid_pixels + 1.0)
		return bce_loss 
Example #12
Source File: test_focal_loss.py    From MONAI with Apache License 2.0 6 votes vote down vote up
def test_multi_class_seg_2d(self):
        num_classes = 6  # labels 0 to 5
        # define 2d examples
        target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])
        # add another dimension corresponding to the batch (batch size = 1 here)
        target = target.unsqueeze(0)  # shape (1, H, W)
        pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
        # initialize the mean dice loss
        loss = FocalLoss()

        # focal loss for pred_very_good should be close to 0
        target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2)  # test one hot
        target = target.unsqueeze(1)  # shape (1, 1, H, W)

        focal_loss_good = float(loss(pred_very_good, target).cpu())
        self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

        focal_loss_good = float(loss(pred_very_good, target_one_hot).cpu())
        self.assertAlmostEqual(focal_loss_good, 0.0, places=3) 
Example #13
Source File: polybeast_loss_functions_test.py    From torchbeast with Apache License 2.0 6 votes vote down vote up
def test_compute_policy_gradient_loss(self):
        T, B, N = self.logits.shape

        # Calculate the the cross entropy loss, with the formula:
        # loss = -sum_over_j(y_j * log(p_j))
        # Where:
        # - `y_j` is whether the action corrisponding to index j has been taken or not,
        #   (hence y is a one-hot-array of size == number of actions).
        # - `p_j` is the value of the sofmax logit corresponding to the jth action.
        # In our implementation, we also multiply for the advantages.
        labels = F.one_hot(torch.from_numpy(self.actions), num_classes=N).numpy()
        cross_entropy_loss = -labels * np.log(_softmax(self.logits))
        ground_truth_value = np.sum(
            cross_entropy_loss * self.advantages.reshape(T, B, 1)
        )

        calculated_value = polybeast.compute_policy_gradient_loss(
            torch.from_numpy(self.logits),
            torch.from_numpy(self.actions),
            torch.from_numpy(self.advantages),
        )
        assert_allclose(ground_truth_value, calculated_value.item()) 
Example #14
Source File: SentiGAN_G.py    From TextGAN-PyTorch with MIT License 6 votes vote down vote up
def batchPGLoss(self, inp, target, reward):
        """
        Returns a policy gradient loss

        :param inp: batch_size x seq_len, inp should be target with <s> (start letter) prepended
        :param target: batch_size x seq_len
        :param reward: batch_size (discriminator reward for each sentence, applied to each token of the corresponding sentence)
        :return loss: policy loss
        """

        batch_size, seq_len = inp.size()
        hidden = self.init_hidden(batch_size)

        out = self.forward(inp, hidden, use_log=False).view(batch_size, self.max_seq_len, self.vocab_size)
        target_onehot = F.one_hot(target, self.vocab_size).float()  # batch_size * seq_len * vocab_size
        pred = torch.sum(out * target_onehot, dim=-1)  # batch_size * seq_len
        loss = -torch.sum(pred * (1 - reward))

        return loss 
Example #15
Source File: MaliGAN_G.py    From TextGAN-PyTorch with MIT License 6 votes vote down vote up
def adv_loss(self, inp, target, reward):
        """
        Returns a MaliGAN loss

        :param inp: batch_size x seq_len, inp should be target with <s> (start letter) prepended
        :param target: batch_size x seq_len
        :param reward: batch_size (discriminator reward for each sentence, applied to each token of the corresponding sentence)
        :return loss: policy loss
        """

        batch_size, seq_len = inp.size()
        hidden = self.init_hidden(batch_size)

        out = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size)
        target_onehot = F.one_hot(target, self.vocab_size).float()  # batch_size * seq_len * vocab_size
        pred = torch.sum(out * target_onehot, dim=-1)  # batch_size * seq_len
        loss = -torch.sum(pred * reward)

        return loss 
Example #16
Source File: SeqGAN_G.py    From TextGAN-PyTorch with MIT License 6 votes vote down vote up
def batchPGLoss(self, inp, target, reward):
        """
        Returns a policy gradient loss

        :param inp: batch_size x seq_len, inp should be target with <s> (start letter) prepended
        :param target: batch_size x seq_len
        :param reward: batch_size (discriminator reward for each sentence, applied to each token of the corresponding sentence)
        :return loss: policy loss
        """

        batch_size, seq_len = inp.size()
        hidden = self.init_hidden(batch_size)

        out = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size)
        target_onehot = F.one_hot(target, self.vocab_size).float()  # batch_size * seq_len * vocab_size
        pred = torch.sum(out * target_onehot, dim=-1)  # batch_size * seq_len
        loss = -torch.sum(pred * reward)

        return loss 
Example #17
Source File: cross_entropy_loss.py    From kaggle-kuzushiji-recognition with MIT License 6 votes vote down vote up
def cross_entropy(pred,
                  label,
                  weight=None,
                  reduction='mean',
                  avg_factor=None,
                  label_smooth=None):
    # element-wise losses
    if label_smooth is None:
        loss = F.cross_entropy(pred, label, reduction='none')
    else:
        num_classes = pred.size(1)
        target = F.one_hot(label, num_classes).type_as(pred)
        target = target.sub_(label_smooth).clamp_(0).add_(label_smooth / num_classes)
        loss = F.kl_div(pred.log_softmax(1), target, reduction='none').sum(1)

    # apply weights and do the reduction
    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)

    return loss 
Example #18
Source File: RelGAN_G.py    From TextGAN-PyTorch with MIT License 6 votes vote down vote up
def step(self, inp, hidden):
        """
        RelGAN step forward
        :param inp: [batch_size]
        :param hidden: memory size
        :return: pred, hidden, next_token, next_token_onehot, next_o
            - pred: batch_size * vocab_size, use for adversarial training backward
            - hidden: next hidden
            - next_token: [batch_size], next sentence token
            - next_token_onehot: batch_size * vocab_size, not used yet
            - next_o: batch_size * vocab_size, not used yet
        """
        emb = self.embeddings(inp).unsqueeze(1)
        out, hidden = self.lstm(emb, hidden)
        gumbel_t = self.add_gumbel(self.lstm2out(out.squeeze(1)))
        next_token = torch.argmax(gumbel_t, dim=1).detach()
        # next_token_onehot = F.one_hot(next_token, cfg.vocab_size).float()  # not used yet
        next_token_onehot = None

        pred = F.softmax(gumbel_t * self.temperature, dim=-1)  # batch_size * vocab_size
        # next_o = torch.sum(next_token_onehot * pred, dim=1)  # not used yet
        next_o = None

        return pred, hidden, next_token, next_token_onehot, next_o 
Example #19
Source File: Load_Agent.py    From FitML with MIT License 6 votes vote down vote up
def train_step(model, state_transitions, tgt, num_actions):
    if len(state_transitions) <=0:
        print("empty state transitions")
        return
    cur_states = torch.stack( ([torch.Tensor(s.state) for s in state_transitions]) ).to(model.device)
    rewards = torch.stack( ([torch.Tensor([s.reward]) for s in state_transitions]) ).to(model.device)
    Qs = torch.stack( ([torch.Tensor([s.qval]) for s in state_transitions]) ).to(model.device)
    mask = torch.stack(([torch.Tensor([0]) if s.done else torch.Tensor([1]) for s in state_transitions])).to(model.device)
    next_states = torch.stack( ([torch.Tensor(s.next_state) for s in state_transitions]) ).to(model.device)
    actions = [s.action for s in state_transitions]
    # import ipdb; ipdb.set_trace()
    with torch.no_grad():
        # actual_Q_values = Qs
        pred_qvals_next = model(next_states).max(-1)[0]
    model.opt.zero_grad()
    pred_qvals = model(cur_states)

    one_hot_actions = F.one_hot(torch.LongTensor(actions),num_actions).to(model.device)
    # loss = torch.mean(torch.sqrt((torch.sum(pred_qvals*one_hot_actions,-1) - actual_Q_values.view(-1) )**2)).to(model.device)
    # loss = F.smooth_l1_loss(torch.sum(pred_qvals*one_hot_actions,-1), actual_Q_values.view(-1) )
    loss = F.smooth_l1_loss(torch.sum(pred_qvals*one_hot_actions,-1), rewards.view(-1)+0.99*mask[:,0]*pred_qvals_next.view(-1) ).mean()
    loss.backward()
    model.opt.step()
    return loss 
Example #20
Source File: relgan_instructor.py    From TextGAN-PyTorch with MIT License 6 votes vote down vote up
def adv_train_discriminator(self, d_step):
        total_loss = 0
        for step in range(d_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.dis_opt, d_loss, self.dis)
            total_loss += d_loss.item()

        return total_loss / d_step if d_step != 0 else 0 
Example #21
Source File: relgan_instructor.py    From TextGAN-PyTorch with MIT License 6 votes vote down vote up
def adv_train_generator(self, g_step):
        total_loss = 0
        for step in range(g_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.gen_adv_opt, g_loss, self.gen)
            total_loss += g_loss.item()

        return total_loss / g_step if g_step != 0 else 0 
Example #22
Source File: relgan_instructor.py    From TextGAN-PyTorch with MIT License 6 votes vote down vote up
def adv_train_discriminator(self, d_step):
        total_loss = 0
        for step in range(d_step):
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.dis_opt, d_loss, self.dis)
            total_loss += d_loss.item()

        return total_loss / d_step if d_step != 0 else 0 
Example #23
Source File: ATARI_DQN_CNN.py    From FitML with MIT License 6 votes vote down vote up
def train_step(model, state_transitions, tgt, num_actions, gamma):
    if len(state_transitions) <=0:
        print("empty state transitions")
        return
    cur_states = torch.stack( ([torch.Tensor(s.state) for s in state_transitions]) ).to(model.device)
    rewards = torch.stack( ([torch.Tensor([s.reward]) for s in state_transitions]) ).to(model.device)
    Qs = torch.stack( ([torch.Tensor([s.qval]) for s in state_transitions]) ).to(model.device)
    mask = torch.stack(([torch.Tensor([0]) if s.done else torch.Tensor([1]) for s in state_transitions])).to(model.device)
    next_states = torch.stack( ([torch.Tensor(s.next_state) for s in state_transitions]) ).to(model.device)
    actions = [s.action for s in state_transitions]
    # import ipdb; ipdb.set_trace()
    with torch.no_grad():
        actual_Q_values = Qs
        # import ipdb; ipdb.set_trace()
        pred_qvals_next = model(next_states.view(len(state_transitions),3,160,140*3)).max(-1)[0]
    model.opt.zero_grad()
    pred_qvals = model(cur_states.view(len(state_transitions),3,160,140*3))

    one_hot_actions = F.one_hot(torch.LongTensor(actions),num_actions).to(model.device)
    # loss = torch.mean(torch.sqrt((torch.sum(pred_qvals*one_hot_actions,-1) - actual_Q_values.view(-1) )**2)).to(model.device)
    loss = F.smooth_l1_loss(torch.sum(pred_qvals*one_hot_actions,-1), actual_Q_values.view(-1) )
    # loss = F.smooth_l1_loss(torch.sum(pred_qvals*one_hot_actions,-1), rewards.view(-1)+gamma*mask[:,0]*pred_qvals_next.view(-1) ).mean()
    loss.backward()
    model.opt.step()
    return loss 
Example #24
Source File: relgan_instructor.py    From TextGAN-PyTorch with MIT License 6 votes vote down vote up
def adv_train_generator(self, g_step):
        total_loss = 0
        for step in range(g_step):
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.gen_adv_opt, g_loss, self.gen)
            total_loss += g_loss.item()

        return total_loss / g_step if g_step != 0 else 0 
Example #25
Source File: mutator.py    From nni with MIT License 6 votes vote down vote up
def sample_search(self):
        """
        Sample a random candidate.
        """
        result = dict()
        for mutable in self.mutables:
            if isinstance(mutable, LayerChoice):
                gen_index = torch.randint(high=len(mutable), size=(1, ))
                result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool()
            elif isinstance(mutable, InputChoice):
                if mutable.n_chosen is None:
                    result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
                else:
                    perm = torch.randperm(mutable.n_candidates)
                    mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
                    result[mutable.key] = torch.tensor(mask, dtype=torch.bool)  # pylint: disable=not-callable
        return result 
Example #26
Source File: mutator.py    From nni with MIT License 6 votes vote down vote up
def sample_final(self):
        """
        Generate the final chosen architecture.

        Returns
        -------
        dict
            the choice of each mutable, i.e., LayerChoice
        """
        result = dict()
        for mutable in self.undedup_mutables:
            assert isinstance(mutable, LayerChoice)
            index, _ = mutable.registered_module.chosen_index
            # pylint: disable=not-callable
            result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=len(mutable)).view(-1).bool()
        return result 
Example #27
Source File: loss.py    From torch-toolbox with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _get_body(self, x, target):
        cos_t = torch.gather(x, 1, target.unsqueeze(1))  # cos(theta_yi)
        if self.easy_margin:
            cond = torch.relu(cos_t)
        else:
            cond_v = cos_t - self.threshold
            cond = torch.relu(cond_v)
        cond = cond.bool()
        # Apex would convert FP16 to FP32 here
        # cos(theta_yi + m)
        new_zy = torch.cos(torch.acos(cos_t) + self.m).type(cos_t.dtype)
        if self.easy_margin:
            zy_keep = cos_t
        else:
            zy_keep = cos_t - self.mm  # (cos(theta_yi) - sin(pi - m)*m)
        new_zy = torch.where(cond, new_zy, zy_keep)
        diff = new_zy - cos_t  # cos(theta_yi + m) - cos(theta_yi)
        gt_one_hot = F.one_hot(target, num_classes=self.classes)
        body = gt_one_hot * diff
        return body 
Example #28
Source File: discrete_sampler.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:

        batch_size, num_actions = scores.shape
        raw_action = self._get_greedy_indices(scores)
        action = F.one_hot(raw_action, num_actions)
        assert action.shape == (batch_size, num_actions)
        return rlt.ActorOutput(action=action, log_prob=torch.ones_like(raw_action)) 
Example #29
Source File: mutator.py    From nni with MIT License 5 votes vote down vote up
def _sample_input_choice(self, mutable):
        query, anchors = [], []
        for label in mutable.choose_from:
            if label not in self._anchors_hid:
                self._lstm_next_step()
                self._mark_anchor(label)  # empty loop, fill not found
            query.append(self.attn_anchor(self._anchors_hid[label]))
            anchors.append(self._anchors_hid[label])
        query = torch.cat(query, 0)
        query = torch.tanh(query + self.attn_query(self._h[-1]))
        query = self.v_attn(query)
        if self.temperature is not None:
            query /= self.temperature
        if self.tanh_constant is not None:
            query = self.tanh_constant * torch.tanh(query)

        if mutable.n_chosen is None:
            logit = torch.cat([-query, query], 1)  # pylint: disable=invalid-unary-operand-type

            skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
            skip_prob = torch.sigmoid(logit)
            kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
            self.sample_skip_penalty += kl
            log_prob = self.cross_entropy_loss(logit, skip)
            self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
        else:
            assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
            logit = query.view(1, -1)
            index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
            skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
            log_prob = self.cross_entropy_loss(logit, index)
            self._inputs = anchors[index.item()]

        self.sample_log_prob += self.entropy_reduction(log_prob)
        entropy = (log_prob * torch.exp(-log_prob)).detach()  # pylint: disable=invalid-unary-operand-type
        self.sample_entropy += self.entropy_reduction(entropy)
        return skip.bool() 
Example #30
Source File: distribution.py    From WaveRNN with MIT License 5 votes vote down vote up
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
    """
    Sample from discretized mixture of logistic distributions
    Args:
        y (Tensor): B x C x T
        log_scale_min (float): Log scale minimum value
    Returns:
        Tensor: sample in range of [-1, 1].
    """
    if log_scale_min is None:
        log_scale_min = float(np.log(1e-14))
    assert y.size(1) % 3 == 0
    nr_mix = y.size(1) // 3

    # B x T x C
    y = y.transpose(1, 2)
    logit_probs = y[:, :, :nr_mix]

    # sample mixture indicator from softmax
    temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
    temp = logit_probs.data - torch.log(- torch.log(temp))
    _, argmax = temp.max(dim=-1)

    # (B, T) -> (B, T, nr_mix)
    one_hot = F.one_hot(argmax, nr_mix).float()
    # select logistic parameters
    means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
    log_scales = torch.clamp(torch.sum(
        y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
    # sample from logistic & clip to interval
    # we don't actually round to the nearest 8bit value when sampling
    u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
    x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))

    x = torch.clamp(torch.clamp(x, min=-1.), max=1.)

    return x