protected static INDArrayIndex validate(long size, INDArrayIndex index) { if ((index instanceof IntervalIndex || index instanceof PointIndex) && size <= index.current() && size > 1) throw new IllegalArgumentException("NDArrayIndex is out of range. Beginning index: " + index.current() + " must be less than its size: " + size); if (index instanceof IntervalIndex && size < index.end()) { long begin = ((IntervalIndex) index).begin; index = NDArrayIndex.interval(begin, index.stride(), size); } return index; }
@Override public INDArray put(INDArrayIndex[] indices, INDArray element) { if (indices[0] instanceof SpecifiedIndex && element.isVector()) { indices[0].reset(); int cnt = 0; while (indices[0].hasNext()) { long idx = indices[0].next(); putScalar((int) idx, element.getDouble(cnt)); cnt++; } return this; } else { return get(indices).assign(element); } }
public static INDArrayIndex interval(long begin, long stride, long end,long max, boolean inclusive) { assert begin <= end : "Beginning index in range must be less than end"; INDArrayIndex index = new IntervalIndex(inclusive, stride); index.init(begin, end); return index; }
if (idx instanceof PointIndex && (arr.isVector() && indexes.length == 1 ? idx.current() >= shape[i + 1] : idx.current() >= shape[i])) { throw new IllegalArgumentException( "INDArrayIndex[" + i + "] is out of bounds (value: " + idx.current() + ")"); pointOffsets.add(idx.offset()); pointStrides.add((long) arr.stride(strideIndex)); numPointIndexes++; || idx instanceof SpecifiedIndex) { if (idx instanceof IntervalIndex) { accumStrides.add(arr.stride(strideIndex) * idx.stride()); intervalStrides.add(idx.stride()); numIntervals++; accumShape.add(idx.length()); accumOffsets.add(idx.offset()); } else accumOffsets.add(idx.offset()); accumOffsets.add(idx.offset()); this.offset = indexes[1].offset(); else this.offset = ArrayUtil.dotProductLong2(pointOffsets, pointStrides);
this.offset = indexes[0].offset() * strides[1]; else { this.offset = indexes[0].offset() * strides[0]; this.offset = indexes[0].offset() * strides[1]; else { this.offset = indexes[0].offset() * strides[0]; if (indexes[0] instanceof PointIndex) { if (indexes.length > 1 && indexes[1] instanceof IntervalIndex) { offset = indexes[1].offset(); this.shapes = new long[2]; shapes[0] = 1; shapes[1] = indexes[1].length(); this.strides = new long[2]; strides[0] = 0; strides[1] = indexes[1].stride(); this.offsets = new long[2]; return true; if (indexes.length > 1 && indexes[1] instanceof PointIndex) { if (indexes[0] instanceof IntervalIndex) { offset = indexes[0].offset(); this.shapes = new long[2]; shapes[1] = 1; shapes[0] = indexes[1].length(); this.strides = new long[2]; strides[1] = 0; strides[0] = indexes[1].stride();
/** * Create a range based on the given indexes. * This is similar to create covering shape in that it approximates * the length of each dimension (ignoring elements) and * reproduces an index of the same dimension and length. * * @param indexes the indexes to create the range for * @return the index ranges. */ public static INDArrayIndex[] rangeOfLength(INDArrayIndex[] indexes) { INDArrayIndex[] indexesRet = new INDArrayIndex[indexes.length]; for (int i = 0; i < indexes.length; i++) indexesRet[i] = NDArrayIndex.interval(0, indexes[i].length()); return indexesRet; }
private boolean allIndexGreatherThanZero(INDArrayIndex... indexes) { for (INDArrayIndex indArrayIndex : indexes) if (indArrayIndex.offset() == 0) return false; return true; }
private boolean anyHaveStrideOne(INDArrayIndex... indexes) { for (INDArrayIndex indArrayIndex : indexes) if (indArrayIndex.stride() == 1) return true; return false; }
/** * Prunes indices of greater length than the shape * and fills in missing indices if there are any * * @param originalShape the original shape to adjust to * @param indexes the indexes to adjust * @return the adjusted indices */ public static INDArrayIndex[] adjustIndices(int[] originalShape, INDArrayIndex... indexes) { if (Shape.isVector(originalShape) && indexes.length == 1) return indexes; if (indexes.length < originalShape.length) indexes = fillIn(originalShape, indexes); if (indexes.length > originalShape.length) { INDArrayIndex[] ret = new INDArrayIndex[originalShape.length]; System.arraycopy(indexes, 0, ret, 0, originalShape.length); return ret; } if (indexes.length == originalShape.length) return indexes; for (int i = 0; i < indexes.length; i++) { if (indexes[i].end() >= originalShape[i] || indexes[i] instanceof NDArrayIndexAll) indexes[i] = NDArrayIndex.interval(0, originalShape[i] - 1); } return indexes; }
if (idx instanceof PointIndex && (arr.isVector() && indexes.length == 1 ? idx.current() >= shape[i + 1] : idx.current() >= shape[i])) { throw new IllegalArgumentException( "INDArrayIndex[" + i + "] is out of bounds (value: " + idx.current() + ")"); pointOffsets.add(idx.offset()); pointStrides.add((long) arr.stride(strideIndex)); numPointIndexes++; || idx instanceof SpecifiedIndex) { if (idx instanceof IntervalIndex) { accumStrides.add(arr.stride(strideIndex) * idx.stride()); intervalStrides.add(idx.stride()); numIntervals++; accumShape.add(idx.length()); accumOffsets.add(idx.offset()); } else accumOffsets.add(idx.offset()); accumOffsets.add(idx.offset()); this.offset = indexes[1].offset(); else this.offset = ArrayUtil.dotProductLong2(pointOffsets, pointStrides);
this.offset = indexes[0].offset(); this.offset = indexes[0].offset(); if (indexes[0] instanceof PointIndex) { if (indexes.length > 1 && indexes[1] instanceof IntervalIndex) { offset = indexes[1].offset(); this.shapes = new long[2]; shapes[0] = 1; shapes[1] = indexes[1].length(); this.strides = new long[2]; strides[0] = 0; strides[1] = indexes[1].stride(); this.offsets = new long[2]; return true; if (indexes.length > 1 && indexes[1] instanceof PointIndex) { if (indexes[0] instanceof IntervalIndex) { offset = indexes[0].offset(); this.shapes = new long[2]; shapes[1] = 1; shapes[0] = indexes[1].length(); this.strides = new long[2]; strides[1] = 0; strides[0] = indexes[1].stride(); this.offsets = new long[2]; return true; if (specifiedIndex.getIndexes().length >= arr.rank())
/** * Calculate the shape for the given set of indices. * <p/> * The shape is defined as (for each dimension) * the difference between the end index + 1 and * the begin index * * @param indices the indices to calculate the shape for * @return the shape for the given indices */ public static int[] shape(INDArrayIndex... indices) { int[] ret = new int[indices.length]; for (int i = 0; i < ret.length; i++) { // FIXME: LONG ret[i] = (int) indices[i].length(); } List<Integer> nonZeros = new ArrayList<>(); for (int i = 0; i < ret.length; i++) { if (ret[i] > 0) nonZeros.add(ret[i]); } return ArrayUtil.toArray(nonZeros); }
ret[i] = 0; else { ret[i] = indices[i].offset(); List<Long> nonZeros = new ArrayList<>(); for (int i = 0; i < indices.length; i++) if (indices[i].offset() > 0) nonZeros.add(indices[i].offset()); if (nonZeros.size() > shape.length) throw new IllegalStateException("Non zeros greater than shape unable to continue"); ret[i] = 0; else { ret[i] = indices[shapeIndex++].offset();
private boolean anyHaveStrideOne(INDArrayIndex... indexes) { for (INDArrayIndex indArrayIndex : indexes) if (indArrayIndex.stride() == 1) return true; return false; }
/** * Prunes indices of greater length than the shape * and fills in missing indices if there are any * * @param originalShape the original shape to adjust to * @param indexes the indexes to adjust * @return the adjusted indices */ public static INDArrayIndex[] adjustIndices(int[] originalShape, INDArrayIndex... indexes) { if (Shape.isVector(originalShape) && indexes.length == 1) return indexes; if (indexes.length < originalShape.length) indexes = fillIn(originalShape, indexes); if (indexes.length > originalShape.length) { INDArrayIndex[] ret = new INDArrayIndex[originalShape.length]; System.arraycopy(indexes, 0, ret, 0, originalShape.length); return ret; } if (indexes.length == originalShape.length) return indexes; for (int i = 0; i < indexes.length; i++) { if (indexes[i].end() >= originalShape[i] || indexes[i] instanceof NDArrayIndexAll) indexes[i] = NDArrayIndex.interval(0, originalShape[i] - 1); } return indexes; }
@Override public INDArray put(INDArrayIndex[] indices, INDArray element) { Nd4j.getCompressor().autoDecompress(this); if (indices[0] instanceof SpecifiedIndex && element.isVector()) { indices[0].reset(); int cnt = 0; while (indices[0].hasNext()) { long idx = indices[0].next(); // FIXME: LONG putScalar((int) idx, element.getDouble(cnt)); cnt++; } return this; } else { return get(indices).assign(element); } }
protected static INDArrayIndex validate(long size, INDArrayIndex index) { if ((index instanceof IntervalIndex || index instanceof PointIndex) && size <= index.current() && size > 1) throw new IllegalArgumentException("NDArrayIndex is out of range. Beginning index: " + index.current() + " must be less than its size: " + size); if (index instanceof IntervalIndex && size < index.end()) { long begin = ((IntervalIndex) index).begin; index = NDArrayIndex.interval(begin, index.stride(), size); } return index; }
accumShape.add(idx.length()); shapeIndex++; continue;
public static INDArrayIndex interval(long begin, long stride, long end, boolean inclusive) { assert begin <= end : "Beginning index in range must be less than end"; INDArrayIndex index = new IntervalIndex(inclusive, stride); index.init(begin, end); return index; }