Python torch.nn.Identity() Examples
The following are 30
code examples of torch.nn.Identity().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn
, or try the search function
.
Example #1
Source File: exercise1_2.py From spinningup with MIT License | 7 votes |
def mlp(sizes, activation, output_activation=nn.Identity): """ Build a multi-layer perceptron in PyTorch. Args: sizes: Tuple, list, or other iterable giving the number of units for each layer of the MLP. activation: Activation function for all layers except last. output_activation: Activation function for last layer. Returns: A PyTorch module that can be called to give the output of the MLP. (Use an nn.Sequential module.) """ ####################### # # # YOUR CODE HERE # # # ####################### pass
Example #2
Source File: visual_backbones.py From virtex with MIT License | 6 votes |
def __init__( self, name: str = "resnet50", visual_feature_size: int = 2048, pretrained: bool = False, frozen: bool = False, ): super().__init__(visual_feature_size) self.cnn = getattr(torchvision.models, name)( pretrained, zero_init_residual=True ) # Do nothing after the final residual stage. self.cnn.fc = nn.Identity() # Freeze all weights if specified. if frozen: for param in self.cnn.parameters(): param.requires_grad = False self.cnn.eval() # Keep a list of intermediate layer names. self._stage_names = [f"layer{i}" for i in range(1, 5)]
Example #3
Source File: polynomial.py From torchsupport with MIT License | 6 votes |
def __init__(self, in_size, out_size, hidden_size=128, depth=2, input_kwargs=None, internal_kwargs=None): super().__init__() self.depth = depth self.input_blocks = nn.ModuleList([ self.make_block(in_size, hidden_size, **input_kwargs) for idx in range(depth) ]) self.internal_blocks = nn.ModuleList([ nn.Identity() ] + [ self.make_block(hidden_size, hidden_size, **internal_kwargs) for idx in range(depth - 1) ]) self.internal_constants = nn.ParameterList([ self.make_constant(hidden_size) for idx in range(depth) ]) self.output_block = self.make_block(hidden_size, out_size, **internal_kwargs) self.output_constant = self.make_constant(out_size)
Example #4
Source File: efficientnet.py From segmentation_models.pytorch with MIT License | 6 votes |
def forward(self, x): stages = self.get_stages() block_number = 0. drop_connect_rate = self._global_params.drop_connect_rate features = [] for i in range(self._depth + 1): # Identity and Sequential stages if i < 2: x = stages[i](x) # Block stages need drop_connect rate else: for module in stages[i]: drop_connect = drop_connect_rate * block_number / len(self._blocks) block_number += 1. x = module(x, drop_connect) features.append(x) return features
Example #5
Source File: modules.py From segmentation_models.pytorch with MIT License | 6 votes |
def __init__(self, name, **params): super().__init__() if name is None or name == 'identity': self.activation = nn.Identity(**params) elif name == 'sigmoid': self.activation = nn.Sigmoid() elif name == 'softmax2d': self.activation = nn.Softmax(dim=1, **params) elif name == 'softmax': self.activation = nn.Softmax(**params) elif name == 'logsoftmax': self.activation = nn.LogSoftmax(**params) elif name == 'argmax': self.activation = ArgMax(**params) elif name == 'argmax2d': self.activation = ArgMax(dim=1, **params) elif callable(name): self.activation = name(**params) else: raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name))
Example #6
Source File: encoder_decoder.py From batchflow with Apache License 2.0 | 6 votes |
def body(cls, inputs, **kwargs): kwargs = cls.get_defaults('body', kwargs) encoder = kwargs.pop('encoder') embedding = kwargs.pop('embedding') decoder = kwargs.pop('decoder') layers = [] encoder = cls.encoder(inputs=inputs, **{**kwargs, **encoder}) encoder_outputs = encoder(inputs) layers.append(('encoder', encoder)) if embedding is not None: embedding = cls.embedding(inputs=encoder_outputs, **{**kwargs, **embedding}) else: embedding = nn.Identity() encoder_outputs = embedding(encoder_outputs) layers.append(('embedding', embedding)) decoder = cls.decoder(inputs=encoder_outputs, **{**kwargs, **decoder}) layers.append(('decoder', decoder)) return nn.Sequential(OrderedDict(layers))
Example #7
Source File: efficientnet.py From convNet.pytorch with MIT License | 6 votes |
def __init__(self, in_channels, out_channels, expansion=1, kernel_size=3, stride=1, padding=1, se_ratio=0.25, hard_act=False): expanded = in_channels * expansion super(MBConv, self).__init__() self.add_res = stride == 1 and in_channels == out_channels self.block = nn.Sequential( ConvBNAct(in_channels, expanded, 1, hard_act=hard_act) if expanded != in_channels else nn.Identity(), ConvBNAct(expanded, expanded, kernel_size, stride=stride, padding=padding, groups=expanded, hard_act=hard_act), SESwishBlock(expanded, expanded, int(in_channels*se_ratio), hard_act=hard_act) if se_ratio > 0 else nn.Identity(), nn.Conv2d(expanded, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels) ) self.drop_prob = 0
Example #8
Source File: dla.py From pytorch-image-models with Apache License 2.0 | 5 votes |
def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if num_classes: num_features = self.num_features * self.global_pool.feat_mult() self.fc = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True) else: self.fc = nn.Identity()
Example #9
Source File: modeling_utils.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __init__(self, config): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' if config.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = nn.Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = nn.Identity() if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() self.first_dropout = nn.Identity() if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = nn.Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
Example #10
Source File: modeling_utils.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __init__(self, config): super().__init__() self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last" if self.summary_type == "attn": # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = Identity() if hasattr(config, "summary_use_proj") and config.summary_use_proj: if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr(config, "summary_activation") and config.summary_activation == "tanh": self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
Example #11
Source File: modeling_utils.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __init__(self, config): super().__init__() self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last" if self.summary_type == "attn": # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = Identity() if hasattr(config, "summary_use_proj") and config.summary_use_proj: if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr(config, "summary_activation") and config.summary_activation == "tanh": self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
Example #12
Source File: modeling_utils.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __init__(self, config): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' if self.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
Example #13
Source File: modeling_utils.py From CCF-BDCI-Sentiment-Analysis-Baseline with Apache License 2.0 | 5 votes |
def __init__(self, config): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' if self.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
Example #14
Source File: md_embedding_bag.py From dlrm with MIT License | 5 votes |
def __init__(self, num_embeddings, embedding_dim, base_dim): super(PrEmbeddingBag, self).__init__() self.embs = nn.EmbeddingBag( num_embeddings, embedding_dim, mode="sum", sparse=True) torch.nn.init.xavier_uniform_(self.embs.weight) if embedding_dim < base_dim: self.proj = nn.Linear(embedding_dim, base_dim, bias=False) torch.nn.init.xavier_uniform_(self.proj.weight) elif embedding_dim == base_dim: self.proj = nn.Identity() else: raise ValueError( "Embedding dim " + str(embedding_dim) + " > base dim " + str(base_dim) )
Example #15
Source File: inception_resnet_v2.py From pytorch-image-models with Apache License 2.0 | 5 votes |
def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes if num_classes: num_features = self.num_features * self.global_pool.feat_mult() self.classif = nn.Linear(num_features, num_classes) else: self.classif = nn.Identity()
Example #16
Source File: pnasnet.py From pytorch-image-models with Apache License 2.0 | 5 votes |
def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if num_classes: num_features = self.num_features * self.global_pool.feat_mult() self.last_linear = nn.Linear(num_features, num_classes) else: self.last_linear = nn.Identity()
Example #17
Source File: efficientnet_blocks.py From pytorch-image-models with Apache License 2.0 | 5 votes |
def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.): super(DepthwiseSeparableConv, self).__init__() norm_kwargs = norm_kwargs or {} has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.drop_path_rate = drop_path_rate self.conv_dw = create_conv2d( in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) self.bn1 = norm_layer(in_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Squeeze-and-excitation if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) else: self.se = None self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
Example #18
Source File: senet.py From pytorch-image-models with Apache License 2.0 | 5 votes |
def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool) if num_classes: num_features = self.num_features * self.avg_pool.feat_mult() self.last_linear = nn.Linear(num_features, num_classes) else: self.last_linear = nn.Identity()
Example #19
Source File: modeling_utils.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __init__(self, *args, **kwargs): super(Identity, self).__init__()
Example #20
Source File: modeling_utils.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __init__(self, config): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' if self.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
Example #21
Source File: modeling_utils.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __init__(self, *args, **kwargs): super(Identity, self).__init__()
Example #22
Source File: modeling_utils.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __init__(self, config): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' if config.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = nn.Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = nn.Identity() if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() self.first_dropout = nn.Identity() if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = nn.Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
Example #23
Source File: modeling_utils.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __init__(self, config: PretrainedConfig): super().__init__() self.summary_type = getattr(config, "summary_type", "last") if self.summary_type == "attn": # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = Identity() if hasattr(config, "summary_use_proj") and config.summary_use_proj: if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) activation_string = getattr(config, "summary_activation", None) self.activation: Callable = (get_activation(activation_string) if activation_string else Identity()) self.first_dropout = Identity() if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
Example #24
Source File: attention.py From batchflow with Apache License 2.0 | 5 votes |
def identity(inputs, **kwargs): """ Return tensor unchanged. """ _ = inputs, kwargs return nn.Identity()
Example #25
Source File: modules.py From batchflow with Apache License 2.0 | 5 votes |
def __init__(self, inputs, layout='cna', filters=None, kernel_size=1, pool_op='mean', pyramid=(0, 1, 2, 3, 6), **kwargs): super().__init__() spatial_shape = np.array(get_shape(inputs)[2:]) filters = filters if filters else 'same // {}'.format(len(pyramid)) modules = nn.ModuleList() for level in pyramid: if level == 0: module = nn.Identity() else: x = inputs pool_size = tuple(np.ceil(spatial_shape / level).astype(np.int32).tolist()) pool_strides = tuple(np.floor((spatial_shape - 1) / level + 1).astype(np.int32).tolist()) layer = ConvBlock(inputs=x, layout='p' + layout, filters=filters, kernel_size=kernel_size, pool_op=pool_op, pool_size=pool_size, pool_strides=pool_strides, **kwargs) x = layer(x) upsample_layer = Upsample(inputs=x, factor=None, layout='b', shape=tuple(spatial_shape.tolist()), **kwargs) module = nn.Sequential(layer, upsample_layer) modules.append(module) self.blocks = modules self.combine = Combine(op='concat')
Example #26
Source File: conv_block.py From batchflow with Apache License 2.0 | 5 votes |
def __init__(self, inputs=None, **kwargs): super().__init__() if kwargs.get('layout'): self.layer = ConvBlock(inputs=inputs, **kwargs) else: self.layer = nn.Identity()
Example #27
Source File: encoder_decoder.py From batchflow with Apache License 2.0 | 5 votes |
def _make_modules(self, inputs, **kwargs): num_stages = kwargs.pop('num_stages') encoder_layout = ''.join([item[0] for item in kwargs.pop('order')]) block_args = kwargs.pop('blocks') downsample_args = kwargs.pop('downsample') self.layout = '' for i in range(num_stages): for letter in encoder_layout: if letter in ['b']: args = {**kwargs, **block_args, **unpack_args(block_args, i, num_stages)} layer = ConvBlock(inputs=inputs, **args) inputs = layer(inputs) layer_desc = 'block-{}'.format(i) elif letter in ['d', 'p']: args = {**kwargs, **downsample_args, **unpack_args(downsample_args, i, num_stages)} layer = ConvBlock(inputs=inputs, **args) inputs = layer(inputs) layer_desc = 'downsample-{}'.format(i) elif letter in ['s']: layer = nn.Identity() layer_desc = 'skip-{}'.format(i) else: raise ValueError('Unknown letter in order {}, use one of "b", "d", "p", "s"' .format(letter)) self.update([(layer_desc, layer)]) self.layout += letter
Example #28
Source File: modeling_utils.py From TextClassify with Apache License 2.0 | 5 votes |
def __init__(self, config): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' if self.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
Example #29
Source File: modeling_utils.py From TextClassify with Apache License 2.0 | 5 votes |
def __init__(self, *args, **kwargs): super(Identity, self).__init__()
Example #30
Source File: decoder.py From segmentation_models.pytorch with MIT License | 5 votes |
def __init__( self, encoder_channels, decoder_channels, n_blocks=5, use_batchnorm=True, attention_type=None, center=False, ): super().__init__() if n_blocks != len(decoder_channels): raise ValueError( "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( n_blocks, len(decoder_channels) ) ) encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder # computing blocks input and output channels head_channels = encoder_channels[0] in_channels = [head_channels] + list(decoder_channels[:-1]) skip_channels = list(encoder_channels[1:]) + [0] out_channels = decoder_channels if center: self.center = CenterBlock( head_channels, head_channels, use_batchnorm=use_batchnorm ) else: self.center = nn.Identity() # combine decoder keyword arguments kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) blocks = [ DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) ] self.blocks = nn.ModuleList(blocks)