@Override public int numExamples() { return data.numExamples(); }
@Override public int totalExamples() { return data.numExamples(); }
@Override public int numExamples() { return sampleFrom.numExamples(); }
/**Create an iterator given the dataset and a value of k (optional, defaults to 10) * If number of samples in the dataset is not a multiple of k, the last fold will have less samples with the rest having the same number of samples. * * @param k number of folds (optional, defaults to 10) * @param singleFold DataSet to split into k folds */ public KFoldIterator(int k, DataSet singleFold) { this.k = k; this.singleFold = singleFold.copy(); if (k <= 1) throw new IllegalArgumentException(); if (singleFold.numExamples() % k != 0) { if (k != 2) { this.batch = singleFold.numExamples() / (k - 1); this.lastBatch = singleFold.numExamples() % (k - 1); } else { this.lastBatch = singleFold.numExamples() / 2; this.batch = this.lastBatch + 1; } } else { this.batch = singleFold.numExamples() / k; this.lastBatch = singleFold.numExamples() / k; } }
/** * Clears the outcome matrix setting a new number of labels * * @param labels the number of labels/columns in the outcome matrix * Note that this clears the labels for each example */ @Override public void setNewNumberOfLabels(int labels) { int examples = numExamples(); INDArray newOutcomes = Nd4j.create(examples, labels); setLabels(newOutcomes); }
@Override public SplitTestAndTrain splitTestAndTrain(double fractionTrain) { Preconditions.checkArgument(fractionTrain > 0.0 && fractionTrain < 1.0, "Train fraction must be > 0.0 and < 1.0 - got %s", fractionTrain); int numTrain = (int) (fractionTrain * numExamples()); if (numTrain <= 0) numTrain = 1; return splitTestAndTrain(numTrain); }
if (baseData.numExamples() < batchSize) throw new IllegalAccessError("Number of examples smaller than batch size"); this.batchSize = batchSize; currIdx = 0; paths = new ArrayList<>(); totalExamples = baseData.numExamples(); totalLabels = baseData.numOutcomes(); int offset = 0; totalBatches = baseData.numExamples() / batchSize; for (int i = 0; i < baseData.numExamples() / batchSize; i++) { paths.add(writeData(new DataSet( baseData.getFeatureMatrix().get(NDArrayIndex.interval(offset, offset + batchSize)),
List<DataSet> data = asList(); int numLabels = numOutcomes(); int examples = numExamples(); for (DataSet d : data) { int label = d.outcome();
/** * Sets the outcome of a particular example * * @param example the example to transform * @param label the label of the outcome */ @Override public void setOutcome(int example, int label) { if (example > numExamples()) throw new IllegalArgumentException("No example at " + example); if (label > numOutcomes() || label < 0) throw new IllegalArgumentException("Illegal label"); INDArray outcome = FeatureUtil.toOutcomeVector(label, numOutcomes()); getLabels().putRow(example, outcome); }
miniBatchSize = next.numExamples(); for (int i = 0; i < next.numExamples(); i++) { DataSet currExample = next.get(i); if (!labelRootDirs.get(currExample.outcome()).exists())
for (int i = 0; i < filtered.numExamples(); i++) { DataSet example = filtered.get(i); int o2 = example.outcome(); INDArray newLabelMatrix = Nd4j.create(filtered.numExamples(), labels.length);
@Override public void addRow(DataSet d, int i) { if (i > numExamples() || d == null) throw new IllegalArgumentException("Invalid index for adding a row"); getFeatures().putRow(i, d.getFeatures()); getLabels().putRow(i, d.getLabels()); }
if (numExamples() < 2) return; int[] map = ArrayUtil.buildInterleavedVector(new Random(seed), numExamples()); ArrayUtil.shuffleWithMap(exampleMetaData, map);
/** * Sample a dataset * * @param numSamples the number of samples to getFromOrigin * @param rng the rng to use * @param withReplacement whether to allow duplicates (only tracked by example row number) * @return the sample dataset */ @Override public DataSet sample(int numSamples, org.nd4j.linalg.api.rng.Random rng, boolean withReplacement) { INDArray examples = Nd4j.create(numSamples, getFeatures().columns()); INDArray outcomes = Nd4j.create(numSamples, numOutcomes()); Set<Integer> added = new HashSet<>(); for (int i = 0; i < numSamples; i++) { int picked = rng.nextInt(numExamples()); if (!withReplacement) while (added.contains(picked)) picked = rng.nextInt(numExamples()); examples.putRow(i, get(picked).getFeatures()); outcomes.putRow(i, get(picked).getLabels()); } return new DataSet(examples, outcomes); }
/** * Gets a copy of example i * * @param i the example to getFromOrigin * @return the example at i (one example) */ @Override public DataSet get(int i) { if (i > numExamples() || i < 0) throw new IllegalArgumentException("invalid example number"); if (i == 0 && numExamples() == 1) return this; if (getFeatureMatrix().rank() == 4) { //ensure rank is preserved INDArray slice = getFeatureMatrix().slice(i); return new DataSet(slice.reshape(ArrayUtil.combine(new long[] {1}, slice.shape())), getLabels().slice(i)); } return new DataSet(getFeatures().slice(i), getLabels().slice(i)); }
this.cursor += ds.numExamples(); return ds;
if (ds.getExampleMetaData() == null || ds.getExampleMetaData().size() != ds.numExamples()) { meta = null; break;
@Override public List<DataSet> asList() { List<DataSet> list = new ArrayList<>(numExamples()); INDArray featuresHere, labelsHere, featureMaskHere, labelMaskHere; int rank = getFeatures().rank(); for (int i = 0; i < numExamples(); i++) { switch (rank) { case 2:
int numExamples = numExamples(); if (numExamples <= 1) throw new IllegalStateException(
while (iterator.hasNext()) { DataSet next = iterator.next(); runningTotal += next.numExamples(); batchCount = next.getFeatures().size(0); if (mean == null) {