@Override public int inputColumns() { // FIXME: int cast return (int)list.get(0).getFeatureMatrix().columns(); }
@Override public void apply(Condition condition, Function<Number, Number> function) { BooleanIndexing.applyWhere(getFeatureMatrix(), condition, function); }
@Override public void scaleMinAndMax(double min, double max) { FeatureUtil.scaleMinMax(min, max, getFeatureMatrix()); }
private String[] writeData(DataSet write) throws IOException { String[] ret = new String[2]; String dataSetId = UUID.randomUUID().toString(); BufferedOutputStream dataOut = new BufferedOutputStream(new FileOutputStream(new File(rootDir, dataSetId + ".bin"))); DataOutputStream dos = new DataOutputStream(dataOut); Nd4j.write(write.getFeatureMatrix(), dos); dos.flush(); dos.close(); BufferedOutputStream dataOutLabels = new BufferedOutputStream(new FileOutputStream(new File(rootDir, dataSetId + ".labels.bin"))); DataOutputStream dosLabels = new DataOutputStream(dataOutLabels); Nd4j.write(write.getLabels(), dosLabels); dosLabels.flush(); dosLabels.close(); ret[0] = new File(rootDir, dataSetId + ".bin").getAbsolutePath(); ret[1] = new File(rootDir, dataSetId + ".labels.bin").getAbsolutePath(); return ret; }
/** * Adds a feature for each example on to the current feature vector * * @param toAdd the feature vector to add */ @Override public void addFeatureVector(INDArray toAdd) { setFeatures(Nd4j.hstack(getFeatureMatrix(), toAdd)); }
@Override public int numExamples() { // FIXME: int cast if (getFeatureMatrix() != null) return (int) getFeatureMatrix().size(0); else if (getLabels() != null) return (int) getLabels().size(0); return 0; }
public void fit(DataSet dataSet) { mean = dataSet.getFeatureMatrix().mean(0); std = dataSet.getFeatureMatrix().std(0); std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD)); if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans."); }
/** * Initializes this data transform fetcher from the passed in datasets * * @param examples the examples to use */ protected void initializeCurrFromList(List<DataSet> examples) { if (examples.isEmpty()) log.warn("Warning: empty dataset from the fetcher"); INDArray inputs = createInputMatrix(examples.size()); INDArray labels = createOutputMatrix(examples.size()); for (int i = 0; i < examples.size(); i++) { inputs.putRow(i, examples.get(i).getFeatureMatrix()); labels.putRow(i, examples.get(i).getLabels()); } curr = new DataSet(inputs, labels); }
for (int i = 0; i < baseData.numExamples() / batchSize; i++) { paths.add(writeData(new DataSet( baseData.getFeatureMatrix().get(NDArrayIndex.interval(offset, offset + batchSize)), baseData.getLabels().get(NDArrayIndex.interval(offset, offset + batchSize))))); offset += batchSize;
/** * @Deprecated * Subtract by the column means and divide by the standard deviation */ @Deprecated @Override public void normalizeZeroMeanZeroUnitVariance() { INDArray columnMeans = getFeatures().mean(0); INDArray columnStds = getFeatureMatrix().std(0); setFeatures(getFeatures().subiRowVector(columnMeans)); columnStds.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD)); setFeatures(getFeatures().diviRowVector(columnStds)); }
/** * Gets a copy of example i * * @param i the example to getFromOrigin * @return the example at i (one example) */ @Override public DataSet get(int i) { if (i > numExamples() || i < 0) throw new IllegalArgumentException("invalid example number"); if (i == 0 && numExamples() == 1) return this; if (getFeatureMatrix().rank() == 4) { //ensure rank is preserved INDArray slice = getFeatureMatrix().slice(i); return new DataSet(slice.reshape(ArrayUtil.combine(new long[] {1}, slice.shape())), getLabels().slice(i)); } return new DataSet(getFeatures().slice(i), getLabels().slice(i)); }
/** * Binarizes the dataset such that any number greater than cutoff is 1 otherwise zero * * @param cutoff the cutoff point */ @Override public void binarize(double cutoff) { INDArray linear = getFeatureMatrix().linearView(); for (int i = 0; i < getFeatures().length(); i++) { double curr = linear.getDouble(i); if (curr > cutoff) getFeatures().putScalar(i, 1); else getFeatures().putScalar(i, 0); } }
mean = next.getFeatureMatrix().mean(0); std = (batchCount == 1) ? Nd4j.zeros(mean.shape()) : Transforms.pow(next.getFeatureMatrix().std(0), 2); std.muli(batchCount); } else { INDArray xMinusMean = next.getFeatureMatrix().subRowVector(mean); INDArray newMean = mean.add(xMinusMean.sum(0).divi(runningTotal)); INDArray meanB = next.getFeatureMatrix().mean(0); INDArray deltaSq = Transforms.pow(meanB.subRowVector(mean), 2); INDArray deltaSqScaled = deltaSq.mul(((float) runningTotal - batchCount) * batchCount / (float) runningTotal); INDArray mtwoB = Transforms.pow(next.getFeatureMatrix().std(0), 2); mtwoB.muli(batchCount); std = std.add(mtwoB);
if(ds.isEmpty()) continue; featuresToMerge[count] = ds.getFeatureMatrix(); labelsToMerge[count] = ds.getLabels();
/** * * @param dataSet * @return */ public INDArray output(DataSet dataSet) { return output(dataSet.getFeatureMatrix()); }
/** * * @param dataSet * @return */ public INDArray output(DataSet dataSet) { return output(dataSet.getFeatureMatrix()); }
private void prefetchBatchSetInputOutputValues() { if (!iterator.hasNext()) return; DataSet next = iterator.next(); inputColumns = next.getFeatureMatrix().size(1); totalOutcomes = next.getLabels().size(1); queued.add(next); } }
@Override public int numExamples() { if (getFeatureMatrix() != null) return getFeatureMatrix().size(0); else if (getLabels() != null) return getLabels().size(0); return 0; }
/** * Returns the next element in the iteration. * * @return the next element in the iteration */ @Override public DataSet next() { DataSet next = iter.next(); next.setLabels(next.getFeatureMatrix()); return next; }
@Override public void perform(Job job) { Serializable work = job.getWork(); if(work instanceof DataSet) { DataSet data = (DataSet) work; neuralNetwork.fit(data.getFeatureMatrix()); } else if(work instanceof INDArray) { neuralNetwork.fit((INDArray) work); } job.setResult(neuralNetwork.params()); }