public static void main(String [] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { //import org.deeplearning4j.transferlearning.vgg16 and print summary LOGGER.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n"); ZooModel zooModel = VGG16.builder().build(); ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(); LOGGER.info(vgg16.summary()); //use the TransferLearningHelper to freeze the specified vertices and below //NOTE: This is done in place! Pass in a cloned version of the model if you would prefer to not do this in place TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16, featurizeExtractionLayer); LOGGER.info(vgg16.summary()); FlowerDataSetIterator.setup(batchSize,trainPerc); DataSetIterator trainIter = FlowerDataSetIterator.trainIterator(); DataSetIterator testIter = FlowerDataSetIterator.testIterator(); int trainDataSaved = 0; while(trainIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(trainIter.next()); saveToDisk(currentFeaturized,trainDataSaved,true); trainDataSaved++; } int testDataSaved = 0; while(testIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(testIter.next()); saveToDisk(currentFeaturized,testDataSaved,false); testDataSaved++; } LOGGER.info("Finished pre saving featurized test and train data"); }
TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16Transfer); log.info(transferLearningHelper.unfrozenGraph().summary()); FileSystem fs = FileSystem.get(sc.hadoopConfiguration()); SparkComputationGraph sparkComputationGraph = new SparkComputationGraph(sc,transferLearningHelper.unfrozenGraph(),tm);
MultiDataSet ret = featurize(inbW); return new DataSet(ret.getFeatures()[0], input.getLabels(), ret.getLabelsMaskArrays()[0], input.getLabelsMaskArray());
public void fitFeaturized(DataSet input) { if (isGraph) { unFrozenSubsetGraph.fit(input); copyParamsFromSubsetGraphToOrig(); } else { unFrozenSubsetMLN.fit(input); copyParamsFromSubsetMLNToOrig(); } }
/** * Returns the unfrozen layers of the MultiLayerNetwork as a multilayernetwork * Note that with each call to featurizedFit the parameters to the original MLN are also updated */ public MultiLayerNetwork unfrozenMLN() { if (isGraph) errorIfGraphIfMLN(); return unFrozenSubsetMLN; }
/** * Expects a computation graph where some vertices are frozen * * @param orig */ public TransferLearningHelper(ComputationGraph orig) { origGraph = orig; initHelperGraph(); }
/** * Fit from a featurized dataset. * The fit is conducted on an internally instantiated subset model that is representative of the unfrozen part of the original model. * After each call on fit the parameters for the original model are updated * * @param iter */ public void fitFeaturized(MultiDataSetIterator iter) { unFrozenSubsetGraph.fit(iter); copyParamsFromSubsetGraphToOrig(); }
/** * Expects a MLN where some layers are frozen * * @param orig */ public TransferLearningHelper(MultiLayerNetwork orig) { isGraph = false; origMLN = orig; initHelperMLN(); }
public void fitFeaturized(DataSetIterator iter) { if (isGraph) { unFrozenSubsetGraph.fit(iter); copyParamsFromSubsetGraphToOrig(); } else { unFrozenSubsetMLN.fit(iter); copyParamsFromSubsetMLNToOrig(); } }
/** * Returns the unfrozen subset of the original computation graph as a computation graph * Note that with each call to featurizedFit the parameters to the original computation graph are also updated */ public ComputationGraph unfrozenGraph() { if (!isGraph) errorIfGraphIfMLN(); return unFrozenSubsetGraph; }
/** * Will modify the given comp graph (in place!) to freeze vertices from input to the vertex specified. * * @param orig Comp graph * @param frozenOutputAt vertex to freeze at (hold params constant during training) */ public TransferLearningHelper(ComputationGraph orig, String... frozenOutputAt) { origGraph = orig; this.frozenOutputAt = frozenOutputAt; applyFrozen = true; initHelperGraph(); }
public void fitFeaturized(MultiDataSet input) { unFrozenSubsetGraph.fit(input); copyParamsFromSubsetGraphToOrig(); }
/** * Will modify the given MLN (in place!) to freeze layers (hold params constant during training) specified and below * * @param orig MLN to freeze * @param frozenTill integer indicating the index of the layer and below to freeze */ public TransferLearningHelper(MultiLayerNetwork orig, int frozenTill) { isGraph = false; this.frozenTill = frozenTill; applyFrozen = true; origMLN = orig; initHelperMLN(); }
public static void main(String [] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { //import org.deeplearning4j.transferlearning.vgg16 and print summary LOGGER.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n"); ZooModel zooModel = VGG16.builder().build(); ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(); LOGGER.info(vgg16.summary()); //use the TransferLearningHelper to freeze the specified vertices and below //NOTE: This is done in place! Pass in a cloned version of the model if you would prefer to not do this in place TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16, featurizeExtractionLayer); LOGGER.info(vgg16.summary()); FlowerDataSetIterator.setup(batchSize,trainPerc); DataSetIterator trainIter = FlowerDataSetIterator.trainIterator(); DataSetIterator testIter = FlowerDataSetIterator.testIterator(); int trainDataSaved = 0; while(trainIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(trainIter.next()); saveToDisk(currentFeaturized,trainDataSaved,true); trainDataSaved++; } int testDataSaved = 0; while(testIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(testIter.next()); saveToDisk(currentFeaturized,testDataSaved,false); testDataSaved++; } LOGGER.info("Finished pre saving featurized test and train data"); }
TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16Transfer); log.info(transferLearningHelper.unfrozenGraph().summary()); ParallelWrapper wrapper = new ParallelWrapper.Builder(transferLearningHelper.unfrozenGraph())
/** * Use to get the output from a featurized input * * @param input featurized data * @return output */ public INDArray[] outputFromFeaturized(INDArray[] input) { if (!isGraph) errorIfGraphIfMLN(); return unFrozenSubsetGraph.output(input); }