Python tensorflow.contrib.slim.nets.resnet_v2.resnet_arg_scope() Examples

The following are 16 code examples of tensorflow.contrib.slim.nets.resnet_v2.resnet_arg_scope(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.contrib.slim.nets.resnet_v2 , or try the search function .
Example #1
Source File: resnet_v2.py    From tf-imagenet with Apache License 2.0 6 votes vote down vote up
def forward(self, inputs, num_classes, data_format, is_training):
        sc = resnet_arg_scope(
            weight_decay=0.0001,
            data_format=data_format,
            batch_norm_decay=0.997,
            batch_norm_epsilon=1e-5,
            batch_norm_scale=True,
            activation_fn=tf.nn.relu,
            use_batch_norm=True,
            is_training=is_training)
        with slim.arg_scope(sc):
            logits, end_points = resnet_v2_50(
                inputs,
                num_classes=num_classes,
                is_training=is_training,
                global_pool=True,
                output_stride=None,
                reuse=None,
                scope=self.scope)
            return logits, end_points 
Example #2
Source File: resnet_v2.py    From tf-imagenet with Apache License 2.0 6 votes vote down vote up
def forward(self, inputs, num_classes, data_format, is_training):
        sc = resnet_arg_scope(
            weight_decay=0.0001,
            data_format=data_format,
            batch_norm_decay=0.997,
            batch_norm_epsilon=1e-5,
            batch_norm_scale=True,
            activation_fn=tf.nn.relu,
            use_batch_norm=True,
            is_training=is_training)
        with slim.arg_scope(sc):
            logits, end_points = resnet_v2_101(
                inputs,
                num_classes=num_classes,
                is_training=is_training,
                global_pool=True,
                output_stride=None,
                reuse=None,
                scope=self.scope)
            return logits, end_points 
Example #3
Source File: resnet_v2.py    From tf-imagenet with Apache License 2.0 6 votes vote down vote up
def forward(self, inputs, num_classes, data_format, is_training):
        sc = resnet_arg_scope(
            weight_decay=0.0001,
            data_format=data_format,
            batch_norm_decay=0.997,
            batch_norm_epsilon=1e-5,
            batch_norm_scale=True,
            activation_fn=tf.nn.relu,
            use_batch_norm=True,
            is_training=is_training)
        with slim.arg_scope(sc):
            logits, end_points = resnet_v2_152(
                inputs,
                num_classes=num_classes,
                is_training=is_training,
                global_pool=True,
                output_stride=None,
                reuse=None,
                scope=self.scope)
            return logits, end_points

# =========================================================================== #
# Functional definition.
# =========================================================================== # 
Example #4
Source File: model_lib.py    From adversarial-logit-pairing-analysis with Apache License 2.0 6 votes vote down vote up
def get_model(model_name, num_classes):
  """Returns function which creates model.

  Args:
    model_name: Name of the model.
    num_classes: Number of classes.

  Raises:
    ValueError: If model_name is invalid.

  Returns:
    Function, which creates model when called.
  """
  if model_name.startswith('resnet'):
    def resnet_model(images, is_training, reuse=tf.AUTO_REUSE):
      with tf.contrib.framework.arg_scope(resnet_v2.resnet_arg_scope()):
        resnet_fn = RESNET_MODELS[model_name]
        logits, _ = resnet_fn(images, num_classes, is_training=is_training,
                              reuse=reuse)
        logits = tf.reshape(logits, [-1, num_classes])
      return logits
    return resnet_model
  else:
    raise ValueError('Invalid model: %s' % model_name) 
Example #5
Source File: model_lib.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def get_model(model_name, num_classes):
  """Returns function which creates model.

  Args:
    model_name: Name of the model.
    num_classes: Number of classes.

  Raises:
    ValueError: If model_name is invalid.

  Returns:
    Function, which creates model when called.
  """
  if model_name.startswith('resnet'):
    def resnet_model(images, is_training, reuse=tf.AUTO_REUSE):
      with tf.contrib.framework.arg_scope(resnet_v2.resnet_arg_scope()):
        resnet_fn = RESNET_MODELS[model_name]
        logits, _ = resnet_fn(images, num_classes, is_training=is_training,
                              reuse=reuse)
        logits = tf.reshape(logits, [-1, num_classes])
      return logits
    return resnet_model
  else:
    raise ValueError('Invalid model: %s' % model_name) 
Example #6
Source File: model_lib.py    From models with Apache License 2.0 6 votes vote down vote up
def get_model(model_name, num_classes):
  """Returns function which creates model.

  Args:
    model_name: Name of the model.
    num_classes: Number of classes.

  Raises:
    ValueError: If model_name is invalid.

  Returns:
    Function, which creates model when called.
  """
  if model_name.startswith('resnet'):
    def resnet_model(images, is_training, reuse=tf.AUTO_REUSE):
      with tf.contrib.framework.arg_scope(resnet_v2.resnet_arg_scope()):
        resnet_fn = RESNET_MODELS[model_name]
        logits, _ = resnet_fn(images, num_classes, is_training=is_training,
                              reuse=reuse)
        logits = tf.reshape(logits, [-1, num_classes])
      return logits
    return resnet_model
  else:
    raise ValueError('Invalid model: %s' % model_name) 
Example #7
Source File: model_lib.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def get_model(model_name, num_classes):
  """Returns function which creates model.

  Args:
    model_name: Name of the model.
    num_classes: Number of classes.

  Raises:
    ValueError: If model_name is invalid.

  Returns:
    Function, which creates model when called.
  """
  if model_name.startswith('resnet'):
    def resnet_model(images, is_training, reuse=tf.AUTO_REUSE):
      with tf.contrib.framework.arg_scope(resnet_v2.resnet_arg_scope()):
        resnet_fn = RESNET_MODELS[model_name]
        logits, _ = resnet_fn(images, num_classes, is_training=is_training,
                              reuse=reuse)
        logits = tf.reshape(logits, [-1, num_classes])
      return logits
    return resnet_model
  else:
    raise ValueError('Invalid model: %s' % model_name) 
Example #8
Source File: resnet_v2_layernorm.py    From TwinGAN with Apache License 2.0 5 votes vote down vote up
def resnet_arg_scope(weight_decay=0.0001,
                     activation_fn=tf.nn.relu,
                     use_layer_norm=True):
  """Defines the default ResNet arg scope.

  TODO(gpapan): The batch-normalization related default values above are
    appropriate for use in conjunction with the reference ResNet models
    released at https://github.com/KaimingHe/deep-residual-networks. When
    training ResNets from scratch, they might need to be tuned.

  Args:
    weight_decay: The weight decay to use for regularizing the model.
    activation_fn: The activation function which is used in ResNet.
    use_layer_norm: Whether or not to use layer normalization.

  Returns:
    An `arg_scope` to use for the resnet models.
  """

  with slim.arg_scope(
      [slim.conv2d],
      weights_regularizer=slim.l2_regularizer(weight_decay),
      weights_initializer=slim.variance_scaling_initializer(),
      activation_fn=activation_fn,
      normalizer_fn=slim.layer_norm if use_layer_norm else None,
      normalizer_params=None):
    # The following implies padding='SAME' for pool1, which makes feature
    # alignment easier for dense prediction tasks. This is also used in
    # https://github.com/facebook/fb.resnet.torch. However the accompanying
    # code of 'Deep Residual Learning for Image Recognition' uses
    # padding='VALID' for pool1. You can switch to that choice by setting
    # slim.arg_scope([slim.max_pool2d], padding='VALID').
    with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
      return arg_sc 
Example #9
Source File: ResNet.py    From TensorFlow-HelloWorld with Apache License 2.0 5 votes vote down vote up
def resnet_arg_scope(is_training=True, # 训练标记
                     weight_decay=0.0001, # 权重衰减速率
                     batch_norm_decay=0.997, # BN的衰减速率
                     batch_norm_epsilon=1e-5, #  BN的epsilon默认1e-5
                     batch_norm_scale=True): # BN的scale默认值

  batch_norm_params = { # 定义batch normalization(标准化)的参数字典
      'is_training': is_training,
      'decay': batch_norm_decay,
      'epsilon': batch_norm_epsilon,
      'scale': batch_norm_scale,
      'updates_collections': tf.GraphKeys.UPDATE_OPS,
  }

  with slim.arg_scope( # 通过slim.arg_scope将[slim.conv2d]的几个默认参数设置好
      [slim.conv2d],
      weights_regularizer=slim.l2_regularizer(weight_decay), # 权重正则器设置为L2正则 
      weights_initializer=slim.variance_scaling_initializer(), # 权重初始化器
      activation_fn=tf.nn.relu, # 激活函数
      normalizer_fn=slim.batch_norm, # 标准化器设置为BN
      normalizer_params=batch_norm_params):
    with slim.arg_scope([slim.batch_norm], **batch_norm_params):
      with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: # ResNet原论文是VALID模式,SAME模式可让特征对齐更简单
        return arg_sc # 最后将基层嵌套的arg_scope作为结果返回



# 定义核心的bottleneck残差学习单元 
Example #10
Source File: deeplab_model.py    From deepglobe_land_cover_classification_with_deeplabv3plus with MIT License 4 votes vote down vote up
def atrous_spatial_pyramid_pooling(inputs, output_stride, batch_norm_decay, is_training, depth=256):
    """Atrous Spatial Pyramid Pooling.

    Args:
      inputs: A tensor of size [batch, height, width, channels].
      output_stride: The ResNet unit's stride. Determines the rates for atrous convolution.
        the rates are (6, 12, 18) when the stride is 16, and doubled when 8.
      batch_norm_decay: The moving average decay when estimating layer activation
        statistics in batch normalization.
      is_training: A boolean denoting whether the input is for training.
      depth: The depth of the ResNet unit output.

    Returns:
      The atrous spatial pyramid pooling output.
    """
    with tf.variable_scope("aspp"):
        if output_stride not in [8, 16]:
            raise ValueError('output_stride must be either 8 or 16.')

        atrous_rates = [6, 12, 18]
        if output_stride == 8:
            atrous_rates = [2 * rate for rate in atrous_rates]

        with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=batch_norm_decay)):
            with arg_scope([layers.batch_norm], is_training=is_training):
                inputs_size = tf.shape(inputs)[1:3]
                # (a) one 1x1 convolution and three 3x3 convolutions with rates = (6, 12, 18) when output stride = 16.
                # the rates are doubled when output stride = 8.
                conv_1x1 = layers_lib.conv2d(inputs, depth, [1, 1], stride=1, scope="conv_1x1")
                conv_3x3_1 = layers_lib.conv2d(inputs, depth, [3, 3], stride=1, rate=atrous_rates[0],
                                               scope='conv_3x3_1')
                conv_3x3_2 = layers_lib.conv2d(inputs, depth, [3, 3], stride=1, rate=atrous_rates[1],
                                               scope='conv_3x3_2')
                conv_3x3_3 = layers_lib.conv2d(inputs, depth, [3, 3], stride=1, rate=atrous_rates[2],
                                               scope='conv_3x3_3')

                # (b) the image-level features
                with tf.variable_scope("image_level_features"):
                    # global average pooling
                    image_level_features = tf.reduce_mean(inputs, [1, 2], name='global_average_pooling', keepdims=True)
                    # 1x1 convolution with 256 filters( and batch normalization)
                    image_level_features = layers_lib.conv2d(image_level_features, depth, [1, 1], stride=1,
                                                             scope='conv_1x1')
                    # bilinearly upsample features
                    image_level_features = tf.image.resize_bilinear(image_level_features, inputs_size, name='upsample')

                net = tf.concat([conv_1x1, conv_3x3_1, conv_3x3_2, conv_3x3_3, image_level_features], axis=3,
                                name='concat')
                net = layers_lib.conv2d(net, depth, [1, 1], stride=1, scope='conv_1x1_concat')

                return net 
Example #11
Source File: deeplab_model.py    From tensorflow-deeplab-v3 with MIT License 4 votes vote down vote up
def atrous_spatial_pyramid_pooling(inputs, output_stride, batch_norm_decay, is_training, depth=256):
  """Atrous Spatial Pyramid Pooling.

  Args:
    inputs: A tensor of size [batch, height, width, channels].
    output_stride: The ResNet unit's stride. Determines the rates for atrous convolution.
      the rates are (6, 12, 18) when the stride is 16, and doubled when 8.
    batch_norm_decay: The moving average decay when estimating layer activation
      statistics in batch normalization.
    is_training: A boolean denoting whether the input is for training.
    depth: The depth of the ResNet unit output.

  Returns:
    The atrous spatial pyramid pooling output.
  """
  with tf.variable_scope("aspp"):
    if output_stride not in [8, 16]:
      raise ValueError('output_stride must be either 8 or 16.')

    atrous_rates = [6, 12, 18]
    if output_stride == 8:
      atrous_rates = [2*rate for rate in atrous_rates]

    with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=batch_norm_decay)):
      with arg_scope([layers.batch_norm], is_training=is_training):
        inputs_size = tf.shape(inputs)[1:3]
        # (a) one 1x1 convolution and three 3x3 convolutions with rates = (6, 12, 18) when output stride = 16.
        # the rates are doubled when output stride = 8.
        conv_1x1 = layers_lib.conv2d(inputs, depth, [1, 1], stride=1, scope="conv_1x1")
        conv_3x3_1 = resnet_utils.conv2d_same(inputs, depth, 3, stride=1, rate=atrous_rates[0], scope='conv_3x3_1')
        conv_3x3_2 = resnet_utils.conv2d_same(inputs, depth, 3, stride=1, rate=atrous_rates[1], scope='conv_3x3_2')
        conv_3x3_3 = resnet_utils.conv2d_same(inputs, depth, 3, stride=1, rate=atrous_rates[2], scope='conv_3x3_3')

        # (b) the image-level features
        with tf.variable_scope("image_level_features"):
          # global average pooling
          image_level_features = tf.reduce_mean(inputs, [1, 2], name='global_average_pooling', keepdims=True)
          # 1x1 convolution with 256 filters( and batch normalization)
          image_level_features = layers_lib.conv2d(image_level_features, depth, [1, 1], stride=1, scope='conv_1x1')
          # bilinearly upsample features
          image_level_features = tf.image.resize_bilinear(image_level_features, inputs_size, name='upsample')

        net = tf.concat([conv_1x1, conv_3x3_1, conv_3x3_2, conv_3x3_3, image_level_features], axis=3, name='concat')
        net = layers_lib.conv2d(net, depth, [1, 1], stride=1, scope='conv_1x1_concat')

        return net 
Example #12
Source File: deeplab_model.py    From tensorflow-deeplab-v3-plus with MIT License 4 votes vote down vote up
def atrous_spatial_pyramid_pooling(inputs, output_stride, batch_norm_decay, is_training, depth=256):
  """Atrous Spatial Pyramid Pooling.

  Args:
    inputs: A tensor of size [batch, height, width, channels].
    output_stride: The ResNet unit's stride. Determines the rates for atrous convolution.
      the rates are (6, 12, 18) when the stride is 16, and doubled when 8.
    batch_norm_decay: The moving average decay when estimating layer activation
      statistics in batch normalization.
    is_training: A boolean denoting whether the input is for training.
    depth: The depth of the ResNet unit output.

  Returns:
    The atrous spatial pyramid pooling output.
  """
  with tf.variable_scope("aspp"):
    if output_stride not in [8, 16]:
      raise ValueError('output_stride must be either 8 or 16.')

    atrous_rates = [6, 12, 18]
    if output_stride == 8:
      atrous_rates = [2*rate for rate in atrous_rates]

    with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=batch_norm_decay)):
      with arg_scope([layers.batch_norm], is_training=is_training):
        inputs_size = tf.shape(inputs)[1:3]
        # (a) one 1x1 convolution and three 3x3 convolutions with rates = (6, 12, 18) when output stride = 16.
        # the rates are doubled when output stride = 8.
        conv_1x1 = layers_lib.conv2d(inputs, depth, [1, 1], stride=1, scope="conv_1x1")
        conv_3x3_1 = layers_lib.conv2d(inputs, depth, [3, 3], stride=1, rate=atrous_rates[0], scope='conv_3x3_1')
        conv_3x3_2 = layers_lib.conv2d(inputs, depth, [3, 3], stride=1, rate=atrous_rates[1], scope='conv_3x3_2')
        conv_3x3_3 = layers_lib.conv2d(inputs, depth, [3, 3], stride=1, rate=atrous_rates[2], scope='conv_3x3_3')

        # (b) the image-level features
        with tf.variable_scope("image_level_features"):
          # global average pooling
          image_level_features = tf.reduce_mean(inputs, [1, 2], name='global_average_pooling', keepdims=True)
          # 1x1 convolution with 256 filters( and batch normalization)
          image_level_features = layers_lib.conv2d(image_level_features, depth, [1, 1], stride=1, scope='conv_1x1')
          # bilinearly upsample features
          image_level_features = tf.image.resize_bilinear(image_level_features, inputs_size, name='upsample')

        net = tf.concat([conv_1x1, conv_3x3_1, conv_3x3_2, conv_3x3_3, image_level_features], axis=3, name='concat')
        net = layers_lib.conv2d(net, depth, [1, 1], stride=1, scope='conv_1x1_concat')

        return net 
Example #13
Source File: 6_4_ResNet.py    From TensorFlow-HelloWorld with Apache License 2.0 4 votes vote down vote up
def resnet_arg_scope(is_training=True,
                     weight_decay=0.0001,
                     batch_norm_decay=0.997,
                     batch_norm_epsilon=1e-5,
                     batch_norm_scale=True):
  """Defines the default ResNet arg scope.

  TODO(gpapan): The batch-normalization related default values above are
    appropriate for use in conjunction with the reference ResNet models
    released at https://github.com/KaimingHe/deep-residual-networks. When
    training ResNets from scratch, they might need to be tuned.

  Args:
    is_training: Whether or not we are training the parameters in the batch
      normalization layers of the model.
    weight_decay: The weight decay to use for regularizing the model.
    batch_norm_decay: The moving average decay when estimating layer activation
      statistics in batch normalization.
    batch_norm_epsilon: Small constant to prevent division by zero when
      normalizing activations by their variance in batch normalization.
    batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
      activations in the batch normalization layer.

  Returns:
    An `arg_scope` to use for the resnet models.
  """
  batch_norm_params = {
      'is_training': is_training,
      'decay': batch_norm_decay,
      'epsilon': batch_norm_epsilon,
      'scale': batch_norm_scale,
      'updates_collections': tf.GraphKeys.UPDATE_OPS,
  }

  with slim.arg_scope(
      [slim.conv2d],
      weights_regularizer=slim.l2_regularizer(weight_decay),
      weights_initializer=slim.variance_scaling_initializer(),
      activation_fn=tf.nn.relu,
      normalizer_fn=slim.batch_norm,
      normalizer_params=batch_norm_params):
    with slim.arg_scope([slim.batch_norm], **batch_norm_params):
      # The following implies padding='SAME' for pool1, which makes feature
      # alignment easier for dense prediction tasks. This is also used in
      # https://github.com/facebook/fb.resnet.torch. However the accompanying
      # code of 'Deep Residual Learning for Image Recognition' uses
      # padding='VALID' for pool1. You can switch to that choice by setting
      # slim.arg_scope([slim.max_pool2d], padding='VALID').
      with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
        return arg_sc 
Example #14
Source File: embedders.py    From g-tensorflow-models with Apache License 2.0 4 votes vote down vote up
def build(self, images):
    """Builds a ResNet50 embedder for the input images.

    It assumes that the range of the pixel values in the images tensor is
      [0,255] and should be castable to tf.uint8.

    Args:
      images: a tensor that contains the input images which has the shape of
          NxTxHxWx3 where N is the batch size, T is the maximum length of the
          sequence, H and W are the height and width of the images and C is the
          number of channels.
    Returns:
      The embedding of the input image with the shape of NxTxL where L is the
        embedding size of the output.

    Raises:
      ValueError: if the shape of the input does not agree with the expected
      shape explained in the Args section.
    """
    shape = images.get_shape().as_list()
    if len(shape) != 5:
      raise ValueError(
          'The tensor shape should have 5 elements, {} is provided'.format(
              len(shape)))
    if shape[4] != 3:
      raise ValueError('Three channels are expected for the input image')

    images = tf.cast(images, tf.uint8)
    images = tf.reshape(images,
                        [shape[0] * shape[1], shape[2], shape[3], shape[4]])
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):

      def preprocess_fn(x):
        x = tf.expand_dims(x, 0)
        x = tf.image.resize_bilinear(x, [299, 299],
                                       align_corners=False)
        return(tf.squeeze(x, [0]))

      images = tf.map_fn(preprocess_fn, images, dtype=tf.float32)

      net, _ = resnet_v2.resnet_v2_50(
          images, is_training=False, global_pool=True)
      output = tf.reshape(net, [shape[0], shape[1], -1])
      return output 
Example #15
Source File: embedders.py    From models with Apache License 2.0 4 votes vote down vote up
def build(self, images):
    """Builds a ResNet50 embedder for the input images.

    It assumes that the range of the pixel values in the images tensor is
      [0,255] and should be castable to tf.uint8.

    Args:
      images: a tensor that contains the input images which has the shape of
          NxTxHxWx3 where N is the batch size, T is the maximum length of the
          sequence, H and W are the height and width of the images and C is the
          number of channels.
    Returns:
      The embedding of the input image with the shape of NxTxL where L is the
        embedding size of the output.

    Raises:
      ValueError: if the shape of the input does not agree with the expected
      shape explained in the Args section.
    """
    shape = images.get_shape().as_list()
    if len(shape) != 5:
      raise ValueError(
          'The tensor shape should have 5 elements, {} is provided'.format(
              len(shape)))
    if shape[4] != 3:
      raise ValueError('Three channels are expected for the input image')

    images = tf.cast(images, tf.uint8)
    images = tf.reshape(images,
                        [shape[0] * shape[1], shape[2], shape[3], shape[4]])
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):

      def preprocess_fn(x):
        x = tf.expand_dims(x, 0)
        x = tf.image.resize_bilinear(x, [299, 299],
                                       align_corners=False)
        return(tf.squeeze(x, [0]))

      images = tf.map_fn(preprocess_fn, images, dtype=tf.float32)

      net, _ = resnet_v2.resnet_v2_50(
          images, is_training=False, global_pool=True)
      output = tf.reshape(net, [shape[0], shape[1], -1])
      return output 
Example #16
Source File: embedders.py    From multilabel-image-classification-tensorflow with MIT License 4 votes vote down vote up
def build(self, images):
    """Builds a ResNet50 embedder for the input images.

    It assumes that the range of the pixel values in the images tensor is
      [0,255] and should be castable to tf.uint8.

    Args:
      images: a tensor that contains the input images which has the shape of
          NxTxHxWx3 where N is the batch size, T is the maximum length of the
          sequence, H and W are the height and width of the images and C is the
          number of channels.
    Returns:
      The embedding of the input image with the shape of NxTxL where L is the
        embedding size of the output.

    Raises:
      ValueError: if the shape of the input does not agree with the expected
      shape explained in the Args section.
    """
    shape = images.get_shape().as_list()
    if len(shape) != 5:
      raise ValueError(
          'The tensor shape should have 5 elements, {} is provided'.format(
              len(shape)))
    if shape[4] != 3:
      raise ValueError('Three channels are expected for the input image')

    images = tf.cast(images, tf.uint8)
    images = tf.reshape(images,
                        [shape[0] * shape[1], shape[2], shape[3], shape[4]])
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):

      def preprocess_fn(x):
        x = tf.expand_dims(x, 0)
        x = tf.image.resize_bilinear(x, [299, 299],
                                       align_corners=False)
        return(tf.squeeze(x, [0]))

      images = tf.map_fn(preprocess_fn, images, dtype=tf.float32)

      net, _ = resnet_v2.resnet_v2_50(
          images, is_training=False, global_pool=True)
      output = tf.reshape(net, [shape[0], shape[1], -1])
      return output