/** * Sample a dataset * * @param numSamples the number of samples to getFromOrigin * @param rng the rng to use * @param withReplacement whether to allow duplicates (only tracked by example row number) * @return the sample dataset */ @Override public DataSet sample(int numSamples, org.nd4j.linalg.api.rng.Random rng, boolean withReplacement) { INDArray examples = Nd4j.create(numSamples, getFeatures().columns()); INDArray outcomes = Nd4j.create(numSamples, numOutcomes()); Set<Integer> added = new HashSet<>(); for (int i = 0; i < numSamples; i++) { int picked = rng.nextInt(numExamples()); if (!withReplacement) while (added.contains(picked)) picked = rng.nextInt(numExamples()); examples.putRow(i, get(picked).getFeatures()); outcomes.putRow(i, get(picked).getLabels()); } return new DataSet(examples, outcomes); }
/** * Gets a copy of example i * * @param i the example to getFromOrigin * @return the example at i (one example) */ @Override public DataSet get(int i) { if (i > numExamples() || i < 0) throw new IllegalArgumentException("invalid example number"); if (i == 0 && numExamples() == 1) return this; if (getFeatureMatrix().rank() == 4) { //ensure rank is preserved INDArray slice = getFeatureMatrix().slice(i); return new DataSet(slice.reshape(ArrayUtil.combine(new long[] {1}, slice.shape())), getLabels().slice(i)); } return new DataSet(getFeatures().slice(i), getLabels().slice(i)); }
@Override public int hashCode() { int result = getFeatures() != null ? getFeatures().hashCode() : 0; result = 31 * result + (getLabels() != null ? getLabels().hashCode() : 0); result = 31 * result + (getFeaturesMaskArray() != null ? getFeaturesMaskArray().hashCode() : 0); result = 31 * result + (getLabelsMaskArray() != null ? getLabelsMaskArray().hashCode() : 0); return result; }
/** * Reshapes the input in to the given rows and columns * * @param rows the row size * @param cols the column size * @return a copy of this data op with the input resized */ @Override public DataSet reshape(int rows, int cols) { DataSet ret = new DataSet(getFeatures().reshape(new long[] {rows, cols}), getLabels()); return ret; }
@Override public int numExamples() { // FIXME: int cast if (getFeatureMatrix() != null) return (int) getFeatureMatrix().size(0); else if (getLabels() != null) return (int) getLabels().size(0); return 0; }
while (iterator.hasNext()) { DataSet next = iterator.next(); runningTotal += next.numExamples(); batchCount = next.getFeatures().size(0); if (mean == null) { mean = next.getFeatureMatrix().mean(0); std = (batchCount == 1) ? Nd4j.zeros(mean.shape()) : Transforms.pow(next.getFeatureMatrix().std(0), 2); std.muli(batchCount); } else { INDArray xMinusMean = next.getFeatureMatrix().subRowVector(mean); INDArray newMean = mean.add(xMinusMean.sum(0).divi(runningTotal)); INDArray meanB = next.getFeatureMatrix().mean(0); INDArray deltaSq = Transforms.pow(meanB.subRowVector(mean), 2); INDArray deltaSqScaled = deltaSq.mul(((float) runningTotal - batchCount) * batchCount / (float) runningTotal); INDArray mtwoB = Transforms.pow(next.getFeatureMatrix().std(0), 2); mtwoB.muli(batchCount); std = std.add(mtwoB); std.divi(runningTotal); std = Transforms.sqrt(std); std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD)); if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans."); iterator.reset();
DataSet filtered = filterBy(labels); List<Integer> newLabels = new ArrayList<>(); for (int i = 0; i < filtered.numExamples(); i++) { DataSet example = filtered.get(i); int o2 = example.outcome(); Integer outcome = labelMap.get(o2); newLabels.add(outcome); INDArray newLabelMatrix = Nd4j.create(filtered.numExamples(), labels.length); if (newLabelMatrix.rows() != newLabels.size()) throw new IllegalStateException("Inconsistent label sizes"); for (int i = 0; i < newLabelMatrix.rows(); i++) { Integer i2 = newLabels.get(i); if (i2 == null) throw new IllegalStateException("Label not found on row " + i); INDArray newRow = FeatureUtil.toOutcomeVector(i2, labels.length); newLabelMatrix.putRow(i, newRow); setFeatures(filtered.getFeatures()); setLabels(newLabelMatrix);
public static void drawMnist(DataSet mnist,INDArray reconstruct) throws InterruptedException { for(int j = 0; j < mnist.numExamples(); j++) { INDArray draw1 = mnist.get(j).getFeatureMatrix().mul(255); INDArray reconstructed2 = reconstruct.getRow(j); INDArray draw2 = Nd4j.getDistributions().createBinomial(1,reconstructed2).sample(reconstructed2.shape()).mul(255); DrawReconstruction d = new DrawReconstruction(draw1); d.title = "REAL"; d.draw(); DrawReconstruction d2 = new DrawReconstruction(draw2,1000,1000); d2.title = "TEST"; d2.draw(); Thread.sleep(1000); d.frame.dispose(); d2.frame.dispose(); } }
INDArray f = Nd4j.create(new int[]{examples.size(), vectorSize, maxLength},'f'); INDArray l = Nd4j.create(examples.size(), numClasses); INDArray fm = (needsFM ? Nd4j.create(examples.size(), maxLength) : null); INDArray w = Transforms.unitVec(storage.get(idxs[j]).dup()); f.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)).assign(w); return new DataSet(f,l,fm,null);
/** * Binarizes the dataset such that any number greater than cutoff is 1 otherwise zero * * @param cutoff the cutoff point */ @Override public void binarize(double cutoff) { INDArray linear = getFeatureMatrix().linearView(); for (int i = 0; i < getFeatures().length(); i++) { double curr = linear.getDouble(i); if (curr > cutoff) getFeatures().putScalar(i, 1); else getFeatures().putScalar(i, 0); } }
INDArray features = Nd4j.create(featuresShape); indices[3] = NDArrayIndex.point(sentenceLength); features.put(indices, vector); featuresMask = Nd4j.create(currMinibatchSize, maxLength); featuresMask.getRow(i).assign(Double.valueOf(1.0D)); } else { featuresMask.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.interval(0, sentenceLength)}).assign(Double.valueOf(1.0D)); DataSet ds = new DataSet(features, labels, featuresMask, null); if(this.dataSetPreProcessor != null) { this.dataSetPreProcessor.preProcess(ds); this.cursor += ds.numExamples(); return ds;
INDArray features = Nd4j.create(reviews.size(), vectorSize, maxLength); INDArray labels = Nd4j.create(reviews.size(), 2, maxLength); //Two labels: positive or negative INDArray featuresMask = Nd4j.zeros(reviews.size(), maxLength); INDArray labelsMask = Nd4j.zeros(reviews.size(), maxLength); String token = tokens.get(j); INDArray vector = wordVectors.getWordVectorMatrix(token); features.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector); featuresMask.putScalar(temp, 1.0); //Word is present (not padding) for this example + time step -> 1.0 in features mask labels.putScalar(new int[]{i,idx,lastIdx-1},1.0); //Set label: [0,1] for negative, [1,0] for positive labelsMask.putScalar(new int[]{i,lastIdx-1},1.0); //Specify that an output exists at the final time step for this example return new DataSet(features,labels,featuresMask,labelsMask);
@Override public List<DataSet> asList() { List<DataSet> list = new ArrayList<>(numExamples()); INDArray featuresHere, labelsHere, featureMaskHere, labelMaskHere; int rank = getFeatures().rank(); int labelsRank = getLabels().rank(); for (int i = 0; i < numExamples(); i++) { switch (rank) { case 2: featuresHere = getFeatures().get(interval(i, i, true), all()); case 3: featuresHere = getFeatures().get(interval(i, i, true), all(), all()); case 4: featuresHere = getFeatures().get(interval(i, i, true), all(), all(), all()); throw new IllegalStateException( "Cannot convert to list: feature set rank must be in range 2 to 4 inclusive. First example labels shape: " + Arrays.toString(getFeatures().shape())); labelsHere = getLabels().get(interval(i, i, true), all()); case 3: labelsHere = getLabels().get(interval(i, i, true), all(), all()); case 4: labelsHere = getLabels().get(interval(i, i, true), all(), all(), all());
/** * Clone the dataset * * @return a clone of the dataset */ @Override public DataSet copy() { DataSet ret = new DataSet(getFeatures().dup(), getLabels().dup()); if (getLabelsMaskArray() != null) ret.setLabelsMaskArray(getLabelsMaskArray().dup()); if (getFeaturesMaskArray() != null) ret.setFeaturesMaskArray(getFeaturesMaskArray().dup()); ret.setColumnNames(getColumnNames()); ret.setLabelNames(getLabelNames()); return ret; }
boolean first = true; for(DataSet ds : data){ if(ds.isEmpty()){ continue; if(anyFeaturesPreset && ds.getFeatures() == null || (!first && !anyFeaturesPreset && ds.getFeatures() != null)){ throw new IllegalStateException("Cannot merge features: encountered null features in one or more DataSets"); if(anyLabelsPreset && ds.getLabels() == null || (!first && !anyLabelsPreset && ds.getLabels() != null)){ throw new IllegalStateException("Cannot merge labels: enountered null labels in one or more DataSets"); anyFeaturesPreset |= ds.getFeatures() != null; anyLabelsPreset |= ds.getLabels() != null; first = false; int count = 0; for (DataSet ds : data) { if(ds.isEmpty()) continue; featuresToMerge[count] = ds.getFeatureMatrix(); labelsToMerge[count] = ds.getLabels(); if (ds.getFeaturesMaskArray() != null) { if (featuresMasksToMerge == null) { featuresMasksToMerge = new INDArray[data.size()]; featuresMasksToMerge[count] = ds.getFeaturesMaskArray(); if (ds.getLabelsMaskArray() != null) {
public DataSet next(int num) { if( exampleStartOffsets.size() == 0 ) throw new NoSuchElementException(); int currMinibatchSize = Math.min(num, exampleStartOffsets.size()); //Allocate space: //Note the order here: // dimension 0 = number of examples in minibatch // dimension 1 = size of each vector (i.e., number of characters) // dimension 2 = length of each time series/example //Why 'f' order here? See http://deeplearning4j.org/usingrnns.html#data section "Alternative: Implementing a custom DataSetIterator" INDArray input = Nd4j.create(new int[]{currMinibatchSize,validCharacters.length,exampleLength}, 'f'); INDArray labels = Nd4j.create(new int[]{currMinibatchSize,validCharacters.length,exampleLength}, 'f'); for( int i=0; i<currMinibatchSize; i++ ){ int startIdx = exampleStartOffsets.removeFirst(); int endIdx = startIdx + exampleLength; int currCharIdx = charToIdxMap.get(fileCharacters[startIdx]); //Current input int c=0; for( int j=startIdx+1; j<endIdx; j++, c++ ){ int nextCharIdx = charToIdxMap.get(fileCharacters[j]); //Next character to predict input.putScalar(new int[]{i,currCharIdx,c}, 1.0); labels.putScalar(new int[]{i,nextCharIdx,c}, 1.0); currCharIdx = nextCharIdx; } } return new DataSet(input,labels); }
public DataSet convertDataSet(int num) { int batchNumCount = 0; List<DataSet> dataSets = new ArrayList(); FileSystem fs = CommonUtils.openHdfsConnect(); try { while (batchNumCount != num && fileIterator.hasNext()) { ++ batchNumCount; String fullPath = fileIterator.next(); Writable labelText = new Text(FilenameUtils.getBaseName((new File(fullPath)).getParent())); INDArray features = null; INDArray label = Nd4j.zeros(1, labels.size()).putScalar(new int[]{0, labels.indexOf(labelText)}, 1); InputStream imageios = fs.open(new Path(fullPath)); features = asMatrix(imageios); imageios.close(); Nd4j.getAffinityManager().tagLocation(features, AffinityManager.Location.HOST); dataSets.add(new DataSet(features, label)); } } catch (Exception e) { throw new RuntimeException(e.getCause()); } finally { CommonUtils.closeHdfsConnect(fs); } if (dataSets.size() == 0) { return new DataSet(); } else { DataSet result = DataSet.merge( dataSets ); return result; } }
if (baseData.numExamples() < batchSize) throw new IllegalAccessError("Number of examples smaller than batch size"); this.batchSize = batchSize; currIdx = 0; paths = new ArrayList<>(); totalExamples = baseData.numExamples(); totalLabels = baseData.numOutcomes(); int offset = 0; totalBatches = baseData.numExamples() / batchSize; for (int i = 0; i < baseData.numExamples() / batchSize; i++) { paths.add(writeData(new DataSet( baseData.getFeatureMatrix().get(NDArrayIndex.interval(offset, offset + batchSize)), baseData.getLabels().get(NDArrayIndex.interval(offset, offset + batchSize))))); offset += batchSize; if (offset >= totalExamples)
if (numExamples() < 2) return; List<int[]> dimensions = new ArrayList<>(); arrays.add(getFeatures()); dimensions.add(ArrayUtil.range(1, getFeatures().rank())); arrays.add(getLabels()); dimensions.add(ArrayUtil.range(1, getLabels().rank())); arrays.add(getFeaturesMaskArray()); dimensions.add(ArrayUtil.range(1, getFeaturesMaskArray().rank())); arrays.add(getLabelsMaskArray()); dimensions.add(ArrayUtil.range(1, getLabelsMaskArray().rank())); Nd4j.shuffle(arrays, new Random(seed), dimensions); int[] map = ArrayUtil.buildInterleavedVector(new Random(seed), numExamples()); ArrayUtil.shuffleWithMap(exampleMetaData, map);
@Override public void roundToTheNearest(int roundTo) { for (int i = 0; i < getFeatures().length(); i++) { double curr = (double) getFeatures().getScalar(i).element(); getFeatures().put(i, Nd4j.scalar(MathUtils.roundDouble(curr, roundTo))); } }