Python torch.enable_grad() Examples
The following are 30
code examples of torch.enable_grad().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: trainers.py From homura with Apache License 2.0 | 7 votes |
def train(self, data_loader: Iterable or DataLoader, mode: str = TRAIN): """ Training the model for an epoch. :param data_loader: :param mode: Name of this loop. Default is `train`. Passed to callbacks. """ self._is_train = True self._epoch += 1 self.model.train() if hasattr(self.loss_f, "train"): self.loss_f.train() with torch.enable_grad(): self._loop(data_loader, mode=mode) if self.scheduler is not None and self.update_scheduler_by_epoch: self.scheduler.step() if isinstance(data_loader, DataLoader) and isinstance(data_loader.sampler, DistributedSampler): data_loader.sampler.set_epoch(self.epoch)
Example #2
Source File: pix2pix.py From jdit with Apache License 2.0 | 6 votes |
def compute_valid(self): """Rewrite this method to compute valid_epoch values. You can return a ``dict`` of values that you want to visualize. .. note:: This method is under ``torch.no_grad():``. So, it will never compute grad. If you want to compute grad, please use ``torch.enable_grad():`` to wrap your operations. Example:: d_fake = self.netD(self.fake.detach()) d_real = self.netD(self.ground_truth) var_dic = {} var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach() return var_dic """ _, d_var_dic = self.compute_g_loss() _, g_var_dic = self.compute_d_loss() var_dic = dict(d_var_dic, **g_var_dic) return var_dic
Example #3
Source File: attacks.py From ss-ood with MIT License | 6 votes |
def forward(self, model, bx, by, by_prime, curr_batch_size): """ :param model: the classifier's forward method :param bx: batch of images :param by: true labels :return: perturbed batch of images """ adv_bx = bx.detach() adv_bx += torch.zeros_like(adv_bx).uniform_(-self.epsilon, self.epsilon) for i in range(self.num_steps): adv_bx.requires_grad_() with torch.enable_grad(): logits, pen = model(adv_bx * 2 - 1) loss = F.cross_entropy(logits[:curr_batch_size], by, reduction='sum') if self.attack_rotations: loss += F.cross_entropy(model.module.rot_pred(pen[curr_batch_size:]), by_prime, reduction='sum') / 8. grad = torch.autograd.grad(loss, adv_bx, only_inputs=True)[0] adv_bx = adv_bx.detach() + self.step_size * torch.sign(grad.detach()) adv_bx = torch.min(torch.max(adv_bx, bx - self.epsilon), bx + self.epsilon).clamp(0, 1) return adv_bx
Example #4
Source File: train_trades_mnist_binary.py From TRADES with MIT License | 6 votes |
def perturb_hinge(net, x_nat): # Perturb function based on (E[\phi(f(x)f(x'))]) # init with random noise net.eval() x = x_nat.detach() + 0.001 * torch.randn(x_nat.shape).cuda().detach() for _ in range(args.num_steps): x.requires_grad_() with torch.enable_grad(): # perturb based on hinge loss loss = torch.mean(torch.clamp(1 - net(x).squeeze(1) * (net(x_nat).squeeze(1) / args.beta), min=0)) grad = torch.autograd.grad(loss, [x])[0] x = x.detach() + args.step_size * torch.sign(grad.detach()) x = torch.min(torch.max(x, x_nat - args.epsilon), x_nat + args.epsilon) x = torch.clamp(x, 0.0, 1.0) net.train() return x
Example #5
Source File: iresblock.py From residual-flows with MIT License | 6 votes |
def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params): ctx.training = training with torch.enable_grad(): x = x.detach().requires_grad_(True) g = gnet(x) ctx.g = g ctx.x = x logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training) if training: grad_x, *grad_params = torch.autograd.grad( logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True ) if grad_x is None: grad_x = torch.zeros_like(x) ctx.save_for_backward(grad_x, *g_params, *grad_params) return safe_detach(g), safe_detach(logdetgrad)
Example #6
Source File: pl.py From neuralsort with MIT License | 6 votes |
def rsample(self, sample_shape, log_score=True): """ sample_shape: number of samples from the PL distribution. Scalar. """ with torch.enable_grad(): # torch.distributions turns off autograd n_samples = sample_shape[0] def sample_gumbel(samples_shape, eps=1e-20): U = torch.zeros(samples_shape, device='cuda').uniform_() return -torch.log(-torch.log(U + eps) + eps) if not log_score: log_s_perturb = torch.log(self.scores.unsqueeze( 0)) + sample_gumbel([n_samples, 1, self.n, 1]) else: log_s_perturb = self.scores.unsqueeze( 0) + sample_gumbel([n_samples, 1, self.n, 1]) log_s_perturb = log_s_perturb.view(-1, self.n, 1) P_hat = self.relaxed_sort(log_s_perturb) P_hat = P_hat.view(n_samples, -1, self.n, self.n) return P_hat.squeeze()
Example #7
Source File: InnerCosFunction.py From Shift-Net_pytorch with MIT License | 6 votes |
def backward(ctx, grad_output): with torch.enable_grad(): input, target, mask = ctx.saved_tensors former = input.narrow(1, 0, ctx.c//2) former_in_mask = torch.mul(former, mask) if former_in_mask.size() != target.size(): # For the last iteration of one epoch target = target.narrow(0, 0, 1).expand_as(former_in_mask).type_as(former_in_mask) former_in_mask_clone = former_in_mask.clone().detach().requires_grad_(True) ctx.loss = ctx.criterion(former_in_mask_clone, target) * ctx.strength ctx.loss.backward() grad_output[:,0:ctx.c//2, :,:] += former_in_mask_clone.grad return grad_output, None, None, None, None
Example #8
Source File: rcgan_trainer.py From rGAN with MIT License | 6 votes |
def cond_gradient_penalty(self, x_real, x_fake, y, netD, index=0): device = x_real.device with torch.enable_grad(): alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device) alpha = alpha.expand_as(x_real) x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach() x_hat.requires_grad = True output = netD(x_hat, y)[index] grad_output = torch.ones(output.size()).to(device) grad = torch.autograd.grad(outputs=output, inputs=x_hat, grad_outputs=grad_output, retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.contiguous().view(grad.size(0), -1) loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean() return loss_gp
Example #9
Source File: racgan_trainer.py From rGAN with MIT License | 6 votes |
def gradient_penalty(self, x_real, x_fake, netD, index=0): device = x_real.device with torch.enable_grad(): alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device) alpha = alpha.expand_as(x_real) x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach() x_hat.requires_grad = True output = netD(x_hat)[index] grad_output = torch.ones(output.size()).to(device) grad = torch.autograd.grad(outputs=output, inputs=x_hat, grad_outputs=grad_output, retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean() return loss_gp
Example #10
Source File: imitator_training.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 6 votes |
def train(self, training_batch, train=True): learning_input = training_batch.training_input with torch.enable_grad(): action_preds = self.imitator(learning_input.state.float_features) # Classification label is index of action with value 1 pred_action_idxs = torch.max(action_preds, dim=1)[1] actual_action_idxs = torch.max(learning_input.action, dim=1)[1] if train: imitator_loss = torch.nn.CrossEntropyLoss() bcq_loss = imitator_loss(action_preds, actual_action_idxs) bcq_loss.backward() self._maybe_run_optimizer( self.imitator_optimizer, self.minibatches_per_step ) return self._imitator_accuracy(pred_action_idxs, actual_action_idxs)
Example #11
Source File: abstract.py From rising with MIT License | 6 votes |
def __call__(self, *args, **kwargs) -> Any: """ Call super class with correct torch context Args: *args: forwarded positional arguments **kwargs: forwarded keyword arguments Returns: Any: transformed data """ if self.grad: context = torch.enable_grad() else: context = torch.no_grad() with context: return super().__call__(*args, **kwargs)
Example #12
Source File: checkpoint.py From pytorch-memonger with MIT License | 6 votes |
def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") inputs = ctx.saved_tensors # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrouding state # when we're done. rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else [] with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state): if preserve_rng_state: torch.set_rng_state(ctx.fwd_cpu_rng_state) if ctx.had_cuda_in_fwd: torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state) detached_inputs = detach_variable(inputs) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) torch.autograd.backward(outputs, args) return (None,) + tuple(inp.grad for inp in detached_inputs)
Example #13
Source File: attack_whitebox.py From ME-Net with MIT License | 6 votes |
def forward(self, inputs, targets): if not args.attack: return self.model(inputs), inputs x = inputs.detach() if self.rand: x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) for i in range(self.num_steps): x.requires_grad_() with torch.enable_grad(): logits = self.model(x) loss = F.cross_entropy(logits, targets, size_average=False) grad = torch.autograd.grad(loss, [x])[0] # print(grad) x = x.detach() + self.step_size * torch.sign(grad.detach()) x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon) x = torch.clamp(x, 0, 1) return self.model(x), x
Example #14
Source File: test_dependency.py From torchgpipe with Apache License 2.0 | 6 votes |
def test_fork_join_enable_grad(): x = torch.rand(1, requires_grad=True) with torch.enable_grad(): x2, p = fork(x) assert p.requires_grad assert x2 is not x x = x2 assert x.requires_grad assert p.requires_grad assert x.grad_fn.__class__ is Fork._backward_cls assert p.grad_fn.__class__ is Fork._backward_cls with torch.enable_grad(): x2 = join(x, p) assert x2 is not x x = x2 assert x.requires_grad assert x.grad_fn.__class__ is Join._backward_cls
Example #15
Source File: train_adv.py From ME-Net with MIT License | 6 votes |
def forward(self, inputs, targets): if not args.attack: return self.model(inputs), inputs x = inputs.detach() if self.rand: x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) for i in range(self.num_steps): x.requires_grad_() with torch.enable_grad(): logits = self.model(x) loss = F.cross_entropy(logits, targets, size_average=False) grad = torch.autograd.grad(loss, [x])[0] # print(grad) x = x.detach() + self.step_size * torch.sign(grad.detach()) x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon) x = torch.clamp(x, 0, 1) return self.model(x), x
Example #16
Source File: deq.py From deq with MIT License | 5 votes |
def backward(ctx, grad): torch.cuda.empty_cache() # grad should have dimension (bsz x d_model x seq_len) bsz, d_model, seq_len = grad.size() grad = grad.clone() z1ss, uss, z0 = ctx.saved_tensors args = ctx.args threshold, train_step = args[-2:] func = ctx.func z1ss = z1ss.clone().detach().requires_grad_() uss = uss.clone().detach() z0 = z0.clone().detach() with torch.enable_grad(): y = RootFind.g(func, z1ss, uss, z0, *args) def g(x): y.backward(x, retain_graph=True) # Retain for future calls to g JTx = z1ss.grad.clone().detach() z1ss.grad.zero_() return JTx + grad eps = 2e-10 * np.sqrt(bsz * seq_len * d_model) dl_df_est = torch.zeros_like(grad) result_info = broyden(g, dl_df_est, threshold=threshold, eps=eps, name="backward") dl_df_est = result_info['result'] y.backward(torch.zeros_like(dl_df_est), retain_graph=False) grad_args = [None for _ in range(len(args))] return (None, dl_df_est, None, None, *grad_args)
Example #17
Source File: norm_flows.py From ffjord with MIT License | 5 votes |
def _detgrad(self, z): """Computes |det df/dz|""" with torch.enable_grad(): z = z.requires_grad_(True) h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b) psi = grad(h, z, grad_outputs=torch.ones_like(h), create_graph=True, only_inputs=True)[0] u_dot_psi = torch.mm(psi, self.u.view(self.nd, 1)) detgrad = 1 + u_dot_psi return detgrad
Example #18
Source File: iresblock.py From residual-flows with MIT License | 5 votes |
def backward(ctx, grad_g, grad_logdetgrad): training = ctx.training if not training: raise ValueError('Provide training=True if using backward.') with torch.enable_grad(): grad_x, *params_and_grad = ctx.saved_tensors g, x = ctx.g, ctx.x # Precomputed gradients. g_params = params_and_grad[:len(params_and_grad) // 2] grad_params = params_and_grad[len(params_and_grad) // 2:] dg_x, *dg_params = torch.autograd.grad(g, [x] + g_params, grad_g, allow_unused=True) # Update based on gradient from logdetgrad. dL = grad_logdetgrad[0].detach() with torch.no_grad(): grad_x.mul_(dL) grad_params = tuple([g.mul_(dL) if g is not None else None for g in grad_params]) # Update based on gradient from g. with torch.no_grad(): grad_x.add_(dg_x) grad_params = tuple([dg.add_(djac) if djac is not None else dg for dg, djac in zip(dg_params, grad_params)]) return (None, None, grad_x, None, None, None, None) + grad_params
Example #19
Source File: pgd_attack_mnist.py From TRADES with MIT License | 5 votes |
def _pgd_whitebox(model, X, y, epsilon=args.epsilon, num_steps=args.num_steps, step_size=args.step_size): out = model(X) err = (out.data.max(1)[1] != y.data).float().sum() X_pgd = Variable(X.data, requires_grad=True) if args.random: random_noise = torch.FloatTensor(*X_pgd.shape).uniform_(-epsilon, epsilon).to(device) X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True) for _ in range(num_steps): opt = optim.SGD([X_pgd], lr=1e-3) opt.zero_grad() with torch.enable_grad(): loss = nn.CrossEntropyLoss()(model(X_pgd), y) loss.backward() eta = step_size * X_pgd.grad.data.sign() X_pgd = Variable(X_pgd.data + eta, requires_grad=True) eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon) X_pgd = Variable(X.data + eta, requires_grad=True) X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True) err_pgd = (model(X_pgd).data.max(1)[1] != y.data).float().sum() print('err pgd (white-box): ', err_pgd) return err, err_pgd
Example #20
Source File: mutator.py From nni with MIT License | 5 votes |
def forward(ctx, x, binary_gates, run_func, backward_func): ctx.run_func = run_func ctx.backward_func = backward_func detached_x = detach_variable(x) with torch.enable_grad(): output = run_func(detached_x) ctx.save_for_backward(detached_x, output) return output.data
Example #21
Source File: power.py From tensorgrad with Apache License 2.0 | 5 votes |
def backward(ctx, grad): A, x = detach_variable(ctx.saved_tensors) dA = grad while True: with torch.enable_grad(): grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0] if (torch.norm(grad) > ctx.tol): dA = dA + grad else: break with torch.enable_grad(): dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0] return dA, None, None
Example #22
Source File: lazy_evaluated_kernel_tensor.py From gpytorch with MIT License | 5 votes |
def _quad_form_derivative(self, left_vecs, right_vecs): # This _quad_form_derivative computes the kernel in chunks # It is only used when we are using kernel checkpointing # It won't be called if checkpointing is off split_size = beta_features.checkpoint_kernel.value() if not split_size: raise RuntimeError( "Should not have ended up in LazyEvaluatedKernelTensor._quad_form_derivative without kernel " "checkpointing. This is probably a bug in GPyTorch." ) x1 = self.x1.detach().requires_grad_(True) x2 = self.x2.detach().requires_grad_(True) # Break objects into chunks sub_x1s = torch.split(x1, split_size, dim=-2) sub_left_vecss = torch.split(left_vecs, split_size, dim=-2) # Compute the gradient in chunks for sub_x1, sub_left_vecs in zip(sub_x1s, sub_left_vecss): sub_x1.detach_().requires_grad_(True) with torch.enable_grad(), settings.lazily_evaluate_kernels(False): sub_kernel_matrix = lazify( self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params) ) sub_grad_outputs = tuple(sub_kernel_matrix._quad_form_derivative(sub_left_vecs, right_vecs)) sub_kernel_outputs = tuple(sub_kernel_matrix.representation()) torch.autograd.backward(sub_kernel_outputs, sub_grad_outputs) x1.grad = torch.cat([sub_x1.grad.data for sub_x1 in sub_x1s], dim=-2) return x1.grad, x2.grad
Example #23
Source File: torch_agent.py From KBRD with MIT License | 5 votes |
def batch_act(self, observations): """ Process a batch of observations (batchsize list of message dicts). These observations have been preprocessed by the observe method. Subclasses can override this for special functionality, but if the default behaviors are fine then just override the ``train_step`` and ``eval_step`` methods instead. The former is called when labels are present in the observations batch; otherwise, the latter is called. """ batch_size = len(observations) # initialize a list of replies with this agent's id batch_reply = [{'id': self.getID()} for _ in range(batch_size)] # check if there are any labels available, if so we will train on them self.is_training = any('labels' in obs for obs in observations) # create a batch from the vectors batch = self.batchify(observations) if self.is_training: output = self.train_step(batch) else: with torch.no_grad(): # save memory and compute by disabling autograd. # use `with torch.enable_grad()` to gain back graidients. output = self.eval_step(batch) if output is None: self.replies['batch_reply'] = None return batch_reply self.match_batch(batch_reply, batch.valid_indices, output) self.replies['batch_reply'] = batch_reply self._save_history(observations, batch_reply) # save model predictions return batch_reply
Example #24
Source File: main.py From NJUNMT-pytorch with MIT License | 5 votes |
def compute_forward(model, critic, seqs_x, seqs_y, eval=False, normalization=1.0, norm_by_words=False ): """ :type model: nn.Module :type critic: NMTCriterion """ y_inp = seqs_y[:, :-1].contiguous() y_label = seqs_y[:, 1:].contiguous() words_norm = y_label.ne(PAD).float().sum(1) if not eval: model.train() critic.train() # For training with torch.enable_grad(): log_probs = model(seqs_x, y_inp) loss = critic(inputs=log_probs, labels=y_label, reduce=False, normalization=normalization) if norm_by_words: loss = loss.div(words_norm).sum() else: loss = loss.sum() torch.autograd.backward(loss) return loss.item() else: model.eval() critic.eval() # For compute loss with torch.no_grad(): log_probs = model(seqs_x, y_inp) loss = critic(inputs=log_probs, labels=y_label, normalization=normalization, reduce=True) return loss.item()
Example #25
Source File: attack_pgd.py From semisup-adv with MIT License | 5 votes |
def pgd(model, X, y, epsilon=8 / 255, num_steps=20, step_size=0.01, random_start=True): out = model(X) is_correct_natural = (out.max(1)[1] == y).float().cpu().numpy() perturbation = torch.zeros_like(X, requires_grad=True) if random_start: perturbation = torch.rand_like(X, requires_grad=True) perturbation.data = perturbation.data * 2 * epsilon - epsilon is_correct_adv = [] opt = optim.SGD([perturbation], lr=1e-3) # This is just to clear the grad for _ in range(num_steps): opt.zero_grad() with torch.enable_grad(): loss = nn.CrossEntropyLoss()(model(X + perturbation), y) loss.backward() perturbation.data = ( perturbation + step_size * perturbation.grad.detach().sign()).clamp( -epsilon, epsilon) perturbation.data = torch.min(torch.max(perturbation.detach(), -X), 1 - X) # clip X+delta to [0,1] X_pgd = Variable(torch.clamp(X.data + perturbation.data, 0, 1.0), requires_grad=False) is_correct_adv.append(np.reshape( (model(X_pgd).max(1)[1] == y).float().cpu().numpy(), [-1, 1])) is_correct_adv = np.concatenate(is_correct_adv, axis=1) return is_correct_natural, is_correct_adv
Example #26
Source File: torch_agent.py From neural_chat with MIT License | 5 votes |
def batch_act(self, observations): """ Process a batch of observations (batchsize list of message dicts). These observations have been preprocessed by the observe method. Subclasses can override this for special functionality, but if the default behaviors are fine then just override the ``train_step`` and ``eval_step`` methods instead. The former is called when labels are present in the observations batch; otherwise, the latter is called. """ batch_size = len(observations) # initialize a list of replies with this agent's id batch_reply = [Message({'id': self.getID()}) for _ in range(batch_size)] # check if there are any labels available, if so we will train on them self.is_training = any('labels' in obs for obs in observations) # create a batch from the vectors batch = self.batchify(observations) if self.is_training: output = self.train_step(batch) else: with torch.no_grad(): # save memory and compute by disabling autograd. # use `with torch.enable_grad()` to gain back graidients. output = self.eval_step(batch) if output is None: self.replies['batch_reply'] = None return batch_reply self.match_batch(batch_reply, batch.valid_indices, output) self.replies['batch_reply'] = batch_reply #print('hello', self.observation) return batch_reply
Example #27
Source File: utils.py From elastic with Apache License 2.0 | 5 votes |
def backward(ctx, *output_grads): for i in range(len(ctx.input_tensors)): temp = ctx.input_tensors[i] ctx.input_tensors[i] = temp.detach() ctx.input_tensors[i].requires_grad = temp.requires_grad with torch.enable_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) return (None, None) + input_grads
Example #28
Source File: ProxProp.py From proxprop with MIT License | 5 votes |
def backward(ctx, grad_z): input = ctx.saved_variables[0] params = list(ctx.saved_variables[1:-1]) output = ctx.saved_variables[-1] grad_input = None grad_params = [None] * len(params) layer = ctx.optimization_layer # explicit gradient step on z z_updated = output - grad_z # prox step or gradient step on the network parameters if layer.optimization_mode == 'prox_exact': A, Y, Z = layer.to_exact_solve_shape(input, z_updated, *params) else: A = input.detach() Y = z_updated.detach() Z = layer.to_cg_shape(params).detach() X_tensor = optimization_step(A, Y, Z, 1./layer.tau_prox, layer.apply_cg, mode=layer.optimization_mode) if 'prox_exact' == layer.optimization_mode: params_udpated = list(layer.from_exact_solve_shape(X_tensor).values()) else: params_udpated = list(layer.from_cg_shape(X_tensor).values()) # write difference in grad fields grad_params = [x[0] - x[1] for x in zip(params,params_udpated)] # explicit gradient step on a input.requires_grad_() with torch.enable_grad(): out_temp = ctx.optimization_layer.apply_forward(input) grad_temp = torch.autograd.grad(out_temp, input, grad_z) grad_input = grad_temp[0] return tuple([grad_input] + grad_params + [None])
Example #29
Source File: metric_optimization_torch.py From pysaliency with MIT License | 5 votes |
def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue d_p = p.grad learning_rate = group['lr'] # constraint_grad = torch.ones_like(d_p) constraint_grad_norm = torch.sum(torch.pow(constraint_grad, 2)) normed_constraint_grad = constraint_grad / constraint_grad_norm # first step: make sure we are not running into negative values max_allowed_grad = p / learning_rate projected_grad1 = torch.min(d_p, max_allowed_grad) # second step: Make sure that the gradient does not walk # out of the constraint projected_grad2 = projected_grad1 - torch.sum(projected_grad1 * constraint_grad) * normed_constraint_grad p.add_(projected_grad2, alpha=-group['lr']) return loss
Example #30
Source File: NeuralIntegral.py From UMNN with BSD 3-Clause "New" or "Revised" License | 5 votes |
def computeIntegrand(x, h, integrand, x_tot): with torch.enable_grad(): f = integrand.forward(x, h) g_param = _flatten(torch.autograd.grad(f, integrand.parameters(), x_tot, create_graph=True, retain_graph=True)) g_h = _flatten(torch.autograd.grad(f, h, x_tot)) return g_param, g_h