Python torch.nn.parallel.data_parallel() Examples
The following are 9
code examples of torch.nn.parallel.data_parallel().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn.parallel
, or try the search function
.
Example #1
Source File: trainer.py From torecsys with MIT License | 6 votes |
def apply_model(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: r"""Apply model forward. Args: inputs (Dict[str, T]): Dictionary to input to sequential. Returns: torch.Tensor: output of sequential. """ if hasattr(self, "_base_device_ordinal"): base_device_ordinal = self._base_device_ordinal else: base_device_ordinal = None if self._devices is not None: return nn_parallel.data_parallel(self._sequential, inputs, list(self._devices), output_device=base_device_ordinal) else: return self._sequential(inputs)
Example #2
Source File: __init__.py From EDSR-PyTorch with MIT License | 6 votes |
def forward(self, x, idx_scale): self.idx_scale = idx_scale if hasattr(self.model, 'set_scale'): self.model.set_scale(idx_scale) if self.training: if self.n_GPUs > 1: return P.data_parallel(self.model, x, range(self.n_GPUs)) else: return self.model(x) else: if self.chop: forward_function = self.forward_chop else: forward_function = self.model.forward if self.self_ensemble: return self.forward_x8(x, forward_function=forward_function) else: return forward_function(x)
Example #3
Source File: lm.py From espnet with Apache License 2.0 | 5 votes |
def update_core(self): """Update the model.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator("main") optimizer = self.get_optimizer("main") # Progress the dataset iterator for sentences at each iteration. self.model.zero_grad() # Clear the parameter gradients accum = {"loss": 0.0, "nll": 0.0, "count": 0} for _ in range(self.accum_grad): batch = train_iter.__next__() # Concatenate the token IDs to matrices and send them to the device # self.converter does this job # (it is chainer.dataset.concat_examples by default) x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) if self.device[0] == -1: loss, nll, count = self.model(x, t) else: # apex does not support torch.nn.DataParallel loss, nll, count = data_parallel(self.model, (x, t), self.device) # backward loss = loss.mean() / self.accum_grad if self.use_apex: from apex import amp with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # Backprop # accumulate stats accum["loss"] += float(loss) accum["nll"] += float(nll.sum()) accum["count"] += int(count.sum()) for k, v in accum.items(): reporter.report({k: v}, optimizer.target) if self.gradclip is not None: nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip) optimizer.step() # Update the parameters self.scheduler.step(n_iter=self.iteration)
Example #4
Source File: lm.py From espnet with Apache License 2.0 | 5 votes |
def evaluate(self): """Evaluate the model.""" val_iter = self.get_iterator("main") loss = 0 nll = 0 count = 0 self.model.eval() with torch.no_grad(): for batch in copy.copy(val_iter): x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) if self.device[0] == -1: l, n, c = self.model(x, t) else: # apex does not support torch.nn.DataParallel l, n, c = data_parallel(self.model, (x, t), self.device) loss += float(l.sum()) nll += float(n.sum()) count += int(c.sum()) self.model.train() # report validation loss observation = {} with reporter.report_scope(observation): reporter.report({"loss": loss}, self.model.reporter) reporter.report({"nll": nll}, self.model.reporter) reporter.report({"count": count}, self.model.reporter) return observation
Example #5
Source File: networks.py From IntroVAE with MIT License | 5 votes |
def encode(self, x): mu, logvar = data_parallel(self.encoder, x) return mu, logvar
Example #6
Source File: networks.py From IntroVAE with MIT License | 5 votes |
def decode(self, z): y = data_parallel(self.decoder, z) return y
Example #7
Source File: seq2seq_base.py From seq2seq.pytorch with MIT License | 5 votes |
def encode(self, inputs, hidden=None, device_ids=None): if isinstance(device_ids, tuple): return data_parallel(self.encoder, (inputs, hidden), device_ids=device_ids, dim=0 if self.encoder.batch_first else 1) else: return self.encoder(inputs, hidden)
Example #8
Source File: seq2seq_base.py From seq2seq.pytorch with MIT License | 5 votes |
def decode(self, *kargs, **kwargs): device_ids = kwargs.pop('device_ids', None) if isinstance(device_ids, tuple): return data_parallel(self.decoder, *kargs, **kwargs, device_ids=device_ids, dim=0 if self.decoder.batch_first else 1) else: return self.decoder(*kargs, **kwargs)
Example #9
Source File: __init__.py From EDSR-PyTorch with MIT License | 4 votes |
def forward_chop(self, *args, shave=10, min_size=160000): scale = 1 if self.input_large else self.scale[self.idx_scale] n_GPUs = min(self.n_GPUs, 4) # height, width h, w = args[0].size()[-2:] top = slice(0, h//2 + shave) bottom = slice(h - h//2 - shave, h) left = slice(0, w//2 + shave) right = slice(w - w//2 - shave, w) x_chops = [torch.cat([ a[..., top, left], a[..., top, right], a[..., bottom, left], a[..., bottom, right] ]) for a in args] y_chops = [] if h * w < 4 * min_size: for i in range(0, 4, n_GPUs): x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] y = P.data_parallel(self.model, *x, range(n_GPUs)) if not isinstance(y, list): y = [y] if not y_chops: y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] else: for y_chop, _y in zip(y_chops, y): y_chop.extend(_y.chunk(n_GPUs, dim=0)) else: for p in zip(*x_chops): y = self.forward_chop(*p, shave=shave, min_size=min_size) if not isinstance(y, list): y = [y] if not y_chops: y_chops = [[_y] for _y in y] else: for y_chop, _y in zip(y_chops, y): y_chop.append(_y) h *= scale w *= scale top = slice(0, h//2) bottom = slice(h - h//2, h) bottom_r = slice(h//2 - h, None) left = slice(0, w//2) right = slice(w - w//2, w) right_r = slice(w//2 - w, None) # batch size, number of color channels b, c = y_chops[0][0].size()[:-2] y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] for y_chop, _y in zip(y_chops, y): _y[..., top, left] = y_chop[0][..., top, left] _y[..., top, right] = y_chop[1][..., top, right_r] _y[..., bottom, left] = y_chop[2][..., bottom_r, left] _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] if len(y) == 1: y = y[0] return y