@Override public void setNIn(InputType inputType, boolean override) { if (inputType == null || (inputType.getType() != InputType.Type.FF && inputType.getType() != InputType.Type.CNNFlat)) { throw new IllegalStateException("Invalid input type (layer name=\"" + getLayerName() + "\"): expected FeedForward input type. Got: " + inputType); } if (nIn <= 0 || override) { if (inputType.getType() == InputType.Type.FF) { InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward) inputType; this.nIn = f.getSize(); } else { InputType.InputTypeConvolutionalFlat f = (InputType.InputTypeConvolutionalFlat) inputType; this.nIn = f.getFlattenedSize(); } } }
case CNNFlat: InputType.InputTypeConvolutionalFlat c3 = (InputType.InputTypeConvolutionalFlat) inputType; if (c3.getDepth() != numChannels || c3.getHeight() != inputHeight || c3.getWidth() != inputWidth) { throw new IllegalStateException("Invalid input: Got CNN input type with (d,w,h)=(" + c3.getDepth() + "," + c3.getWidth() + "," + c3.getHeight() + ") but expected (" + numChannels + "," + inputHeight + "," + inputWidth + ")"); return c3.getUnflattenedType(); default: throw new IllegalStateException("Invalid input type: got " + inputType);
@Override public InputType getOutputType(int layerIndex, InputType inputType) { int inH; int inW; int inDepth; if (inputType instanceof InputType.InputTypeConvolutional) { InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) inputType; inH = conv.getHeight(); inW = conv.getWidth(); inDepth = conv.getDepth(); } else if (inputType instanceof InputType.InputTypeConvolutionalFlat) { InputType.InputTypeConvolutionalFlat conv = (InputType.InputTypeConvolutionalFlat) inputType; inH = conv.getHeight(); inW = conv.getWidth(); inDepth = conv.getDepth(); } else { throw new IllegalStateException( "Invalid input type: expected InputTypeConvolutional or InputTypeConvolutionalFlat." + " Got: " + inputType); } int outH = inH + padding[0] + padding[1]; int outW = inW + padding[2] + padding[3]; return InputType.convolutional(outH, outW, inDepth); }
@Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { if (vertexInputs.length != 1) throw new InvalidInputTypeException("Invalid input type: cannot duplicate more than 1 input"); int tsLength = 1; //TODO work this out properly if (vertexInputs[0].getType() == InputType.Type.FF) { return InputType.recurrent(((InputType.InputTypeFeedForward) vertexInputs[0]).getSize(), tsLength); } else if (vertexInputs[0].getType() != InputType.Type.CNNFlat) { return InputType.recurrent(((InputType.InputTypeConvolutionalFlat) vertexInputs[0]).getFlattenedSize(), tsLength); } else { throw new InvalidInputTypeException( "Invalid input type: cannot duplicate to time series non feed forward (or CNN flat) input (got: " + vertexInputs[0] + ")"); } }
@Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { switch (inputType.getType()) { case FF: throw new UnsupportedOperationException( "Global max pooling cannot be applied to feed-forward input type. Got input type = " + inputType); case RNN: case CNN: //No preprocessor required return null; case CNNFlat: InputType.InputTypeConvolutionalFlat cFlat = (InputType.InputTypeConvolutionalFlat) inputType; return new FeedForwardToCnnPreProcessor(cFlat.getHeight(), cFlat.getWidth(), cFlat.getDepth()); } return null; }
@Override public void setNIn(InputType inputType, boolean override) { if (nIn <= 0 || override) { switch (inputType.getType()) { case FF: nIn = ((InputType.InputTypeFeedForward) inputType).getSize(); break; case CNN: nIn = ((InputType.InputTypeConvolutional) inputType).getDepth(); break; case CNNFlat: nIn = ((InputType.InputTypeConvolutionalFlat) inputType).getDepth(); default: throw new IllegalStateException( "Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got " + inputType + " for layer " + getLayerName() + "\""); } nOut = nIn; } }
@Override public InputType getOutputType(InputType inputType) { if (inputType == null || (inputType.getType() != InputType.Type.FF && inputType.getType() != InputType.Type.CNNFlat)) { throw new IllegalStateException("Invalid input: expected input of type FeedForward, got " + inputType); } if (inputType.getType() == InputType.Type.FF) { InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType; return InputType.recurrent(ff.getSize()); } else { InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType; return InputType.recurrent(cf.getFlattenedSize()); } }
/** * Input type for convolutional (CNN) data, where the data is in flattened (row vector) format. * Expect data with shape [miniBatchSize, height * width * depth]. For CNN data in 4d format, use {@link #convolutional(int, int, int)} * * @param height Height of the (unflattened) data represented by this input type * @param width Width of the (unflattened) data represented by this input type * @param depth Depth of the (unflattened) data represented by this input type * @return */ public static InputType convolutionalFlat(int height, int width, int depth) { return new InputTypeConvolutionalFlat(height, width, depth); }
@Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { if (inputType.getType() == InputType.Type.CNNFlat) { InputType.InputTypeConvolutionalFlat i = (InputType.InputTypeConvolutionalFlat) inputType; return new FeedForwardToCnnPreProcessor(i.getHeight(), i.getWidth(), i.getDepth()); } return null; }