Python torch.nn.functional.batch_norm() Examples
The following are 30
code examples of torch.nn.functional.batch_norm().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn.functional
, or try the search function
.
Example #1
Source File: batchnorm2d.py From BayesianDefense with MIT License | 6 votes |
def forward(self, input): self._check_input_dim(input) exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() else: # use exponential moving average exponential_average_factor = self.momentum # generate weight and bias weight = bias = None if self.affine: sig_weight = torch.exp(self.sigma_weight) weight = self.mu_weight + sig_weight * self.eps_weight.normal_() kl_weight = math.log(self.sigma_0) - self.sigma_weight + (sig_weight**2 + self.mu_weight**2) / (2 * self.sigma_0 ** 2) - 0.5 sig_bias = torch.exp(self.sigma_bias) bias = self.mu_bias + sig_bias * self.eps_bias.normal_() kl_bias = math.log(self.sigma_0) - self.sigma_bias + (sig_bias**2 + self.mu_bias**2) / (2 * self.sigma_0 ** 2) - 0.5 out = F.batch_norm(input, self.running_mean, self.running_var, weight, bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps) kl = kl_weight.sum() + kl_bias.sum() return out, kl
Example #2
Source File: module.py From openseg.pytorch with MIT License | 6 votes |
def forward(self, input): if not self.training: return batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input_shape[0], self.num_features, -1) # sum(x) and sum(x^2) N = input.size(0) * input.size(2) xsum, xsqsum = sum_square(input) # all-reduce for global sum(x) and sum(x^2) if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N)) # forward return batchnormtrain(input, mean, 1.0/inv_std, self.weight, self.bias).view(input_shape)
Example #3
Source File: CrossReplicaBN.py From BigGAN-pytorch with Apache License 2.0 | 6 votes |
def forward(self, input): self._check_input_dim(input) exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() else: # use exponential moving average exponential_average_factor = self.momentum return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
Example #4
Source File: conditional_batchnorm.py From pytorch.sngan_projection with MIT License | 6 votes |
def forward(self, input, weight, bias, **kwargs): self._check_input_dim(input) exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() else: # use exponential moving average exponential_average_factor = self.momentum output = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps) if weight.dim() == 1: weight = weight.unsqueeze(0) if bias.dim() == 1: bias = bias.unsqueeze(0) size = output.size() weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) return weight * output + bias
Example #5
Source File: models.py From few-shot with MIT License | 6 votes |
def functional_conv_block(x: torch.Tensor, weights: torch.Tensor, biases: torch.Tensor, bn_weights, bn_biases) -> torch.Tensor: """Performs 3x3 convolution, ReLu activation, 2x2 max pooling in a functional fashion. # Arguments: x: Input Tensor for the conv block weights: Weights for the convolutional block biases: Biases for the convolutional block bn_weights: bn_biases: """ x = F.conv2d(x, weights, biases, padding=1) x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True) x = F.relu(x) x = F.max_pool2d(x, kernel_size=2, stride=2) return x ########## # Models # ##########
Example #6
Source File: test_higher.py From higher with Apache License 2.0 | 6 votes |
def batch_norm( self, inputs, weight=None, bias=None, running_mean=None, running_var=None, training=True, eps=1e-5, momentum=0.1 ): running_mean = torch.zeros(np.prod(np.array(inputs.data.size()[1]))) running_var = torch.ones(np.prod(np.array(inputs.data.size()[1]))) return F.batch_norm( inputs, running_mean, running_var, weight, bias, training, momentum, eps )
Example #7
Source File: batch_norm.py From fast-reid with Apache License 2.0 | 6 votes |
def forward(self, x): if x.requires_grad: # When gradients are needed, F.batch_norm will use extra memory # because its backward op computes gradients for weight/bias as well. scale = self.weight * (self.running_var + self.eps).rsqrt() bias = self.bias - self.running_mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return x * scale + bias else: # When gradients are not needed, F.batch_norm is a single fused op # and provide more optimization opportunities. return F.batch_norm( x, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps, )
Example #8
Source File: syncbn.py From DetNAS with MIT License | 6 votes |
def forward(self, x): if self.training and self.sync: return DistributedSyncBNFucntion.apply(x, self.weight, self.bias, self.running_mean, self.running_var, self.training, self.momentum, self.eps, self.sync) else: exponential_average_factor = 0.0 if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum return F.batch_norm( x, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
Example #9
Source File: syncbn.py From TreeFilter-Torch with MIT License | 6 votes |
def forward(self, input): if not self.training: return batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input_shape[0], self.num_features, -1) # sum(x) and sum(x^2) N = input.size(0) * input.size(2) xsum, xsqsum = sum_square(input) # all-reduce for global sum(x) and sum(x^2) if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N)) # forward return batchnormtrain(input, mean, 1.0/inv_std, self.weight, self.bias).view(input_shape)
Example #10
Source File: models.py From PyTorch-GAN with MIT License | 6 votes |
def forward(self, x): assert ( self.weight is not None and self.bias is not None ), "Please assign weight and bias before calling AdaIN!" b, c, h, w = x.size() running_mean = self.running_mean.repeat(b) running_var = self.running_var.repeat(b) # Apply instance norm x_reshaped = x.contiguous().view(1, b * c, h, w) out = F.batch_norm( x_reshaped, running_mean, running_var, self.weight, self.bias, True, self.momentum, self.eps ) return out.view(b, c, h, w)
Example #11
Source File: BatchNorm.py From UCB with MIT License | 6 votes |
def forward(self, input, sample=False, calculate_log_probs=False): self._check_input_dim(input) if self.mask_flag: self.weight = VariationalPosterior(self.pruned_weight_mu, self.weight_rho, self.device) # if self.use_bias: # self.bias = VariationalPosterior(self.bias_mu, self.bias_rho) if self.training or sample: weight = self.weight.sample() bias = self.bias.sample() else: weight = self.weight.mu bias = self.bias.mu if self.training or calculate_log_probs: self.log_prior = self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias) self.log_variational_posterior = self.weight.log_prob(weight) + self.bias.log_prob(bias) else: self.log_prior, self.log_variational_posterior = 0, 0 return F.batch_norm(input, self.running_mean, self.running_var, weight, bias, self.training or not self.track_running_stats, self.momentum, self.eps)
Example #12
Source File: layers.py From BigGAN-PyTorch with MIT License | 6 votes |
def forward(self, x, y): # Calculate class-conditional gains and biases gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) bias = self.bias(y).view(y.size(0), -1, 1, 1) # If using my batchnorm if self.mybn or self.cross_replica: return self.bn(x, gain=gain, bias=bias) # else: else: if self.norm_style == 'bn': out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, self.training, 0.1, self.eps) elif self.norm_style == 'in': out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, self.training, 0.1, self.eps) elif self.norm_style == 'gn': out = groupnorm(x, self.normstyle) elif self.norm_style == 'nonorm': out = x return out * gain + bias
Example #13
Source File: batch_norm.py From fast-reid with Apache License 2.0 | 6 votes |
def forward(self, input): N, C, H, W = input.shape if self.training or not self.track_running_stats: self.running_mean = self.running_mean.repeat(self.num_splits) self.running_var = self.running_var.repeat(self.num_splits) outputs = F.batch_norm( input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var, self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), True, self.momentum, self.eps).view(N, C, H, W) self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0) self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0) return outputs else: return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, False, self.momentum, self.eps)
Example #14
Source File: layers.py From BigGAN-PyTorch with MIT License | 6 votes |
def forward(self, x, y=None): if self.cross_replica or self.mybn: gain = self.gain.view(1,-1,1,1) bias = self.bias.view(1,-1,1,1) return self.bn(x, gain=gain, bias=bias) else: return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, self.bias, self.training, self.momentum, self.eps) # Generator blocks # Note that this class assumes the kernel size and padding (and any other # settings) have been selected in the main generator module and passed in # through the which_conv arg. Similar rules apply with which_bn (the input # size [which is actually the number of channels of the conditional info] must # be preselected)
Example #15
Source File: norm_act.py From pytorch-image-models with Apache License 2.0 | 6 votes |
def _forward_jit(self, x): """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function """ # exponential_average_factor is self.momentum set to # (when it is available) only so that if gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum x = F.batch_norm( x, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps) return x
Example #16
Source File: batchnorm2d.py From BayesianDefense with MIT License | 6 votes |
def forward_(self, input): self._check_input_dim(input) exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() else: # use exponential moving average exponential_average_factor = self.momentum # generate weight and bias weight = bias = None if self.affine: weight = noise_fn(self.mu_weight, self.sigma_weight, self.eps_weight, self.sigma_0, self.N) bias = noise_fn(self.mu_bias, self.sigma_bias, self.eps_bias, self.sigma_0, self.N) return F.batch_norm(input, self.running_mean, self.running_var, weight, bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
Example #17
Source File: activated_batch_norm.py From pytorch-tools with MIT License | 6 votes |
def forward(self, x): x = F.batch_norm( x, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps, ) func = ACT_FUNC_DICT[self.activation] if self.activation == ACT.LEAKY_RELU: return func(x, inplace=True, negative_slope=self.activation_param) elif self.activation == ACT.ELU: return func(x, inplace=True, alpha=self.activation_param) else: return func(x, inplace=True)
Example #18
Source File: batchnorm.py From fast-reid with Apache License 2.0 | 5 votes |
def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape)
Example #19
Source File: batchnorm.py From pytorch-meta with MIT License | 5 votes |
def forward(self, input, params=None): self._check_input_dim(input) if params is None: params = OrderedDict(self.named_parameters()) # exponential_average_factor is self.momentum set to # (when it is available) only so that if gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum weight = params.get('weight', None) bias = params.get('bias', None) return F.batch_norm( input, self.running_mean, self.running_var, weight, bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
Example #20
Source File: batchnorm.py From reid_baseline_with_syncbn with MIT License | 5 votes |
def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape)
Example #21
Source File: batchnorm.py From kaggle-understanding-clouds with BSD 2-Clause "Simplified" License | 5 votes |
def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape)
Example #22
Source File: vnet.py From elektronn3 with MIT License | 5 votes |
def forward(self, input): self._check_input_dim(input) return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps)
Example #23
Source File: batchnorm.py From RMI with MIT License | 5 votes |
def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape)
Example #24
Source File: batchnorm.py From pytorch_segmentation with MIT License | 5 votes |
def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape)
Example #25
Source File: batchnorm.py From BraTS-DMFNet with Apache License 2.0 | 5 votes |
def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape)
Example #26
Source File: recurrent_BatchNorm.py From Recognizing-Textual-Entailment with MIT License | 5 votes |
def forward(self, input_, index): if index >= self.max_len: index = self.max_len - 1 self._check_input_dim(input_, index) running_mean = getattr(self, 'running_mean_{}'.format(index)) running_var = getattr(self, 'running_var_{}'.format(index)) return F.batch_norm( input_, running_mean, running_var, self.weight, self.bias, self.training, self.momentum, self.eps)
Example #27
Source File: layers.py From SO-Net with MIT License | 5 votes |
def forward(self, input, epoch=None): if (epoch is not None) and (epoch >= 1) and (self.momentum_decay_step is not None) and (self.momentum_decay_step > 0): # perform momentum decay self.momentum = self.momentum_original * (self.momentum_decay**(epoch//self.momentum_decay_step)) if self.momentum < 0.01: self.momentum = 0.01 return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps)
Example #28
Source File: layers.py From SO-Net with MIT License | 5 votes |
def forward(self, input, epoch=None): if (epoch is not None) and (epoch >= 1) and (self.momentum_decay_step is not None) and (self.momentum_decay_step > 0): # perform momentum decay self.momentum = self.momentum_original * (self.momentum_decay**(epoch//self.momentum_decay_step)) if self.momentum < 0.01: self.momentum = 0.01 return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps)
Example #29
Source File: cbbn.py From SingleGAN with MIT License | 5 votes |
def forward(self, input, ConInfor): self._check_input_dim(input) b, c = input.size(0), input.size(1) exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() else: # use exponential moving average exponential_average_factor = self.momentum out = F.batch_norm( input, self.running_mean, self.running_var, None, None, self.training or not self.track_running_stats, exponential_average_factor, self.eps) biasSor = self.avgpool(out) biasTar = self.ConBias(ConInfor).view(b,c,1,1) if self.affine: weight = self.weight.repeat(b).view(b,c,1,1) bias = self.bias.repeat(b).view(b,c,1,1) return (out - biasSor + biasTar)*weight + bias else: return out - biasSor + biasTar
Example #30
Source File: batchnorm.py From bonnetal with MIT License | 5 votes |
def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master( _ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave( _ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * \ _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape)