public EncodedGradientsAccumulator build() { if (handler == null) { if (boundary == null) handler = new EncodingHandler(threshold); else handler = new EncodingHandler(threshold, boundary); } EncodedGradientsAccumulator accumulator = new EncodedGradientsAccumulator(parties, handler, initialMemory, queueSize, boundary); return accumulator; } }
/** * Creates new GradientsAccumulator with starting threshold of 1e-3 */ public BasicGradientsAccumulator(int parties) { this(parties, new LocalHandler()); }
protected void synchronize(int consumers) { synchronize(consumers, false); }
.gradientsAccumulator(new EncodedGradientsAccumulator(2, 1e-3))
@Override public E poll() { if (bypassMode.get()) return backingQueue.poll(); // if that's first step, set local step counter to -1 if (currentStep.get() == null) currentStep.set(new AtomicLong(-1)); // we block until everyone else step forward while (step.get() == currentStep.get().get()) LockSupport.parkNanos(1000L); E object = peek(); // we wait until all consumers peek() this object from queue synchronize(currentConsumers.get()); currentStep.get().incrementAndGet(); // last consumer shifts queue on step further if (state.incrementAndGet() == currentConsumers.get()) { // we're removing current head of queue remove(); numElementsDrained.incrementAndGet(); // and moving step counter further state.set(0); step.incrementAndGet(); } // we wait until all consumers know that queue is updated (for isEmpty()) synchronize(currentConsumers.get()); //log.info("Second lock passed"); // now, every consumer in separate threads will get it's own copy of CURRENT head of the queue return object; }
@Override public void registerConsumers(int numConsumers) { // we don't want double spending here if (registered.get()) { if (isDebug) log.info("Master thread locks at RC"); while (registered.get()) { LockSupport.parkNanos(100L); if (throwable.isTriggered()) throw new RuntimeException(throwable.get()); } if (isDebug) log.info("Master thread unlocks at RC"); } // we're passing number of consumers for current session to externalSource, if applicable if (externalSource != null && externalSource instanceof Registerable) ((Registerable) externalSource).registerConsumers(numConsumers); currentConsumers.set(numConsumers); registered.set(true); }
@Override protected void postInit() { super.postInit(); if (accumulator == null) { log.warn("GradientsAccumulator is undefined, gradients sharing will be skipped"); return; } // just pass accumulator down the hill if (replicatedModel instanceof ComputationGraph) { ((ComputationGraph) replicatedModel).setGradientsAccumulator(accumulator); } else if (replicatedModel instanceof MultiLayerNetwork) { ((MultiLayerNetwork) replicatedModel).setGradientsAccumulator(accumulator); } // need to attach this device id to accumulator's workspaces accumulator.touch(); }
@Override public boolean broadcastUpdates(INDArray updates) { /* we want to do 2 things here: 1) encode updates 2) send them somewhere */ INDArray message = encodeUpdates(updates); if (message != null) { sendMessage(message); return true; } else return false; } }
handler.broadcastUpdates(accumulator.get()); synchronize(currentConsumers.get()); } catch (Exception e) { throwable.setIfFirst(e);
accumulator.storeUpdate(gradient.gradient()); accumulator.applyUpdate(stepFunction, params, gradient.gradient());
/** * This method does loops encoded data back to updates queue * @param message */ protected void sendMessage(INDArray message) { //INDArray update = decodeUpdates(message); accumulator.receiveUpdate(message); }
/** * Creates new GradientsAccumulator with custom starting threshold * * @param handler MessageHandler instance that'll be used for communication purposes */ public BasicGradientsAccumulator(int parties, @NonNull MessageHandler handler) { this.gradients = new LinkedTransferQueue<>(); this.handler = handler; this.handler.initialize(this); this.parties = parties; barrier = new CyclicBarrier(parties); }
public EncodedGradientsAccumulator(int parties, double threshold) { this(parties, new EncodingHandler(threshold), 100 * 1024 * 1024L, 10, 1.0); }
@Override public void fallbackToSingleConsumerMode(boolean reallyFallback) { if (externalSource != null && externalSource instanceof Registerable) ((Registerable) externalSource).fallbackToSingleConsumerMode(reallyFallback); bypassMode.set(reallyFallback); }
public static int getOptimalBufferSize(Model model, int numWorkers, int queueSize) { return getOptimalBufferSize(model.params().length(), numWorkers, queueSize); }
if (handler.broadcastUpdates(storage)) { ownCounter.getAndIncrement();
if (this.accumulator == null) { log.info("Creating new GradientsAccumulator instance"); this.accumulator = new EncodedGradientsAccumulator(workers, 1e-3);
@Override protected void postInit() { super.postInit(); if (accumulator == null) { log.warn("GradientsAccumulator is undefined, gradients sharing will be skipped"); return; } // just pass accumulator down the hill if (replicatedModel instanceof ComputationGraph) { ((ComputationGraph) replicatedModel).setGradientsAccumulator(accumulator); } else if (replicatedModel instanceof MultiLayerNetwork) { ((MultiLayerNetwork) replicatedModel).setGradientsAccumulator(accumulator); } // need to attach this device id to accumulator's workspaces accumulator.touch(); }
@Override public boolean broadcastUpdates(INDArray updates) { // we just loop back data immediately accumulator.receiveUpdate(updates); updates.assign(0.0); Nd4j.getExecutioner().commit(); return true; } }
handler.initialize(this);