return (MultiDataSet) mds; else return new MultiDataSet(mds.getFeatures(), mds.getLabels(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays()); mergedLabelsMasks = null; return new MultiDataSet(mergedFeatures, mergedLabels, mergedFeaturesMasks, mergedLabelsMasks);
/** * Clone the dataset * * @return a clone of the dataset */ @Override public MultiDataSet copy() { MultiDataSet ret = new MultiDataSet(copy(getFeatures()), copy(getLabels())); if (labelsMaskArrays != null) { ret.setLabelsMaskArray(copy(labelsMaskArrays)); } if (featuresMaskArrays != null) { ret.setFeaturesMaskArrays(copy(featuresMaskArrays)); } return ret; }
@Override public void shuffle() { List<org.nd4j.linalg.dataset.api.MultiDataSet> split = asList(); Collections.shuffle(split); MultiDataSet mds = merge(split); this.features = mds.features; this.labels = mds.labels; this.featuresMaskArrays = mds.featuresMaskArrays; this.labelsMaskArrays = mds.labelsMaskArrays; this.exampleMetaData = mds.exampleMetaData; }
thisFeatures[j] = getSubsetForExample(features[j], i); thisLabels[j] = getSubsetForExample(labels[j], i); if (featuresMaskArrays[j] == null) continue; thisFeaturesMaskArray[j] = getSubsetForExample(featuresMaskArrays[j], i); if (labelsMaskArrays[j] == null) continue; thisLabelsMaskArray[j] = getSubsetForExample(labelsMaskArrays[j], i); list.add(new MultiDataSet(thisFeatures, thisLabels, thisFeaturesMaskArray, thisLabelsMaskArray));
throw new IllegalArgumentException("Cannot use multidatasets with MultiLayerNetworks."); INDArray[] labels = input.getLabels(); INDArray[] features = input.getFeatures(); if (input.getFeaturesMaskArrays() != null) { throw new IllegalArgumentException("Currently cannot support featurizing datasets with feature masks"); INDArray[] labelMasks = input.getLabelsMaskArrays(); return new MultiDataSet(featuresNow, labels, featureMasks, labelMasks);
sb.append("MultiDataSet: ").append(numFeatureArrays()).append(" input arrays, ") .append(numLabelsArrays()).append(" label arrays, ") .append(nfMask).append(" input masks, ") .append(nlMask).append(" label masks"); for (int i = 0; i < numFeatureArrays(); i++) { sb.append("\n=== INPUT ").append(i).append(" ===\n").append(getFeatures(i).toString().replaceAll(";", "\n")); if (getFeaturesMaskArray(i) != null) { sb.append("\n--- INPUT MASK ---\n") .append(getFeaturesMaskArray(i).toString().replaceAll(";", "\n")); for( int i=0; i<numLabelsArrays(); i++){ sb.append("\n=== LABEL ").append(i).append(" ===\n") .append(getLabels(i).toString().replaceAll(";", "\n")); if (getLabelsMaskArray(i) != null) { sb.append("\n--- LABEL MASK ---\n") .append(getLabelsMaskArray(i).toString().replaceAll(";", "\n"));
new org.nd4j.linalg.dataset.MultiDataSet(fToKeep, lToKeep, fMaskToKeep, lMaskToKeep); MultiDataSet toCache = new org.nd4j.linalg.dataset.MultiDataSet(fToCache, lToCache, fMaskToCache, lMaskToCache); list.add(toKeep); out = list.get(0); } else { out = org.nd4j.linalg.dataset.MultiDataSet.merge(list);
"Currently cannot support featurizing datasets with feature masks"); MultiDataSet inbW = new MultiDataSet(new INDArray[] {input.getFeatures()}, new INDArray[] {input.getLabels()}, null, new INDArray[] {input.getLabelsMaskArray()}); MultiDataSet ret = featurize(inbW); return new DataSet(ret.getFeatures()[0], input.getLabels(), ret.getLabelsMaskArrays()[0], input.getLabelsMaskArray());
public BenchmarkMultiDataSetIterator(MultiDataSet example, int totalIterations) { this.baseFeatures = new INDArray[example.getFeatures().length]; for (int i = 0; i < example.getFeatures().length; i++) { baseFeatures[i] = example.getFeatures()[i].dup(); } this.baseLabels = new INDArray[example.getLabels().length]; for (int i = 0; i < example.getLabels().length; i++) { baseFeatures[i] = example.getLabels()[i].dup(); } Nd4j.getExecutioner().commit(); this.limit = totalIterations; }
@Override public MultiDataSet next(int num) { int end = curr + num; List<MultiDataSet> r = new ArrayList<>(); if (end >= list.size()) { end = list.size(); } for (; curr < end; curr++) { r.add(list.get(curr)); } MultiDataSet d = org.nd4j.linalg.dataset.MultiDataSet.merge(r); if (preProcessor != null) { preProcessor.preProcess(d); } return d; }
@Override public MultiDataSet next() { counter.incrementAndGet(); val p = backedIterator.next(); if (counter.get() == 1 && firstTrain == null) { // first epoch ever, we'll save first dataset and will use it to check for equality later firstTrain = (org.nd4j.linalg.dataset.MultiDataSet) p.copy(); firstTrain.detach(); } else if (counter.get() == 1) { // epoch > 1, comparing first dataset to previously stored dataset. they should be equal int cnt = 0; for (val c: p.getFeatures()) if (!c.equalsWithEps(firstTrain.getFeatures()[cnt++], 1e-5)) throw new ND4JIllegalStateException("First examples do not match. Randomization was used?"); } return p; }
@Override public String toString() { StringBuilder builder = new StringBuilder(); int totalEntries = numFeatureArrays(); if (totalEntries != numLabelsArrays()) { return ""; } for (int i = 0; i < totalEntries; i++) { builder.append("\n=========== ENTRY " + i + " =================\n"); builder.append("\n=== INPUT ===\n").append(getFeatures(i).toString().replaceAll(";", "\n")) .append("\n=== OUTPUT ===\n").append(getLabels(i).toString().replaceAll(";", "\n")); if (getFeaturesMaskArray(i) != null) { builder.append("\n=== INPUT MASK ===\n") .append(getFeaturesMaskArray(i).toString().replaceAll(";", "\n")); } if (getLabelsMaskArray(i) != null) { builder.append("\n=== OUTPUT MASK ===\n") .append(getLabelsMaskArray(i).toString().replaceAll(";", "\n")); } } return builder.toString(); }
new org.nd4j.linalg.dataset.MultiDataSet(fToKeep, lToKeep, fMaskToKeep, lMaskToKeep); MultiDataSet toCache = new org.nd4j.linalg.dataset.MultiDataSet(fToCache, lToCache, fMaskToCache, lMaskToCache); list.add(toKeep); out = list.get(0); } else { out = org.nd4j.linalg.dataset.MultiDataSet.merge(list);
thisFeatures[j] = getSubsetForExample(features[j], i); thisLabels[j] = getSubsetForExample(labels[j], i); if (featuresMaskArrays[j] == null) continue; thisFeaturesMaskArray[j] = getSubsetForExample(featuresMaskArrays[j], i); if (labelsMaskArrays[j] == null) continue; thisLabelsMaskArray[j] = getSubsetForExample(labelsMaskArrays[j], i); list.add(new MultiDataSet(thisFeatures, thisLabels, thisFeaturesMaskArray, thisLabelsMaskArray));
public BenchmarkMultiDataSetIterator(MultiDataSet example, int totalIterations) { this.baseFeatures = new INDArray[example.getFeatures().length]; for (int i = 0; i < example.getFeatures().length; i++) { baseFeatures[i] = example.getFeatures()[i].dup(); } this.baseLabels = new INDArray[example.getLabels().length]; for (int i = 0; i < example.getLabels().length; i++) { baseFeatures[i] = example.getLabels()[i].dup(); } Nd4j.getExecutioner().commit(); this.limit = totalIterations; }
@Override public MultiDataSet merge(List<MultiDataSet> toMerge) { return org.nd4j.linalg.dataset.MultiDataSet.merge(toMerge); }
/** * Returns the next element in the iteration. * * @return the next element in the iteration */ @Override public MultiDataSet next() { counter.incrementAndGet(); INDArray[] features = new INDArray[baseFeatures.length]; for (int i = 0; i < baseFeatures.length; i++) { features[i] = baseFeatures[i]; } INDArray[] labels = new INDArray[baseLabels.length]; for (int i = 0; i < baseLabels.length; i++) { labels[i] = baseLabels[i]; } MultiDataSet ds = new MultiDataSet(features, labels); return ds; }
/** * Clone the dataset * * @return a clone of the dataset */ @Override public MultiDataSet copy() { MultiDataSet ret = new MultiDataSet(copy(getFeatures()), copy(getLabels())); if (labelsMaskArrays != null) { ret.setLabelsMaskArray(copy(labelsMaskArrays)); } if (featuresMaskArrays != null) { ret.setFeaturesMaskArrays(copy(featuresMaskArrays)); } return ret; }
} else if (mds != null) { for (IEvaluation evaluation : evaluations) evalAtIndex(evaluation, mds.getLabels(), ((ComputationGraph) model).output(mds.getFeatures()), 0);
@Override public MultiDataSet next(int num) { int end = curr + num; List<MultiDataSet> r = new ArrayList<>(); if (end >= list.size()) { end = list.size(); } for (; curr < end; curr++) { r.add(list.get(curr)); } MultiDataSet d = org.nd4j.linalg.dataset.MultiDataSet.merge(r); if (preProcessor != null) { preProcessor.preProcess(d); } return d; }