Python keras.utils.conv_utils.conv_output_length() Examples

The following are 30 code examples of keras.utils.conv_utils.conv_output_length(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module keras.utils.conv_utils , or try the search function .
Example #1
Source File: sn.py    From Coloring-greyscale-images with MIT License 7 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            return (input_shape[0],) + tuple(new_space) + (self.filters,)
        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            return (input_shape[0], self.filters) + tuple(new_space) 
Example #2
Source File: depthwise_conv2d.py    From keras-mobilenet with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])
        if self.data_format == 'channels_first':
            return (input_shape[0], self.filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, self.filters) 
Example #3
Source File: mobilenet.py    From keras-FP16-test with Apache License 2.0 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #4
Source File: mobilenet.py    From deep-learning-models with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #5
Source File: mobilenetv2.py    From mobilenet_v2_keras with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #6
Source File: mobilenets-checkpoint.py    From CBAM-keras with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #7
Source File: mobilenets.py    From CBAM-keras with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #8
Source File: ConvolutionalMoE.py    From mixture-of-experts with GNU General Public License v3.0 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            return (input_shape[0],) + tuple(new_space) + (self.n_filters,)
        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            return (input_shape[0], self.n_filters) + tuple(new_space) 
Example #9
Source File: gc_mobilenets.py    From keras-global-context-networks with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #10
Source File: se_mobilenets.py    From keras-squeeze-excite-network with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        if self.data_format == 'channels_first':
            return input_shape[0], out_filters, rows, cols
        elif self.data_format == 'channels_last':
            return input_shape[0], rows, cols, out_filters 
Example #11
Source File: utils.py    From face_landmark_dnn with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding, self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding, self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #12
Source File: train_mobilenets.py    From face_landmark_dnn with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding, self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding, self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #13
Source File: se_mobilenets.py    From TF.Keras-Commonly-used-models with Apache License 2.0 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #14
Source File: mobile_net_fixed.py    From kaggle-carvana-2017 with MIT License 6 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            out_filters = input_shape[3] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, out_filters) 
Example #15
Source File: conv_utils_test.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def test_conv_input_length():
    assert conv_utils.conv_input_length(None, 7, 'same', 1) is None
    assert conv_utils.conv_input_length(112, 7, 'same', 1) == 112
    assert conv_utils.conv_input_length(112, 7, 'same', 2) == 223
    assert conv_utils.conv_input_length(28, 5, 'valid', 1) == 32
    assert conv_utils.conv_input_length(14, 5, 'valid', 2) == 31
    assert conv_utils.conv_input_length(36, 5, 'full', 1) == 32
    assert conv_utils.conv_input_length(18, 5, 'full', 2) == 31

    with pytest.raises(AssertionError):
        conv_utils.conv_output_length(18, 5, 'diagonal', 2) 
Example #16
Source File: conv_utils_test.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def test_conv_output_length():
    assert conv_utils.conv_output_length(None, 7, 'same', 1) is None
    assert conv_utils.conv_output_length(224, 7, 'same', 1) == 224
    assert conv_utils.conv_output_length(224, 7, 'same', 2) == 112
    assert conv_utils.conv_output_length(32, 5, 'valid', 1) == 28
    assert conv_utils.conv_output_length(32, 5, 'valid', 2) == 14
    assert conv_utils.conv_output_length(32, 5, 'causal', 1) == 32
    assert conv_utils.conv_output_length(32, 5, 'causal', 2) == 16
    assert conv_utils.conv_output_length(32, 5, 'full', 1) == 36
    assert conv_utils.conv_output_length(32, 5, 'full', 2) == 18

    with pytest.raises(AssertionError):
        conv_utils.conv_output_length(32, 5, 'diagonal', 2) 
Example #17
Source File: conv_utils_test.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def test_conv_input_length():
    assert conv_utils.conv_input_length(None, 7, 'same', 1) is None
    assert conv_utils.conv_input_length(112, 7, 'same', 1) == 112
    assert conv_utils.conv_input_length(112, 7, 'same', 2) == 223
    assert conv_utils.conv_input_length(28, 5, 'valid', 1) == 32
    assert conv_utils.conv_input_length(14, 5, 'valid', 2) == 31
    assert conv_utils.conv_input_length(36, 5, 'full', 1) == 32
    assert conv_utils.conv_input_length(18, 5, 'full', 2) == 31

    with pytest.raises(AssertionError):
        conv_utils.conv_output_length(18, 5, 'diagonal', 2) 
Example #18
Source File: DepthwiseConv3D.py    From keras-DepthwiseConv3D with MIT License 5 votes vote down vote up
def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            depth = input_shape[2]
            rows = input_shape[3]
            cols = input_shape[4]
            out_filters = self.groups * self.depth_multiplier
        elif self.data_format == 'channels_last':
            depth = input_shape[1]
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = self.groups * self.depth_multiplier

        depth = conv_utils.conv_output_length(depth, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])

        rows = conv_utils.conv_output_length(rows, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        cols = conv_utils.conv_output_length(cols, self.kernel_size[2],
                                             self.padding,
                                             self.strides[2])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, depth, rows, cols)

        elif self.data_format == 'channels_last':
            return (input_shape[0], depth, rows, cols, out_filters) 
Example #19
Source File: qrnn.py    From nn_playground with MIT License 5 votes vote down vote up
def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            input_shape = input_shape[0]

        length = input_shape[1]
        if length:
            length = conv_output_length(length + self.window_size - 1,
                                        self.window_size, 'valid',
                                        self.strides[0])
        if self.return_sequences:
            return (input_shape[0], length, self.units)
        else:
            return (input_shape[0], self.units) 
Example #20
Source File: gcnn.py    From nn_playground with MIT License 5 votes vote down vote up
def compute_output_shape(self, input_shape):
        length = input_shape[1]
        if length:
            length = conv_output_length(length + self.window_size - 1,
                                        self.window_size, 'valid',
                                        self.strides[0])
        return (input_shape[0], length, self.output_dim) 
Example #21
Source File: capsule_layers.py    From SegCaps with Apache License 2.0 5 votes vote down vote up
def compute_output_shape(self, input_shape):
        space = input_shape[1:-2]
        new_space = []
        for i in range(len(space)):
            new_dim = conv_output_length(
                space[i],
                self.kernel_size,
                padding=self.padding,
                stride=self.strides,
                dilation=1)
            new_space.append(new_dim)

        return (input_shape[0],) + tuple(new_space) + (self.num_capsule, self.num_atoms) 
Example #22
Source File: capslayers.py    From deepcaps with MIT License 5 votes vote down vote up
def build(self, input_shape):

        self.h_i, self.w_i, self.ch_i, self.n_i = input_shape[1:5]

        self.h_j, self.w_j = [conv_utils.conv_output_length(input_shape[i + 1],
                                                            self.kernel_size[i],
                                                            padding=self.padding,
                                                            stride=self.strides[i],
                                                            dilation=self.dilation_rate[i]) for i in (0, 1)]

        self.ah_j, self.aw_j = [conv_utils.conv_output_length(input_shape[i + 1],
                                                              self.kernel_size[i],
                                                              padding=self.padding,
                                                              stride=1,
                                                              dilation=self.dilation_rate[i]) for i in (0, 1)]

        self.w_shape = self.kernel_size + (self.ch_i, self.n_i,
                                           self.ch_j, self.n_j)

        self.w = self.add_weight(shape=self.w_shape,
                                 initializer=self.kernel_initializer,
                                 name='kernel',
                                 regularizer=self.kernel_regularizer,
                                 constraint=self.kernel_constraint)

        self.built = True 
Example #23
Source File: capslayers.py    From deepcaps with MIT License 5 votes vote down vote up
def compute_output_shape(self, input_shape):
        space = input_shape[1:-2]
        new_space = []
        for i in range(len(space)):
            new_dim = conv_output_length(space[i], self.kernel_size, padding=self.padding, stride=self.strides, dilation=1)
            new_space.append(new_dim)

        return (input_shape[0],) + tuple(new_space) + (self.num_capsule, self.num_atoms) 
Example #24
Source File: gram.py    From subjective-functions with MIT License 5 votes vote down vote up
def reduce_layer(a=0.4, padding_mode='valid'):
    # A 5-tap Gaussian pyramid generating kernel from Burt & Adelson 1983.
    kernel_1d = [0.25 - a/2, 0.25, a, 0.25, 0.25 - a/2]
    
    #kernel_2d = np.outer(kernel_1d, kernel_1d)
    
    # This doesn't seem very computationally bright; but there you have it.
    #kernel_4d = np.zeros((5, 5, 3, 3), 'float32')
    #kernel_4d[:,:,0,0] = kernel_2d
    #kernel_4d[:,:,1,1] = kernel_2d
    #kernel_4d[:,:,2,2] = kernel_2d
    kernel_3d = np.zeros((5, 1, 3, 3), 'float32')
    kernel_3d[:, 0, 0, 0] = kernel_1d
    kernel_3d[:, 0, 1, 1] = kernel_1d
    kernel_3d[:, 0, 2, 2] = kernel_1d

    def fn(x):
        return K.conv2d(K.conv2d(x, kernel_3d, strides=(2,1)),
                K.permute_dimensions(kernel_3d, (1, 0, 2, 3)),
                strides = (1, 2))
    
    def shape(input_shape):
        assert len(input_shape) == 4
        assert K.image_data_format() == 'channels_last'
        space = input_shape[1:-1]
        new_space = []
        for i, dim in enumerate(space):
            new_dim = conv_utils.conv_output_length(
                dim,
                5,
                padding=padding_mode,
                stride=2)
            new_space.append(new_dim)
        return (input_shape[0],) + tuple(new_space) + (input_shape[3],)
    
    return Lambda(fn, shape) 
Example #25
Source File: custom_objects.py    From keras_mixnets with MIT License 5 votes vote down vote up
def compute_output_shape(self, input_shape):
        space = input_shape[1:-1]
        new_space = []
        for i in range(len(space)):
            new_dim = conv_utils.conv_output_length(
                space[i],
                filter_size=1,
                padding=self.padding,
                stride=self.strides[i],
                dilation=self.dilation_rate[i])
            new_space.append(new_dim)
        return (input_shape[0],) + tuple(new_space) + (self.filters,) 
Example #26
Source File: conv_utils_test.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def test_conv_output_length():
    assert conv_utils.conv_output_length(None, 7, 'same', 1) is None
    assert conv_utils.conv_output_length(224, 7, 'same', 1) == 224
    assert conv_utils.conv_output_length(224, 7, 'same', 2) == 112
    assert conv_utils.conv_output_length(32, 5, 'valid', 1) == 28
    assert conv_utils.conv_output_length(32, 5, 'valid', 2) == 14
    assert conv_utils.conv_output_length(32, 5, 'causal', 1) == 32
    assert conv_utils.conv_output_length(32, 5, 'causal', 2) == 16
    assert conv_utils.conv_output_length(32, 5, 'full', 1) == 36
    assert conv_utils.conv_output_length(32, 5, 'full', 2) == 18

    with pytest.raises(AssertionError):
        conv_utils.conv_output_length(32, 5, 'diagonal', 2) 
Example #27
Source File: conv_utils_test.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def test_conv_output_length():
    assert conv_utils.conv_output_length(None, 7, 'same', 1) is None
    assert conv_utils.conv_output_length(224, 7, 'same', 1) == 224
    assert conv_utils.conv_output_length(224, 7, 'same', 2) == 112
    assert conv_utils.conv_output_length(32, 5, 'valid', 1) == 28
    assert conv_utils.conv_output_length(32, 5, 'valid', 2) == 14
    assert conv_utils.conv_output_length(32, 5, 'causal', 1) == 32
    assert conv_utils.conv_output_length(32, 5, 'causal', 2) == 16
    assert conv_utils.conv_output_length(32, 5, 'full', 1) == 36
    assert conv_utils.conv_output_length(32, 5, 'full', 2) == 18

    with pytest.raises(AssertionError):
        conv_utils.conv_output_length(32, 5, 'diagonal', 2) 
Example #28
Source File: conv_utils_test.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def test_conv_input_length():
    assert conv_utils.conv_input_length(None, 7, 'same', 1) is None
    assert conv_utils.conv_input_length(112, 7, 'same', 1) == 112
    assert conv_utils.conv_input_length(112, 7, 'same', 2) == 223
    assert conv_utils.conv_input_length(28, 5, 'valid', 1) == 32
    assert conv_utils.conv_input_length(14, 5, 'valid', 2) == 31
    assert conv_utils.conv_input_length(36, 5, 'full', 1) == 32
    assert conv_utils.conv_input_length(18, 5, 'full', 2) == 31

    with pytest.raises(AssertionError):
        conv_utils.conv_output_length(18, 5, 'diagonal', 2) 
Example #29
Source File: conv_utils_test.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def test_conv_output_length():
    assert conv_utils.conv_output_length(None, 7, 'same', 1) is None
    assert conv_utils.conv_output_length(224, 7, 'same', 1) == 224
    assert conv_utils.conv_output_length(224, 7, 'same', 2) == 112
    assert conv_utils.conv_output_length(32, 5, 'valid', 1) == 28
    assert conv_utils.conv_output_length(32, 5, 'valid', 2) == 14
    assert conv_utils.conv_output_length(32, 5, 'causal', 1) == 32
    assert conv_utils.conv_output_length(32, 5, 'causal', 2) == 16
    assert conv_utils.conv_output_length(32, 5, 'full', 1) == 36
    assert conv_utils.conv_output_length(32, 5, 'full', 2) == 18

    with pytest.raises(AssertionError):
        conv_utils.conv_output_length(32, 5, 'diagonal', 2) 
Example #30
Source File: conv_utils_test.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def test_conv_input_length():
    assert conv_utils.conv_input_length(None, 7, 'same', 1) is None
    assert conv_utils.conv_input_length(112, 7, 'same', 1) == 112
    assert conv_utils.conv_input_length(112, 7, 'same', 2) == 223
    assert conv_utils.conv_input_length(28, 5, 'valid', 1) == 32
    assert conv_utils.conv_input_length(14, 5, 'valid', 2) == 31
    assert conv_utils.conv_input_length(36, 5, 'full', 1) == 32
    assert conv_utils.conv_input_length(18, 5, 'full', 2) == 31

    with pytest.raises(AssertionError):
        conv_utils.conv_output_length(18, 5, 'diagonal', 2)