@Override public int getTrailingOnes() { int numLeadingOnes = 0; for (int i = rank() - 1; i > 0; i--) { if (size(i) == 1) numLeadingOnes++; } return numLeadingOnes; }
@Override public int getLeadingOnes() { int numLeadingOnes = 0; for (int i = 0; i < rank(); i++) { if (size(i) == 1) numLeadingOnes++; } return numLeadingOnes; }
/** * Number of slices: aka shape[0] * * @return the number of slices * for this nd array */ @Override public long slices() { if (isRowVector()) return length(); return size(0); }
@Override @Deprecated public long linearIndex(long i) { long idx = i; for (int j = 0; j < Shape.rank(javaShapeInformation) - 1; j++) { if (size((int) i) == 1) continue; idx += i * stride(j); } return Shape.offset(javaShapeInformation) + (idx); }
if (dimension == 0 && isVector() || isRowVector()) return 1; if (size(dimension) == 1 && !isVector()) { for (int i = dimension; i < rank(); i++) { if (size(i) != 1) return vectorsAlongDimension(i); } else if (size(0) == 1 && !isVector()) { int realDimension = rank() - getLeadingOnes(); long length = length(); if (length / size(realDimension) >= Integer.MAX_VALUE) throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); return length / size(realDimension); if (length / size(Shape.rank(javaShapeInformation) - 1) >= Integer.MAX_VALUE) throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); return (int) (length / size(Shape.rank(javaShapeInformation) - 1)); if (length / size(dimension) >= Integer.MAX_VALUE) throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); return length / size(dimension);
/** * Number of columns (shape[1]), throws an exception when * called when not 2d * * @return the number of columns in the array (only 2d) */ @Override public int columns() { // FIXME: int cast if (isMatrix()) return (int) size(1); else if (Shape.isColumnVectorShape(shape())) { return 1; } else if (Shape.isRowVectorShape(shape())) { return (int) length(); } throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid"); }
/** * Inserts the element at the specified index * * @param indices the indices to insert into * @param element a scalar ndarray * @return a scalar ndarray of the element at this index */ @Override public INDArray put(int[] indices, INDArray element) { Nd4j.getCompressor().autoDecompress(this); if (!element.isScalar()) throw new IllegalArgumentException("Unable to insert anything but a scalar"); if (isRowVector() && indices[0] == 0 && indices.length == 2) { int ix = 0; //Shape.offset(javaShapeInformation); for (int i = 1; i < indices.length; i++) ix += indices[i] * stride(i); if (ix >= data.length()) throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); data.put(ix, element.getDouble(0)); } else { int ix = 0; //Shape.offset(javaShapeInformation); for (int i = 0; i < indices.length; i++) if (size(i) != 1) ix += indices[i] * stride(i); if (ix >= data.length()) throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); data.put(ix, element.getDouble(0)); } return this; }
/** * Returns the number of rows * in the array (only 2d) throws an exception when * called when not 2d * * @return the number of rows in the matrix */ @Override public int rows() { // FIXME: if (isMatrix()) return (int) size(0); else if (Shape.isRowVectorShape(shape())) { return 1; } else if (Shape.isColumnVectorShape(shape())) { return (int) length(); } throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid"); }
/** * Returns the slice of this from the specified dimension * * @param slice the dimension to return from * @param dimension the dimension of the slice to return * @return the slice of this matrix from the specified dimension * and dimension */ @Override public INDArray slice(long slice, int dimension) { Nd4j.getCompressor().autoDecompress(this); long slices = size(dimension); if (slice >= slices) throw new IllegalArgumentException("Illegal slice " + slice); if (Shape.rank(javaShapeInformation) == 0) { if (slice == 0) return createScalarForIndex(slice, true); else throw new IllegalArgumentException("Can't slice a 0-d NDArray"); } if (slice < 0) slice += rank(); INDArrayIndex[] indexes = new INDArrayIndex[rank()]; indexes[dimension] = NDArrayIndex.point(slice); for (int i = 0; i < rank(); i++) { if (i != dimension) indexes[i] = NDArrayIndex.all(); } return get(indexes); }
} else { if (i < shape().length) retShape[i] = Math.max(shape[i], size(i)); else retShape[i] = shape[i]; if (i < rank() && size(i) == 1) broadCastDimensions.add(i); else nonBroadCastDimensions.add(i); if (i < shape().length) retShape[i] = Math.max(shape[i], size(i)); else retShape[i] = shape[i]; for(int i = 0; i < shape.length; i++) { if(i < rank()) { if(size(i) == 1) repeat[i] = (int) shape[i]; else {
/** * Get the vector along a particular dimension * * @param index the index of the vector to get * @param dimension the dimension to get the vector from * @return the vector along a particular dimension */ @Override public INDArray vectorAlongDimension(int index, int dimension) { if (dimension < 0) dimension = Shape.rank(javaShapeInformation) + dimension; //return the whole thing if (dimension == Shape.rank(javaShapeInformation) - 1 && size(dimension) == 1 && rank() > 2 || rank() > 2 && dimension == 0 && size(dimension) == 1) { return this; } INDArray ret = tensorAlongDimension(index, dimension); if (isMatrix() && ret.isVector() && dimension == 1 && !ret.isRowVector()) return ret.reshape(ArrayUtil.reverseCopy(ret.shape())); else if (isMatrix() && ret.isVector() && dimension == 0 && !ret.isColumnVector()) return ret.reshape(ArrayUtil.reverseCopy(ret.shape())); return ret; }
@Override public INDArray repeat(int dimension, long... repeats) { Nd4j.getCompressor().autoDecompress(this); if (dimension < 0) dimension += rank(); if (repeats.length < rank()) { if (dimension > 0) repeats = Longs.concat(ArrayUtil.nTimes((long) rank() - repeats.length, 1), repeats); //append rather than prepend for dimension == 0 else repeats = Longs.concat(repeats, ArrayUtil.nTimes((long) rank() - repeats.length, 1)); } long[] newShape = new long[rank()]; for (int i = 0; i < newShape.length; i++) newShape[i] = size(i) * repeats[i]; INDArray ret = Nd4j.create(newShape); //number of times to repeat each value long repeatDelta = ArrayUtil.prod(newShape) / length(); for (int i = 0; i < tensorssAlongDimension(dimension); i++) { INDArray thisTensor = tensorAlongDimension(i, dimension); INDArray retTensor = ret.tensorAlongDimension(i, dimension); int retIdx = 0; for (int k = 0; k < thisTensor.length(); k++) { for (int j = 0; j < repeatDelta; j++) { retTensor.putScalar(retIdx++, thisTensor.getDouble(k)); } } } return ret; }
if (!columnVector.isColumnVector() || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) { throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")");
if (!rowVector.isRowVector() || this.rank() > 1 && rowVector.rank() > 1 && this.size(1) != rowVector.size(1) || rowVector.length() <= 1) { throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + ", row vector shape =" + Arrays.toString(rowVector.shape()) + ")");
@Override public int getLeadingOnes() { int numLeadingOnes = 0; for (int i = 0; i < rank(); i++) { if (size(i) == 1) numLeadingOnes++; } return numLeadingOnes; }
@Override public int getTrailingOnes() { int numLeadingOnes = 0; for (int i = rank() - 1; i > 0; i--) { if (size(i) == 1) numLeadingOnes++; } return numLeadingOnes; }
/** * Number of slices: aka shape[0] * * @return the number of slices * for this nd array */ @Override public int slices() { if (isRowVector()) return length(); return size(0); }
@Override @Deprecated public int linearIndex(int i) { int idx = i; for (int j = 0; j < Shape.rank(javaShapeInformation) - 1; j++) { if (size(i) == 1) continue; idx += i * stride(j); } return Shape.offset(javaShapeInformation) + (idx); }