public static void main(String [] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { //import org.deeplearning4j.transferlearning.vgg16 and print summary LOGGER.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n"); ZooModel zooModel = VGG16.builder().build(); ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(); LOGGER.info(vgg16.summary()); //use the TransferLearningHelper to freeze the specified vertices and below //NOTE: This is done in place! Pass in a cloned version of the model if you would prefer to not do this in place TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16, featurizeExtractionLayer); LOGGER.info(vgg16.summary()); FlowerDataSetIterator.setup(batchSize,trainPerc); DataSetIterator trainIter = FlowerDataSetIterator.trainIterator(); DataSetIterator testIter = FlowerDataSetIterator.testIterator(); int trainDataSaved = 0; while(trainIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(trainIter.next()); saveToDisk(currentFeaturized,trainDataSaved,true); trainDataSaved++; } int testDataSaved = 0; while(testIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(testIter.next()); saveToDisk(currentFeaturized,testDataSaved,false); testDataSaved++; } LOGGER.info("Finished pre saving featurized test and train data"); }
public static void main(String [] args) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { //import org.deeplearning4j.transferlearning.vgg16 and print summary LOGGER.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n"); ZooModel zooModel = VGG16.builder().build(); ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(); LOGGER.info(vgg16.summary()); //use the TransferLearningHelper to freeze the specified vertices and below //NOTE: This is done in place! Pass in a cloned version of the model if you would prefer to not do this in place TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16, featurizeExtractionLayer); LOGGER.info(vgg16.summary()); FlowerDataSetIterator.setup(batchSize,trainPerc); DataSetIterator trainIter = FlowerDataSetIterator.trainIterator(); DataSetIterator testIter = FlowerDataSetIterator.testIterator(); int trainDataSaved = 0; while(trainIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(trainIter.next()); saveToDisk(currentFeaturized,trainDataSaved,true); trainDataSaved++; } int testDataSaved = 0; while(testIter.hasNext()) { DataSet currentFeaturized = transferLearningHelper.featurize(testIter.next()); saveToDisk(currentFeaturized,testDataSaved,false); testDataSaved++; } LOGGER.info("Finished pre saving featurized test and train data"); }
TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16Transfer); log.info(transferLearningHelper.unfrozenGraph().summary());
TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16Transfer); log.info(transferLearningHelper.unfrozenGraph().summary());