/** * Splits a dataset in to test and train randomly. * This will modify the dataset in place to shuffle it before splitting into test/train! * * @param numHoldout the number to hold out for training * @param rng Random Number Generator to use to shuffle the dataset * @return the pair of datasets for the train test split */ @Override public SplitTestAndTrain splitTestAndTrain(int numHoldout, Random rng) { long seed = rng.nextLong(); this.shuffle(seed); return splitTestAndTrain(numHoldout); }
@Override public SplitTestAndTrain splitTestAndTrain(double fractionTrain) { Preconditions.checkArgument(fractionTrain > 0.0 && fractionTrain < 1.0, "Train fraction must be > 0.0 and < 1.0 - got %s", fractionTrain); int numTrain = (int) (fractionTrain * numExamples()); if (numTrain <= 0) numTrain = 1; return splitTestAndTrain(numTrain); }
@Override public SplitTestAndTrain splitTestAndTrain(double percentTrain) { int numPercent = (int) (percentTrain * numExamples()); if (numPercent <= 0) numPercent = 1; return splitTestAndTrain(numPercent); }
/** * Splits a dataset in to test and train randomly. * This will modify the dataset in place to shuffle it before splitting into test/train! * * @param numHoldout the number to hold out for training * @param rng Random Number Generator to use to shuffle the dataset * @return the pair of datasets for the train test split */ @Override public SplitTestAndTrain splitTestAndTrain(int numHoldout, Random rng) { long seed = rng.nextLong(); this.shuffle(seed); return splitTestAndTrain(numHoldout); }
DataSet allData = iterator.next(); allData.shuffle(seed); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(trainPercent); //Use 65% of data for training
private void createDataSource() throws IOException, InterruptedException { //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing int numLinesToSkip = 0; String delimiter = ","; RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter); recordReader.initialize(new InputStreamInputSplit(dataFile)); //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network int labelIndex = 11; DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true); DataSet allData = iterator.next(); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training trainingData = testAndTrain.getTrain(); testData = testAndTrain.getTest(); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data normalizer.transform(trainingData); //Apply normalization to the training data normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set }
private void createDataSource() throws IOException, InterruptedException { //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing int numLinesToSkip = 0; String delimiter = ","; RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter); recordReader.initialize(new InputStreamInputSplit(dataFile)); //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2 DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses); DataSet allData = iterator.next(); allData.shuffle(); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training trainingData = testAndTrain.getTrain(); testData = testAndTrain.getTest(); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data normalizer.transform(trainingData); //Apply normalization to the training data normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set }
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);