Python torch.nn.functional.threshold() Examples

The following are 24 code examples of torch.nn.functional.threshold(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn.functional , or try the search function .
Example #1
Source File: pytorch_to_caffe.py    From PytorchToCaffe with MIT License 6 votes vote down vote up
def _threshold(raw,input, threshold, value, inplace=False):
    # for threshold or relu
    if threshold==0 and value==0:
        x = raw(input,threshold, value, inplace)
        bottom_blobs=[log.blobs(input)]
        name = log.add_layer(name='relu')
        log.add_blobs([x], name='relu_blob')
        layer = caffe_net.Layer_param(name=name, type='ReLU',
                                      bottom=bottom_blobs, top=[log.blobs(x)])
        log.cnet.add_layer(layer)
        return x
    if value!=0:
        raise NotImplemented("value !=0 not implemented in caffe")
    x=raw(input,input, threshold, value, inplace)
    bottom_blobs=[log.blobs(input)]
    layer_name=log.add_layer(name='threshold')
    top_blobs=log.add_blobs([x],name='threshold_blob')
    layer=caffe_net.Layer_param(name=layer_name,type='Threshold',
                                bottom=bottom_blobs,top=top_blobs)
    layer.param.threshold_param.threshold = threshold
    log.cnet.add_layer(layer)
    return x 
Example #2
Source File: pytorch_to_caffe.py    From PytorchToCaffe with MIT License 6 votes vote down vote up
def _relu(raw, input, inplace=False):
    # for threshold or prelu
    x = raw(input, False)
    name = log.add_layer(name='relu')
    log.add_blobs([x], name='relu_blob')
    layer = caffe_net.Layer_param(name=name, type='ReLU',
                                  bottom=[log.blobs(input)], top=[log.blobs(x)])
    log.cnet.add_layer(layer)
    return x 
Example #3
Source File: pytorch_to_caffe.py    From fast-reid with Apache License 2.0 6 votes vote down vote up
def _threshold(raw, input, threshold, value, inplace=False):
    # for threshold or relu
    if threshold == 0 and value == 0:
        x = raw(input, threshold, value, inplace)
        bottom_blobs = [log.blobs(input)]
        name = log.add_layer(name='relu')
        log.add_blobs([x], name='relu_blob')
        layer = caffe_net.Layer_param(name=name, type='ReLU',
                                      bottom=bottom_blobs, top=[log.blobs(x)])
        log.cnet.add_layer(layer)
        return x
    if value != 0:
        raise NotImplemented("value !=0 not implemented in caffe")
    x = raw(input, input, threshold, value, inplace)
    bottom_blobs = [log.blobs(input)]
    layer_name = log.add_layer(name='threshold')
    top_blobs = log.add_blobs([x], name='threshold_blob')
    layer = caffe_net.Layer_param(name=layer_name, type='Threshold',
                                  bottom=bottom_blobs, top=top_blobs)
    layer.param.threshold_param.threshold = threshold
    log.cnet.add_layer(layer)
    return x 
Example #4
Source File: pytorch_to_caffe.py    From fast-reid with Apache License 2.0 5 votes vote down vote up
def _relu(raw, input, inplace=False):
    # for threshold or prelu
    x = raw(input, False)
    name = log.add_layer(name='relu')
    log.add_blobs([x], name='relu_blob')
    layer = caffe_net.Layer_param(name=name, type='ReLU',
                                  bottom=[log.blobs(input)], top=[log.blobs(x)])
    log.cnet.add_layer(layer)
    return x 
Example #5
Source File: threshold.py    From onnx2keras with MIT License 5 votes vote down vote up
def forward(self, x):
        from torch.nn import functional as F
        return F.threshold(x, threshold=self.threshold, value=self.value) 
Example #6
Source File: threshold.py    From onnx2keras with MIT License 5 votes vote down vote up
def __init__(self):
        super(FThresholdTest, self).__init__()
        self.threshold = random.random()
        self.value = self.threshold + random.random() 
Example #7
Source File: threshold.py    From onnx2keras with MIT License 5 votes vote down vote up
def __init__(self):
        super(LayerThresholdTest, self).__init__()
        self.threshold = random.random()
        self.value = self.threshold + random.random()
        self.thresh = nn.Threshold(self.threshold, self.value) 
Example #8
Source File: MobilenetV3.py    From DBNet.pytorch with Apache License 2.0 5 votes vote down vote up
def forward(self, x):
        x = (self.slope * x) + self.offset
        x = F.threshold(-x, -1, -1)
        x = F.threshold(-x, 0, 0)
        return x 
Example #9
Source File: layers.py    From modular-metalearning with MIT License 5 votes vote down vote up
def relu(input):
    return F.threshold(input, 0, 0, inplace=True) 
Example #10
Source File: layers.py    From modular-metalearning with MIT License 5 votes vote down vote up
def relu(input):
    return F.threshold(input, 0, 0, inplace=True) 
Example #11
Source File: test_pyprof_nvtx.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_softplus(self):
        inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype)
        output = F.softplus(inp, beta=1, threshold=20) 
Example #12
Source File: test_pyprof_nvtx.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_threshold(self):
        inp = torch.randn(1, 8, 32, 32, device='cuda', dtype=self.dtype)
        output = F.threshold(inp, 6, 6, inplace=False) 
Example #13
Source File: pytorch_to_caffe.py    From fast-reid with Apache License 2.0 5 votes vote down vote up
def _prelu(raw, input, weight):
    # for threshold or prelu
    x = raw(input, weight)
    bottom_blobs = [log.blobs(input)]
    name = log.add_layer(name='prelu')
    log.add_blobs([x], name='prelu_blob')
    layer = caffe_net.Layer_param(name=name, type='PReLU',
                                  bottom=bottom_blobs, top=[log.blobs(x)])
    if weight.size()[0] == 1:
        layer.param.prelu_param.channel_shared = True
        layer.add_data(weight.cpu().data.numpy()[0])
    else:
        layer.add_data(weight.cpu().data.numpy())
    log.cnet.add_layer(layer)
    return x 
Example #14
Source File: lstm_hard_sigmoid.py    From SemEval2019Task3 with MIT License 5 votes vote down vote up
def hard_sigmoid(x):
    """
    Computes element-wise hard sigmoid of x.
    See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279
    """
    x = (0.2 * x) + 0.5
    x = F.threshold(-x, -1, -1)
    x = F.threshold(-x, 0, 0)
    return x 
Example #15
Source File: lstm.py    From neural_chat with MIT License 5 votes vote down vote up
def hard_sigmoid(x):
    """
    Computes element-wise hard sigmoid of x.
    See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279
    """
    x = (0.2 * x) + 0.5
    x = F.threshold(-x, -1, -1)
    x = F.threshold(-x, 0, 0)
    return x 
Example #16
Source File: basic_batch.py    From landmark-detection with MIT License 5 votes vote down vote up
def find_tensor_peak_batch(heatmap, radius, downsample, threshold = 0.000001):
  assert heatmap.dim() == 3, 'The dimension of the heatmap is wrong : {}'.format(heatmap.size())
  assert radius > 0 and isinstance(radius, numbers.Number), 'The radius is not ok : {}'.format(radius)
  num_pts, H, W = heatmap.size(0), heatmap.size(1), heatmap.size(2)
  assert W > 1 and H > 1, 'To avoid the normalization function divide zero'
  # find the approximate location:
  score, index = torch.max(heatmap.view(num_pts, -1), 1)
  index_w = (index % W).float()
  index_h = (index / W).float()
  
  def normalize(x, L):
    return -1. + 2. * x.data / (L-1)
  boxes = [index_w - radius, index_h - radius, index_w + radius, index_h + radius]
  boxes[0] = normalize(boxes[0], W)
  boxes[1] = normalize(boxes[1], H)
  boxes[2] = normalize(boxes[2], W)
  boxes[3] = normalize(boxes[3], H)

  affine_parameter = torch.zeros((num_pts, 2, 3))
  affine_parameter[:,0,0] = (boxes[2]-boxes[0])/2
  affine_parameter[:,0,2] = (boxes[2]+boxes[0])/2
  affine_parameter[:,1,1] = (boxes[3]-boxes[1])/2
  affine_parameter[:,1,2] = (boxes[3]+boxes[1])/2
  # extract the sub-region heatmap
  theta = MU.np2variable(affine_parameter, heatmap.is_cuda, False)
  grid_size = torch.Size([num_pts, 1, radius*2+1, radius*2+1])
  grid = F.affine_grid(theta, grid_size)
  sub_feature = F.grid_sample(heatmap.unsqueeze(1), grid).squeeze(1)
  sub_feature = F.threshold(sub_feature, threshold, np.finfo(float).eps)

  X = MU.np2variable(torch.arange(-radius, radius+1), heatmap.is_cuda, False).view(1, 1, radius*2+1)
  Y = MU.np2variable(torch.arange(-radius, radius+1), heatmap.is_cuda, False).view(1, radius*2+1, 1)
  
  sum_region = torch.sum(sub_feature.view(num_pts,-1),1)
  x = torch.sum((sub_feature*X).view(num_pts,-1),1) / sum_region + index_w
  y = torch.sum((sub_feature*Y).view(num_pts,-1),1) / sum_region + index_h
     
  x = x * downsample + downsample / 2.0 - 0.5
  y = y * downsample + downsample / 2.0 - 0.5
  return torch.stack([x, y],1), score 
Example #17
Source File: ops.py    From MLDG with MIT License 5 votes vote down vote up
def relu(inputs):
    return F.threshold(inputs, 0, 0, inplace=True) 
Example #18
Source File: pytorch_to_caffe.py    From PytorchToCaffe with MIT License 5 votes vote down vote up
def _prelu(raw, input, weight):
    # for threshold or prelu
    x = raw(input, weight)
    bottom_blobs=[log.blobs(input)]
    name = log.add_layer(name='prelu')
    log.add_blobs([x], name='prelu_blob')
    layer = caffe_net.Layer_param(name=name, type='PReLU',
                                  bottom=bottom_blobs, top=[log.blobs(x)])
    if weight.size()[0]==1:
        layer.param.prelu_param.channel_shared=True
        layer.add_data(weight.cpu().data.numpy()[0])
    else:
        layer.add_data(weight.cpu().data.numpy())
    log.cnet.add_layer(layer)
    return x 
Example #19
Source File: transformer.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def forward(self, x, encoder_padding_mask):
        residual = x

        x = self.maybe_layer_norm(0, x, before=True)
        x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
        if self.fuse_dropout_add and self.training :
            x = fused_dropout_add(x, residual, self.dropout)
        else :
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
        x = self.maybe_layer_norm(0, x, after=True)

        residual = x
        x = self.maybe_layer_norm(1, x, before=True)

        if self.fuse_relu_dropout :
            x = fused_relu_dropout(self.fc1(x), self.relu_dropout)
        else :
            x = F.threshold(self.fc1(x),0,0)
            x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)

        if self.fuse_dropout_add and self.training :
            x = fused_dropout_add(x, residual, self.dropout)
        else :
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
        x = self.maybe_layer_norm(1, x, after=True)
        return x 
Example #20
Source File: pixelcnn_loss.py    From ssl_bad_gan with MIT License 4 votes vote down vote up
def discretized_mix_logistic_loss(x, l, sum_all=True):
    xs = x.size()    # (B,32,32,C)
    ls = l.size()    # (B,32,32,100)

    # here and below: unpacking the params of the mixture of logistics
    nr_mix = int(ls[-1] / 10)    # 10

    logit_probs = l[:, :, :, :nr_mix] # size: [B, 32, 32, 3, nr_mix]
    # l = l[:, :, :, nr_mix:].contiguous().view(xs[0], xs[1], xs[2], xs[3], nr_mix * 3) # size: [B, 32, 32, 3, 3 * nr_mix]
    l = l[:, :, :, nr_mix:].contiguous().view(xs[0], xs[1], xs[2], xs[3], -1) # size: [B, 32, 32, C, 9 * nr_mix / C]

    # size: [B, 32, 32, C, nr_mix]
    means = l[:, :, :, :, :nr_mix]
    log_scales = F.threshold(l[:, :, :, :, nr_mix:2 * nr_mix], -7., -7.)
    coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix])

    # here and below: getting the means and adjusting them based on preceding
    # sub-pixels
    x = x.unsqueeze(4).expand(xs[0], xs[1], xs[2], xs[3], nr_mix)  # size: [B, 32, 32, C, nr_mix]

    m1 = means[:, :, :, 0, :]
    m2 = means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :]
    m3 = means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]
    means = torch.cat([m1, m2, m3], 3)

    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1. / 255.)
    cdf_plus = F.sigmoid(plus_in)
    min_in = inv_stdv * (centered_x - 1. / 255.)
    cdf_min = F.sigmoid(min_in)

    # log probability for edge case of 0 (before scaling)
    log_cdf_plus = plus_in - F.softplus(plus_in)
    # log probability for edge case of 255 (before scaling)
    log_one_minus_cdf_min = -F.softplus(min_in)
    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_x
    # log probability in the center of the bin, to be used in extreme cases
    # (not actually used in our code)
    log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)

    # now select the right output: left edge case, right edge case, normal
    # case, extremely low prob case (doesn't actually happen for us)

    mask1 = (cdf_delta > 1e-5).float().detach()
    term1 = mask1 * torch.log(F.threshold(cdf_delta, 1e-12, 1e-12)) + (1. - mask1) * (log_pdf_mid - np.log(127.5))

    mask2 = (x > 0.999).float().detach()
    term2 = mask2 * log_one_minus_cdf_min + (1. - mask2) * term1

    mask3 = (x < -0.999).float().detach()
    term3 = mask3 * log_cdf_plus + (1. - mask3) * term2

    log_probs = term3.sum(3) + log_prob_from_logits(logit_probs)

    if not sum_all:
        return -log_sum_exp(log_probs).sum(1).sum(2).squeeze()
    else:
        return -log_sum_exp(log_probs).sum() 
Example #21
Source File: pixelcnn_loss.py    From ssl_bad_gan with MIT License 4 votes vote down vote up
def discretized_mix_logistic_loss_c1(x, l, sum_all=True):
    xs = x.size()    # (B,32,32,1)
    ls = l.size()    # (B,32,32,100)

    # here and below: unpacking the params of the mixture of logistics
    nr_mix = int(ls[-1] / 3)

    logit_probs = l[:, :, :, :nr_mix] # size: [B, 32, 32, nr_mix]
    # l = l[:, :, :, nr_mix:].contiguous().view(xs[0], xs[1], xs[2], xs[3], nr_mix * 3) # size: [B, 32, 32, 3, 3 * nr_mix]
    l = l[:, :, :, nr_mix:].contiguous().view(xs[0], xs[1], xs[2], xs[3], nr_mix * 2) # size: [B, 32, 32, 1, 2 * nr_mix]

    # size: [B, 32, 32, 1, nr_mix]
    means = l[:, :, :, :, :nr_mix]
    log_scales = F.threshold(l[:, :, :, :, nr_mix:2 * nr_mix], -7., -7.)
    # coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix])

    # here and below: getting the means and adjusting them based on preceding
    # sub-pixels
    x = x.unsqueeze(4).expand(xs[0], xs[1], xs[2], xs[3], nr_mix)  # size: [B, 32, 32, C, nr_mix]

    # m1 = means[:, :, :, 0, :]
    # m2 = means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :]
    # m3 = means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]
    # means = torch.cat([m1, m2, m3], 3)

    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1. / 255.)
    cdf_plus = F.sigmoid(plus_in)
    min_in = inv_stdv * (centered_x - 1. / 255.)
    cdf_min = F.sigmoid(min_in)

    # log probability for edge case of 0 (before scaling)
    log_cdf_plus = plus_in - F.softplus(plus_in)
    # log probability for edge case of 255 (before scaling)
    log_one_minus_cdf_min = -F.softplus(min_in)
    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_x
    # log probability in the center of the bin, to be used in extreme cases
    # (not actually used in our code)
    log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)

    # now select the right output: left edge case, right edge case, normal
    # case, extremely low prob case (doesn't actually happen for us)

    mask1 = (cdf_delta > 1e-5).float().detach()
    term1 = mask1 * torch.log(F.threshold(cdf_delta, 1e-12, 1e-12)) + (1. - mask1) * (log_pdf_mid - np.log(127.5))

    mask2 = (x > 0.999).float().detach()
    term2 = mask2 * log_one_minus_cdf_min + (1. - mask2) * term1

    mask3 = (x < -0.999).float().detach()
    term3 = mask3 * log_cdf_plus + (1. - mask3) * term2

    log_probs = term3.sum(3) + log_prob_from_logits(logit_probs)

    if not sum_all:
        return -log_sum_exp(log_probs).sum(1).sum(2).squeeze()
    else:
        return -log_sum_exp(log_probs).sum() 
Example #22
Source File: basic_batch.py    From landmark-detection with MIT License 4 votes vote down vote up
def find_tensor_peak_batch(heatmap, radius, downsample, threshold = 0.000001):
  assert heatmap.dim() == 3, 'The dimension of the heatmap is wrong : {}'.format(heatmap.size())
  assert radius > 0 and isinstance(radius, numbers.Number), 'The radius is not ok : {}'.format(radius)
  num_pts, H, W = heatmap.size(0), heatmap.size(1), heatmap.size(2)
  assert W > 1 and H > 1, 'To avoid the normalization function divide zero'
  # find the approximate location:
  score, index = torch.max(heatmap.view(num_pts, -1), 1)
  index_w = (index % W).float()
  index_h = (index / W).float()
  
  def normalize(x, L):
    return -1. + 2. * x.data / (L-1)
  boxes = [index_w - radius, index_h - radius, index_w + radius, index_h + radius]
  boxes[0] = normalize(boxes[0], W)
  boxes[1] = normalize(boxes[1], H)
  boxes[2] = normalize(boxes[2], W)
  boxes[3] = normalize(boxes[3], H)
  #affine_parameter = [(boxes[2]-boxes[0])/2, boxes[0]*0, (boxes[2]+boxes[0])/2,
  #                   boxes[0]*0, (boxes[3]-boxes[1])/2, (boxes[3]+boxes[1])/2]
  #theta = torch.stack(affine_parameter, 1).view(num_pts, 2, 3)

  affine_parameter = torch.zeros((num_pts, 2, 3))
  affine_parameter[:,0,0] = (boxes[2]-boxes[0])/2
  affine_parameter[:,0,2] = (boxes[2]+boxes[0])/2
  affine_parameter[:,1,1] = (boxes[3]-boxes[1])/2
  affine_parameter[:,1,2] = (boxes[3]+boxes[1])/2
  # extract the sub-region heatmap
  theta = affine_parameter.to(heatmap.device)
  grid_size = torch.Size([num_pts, 1, radius*2+1, radius*2+1])
  grid = F.affine_grid(theta, grid_size)
  sub_feature = F.grid_sample(heatmap.unsqueeze(1), grid).squeeze(1)
  sub_feature = F.threshold(sub_feature, threshold, np.finfo(float).eps)

  X = torch.arange(-radius, radius+1).to(heatmap).view(1, 1, radius*2+1)
  Y = torch.arange(-radius, radius+1).to(heatmap).view(1, radius*2+1, 1)
  
  sum_region = torch.sum(sub_feature.view(num_pts,-1),1)
  x = torch.sum((sub_feature*X).view(num_pts,-1),1) / sum_region + index_w
  y = torch.sum((sub_feature*Y).view(num_pts,-1),1) / sum_region + index_h
     
  x = x * downsample + downsample / 2.0 - 0.5
  y = y * downsample + downsample / 2.0 - 0.5
  return torch.stack([x, y],1), score 
Example #23
Source File: basic_batch.py    From landmark-detection with MIT License 4 votes vote down vote up
def find_tensor_peak_batch(heatmap, radius, downsample, threshold = 0.000001):
  assert heatmap.dim() == 3, 'The dimension of the heatmap is wrong : {}'.format(heatmap.size())
  assert radius > 0 and isinstance(radius, numbers.Number), 'The radius is not ok : {}'.format(radius)
  num_pts, H, W = heatmap.size(0), heatmap.size(1), heatmap.size(2)
  assert W > 1 and H > 1, 'To avoid the normalization function divide zero'
  # find the approximate location:
  score, index = torch.max(heatmap.view(num_pts, -1), 1)
  index_w = (index % W).float()
  index_h = (index / W).float()
  
  def normalize(x, L):
    return -1. + 2. * x.data / (L-1)
  boxes = [index_w - radius, index_h - radius, index_w + radius, index_h + radius]
  boxes[0] = normalize(boxes[0], W)
  boxes[1] = normalize(boxes[1], H)
  boxes[2] = normalize(boxes[2], W)
  boxes[3] = normalize(boxes[3], H)
  #affine_parameter = [(boxes[2]-boxes[0])/2, boxes[0]*0, (boxes[2]+boxes[0])/2,
  #                   boxes[0]*0, (boxes[3]-boxes[1])/2, (boxes[3]+boxes[1])/2]
  #theta = torch.stack(affine_parameter, 1).view(num_pts, 2, 3)

  affine_parameter = torch.zeros((num_pts, 2, 3))
  affine_parameter[:,0,0] = (boxes[2]-boxes[0])/2
  affine_parameter[:,0,2] = (boxes[2]+boxes[0])/2
  affine_parameter[:,1,1] = (boxes[3]-boxes[1])/2
  affine_parameter[:,1,2] = (boxes[3]+boxes[1])/2
  # extract the sub-region heatmap
  theta = MU.np2variable(affine_parameter, heatmap.is_cuda, False)
  grid_size = torch.Size([num_pts, 1, radius*2+1, radius*2+1])
  grid = F.affine_grid(theta, grid_size)
  sub_feature = F.grid_sample(heatmap.unsqueeze(1), grid).squeeze(1)
  sub_feature = F.threshold(sub_feature, threshold, np.finfo(float).eps)

  X = MU.np2variable(torch.arange(-radius, radius+1), heatmap.is_cuda, False).view(1, 1, radius*2+1)
  Y = MU.np2variable(torch.arange(-radius, radius+1), heatmap.is_cuda, False).view(1, radius*2+1, 1)
  
  sum_region = torch.sum(sub_feature.view(num_pts,-1),1)
  x = torch.sum((sub_feature*X).view(num_pts,-1),1) / sum_region + index_w
  y = torch.sum((sub_feature*Y).view(num_pts,-1),1) / sum_region + index_h
     
  x = x * downsample + downsample / 2.0 - 0.5
  y = y * downsample + downsample / 2.0 - 0.5
  return torch.stack([x, y],1), score 
Example #24
Source File: transformer.py    From training_results_v0.5 with Apache License 2.0 4 votes vote down vote up
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            mask_future_timesteps=True,
            incremental_state=incremental_state,
            need_weights=False,
        )
        if self.fuse_dropout_add and self.training :
            x = fused_dropout_add(x, residual, self.dropout)
        else :
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        attn = None
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            if self.fuse_dropout_add and self.training :
                x = fused_dropout_add(x, residual, self.dropout)
            else :
                x = F.dropout(x, p=self.dropout, training=self.training)
                x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        if self.fuse_relu_dropout :
            x = fused_relu_dropout(self.fc1(x), self.relu_dropout)
        else :
            x = F.threshold(self.fc1(x),0,0)
            x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        if self.fuse_dropout_add and self.training :
            x = fused_dropout_add(x, residual, self.dropout)
        else :
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        return x, attn