Python torch.nn.functional.gumbel_softmax() Examples
The following are 19
code examples of torch.nn.functional.gumbel_softmax().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn.functional
, or try the search function
.
Example #1
Source File: synthesizer.py From CTGAN with MIT License | 6 votes |
def _apply_activate(self, data): data_t = [] st = 0 for item in self.transformer.output_info: if item[1] == 'tanh': ed = st + item[0] data_t.append(torch.tanh(data[:, st:ed])) st = ed elif item[1] == 'softmax': ed = st + item[0] data_t.append(functional.gumbel_softmax(data[:, st:ed], tau=0.2)) st = ed else: assert 0 return torch.cat(data_t, dim=1)
Example #2
Source File: qv.py From attn2d with MIT License | 6 votes |
def assign(self, points, distance='euclid', greedy=False): # points = points.data centroids = self.centroids if distance == 'cosine': # nearest neigbor in the centroids (cosine distance): points = F.normalize(points, dim=-1) centroids = F.normalize(centroids, dim=-1) distances = (torch.sum(points**2, dim=1, keepdim=True) + torch.sum(centroids**2, dim=1, keepdim=True).t() - 2 * torch.matmul(points, centroids.t())) # T*B, e print('Distances:', distances[:3]) if not greedy: logits = - distances resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard) else: # Greedy non-differentiable responsabilities: indices = torch.argmin(distances, dim=-1) # T*B resp = torch.zeros(points.size(0), self.ne).type_as(points) resp.scatter_(1, indices.unsqueeze(1), 1) return resp
Example #3
Source File: qv.py From attn2d with MIT License | 6 votes |
def assign(self, points, distance='euclid', greedy=False): points = points.data centroids = self.centroids if distance == 'cosine': # nearest neigbor in the centroids (cosine distance): points = F.normalize(points, dim=-1) centroids = F.normalize(centroids, dim=-1) distances = (torch.sum(points**2, dim=1, keepdim=True) + torch.sum(centroids**2, dim=1, keepdim=True).t() - 2 * torch.matmul(points, centroids.t())) # T*B, e if not greedy: logits = - distances resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard) # batch_counts = resp.sum(dim=0).view(-1).data else: # Greedy non-differentiable responsabilities: indices = torch.argmin(distances, dim=-1) # T*B resp = torch.zeros(points.size(0), self.ne).type_as(points) resp.scatter_(1, indices.unsqueeze(1), 1) return resp
Example #4
Source File: qv.py From attn2d with MIT License | 6 votes |
def assign(self, points, distance='euclid', greedy=False): # points = points.data # the only diff from 16 centroids = self.centroids if distance == 'cosine': # nearest neigbor in the centroids (cosine distance): points = F.normalize(points, dim=-1) centroids = F.normalize(centroids, dim=-1) distances = (torch.sum(points**2, dim=1, keepdim=True) + torch.sum(centroids**2, dim=1, keepdim=True).t() - 2 * torch.matmul(points, centroids.t())) # T*B, e if not greedy: logits = - distances resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard) # batch_counts = resp.sum(dim=0).view(-1).data else: # Greedy non-differentiable responsabilities: indices = torch.argmin(distances, dim=-1) # T*B resp = torch.zeros(points.size(0), self.ne).type_as(points) resp.scatter_(1, indices.unsqueeze(1), 1) return resp
Example #5
Source File: qv_v1.py From attn2d with MIT License | 6 votes |
def assign(self, points, distance='euclid', greedy=False): points = points.data centroids = self.centroids if distance == 'cosine': # nearest neigbor in the centroids (cosine distance): points = F.normalize(points, dim=-1) centroids = F.normalize(centroids, dim=-1) distances = (torch.sum(points**2, dim=1, keepdim=True) + torch.sum(centroids**2, dim=1, keepdim=True).t() - 2 * torch.matmul(points, centroids.t())) # T*B, e if not greedy: logits = - distances resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard) # batch_counts = resp.sum(dim=0).view(-1).data else: # Greedy non-differentiable responsabilities: indices = torch.argmin(distances, dim=-1) # T*B resp = torch.zeros(points.size(0), self.ne).type_as(points) resp.scatter_(1, indices.unsqueeze(1), 1) return resp
Example #6
Source File: qv_v1.py From attn2d with MIT License | 6 votes |
def assign(self, points, distance='euclid', greedy=False): points = points.data centroids = self.centroids if distance == 'cosine': # nearest neigbor in the centroids (cosine distance): points = F.normalize(points, dim=-1) centroids = F.normalize(centroids, dim=-1) distances = (torch.sum(points**2, dim=1, keepdim=True) + torch.sum(centroids**2, dim=1, keepdim=True).t() - 2 * torch.matmul(points, centroids.t())) # T*B, e if not greedy: logits = - distances resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard) # batch_counts = resp.sum(dim=0).view(-1).data else: # Greedy non-differentiable responsabilities: indices = torch.argmin(distances, dim=-1) # T*B resp = torch.zeros(points.size(0), self.ne).type_as(points) resp.scatter_(1, indices.unsqueeze(1), 1) return resp
Example #7
Source File: model_search.py From nasbench-1shot1 with Apache License 2.0 | 5 votes |
def forward(self, input, discrete=False, normalize=False): # NASBench only has one input to each cell s0 = self.stem(input) for i, cell in enumerate(self.cells): if i in [self._layers // 3, 2 * self._layers // 3]: # Perform down-sampling by factor 1/2 # Equivalent to https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L68 s0 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(s0) # If using discrete architecture from random_ws search with weight sharing then pass through architecture # weights directly. # For GDAS use gumbel softmax hard, therefore per mixed block only a single operation is evaluated preprocess_op_mixed_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=True, dim=-1) # Don't use hard for the rest, because it very quickly gave exploding gradients preprocess_op = lambda x: x if discrete else F.gumbel_softmax(x, tau=self.tau, hard=False, dim=-1) # Normalize mixed_op weights for the choice blocks in the graph mixed_op_weights = preprocess_op_mixed_op(self._arch_parameters[0]) # Normalize the output weights output_weights = preprocess_op(self._arch_parameters[1]) if self._output_weights else None # Normalize the input weights for the nodes in the cell input_weights = [preprocess_op(alpha) for alpha in self._arch_parameters[2:]] s0 = cell(s0, mixed_op_weights, output_weights, input_weights) # Include one more preprocessing step here s0 = self.postprocess(s0) # [N, C_max * (steps + 1), w, h] -> [N, C_max, w, h] # Global Average Pooling by averaging over last two remaining spatial dimensions # Like in nasbench: https://github.com/google-research/nasbench/blob/master/nasbench/lib/model_builder.py#L92 out = s0.view(*s0.shape[:2], -1).mean(-1) logits = self.classifier(out.view(out.size(0), -1)) return logits
Example #8
Source File: test_pyprof_nvtx.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_gumbel_softmax(self): inp = torch.randn(16, 1024, device='cuda', dtype=self.dtype) output = F.gumbel_softmax(inp, tau=1, hard=False, eps=1e-10, dim=-1)
Example #9
Source File: multi_categorical.py From multi-categorical-gans with BSD 3-Clause "New" or "Revised" License | 5 votes |
def forward(self, logits, training=True, temperature=None): # gumbel-softmax (training and evaluation) if temperature is not None: return F.gumbel_softmax(logits, hard=not training, tau=temperature) # softmax training elif training: return F.softmax(logits, dim=1) # softmax evaluation else: return OneHotCategorical(logits=logits).sample()
Example #10
Source File: attention.py From seq2seq.pytorch with MIT License | 5 votes |
def forward(self, q, k, v): b_q, t_q, dim_q = list(q.size()) b_k, t_k, dim_k = list(k.size()) b_v, t_v, dim_v = list(v.size()) assert(b_q == b_k and b_k == b_v) # batch size should be equal assert(dim_q == dim_k) # dims should be equal assert(t_k == t_v) # times should be equal b = b_q qk = torch.bmm(q, k.transpose(1, 2)) # b x t_q x t_k qk = qk / (dim_k ** 0.5) mask = None with torch.no_grad(): if self.causal and t_q > 1: causal_mask = q.data.new(t_q, t_k).byte().fill_(1).triu_(1) mask = causal_mask.unsqueeze(0).expand(b, t_q, t_k) if self.mask_k is not None: mask_k = self.mask_k.unsqueeze(1).expand(b, t_q, t_k) mask = mask_k if mask is None else mask | mask_k if self.mask_q is not None: mask_q = self.mask_q.unsqueeze(2).expand(b, t_q, t_k) mask = mask_q if mask is None else mask | mask_q if mask is not None: qk.masked_fill_(mask, -1e12) if self.gumbel: sm_qk = F.gumbel_softmax(qk, dim=2, hard=True) else: sm_qk = F.softmax(qk, dim=2, dtype=torch.float32 if qk.dtype == torch.float16 else qk.dtype) sm_qk = self.dropout(sm_qk) return torch.bmm(sm_qk, v), sm_qk # b x t_q x dim_v
Example #11
Source File: dynamic_halters.py From attn2d with MIT License | 5 votes |
def step(self, x, n, cumul=None, total_computes=None): """ n is the index of the upcoming block, Given the current activation decide whether to go in or skip/exit. returns the binary decision and the log-(p, 1-p) """ T, B, C = x.size() if self.detach_before_classifier: x = x.detach() x = self.halting_predictors[n if self.separate_halting_predictors else 0](x) halt_logits = F.logsigmoid(x) # the log-p of halting # Apply the gumbel trick halt = halt_logits.view(-1, 2) halt = F.gumbel_softmax(halt, tau=self.gumbel_tau).view(T, B, 2) return halt
Example #12
Source File: qv.py From attn2d with MIT License | 5 votes |
def forward(self, x, key): T, B, C = x.size() loss = torch.zeros(1).type_as(x).to(x.device) if key is not None: Tr = 1 else: key = x Tr = T if self.tau: resp = F.gumbel_softmax( self.assign(key.contiguous().view(Tr*B, self.key_dim)), tau=self.tau, hard=self.hard ) # T*B, ne else: resp = torch.softmax( self.assign(key.contiguous().view(Tr*B, self.key_dim)), dim=-1 ) # T*B, ne importance = resp.sum(dim=0) loss = self.loss_scale * torch.std(importance) / torch.mean(importance) print('importance', importance.data.round()) # w = torch.matmul(resp, self.pw_w1) # T*B, C_out * C_in # w = w.view(T, B, self.output_dim, self.input_dim) # x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1) # if self.pw_bias is not None: # x = x + self.pw_bias(x0) # First evaluate each expert output resp = resp.view(Tr, B, self.ne, 1) residual = x x = torch.matmul(self.pw_w1, x.unsqueeze(2).unsqueeze(-1)).squeeze(-1) # T, B, ne, out x = F.relu(x) x = torch.sum(resp * x, dim=2) if self.pw_bias is not None: x = x + self.pw_bias(key) return x + residual, loss
Example #13
Source File: ctgan.py From SDGym with MIT License | 5 votes |
def apply_activate(data, output_info): data_t = [] st = 0 for item in output_info: if item[1] == 'tanh': ed = st + item[0] data_t.append(torch.tanh(data[:, st:ed])) st = ed elif item[1] == 'softmax': ed = st + item[0] data_t.append(F.gumbel_softmax(data[:, st:ed], tau=0.2)) st = ed else: assert 0 return torch.cat(data_t, dim=1)
Example #14
Source File: adv_masker.py From bert_on_stilts with Apache License 2.0 | 5 votes |
def forward(self, x, attention_mask, gumbel_softmax=True, tau=None): extended_attention_mask = self.convert_mask(attention_mask) h = self.bert_layer(x, extended_attention_mask) h = self.linear_layer(h) log_probs = self.log_sigmoid(h).squeeze(dim=2) if gumbel_softmax: tau = self.tau if tau is None else tau return F.gumbel_softmax(log_probs, tau=tau) else: return log_probs
Example #15
Source File: qv.py From attn2d with MIT License | 4 votes |
def forward(self, x, key): T, B, C = x.size() loss = torch.zeros(1).type_as(x).to(x.device) if key is not None: # outsider influence: scales = torch.softmax(self.pw_scales(key), dim=-1).unsqueeze(-1).unsqueeze(-1) # 1, B, ne, 1, 1 w = self.pw_w1 * scales # 1, B, ne, C_out, C_in w = torch.sum(w, dim=2) x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1) if self.pw_bias is not None: x = x + self.pw_bias(key) return x, loss else: residual = x if self.training: if self.tau: resp = F.gumbel_softmax( self.assign(x.contiguous().view(T*B, C)), tau=self.tau, hard=self.hard ) # T*B, ne else: resp = torch.softmax( self.assign(x.contiguous().view(T*B, C)), dim=-1 ) # T*B, ne else: if self.tau: resp = F.gumbel_softmax( self.assign(x.contiguous().view(T*B, C)), tau=self.tau, hard=self.hard ) # T*B, ne else: resp = torch.softmax( self.assign(x.contiguous().view(T*B, C)), dim=-1 ) # T*B, ne # For the new exp with ne600 eval is soft as well # logits = self.assign(x.contiguous().view(T*B, C)) # indices = torch.argmax(logits, dim=-1) # resp = torch.zeros(logits.size(0), self.ne).type_as(logits) # resp.scatter_(1, indices.unsqueeze(1), 1) importance = resp.sum(dim=0) loss = self.loss_scale * torch.std(importance) / torch.mean(importance) print('importance', importance.data.round()) w = torch.matmul(resp, self.pw_w1) # T*B, C_out * C_in w = w.view(T, B, self.output_dim, self.input_dim) x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1) x = x + residual # v1 sigmoid on x before residual if self.pw_bias is not None: x = x + self.pw_bias(residual) return x, loss
Example #16
Source File: qv.py From attn2d with MIT License | 4 votes |
def forward(self, x, key): T, B, C = x.size() loss = torch.zeros(1).type_as(x).to(x.device) if key is not None: # outsider influence: scales = torch.softmax(self.pw_scales(key), dim=-1).unsqueeze(-1).unsqueeze(-1) # 1, B, ne, 1, 1 w = self.pw_w1 * scales # 1, B, ne, C_out, C_in w = torch.sum(w, dim=2) x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1) if self.pw_bias is not None: x = x + self.pw_bias(key) return x, loss else: residual = x if self.training: if self.tau: resp = F.gumbel_softmax( self.assign(x.contiguous().view(T*B, C)), tau=self.tau, hard=self.hard ) # T*B, ne else: resp = torch.softmax( self.assign(x.contiguous().view(T*B, C)), dim=-1 ) # T*B, ne else: if self.tau: resp = F.gumbel_softmax( self.assign(x.contiguous().view(T*B, C)), tau=self.tau, hard=self.hard ) # T*B, ne else: resp = torch.softmax( self.assign(x.contiguous().view(T*B, C)), dim=-1 ) # T*B, ne # For the new exp with ne600 eval is soft as well # logits = self.assign(x.contiguous().view(T*B, C)) # indices = torch.argmax(logits, dim=-1) # resp = torch.zeros(logits.size(0), self.ne).type_as(logits) # resp.scatter_(1, indices.unsqueeze(1), 1) importance = resp.sum(dim=0) loss = self.loss_scale * torch.std(importance) / torch.mean(importance) print('importance', importance.data.round()) w = torch.matmul(resp, self.pw_w1) # T*B, C_out * C_in w = w.view(T, B, self.output_dim, self.input_dim) x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1) if self.pw_bias is not None: x = x + self.pw_bias(residual) x = torch.sigmoid(x) + residual return x, loss
Example #17
Source File: qv.py From attn2d with MIT License | 4 votes |
def forward(self, x, key): T, B, C = x.size() loss = torch.zeros(1).type_as(x).to(x.device) if key is not None: # outsider influence: scales = torch.softmax(self.pw_scales(key), dim=-1).unsqueeze(-1).unsqueeze(-1) # 1, B, ne, 1, 1 w = self.pw_w1 * scales # 1, B, ne, C_out, C_in w = torch.sum(w, dim=2) x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1) if self.pw_bias is not None: x = x + self.pw_bias(key) return x, loss else: x0 = x key = x.contiguous().view(T*B, C) energies = self.assign(key) if self.training: noise = F.softplus(self.noise(key)) * torch.randn_like(energies) energies = keeptopk_masked(energies + noise, self.topk) if self.tau: resp = F.gumbel_softmax( energies, tau=self.tau, hard=self.hard ) # T*B, ne else: resp = F.softmax( energies, dim=-1 ) # T*B, ne else: energies = keeptopk_masked(energies, self.topk) if self.tau: resp = F.gumbel_softmax( energies, tau=self.tau, hard=self.hard ) # T*B, ne else: resp = F.softmax( energies, dim=-1 ) # T*B, ne # indices = torch.argmax(energies, dim=-1) # resp = torch.zeros(energies.size(0), self.ne).type_as(energies) # resp.scatter_(1, indices.unsqueeze(1), 1) importance = resp.sum(dim=0) loss = self.loss_scale * torch.std(importance) / torch.mean(importance) print('importance', importance.data.round()) w = torch.matmul(resp, self.pw_w1) # T*B, C_out * C_in w = w.view(T, B, self.output_dim, self.input_dim) x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1) if self.pw_bias is not None: x = x + self.pw_bias(x0) return x, loss
Example #18
Source File: qv.py From attn2d with MIT License | 4 votes |
def forward(self, x, key): T, B, C = x.size() loss = torch.zeros(1).type_as(x).to(x.device) if key is not None: # outsider influence: scales = torch.softmax(self.pw_scales(key), dim=-1).unsqueeze(-1).unsqueeze(-1) # 1, B, ne, 1, 1 w = self.pw_w1 * scales # 1, B, ne, C_out, C_in w = torch.sum(w, dim=2) x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1) if self.pw_bias is not None: x = x + self.pw_bias(key) return x, loss else: x0 = x if self.training: if self.tau: resp = F.gumbel_softmax( self.assign(x.contiguous().view(T*B, C)), tau=self.tau, hard=self.hard ) # T*B, ne else: resp = torch.softmax( self.assign(x.contiguous().view(T*B, C)), dim=-1 ) # T*B, ne else: if self.tau: resp = F.gumbel_softmax( self.assign(x.contiguous().view(T*B, C)), tau=self.tau, hard=self.hard ) # T*B, ne else: resp = torch.softmax( self.assign(x.contiguous().view(T*B, C)), dim=-1 ) # T*B, ne # For the new exp with ne600 eval is soft as well # logits = self.assign(x.contiguous().view(T*B, C)) # indices = torch.argmax(logits, dim=-1) # resp = torch.zeros(logits.size(0), self.ne).type_as(logits) # resp.scatter_(1, indices.unsqueeze(1), 1) importance = resp.sum(dim=0) loss = self.loss_scale * torch.std(importance) / torch.mean(importance) print('importance', importance.data.round()) w = torch.matmul(resp, self.pw_w1) # T*B, C_out * C_in w = w.view(T, B, self.output_dim, self.input_dim) x = torch.matmul(w, x.unsqueeze(-1)).squeeze(-1) if self.pw_bias is not None: x = x + self.pw_bias(x0) return x, loss
Example #19
Source File: genotypes.py From nni with MIT License | 4 votes |
def parse_gumbel(alpha, beta, k): """ parse continuous alpha to discrete gene. alpha is ParameterList: ParameterList [ Parameter(n_edges1, n_ops), Parameter(n_edges2, n_ops), ... ] beta is ParameterList: ParameterList [ Parameter(n_edges1), Parameter(n_edges2), ... ] gene is list: [ [('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)], [('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)], ... ] each node has two edges (k=2) in CNN. """ gene = [] assert PRIMITIVES[-1] == 'none' # assume last PRIMITIVE is 'none' # 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge # 2) Choose top-k edges per node by edge score (top-1 weight in edge) # output the connect idx[(node_idx, connect_idx, op_idx).... () ()] connect_idx = [] for edges, w in zip(alpha, beta): # edges: Tensor(n_edges, n_ops) discrete_a = F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True) for i in range(k-1): discrete_a = discrete_a + F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True) discrete_a = discrete_a.reshape(-1, len(PRIMITIVES)-1) reserved_edge = (discrete_a > 0).nonzero() node_gene = [] node_idx = [] for i in range(reserved_edge.shape[0]): edge_idx = reserved_edge[i][0].item() prim_idx = reserved_edge[i][1].item() prim = PRIMITIVES[prim_idx] node_gene.append((prim, edge_idx)) node_idx.append((edge_idx, prim_idx)) gene.append(node_gene) connect_idx.append(node_idx) return gene, connect_idx