Python torch.empty_like() Examples
The following are 30
code examples of torch.empty_like().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: abstract.py From rising with MIT License | 6 votes |
def forward(self, **data) -> dict: """ Apply transformation Args: data: dict with tensors Returns: dict: dict with augmented data """ if self.per_channel: kwargs = {} for k in self.property_names: kwargs[k] = getattr(self, k) kwargs.update(self.kwargs) for _key in self.keys: out = torch.empty_like(data[_key]) for _i in range(data[_key].shape[1]): out[:, _i] = self.augment_fn(data[_key][:, _i], out=out[:, _i], **kwargs) data[_key] = out return data else: return super().forward(**data)
Example #2
Source File: oracle_controls.py From attn2d with MIT License | 6 votes |
def fill_controls_emissions_grid(self, controls, emissions, indices, src_length): """ Return controls (C) and emissions (E) covering all the grid C : Tt, N, Ts, 2 E : Tt, N, Ts """ N = controls[0].size(0) tgt_length = len(controls) gamma = controls[0].new_zeros((tgt_length, src_length, N)) Cread = controls[0].new_zeros((tgt_length, src_length, N, 1)) Cwrite = utils.fill_with_neg_inf(torch.empty_like(Cread)) triu_mask = torch.triu(controls[0].new_ones(tgt_length, src_length), 1).byte() triu_mask = triu_mask.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, N, 1) Cwrite.masked_fill_(triu_mask, 0) C = torch.cat((Cread, Cwrite), dim=-1) E = utils.fill_with_neg_inf(emissions[0].new(tgt_length, src_length, N)) for t, (subC, subE) in enumerate(zip(controls, emissions)): select = [indices[t].to(C.device)] C[t].index_put_(select, subC.transpose(0, 1)) E[t].index_put_(select, subE.transpose(0, 1)) gamma[t].index_fill_(0, select[0], 1) # Normalize gamma: gamma = gamma / gamma.sum(dim=1, keepdim=True) return C.transpose(1, 2), E.transpose(1, 2), gamma.transpose(1, 2)
Example #3
Source File: label_smooth.py From pytorch-loss with MIT License | 6 votes |
def forward(ctx, logits, label, lb_smooth, lb_ignore): # prepare label num_classes = logits.size(1) lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / num_classes label = label.clone().detach() ignore = label == lb_ignore n_valid = (label != lb_ignore).sum() label[ignore] = 0 lb_one_hot = torch.empty_like(logits).fill_( lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach() ignore = ignore.nonzero() _, M = ignore.size() a, *b = ignore.chunk(M, dim=1) mask = [a, torch.arange(logits.size(1)), *b] lb_one_hot[mask] = 0 coeff = (num_classes - 1) * lb_neg + lb_pos ctx.variables = coeff, mask, logits, lb_one_hot loss = torch.log_softmax(logits, dim=1).neg_().mul_(lb_one_hot).sum(dim=1) return loss
Example #4
Source File: networks.py From connecting_the_dots with MIT License | 6 votes |
def tforward(self, disp0, im, std=None): self.pattern = self.pattern.to(disp0.device) self.uv0 = self.uv0.to(disp0.device) uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:]) uv1 = torch.empty_like(uv0) uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1) uv1[...,1] = uv0[...,1] uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone() pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:]) pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border') mask = torch.ones_like(im) if std is not None: mask = mask*std diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps) val = (mask*diff).sum() / mask.sum() return val, pattern_proj
Example #5
Source File: dynamic_controls.py From attn2d with MIT License | 6 votes |
def fill_controls_emissions_grid(self, controls, emissions, indices, src_length): """ Return controls (C) and emissions (E) covering all the grid C : Tt, N, Ts, 2 E : Tt, N, Ts """ N = controls[0].size(0) tgt_length = len(controls) Cread = controls[0].new_zeros((tgt_length, src_length, N, 1)) Cwrite = utils.fill_with_neg_inf(torch.empty_like(Cread)) triu_mask = torch.triu(controls[0].new_ones(tgt_length, src_length), 1).byte() triu_mask = triu_mask.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, N, 1) Cwrite.masked_fill_(triu_mask, 0) C = torch.cat((Cread, Cwrite), dim=-1) E = utils.fill_with_neg_inf(emissions[0].new(tgt_length, src_length, N)) for t, (subC, subE) in enumerate(zip(controls, emissions)): select = [indices[t]] C[t].index_put_(select, subC.transpose(0, 1)) E[t].index_put_(select, subE.transpose(0, 1)) return C.transpose(1, 2), E.transpose(1, 2)
Example #6
Source File: sampling_result.py From mmdetection with Apache License 2.0 | 6 votes |
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags): self.pos_inds = pos_inds self.neg_inds = neg_inds self.pos_bboxes = bboxes[pos_inds] self.neg_bboxes = bboxes[neg_inds] self.pos_is_gt = gt_flags[pos_inds] self.num_gts = gt_bboxes.shape[0] self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 if gt_bboxes.numel() == 0: # hack for index error case assert self.pos_assigned_gt_inds.numel() == 0 self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) else: if len(gt_bboxes.shape) < 2: gt_bboxes = gt_bboxes.view(-1, 4) self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :] if assign_result.labels is not None: self.pos_gt_labels = assign_result.labels[pos_inds] else: self.pos_gt_labels = None
Example #7
Source File: hmm_controls2.py From attn2d with MIT License | 6 votes |
def _forward_alpha(self, emissions, M): Tt, B, Ts = emissions.size() alpha = utils.fill_with_neg_inf(torch.empty_like(emissions)) # Tt, B, Ts # initialization t=1 initial = torch.empty_like(alpha[0]).fill_(-math.log(Ts)) # log(1/Ts) # initial = utils.fill_with_neg_inf(torch.empty_like(alpha[0])) # initial[:, 0] = 0 alpha[0] = emissions[0] + initial # print('Initialize alpha:', alpha[0]) # induction for i in range(1, Tt): alpha[i] = torch.logsumexp(alpha[i-1].unsqueeze(-1) + M[i-1], dim=1) alpha[i] = alpha[i] + emissions[i] # print('Emissions@', i, emissions[i]) # print('alpha@',i, alpha[i]) return alpha
Example #8
Source File: hmm_controls.py From attn2d with MIT License | 6 votes |
def _forward_alpha(self, emissions, M): Tt, B, Ts = emissions.size() alpha = utils.fill_with_neg_inf(torch.empty_like(emissions)) # Tt, B, Ts # initialization t=1 initial = torch.empty_like(alpha[0]).fill_(-math.log(Ts)) # log(1/Ts) # initial = utils.fill_with_neg_inf(torch.empty_like(alpha[0])) # initial[:, 0] = 0 alpha[0] = emissions[0] + initial # print('Initialize alpha:', alpha[0]) # induction for i in range(1, Tt): alpha[i] = torch.logsumexp(alpha[i-1].unsqueeze(-1) + M[i-1], dim=1) alpha[i] = alpha[i] + emissions[i] # print('Emissions@', i, emissions[i]) # print('alpha@',i, alpha[i]) return alpha
Example #9
Source File: sync_bn.py From mmcv with Apache License 2.0 | 6 votes |
def backward(self, grad_output): norm, std, weight = self.saved_tensors grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(weight) grad_input = torch.empty_like(grad_output) grad_output3d = grad_output.view( grad_output.size(0), grad_output.size(1), -1) grad_input3d = grad_input.view_as(grad_output3d) ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight, grad_bias) # all reduce if self.group_size > 1: dist.all_reduce(grad_weight, group=self.group) dist.all_reduce(grad_bias, group=self.group) grad_weight /= self.group_size grad_bias /= self.group_size ext_module.sync_bn_backward_data(grad_output3d, weight, grad_weight, grad_bias, norm, std, grad_input3d) return grad_input, None, None, grad_weight, grad_bias, \ None, None, None, None
Example #10
Source File: abstract.py From rising with MIT License | 6 votes |
def forward(self, **data) -> dict: """ Args: data: dict with tensors Returns: dict: dict with augmented data """ kwargs = {} for k in self.property_names: kwargs[k] = getattr(self, k) kwargs.update(self.kwargs) for _key in self.keys: out = torch.empty_like(data[_key]) for _i in range(data[_key].shape[0]): out[_i] = self.augment_fn(data[_key][_i], out=out[_i], **kwargs) data[_key] = out return data
Example #11
Source File: intensity.py From rising with MIT License | 6 votes |
def add_noise(data: torch.Tensor, noise_type: str, out: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ Add noise to input Args: data: input data noise_type: supports all inplace functions of a pytorch tensor out: if provided, result is saved in here kwargs: keyword arguments passed to generating function Returns: torch.Tensor: data with added noise See Also: :func:`torch.Tensor.normal_`, :func:`torch.Tensor.exponential_` """ if not noise_type.endswith('_'): noise_type = noise_type + '_' noise_tensor = torch.empty_like(data, requires_grad=False) getattr(noise_tensor, noise_type)(**kwargs) return torch.add(data, noise_tensor, out=out)
Example #12
Source File: swa.py From elektronn3 with MIT License | 6 votes |
def swap_swa_sgd(self): r"""Swaps the values of the optimized variables and swa buffers. It's meant to be called in the end of training to use the collected swa running averages. It can also be used to evaluate the running averages during training; to continue training `swap_swa_sgd` should be called again. """ for group in self.param_groups: for p in group['params']: param_state = self.state[p] if 'swa_buffer' not in param_state: # If swa wasn't applied we don't swap params warnings.warn( "SWA wasn't applied to param {}; skipping it".format(p)) continue buf = param_state['swa_buffer'] tmp = torch.empty_like(p.data) tmp.copy_(p.data) p.data.copy_(buf) buf.copy_(tmp)
Example #13
Source File: test_torch_scattering2d.py From kymatio with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_fft(self, backend): x = torch.randn(2, 2, 2) y = torch.empty_like(x) y[0, 0, :] = x[0, 0, :] + x[0, 1, :] + x[1, 0, :] + x[1, 1, :] y[0, 1, :] = x[0, 0, :] - x[0, 1, :] + x[1, 0, :] - x[1, 1, :] y[1, 0, :] = x[0, 0, :] + x[0, 1, :] - x[1, 0, :] - x[1, 1, :] y[1, 1, :] = x[0, 0, :] - x[0, 1, :] - x[1, 0, :] + x[1, 1, :] z = backend.fft(x, direction='C2C') assert torch.allclose(y, z) z = backend.fft(x, direction='C2C', inverse=True) z = z * 4.0 assert torch.allclose(y, z) z = backend.fft(x, direction='C2R', inverse=True) z = z * 4.0 assert z.shape == x.shape[:-1] assert torch.allclose(y[..., 0], z)
Example #14
Source File: SoftSelect.py From AutoDL-Projects with MIT License | 6 votes |
def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): if tau <= 0: new_logits = logits probs = nn.functional.softmax(new_logits, dim=1) else : while True: # a trick to avoid the gumbels bug gumbels = -torch.empty_like(logits).exponential_().log() new_logits = (logits.log_softmax(dim=1) + gumbels) / tau probs = nn.functional.softmax(new_logits, dim=1) if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break if just_prob: return probs #with torch.no_grad(): # add eps for unexpected torch error # probs = nn.functional.softmax(new_logits, dim=1) # selected_index = torch.multinomial(probs + eps, 2, False) with torch.no_grad(): # add eps for unexpected torch error probs = probs.cpu() selected_index = torch.multinomial(probs + eps, num, False).to(logits.device) selected_logit = torch.gather(new_logits, 1, selected_index) selcted_probs = nn.functional.softmax(selected_logit, dim=1) return selected_index, selcted_probs
Example #15
Source File: swa.py From fast-reid with Apache License 2.0 | 6 votes |
def swap_swa_param(self): r"""Swaps the values of the optimized variables and swa buffers. It's meant to be called in the end of training to use the collected swa running averages. It can also be used to evaluate the running averages during training; to continue training `swap_swa_sgd` should be called again. """ for group in self.param_groups: for p in group['params']: param_state = self.state[p] if 'swa_buffer' not in param_state: # If swa wasn't applied we don't swap params warnings.warn( "SWA wasn't applied to param {}; skipping it".format(p)) continue buf = param_state['swa_buffer'] tmp = torch.empty_like(p.data) tmp.copy_(p.data) p.data.copy_(buf) buf.copy_(tmp)
Example #16
Source File: search_model_gdas.py From AutoDL-Projects with MIT License | 6 votes |
def forward(self, inputs): while True: gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau probs = nn.functional.softmax(logits, dim=1) index = probs.max(-1, keepdim=True)[1] one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) hardwts = one_h - probs.detach() + probs if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue else: break feature = self.stem(inputs) for i, cell in enumerate(self.cells): if isinstance(cell, SearchCell): feature = cell.forward_gdas(feature, hardwts, index) else: feature = cell(feature) out = self.lastact(feature) out = self.global_pooling( out ) out = out.view(out.size(0), -1) logits = self.classifier(out) return out, logits
Example #17
Source File: loss.py From oft with MIT License | 6 votes |
def hard_neg_mining_loss(scores, labels, neg_ratio=5): # Flatten tensors along the spatial dimensions scores = scores.flatten(2, 3) labels = labels.flatten(2, 3) count = labels.size(-1) # Rank negative locations by the predicted confidence _, inds = (scores.sigmoid() * (~labels).float()).sort(-1, descending=True) ordinals = torch.arange(count, out=inds.new_empty(count)).expand_as(inds) rank = torch.empty_like(inds) rank.scatter_(-1, inds, ordinals) # Include only positive locations + N most confident negative locations num_pos = labels.long().sum(dim=-1, keepdim=True) num_neg = (num_pos + 1) * neg_ratio mask = (labels | (rank < num_neg)).float() # Apply cross entropy loss return F.binary_cross_entropy_with_logits( scores, labels.float(), mask, reduction='sum')
Example #18
Source File: ngram.py From espnet with Apache License 2.0 | 6 votes |
def score_partial_(self, y, next_token, state, x): """Score interface for both full and partial scorer. Args: y: previous char next_token: next token need to be score state: previous state x: encoded feature Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ out_state = kenlm.State() ys = self.chardict[y[-1]] if y.shape[0] > 1 else "<s>" self.lm.BaseScore(state, ys, out_state) scores = torch.empty_like(next_token, dtype=x.dtype, device=y.device) for i, j in enumerate(next_token): scores[i] = self.lm.BaseScore( out_state, self.chardict[j], self.tmpkenlmstate ) return scores, out_state
Example #19
Source File: prefetch_data.py From DenseNAS with Apache License 2.0 | 6 votes |
def preload(self): try: self.next_input, self.next_target = next(self.loader) except StopIteration: self.next_input = None self.next_target = None return # if record_stream() doesn't work, another option is to make sure device inputs are created # on the main stream. # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') # Need to make sure the memory allocated for next_* is not still in use by the main stream # at the time we start copying to next_*: # self.stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.stream): self.next_input = self.next_input.cuda(non_blocking=True) self.next_target = self.next_target.cuda(non_blocking=True) self.next_input = self.normalize(self.next_input) if self.is_cutout: self.next_input = self.cutout(self.next_input)
Example #20
Source File: distributed.py From nnUNet with Apache License 2.0 | 5 votes |
def forward(ctx, input): world_size = distributed.get_world_size() # create a destination list for the allgather. I'm assuming you're gathering from 3 workers. allgather_list = [torch.empty_like(input) for _ in range(world_size)] #if distributed.get_rank() == 0: # import IPython;IPython.embed() distributed.all_gather(allgather_list, input) return torch.cat(allgather_list, dim=0)
Example #21
Source File: jpeg_compression.py From HiDDeN with MIT License | 5 votes |
def forward(self, noised_and_cover): noised_image = noised_and_cover[0] # pad the image so that we can do dct on 8x8 blocks pad_height = (8 - noised_image.shape[2] % 8) % 8 pad_width = (8 - noised_image.shape[3] % 8) % 8 noised_image = nn.ZeroPad2d((0, pad_width, 0, pad_height))(noised_image) # convert to yuv image_yuv = torch.empty_like(noised_image) rgb2yuv(noised_image, image_yuv) assert image_yuv.shape[2] % 8 == 0 assert image_yuv.shape[3] % 8 == 0 # apply dct image_dct = self.apply_conv(image_yuv, 'dct') # get the jpeg-compression mask mask = self.get_mask(image_dct.shape[1:]) # multiply the dct-ed image with the mask. image_dct_mask = torch.mul(image_dct, mask) # apply inverse dct (idct) image_idct = self.apply_conv(image_dct_mask, 'idct') # transform from yuv to to rgb image_ret_padded = torch.empty_like(image_dct) yuv2rgb(image_idct, image_ret_padded) # un-pad noised_and_cover[0] = image_ret_padded[:, :, :image_ret_padded.shape[2]-pad_height, :image_ret_padded.shape[3]-pad_width].clone() return noised_and_cover
Example #22
Source File: butterfly_factor.py From learning-circuits with Apache License 2.0 | 5 votes |
def forward(ctx, coefficients, input): ctx.save_for_backward(coefficients, input) return butterfly_factor_multiply(coefficients, input) # output = torch.empty_like(input) # ABCD_mult(coefficients.detach().numpy(), input.detach().numpy(), output.detach().numpy()) # return output
Example #23
Source File: loss.py From EDVR with Apache License 2.0 | 5 votes |
def get_target_label(self, input, target_is_real): if self.gan_type == 'wgan-gp': return target_is_real if target_is_real: return torch.empty_like(input).fill_(self.real_label_val) else: return torch.empty_like(input).fill_(self.fake_label_val)
Example #24
Source File: focal_loss.py From pytorch-loss with MIT License | 5 votes |
def forward(self, logits, label): ''' args: logits: tensor of shape (N, ...) label: tensor of shape(N, ...) ''' # compute loss logits = logits.float() # use fp32 if logits is fp16 with torch.no_grad(): alpha = torch.empty_like(logits).fill_(1 - self.alpha) alpha[label == 1] = self.alpha probs = torch.sigmoid(logits) pt = torch.where(label == 1, probs, 1 - probs) ce_loss = self.crit(logits, label.double()) loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss) if self.reduction == 'mean': loss = loss.mean() if self.reduction == 'sum': loss = loss.sum() return loss ## # version 2: user derived grad computation
Example #25
Source File: focal_loss.py From pytorch-loss with MIT License | 5 votes |
def forward(ctx, logits, label, alpha, gamma): logits = logits.float() coeff = torch.empty_like(logits).fill_(1 - alpha) coeff[label == 1] = alpha probs = torch.sigmoid(logits) log_probs = torch.where(logits >= 0, F.softplus(logits, -1, 50), logits - F.softplus(logits, 1, 50)) log_1_probs = torch.where(logits >= 0, -logits + F.softplus(logits, -1, 50), -F.softplus(logits, 1, 50)) probs_gamma = probs ** gamma probs_1_gamma = (1. - probs) ** gamma ctx.coeff = coeff ctx.probs = probs ctx.log_probs = log_probs ctx.log_1_probs = log_1_probs ctx.probs_gamma = probs_gamma ctx.probs_1_gamma = probs_1_gamma ctx.label = label ctx.gamma = gamma term1 = probs_1_gamma * log_probs term2 = probs_gamma * log_1_probs loss = torch.where(label == 1, term1, term2).mul_(coeff).neg_() return loss
Example #26
Source File: label_smooth.py From pytorch-loss with MIT License | 5 votes |
def forward(self, logits, label): ''' args: logits: tensor of shape (N, C, H, W) args: label: tensor of shape(N, H, W) ''' # overcome ignored label logits = logits.float() # use fp32 to avoid nan with torch.no_grad(): num_classes = logits.size(1) label = label.clone().detach() ignore = label == self.lb_ignore n_valid = (ignore == 0).sum() label[ignore] = 0 lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes lb_one_hot = torch.empty_like(logits).fill_( lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach() logs = self.log_softmax(logits) loss = -torch.sum(logs * lb_one_hot, dim=1) loss[ignore] = 0 if self.reduction == 'mean': loss = loss.sum() / n_valid if self.reduction == 'sum': loss = loss.sum() return loss ## # version 2: user derived grad computation
Example #27
Source File: search_model_gdas_nasnet.py From AutoDL-Projects with MIT License | 5 votes |
def forward(self, inputs): def get_gumbel_prob(xins): while True: gumbels = -torch.empty_like(xins).exponential_().log() logits = (xins.log_softmax(dim=1) + gumbels) / self.tau probs = nn.functional.softmax(logits, dim=1) index = probs.max(-1, keepdim=True)[1] one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) hardwts = one_h - probs.detach() + probs if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue else: break return hardwts, index normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters) reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters) s0 = s1 = self.stem(inputs) for i, cell in enumerate(self.cells): if cell.reduction: hardwts, index = reduce_hardwts, reduce_index else : hardwts, index = normal_hardwts, normal_index s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) out = self.lastact(s1) out = self.global_pooling( out ) out = out.view(out.size(0), -1) logits = self.classifier(out) return out, logits
Example #28
Source File: common.py From catalyst with Apache License 2.0 | 5 votes |
def forward(self, x: torch.Tensor): """Forward call.""" noise = torch.empty_like(x) noise.normal_(0, self.stddev)
Example #29
Source File: hmm_controls2.py From attn2d with MIT License | 5 votes |
def _backward_beta(self, emissions, M): Tt, B, Ts = emissions.size() beta = utils.fill_with_neg_inf(torch.empty_like(emissions)) # Tt, B, Ts # initialization beta[-1] = 0 for i in range(Tt-2, -1, -1): beta[i] = torch.logsumexp(M[i].transpose(1, 2) + # N, Ts, Ts beta[i+1].unsqueeze(-1) + # N, Ts, 1 emissions[i+1].unsqueeze(-1), # N, Ts, 1 dim=1) return beta
Example #30
Source File: lookahead.py From torch-toolbox with BSD 3-Clause "New" or "Revised" License | 5 votes |
def update(self, group): for fast in group["params"]: param_state = self.state[fast] if "slow_param" not in param_state: param_state["slow_param"] = torch.empty_like(fast.data) param_state["slow_param"].copy_(fast.data) slow = param_state["slow_param"] slow += (fast.data - slow) * self.alpha fast.data.copy_(slow)