@Override public INDArray doForward(boolean training) { if (!canDoForward()) throw new IllegalStateException("Cannot do forward pass: all inputs not set"); return layer.activate(training); }
@Override public boolean canDoBackward() { if (!isOutputVertex()) { //inputs to frozen layer go unchecked, so could be null if (getLayer() instanceof FrozenLayer) { return true; } else { return super.canDoBackward(); } } for (INDArray input : inputs) { if (input == null) { return false; } } if (!(layer instanceof IOutputLayer)) { if (epsilon == null) { return false; } } return true; } }
@Override public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) { if (!canDoBackward()) { throw new IllegalStateException("Cannot do backward pass: all epsilons not set. Layer " + vertexName + " (idx " + vertexIndex + ") numInputs " + getNumInputArrays() + "; numOutputs " + getNumOutputConnections()); } Pair<Gradient, INDArray> pair; if (tbptt && layer instanceof RecurrentLayer) { //Truncated BPTT for recurrent layers pair = ((RecurrentLayer) layer).tbpttBackpropGradient(epsilon, graph.getConfiguration().getTbpttBackLength()); } else { //Normal backprop pair = layer.backpropGradient(epsilon); //epsTotal may be null for OutputLayers } if (layerPreProcessor != null) { INDArray eps = pair.getSecond(); eps = layerPreProcessor.backprop(eps, graph.batchSize()); pair.setSecond(eps); } //Layers always have single activations input -> always have single epsilon output during backprop return new Pair<>(pair.getFirst(), new INDArray[] {pair.getSecond()}); }
@Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) { //Now, we need to work out if this vertex is an output vertex or not... boolean isOutput = graph.getConfiguration().getNetworkOutputs().contains(name); org.deeplearning4j.nn.api.Layer layer = layerConf.getLayer().instantiate(layerConf, null, idx, paramsView, initializeParams); return new org.deeplearning4j.nn.graph.vertex.impl.LayerVertex(graph, name, idx, layer, preProcessor, isOutput); }
String paramShape = "-"; if (current.hasLayer()) { Layer currentLayer = ((LayerVertex) current).getLayer(); classNameArr = currentLayer.getClass().getName().split("\\."); className = classNameArr[classNameArr.length - 1];