Python cntk.reshape() Examples

The following are 30 code examples of cntk.reshape(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module cntk , or try the search function .
Example #1
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def squeeze(x, axis):
    if isinstance(axis, tuple):
        axis = list(axis)
    if not isinstance(axis, list):
        axis = [axis]

    shape = list(int_shape(x))

    _axis = []
    for _ in axis:
        if isinstance(_, int):
            _axis.append(_ if _ >= 0 else _ + len(shape))

    if len(_axis) == 0:
        return x

    nones = _get_dynamic_axis_num(x)
    for _ in sorted(_axis, reverse=True):
        del shape[_]

    new_shape = shape[nones:]
    new_shape = tuple([C.InferredDimension if _ == C.FreeDimension else _ for _ in new_shape])
    return C.reshape(x, new_shape) 
Example #2
Source File: cntk_backend.py    From GraphicDesignPatternByPython with MIT License 6 votes vote down vote up
def _reshape_dummy_dim(x, axis):
    shape = list(x.shape)

    _axis = [_ + len(shape) if _ < 0 else _ for _ in axis]

    if shape.count(C.InferredDimension) > 1 or shape.count(C.FreeDimension) > 1:
        result = x
        for index in sorted(_axis, reverse=True):
            result = C.reshape(result,
                               shape=(),
                               begin_axis=index,
                               end_axis=index + 1)
        return result
    else:
        for index in sorted(_axis, reverse=True):
            del shape[index]

        shape = [C.InferredDimension if _ == C.FreeDimension else _ for _ in shape]
        return C.reshape(x, shape) 
Example #3
Source File: cntk_backend.py    From GraphicDesignPatternByPython with MIT License 6 votes vote down vote up
def squeeze(x, axis):
    if isinstance(axis, tuple):
        axis = list(axis)
    if not isinstance(axis, list):
        axis = [axis]

    shape = list(int_shape(x))

    _axis = []
    for _ in axis:
        if isinstance(_, int):
            _axis.append(_ if _ >= 0 else _ + len(shape))

    if len(_axis) == 0:
        return x

    nones = _get_dynamic_axis_num(x)
    for _ in sorted(_axis, reverse=True):
        del shape[_]

    new_shape = shape[nones:]
    new_shape = tuple([C.InferredDimension if _ == C.FreeDimension else _ for _ in new_shape])
    return C.reshape(x, new_shape) 
Example #4
Source File: cntk_backend.py    From GraphicDesignPatternByPython with MIT License 6 votes vote down vote up
def repeat(x, n):
    # this is a workaround for recurrent layer
    # if n is inferred dimension,
    # we can't figure out how to repeat it in cntk now
    # return the same x to take cntk broadcast feature
    # to make the recurrent layer work.
    # need to be fixed in GA.
    if n is C.InferredDimension or n is C.FreeDimension:
        return x
    index = 1 - _get_dynamic_axis_num(x)
    if index < 0 or index > 1:
        raise NotImplementedError

    new_shape = list(x.shape)
    new_shape.insert(index, 1)
    new_shape = tuple(new_shape)
    x = C.reshape(x, new_shape)
    temp = [x] * n
    return C.splice(*temp, axis=index) 
Example #5
Source File: cntk_emitter.py    From MMdnn with MIT License 6 votes vote down vote up
def _layer_Conv(self):
        self.add_body(0, """
def convolution(input, is_transpose, name, **kwargs):
    dim = __weights_dict[name]['weights'].ndim

    if is_transpose:
        weight = np.transpose(__weights_dict[name]['weights'], [dim - 2, dim - 1] + list(range(0, dim - 2)))
        kwargs.pop('groups', None)
    else:
        weight = np.transpose(__weights_dict[name]['weights'], [dim - 1, dim - 2] + list(range(0, dim - 2)))
    w = cntk.Parameter(init=weight, name=name + '_weight')

    input = cntk.transpose(input, [dim - 2] + list(range(0, dim - 2)))

    if is_transpose:
        layer = ops.convolution_transpose(w, input, **kwargs)
    else:
        layer = ops.convolution(w, input, **kwargs)
    if 'bias' in __weights_dict[name]:
        bias = np.reshape(__weights_dict[name]['bias'], [-1] + [1] * (dim - 2))
        b = cntk.Parameter(init=bias, name=name + '_bias')
        layer = layer + b
    layer = cntk.transpose(layer, list(range(1, dim - 1)) + [0])
    return layer
""") 
Example #6
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def random_binomial(shape, p=0.0, dtype=None, seed=None):
    # use numpy workaround now
    if seed is None:
        # ensure that randomness is conditioned by the Numpy RNG
        seed = np.random.randint(10e7)
        np.random.seed(seed)
    if dtype is None:
        dtype = np.float32
    else:
        dtype = _convert_string_dtype(dtype)

    size = 1
    for _ in shape:
        if _ is None:
            raise ValueError('CNTK Backend: randomness op with '
                             'dynamic shape is not supported now. '
                             'Please provide fixed dimension '
                             'instead of `None`.')
        size *= _

    binomial = np.random.binomial(1, p, size).astype(dtype).reshape(shape)
    return variable(value=binomial, dtype=dtype) 
Example #7
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def _remove_dims(x, axis, keepdims=False):
    if keepdims is False and isinstance(axis, list):
        # sequence axis is removed by default, so don't need reshape on it
        reduce_axes = []
        for a in axis:
            if isinstance(a, C.Axis) is False:
                reduce_axes.append(a)
        return _reshape_dummy_dim(x, reduce_axes)
    else:
        if isinstance(axis, list):
            has_seq = False
            for a in axis:
                if isinstance(a, C.Axis):
                    has_seq = True
                    break
            if has_seq:
                nones = _get_dynamic_axis_num(x)
                x = expand_dims(x, nones)
        return x 
Example #8
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def squeeze(x, axis):
    if isinstance(axis, tuple):
        axis = list(axis)
    if not isinstance(axis, list):
        axis = [axis]

    shape = list(int_shape(x))

    _axis = []
    for _ in axis:
        if isinstance(_, int):
            _axis.append(_ if _ >= 0 else _ + len(shape))

    if len(_axis) == 0:
        return x

    nones = _get_dynamic_axis_num(x)
    for _ in sorted(_axis, reverse=True):
        del shape[_]

    new_shape = shape[nones:]
    new_shape = tuple([C.InferredDimension if _ == C.FreeDimension else _ for _ in new_shape])
    return C.reshape(x, new_shape) 
Example #9
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def repeat(x, n):
    # this is a workaround for recurrent layer
    # if n is inferred dimension,
    # we can't figure out how to repeat it in cntk now
    # return the same x to take cntk broadcast feature
    # to make the recurrent layer work.
    # need to be fixed in GA.
    if n is C.InferredDimension or n is C.FreeDimension:
        return x
    index = 1 - _get_dynamic_axis_num(x)
    if index < 0 or index > 1:
        raise NotImplementedError

    new_shape = list(x.shape)
    new_shape.insert(index, 1)
    new_shape = tuple(new_shape)
    x = C.reshape(x, new_shape)
    temp = [x] * n
    return C.splice(*temp, axis=index) 
Example #10
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def random_binomial(shape, p=0.0, dtype=None, seed=None):
    # use numpy workaround now
    if seed is None:
        # ensure that randomness is conditioned by the Numpy RNG
        seed = np.random.randint(10e7)
        np.random.seed(seed)
    if dtype is None:
        dtype = np.float32
    else:
        dtype = _convert_string_dtype(dtype)

    size = 1
    for _ in shape:
        if _ is None:
            raise ValueError('CNTK Backend: randomness op with '
                             'dynamic shape is not supported now. '
                             'Please provide fixed dimension '
                             'instead of `None`.')
        size *= _

    binomial = np.random.binomial(1, p, size).astype(dtype).reshape(shape)
    return variable(value=binomial, dtype=dtype) 
Example #11
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def _remove_dims(x, axis, keepdims=False):
    if keepdims is False and isinstance(axis, list):
        # sequence axis is removed by default, so don't need reshape on it
        reduce_axes = []
        for a in axis:
            if isinstance(a, C.Axis) is False:
                reduce_axes.append(a)
        return _reshape_dummy_dim(x, reduce_axes)
    else:
        if isinstance(axis, list):
            has_seq = False
            for a in axis:
                if isinstance(a, C.Axis):
                    has_seq = True
                    break
            if has_seq:
                nones = _get_dynamic_axis_num(x)
                x = expand_dims(x, nones)
        return x 
Example #12
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def _reshape_dummy_dim(x, axis):
    shape = list(x.shape)

    _axis = [_ + len(shape) if _ < 0 else _ for _ in axis]

    if shape.count(C.InferredDimension) > 1 or shape.count(C.FreeDimension) > 1:
        result = x
        for index in sorted(_axis, reverse=True):
            result = C.reshape(result,
                               shape=(),
                               begin_axis=index,
                               end_axis=index + 1)
        return result
    else:
        for index in sorted(_axis, reverse=True):
            del shape[index]

        shape = [C.InferredDimension if _ == C.FreeDimension else _ for _ in shape]
        return C.reshape(x, shape) 
Example #13
Source File: cntk_backend.py    From GraphicDesignPatternByPython with MIT License 6 votes vote down vote up
def _remove_dims(x, axis, keepdims=False):
    if keepdims is False and isinstance(axis, list):
        # sequence axis is removed by default, so don't need reshape on it
        reduce_axes = []
        for a in axis:
            if isinstance(a, C.Axis) is False:
                reduce_axes.append(a)
        return _reshape_dummy_dim(x, reduce_axes)
    else:
        if isinstance(axis, list):
            has_seq = False
            for a in axis:
                if isinstance(a, C.Axis):
                    has_seq = True
                    break
            if has_seq:
                nones = _get_dynamic_axis_num(x)
                x = expand_dims(x, nones)
        return x 
Example #14
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def _remove_dims(x, axis, keepdims=False):
    if keepdims is False and isinstance(axis, list):
        # sequence axis is removed by default, so don't need reshape on it
        reduce_axes = []
        for a in axis:
            if isinstance(a, C.Axis) is False:
                reduce_axes.append(a)
        return _reshape_dummy_dim(x, reduce_axes)
    else:
        if isinstance(axis, list):
            has_seq = False
            for a in axis:
                if isinstance(a, C.Axis):
                    has_seq = True
                    break
            if has_seq:
                nones = _get_dynamic_axis_num(x)
                x = expand_dims(x, nones)
        return x 
Example #15
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def random_binomial(shape, p=0.0, dtype=None, seed=None):
    # use numpy workaround now
    if seed is None:
        # ensure that randomness is conditioned by the Numpy RNG
        seed = np.random.randint(10e7)
        np.random.seed(seed)
    if dtype is None:
        dtype = np.float32
    else:
        dtype = _convert_string_dtype(dtype)

    size = 1
    for _ in shape:
        if _ is None:
            raise ValueError('CNTK Backend: randomness op with '
                             'dynamic shape is not supported now. '
                             'Please provide fixed dimension '
                             'instead of `None`.')
        size *= _

    binomial = np.random.binomial(1, p, size).astype(dtype).reshape(shape)
    return variable(value=binomial, dtype=dtype) 
Example #16
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
    if data_format is None:
        data_format = image_data_format()
    if data_format not in {'channels_first', 'channels_last'}:
        raise ValueError('Unknown data_format ' + str(data_format))

    stride = strides[0]
    kernel_shape = int_shape(kernel)
    output_length, feature_dim, filters = kernel_shape

    xs = []
    for i in range(output_length):
        slice_length = slice(i * stride,
                             i * stride + kernel_size[0])
        xs.append(reshape(inputs[:, slice_length, :],
                          (-1, 1, feature_dim)))
    x_aggregate = concatenate(xs, axis=1)
    # transpose kernel to output_filters first, to apply broadcast
    weight = permute_dimensions(kernel, (2, 0, 1))
    # Shape: (batch, filters, output_length, input_length * kernel_size)
    output = x_aggregate * weight
    # Shape: (batch, filters, output_length)
    output = sum(output, axis=3)
    # Shape: (batch, output_length, filters)
    return permute_dimensions(output, (0, 2, 1)) 
Example #17
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def repeat(x, n):
    # this is a workaround for recurrent layer
    # if n is inferred dimension,
    # we can't figure out how to repeat it in cntk now
    # return the same x to take cntk broadcast feature
    # to make the recurrent layer work.
    # need to be fixed in GA.
    if n is C.InferredDimension or n is C.FreeDimension:
        return x
    index = 1 - _get_dynamic_axis_num(x)
    if index < 0 or index > 1:
        raise NotImplementedError

    new_shape = list(x.shape)
    new_shape.insert(index, 1)
    new_shape = tuple(new_shape)
    x = C.reshape(x, new_shape)
    temp = [x] * n
    return C.splice(*temp, axis=index) 
Example #18
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def _reshape_dummy_dim(x, axis):
    shape = list(x.shape)

    _axis = [_ + len(shape) if _ < 0 else _ for _ in axis]

    if shape.count(C.InferredDimension) > 1 or shape.count(C.FreeDimension) > 1:
        result = x
        for index in sorted(_axis, reverse=True):
            result = C.reshape(result,
                               shape=(),
                               begin_axis=index,
                               end_axis=index + 1)
        return result
    else:
        for index in sorted(_axis, reverse=True):
            del shape[index]

        shape = [C.InferredDimension if _ == C.FreeDimension else _ for _ in shape]
        return C.reshape(x, shape) 
Example #19
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def squeeze(x, axis):
    if isinstance(axis, tuple):
        axis = list(axis)
    if not isinstance(axis, list):
        axis = [axis]

    shape = list(int_shape(x))

    _axis = []
    for _ in axis:
        if isinstance(_, int):
            _axis.append(_ if _ >= 0 else _ + len(shape))

    if len(_axis) == 0:
        return x

    nones = _get_dynamic_axis_num(x)
    for _ in sorted(_axis, reverse=True):
        del shape[_]

    new_shape = shape[nones:]
    new_shape = tuple([C.InferredDimension if _ == C.FreeDimension else _ for _ in new_shape])
    return C.reshape(x, new_shape) 
Example #20
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def _remove_dims(x, axis, keepdims=False):
    if keepdims is False and isinstance(axis, list):
        # sequence axis is removed by default, so don't need reshape on it
        reduce_axes = []
        for a in axis:
            if isinstance(a, C.Axis) is False:
                reduce_axes.append(a)
        return _reshape_dummy_dim(x, reduce_axes)
    else:
        if isinstance(axis, list):
            has_seq = False
            for a in axis:
                if isinstance(a, C.Axis):
                    has_seq = True
                    break
            if has_seq:
                nones = _get_dynamic_axis_num(x)
                x = expand_dims(x, nones)
        return x 
Example #21
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
    if data_format is None:
        data_format = image_data_format()
    if data_format not in {'channels_first', 'channels_last'}:
        raise ValueError('Unknown data_format ' + str(data_format))

    stride = strides[0]
    kernel_shape = int_shape(kernel)
    output_length, feature_dim, filters = kernel_shape

    xs = []
    for i in range(output_length):
        slice_length = slice(i * stride,
                             i * stride + kernel_size[0])
        xs.append(reshape(inputs[:, slice_length, :],
                          (-1, 1, feature_dim)))
    x_aggregate = concatenate(xs, axis=1)
    # transpose kernel to output_filters first, to apply broadcast
    weight = permute_dimensions(kernel, (2, 0, 1))
    # Shape: (batch, filters, output_length, input_length * kernel_size)
    output = x_aggregate * weight
    # Shape: (batch, filters, output_length)
    output = sum(output, axis=3)
    # Shape: (batch, output_length, filters)
    return permute_dimensions(output, (0, 2, 1)) 
Example #22
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 6 votes vote down vote up
def squeeze(x, axis):
    if isinstance(axis, tuple):
        axis = list(axis)
    if not isinstance(axis, list):
        axis = [axis]

    shape = list(int_shape(x))

    _axis = []
    for _ in axis:
        if isinstance(_, int):
            _axis.append(_ if _ >= 0 else _ + len(shape))

    if len(_axis) == 0:
        return x

    nones = _get_dynamic_axis_num(x)
    for _ in sorted(_axis, reverse=True):
        del shape[_]

    new_shape = shape[nones:]
    new_shape = tuple([C.InferredDimension if _ == C.FreeDimension else _ for _ in new_shape])
    return C.reshape(x, new_shape) 
Example #23
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def depthwise_conv2d(x, depthwise_kernel, strides=(1, 1), padding='valid',
                     data_format=None, dilation_rate=(1, 1)):
    if data_format is None:
        data_format = image_data_format()
    if data_format not in {'channels_first', 'channels_last'}:
        raise ValueError('Unknown data_format ' + str(data_format))

    x = _preprocess_conv2d_input(x, data_format)
    depthwise_kernel = _preprocess_conv2d_kernel(depthwise_kernel, data_format)
    depthwise_kernel = C.reshape(C.transpose(depthwise_kernel, (1, 0, 2, 3)),
                                 (-1, 1) + depthwise_kernel.shape[2:])
    padding = _preprocess_border_mode(padding)
    if dilation_rate == (1, 1):
        strides = (1,) + strides
        x = C.convolution(depthwise_kernel, x,
                          strides=strides,
                          auto_padding=[False, padding, padding],
                          groups=x.shape[0])
    else:
        if dilation_rate[0] != dilation_rate[1]:
            raise ValueError('CNTK Backend: non-square dilation_rate is '
                             'not supported.')
        if strides != (1, 1):
            raise ValueError('Invalid strides for dilated convolution')
        x = C.convolution(depthwise_kernel, x,
                          strides=dilation_rate[0],
                          auto_padding=[False, padding, padding],
                          groups=x.shape[0])
    return _postprocess_conv2d_output(x, data_format) 
Example #24
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def _reshape_batch(x, shape):
    # there is a bug in cntk 2.1's unpack_batch implementation
    if hasattr(C, 'unpack_batch') and _get_cntk_version() >= 2.2:
        const_a = C.unpack_batch(x)
        const_a = C.reshape(const_a, shape)
        return C.to_batch(const_a)
    else:
        return C.user_function(ReshapeBatch(x, shape[1:])) 
Example #25
Source File: cntk_backend.py    From GraphicDesignPatternByPython with MIT License 5 votes vote down vote up
def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
                     padding='valid', data_format=None, dilation_rate=(1, 1)):
    data_format = normalize_data_format(data_format)

    x = _preprocess_conv2d_input(x, data_format)
    depthwise_kernel = _preprocess_conv2d_kernel(depthwise_kernel, data_format)
    depthwise_kernel = C.reshape(C.transpose(depthwise_kernel, (1, 0, 2, 3)),
                                 (-1, 1) + depthwise_kernel.shape[2:])
    pointwise_kernel = _preprocess_conv2d_kernel(pointwise_kernel, data_format)
    padding = _preprocess_border_mode(padding)

    if dilation_rate == (1, 1):
        strides = (1,) + strides
        x = C.convolution(depthwise_kernel, x,
                          strides=strides,
                          auto_padding=[False, padding, padding],
                          groups=x.shape[0])
        x = C.convolution(pointwise_kernel, x,
                          strides=(1, 1, 1),
                          auto_padding=[False])
    else:
        if dilation_rate[0] != dilation_rate[1]:
            raise ValueError('CNTK Backend: non-square dilation_rate is '
                             'not supported.')
        if strides != (1, 1):
            raise ValueError('Invalid strides for dilated convolution')
        x = C.convolution(depthwise_kernel, x,
                          strides=dilation_rate[0],
                          auto_padding=[False, padding, padding])
        x = C.convolution(pointwise_kernel, x,
                          strides=(1, 1, 1),
                          auto_padding=[False])
    return _postprocess_conv2d_output(x, data_format) 
Example #26
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def _reshape_sequence(x, time_step):
    tmp_shape = list(int_shape(x))
    tmp_shape[1] = time_step
    return reshape(x, tmp_shape) 
Example #27
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def sparse_categorical_crossentropy(target, output, from_logits=False):
    target = C.one_hot(target, output.shape[-1])
    target = C.reshape(target, output.shape)
    return categorical_crossentropy(target, output, from_logits) 
Example #28
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def categorical_crossentropy(target, output, from_logits=False):
    if from_logits:
        result = C.cross_entropy_with_softmax(output, target)
        # cntk's result shape is (batch, 1), while keras expect (batch, )
        return C.reshape(result, ())
    else:
        # scale preds so that the class probas of each sample sum to 1
        output /= C.reduce_sum(output, axis=-1)
        # avoid numerical instability with epsilon clipping
        output = C.clip(output, epsilon(), 1.0 - epsilon())
        return -sum(target * C.log(output), axis=-1) 
Example #29
Source File: cntk_backend.py    From DeepLearning_Wavelet-LSTM with MIT License 5 votes vote down vote up
def reshape(x, shape):
    shape = tuple([C.InferredDimension if _ == C.FreeDimension else _ for _ in shape])
    if isinstance(x, C.variables.Parameter):
        return C.reshape(x, shape)
    else:
        num_dynamic_axis = _get_dynamic_axis_num(x)

        if num_dynamic_axis == 1 and len(shape) > 0 and shape[0] == -1:
            # collapse axis with batch axis
            if b_any(_ == C.InferredDimension for _ in x.shape) or b_any(
                    _ == C.FreeDimension for _ in x.shape):
                warnings.warn(
                    'Warning: CNTK backend does not support '
                    'collapse of batch axis with inferred dimension. '
                    'The reshape did not take place.')
                return x
            return _reshape_batch(x, shape)
        else:
            # no collapse, then first need to padding the shape
            if num_dynamic_axis >= len(shape):
                i = 0
                while i < len(shape):
                    if shape[i] is None or shape[i] == -1:
                        i += 1
                    else:
                        break
                shape = tuple([-1 for _ in range(num_dynamic_axis - i)]) + shape

            new_shape = list(shape)
            new_shape = new_shape[num_dynamic_axis:]
            new_shape = [C.InferredDimension if _ is None else _ for _ in new_shape]
            return C.reshape(x, new_shape) 
Example #30
Source File: cntk_backend.py    From GraphicDesignPatternByPython with MIT License 5 votes vote down vote up
def batch_flatten(x):
    # cntk's batch axis is not in shape,
    # so just flatten all the dim in x.shape
    dim = np.prod(x.shape)
    x = C.reshape(x, (-1,))
    x._keras_shape = (None, dim)
    return x