@Override public DataBuffer createInt(double[] data, boolean copy) { return new IntBuffer(ArrayUtil.toInts(data), copy); }
/** Are the elements in the buffer contiguous for this NDArray? */ public static boolean isContiguousInBuffer(INDArray in) { long length = in.length(); long dLength = in.data().length(); if (length == dLength) return true; //full buffer, always contiguous char order = in.ordering(); long[] shape = in.shape(); long[] stridesIfContiguous; if (order == 'f') { stridesIfContiguous = ArrayUtil.calcStridesFortran(shape); } else if (order == 'c') { stridesIfContiguous = ArrayUtil.calcStrides(shape); } else if (order == 'a') { stridesIfContiguous = new long[] {1, 1}; } else { throw new RuntimeException("Invalid order: not c or f (is: " + order + ")"); } return Arrays.equals(in.stride(), stridesIfContiguous); }
public static long[] getStrides(long[] shape, char order) { if (order == NDArrayFactory.FORTRAN) return ArrayUtil.calcStridesFortran(shape); return ArrayUtil.calcStrides(shape); }
/** * Convert all dimensions in the specified * axes array to be positive * based on the specified range of values * @param range * @param axes * @return */ public static int[] convertNegativeIndices(int range, int[] axes) { int[] axesRet = ArrayUtil.range(0, range); int[] newAxes = ArrayUtil.copy(axes); for (int i = 0; i < axes.length; i++) { newAxes[i] = axes[axesRet[i]]; } return newAxes; }
return toConcat[0]; int sumAlongDim = 0; boolean allC = toConcat[0].ordering() == 'c'; long[] outputShape = ArrayUtil.copy(toConcat[0].shape()); outputShape[dimension] = sumAlongDim; sumAlongDim += toConcat[i].size(dimension); allC = allC && toConcat[i].ordering() == 'c'; for (int j = 0; j < toConcat[i].rank(); j++) { long[] sortedStrides = Nd4j.getStrides(outputShape); INDArray ret = Nd4j.create(outputShape, sortedStrides); allC &= (ret.ordering() == 'c'); int currBufferOffset = 0; for (int i = 0; i < ret.length(); i++) { ret.data().put(i, toConcat[currBuffer].data() .getDouble(toConcat[currBuffer].offset() + currBufferOffset++)); if (currBufferOffset >= toConcat[currBuffer].length()) { currBuffer++;
dimension[i] += op.x().rank(); if (dimension.length == op.x().rank()) dimension = new int[] {Integer.MAX_VALUE}; long[] retShape = Shape.wholeArrayDimension(dimension) ? new long[] {1, 1} : ArrayUtil.removeIndex(op.x().shape(), dimension); val yT = op.y().tensorssAlongDimension(dimension); ret = Nd4j.create(xT, yT); } else { if (Math.abs(op.zeroDouble()) < Nd4j.EPS_THRESHOLD) { ret = Nd4j.zeros(retShape); } else { ret = Nd4j.valueArrayOf(retShape, op.zeroDouble()); } else { if (op.z().lengthLong() != ArrayUtil.prodLong(retShape)) throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]"); if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) { op.z().assign(op.zeroDouble()); } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) { op.z().assign(op.zeroFloat()); } else if (op.x().data().dataType() == DataBuffer.Type.HALF) { op.z().assign(op.zeroHalf());
@Override // return 4 dimensions public INDArray backprop(INDArray epsilons, int miniBatchSize) { if (epsilons.ordering() != 'c' || !Shape.strideDescendingCAscendingF(epsilons)) epsilons = epsilons.dup('c'); if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) { if (epsilons.rank() == 2) return epsilons; //should never happen return epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth); } return epsilons.reshape('c', shape); }
public BaseNDArray(List<INDArray> slices, long[] shape, long[] stride, char ordering) { DataBuffer ret = slices.get(0).data().dataType() == (DataBuffer.Type.FLOAT) ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); this.data = ret; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering)); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); if (slices.get(0).isScalar()) { for (int i = 0; i < length(); i++) { putScalar(i, slices.get(i).getDouble(0)); } } else { for (int i = 0; i < slices(); i++) { putSlice(i, slices.get(i)); } } }
@Override public INDArray reshape(char order, long... newShape) { Nd4j.getCompressor().autoDecompress(this); throw new ND4JIllegalStateException( "Can't reshape(int...) without shape arguments. Got empty shape instead."); long[] shape = ArrayUtil.copy(newShape); long prod = ArrayUtil.prodLong(shape); throw new ND4JIllegalStateException("New shape length doesn't match original length: [" + prod + "] vs [" + this.lengthLong() + "]. Original shape: "+Arrays.toString(this.shape())+" New Shape: "+Arrays.toString(newShape)); INDArray reshapeAttempt = Shape.newShapeNoCopy(this, shape, order == 'f'); if (reshapeAttempt != null) { INDArray ret = Nd4j.createUninitialized(shape, order); if (order != ordering()) { ret.setData(dup(order).data()); } else ret.assign(this); return ret;
public static List<Pair<INDArray, String>> get3dPermutedWithShape(long seed, long... shape) { Nd4j.getRandom().setSeed(seed); long[] createdShape = {shape[1], shape[2], shape[0]}; int lencreatedShape = ArrayUtil.prod(createdShape); INDArray arr = Nd4j.linspace(1, lencreatedShape, lencreatedShape).reshape(createdShape); INDArray permuted = arr.permute(2, 0, 1); return Collections.singletonList(new Pair<>(permuted, "get3dPermutedWithShape(" + seed + "," + Arrays.toString(shape) + ").get(0)")); }
/** * An alias for repmat * * @param tile the ndarray to tile * @param repeat the shape to repeat * @return the tiled ndarray */ public static INDArray tile(INDArray tile, int... repeat) { int d = repeat.length; long[] shape = ArrayUtil.copy(tile.shape()); long n = Math.max(tile.length(), 1); if (d < tile.rank()) { repeat = Ints.concat(ArrayUtil.nTimes(tile.rank() - d, 1), repeat); } for (int i = 0; i < shape.length; i++) { if (repeat[i] != 1) { tile = tile.reshape(-1, n).repeat(0, new int[] {repeat[i]}); } long in = shape[i]; long nOut = in * repeat[i]; shape[i] = nOut; n /= Math.max(in, 1); } return tile.reshape(shape); }
public static INDArray toNDArray(int[] nums) { if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) { double[] doubles = ArrayUtil.toDoubles(nums); INDArray create = Nd4j.create(doubles, new int[] {1, nums.length}); return create; } else { float[] doubles = ArrayUtil.toFloats(nums); INDArray create = Nd4j.create(doubles, new int[] {1, nums.length}); return create; } }
/**Choose tensor dimension for operations with 3 arguments: z=Op(x,y) or similar<br> * @see #chooseElementWiseTensorDimension(INDArray) */ public static int chooseElementWiseTensorDimension(INDArray x, INDArray y, INDArray z) { if (x.isVector()) return ArrayUtil.argMax(x.shape()); // FIXME: int cast int opAlongDimensionMinStride = (int) ArrayUtil.argMinOfMax(x.stride(), y.stride(), z.stride()); int opAlongDimensionMaxLength = ArrayUtil.argMax(x.shape()); //Edge case: shapes with 1s in them can have stride of 1 on the dimensions of length 1 if (opAlongDimensionMinStride == opAlongDimensionMaxLength || x.size((int) opAlongDimensionMinStride) == 1) return opAlongDimensionMaxLength; int nOpsAlongMinStride = ArrayUtil.prod(ArrayUtil.removeIndex(x.shape(), (int) opAlongDimensionMinStride)); int nOpsAlongMaxLength = ArrayUtil.prod(ArrayUtil.removeIndex(x.shape(), opAlongDimensionMaxLength)); if (nOpsAlongMinStride <= 10 * nOpsAlongMaxLength) return opAlongDimensionMinStride; else return opAlongDimensionMaxLength; }
@Override public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) { DataBuffer.Type type = dataTypeForTensor(tensorProto); if(!tensorProto.isInitialized()) { throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized"); } OnnxProto3.TensorProto tensor = null; for(int i = 0; i < graph.getInitializerCount(); i++) { val initializer = graph.getInitializer(i); if(initializer.getName().equals(tensorName)) { tensor = initializer; break; } } if(tensor == null) return null; ByteString bytes = tensor.getRawData(); ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()); ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder()); directAlloc.put(byteBuffer); directAlloc.rewind(); long[] shape = getShapeFromTensor(tensorProto); DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape)); INDArray arr = Nd4j.create(buffer).reshape(shape); return arr; }
if (tfTensor.getIntValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { return Nd4j.trueScalar(0.0); arrayShape = new int[]{}; INDArray array = Nd4j.valueArrayOf(arrayShape, (double) val); return array; } else if (tfTensor.getInt64ValCount() > 0) { INDArray array = Nd4j.create(jArray, arrayShape, 0, 'c'); return array; } else { long length = ArrayUtil.prodLong(arrayShape); throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?"); if (tfTensor.getFloatValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?"); if (tfTensor.getDoubleValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?"); if (tfTensor.getInt64ValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
protected void assertSlice(INDArray put, long slice) { assert slice <= slices() : "Invalid slice specified " + slice; long[] sliceShape = put.shape(); if (Shape.isRowVectorShape(sliceShape)) { return; } else { long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); //no need to compare for scalar; primarily due to shapes either being [1] or length 0 if (put.isScalar()) return; if (isVector() && put.isVector() && put.length() < length()) return; //edge case for column vectors if (Shape.isColumnVectorShape(sliceShape)) return; if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape) && !Shape.isRowVectorShape(sliceShape)) throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ", Arrays.toString(sliceShape), Arrays.toString(requiredShape))); } }
public static boolean hasDefaultStridesForShape(INDArray input){ if(!strideDescendingCAscendingF(input)){ return false; } char order = input.ordering(); long[] defaultStrides; if(order == 'f'){ defaultStrides = ArrayUtil.calcStridesFortran(input.shape()); } else { defaultStrides = ArrayUtil.calcStrides(input.shape()); } return Arrays.equals(input.stride(), defaultStrides); } }