Python torch.nn.html() Examples
The following are 30
code examples of torch.nn.html().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn
, or try the search function
.
Example #1
Source File: task.py From cloudml-samples with Apache License 2.0 | 6 votes |
def test(net, test_loader): """Test the DNN""" net.eval() criterion = nn.BCELoss() # https://pytorch.org/docs/stable/nn.html#bceloss test_loss = 0 correct = 0 with torch.no_grad(): for i, data in enumerate(test_loader, 0): features = data['features'] target = data['target'] output = net(features) # Binarize the output pred = output.apply_(lambda x: 0.0 if x < 0.5 else 1.0) test_loss += criterion(output, target) # sum up batch loss correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set:\n\tAverage loss: {:.4f}'.format(test_loss)) print('\tAccuracy: {}/{} ({:.0f}%)\n'.format( correct, (len(test_loader) * test_loader.batch_size), 100. * correct / (len(test_loader) * test_loader.batch_size)))
Example #2
Source File: model.py From cups-rl with MIT License | 6 votes |
def __init__(self, in_features, out_features, std_init=0.5): super(NoisyLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.std_init = std_init self.weight_mu = nn.Parameter(torch.empty(out_features, in_features)) self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features)) """ This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the persistent state. Source: https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_buffer """ self.register_buffer('weight_epsilon', torch.empty(out_features, in_features)) self.bias_mu = nn.Parameter(torch.empty(out_features)) self.bias_sigma = nn.Parameter(torch.empty(out_features)) self.register_buffer('bias_epsilon', torch.empty(out_features)) self.reset_parameters() self.reset_noise()
Example #3
Source File: layers.py From texar-pytorch with Apache License 2.0 | 6 votes |
def get_pooling_layer_hparams(hparams: Union[HParams, Dict[str, Any]]) \ -> Dict[str, Any]: r"""Creates pooling layer hyperparameters `dict` for :func:`get_layer`. If the :attr:`hparams` sets `'pool_size'` to `None`, the layer will be changed to the respective reduce-pooling layer. For example, :torch_docs:`torch.conv.MaxPool1d <nn.html#torch.nn.Conv1d>` is replaced with :class:`~texar.torch.core.MaxReducePool1d`. """ if isinstance(hparams, HParams): hparams = hparams.todict() new_hparams = copy.copy(hparams) kwargs = new_hparams.get('kwargs', None) if kwargs and kwargs.get('kernel_size', None) is None: pool_type = hparams['type'] new_hparams['type'] = _POOLING_TO_REDUCE.get(pool_type, pool_type) kwargs.pop('kernel_size', None) kwargs.pop('stride', None) kwargs.pop('padding', None) return new_hparams
Example #4
Source File: generative_adversarial_net.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def main(args: Namespace) -> None: # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ model = GAN(**vars(args)) # ------------------------ # 2 INIT TRAINER # ------------------------ # If use distubuted training PyTorch recommends to use DistributedDataParallel. # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel trainer = Trainer() # ------------------------ # 3 START TRAINING # ------------------------ trainer.fit(model)
Example #5
Source File: main.py From ArtificialIntelligenceEngines with MIT License | 6 votes |
def loss_function(recon_x, x, mu, logvar): # next 2 lines are equivalent BCE = -F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') #BCE = -F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False) # deprecated # for binary_cross_entropy, see https://pytorch.org/docs/stable/nn.html # KLD is Kullback–Leibler divergence -- how much does one learned # distribution deviate from another, in this specific case the # learned distribution from the unit Gaussian # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD = 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # JVS: Kingma's repo = https://github.com/dpkingma/examples/blob/master/vae/main.py # BCE tries to make our reconstruction as accurate as possible # KLD tries to push the distributions as close as possible to unit Gaussian ELBO = BCE + KLD loss = -ELBO return loss
Example #6
Source File: trial.py From torchbearer with MIT License | 6 votes |
def state_dict(self, **kwargs): """Get a dict containing the model and optimizer states, as well as the model history. Example: :: >>> from torchbearer import Trial >>> t = Trial(None) >>> state = t.state_dict() # State dict that can now be saved with torch.save Args: kwargs: See: `torch.nn.Module.state_dict <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.state_dict>`_ Returns: dict: A dict containing parameters and persistent buffers. """ state_dict = { torchbearer.VERSION: torchbearer.__version__.replace('.dev', ''), torchbearer.MODEL: self.state[torchbearer.MODEL].state_dict(**kwargs), torchbearer.OPTIMIZER: self.state[torchbearer.OPTIMIZER].state_dict(), torchbearer.HISTORY: self.state[torchbearer.HISTORY], torchbearer.CALLBACK_LIST: self.state[torchbearer.CALLBACK_LIST].state_dict() } return state_dict
Example #7
Source File: trial.py From torchbearer with MIT License | 6 votes |
def to(self, *args, **kwargs): """ Moves and/or casts the parameters and buffers. Example: :: >>> from torchbearer import Trial >>> t = Trial(None).to('cuda:1') Args: args: See: `torch.nn.Module.to <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.to>`_ kwargs: See: `torch.nn.Module.to <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.to>`_ Returns: Trial: self """ self.state[torchbearer.MODEL].to(*args, **kwargs) for state in self.state[torchbearer.OPTIMIZER].state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(*args, **kwargs) self.state = update_device_and_dtype(self.state, *args, **kwargs) return self
Example #8
Source File: task.py From cloudml-samples with Apache License 2.0 | 6 votes |
def test(net, test_loader): """Test the DNN""" net.eval() criterion = nn.BCELoss() # https://pytorch.org/docs/stable/nn.html#bceloss test_loss = 0 correct = 0 with torch.no_grad(): for i, data in enumerate(test_loader, 0): features = data['features'] target = data['target'] output = net(features) # Binarize the output pred = output.apply_(lambda x: 0.0 if x < 0.5 else 1.0) test_loss += criterion(output, target) # sum up batch loss correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) total = len(test_loader) * test_loader.batch_size accuracy = 100. * correct / total return accuracy
Example #9
Source File: task.py From cloudml-samples with Apache License 2.0 | 6 votes |
def test(net, test_loader): """Test the DNN""" net.eval() criterion = nn.BCELoss() # https://pytorch.org/docs/stable/nn.html#bceloss test_loss = 0 correct = 0 with torch.no_grad(): for i, data in enumerate(test_loader, 0): features = data['features'] target = data['target'] output = net(features) # Binarize the output pred = output.apply_(lambda x: 0.0 if x < 0.5 else 1.0) test_loss += criterion(output, target) # sum up batch loss correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set:\n\tAverage loss: {:.4f}'.format(test_loss)) print('\tAccuracy: {}/{} ({:.0f}%)\n'.format( correct, (len(test_loader) * test_loader.batch_size), 100. * correct / (len(test_loader) * test_loader.batch_size)))
Example #10
Source File: trial.py From torchbearer with MIT License | 5 votes |
def load_state_dict(self, state_dict, resume=True, **kwargs): """Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally, just load the model state when resume=False. Example: :: >>> from torchbearer import Trial >>> t = Trial(None) >>> state = torch.load('some_state.pt') >>> t.load_state_dict(state) Args: state_dict (dict): The state dict to reload resume (bool): If True, resume from the given state. Else, just load in the model weights. kwargs: See: `torch.nn.Module.load_state_dict <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.load_state_dict>`_ Returns: Trial: self """ if resume and torchbearer.MODEL in state_dict: # torchbearer dict if torchbearer.VERSION in state_dict and state_dict[torchbearer.VERSION] != torchbearer.__version__.replace('.dev', ''): warnings.warn('This state dict was saved with a different torchbearer version, loading available keys. Consider setting resume=False') if torchbearer.MODEL in state_dict: self.state[torchbearer.MODEL].load_state_dict(state_dict[torchbearer.MODEL], **kwargs) if torchbearer.OPTIMIZER in state_dict: self.state[torchbearer.OPTIMIZER].load_state_dict(state_dict[torchbearer.OPTIMIZER]) if torchbearer.HISTORY in state_dict: self.state[torchbearer.HISTORY] = state_dict[torchbearer.HISTORY] if torchbearer.CALLBACK_LIST in state_dict: self.state[torchbearer.CALLBACK_LIST].load_state_dict(state_dict[torchbearer.CALLBACK_LIST]) elif torchbearer.MODEL in state_dict: self.state[torchbearer.MODEL].load_state_dict(state_dict[torchbearer.MODEL], **kwargs) else: # something else warnings.warn('Not a torchbearer state dict, passing to model') self.state[torchbearer.MODEL].load_state_dict(state_dict, **kwargs) return self
Example #11
Source File: resnet.py From Teacher-free-Knowledge-Distillation with MIT License | 5 votes |
def loss_fn(outputs, labels): """ Compute the cross entropy loss given outputs and labels. Returns: loss (Variable): cross entropy loss for all images in the batch Note: you may use a standard loss function from http://pytorch.org/docs/master/nn.html#loss-functions. This example demonstrates how you can easily define a custom loss function. """ return nn.CrossEntropyLoss()(outputs, labels)
Example #12
Source File: wrn.py From Teacher-free-Knowledge-Distillation with MIT License | 5 votes |
def loss_fn(outputs, labels): """ Compute the cross entropy loss given outputs and labels. Args: outputs: (Variable) dimension batch_size x 6 - output of the model labels: (Variable) dimension batch_size, where each element is a value in [0, 1, 2, 3, 4, 5] Returns: loss (Variable): cross entropy loss for all images in the batch Note: you may use a standard loss function from http://pytorch.org/docs/master/nn.html#loss-functions. This example demonstrates how you can easily define a custom loss function. """ return nn.CrossEntropyLoss()(outputs, labels)
Example #13
Source File: resnext.py From Teacher-free-Knowledge-Distillation with MIT License | 5 votes |
def loss_fn(outputs, labels): """ Compute the cross entropy loss given outputs and labels. Note: you may use a standard loss function from http://pytorch.org/docs/master/nn.html#loss-functions. This example demonstrates how you can easily define a custom loss function. """ return nn.CrossEntropyLoss()(outputs, labels)
Example #14
Source File: auto.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def auto_optim(optimizer: Optimizer) -> Optimizer: """Helper method to adapt optimizer for non-distributed and distributed configurations (supporting all available backends from :meth:`~ignite.distributed.utils.available_backends()`). Internally, this method is no-op for non-distributed and torch native distributed configuration. For XLA distributed configuration, we create a new class that inherits from provided optimizer. The goal is to override the `step()` method with specific `xm.optimizer_step`_ implementation. Examples: .. code-block:: python import ignite.distribted as idist optimizer = idist.auto_optim(optimizer) Args: optimizer (Optimizer): input torch optimizer Returns: Optimizer .. _xm.optimizer_step: http://pytorch.org/xla/release/1.5/index.html#torch_xla.core.xla_model.optimizer_step """ if not (idist.has_xla_support and idist.backend() == idist_xla.XLA_TPU): return optimizer cls = type(optimizer.__class__.__name__, (optimizer.__class__,), dict(_XLADistributedOptimizer.__dict__)) return cls(optimizer)
Example #15
Source File: word_embedding.py From claf with MIT License | 5 votes |
def __init__( self, vocab, dropout=0.2, embed_dim=100, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False, pretrained_path=None, trainable=True, ): super(WordEmbedding, self).__init__(vocab) self.data_handler = DataHandler(cache_path=CachePath.PRETRAINED_VECTOR) self.embed_dim = embed_dim if dropout and dropout > 0: self.dropout = nn.Dropout(p=dropout) else: self.dropout = lambda x: x if pretrained_path: weight = self._read_pretrained_file(pretrained_path) self.weight = torch.nn.Parameter(weight, requires_grad=trainable) else: self.weight = self._init_weight(trainable=trainable) # nn.functional.embedding = optional paramters # (padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) # check - https://pytorch.org/docs/master/nn.html#torch.nn.functional.embeddin\ # ://pytorch.org/docs/master/nn.html#torch.nn.functional.embedding self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq self.sparse = sparse
Example #16
Source File: util.py From DeepLab_v3_plus with MIT License | 5 votes |
def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): """ logit 是网络输出 (batchsize, 21, 512, 512) 值应该为任意(没经历归一化) target是gt (batchsize, 1, 512, 512) 值应该是背景为0,其他类分别为1-20,忽略为255 return 经过h*w*batchsize平均的loss 这里的loss相当于对每个像素点求分类交叉熵 ignore_index 是指target中有些忽略的(非背景也非目标,是不属于数据集类别的其他物体,不计算loss) 表现为白色 最后要注意:crossentropy是已经经过softmax,所以网络最后一层不需要处理 https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss """ n, c, h, w = logit.size() # logit = logit.permute(0, 2, 3, 1) target = target.squeeze(1)# (batchsize, 1, 512, 512) -> (batchsize, 512, 512) if weight is None: criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) else: criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False) loss = criterion(logit, target.long()) if size_average: loss /= (h * w) if batch_average: loss /= n return loss
Example #17
Source File: resnet.py From knowledge-distillation-pytorch with MIT License | 5 votes |
def loss_fn(outputs, labels): """ Compute the cross entropy loss given outputs and labels. Returns: loss (Variable): cross entropy loss for all images in the batch Note: you may use a standard loss function from http://pytorch.org/docs/master/nn.html#loss-functions. This example demonstrates how you can easily define a custom loss function. """ return nn.CrossEntropyLoss()(outputs, labels)
Example #18
Source File: preresnet.py From knowledge-distillation-pytorch with MIT License | 5 votes |
def loss_fn(outputs, labels): """ Compute the cross entropy loss given outputs and labels. Note: you may use a standard loss function from http://pytorch.org/docs/master/nn.html#loss-functions. This example demonstrates how you can easily define a custom loss function. """ return nn.CrossEntropyLoss()(outputs, labels)
Example #19
Source File: wrn.py From knowledge-distillation-pytorch with MIT License | 5 votes |
def loss_fn(outputs, labels): """ Compute the cross entropy loss given outputs and labels. Args: outputs: (Variable) dimension batch_size x 6 - output of the model labels: (Variable) dimension batch_size, where each element is a value in [0, 1, 2, 3, 4, 5] Returns: loss (Variable): cross entropy loss for all images in the batch Note: you may use a standard loss function from http://pytorch.org/docs/master/nn.html#loss-functions. This example demonstrates how you can easily define a custom loss function. """ return nn.CrossEntropyLoss()(outputs, labels)
Example #20
Source File: net.py From knowledge-distillation-pytorch with MIT License | 5 votes |
def loss_fn(outputs, labels): """ Compute the cross entropy loss given outputs and labels. Args: outputs: (Variable) dimension batch_size x 6 - output of the model labels: (Variable) dimension batch_size, where each element is a value in [0, 1, 2, 3, 4, 5] Returns: loss (Variable): cross entropy loss for all images in the batch Note: you may use a standard loss function from http://pytorch.org/docs/master/nn.html#loss-functions. This example demonstrates how you can easily define a custom loss function. """ return nn.CrossEntropyLoss()(outputs, labels)
Example #21
Source File: resnext.py From knowledge-distillation-pytorch with MIT License | 5 votes |
def loss_fn(outputs, labels): """ Compute the cross entropy loss given outputs and labels. Note: you may use a standard loss function from http://pytorch.org/docs/master/nn.html#loss-functions. This example demonstrates how you can easily define a custom loss function. """ return nn.CrossEntropyLoss()(outputs, labels)
Example #22
Source File: layers.py From texar-pytorch with Apache License 2.0 | 5 votes |
def get_initializer(hparams=None) \ -> Optional[Callable[[torch.Tensor], torch.Tensor]]: r"""Returns an initializer instance. Args: hparams (dict or HParams, optional): Hyperparameters with the structure .. code-block:: python { "type": "initializer_class_or_function", "kwargs": { # ... } } The `"type"` field can be a function name or module path. If name is provided, it be must be from one the following modules: :torch_docs:`torch.nn.init <nn.html#torch-nn-init>` and :mod:`texar.torch.custom`. Besides, the `"type"` field can also be an initialization function called with :python:`initialization_fn(**kwargs)`. In this case `"type"` can be the function, or its name or module path. If no keyword argument is required, `"kwargs"` can be omitted. Returns: An initializer instance. `None` if :attr:`hparams` is `None`. """ if hparams is None: return None kwargs = hparams.get('kwargs', {}) if isinstance(kwargs, HParams): kwargs = kwargs.todict() modules = ['torch.nn.init', 'torch', 'texar.torch.custom'] initializer_fn = utils.get_function(hparams['type'], modules) initializer = functools.partial(initializer_fn, **kwargs) return initializer
Example #23
Source File: lstm.py From online-normalization with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self, input_size, hidden_size, bias=True, norm=None, cell_norm=True, **kwargs): super(NormLSTMCell, self).__init__(input_size=input_size, hidden_size=hidden_size, bias=bias, num_chunks=4) self.reset_parameters() self.cell_norm = norm and cell_norm num_norms = 5 if self.cell_norm else 4 self.norms = None if not norm: warnings.warn('LSTMCell w/out LayerNorm see Pytorch\'s LSTMCell: ' 'https://pytorch.org/docs/stable/nn.html#lstmcell') if norm[0].lower() == 'l': warnings.warn('Using Layer Norm in LSTMCell') self.norms = [nn.LayerNorm(hidden_size) for _ in range(num_norms)] elif norm[0].lower() == 'o': warnings.warn('Using Online Norm in LSTMCell') self.norms = [ OnlineNorm1d(hidden_size, batch_size=kwargs['batch_size'], alpha_fwd=kwargs['alpha_fwd'], alpha_bkw=kwargs['alpha_bkw'], ecm=kwargs['ecm']) for _ in range(num_norms)] self.reset_norm_parameters() self.set_norm_modules()
Example #24
Source File: nce_loss.py From Pytorch-NCE with MIT License | 5 votes |
def nce_loss(self, logit_target_in_model, logit_noise_in_model, logit_noise_in_noise, logit_target_in_noise): """Compute the classification loss given all four probabilities Args: - logit_target_in_model: logit of target words given by the model (RNN) - logit_noise_in_model: logit of noise words given by the model - logit_noise_in_noise: logit of noise words given by the noise distribution - logit_target_in_noise: logit of target words given by the noise distribution Returns: - loss: a mis-classification loss for every single case """ # NOTE: prob <= 1 is not guaranteed logit_model = torch.cat([logit_target_in_model.unsqueeze(2), logit_noise_in_model], dim=2) logit_noise = torch.cat([logit_target_in_noise.unsqueeze(2), logit_noise_in_noise], dim=2) # predicted probability of the word comes from true data distribution # The posterior can be computed as following # p_true = logit_model.exp() / (logit_model.exp() + self.noise_ratio * logit_noise.exp()) # For numeric stability we compute the logits of true label and # directly use bce_with_logits. # Ref https://pytorch.org/docs/stable/nn.html?highlight=bce#torch.nn.BCEWithLogitsLoss logit_true = logit_model - logit_noise - math.log(self.noise_ratio) label = torch.zeros_like(logit_model) label[:, :, 0] = 1 loss = self.bce_with_logits(logit_true, label).sum(dim=2) return loss
Example #25
Source File: engine.py From NeuralDialog-LaRL with Apache License 2.0 | 5 votes |
def __init__(self, dictionary, device_id=None, bad_toks=[], reduction='elementwise_mean'): w = torch.Tensor(len(dictionary)).fill_(1) for tok in bad_toks: w[dictionary.get_idx(tok)] = 0.0 if device_id is not None: w = w.cuda(device_id) # https://pytorch.org/docs/stable/nn.html self.crit = nn.CrossEntropyLoss(w, reduction=reduction)
Example #26
Source File: simple_cnn.py From habitat-api with MIT License | 5 votes |
def _conv_output_dim( self, dimension, padding, dilation, kernel_size, stride ): r"""Calculates the output height and width based on the input height and width to the convolution layer. ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d """ assert len(dimension) == 2 out_dimension = [] for i in range(len(dimension)): out_dimension.append( int( np.floor( ( ( dimension[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1 ) / stride[i] ) + 1 ) ) ) return tuple(out_dimension)
Example #27
Source File: __init__.py From tatk with Apache License 2.0 | 5 votes |
def __init__(self, dictionary, device_id=None, bad_toks=[], reduction='mean'): w = torch.Tensor(len(dictionary)).fill_(1) for tok in bad_toks: w[dictionary.get_idx(tok)] = 0.0 if device_id is not None: w = w.cuda(device_id) # https://pytorch.org/docs/stable/nn.html self.crit = nn.CrossEntropyLoss(w, reduction=reduction)
Example #28
Source File: engine.py From tatk with Apache License 2.0 | 5 votes |
def __init__(self, dictionary, device_id=None, bad_toks=[], reduction='mean'): w = torch.Tensor(len(dictionary)).fill_(1) for tok in bad_toks: w[dictionary.get_idx(tok)] = 0.0 if device_id is not None: w = w.cuda(device_id) # https://pytorch.org/docs/stable/nn.html self.crit = nn.CrossEntropyLoss(w, reduction=reduction)
Example #29
Source File: main.py From ArtificialIntelligenceEngines with MIT License | 5 votes |
def __init__(self): # see https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5, 1) self.conv2 = nn.Conv2d(20, 50, 5, 1) self.fc1 = nn.Linear(4*4*50, 500) self.fc2 = nn.Linear(500, 10)
Example #30
Source File: upsample.py From MONAI with Apache License 2.0 | 5 votes |
def __init__( self, spatial_dims: int, in_channels: int, out_channels: Optional[int] = None, scale_factor=2, with_conv: bool = False, mode: Union[UpsampleMode, str] = UpsampleMode.LINEAR, align_corners: Optional[bool] = True, ): """ Args: spatial_dims: number of spatial dimensions of the input image. in_channels: number of channels of the input image. out_channels: number of channels of the output image. Defaults to `in_channels`. scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2. with_conv: whether to use a transposed convolution for upsampling. Defaults to False. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation. This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively. The interpolation mode. Defaults to ``"linear"``. See also: https://pytorch.org/docs/stable/nn.html#upsample align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True. """ super().__init__() if not out_channels: out_channels = in_channels if not with_conv: mode = UpsampleMode(mode) linear_mode = [UpsampleMode.LINEAR, UpsampleMode.BILINEAR, UpsampleMode.TRILINEAR] if mode in linear_mode: # choose mode based on spatial_dims mode = linear_mode[spatial_dims - 1] self.upsample = nn.Sequential( Conv[Conv.CONV, spatial_dims](in_channels=in_channels, out_channels=out_channels, kernel_size=1), nn.Upsample(scale_factor=scale_factor, mode=mode.value, align_corners=align_corners), ) else: self.upsample = Conv[Conv.CONVTRANS, spatial_dims]( in_channels=in_channels, out_channels=out_channels, kernel_size=scale_factor, stride=scale_factor )