private String getZookeeperServerFolder(MasterContext<MASTER_RESULT, WORKER_RESULT> context) { String defaultZooKeeperServePath = new StringBuilder(200).append("tmp").append(Path.SEPARATOR) .append("_guagua").append(Path.SEPARATOR).append(context.getAppId()).append(Path.SEPARATOR).toString(); String hdfsZookeeperServerPath = context.getProps().getProperty( GuaguaConstants.GUAGUA_ZK_CLUSTER_SERVER_FOLDER, defaultZooKeeperServePath); return hdfsZookeeperServerPath; }
@Override public void postIteration(MasterContext<MASTER_RESULT, WORKER_RESULT> context) { LOG.debug("post application:{} container:{} iteration:{}, context:{}", context.getAppId(), context.getContainerId(), context.getCurrentIteration(), context); }
/** * Check whether GuaguaConstants.GUAGUA_WORKER_HALT_ENABLE) is enabled, if yes, check whether all workers are halted * and update master status. */ protected void updateMasterHaltStatus(final MasterContext<MASTER_RESULT, WORKER_RESULT> context) { MASTER_RESULT result = context.getMasterResult(); // a switch to make all workers have the right to terminate the application if(Boolean.TRUE.toString().equalsIgnoreCase( context.getProps().getProperty(GuaguaConstants.GUAGUA_WORKER_HALT_ENABLE, GuaguaConstants.GUAGUA_WORKER_DEFAULT_HALT_ENABLE))) { if(isAllWorkersHalt(context.getWorkerResults()) && result instanceof HaltBytable) { ((HaltBytable) result).setHalt(true); context.setMasterResult(result); } } }
@Override public void preIteration(MasterContext<MASTER_RESULT, WORKER_RESULT> context) { LOG.debug("pre application:{} container:{} iteration:{}, context:{}", context.getAppId(), context.getCurrentIteration(), context); }
@Override public void preApplication(MasterContext<MASTER_RESULT, WORKER_RESULT> context) { LOG.debug("pre application:{}, container:{} context:{}", context.getAppId(), context.getContainerId(), context); }
/** * Get output file setting and write final sum value to HDFS file. */ @Override public void postApplication(final MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { LOG.info("Starts to write final value to file."); Path out = new Path(context.getProps().getProperty("lr.model.output")); LOG.info("Writing results to {}", out.toString()); PrintWriter pw = null; try { FSDataOutputStream fos = FileSystem.get(new Configuration()).create(out); pw = new PrintWriter(fos); pw.println(Arrays.toString(context.getMasterResult().getParameters())); pw.flush(); } catch (IOException e) { LOG.error("Error in writing output.", e); } finally { IOUtils.closeStream(pw); } } }
@Override public LogisticRegressionParams doCompute(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { if(context.isFirstIteration()) { initWeights(); } else { double[] gradients = new double[this.inputNum + 1]; double sumError = 0.0d; int size = 0; for(LogisticRegressionParams param: context.getWorkerResults()) { if(param != null) { for(int i = 0; i < gradients.length; i++) { gradients[i] += param.getParameters()[i]; } sumError += param.getError(); } size++; } for(int i = 0; i < weights.length; i++) { weights[i] -= learnRate * gradients[i]; } LOG.debug("DEBUG: Weights: {}", Arrays.toString(this.weights)); LOG.info("Iteration {} with error {}", context.getCurrentIteration(), sumError / size); } return new LogisticRegressionParams(weights); }
@Override public void doExecute() throws KeeperException, InterruptedException { NettyMasterCoordinator.this.masterResult = context.getMasterResult(); String appCurrentMasterNode = getCurrentMasterNode(context.getAppId(), context.getCurrentIteration()) .toString(); String appCurrentMasterSplitNode = getCurrentMasterSplitNode(context.getAppId(), context.getCurrentIteration()).toString(); LOG.debug("master result:{}", context.getMasterResult()); final long start = System.nanoTime(); try { byte[] bytes = getMasterSerializer().objectToBytes(context.getMasterResult()); isSplit = setBytesToZNode(appCurrentMasterNode, appCurrentMasterSplitNode, bytes, CreateMode.PERSISTENT); clear(context.getProps()); NettyMasterCoordinator.this.currentInteration = context.getCurrentIteration() + 1; NettyMasterCoordinator.this.canUpdateWorkerResultMap = true; context.getProps().getProperty(GuaguaConstants.GUAGUA_CLEANUP_INTERVAL), GuaguaConstants.GUAGUA_DEFAULT_CLEANUP_INTERVAL); if(context.getCurrentIteration() >= (resultCleanUpInterval + 1)) { final boolean isLocalSplit = isSplit; NettyMasterCoordinator.this.cleanOldZkDataThreadPool.submit(new Runnable() {
@Override public void run() { saveTmpModelToHDFS(currentIteration - 1, parameters); // save model results for continue model training, if current job is failed, then next running // we can start from this point to save time. // another case for master recovery, if master is failed, read such checkpoint model Path out = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); // if current iteration is the last iteration, or it is halted by early stop condition, no // need to save checkpoint model here as it is replicated with postApplicaiton. // There is issue here if saving the same model in this thread and another thread in // postApplication, sometimes this conflict will cause model writing failed. if(!isHalt && currentIteration != totalIteration) { writeModelWeightsToFileSystem(optimizedWeights, out); } } }, "saveTmpModelToHDFS thread");
@Override public void init(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) { loadConfigFiles(context.getProps()); int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING)); LOG.info("continuousEnabled: {}", this.isContinuousEnabled); if(!context.isFirstIteration()) { LogisticRegressionParams lastMasterResult = context.getMasterResult(); if(lastMasterResult != null && lastMasterResult.getParameters() != null) {
if(!context.isFirstIteration()) { LogisticRegressionParams lastMasterResult = context.getMasterResult(); if(lastMasterResult != null && lastMasterResult.getParameters() != null) { if(context.isFirstIteration()) { if(this.isContinuousEnabled) { return initOrRecoverParams(context); double trainError = 0.0d, testError = 0d; long trainSize = 0, testSize = 0; for(LogisticRegressionParams param: context.getWorkerResults()) { if(param != null) { for(int i = 0; i < gradients.length; i++) { (context.getCurrentIteration() - 1)); LOG.info("Iteration {} with train error {}, test error {}", context.getCurrentIteration(), finalTrainError, finalTestError); LogisticRegressionParams lrParams = new LogisticRegressionParams(weights, finalTrainError, finalTestError, LOG.info("LRMaster compute iteration {} converged !", context.getCurrentIteration()); lrParams.setHalt(true); } else { LOG.debug("LRMaster compute iteration {} not converged yet !", context.getCurrentIteration());
int[] dumpArray = { context.getWorkers() / 4, context.getWorkers() * 2 / 4, context.getWorkers() * 3 / 4, context.getWorkers() }; for(int i = nextIndex; i < dumpArray.length; i++) { if(doneWorkers >= dumpArray[i]) { nextIndex = i + 1; LOG.info("Iteration {}, workers compelted: {}, still {} workers are not synced (fixed).", context.getCurrentIteration(), doneWorkers, (context.getWorkers() - doneWorkers)); if(context.isFirstIteration() || context.getCurrentIteration() == context.getTotalIteration()) { timeOut = context.getMinWorkersTimeOut(); boolean isTerminated = isTerminated(doneWorkers, context.getWorkers(), context.getMinWorkersRatio(), timeOut); if(isTerminated) { + "minWorkersRatio {} minWorkersTimeOut {}.", context.getCurrentIteration(), context.getWorkers(), doneWorkers, context.getMinWorkersRatio(), GuaguaConstants.GUAGUA_DEFAULT_MIN_WORKERS_TIMEOUT);
context.getProps().getProperty(GuaguaConstants.GUAGUA_ZK_CLEANUP_ENABLE), GuaguaConstants.GUAGUA_ZK_DEFAULT_CLEANUP_VALUE); String appId = context.getAppId(); boolean isLastMaster = true; if(NumberFormatUtils.getInt(context.getProps().getProperty(GuaguaConstants.GUAGUA_MASTER_NUMBER), GuaguaConstants.DEFAULT_MASTER_NUMBER) > 1) { String masterElectionPath = getBaseMasterElectionNode(appId).toString(); final int currentIteration = context.getCurrentIteration(); final int workers = context.getWorkers(); final String endWorkersNode = getWorkerBaseNode(appId, currentIteration).toString();
double currentError = ((modelConfig.getTrain().getValidSetRate() < EPSILON) ? context.getMasterResult() .getTrainError() : context.getMasterResult().getTestError()); this.optimizedWeights = context.getMasterResult().getParameters(); final int tmpModelFactor = DTrainUtils.tmpModelFactor(context.getTotalIteration()); final int currentIteration = context.getCurrentIteration(); final double[] parameters = context.getMasterResult().getParameters(); final int totalIteration = context.getTotalIteration(); final boolean isHalt = context.getMasterResult().isHalt();
public void preApplication(final MasterContext<MASTER_RESULT, WORKER_RESULT> context) { initialize(context.getProps()); this.workerClassName = context.getWorkerResultClassName(); this.totalInteration = context.getTotalIteration(); if(!context.isInitIteration()) { new BasicCoordinatorCommand() { @Override startNettyServer(context.getProps()); if(context.isInitIteration()) { this.currentInteration = GuaguaConstants.GUAGUA_INIT_STEP; } else { this.currentInteration = context.getCurrentIteration(); if(!context.isInitIteration()) { this.clear(context.getProps());
@SuppressWarnings("deprecation") private void updateProgressLog(final MasterContext<NNParams, NNParams> context) { int currentIteration = context.getCurrentIteration(); if(context.isFirstIteration()) { // first iteration is used for training preparation return; } String progress = new StringBuilder(200).append(" Trainer ").append(this.trainerId).append(" Epoch #") .append(currentIteration - 1).append(" Training Error:") .append(String.format("%.10f", context.getMasterResult().getTrainError())).append(" Validation Error:") .append(String.format("%.10f", context.getMasterResult().getTestError())).append("\n").toString(); try { LOG.debug("Writing progress results to {} {}", context.getCurrentIteration(), progress.toString()); this.progressOutput.write(progress.getBytes("UTF-8")); this.progressOutput.flush(); this.progressOutput.sync(); } catch (IOException e) { LOG.error("Error in write progress log:", e); } }
@Override public void preApplication(final MasterContext<MASTER_RESULT, WORKER_RESULT> context) { initialize(context.getProps()); if(NumberFormatUtils.getInt(context.getProps().getProperty(GuaguaConstants.GUAGUA_MASTER_NUMBER), GuaguaConstants.DEFAULT_MASTER_NUMBER) > 1) { new MasterElectionCommand(context.getAppId()).execute(); if(context.getCurrentIteration() != GuaguaConstants.GUAGUA_INIT_STEP) {
/** * Master computation by accumulating all the k center points sum values from all workers, then average to get new k * center points. * * @throws NullPointerException * if worker result or worker results is null. */ @Override public KMeansMasterParams compute(MasterContext<KMeansMasterParams, KMeansWorkerParams> context) { if(context.getWorkerResults() == null) { throw new NullPointerException("No worker results received in Master."); } if(context.getCurrentIteration() == 1) { return doFirstIteration(context); } else { return doOtherIterations(context); } }
@Override public void init(MasterContext<DTMasterParams, DTWorkerParams> context) { Properties props = context.getProps(); int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0")); this.checkpointInterval = NumberFormatUtils.getInt(context.getProps().getProperty( CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_INTERVAL, "20")); this.checkpointOutput = new Path(context.getProps().getProperty( CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId())); context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING)); if(context.isFirstIteration()) { if(this.isRF) { TreeModel existingModel; try { Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT)); existingModel = (TreeModel) CommonUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
workerBaseNode = getWorkerBaseNode(context.getAppId(), context.getCurrentIteration() + 1) .toString(); getZooKeeper().createExt(workerBaseNode, null, Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT, false); String appCurrentMasterNode = getCurrentMasterNode(context.getAppId(), context.getCurrentIteration()) .toString(); String appCurrentMasterSplitNode = getCurrentMasterSplitNode(context.getAppId(), context.getCurrentIteration()).toString(); try { byte[] bytes = getMasterSerializer().objectToBytes(context.getMasterResult()); isSplit = setBytesToZNode(appCurrentMasterNode, appCurrentMasterSplitNode, bytes, CreateMode.PERSISTENT); if(context.getCurrentIteration() >= 2) { String znode = getMasterNode(context.getAppId(), context.getCurrentIteration() - 2).toString(); try { getZooKeeper().deleteExt(znode, -1, false); if(isSplit) { znode = getCurrentMasterSplitNode(context.getAppId(), context.getCurrentIteration() - 2) .toString(); getZooKeeper().deleteExt(znode, -1, true);