@Override public void clearVertex() { clear(); epsilon = null; } }
@Override public boolean canDoBackward() { if (!isOutputVertex()) { //inputs to frozen layer go unchecked, so could be null if (getLayer() instanceof FrozenLayer) { return true; } else { return super.canDoBackward(); } } for (INDArray input : inputs) { if (input == null) { return false; } } if (!(layer instanceof IOutputLayer)) { if (epsilon == null) { return false; } } return true; } }
@Override public void setInput(int inputNumber, INDArray input) { if (inputNumber >= getNumInputArrays()) { throw new IllegalArgumentException("Invalid input number"); } inputs[inputNumber] = input; }
if (current.isInputVertex()) { VertexIndices[] inputsTo = current.getOutputVertices(); INDArray input = inputs[current.getVertexIndex()]; layerActivations.put(current.getVertexName(), input); int vIdx = v.getVertexIndex(); int vIdxInputNum = v.getVertexEdgeNumber(); vertices[vIdx].setInput(vIdxInputNum, input.dup()); //TODO When to dup? if (current.hasLayer()) { Layer l = current.getLayer(); if (l instanceof RecurrentLayer) { out = ((RecurrentLayer) l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT); } else if (l instanceof MultiLayerNetwork) { List<INDArray> temp = ((MultiLayerNetwork) l).rnnActivateUsingStoredState( current.getInputs()[0], training, storeLastForTBPTT); out = temp.get(temp.size() - 1); } else { out = current.doForward(training); layerActivations.put(current.getVertexName(), out); } else { out = current.doForward(training); VertexIndices[] outputsTo = current.getOutputVertices();
for (int v = 0; v < vertices.length; v++) { GraphVertex vertex = vertices[v]; VertexIndices[] indices = vertex.getInputVertices(); GraphVertex cv = vertices[indices[i].getVertexIndex()]; String inputName = cv.getVertexName(); LayerInfo info = model.getLayerInfoByName(vertex.getVertexName()); if (info == null) info = getLayerInfo(vertex.getLayer(), x, currentY, 121); info.setName(vertex.getVertexName()); if (vertex.getLayer() == null) { info.setLayerType(vertex.getClass().getSimpleName()); if (model.getLayerInfoByName(vertex.getVertexName()) == null) { x++; model.addLayer(info);
private void copyParamsFromSubsetGraphToOrig() { for (GraphVertex aVertex : unFrozenSubsetGraph.getVertices()) { if (!aVertex.hasLayer()) continue; origGraph.getVertex(aVertex.getVertexName()).getLayer().setParams(aVertex.getLayer().params()); } }
@Override protected Layer[] getOrderedLayers() { if (orderedLayers != null) { return orderedLayers; } GraphVertex[] vertices = network.getVertices(); //In CompGraph: we need to know topological ordering, so we know how parameters are laid out in the 1d view arrays int[] topologicalOrdering = network.topologicalSortOrder(); Layer[] out = new Layer[network.getNumLayers()]; int j = 0; for (int i = 0; i < topologicalOrdering.length; i++) { GraphVertex currentVertex = vertices[topologicalOrdering[i]]; if (!currentVertex.hasLayer()) { continue; } out[j++] = currentVertex.getLayer(); } orderedLayers = out; return orderedLayers; }
/** * Get a given layer by name. */ public Layer getLayer(String name) { return verticesMap.get(name).getLayer(); //TODO checks }
@Override public ComputationGraph clone() { ComputationGraph cg = new ComputationGraph(configuration.clone()); cg.init(params().dup(), false); if (solver != null) { //If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however ComputationGraphUpdater u = this.getUpdater(); INDArray updaterState = u.getStateViewArray(); if (updaterState != null) { cg.getUpdater().setStateViewArray(updaterState.dup()); } } cg.listeners = this.listeners; for (int i = 0; i < topologicalOrder.length; i++) { if (!vertices[topologicalOrder[i]].hasLayer()) continue; String layerName = vertices[topologicalOrder[i]].getVertexName(); if (getLayer(layerName) instanceof FrozenLayer) { cg.getVertex(layerName).setLayerAsFrozen(); } } return cg; }
/** * This method just makes sure there's no state preserved within layers */ protected void clearLayersStates() { for (int f = 0; f < layers.length; f++) { layers[f].setInput(null); } for (int f = 0; f < vertices.length; f++) { vertices[f].clearVertex(); } }
for (int i = 0; i < graphInputs.size(); i++) { String anInput = graphInputs.get(i); if (origGraph.getVertex(anInput).isInputVertex()) {
if (current.isInputVertex()) { VertexIndices[] inputsTo = current.getOutputVertices(); INDArray input = inputs[current.getVertexIndex()]; int vIdx = v.getVertexIndex(); int vIdxInputNum = v.getVertexEdgeNumber(); vertices[vIdx].setInput(vIdxInputNum, input.dup()); //TODO When to dup? if (current.hasLayer()) { Layer l = current.getLayer(); if (l instanceof RecurrentLayer) { out = ((RecurrentLayer) l).rnnTimeStep(current.getInputs()[0]); } else if (l instanceof MultiLayerNetwork) { out = ((MultiLayerNetwork) l).rnnTimeStep(current.getInputs()[0]); } else { out = current.doForward(false); out = current.doForward(false); if (current.isOutputVertex()) { int idx = configuration.getNetworkOutputs().indexOf(current.getVertexName()); outputs[idx] = out; VertexIndices[] outputsTo = current.getOutputVertices(); if (outputsTo != null) {
for (int v = 0; v < vertices.length; v++) { GraphVertex vertex = vertices[v]; VertexIndices[] indices = vertex.getInputVertices(); GraphVertex cv = vertices[indices[i].getVertexIndex()]; String inputName = cv.getVertexName(); LayerInfo info = model.getLayerInfoByName(vertex.getVertexName()); if (info == null) info = getLayerInfo(vertex.getLayer(), x, currentY, 121); info.setName(vertex.getVertexName()); if (vertex.getLayer() == null) { info.setLayerType(vertex.getClass().getSimpleName()); if (model.getLayerInfoByName(vertex.getVertexName()) == null) { x++; model.addLayer(info);
private void copyOrigParamsToSubsetGraph() { for (GraphVertex aVertex : unFrozenSubsetGraph.getVertices()) { if (!aVertex.hasLayer()) continue; aVertex.getLayer().setParams(origGraph.getLayer(aVertex.getVertexName()).params()); } }
/** * Get the parameters for the ComputationGraph * * @param backwardOnly If true: backprop parameters only (i.e., no visible layer biases used in layerwise pretraining layers) */ public INDArray params(boolean backwardOnly) { if (backwardOnly) return flattenedParams; List<INDArray> list = new ArrayList<>(layers.length); for (int i = 0; i < topologicalOrder.length; i++) { if (!vertices[topologicalOrder[i]].hasLayer()) continue; Layer l = vertices[topologicalOrder[i]].getLayer(); INDArray layerParams = l.params(); if (layerParams != null) list.add(layerParams); //may be null: subsampling etc layers } return Nd4j.toFlattened('f', list); }
/** * Get the state of the RNN layer, as used in {@link #rnnTimeStep(INDArray...)}. * * @param layerName name of the layer * @return Hidden state, or null if layer is not an RNN layer */ public Map<String, INDArray> rnnGetPreviousState(String layerName) { Layer l = verticesMap.get(layerName).getLayer(); if (l == null || !(l instanceof RecurrentLayer)) return null; return ((RecurrentLayer) l).rnnGetPreviousState(); }
/** * Pretrain network with multiple inputs and/or outputs */ public void pretrain(MultiDataSetIterator iter) { if (!configuration.isPretrain()) return; if (flattenedGradients == null) { initGradientsView(); } //Assume here that all layers are pretrainable layers for (int i = 0; i < topologicalOrder.length; i++) { if (!vertices[i].hasLayer()) continue; if (vertices[i].getLayer() instanceof IOutputLayer) continue; //Don't pretrain output layer if (!vertices[i].getLayer().isPretrainLayer()) continue; //Skip layers that aren't pretrainable pretrainLayer(vertices[i].getVertexName(), iter); } }
@Override public void setBackpropGradientsViewArray(INDArray gradient) { int paramsSoFar = 0; for (int i = 0; i < topologicalOrder.length; i++) { if (!vertices[topologicalOrder[i]].hasLayer()) continue; Layer layer = vertices[topologicalOrder[i]].getLayer(); int range = layer.numParams(); if (range <= 0) continue; //Some layers: no parameters (subsampling etc) layer.setBackpropGradientsViewArray(gradient.get(NDArrayIndex.point(0), NDArrayIndex.interval(paramsSoFar, paramsSoFar + range))); paramsSoFar += range; } }
/** * Set the state of the RNN layer, for use in {@link #rnnTimeStep(INDArray...)} * * @param layerName The name of the layer. * @param state The state to set the specified layer to */ public void rnnSetPreviousState(String layerName, Map<String, INDArray> state) { Layer l = verticesMap.get(layerName).getLayer(); if (l == null || !(l instanceof RecurrentLayer)) { throw new UnsupportedOperationException( "Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state"); } ((RecurrentLayer) l).rnnSetPreviousState(state); }
@Override public void setParams(INDArray params) { if (params == flattenedParams) return; //No op if (this.flattenedParams != null && this.flattenedParams.length() == params.length()) { this.flattenedParams.assign(params); return; } int idx = 0; for (int i = 0; i < topologicalOrder.length; i++) { if (!vertices[topologicalOrder[i]].hasLayer()) continue; Layer layer = vertices[topologicalOrder[i]].getLayer(); int range = layer.numParams(); if (range <= 0) continue; //Some layers: no parameters (subsampling etc) INDArray get = params.get(NDArrayIndex.point(0), NDArrayIndex.interval(idx, range + idx)); layer.setParams(get); idx += range; } }