Refine search
public static void validateDataType(DataBuffer.Type expectedType, INDArray... operands) { if (operands == null || operands.length == 0) return; int cnt = 0; for (INDArray operand : operands) { if (operand == null) continue; if (operand.data().dataType() != expectedType) throw new ND4JIllegalStateException("INDArray [" + cnt++ + "] dataType is [" + operand.data().dataType() + "] instead of expected [" + expectedType + "]"); } }
public LogEntry(INDArray toLog, String status) { //this.id = toLog.id(); this.shape = toLog.shape(); this.stride = toLog.stride(); this.ndArrayType = toLog.getClass().getName(); this.length = toLog.length(); this.references = toLog.data().references(); this.dataType = toLog.data().dataType() == DataBuffer.Type.DOUBLE ? "double" : "float"; this.timestamp = System.currentTimeMillis(); this.stackTraceElements = Thread.currentThread().getStackTrace(); this.status = status; }
public LogEntry(INDArray toLog, StackTraceElement[] stackTraceElements, String status) { //this.id = toLog.id(); this.shape = toLog.shape(); this.stride = toLog.stride(); this.ndArrayType = toLog.getClass().getName(); this.length = toLog.length(); this.references = toLog.data().references(); this.dataType = toLog.data().dataType() == DataBuffer.Type.DOUBLE ? "double" : "float"; this.timestamp = System.currentTimeMillis(); this.stackTraceElements = stackTraceElements; this.status = status; }
public static void printMatrixFullPrecision(INDArray matrix) { boolean floatType = (matrix.data().dataType() == DataBuffer.Type.FLOAT); printNDArrayHeader(matrix); long[] shape = matrix.shape(); for (int i = 0; i < shape[0]; i++) { for (int j = 0; j < shape[1]; j++) { if (floatType) System.out.print(matrix.getFloat(i, j)); else System.out.print(matrix.getDouble(i, j)); if (j != shape[1] - 1) System.out.print(", "); else System.out.println(); } } }
public static void printNDArrayHeader(INDArray array) { System.out.println(array.data().dataType() + " - order=" + array.ordering() + ", offset=" + array.offset() + ", shape=" + Arrays.toString(array.shape()) + ", stride=" + Arrays.toString(array.stride()) + ", length=" + array.length() + ", data().length()=" + array.data().length()); }
/** * * @param array * @return */ public INDArray decompress(INDArray array) { if (array.data().dataType() != DataBuffer.Type.COMPRESSED) return array; CompressedDataBuffer comp = (CompressedDataBuffer) array.data(); CompressionDescriptor descriptor = comp.getCompressionDescriptor(); if (!codecs.containsKey(descriptor.getCompressionAlgorithm())) throw new RuntimeException("Non-existent compression algorithm requested: [" + descriptor.getCompressionAlgorithm() + "]"); return codecs.get(descriptor.getCompressionAlgorithm()).decompress(array); }
/** * in place decompression of the given * ndarray. If the ndarray isn't compressed * this will do nothing * @param array the array to decompressed * if it is comprssed */ public void decompressi(INDArray array) { if (array.data().dataType() != DataBuffer.Type.COMPRESSED) return; CompressedDataBuffer comp = (CompressedDataBuffer) array.data(); CompressionDescriptor descriptor = comp.getCompressionDescriptor(); if (!codecs.containsKey(descriptor.getCompressionAlgorithm())) throw new RuntimeException("Non-existent compression algorithm requested: [" + descriptor.getCompressionAlgorithm() + "]"); codecs.get(descriptor.getCompressionAlgorithm()).decompressi(array); }
/** * This method checks if any Op operand has data opType of INT, and throws exception if any. * * @param op */ protected void interceptIntDataType(Op op) { // FIXME: Remove this method, after we'll add support for <int> dtype operations if (op.x() != null && op.x().data().dataType() == DataBuffer.Type.INT) throw new ND4JIllegalStateException( "Op.X contains INT data. Operations on INT dataType are not supported yet"); if (op.z() != null && op.z().data().dataType() == DataBuffer.Type.INT) throw new ND4JIllegalStateException( "Op.Z contains INT data. Operations on INT dataType are not supported yet"); if (op.y() != null && op.y().data().dataType() == DataBuffer.Type.INT) throw new ND4JIllegalStateException( "Op.Y contains INT data. Operations on INT dataType are not supported yet."); }
/** * Scale by 1 / norm2 of the matrix * * @param toScale the ndarray to scale * @return the scaled ndarray */ public static INDArray unitVec(INDArray toScale) { double length = toScale.norm2Number().doubleValue(); if (length > 0) { if (toScale.data().dataType() == (DataBuffer.Type.FLOAT)) return Nd4j.getBlasWrapper().scal(1.0f / (float) length, toScale); else return Nd4j.getBlasWrapper().scal(1.0 / length, toScale); } return toScale; }
@Override public INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y) { LinAlgExceptions.assertVector(x, y); LinAlgExceptions.assertMatrix(a); if (a.data().dataType() == DataBuffer.Type.FLOAT) { // DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, a, x, y); return gemv((float) alpha, a, x, (float) beta, y); } else { level2().gemv('N', 'N', alpha, a, x, beta, y); } return y; }
@Override public INDArray gemv(float alpha, INDArray a, INDArray x, float beta, INDArray y) { LinAlgExceptions.assertVector(x, y); LinAlgExceptions.assertMatrix(a); if (a.data().dataType() == DataBuffer.Type.DOUBLE) { return gemv((double) alpha, a, x, (double) beta, y); } level2().gemv('N', 'N', alpha, a, x, beta, y); return y; }
/** * Returns the double data * for this ndarray. * If possible (the offset is 0 representing the whole buffer) * it will return a direct reference to the underlying array * @param buf the ndarray to get the data for * @return the double data for this ndarray */ public static double[] getDoubleData(INDArray buf) { if (buf.data().dataType() != DataBuffer.Type.DOUBLE) throw new IllegalArgumentException("Double data must be obtained from a double buffer"); if (buf.data().allocationMode() == DataBuffer.AllocationMode.HEAP) { return buf.data().asDouble(); } else { double[] ret = new double[(int) buf.length()]; INDArray linear = buf.linearView(); for (int i = 0; i < buf.length(); i++) ret[i] = linear.getDouble(i); return ret; } }