Python torch.nn.functional.one_hot() Examples
The following are 30
code examples of torch.nn.functional.one_hot().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn.functional
, or try the search function
.
Example #1
Source File: mutator.py From nni with MIT License | 6 votes |
def _sample_layer_choice(self, mutable): self._lstm_next_step() logit = self.soft(self._h[-1]) if self.temperature is not None: logit /= self.temperature if self.tanh_constant is not None: logit = self.tanh_constant * torch.tanh(logit) if mutable.key in self.bias_dict: logit += self.bias_dict[mutable.key] branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) log_prob = self.cross_entropy_loss(logit, branch_id) self.sample_log_prob += self.entropy_reduction(log_prob) entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type self.sample_entropy += self.entropy_reduction(entropy) self._inputs = self.embedding(branch_id) return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
Example #2
Source File: criterion.py From MusicTransformer-pytorch with MIT License | 6 votes |
def forward(self, input, target): """ Args: input: [B * T, V] target: [B * T] Returns: cross entropy: [1] """ mask = (target == self.ignore_index).unsqueeze(-1) q = F.one_hot(target.long(), self.vocab_size).type(torch.float32) u = 1.0 / self.vocab_size q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u q_prime = q_prime.masked_fill(mask, 0) ce = self.cross_entropy_with_logits(q_prime, input) if self.reduction == 'mean': lengths = torch.sum(target != self.ignore_index) return ce.sum() / lengths elif self.reduction == 'sum': return ce.sum() else: raise NotImplementedError
Example #3
Source File: datasets.py From tape with BSD 3-Clause "New" or "Revised" License | 6 votes |
def collate_fn(self, batch): msa, dist_bins, omega_bins, theta_bins, phi_bins = tuple(zip(*batch)) # features = pad_sequences([self.featurize(msa_) for msa_ in msa], 0) msa1hot = pad_sequences( [F.one_hot(torch.LongTensor(msa_), 21) for msa_ in msa], 0, torch.float) # input_mask = torch.FloatTensor(pad_sequences(input_mask, 0)) dist_bins = torch.LongTensor(pad_sequences(dist_bins, -1)) omega_bins = torch.LongTensor(pad_sequences(omega_bins, 0)) theta_bins = torch.LongTensor(pad_sequences(theta_bins, 0)) phi_bins = torch.LongTensor(pad_sequences(phi_bins, 0)) return {'msa1hot': msa1hot, # 'input_mask': input_mask, 'dist': dist_bins, 'omega': omega_bins, 'theta': theta_bins, 'phi': phi_bins}
Example #4
Source File: discrete_sampler.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 6 votes |
def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput: assert scores.dim() == 2, ( "scores dim is %d" % scores.dim() ) # batch_size x num_actions batch_size, num_actions = scores.shape # pyre-fixme[16]: `Tensor` has no attribute `argmax`. argmax = F.one_hot(scores.argmax(dim=1), num_actions).bool() rand_prob = self.epsilon / num_actions p = torch.full_like(rand_prob, scores) greedy_prob = 1 - self.epsilon + rand_prob p[argmax] = greedy_prob m = torch.distributions.Categorical(probs=p) raw_action = m.sample() action = F.one_hot(raw_action, num_actions) assert action.shape == (batch_size, num_actions) log_prob = m.log_prob(raw_action) assert log_prob.shape == (batch_size,) return rlt.ActorOutput(action=action, log_prob=log_prob)
Example #5
Source File: losses.py From EfficientDet-PyTorch with Apache License 2.0 | 6 votes |
def focal_loss(self, x, y): '''Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 gamma = 2 t = F.one_hot(y.data, 1+self.num_classes) # [N,21] t = t[:,1:] # exclude background t = Variable(t) p = x.sigmoid() pt = p*t + (1-p)*(1-t) # pt = p if t > 0 else 1-p w = alpha*t + (1-alpha)*(1-t) # w = alpha if t > 0 else 1-alpha w = w * (1-pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w, reduction='sum')
Example #6
Source File: datasets.py From tape with BSD 3-Clause "New" or "Revised" License | 6 votes |
def featurize(self, msa): msa = torch.LongTensor(msa) msa1hot = F.one_hot(msa, 21).float() seqlen = msa1hot.size(1) weights = self.reweight(msa1hot) features_1d = self.extract_features_1d(msa1hot, weights) features_2d = self.extract_features_2d(msa1hot, weights) features = torch.cat(( features_1d.unsqueeze(1).repeat(1, seqlen, 1), features_1d.unsqueeze(0).repeat(seqlen, 1, 1), features_2d), -1) features = features.permute(2, 0, 1) return features
Example #7
Source File: crossentropyloss.py From backpack with MIT License | 6 votes |
def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1): self._check_2nd_order_parameters(module) M = mc_samples C = module.input0.shape[1] probs = self._get_probs(module) V_dim = 0 probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1) multi = multinomial(probs, M, replacement=True) classes = one_hot(multi, num_classes=C) classes = einsum("nvc->vnc", classes).float() sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M) if module.reduction == "mean": N = module.input0.shape[0] sqrt_mc_h /= sqrt(N) return sqrt_mc_h
Example #8
Source File: sdf.py From pytorch_geometric with MIT License | 6 votes |
def parse_sdf(src): src = src.split('\n')[3:] num_atoms, num_bonds = [int(item) for item in src[0].split()[:2]] atom_block = src[1:num_atoms + 1] pos = parse_txt_array(atom_block, end=3) x = torch.tensor([elems[item.split()[3]] for item in atom_block]) x = F.one_hot(x, num_classes=len(elems)) bond_block = src[1 + num_atoms:1 + num_atoms + num_bonds] row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1 row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) edge_index = torch.stack([row, col], dim=0) edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1 edge_attr = torch.cat([edge_attr, edge_attr], dim=0) edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms, num_atoms) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) return data
Example #9
Source File: metric.py From pytorch_geometric with MIT License | 6 votes |
def intersection_and_union(pred, target, num_classes, batch=None): r"""Computes intersection and union of predictions. Args: pred (LongTensor): The predictions. target (LongTensor): The targets. num_classes (int): The number of classes. batch (LongTensor): The assignment vector which maps each pred-target pair to an example. :rtype: (:class:`LongTensor`, :class:`LongTensor`) """ pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes) if batch is None: i = (pred & target).sum(dim=0) u = (pred | target).sum(dim=0) else: i = scatter_add(pred & target, batch, dim=0) u = scatter_add(pred | target, batch, dim=0) return i, u
Example #10
Source File: losses.py From EfficientDet-PyTorch with Apache License 2.0 | 6 votes |
def focal_loss_alt(self, x, y, alpha=0.25, gamma=1.5): '''Focal loss alternative. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' t = F.one_hot(y, self.num_classes+1) t = t[:,1:] xt = x*(2*t-1) # xt = x if t > 0 else -x pt = (2*xt+1).sigmoid() pt = pt.clamp(1e-7, 1.0) w = (0+alpha)*(0+t) + (1-alpha)*(1-t) loss = -w*pt.log() / gamma return loss.sum()
Example #11
Source File: normal_loss.py From RMI with MIT License | 6 votes |
def forward(self, logits_4D, labels_4D): """ Args: logits_4D : [N, C, H, W], dtype=float32 labels_4D : [N, H, W], dtype=long """ label_flat = labels_4D.view(-1).requires_grad_(False) label_mask_flat = label_flat < self.num_classes onehot_label_flat = F.one_hot(label_flat * label_mask_flat.long(), num_classes=self.num_classes).float() onehot_label_flat = onehot_label_flat.requires_grad_(False) logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes]) # binary loss, multiplied by the not_ignore_mask label_mask_flat = label_mask_flat.float() valid_pixels = torch.sum(label_mask_flat) binary_loss = F.binary_cross_entropy_with_logits(logits_flat, target=onehot_label_flat, weight=label_mask_flat.unsqueeze(dim=1), reduction='sum') bce_loss = torch.div(binary_loss, valid_pixels + 1.0) return bce_loss
Example #12
Source File: test_focal_loss.py From MONAI with Apache License 2.0 | 6 votes |
def test_multi_class_seg_2d(self): num_classes = 6 # labels 0 to 5 # define 2d examples target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() # initialize the mean dice loss loss = FocalLoss() # focal loss for pred_very_good should be close to 0 target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2) # test one hot target = target.unsqueeze(1) # shape (1, 1, H, W) focal_loss_good = float(loss(pred_very_good, target).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) focal_loss_good = float(loss(pred_very_good, target_one_hot).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
Example #13
Source File: polybeast_loss_functions_test.py From torchbeast with Apache License 2.0 | 6 votes |
def test_compute_policy_gradient_loss(self): T, B, N = self.logits.shape # Calculate the the cross entropy loss, with the formula: # loss = -sum_over_j(y_j * log(p_j)) # Where: # - `y_j` is whether the action corrisponding to index j has been taken or not, # (hence y is a one-hot-array of size == number of actions). # - `p_j` is the value of the sofmax logit corresponding to the jth action. # In our implementation, we also multiply for the advantages. labels = F.one_hot(torch.from_numpy(self.actions), num_classes=N).numpy() cross_entropy_loss = -labels * np.log(_softmax(self.logits)) ground_truth_value = np.sum( cross_entropy_loss * self.advantages.reshape(T, B, 1) ) calculated_value = polybeast.compute_policy_gradient_loss( torch.from_numpy(self.logits), torch.from_numpy(self.actions), torch.from_numpy(self.advantages), ) assert_allclose(ground_truth_value, calculated_value.item())
Example #14
Source File: SentiGAN_G.py From TextGAN-PyTorch with MIT License | 6 votes |
def batchPGLoss(self, inp, target, reward): """ Returns a policy gradient loss :param inp: batch_size x seq_len, inp should be target with <s> (start letter) prepended :param target: batch_size x seq_len :param reward: batch_size (discriminator reward for each sentence, applied to each token of the corresponding sentence) :return loss: policy loss """ batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) out = self.forward(inp, hidden, use_log=False).view(batch_size, self.max_seq_len, self.vocab_size) target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len loss = -torch.sum(pred * (1 - reward)) return loss
Example #15
Source File: MaliGAN_G.py From TextGAN-PyTorch with MIT License | 6 votes |
def adv_loss(self, inp, target, reward): """ Returns a MaliGAN loss :param inp: batch_size x seq_len, inp should be target with <s> (start letter) prepended :param target: batch_size x seq_len :param reward: batch_size (discriminator reward for each sentence, applied to each token of the corresponding sentence) :return loss: policy loss """ batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) out = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size) target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len loss = -torch.sum(pred * reward) return loss
Example #16
Source File: SeqGAN_G.py From TextGAN-PyTorch with MIT License | 6 votes |
def batchPGLoss(self, inp, target, reward): """ Returns a policy gradient loss :param inp: batch_size x seq_len, inp should be target with <s> (start letter) prepended :param target: batch_size x seq_len :param reward: batch_size (discriminator reward for each sentence, applied to each token of the corresponding sentence) :return loss: policy loss """ batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) out = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size) target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len loss = -torch.sum(pred * reward) return loss
Example #17
Source File: cross_entropy_loss.py From kaggle-kuzushiji-recognition with MIT License | 6 votes |
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None, label_smooth=None): # element-wise losses if label_smooth is None: loss = F.cross_entropy(pred, label, reduction='none') else: num_classes = pred.size(1) target = F.one_hot(label, num_classes).type_as(pred) target = target.sub_(label_smooth).clamp_(0).add_(label_smooth / num_classes) loss = F.kl_div(pred.log_softmax(1), target, reduction='none').sum(1) # apply weights and do the reduction if weight is not None: weight = weight.float() loss = weight_reduce_loss( loss, weight=weight, reduction=reduction, avg_factor=avg_factor) return loss
Example #18
Source File: RelGAN_G.py From TextGAN-PyTorch with MIT License | 6 votes |
def step(self, inp, hidden): """ RelGAN step forward :param inp: [batch_size] :param hidden: memory size :return: pred, hidden, next_token, next_token_onehot, next_o - pred: batch_size * vocab_size, use for adversarial training backward - hidden: next hidden - next_token: [batch_size], next sentence token - next_token_onehot: batch_size * vocab_size, not used yet - next_o: batch_size * vocab_size, not used yet """ emb = self.embeddings(inp).unsqueeze(1) out, hidden = self.lstm(emb, hidden) gumbel_t = self.add_gumbel(self.lstm2out(out.squeeze(1))) next_token = torch.argmax(gumbel_t, dim=1).detach() # next_token_onehot = F.one_hot(next_token, cfg.vocab_size).float() # not used yet next_token_onehot = None pred = F.softmax(gumbel_t * self.temperature, dim=-1) # batch_size * vocab_size # next_o = torch.sum(next_token_onehot * pred, dim=1) # not used yet next_o = None return pred, hidden, next_token, next_token_onehot, next_o
Example #19
Source File: Load_Agent.py From FitML with MIT License | 6 votes |
def train_step(model, state_transitions, tgt, num_actions): if len(state_transitions) <=0: print("empty state transitions") return cur_states = torch.stack( ([torch.Tensor(s.state) for s in state_transitions]) ).to(model.device) rewards = torch.stack( ([torch.Tensor([s.reward]) for s in state_transitions]) ).to(model.device) Qs = torch.stack( ([torch.Tensor([s.qval]) for s in state_transitions]) ).to(model.device) mask = torch.stack(([torch.Tensor([0]) if s.done else torch.Tensor([1]) for s in state_transitions])).to(model.device) next_states = torch.stack( ([torch.Tensor(s.next_state) for s in state_transitions]) ).to(model.device) actions = [s.action for s in state_transitions] # import ipdb; ipdb.set_trace() with torch.no_grad(): # actual_Q_values = Qs pred_qvals_next = model(next_states).max(-1)[0] model.opt.zero_grad() pred_qvals = model(cur_states) one_hot_actions = F.one_hot(torch.LongTensor(actions),num_actions).to(model.device) # loss = torch.mean(torch.sqrt((torch.sum(pred_qvals*one_hot_actions,-1) - actual_Q_values.view(-1) )**2)).to(model.device) # loss = F.smooth_l1_loss(torch.sum(pred_qvals*one_hot_actions,-1), actual_Q_values.view(-1) ) loss = F.smooth_l1_loss(torch.sum(pred_qvals*one_hot_actions,-1), rewards.view(-1)+0.99*mask[:,0]*pred_qvals_next.view(-1) ).mean() loss.backward() model.opt.step() return loss
Example #20
Source File: relgan_instructor.py From TextGAN-PyTorch with MIT License | 6 votes |
def adv_train_discriminator(self, d_step): total_loss = 0 for step in range(d_step): real_samples = self.train_data.random_batch()['target'] gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() real_samples = F.one_hot(real_samples, cfg.vocab_size).float() # ===Train=== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type) self.optimize(self.dis_opt, d_loss, self.dis) total_loss += d_loss.item() return total_loss / d_step if d_step != 0 else 0
Example #21
Source File: relgan_instructor.py From TextGAN-PyTorch with MIT License | 6 votes |
def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): real_samples = self.train_data.random_batch()['target'] gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() real_samples = F.one_hot(real_samples, cfg.vocab_size).float() # ===Train=== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type) self.optimize(self.gen_adv_opt, g_loss, self.gen) total_loss += g_loss.item() return total_loss / g_step if g_step != 0 else 0
Example #22
Source File: relgan_instructor.py From TextGAN-PyTorch with MIT License | 6 votes |
def adv_train_discriminator(self, d_step): total_loss = 0 for step in range(d_step): real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() # ===Train=== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type) self.optimize(self.dis_opt, d_loss, self.dis) total_loss += d_loss.item() return total_loss / d_step if d_step != 0 else 0
Example #23
Source File: ATARI_DQN_CNN.py From FitML with MIT License | 6 votes |
def train_step(model, state_transitions, tgt, num_actions, gamma): if len(state_transitions) <=0: print("empty state transitions") return cur_states = torch.stack( ([torch.Tensor(s.state) for s in state_transitions]) ).to(model.device) rewards = torch.stack( ([torch.Tensor([s.reward]) for s in state_transitions]) ).to(model.device) Qs = torch.stack( ([torch.Tensor([s.qval]) for s in state_transitions]) ).to(model.device) mask = torch.stack(([torch.Tensor([0]) if s.done else torch.Tensor([1]) for s in state_transitions])).to(model.device) next_states = torch.stack( ([torch.Tensor(s.next_state) for s in state_transitions]) ).to(model.device) actions = [s.action for s in state_transitions] # import ipdb; ipdb.set_trace() with torch.no_grad(): actual_Q_values = Qs # import ipdb; ipdb.set_trace() pred_qvals_next = model(next_states.view(len(state_transitions),3,160,140*3)).max(-1)[0] model.opt.zero_grad() pred_qvals = model(cur_states.view(len(state_transitions),3,160,140*3)) one_hot_actions = F.one_hot(torch.LongTensor(actions),num_actions).to(model.device) # loss = torch.mean(torch.sqrt((torch.sum(pred_qvals*one_hot_actions,-1) - actual_Q_values.view(-1) )**2)).to(model.device) loss = F.smooth_l1_loss(torch.sum(pred_qvals*one_hot_actions,-1), actual_Q_values.view(-1) ) # loss = F.smooth_l1_loss(torch.sum(pred_qvals*one_hot_actions,-1), rewards.view(-1)+gamma*mask[:,0]*pred_qvals_next.view(-1) ).mean() loss.backward() model.opt.step() return loss
Example #24
Source File: relgan_instructor.py From TextGAN-PyTorch with MIT License | 6 votes |
def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() # ===Train=== d_out_real = self.dis(real_samples) d_out_fake = self.dis(gen_samples) g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type) self.optimize(self.gen_adv_opt, g_loss, self.gen) total_loss += g_loss.item() return total_loss / g_step if g_step != 0 else 0
Example #25
Source File: mutator.py From nni with MIT License | 6 votes |
def sample_search(self): """ Sample a random candidate. """ result = dict() for mutable in self.mutables: if isinstance(mutable, LayerChoice): gen_index = torch.randint(high=len(mutable), size=(1, )) result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool() elif isinstance(mutable, InputChoice): if mutable.n_chosen is None: result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() else: perm = torch.randperm(mutable.n_candidates) mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)] result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable return result
Example #26
Source File: mutator.py From nni with MIT License | 6 votes |
def sample_final(self): """ Generate the final chosen architecture. Returns ------- dict the choice of each mutable, i.e., LayerChoice """ result = dict() for mutable in self.undedup_mutables: assert isinstance(mutable, LayerChoice) index, _ = mutable.registered_module.chosen_index # pylint: disable=not-callable result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=len(mutable)).view(-1).bool() return result
Example #27
Source File: loss.py From torch-toolbox with BSD 3-Clause "New" or "Revised" License | 6 votes |
def _get_body(self, x, target): cos_t = torch.gather(x, 1, target.unsqueeze(1)) # cos(theta_yi) if self.easy_margin: cond = torch.relu(cos_t) else: cond_v = cos_t - self.threshold cond = torch.relu(cond_v) cond = cond.bool() # Apex would convert FP16 to FP32 here # cos(theta_yi + m) new_zy = torch.cos(torch.acos(cos_t) + self.m).type(cos_t.dtype) if self.easy_margin: zy_keep = cos_t else: zy_keep = cos_t - self.mm # (cos(theta_yi) - sin(pi - m)*m) new_zy = torch.where(cond, new_zy, zy_keep) diff = new_zy - cos_t # cos(theta_yi + m) - cos(theta_yi) gt_one_hot = F.one_hot(target, num_classes=self.classes) body = gt_one_hot * diff return body
Example #28
Source File: discrete_sampler.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 5 votes |
def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput: batch_size, num_actions = scores.shape raw_action = self._get_greedy_indices(scores) action = F.one_hot(raw_action, num_actions) assert action.shape == (batch_size, num_actions) return rlt.ActorOutput(action=action, log_prob=torch.ones_like(raw_action))
Example #29
Source File: mutator.py From nni with MIT License | 5 votes |
def _sample_input_choice(self, mutable): query, anchors = [], [] for label in mutable.choose_from: if label not in self._anchors_hid: self._lstm_next_step() self._mark_anchor(label) # empty loop, fill not found query.append(self.attn_anchor(self._anchors_hid[label])) anchors.append(self._anchors_hid[label]) query = torch.cat(query, 0) query = torch.tanh(query + self.attn_query(self._h[-1])) query = self.v_attn(query) if self.temperature is not None: query /= self.temperature if self.tanh_constant is not None: query = self.tanh_constant * torch.tanh(query) if mutable.n_chosen is None: logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip_prob = torch.sigmoid(logit) kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) self.sample_skip_penalty += kl log_prob = self.cross_entropy_loss(logit, skip) self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) else: assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS." logit = query.view(1, -1) index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1) log_prob = self.cross_entropy_loss(logit, index) self._inputs = anchors[index.item()] self.sample_log_prob += self.entropy_reduction(log_prob) entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type self.sample_entropy += self.entropy_reduction(entropy) return skip.bool()
Example #30
Source File: distribution.py From WaveRNN with MIT License | 5 votes |
def sample_from_discretized_mix_logistic(y, log_scale_min=None): """ Sample from discretized mixture of logistic distributions Args: y (Tensor): B x C x T log_scale_min (float): Log scale minimum value Returns: Tensor: sample in range of [-1, 1]. """ if log_scale_min is None: log_scale_min = float(np.log(1e-14)) assert y.size(1) % 3 == 0 nr_mix = y.size(1) // 3 # B x T x C y = y.transpose(1, 2) logit_probs = y[:, :, :nr_mix] # sample mixture indicator from softmax temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) temp = logit_probs.data - torch.log(- torch.log(temp)) _, argmax = temp.max(dim=-1) # (B, T) -> (B, T, nr_mix) one_hot = F.one_hot(argmax, nr_mix).float() # select logistic parameters means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) log_scales = torch.clamp(torch.sum( y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) # sample from logistic & clip to interval # we don't actually round to the nearest 8bit value when sampling u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) x = torch.clamp(torch.clamp(x, min=-1.), max=1.) return x