switch (this.inputShape.length) { case 1: myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]); break; case 2:
InputType.InputTypeFeedForward c = (InputType.InputTypeFeedForward) inputType; int expSize = inputHeight * inputWidth * numChannels; if (c.getSize() != expSize) { throw new IllegalStateException("Invalid input: expected FeedForward input of size " + expSize + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got "
switch (vertexInputs[i].getType()) { case FF: thisSize = ((InputType.InputTypeFeedForward) vertexInputs[i]).getSize(); type = InputType.Type.FF; break;
switch (vertexInputs[i].getType()) { case FF: thisSize = ((InputType.InputTypeFeedForward) vertexInputs[i]).getSize(); type = InputType.Type.FF; break;
switch (vertexInputs[i].getType()) { case FF: thisSize = ((InputType.InputTypeFeedForward) vertexInputs[i]).getSize(); type = InputType.Type.FF; break;
switch (vertexInputs[i].getType()) { case FF: thisSize = ((InputType.InputTypeFeedForward) vertexInputs[i]).getSize(); type = InputType.Type.FF; break;
@Override public void setNIn(InputType inputType, boolean override) { if (nIn <= 0 || override) { switch (inputType.getType()) { case FF: nIn = ((InputType.InputTypeFeedForward) inputType).getSize(); break; case CNN: nIn = ((InputType.InputTypeConvolutional) inputType).getDepth(); break; case CNNFlat: nIn = ((InputType.InputTypeConvolutionalFlat) inputType).getDepth(); default: throw new IllegalStateException( "Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got " + inputType + " for layer " + getLayerName() + "\""); } nOut = nIn; } }
@Override public void setNIn(InputType inputType, boolean override) { if (inputType == null || (inputType.getType() != InputType.Type.FF && inputType.getType() != InputType.Type.CNNFlat)) { throw new IllegalStateException("Invalid input type (layer name=\"" + getLayerName() + "\"): expected FeedForward input type. Got: " + inputType); } if (nIn <= 0 || override) { if (inputType.getType() == InputType.Type.FF) { InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward) inputType; this.nIn = f.getSize(); } else { InputType.InputTypeConvolutionalFlat f = (InputType.InputTypeConvolutionalFlat) inputType; this.nIn = f.getFlattenedSize(); } } }
@Override public InputType getOutputType(InputType inputType) { if (inputType == null || (inputType.getType() != InputType.Type.FF && inputType.getType() != InputType.Type.CNNFlat)) { throw new IllegalStateException("Invalid input: expected input of type FeedForward, got " + inputType); } if (inputType.getType() == InputType.Type.FF) { InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType; return InputType.recurrent(ff.getSize()); } else { InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType; return InputType.recurrent(cf.getFlattenedSize()); } }
@Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { if (vertexInputs.length != 1) throw new InvalidInputTypeException("Invalid input type: cannot duplicate more than 1 input"); int tsLength = 1; //TODO work this out properly if (vertexInputs[0].getType() == InputType.Type.FF) { return InputType.recurrent(((InputType.InputTypeFeedForward) vertexInputs[0]).getSize(), tsLength); } else if (vertexInputs[0].getType() != InputType.Type.CNNFlat) { return InputType.recurrent(((InputType.InputTypeConvolutionalFlat) vertexInputs[0]).getFlattenedSize(), tsLength); } else { throw new InvalidInputTypeException( "Invalid input type: cannot duplicate to time series non feed forward (or CNN flat) input (got: " + vertexInputs[0] + ")"); } }
/** InputType for feed forward network data * @param size The size of the activations */ public static InputType feedForward(int size) { return new InputTypeFeedForward(size); }