public ArrayDescriptor(int[] array) { this.intArray = array; this.dtype = DTYPE.INT; this.bufferType = Nd4j.dataType(); }
public ArrayDescriptor(long[] array) { this.longArray = array; this.dtype = DTYPE.LONG; this.bufferType = Nd4j.dataType(); }
public ArrayDescriptor(float[] array) { this.floatArray = array; this.dtype = DTYPE.FLOAT; this.bufferType = Nd4j.dataType(); }
/** * Returns the data opType for this ndarray * * @return the data opType for this ndarray */ @Override public DataBuffer.Type dtype() { return Nd4j.dataType(); }
public ArrayDescriptor(double[] array) { this.doubleArray = array; this.dtype = DTYPE.DOUBLE; this.bufferType = Nd4j.dataType(); }
@Override public Boolean apply(Number input) { if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) return input.doubleValue() != value.doubleValue(); else return input.floatValue() != value.floatValue(); }
@Override public Boolean apply(Number input) { if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) return input.doubleValue() == value.doubleValue(); else return input.floatValue() == value.floatValue(); }
/** * This method returns sizeOf(currentDataType), in bytes * * @return number of bytes per element */ public static int sizeOfDataType() { return sizeOfDataType(Nd4j.dataType()); }
protected DataBuffer.TypeEx getGlobalTypeEx() { DataBuffer.Type type = Nd4j.dataType(); return convertType(type); }
@Override public void init(INDArray x, INDArray y, INDArray z, long n) { super.init(x, y, z, n); if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) { this.extraArgs = new Object[] {zeroDouble()}; } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) { this.extraArgs = new Object[] {zeroFloat()}; } else if (Nd4j.dataType() == DataBuffer.Type.HALF) { this.extraArgs = new Object[] {zeroHalf()}; } }
/** * Create double based on real and imaginary * * @param real real component * @param imag imag component * @return */ public static IComplexNumber createComplexNumber(Number real, Number imag) { if (dataType() == DataBuffer.Type.FLOAT) return INSTANCE.createFloat(real.floatValue(), imag.floatValue()); return INSTANCE.createDouble(real.doubleValue(), imag.doubleValue()); }
@Override public INDArray trueScalar(Number value) { val dtype = Nd4j.dataType(); switch (dtype) { case DOUBLE: return create(new double[] {value.doubleValue()}, new int[] {}, new int[] {}, 0); case FLOAT: return create(new float[] {value.floatValue()}, new int[] {}, new int[] {}, 0); case HALF: return create(new float[] {value.floatValue()}, new int[] {}, new int[] {}, 0); default: throw new UnsupportedOperationException("Unsupported data type: [" + dtype + "]"); } }
/** * Returns the number of bytes * for the graph * * @return */ public long memoryForGraph() { return numElements() * DataTypeUtil.lengthForDtype(Nd4j.dataType()); }
/** * Create a scalar nd array with the specified value and offset * * @param value the value of the scalar * @return the scalar nd array */ @Override public INDArray scalar(double value) { if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) return create(new double[] {value}, new int[] {1, 1}, new int[] {1, 1}, 0); else return scalar((float) value); }
public static DataBuffer createBufferDetached(float[] data) { DataBuffer ret; if (dataType() == DataBuffer.Type.FLOAT) ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(data); else if (dataType() == DataBuffer.Type.HALF) ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(data); else ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(ArrayUtil.toDoubles(data)); logCreationIfNecessary(ret); return ret; }
public static DataBuffer createBufferDetached(double[] data) { DataBuffer ret; if (dataType() == DataBuffer.Type.DOUBLE) ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(data); else if (dataType() == DataBuffer.Type.HALF) ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(ArrayUtil.toFloats(data)); else ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(ArrayUtil.toFloats(data)); logCreationIfNecessary(ret); return ret; }
@Override public long getRequiredBatchMemorySize() { long result = maxIntArrays() * maxIntArraySize() * 4; result += maxArguments() * 8; // pointers result += maxShapes() * 8; // pointers result += maxIndexArguments() * 4; result += maxRealArguments() * (Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 8 : Nd4j.dataType() == DataBuffer.Type.FLOAT ? 4 : 2); result += 5 * 4; // numArgs return result * Batch.getBatchLimit(); } }
public static INDArray toNDArray(int[] nums) { if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) { double[] doubles = ArrayUtil.toDoubles(nums); INDArray create = Nd4j.create(doubles, new int[] {1, nums.length}); return create; } else { float[] doubles = ArrayUtil.toFloats(nums); INDArray create = Nd4j.create(doubles, new int[] {1, nums.length}); return create; } }
public static INDArray toNDArray(long[] nums) { if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) { double[] doubles = ArrayUtil.toDoubles(nums); INDArray create = Nd4j.create(doubles, new int[] {1, nums.length}); return create; } else { float[] doubles = ArrayUtil.toFloats(nums); INDArray create = Nd4j.create(doubles, new int[] {1, nums.length}); return create; } }
public static INDArray toNDArray(int[][] nums) { if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) { double[] doubles = ArrayUtil.toDoubles(nums); INDArray create = Nd4j.create(doubles, new int[] {nums[0].length, nums.length}); return create; } else { float[] doubles = ArrayUtil.toFloats(nums); INDArray create = Nd4j.create(doubles, new int[] {nums[0].length, nums.length}); return create; } }