Python torch.nn.parallel.data_parallel() Examples

The following are 9 code examples of torch.nn.parallel.data_parallel(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn.parallel , or try the search function .
Example #1
Source File: trainer.py    From torecsys with MIT License 6 votes vote down vote up
def apply_model(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        r"""Apply model forward.

        Args:
            inputs (Dict[str, T]): Dictionary to input to sequential.
        
        Returns:
            torch.Tensor: output of sequential.
        """
        if hasattr(self, "_base_device_ordinal"):
            base_device_ordinal = self._base_device_ordinal
        else:
            base_device_ordinal = None
        
        if self._devices is not None:
            return nn_parallel.data_parallel(self._sequential, inputs, list(self._devices), output_device=base_device_ordinal)
        else:
            return self._sequential(inputs) 
Example #2
Source File: __init__.py    From EDSR-PyTorch with MIT License 6 votes vote down vote up
def forward(self, x, idx_scale):
        self.idx_scale = idx_scale
        if hasattr(self.model, 'set_scale'):
            self.model.set_scale(idx_scale)

        if self.training:
            if self.n_GPUs > 1:
                return P.data_parallel(self.model, x, range(self.n_GPUs))
            else:
                return self.model(x)
        else:
            if self.chop:
                forward_function = self.forward_chop
            else:
                forward_function = self.model.forward

            if self.self_ensemble:
                return self.forward_x8(x, forward_function=forward_function)
            else:
                return forward_function(x) 
Example #3
Source File: lm.py    From espnet with Apache License 2.0 5 votes vote down vote up
def update_core(self):
        """Update the model."""
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
        train_iter = self.get_iterator("main")
        optimizer = self.get_optimizer("main")
        # Progress the dataset iterator for sentences at each iteration.
        self.model.zero_grad()  # Clear the parameter gradients
        accum = {"loss": 0.0, "nll": 0.0, "count": 0}
        for _ in range(self.accum_grad):
            batch = train_iter.__next__()
            # Concatenate the token IDs to matrices and send them to the device
            # self.converter does this job
            # (it is chainer.dataset.concat_examples by default)
            x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
            if self.device[0] == -1:
                loss, nll, count = self.model(x, t)
            else:
                # apex does not support torch.nn.DataParallel
                loss, nll, count = data_parallel(self.model, (x, t), self.device)

            # backward
            loss = loss.mean() / self.accum_grad
            if self.use_apex:
                from apex import amp

                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()  # Backprop
            # accumulate stats
            accum["loss"] += float(loss)
            accum["nll"] += float(nll.sum())
            accum["count"] += int(count.sum())

        for k, v in accum.items():
            reporter.report({k: v}, optimizer.target)
        if self.gradclip is not None:
            nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip)
        optimizer.step()  # Update the parameters
        self.scheduler.step(n_iter=self.iteration) 
Example #4
Source File: lm.py    From espnet with Apache License 2.0 5 votes vote down vote up
def evaluate(self):
        """Evaluate the model."""
        val_iter = self.get_iterator("main")
        loss = 0
        nll = 0
        count = 0
        self.model.eval()
        with torch.no_grad():
            for batch in copy.copy(val_iter):
                x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
                if self.device[0] == -1:
                    l, n, c = self.model(x, t)
                else:
                    # apex does not support torch.nn.DataParallel
                    l, n, c = data_parallel(self.model, (x, t), self.device)
                loss += float(l.sum())
                nll += float(n.sum())
                count += int(c.sum())
        self.model.train()
        # report validation loss
        observation = {}
        with reporter.report_scope(observation):
            reporter.report({"loss": loss}, self.model.reporter)
            reporter.report({"nll": nll}, self.model.reporter)
            reporter.report({"count": count}, self.model.reporter)
        return observation 
Example #5
Source File: networks.py    From IntroVAE with MIT License 5 votes vote down vote up
def encode(self, x):  
        mu, logvar = data_parallel(self.encoder, x)
        return mu, logvar 
Example #6
Source File: networks.py    From IntroVAE with MIT License 5 votes vote down vote up
def decode(self, z):        
        y = data_parallel(self.decoder, z)
        return y 
Example #7
Source File: seq2seq_base.py    From seq2seq.pytorch with MIT License 5 votes vote down vote up
def encode(self, inputs, hidden=None, device_ids=None):
        if isinstance(device_ids, tuple):
            return data_parallel(self.encoder, (inputs, hidden),
                                 device_ids=device_ids,
                                 dim=0 if self.encoder.batch_first else 1)
        else:
            return self.encoder(inputs, hidden) 
Example #8
Source File: seq2seq_base.py    From seq2seq.pytorch with MIT License 5 votes vote down vote up
def decode(self, *kargs, **kwargs):
        device_ids = kwargs.pop('device_ids', None)
        if isinstance(device_ids, tuple):
            return data_parallel(self.decoder, *kargs, **kwargs,
                                 device_ids=device_ids,
                                 dim=0 if self.decoder.batch_first else 1)
        else:
            return self.decoder(*kargs, **kwargs) 
Example #9
Source File: __init__.py    From EDSR-PyTorch with MIT License 4 votes vote down vote up
def forward_chop(self, *args, shave=10, min_size=160000):
        scale = 1 if self.input_large else self.scale[self.idx_scale]
        n_GPUs = min(self.n_GPUs, 4)
        # height, width
        h, w = args[0].size()[-2:]

        top = slice(0, h//2 + shave)
        bottom = slice(h - h//2 - shave, h)
        left = slice(0, w//2 + shave)
        right = slice(w - w//2 - shave, w)
        x_chops = [torch.cat([
            a[..., top, left],
            a[..., top, right],
            a[..., bottom, left],
            a[..., bottom, right]
        ]) for a in args]

        y_chops = []
        if h * w < 4 * min_size:
            for i in range(0, 4, n_GPUs):
                x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
                y = P.data_parallel(self.model, *x, range(n_GPUs))
                if not isinstance(y, list): y = [y]
                if not y_chops:
                    y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
                else:
                    for y_chop, _y in zip(y_chops, y):
                        y_chop.extend(_y.chunk(n_GPUs, dim=0))
        else:
            for p in zip(*x_chops):
                y = self.forward_chop(*p, shave=shave, min_size=min_size)
                if not isinstance(y, list): y = [y]
                if not y_chops:
                    y_chops = [[_y] for _y in y]
                else:
                    for y_chop, _y in zip(y_chops, y): y_chop.append(_y)

        h *= scale
        w *= scale
        top = slice(0, h//2)
        bottom = slice(h - h//2, h)
        bottom_r = slice(h//2 - h, None)
        left = slice(0, w//2)
        right = slice(w - w//2, w)
        right_r = slice(w//2 - w, None)

        # batch size, number of color channels
        b, c = y_chops[0][0].size()[:-2]
        y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]
        for y_chop, _y in zip(y_chops, y):
            _y[..., top, left] = y_chop[0][..., top, left]
            _y[..., top, right] = y_chop[1][..., top, right_r]
            _y[..., bottom, left] = y_chop[2][..., bottom_r, left]
            _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]

        if len(y) == 1: y = y[0]

        return y