@Override public InputType getOutputType(InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input type: Expected input of type RNN, got " + inputType); } InputType.InputTypeRecurrent c = (InputType.InputTypeRecurrent) inputType; int expSize = inputHeight * inputWidth * numChannels; if (c.getSize() != expSize) { throw new IllegalStateException("Invalid input: expected RNN input of size " + expSize + " = (d=" + numChannels + " * w=" + inputWidth + " * h=" + inputHeight + "), got " + inputType); } return InputType.convolutional(inputHeight, inputWidth, numChannels); }
@Override public InputType getOutputType(InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input: expected input of type RNN, got " + inputType); } InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; return InputType.feedForward(rnn.getSize()); }
@Override public void setNIn(InputType inputType, boolean override) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input for RNN layer (layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType); } if (nIn <= 0 || override) { InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); } }
break; case RNN: thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize(); type = InputType.Type.RNN; break;
break; case RNN: thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize(); type = InputType.Type.RNN; break;
break; case RNN: thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize(); tsLength = ((InputType.InputTypeRecurrent) vertexInputs[i]).getTimeSeriesLength(); type = InputType.Type.RNN;
break; case RNN: thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize(); type = InputType.Type.RNN; break;
if (collapseDimensions) { return InputType.feedForward(recurrent.getSize()); } else {
@Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { if (vertexInputs.length != 1) throw new InvalidInputTypeException("Invalid input type: cannot get last time step of more than 1 input"); if (vertexInputs[0].getType() != InputType.Type.RNN) { throw new InvalidInputTypeException( "Invalid input type: cannot get subset of non RNN input (got: " + vertexInputs[0] + ")"); } return InputType.feedForward(((InputType.InputTypeRecurrent) vertexInputs[0]).getSize()); }
@Override public void setNIn(InputType inputType, boolean override) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input for 1D CNN layer (layer name = \"" + getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType); } if (nIn <= 0 || override) { InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); } }
@Override public void setNIn(InputType inputType, boolean override) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input type for RnnOutputLayer (layer name=\"" + getLayerName() + "\"): Expected RNN input, got " + inputType); } if (nIn <= 0 || override) { InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); } }