/** * Run SGD based on the given labels */ public void finetune() { if (!layerWiseConfigurations.isBackprop()) { log.warn("Warning: finetune is not applied."); return; } if (!(getOutputLayer() instanceof IOutputLayer)) { log.warn("Output layer not instance of output layer returning."); return; } if (flattenedGradients == null) { initGradientsView(); } if (labels == null) throw new IllegalStateException("No labels found"); log.info("Finetune phase"); IOutputLayer output = (IOutputLayer) getOutputLayer(); if (output.conf().getOptimizationAlgo() != OptimizationAlgorithm.HESSIAN_FREE) { feedForward(); output.fit(output.input(), labels); } else { throw new UnsupportedOperationException(); } }
private void initHelperMLN() { if (applyFrozen) { org.deeplearning4j.nn.api.Layer[] layers = origMLN.getLayers(); for (int i = frozenTill; i >= 0; i--) { //unchecked? layers[i] = new FrozenLayer(layers[i]); } origMLN.setLayers(layers); } for (int i = 0; i < origMLN.getnLayers(); i++) { if (origMLN.getLayer(i) instanceof FrozenLayer) { frozenInputLayer = i; } } List<NeuralNetConfiguration> allConfs = new ArrayList<>(); for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) { allConfs.add(origMLN.getLayer(i).conf()); } MultiLayerConfiguration c = origMLN.getLayerWiseConfigurations(); unFrozenSubsetMLN = new MultiLayerNetwork(new MultiLayerConfiguration.Builder().backprop(c.isBackprop()) .inputPreProcessors(c.getInputPreProcessors()).pretrain(c.isPretrain()) .backpropType(c.getBackpropType()).tBPTTForwardLength(c.getTbpttFwdLength()) .tBPTTBackwardLength(c.getTbpttBackLength()).confs(allConfs).build()); unFrozenSubsetMLN.init(); //copy over params for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) { unFrozenSubsetMLN.getLayer(i - frozenInputLayer - 1).setParams(origMLN.getLayer(i).params()); } //unFrozenSubsetMLN.setListeners(origMLN.getListeners()); }
ComputationGraph.workspaceCache); if (layerWiseConfigurations.isBackprop()) { update(TaskUtils.buildTask(iter)); if (!iter.hasNext() && iter.resetSupported()) {
if (layerWiseConfigurations.isBackprop()) { if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) { doTruncatedBPTT(features, labels, featuresMask, labelsMask);