/** * * @param img * @param kernel * @param stride * @param padding * @return */ public static INDArray im2col(INDArray img, int[] kernel, int[] stride, int[] padding) { Nd4j.getCompressor().autoDecompress(img); return im2col(img, kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false); }
/** * Replaces all elements in this ndarray that are matching give condition, with corresponding elements from given array * * @param arr * @param condition * @return */ @Override public INDArray replaceWhere(INDArray arr, Condition condition) { Nd4j.getCompressor().autoDecompress(this); BooleanIndexing.replaceWhere(this, arr, condition); return this; }
@Override public INDArray put(INDArrayIndex[] indices, Number element) { Nd4j.getCompressor().autoDecompress(this); INDArray get = get(indices); for (int i = 0; i < get.length(); i++) get.putScalar(i, element.doubleValue()); return this; }
/** * Returns the (euclidean) distance. */ @Override public double distance2(INDArray other) { Nd4j.getCompressor().autoDecompress(this); return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue(); }
/** * Returns the (1-norm) distance. */ @Override public double distance1(INDArray other) { Nd4j.getCompressor().autoDecompress(this); return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue(); }
@Override public INDArray putScalar(long row, long col, double value) { Nd4j.getCompressor().autoDecompress(this); autoProcessScalarCall(); if (rank() > 2) throw new IllegalStateException("Cannot use putScalar(int,int,double) on a rank " + rank() + " INDArray"); long offset = Shape.getOffsetUnsafe(javaShapeInformation, row, col); data.put(offset, value); return this; }
@Override public INDArray putScalar(long dim0, long dim1, long dim2, long dim3, double value) { Nd4j.getCompressor().autoDecompress(this); autoProcessScalarCall(); if (rank() != 4) throw new IllegalStateException( "Cannot use putScalar(int,int,int,int,double) on a rank " + rank() + " INDArray"); long offset = Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2, dim3); data.put(offset, value); return this; }
@Override public INDArray putScalar(long dim0, long dim1, long dim2, double value) { Nd4j.getCompressor().autoDecompress(this); autoProcessScalarCall(); if (rank() != 3) throw new IllegalStateException( "Cannot use putScalar(int,int,int,double) on a rank " + rank() + " INDArray"); long offset = 0; // Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2); long size_0 = javaShapeInformation[1]; long size_1 = javaShapeInformation[1 + 1]; long size_2 = javaShapeInformation[1 + 2]; if (size_0 != 1) offset += dim0 * javaShapeInformation[1 + 0 + 3]; if (size_1 != 1) offset += dim1 * javaShapeInformation[1 + 1 + 3]; if (size_2 != 1) offset += dim2 * javaShapeInformation[1 + 2 + 3]; data.put(offset, value); return this; }
@Override public INDArray condi(Condition condition) { Nd4j.getCompressor().autoDecompress(this); INDArray linear = this; for (int i = 0; i < length(); i++) { boolean met = condition.apply(linear.getDouble(i)); linear.putScalar(i, met ? 1 : 0); } return this; }
protected void checkForCompression(Op op) { // check for INT datatype arrays interceptIntDataType(op); if (op.x() != null && op.x().isCompressed()) Nd4j.getCompressor().decompressi(op.x()); if (op.y() != null && op.y().isCompressed()) Nd4j.getCompressor().decompressi(op.y()); if (op.z() != null && op.z().isCompressed()) Nd4j.getCompressor().decompressi(op.z()); }
@Override public INDArray putWhere(INDArray comp, INDArray put, Condition condition) { Nd4j.getCompressor().autoDecompress(this); MatchConditionTransform matchCondition = new MatchConditionTransform(this,comp,condition); Nd4j.getExecutioner().exec(matchCondition); return putWhereWithMask(matchCondition.z(),put); }
/** * Get the specified column * * @param c */ @Override public INDArray getColumn(long c) { Nd4j.getCompressor().autoDecompress(this); if (isColumnVector() && c == 0) return this; else if (isColumnVector() && c > 0) throw new IllegalArgumentException("Illegal index for row"); else if(isRowVector()) { return Nd4j.scalar(getDouble(c)); } return get(NDArrayIndex.all(), NDArrayIndex.point(c)); }
@Override public INDArray dup() { WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray"); if (this.isCompressed() && this.ordering() == Nd4j.order().charValue()) { INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer()); ret.markAsCompressed(true); return ret; } Nd4j.getCompressor().autoDecompress(this); INDArray ret = Shape.toOffsetZeroCopy(this); return ret; }
/** * Insert a column in to this array * Will throw an exception if this * ndarray is not a matrix * * @param column the column to insert * @param toPut the array to put * @return this */ @Override public INDArray putColumn(int column, INDArray toPut) { Nd4j.getCompressor().autoDecompress(this); if (isColumnVector() && toPut.isVector()) { return assign(toPut); } return put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(column)}, toPut); }
@Override public double getDouble(long i) { Nd4j.getCompressor().autoDecompress(this); if (i >= length()) { throw new IllegalArgumentException("Unable to get linear index >= " + length()); } autoProcessScalarCall(); if (i == 0) return data().getDouble(i); long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); Shape.assertShapeLessThan(dimensions, shape()); return getDouble(dimensions); }
public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dh, int dw, boolean isSameMode) { Nd4j.getCompressor().autoDecompress(img); //Input: NCHW format // FIXME: int cast int outH = outputSize((int) img.size(2), kh, sy, ph, dh, isSameMode); int outW = outputSize((int) img.size(3), kw, sx, pw, dw, isSameMode); //[miniBatch,depth,kH,kW,outH,outW] INDArray out = Nd4j.create(new long[]{img.size(0), img.size(1), kh, kw, outH, outW}, 'c'); return im2col(img, kh, kw, sy, sx, ph, pw, dh, dw, isSameMode, out); }
/** * Flattens the array for linear indexing * * @return the flattened version of this array */ @Override public INDArray ravel(char ordering) { Nd4j.getCompressor().autoDecompress(this); if (length() >= Integer.MAX_VALUE) throw new IllegalArgumentException("Length can not be >= Integer.MAX_VALUE"); INDArray ret = create(new int[] {1, (int) length()}, ordering); NDArrayIndex index = new NDArrayIndex(this.shape()); for (int i = 0; i < length(); i++) { // FIXME: LONG double val = getDouble((int) index.next()); ret.putScalar(new int[] {0, i}, val); } return ret; }
@Override public INDArray dup(char order) { WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray"); if (this.isCompressed() && this.ordering() == order) { INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer()); ret.markAsCompressed(true); return ret; } Nd4j.getCompressor().autoDecompress(this); return Shape.toOffsetZeroCopy(this, order); }
@Override public INDArray put(INDArrayIndex[] indices, INDArray element) { Nd4j.getCompressor().autoDecompress(this); if (indices[0] instanceof SpecifiedIndex && element.isVector()) { indices[0].reset(); int cnt = 0; while (indices[0].hasNext()) { long idx = indices[0].next(); // FIXME: LONG putScalar((int) idx, element.getDouble(cnt)); cnt++; } return this; } else { return get(indices).assign(element); } }
/** * Replicate and tile array to fill out to the given shape * See: * https://github.com/numpy/numpy/blob/master/numpy/matlib.py#L310-L358 * @param shape the new shape of this ndarray * @return the shape to fill out to */ @Override public INDArray repmat(int[] shape) { Nd4j.getCompressor().autoDecompress(this); long rows = rows() * shape[0]; long cols = columns() * shape[1]; INDArray ret = reshape(1, length()).repeat(0, shape[0]).reshape(rows, columns()).repeat(0, shape[1]); return ret.reshape(rows, cols); }