ParallelWrapper wrapper = new ParallelWrapper.Builder(vgg16) .prefetchBuffer(24) .workers(Nd4j.getAffinityManager().getNumberOfDevices()) .trainingMode(ParallelWrapper.TrainingMode.SHARED_GRADIENTS) .build();
ParallelWrapper wrapper = new ParallelWrapper.Builder(model) .prefetchBuffer(24) .workers(4) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build();
ParallelWrapper wrapper = new ParallelWrapper.Builder(model) .prefetchBuffer(prefetchSize) .workers(workers) .averagingFrequency(averagingFrequency).averageUpdaters(averageUpdaters) .reportScoreAfterAveraging(reportScore) .build();
ParallelWrapper wrapper = new ParallelWrapper.Builder(model) .prefetchBuffer(prefetchSize) .workers(workers) .averagingFrequency(averagingFrequency).averageUpdaters(averageUpdaters) .reportScoreAfterAveraging(reportScore) .build();
public EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model, DataSetIterator train, MultiDataSetIterator trainMulti, EarlyStoppingListener<T> listener, int workers, int prefetchBuffer, int averagingFrequency, boolean reportScoreAfterAveraging, boolean useLegacyAveraging) { this.esConfig = earlyStoppingConfiguration; this.train = train; this.trainMulti = trainMulti; this.iterator = (train != null ? train : trainMulti); this.listener = listener; this.model = model; // adjust UI listeners AveragingIterationListener trainerListener = new AveragingIterationListener(this); if (model instanceof MultiLayerNetwork) { Collection<IterationListener> listeners = ((MultiLayerNetwork) model).getListeners(); Collection<IterationListener> newListeners = new LinkedList<>(listeners); newListeners.add(trainerListener); model.setListeners(newListeners); } else if (model instanceof ComputationGraph) { Collection<IterationListener> listeners = ((ComputationGraph) model).getListeners(); Collection<IterationListener> newListeners = new LinkedList<>(listeners); newListeners.add(trainerListener); model.setListeners(newListeners); } this.wrapper = new ParallelWrapper.Builder<>(model).workers(workers).prefetchBuffer(prefetchBuffer) .averagingFrequency(averagingFrequency) //.useLegacyAveraging(useLegacyAveraging) .reportScoreAfterAveraging(reportScoreAfterAveraging).build(); }
public EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model, DataSetIterator train, MultiDataSetIterator trainMulti, EarlyStoppingListener<T> listener, int workers, int prefetchBuffer, int averagingFrequency, boolean reportScoreAfterAveraging, boolean useLegacyAveraging) { this.esConfig = earlyStoppingConfiguration; this.train = train; this.trainMulti = trainMulti; this.iterator = (train != null ? train : trainMulti); this.listener = listener; this.model = model; // adjust UI listeners AveragingIterationListener trainerListener = new AveragingIterationListener(this); if (model instanceof MultiLayerNetwork) { Collection<IterationListener> listeners = ((MultiLayerNetwork) model).getListeners(); Collection<IterationListener> newListeners = new LinkedList<>(listeners); newListeners.add(trainerListener); model.setListeners(newListeners); } else if (model instanceof ComputationGraph) { Collection<IterationListener> listeners = ((ComputationGraph) model).getListeners(); Collection<IterationListener> newListeners = new LinkedList<>(listeners); newListeners.add(trainerListener); model.setListeners(newListeners); } this.wrapper = new ParallelWrapper.Builder<>(model).workers(workers).prefetchBuffer(prefetchBuffer) .averagingFrequency(averagingFrequency) //.useLegacyAveraging(useLegacyAveraging) .reportScoreAfterAveraging(reportScoreAfterAveraging).build(); }
public ParallelTrainerOnGPU(ComputationGraph graph, int miniBatchSize, int totalExamplesPerIterator) { String numWorkersString = System.getProperty("framework.parallelWrapper.numWorkers"); int numWorkers = numWorkersString != null ? Integer.parseInt(numWorkersString) : 4; String prefetchBufferString = System.getProperty("framework.parallelWrapper.prefetchBuffer"); int prefetchBuffer = prefetchBufferString != null ? Integer.parseInt(prefetchBufferString) : 12 * numWorkers; String averagingFrequencyString = System.getProperty("framework.parallelWrapper.averagingFrequency"); int averagingFrequency = averagingFrequencyString != null ? Integer.parseInt(averagingFrequencyString) : 3; wrapper = new ParallelWrapper.Builder<>(graph) .prefetchBuffer(prefetchBuffer) .workers(numWorkers) .averagingFrequency(averagingFrequency) .reportScoreAfterAveraging(false) // .useLegacyAveraging(true) .build(); wrapper.setListeners(perListener); this.numExamplesPerIterator = totalExamplesPerIterator; this.miniBatchSize = miniBatchSize; }
ParallelWrapper wrapper = new ParallelWrapper.Builder(model) .prefetchBuffer(24) .workers(2) .workspaceMode(WorkspaceMode.SINGLE) .trainerFactory(new SymmetricTrainerContext()) .trainingMode(ParallelWrapper.TrainingMode.CUSTOM) .gradientsAccumulator(new EncodedGradientsAccumulator(2, 1e-3)) .build();
ParallelWrapper wrapper = new ParallelWrapper.Builder(net) .prefetchBuffer(24) .workers(8) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build();
ParallelWrapper wrapper = new ParallelWrapper.Builder(model) .prefetchBuffer(24) .workers(2) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build();
ParallelWrapper wrapper = new ParallelWrapper.Builder(net) .prefetchBuffer(24) .workers(4) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build();
log.info(transferLearningHelper.unfrozenGraph().summary()); ParallelWrapper wrapper = new ParallelWrapper.Builder(transferLearningHelper.unfrozenGraph()) .prefetchBuffer(24) .workers(Nd4j.getAffinityManager().getNumberOfDevices()) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build();
DataSetIterator test = new ExistingMiniBatchDataSetIterator(new File(TEST_PATH)); ParallelWrapper pw = new ParallelWrapper.Builder<>(net) .prefetchBuffer(16 * Nd4j.getAffinityManager().getNumberOfDevices()) .reportScoreAfterAveraging(true) .averagingFrequency(10) .workers(Nd4j.getAffinityManager().getNumberOfDevices()) .build();