@Override public InputType getOutputType(int layerIndex, InputType inputType) { return layer.getOutputType(layerIndex, inputType); }
@Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { if (vertexInputs.length != 1) { throw new InvalidInputTypeException( "LayerVertex expects exactly one input. Got: " + Arrays.toString(vertexInputs)); } //Assume any necessary preprocessors have already been added InputType afterPreprocessor; if (preProcessor == null) afterPreprocessor = vertexInputs[0]; else afterPreprocessor = preProcessor.getOutputType(vertexInputs[0]); return layerConf.getLayer().getOutputType(layerIndex, afterPreprocessor); }
/** * Get a {@link MemoryReport} for the given MultiLayerConfiguration. This is used to estimate the * memory requirements for the given network configuration and input * * @param inputType Input types for the network * @return Memory report for the network */ public NetworkMemoryReport getMemoryReport(InputType inputType) { Map<String, MemoryReport> memoryReportMap = new LinkedHashMap<>(); int nLayers = confs.size(); for (int i = 0; i < nLayers; i++) { String layerName = confs.get(i).getLayer().getLayerName(); if (layerName == null) { layerName = String.valueOf(i); } //Pass input type through preprocessor, if necessary InputPreProcessor preproc = getInputPreProcess(0); //TODO memory requirements for preprocessor if (preproc != null) { inputType = preproc.getOutputType(inputType); } LayerMemoryReport report = confs.get(i).getLayer().getMemoryReport(inputType); memoryReportMap.put(layerName, report); inputType = confs.get(i).getLayer().getOutputType(i, inputType); } return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class, "MultiLayerNetwork", inputType); }
currentInputType = l.getOutputType(i, currentInputType);