private void ensureConsistent(int i) { probabilityOfSuccess = p.linearView().getDouble(i); }
public static void assertValidNum(INDArray n) { INDArray linear = n.linearView(); for (int i = 0; i < linear.length(); i++) { double d = linear.getDouble(i); if (Double.isNaN(d) || Double.isInfinite(d)) throw new IllegalStateException("Found infinite or nan"); } }
/** * Given a sequence of Iterators over a transform of matrices, fill in all of * the matrices with the entries in the theta vector. Errors are * thrown if the theta vector does not exactly fill the matrices. */ public static void setParams(INDArray theta, Iterator<? extends INDArray>... matrices) { int index = 0; for (Iterator<? extends INDArray> matrixIterator : matrices) { while (matrixIterator.hasNext()) { INDArray matrix = matrixIterator.next().linearView(); for (int i = 0; i < matrix.length(); i++) { matrix.putScalar(i, theta.getDouble(index)); index++; } } } if (index != theta.length()) { throw new AssertionError("Did not entirely use the theta vector"); } }
public static int[] toInts(INDArray n) { if (n instanceof IComplexNDArray) throw new IllegalArgumentException("Unable to convert complex array"); if (n.length() > Integer.MAX_VALUE) throw new ND4JIllegalStateException("Can't convert INDArray with length > Integer.MAX_VALUE"); n = n.linearView(); int[] ret = new int[(int) n.length()]; for (int i = 0; i < n.length(); i++) ret[i] = (int) n.getFloat(i); return ret; }
public static long[] toLongs(INDArray n) { if (n instanceof IComplexNDArray) throw new IllegalArgumentException("Unable to convert complex array"); if (n.length() > Integer.MAX_VALUE) throw new ND4JIllegalStateException("Can't convert INDArray with length > Integer.MAX_VALUE"); n = n.linearView(); // FIXME: int cast long[] ret = new long[(int) n.length()]; for (int i = 0; i < n.length(); i++) ret[i] = (long) n.getFloat(i); return ret; }
/** * Given a sequence of Iterators over a transform of matrices, fill in all of * the matrices with the entries in the theta vector. Errors are * thrown if the theta vector does not exactly fill the matrices. */ public static void setParams(INDArray theta, Collection<INDArray>... matrices) { int index = 0; for (Collection<INDArray> matrixCollection : matrices) { for (INDArray matrix : matrixCollection) { INDArray linear = matrix.linearView(); for (int i = 0; i < matrix.length(); i++) { linear.putScalar(i, theta.getDouble(index)); index++; } } } if (index != theta.length()) { throw new AssertionError("Did not entirely use the theta vector"); } }
@Override public IComplexNDArray addi(IComplexNumber n, INDArray result) { IComplexNDArray linear = linearView(); IComplexNDArray cResult = (IComplexNDArray) result.linearView(); for (int i = 0; i < length(); i++) { cResult.putScalar(i, linear.getComplex(i).add(n)); } return (IComplexNDArray) result; }
/** * * @param real */ protected void copyFromReal(INDArray real) { if (!Shape.shapeEquals(shape(), real.shape())) throw new IllegalStateException("Unable to copy array. Not the same shape"); INDArray linear = real.linearView(); IComplexNDArray thisLinear = linearView(); for (int i = 0; i < linear.length(); i++) { thisLinear.putScalar(i, Nd4j.createComplexNumber(linear.getDouble(i), 0.0)); } }
@Override public INDArray getReal() { INDArray result = Nd4j.create(shape()); IComplexNDArray linearView = linearView(); INDArray linearRet = result.linearView(); for (int i = 0; i < linearView.length(); i++) { linearRet.putScalar(i, linearView.getReal(i)); } return result; }
@Override public IComplexNDArray lti(INDArray other) { if (other instanceof IComplexNDArray) { IComplexNDArray linear = linearView(); IComplexNDArray otherLinear = (IComplexNDArray) other.linearView(); for (int i = 0; i < linear.length(); i++) { linear.putScalar(i, linear.getComplex(i).absoluteValue().doubleValue() < otherLinear.getComplex(i) .absoluteValue().doubleValue() ? Nd4j.createComplexNumber(1, 0) : Nd4j.createComplexNumber(0, 0)); } } else { IComplexNDArray linear = linearView(); INDArray otherLinear = other.linearView(); for (int i = 0; i < linear.length(); i++) { linear.putScalar(i, linear.getComplex(i).absoluteValue().doubleValue() < otherLinear.getDouble(i) ? Nd4j.createComplexNumber(1, 0) : Nd4j.createComplexNumber(0, 0)); } } return this; }
@Override public IComplexNDArray eqi(INDArray other) { if (other instanceof IComplexNDArray) { IComplexNDArray linear = linearView(); IComplexNDArray otherLinear = (IComplexNDArray) other.linearView(); for (int i = 0; i < linear.length(); i++) { linear.putScalar(i, linear.getComplex(i).absoluteValue().doubleValue() == otherLinear.getComplex(i) .absoluteValue().doubleValue() ? Nd4j.createComplexNumber(1, 0) : Nd4j.createComplexNumber(0, 0)); } } else { IComplexNDArray linear = linearView(); INDArray otherLinear = other.linearView(); for (int i = 0; i < linear.length(); i++) { linear.putScalar(i, linear.getComplex(i).absoluteValue().doubleValue() == otherLinear.getDouble(i) ? Nd4j.createComplexNumber(1, 0) : Nd4j.createComplexNumber(0, 0)); } } return this; }
@Override public IComplexNDArray neqi(INDArray other) { if (other instanceof IComplexNDArray) { IComplexNDArray linear = linearView(); IComplexNDArray otherLinear = (IComplexNDArray) other.linearView(); for (int i = 0; i < linear.length(); i++) { linear.putScalar(i, linear.getComplex(i).absoluteValue().doubleValue() != otherLinear.getComplex(i) .absoluteValue().doubleValue() ? Nd4j.createComplexNumber(1, 0) : Nd4j.createComplexNumber(0, 0)); } } else { IComplexNDArray linear = linearView(); INDArray otherLinear = other.linearView(); for (int i = 0; i < linear.length(); i++) { linear.putScalar(i, linear.getComplex(i).absoluteValue().doubleValue() != otherLinear.getDouble(i) ? Nd4j.createComplexNumber(1, 0) : Nd4j.createComplexNumber(0, 0)); } } return this; }
@Override public IComplexNDArray gti(INDArray other) { if (other instanceof IComplexNDArray) { IComplexNDArray linear = linearView(); IComplexNDArray otherLinear = (IComplexNDArray) other.linearView(); for (int i = 0; i < linear.length(); i++) { linear.putScalar(i, linear.getComplex(i).absoluteValue().doubleValue() > otherLinear.getComplex(i) .absoluteValue().doubleValue() ? Nd4j.createComplexNumber(1, 0) : Nd4j.createComplexNumber(0, 0)); } } else { IComplexNDArray linear = linearView(); INDArray otherLinear = other.linearView(); for (int i = 0; i < linear.length(); i++) { linear.putScalar(i, linear.getComplex(i).absoluteValue().doubleValue() > otherLinear.getDouble(i) ? Nd4j.createComplexNumber(1, 0) : Nd4j.createComplexNumber(0, 0)); } } return this; }
/** * Copy real numbers to arr * @param arr the arr to copy to */ protected void copyRealTo(INDArray arr) { INDArray linear = arr.linearView(); IComplexNDArray thisLinear = linearView(); if (arr.isScalar()) arr.putScalar(0, getReal(0)); else for (int i = 0; i < linear.length(); i++) { arr.putScalar(i, thisLinear.getReal(i)); } }
/** * Copy imaginary numbers to the given * ndarray * @param arr the array to copy imaginary numbers to */ protected void copyImagTo(INDArray arr) { INDArray linear = arr.linearView(); IComplexNDArray thisLinear = linearView(); if (arr.isScalar()) arr.putScalar(0, getReal(0)); else for (int i = 0; i < linear.length(); i++) { arr.putScalar(i, thisLinear.getImag(i)); } }
/** * Returns the double data * for this ndarray. * If possible (the offset is 0 representing the whole buffer) * it will return a direct reference to the underlying array * @param buf the ndarray to get the data for * @return the double data for this ndarray */ public static double[] getDoubleData(INDArray buf) { if (buf.data().dataType() != DataBuffer.Type.DOUBLE) throw new IllegalArgumentException("Double data must be obtained from a double buffer"); if (buf.data().allocationMode() == DataBuffer.AllocationMode.HEAP) { return buf.data().asDouble(); } else { double[] ret = new double[(int) buf.length()]; INDArray linear = buf.linearView(); for (int i = 0; i < buf.length(); i++) ret[i] = linear.getDouble(i); return ret; } }
/** * Returns the float data * for this ndarray. * If possible (the offset is 0 representing the whole buffer) * it will return a direct reference to the underlying array * @param buf the ndarray to get the data for * @return the float data for this ndarray */ public static float[] getFloatData(INDArray buf) { if (buf.data().dataType() != DataBuffer.Type.FLOAT) throw new IllegalArgumentException("Float data must be obtained from a float buffer"); if (buf.data().allocationMode() == DataBuffer.AllocationMode.HEAP) { return buf.data().asFloat(); } else { float[] ret = new float[(int) buf.length()]; INDArray linear = buf.linearView(); for (int i = 0; i < buf.length(); i++) ret[i] = linear.getFloat(i); return ret; } }
/** * Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc * * @param reverse the matrix to reverse * @return the reversed matrix */ @Override public INDArray reverse(INDArray reverse) { // FIXME: native method should be used instead INDArray rev = reverse.linearView(); INDArray ret = Nd4j.create(rev.shape()); int count = 0; for (long i = rev.length() - 1; i >= 0; i--) { ret.putScalar(count++, rev.getFloat(i)); } return ret.reshape(reverse.shape()); }
/** * Returns the (1-norm) distance. */ @Override public double distance1(INDArray other) { float d = 0.0f; if (other instanceof IComplexNDArray) { IComplexNDArray n2 = (IComplexNDArray) other; IComplexNDArray n2Linear = n2.linearView(); for (int i = 0; i < length(); i++) { IComplexNumber n = getComplex(i).sub(n2Linear.getComplex(i)); d += n.absoluteValue().doubleValue(); } return d; } INDArray linear = other.linearView(); for (int i = 0; i < length(); i++) { IComplexNumber n = linearView().getComplex(i).sub(linear.getDouble(i)); d += n.absoluteValue().doubleValue(); } return d; }
/** * Binarizes the dataset such that any number greater than cutoff is 1 otherwise zero * * @param cutoff the cutoff point */ @Override public void binarize(double cutoff) { INDArray linear = getFeatureMatrix().linearView(); for (int i = 0; i < getFeatures().length(); i++) { double curr = linear.getDouble(i); if (curr > cutoff) getFeatures().putScalar(i, 1); else getFeatures().putScalar(i, 0); } }