cg.getUpdater().setStateViewArray(updaterState); } else if (gotOldUpdater && updater != null) { cg.setUpdater(updater);
@Override public void updateModel(@NonNull Model model) { this.shouldUpdate.set(true); if (replicatedModel instanceof MultiLayerNetwork) { replicatedModel.setParams(model.params().dup()); Updater updater = ((MultiLayerNetwork) model).getUpdater(); INDArray view = updater.getStateViewArray(); if (view != null) { updater = ((MultiLayerNetwork) replicatedModel).getUpdater(); INDArray viewD = view.dup(); Nd4j.getExecutioner().commit(); updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false); } } else if (replicatedModel instanceof ComputationGraph) { replicatedModel.setParams(model.params().dup()); ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater(); INDArray view = updater.getStateViewArray(); if (view != null) { INDArray viewD = view.dup(); Nd4j.getExecutioner().commit(); updater = ((ComputationGraph) replicatedModel).getUpdater(); updater.setStateViewArray(viewD); } } Nd4j.getExecutioner().commit(); }
@Override public void updateModel(@NonNull Model model) { this.shouldUpdate.set(true); if (replicatedModel instanceof MultiLayerNetwork) { replicatedModel.setParams(model.params().dup()); Updater updater = ((MultiLayerNetwork) model).getUpdater(); INDArray view = updater.getStateViewArray(); if (view != null) { updater = ((MultiLayerNetwork) replicatedModel).getUpdater(); INDArray viewD = view.dup(); Nd4j.getExecutioner().commit(); updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false); } } else if (replicatedModel instanceof ComputationGraph) { replicatedModel.setParams(model.params().dup()); ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater(); INDArray view = updater.getStateViewArray(); if (view != null) { INDArray viewD = view.dup(); Nd4j.getExecutioner().commit(); updater = ((ComputationGraph) replicatedModel).getUpdater(); updater.setStateViewArray(viewD); } } Nd4j.getExecutioner().commit(); }
@Override public ComputationGraph clone() { ComputationGraph cg = new ComputationGraph(configuration.clone()); cg.init(params().dup(), false); if (solver != null) { //If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however ComputationGraphUpdater u = this.getUpdater(); INDArray updaterState = u.getStateViewArray(); if (updaterState != null) { cg.getUpdater().setStateViewArray(updaterState.dup()); } } cg.listeners = this.listeners; for (int i = 0; i < topologicalOrder.length; i++) { if (!vertices[topologicalOrder[i]].hasLayer()) continue; String layerName = vertices[topologicalOrder[i]].getVertexName(); if (getLayer(layerName) instanceof FrozenLayer) { cg.getVertex(layerName).setLayerAsFrozen(); } } return cg; }
updaterReplica.setStateViewArray( updaterOrigina.getStateViewArray().unsafeDuplication(true));
updaterReplica.setStateViewArray( updaterOrigina.getStateViewArray().unsafeDuplication(true));