/** * This method does loops encoded data back to updates queue * @param message */ protected void sendMessage(INDArray message) { //INDArray update = decodeUpdates(message); accumulator.receiveUpdate(message); }
@Override protected void postInit() { super.postInit(); if (accumulator == null) { log.warn("GradientsAccumulator is undefined, gradients sharing will be skipped"); return; } // just pass accumulator down the hill if (replicatedModel instanceof ComputationGraph) { ((ComputationGraph) replicatedModel).setGradientsAccumulator(accumulator); } else if (replicatedModel instanceof MultiLayerNetwork) { ((MultiLayerNetwork) replicatedModel).setGradientsAccumulator(accumulator); } // need to attach this device id to accumulator's workspaces accumulator.touch(); }
accumulator.storeUpdate(gradient.gradient()); accumulator.applyUpdate(stepFunction, params, gradient.gradient());
@Override protected void postInit() { super.postInit(); if (accumulator == null) { log.warn("GradientsAccumulator is undefined, gradients sharing will be skipped"); return; } // just pass accumulator down the hill if (replicatedModel instanceof ComputationGraph) { ((ComputationGraph) replicatedModel).setGradientsAccumulator(accumulator); } else if (replicatedModel instanceof MultiLayerNetwork) { ((MultiLayerNetwork) replicatedModel).setGradientsAccumulator(accumulator); } // need to attach this device id to accumulator's workspaces accumulator.touch(); }
@Override public boolean broadcastUpdates(INDArray updates) { // we just loop back data immediately accumulator.receiveUpdate(updates); updates.assign(0.0); Nd4j.getExecutioner().commit(); return true; } }