Java Code Examples for org.deeplearning4j.nn.conf.inputs.InputType#InputTypeConvolutionalFlat
The following examples show how to use
org.deeplearning4j.nn.conf.inputs.InputType#InputTypeConvolutionalFlat .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: GlobalPoolingLayer.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { switch (inputType.getType()) { case FF: throw new UnsupportedOperationException( "Global max pooling cannot be applied to feed-forward input type. Got input type = " + inputType); case RNN: case CNN: case CNN3D: //No preprocessor required return null; case CNNFlat: InputType.InputTypeConvolutionalFlat cFlat = (InputType.InputTypeConvolutionalFlat) inputType; return new FeedForwardToCnnPreProcessor(cFlat.getHeight(), cFlat.getWidth(), cFlat.getDepth()); } return null; }
Example 2
Source File: FeedForwardToRnnPreProcessor.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public InputType getOutputType(InputType inputType) { if (inputType == null || (inputType.getType() != InputType.Type.FF && inputType.getType() != InputType.Type.CNNFlat)) { throw new IllegalStateException("Invalid input: expected input of type FeedForward, got " + inputType); } if (inputType.getType() == InputType.Type.FF) { InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType; return InputType.recurrent(ff.getSize(), rnnDataFormat); } else { InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType; return InputType.recurrent(cf.getFlattenedSize(), rnnDataFormat); } }
Example 3
Source File: FeedForwardToCnnPreProcessor.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public InputType getOutputType(InputType inputType) { switch (inputType.getType()) { case FF: InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType; val expSize = inputHeight * inputWidth * numChannels; if (c.getSize() != expSize) { throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got " + inputType); } return InputType.convolutional(inputHeight, inputWidth, numChannels); case CNN: InputType.InputTypeConvolutional c2 = (InputType.InputTypeConvolutional) inputType; if (c2.getChannels() != numChannels || c2.getHeight() != inputHeight || c2.getWidth() != inputWidth) { throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c2.getChannels() + "," + c2.getWidth() + "," + c2.getHeight() + ") but expected (" + numChannels + "," + inputHeight + "," + inputWidth + ")"); } return c2; case CNNFlat: InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType; if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) { throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth() + "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels + "," + inputHeight + "," + inputWidth + ")"); } return c3.getUnflattenedType(); default: throw new IllegalStateException("Invalid input type: got " + inputType); } }
Example 4
Source File: BatchNormalization.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { if (inputType.getType() == InputType.Type.CNNFlat) { InputType.InputTypeConvolutionalFlat i = (InputType.InputTypeConvolutionalFlat) inputType; return new FeedForwardToCnnPreProcessor(i.getHeight(), i.getWidth(), i.getDepth()); } else if (inputType.getType() == InputType.Type.RNN) { return new RnnToFeedForwardPreProcessor(); } return null; }
Example 5
Source File: Yolo2OutputLayer.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { switch (inputType.getType()) { case FF: case RNN: throw new UnsupportedOperationException("Cannot use FF or RNN input types"); case CNN: return null; case CNNFlat: InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType; return new FeedForwardToCnnPreProcessor(cf.getHeight(), cf.getWidth(), cf.getDepth()); default: return null; } }
Example 6
Source File: InputTypeUtil.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Utility method for determining the appropriate preprocessor for CNN layers, such as {@link ConvolutionLayer} and * {@link SubsamplingLayer} * * @param inputType Input type to get the preprocessor for * @return Null if no preprocessor is required; otherwise the appropriate preprocessor for the given input type */ public static InputPreProcessor getPreProcessorForInputTypeCnnLayers(InputType inputType, String layerName) { //To add x-to-CNN preprocessor: need to know image channels/width/height after reshaping //But this can't be inferred from the FF/RNN activations directly (could be anything) switch (inputType.getType()) { case FF: //FF -> CNN // return new FeedForwardToCnnPreProcessor(inputSize[0], inputSize[1], inputDepth); log.info("Automatic addition of FF -> CNN preprocessors: not yet implemented (layer name: \"" + layerName + "\")"); return null; case RNN: //RNN -> CNN // return new RnnToCnnPreProcessor(inputSize[0], inputSize[1], inputDepth); log.warn("Automatic addition of RNN -> CNN preprocessors: not yet implemented (layer name: \"" + layerName + "\")"); return null; case CNN: //CNN -> CNN: no preprocessor required return null; case CNNFlat: //CNN (flat) -> CNN InputType.InputTypeConvolutionalFlat f = (InputType.InputTypeConvolutionalFlat) inputType; return new FeedForwardToCnnPreProcessor(f.getHeight(), f.getWidth(), f.getDepth()); default: throw new RuntimeException("Unknown input type: " + inputType); } }
Example 7
Source File: GlobalPoolingLayer.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public InputType getOutputType(int layerIndex, InputType inputType) { switch (inputType.getType()) { case FF: throw new UnsupportedOperationException( "Global max pooling cannot be applied to feed-forward input type. Got input type = " + inputType); case RNN: InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType; if (collapseDimensions) { //Return 2d (feed-forward) activations return InputType.feedForward(recurrent.getSize()); } else { //Return 3d activations, with shape [minibatch, timeStepSize, 1] return recurrent; } case CNN: InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) inputType; if (collapseDimensions) { return InputType.feedForward(conv.getChannels()); } else { return InputType.convolutional(1, 1, conv.getChannels(), conv.getFormat()); } case CNN3D: InputType.InputTypeConvolutional3D conv3d = (InputType.InputTypeConvolutional3D) inputType; if (collapseDimensions) { return InputType.feedForward(conv3d.getChannels()); } else { return InputType.convolutional3D(1, 1, 1, conv3d.getChannels()); } case CNNFlat: InputType.InputTypeConvolutionalFlat convFlat = (InputType.InputTypeConvolutionalFlat) inputType; if (collapseDimensions) { return InputType.feedForward(convFlat.getDepth()); } else { return InputType.convolutional(1, 1, convFlat.getDepth()); } default: throw new UnsupportedOperationException("Unknown or not supported input type: " + inputType); } }