@Override public boolean isVectorOrScalar() { return isVector() || isScalar(); }
/** * Flattens the array for linear indexing * * @return the flattened version of this array */ @Override public void sliceVectors(List<INDArray> list) { if (isVector()) list.add(this); else { for (int i = 0; i < slices(); i++) { slice(i).sliceVectors(list); } } }
protected INDArray createScalarForIndex(long i, boolean applyOffset) { if(isVector()) return getScalar(i); return Nd4j.create(data(), new long[] {1, 1}, new long[] {1, 1}, i); }
/** * Get whole columns * from the passed indices. * * @param cindices */ @Override public INDArray getColumns(int... cindices) { if (!isMatrix() && !isVector()) throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); if (isVector()) { return Nd4j.pullRows(this, 0, cindices, this.ordering()); } else { INDArray ret = Nd4j.create(rows(), cindices.length); for (int i = 0; i < cindices.length; i++) ret.putColumn(i, getColumn(cindices[i])); return ret; } }
@Override public double[] toDoubleVector() { if(!isVector()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector!"); } return dup().data().asDouble(); }
@Override public float[] toFloatVector() { if(!isVector()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector!"); } return dup().data().asFloat(); }
@Override public int[] toIntVector() { if(!isVector()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector!"); } return dup().data().asInt(); }
@Override public long[] toLongVector() { if(!isVector()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector!"); } return dup().data().asLong(); }
/** * Get whole rows from the passed indices. * * @param rindices */ @Override public INDArray getRows(int[] rindices) { Nd4j.getCompressor().autoDecompress(this); if (!isMatrix() && !isVector()) throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); if (isVector()) return Nd4j.pullRows(this, 1, rindices); else { INDArray ret = Nd4j.create(rindices.length, columns()); for (int i = 0; i < rindices.length; i++) ret.putRow(i, getRow(rindices[i])); return ret; } }
protected void assertSlice(INDArray put, long slice) { assert slice <= slices() : "Invalid slice specified " + slice; long[] sliceShape = put.shape(); if (Shape.isRowVectorShape(sliceShape)) { return; } else { long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); //no need to compare for scalar; primarily due to shapes either being [1] or length 0 if (put.isScalar()) return; if (isVector() && put.isVector() && put.length() < length()) return; //edge case for column vectors if (Shape.isColumnVectorShape(sliceShape)) return; if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape) && !Shape.isRowVectorShape(sliceShape)) throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ", Arrays.toString(sliceShape), Arrays.toString(requiredShape))); } }
} else if (isVector() && n.isVector()) {
public INDArray cumsumi(int dimension) { if (isVector()) { double s = 0.0; for (int i = 0; i < length(); i++) {
put(0, put.getScalar(0)); return this; } else if (isVector()) { assert put.isScalar() || put.isVector() && put .length() == length() : "Invalid dimension on insertion. Can only insert scalars input vectors";
if (isVector()) { switch (operation) { case 'a':
return doRowWise(rowVector.dup(), operation); if (isVector()) { switch (operation) { case 'a':
/** * Flattens the array for linear indexing * * @return the flattened version of this array */ @Override public void sliceVectors(List<INDArray> list) { if (isVector()) list.add(this); else { for (int i = 0; i < slices(); i++) { slice(i).sliceVectors(list); } } }
/** * Get whole rows from the passed indices. * * @param rindices */ @Override public INDArray getRows(int[] rindices) { Nd4j.getCompressor().autoDecompress(this); if (!isMatrix() && !isVector()) throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); if (isVector()) return Nd4j.pullRows(this, 1, rindices); else { INDArray ret = Nd4j.create(rindices.length, columns()); for (int i = 0; i < rindices.length; i++) ret.putRow(i, getRow(rindices[i])); return ret; } }
protected void assertSlice(INDArray put, int slice) { assert slice <= slices() : "Invalid slice specified " + slice; int[] sliceShape = put.shape(); if (Shape.isRowVectorShape(sliceShape)) { return; } else { int[] requiredShape = ArrayUtil.removeIndex(shape(), 0); //no need to compare for scalar; primarily due to shapes either being [1] or length 0 if (put.isScalar()) return; if (isVector() && put.isVector() && put.length() < length()) return; //edge case for column vectors if (Shape.isColumnVectorShape(sliceShape)) return; if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape) && !Shape.isRowVectorShape(sliceShape)) throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ", Arrays.toString(sliceShape), Arrays.toString(requiredShape))); } }