/** * Get the feature matrix (inputs for the data) * * @return the feature matrix for the dataset */ @Override public INDArray getFeatureMatrix() { return getFeatures(); }
@Override public int hashCode() { int result = getFeatures() != null ? getFeatures().hashCode() : 0; result = 31 * result + (getLabels() != null ? getLabels().hashCode() : 0); result = 31 * result + (getFeaturesMaskArray() != null ? getFeaturesMaskArray().hashCode() : 0); result = 31 * result + (getLabelsMaskArray() != null ? getLabelsMaskArray().hashCode() : 0); return result; }
/** * The feature to add, and the example/row number * * @param feature the feature vector to add * @param example the number of the example to append to */ @Override public void addFeatureVector(INDArray feature, int example) { getFeatures().putRow(example, feature); }
@Override public int inputColumns() { // FIXME: int cast return (int) singleFold.getFeatures().size(1); }
/** * Divides the input data transform * by the max number in each row */ @Override public void scale() { FeatureUtil.scaleByMax(getFeatures()); }
/** * The number of inputs in the feature matrix * * @return */ @Override public int numInputs() { // FIXME: int cast return (int) getFeatures().size(1); }
@Override public INDArray exampleSums() { return getFeatures().sum(1); }
@Override public INDArray exampleMaxs() { return getFeatures().max(1); }
@Override public INDArray exampleMeans() { return getFeatures().mean(1); }
/** * Transform the data * @param dataSet the dataset to transform */ public void transform(DataSet dataSet) { dataSet.setFeatures(dataSet.getFeatures().subRowVector(mean)); dataSet.setFeatures(dataSet.getFeatures().divRowVector(std)); }
@Override public void roundToTheNearest(int roundTo) { for (int i = 0; i < getFeatures().length(); i++) { double curr = (double) getFeatures().getScalar(i).element(); getFeatures().put(i, Nd4j.scalar(MathUtils.roundDouble(curr, roundTo))); } }
@Override public void divideBy(int num) { getFeatures().divi(Nd4j.scalar(num)); }
@Override public void addRow(DataSet d, int i) { if (i > numExamples() || d == null) throw new IllegalArgumentException("Invalid index for adding a row"); getFeatures().putRow(i, d.getFeatures()); getLabels().putRow(i, d.getLabels()); }
/** * Squeezes input data to a max and a min * * @param min the min value to occur in the dataset * @param max the max value to ccur in the dataset */ @Override public void squishToRange(double min, double max) { for (int i = 0; i < getFeatures().length(); i++) { double curr = (double) getFeatures().getScalar(i).element(); if (curr < min) getFeatures().put(i, Nd4j.scalar(min)); else if (curr > max) getFeatures().put(i, Nd4j.scalar(max)); } }
@Override public void multiplyBy(double num) { getFeatures().muli(Nd4j.scalar(num)); }
@Override public void validate() { if (getFeatures().size(0) != getLabels().size(0)) throw new IllegalStateException("Invalid dataset"); }
/** * Binarizes the dataset such that any number greater than cutoff is 1 otherwise zero * * @param cutoff the cutoff point */ @Override public void binarize(double cutoff) { INDArray linear = getFeatureMatrix().linearView(); for (int i = 0; i < getFeatures().length(); i++) { double curr = linear.getDouble(i); if (curr > cutoff) getFeatures().putScalar(i, 1); else getFeatures().putScalar(i, 0); } }
/** * @Deprecated * Subtract by the column means and divide by the standard deviation */ @Deprecated @Override public void normalizeZeroMeanZeroUnitVariance() { INDArray columnMeans = getFeatures().mean(0); INDArray columnStds = getFeatureMatrix().std(0); setFeatures(getFeatures().subiRowVector(columnMeans)); columnStds.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD)); setFeatures(getFeatures().diviRowVector(columnStds)); }
/** * Reshapes the input in to the given rows and columns * * @param rows the row size * @param cols the column size * @return a copy of this data op with the input resized */ @Override public DataSet reshape(int rows, int cols) { DataSet ret = new DataSet(getFeatures().reshape(new long[] {rows, cols}), getLabels()); return ret; }
/** * Gets a copy of example i * * @param i the example to getFromOrigin * @return the example at i (one example) */ @Override public DataSet get(int[] i) { return new DataSet(getFeatures().getRows(i), getLabels().getRows(i)); }