@Override public INDArray put(int i, INDArray element) { // TODO remove and use basendarray method if (!element.isScalar()) throw new IllegalArgumentException("Element must be a scalar"); return putScalar(i, element.getDouble(0)); }
/** * Inserts the element at the specified index * * @param i the index insert into * @param element a scalar ndarray * @return a scalar ndarray of the element at this index */ @Override public INDArray put(int i, INDArray element) { if (!element.isScalar()) throw new IllegalArgumentException("Element must be a scalar"); return putScalar(i, element.getDouble(0)); }
@Override public INDArray put(int[] indexes, INDArray element) { if (!element.isScalar()) throw new IllegalArgumentException("Unable to insert anything but a scalar"); if (indexes.length != rank) throw new IllegalStateException( "Cannot use putScalar with indexes length " + indexes.length + " on rank " + rank); addOrUpdate(ArrayUtil.toLongArray(indexes), element.getDouble(0)); return this; }
@Override public Object next() { INDArray s = iterateOver.slice(i++); if (s.isScalar()) { return s.getDouble(0); } else { return s; } }
protected static void validateConcat(int dimension, INDArray... arrs) { if (arrs[0].isScalar()) { for (int i = 1; i < arrs.length; i++) if (!arrs[i].isScalar()) throw new IllegalArgumentException("All arrays must have same dimensions"); } else { int dims = arrs[0].shape().length; long[] shape = ArrayUtil.removeIndex(arrs[0].shape(), dimension); for (int i = 1; i < arrs.length; i++) { assert Arrays.equals(shape, ArrayUtil.removeIndex(arrs[i].shape(), dimension)); assert arrs[i].shape().length == dims; } } }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val weightsName = nodeDef.getInput(1); val variable = initWith.getVariable(weightsName); val tmp = initWith.getArrForVarName(weightsName); // if second argument is scalar - we should provide array of same shape if (tmp != null) { if (tmp.isScalar()) { this.pow = tmp.getDouble(0); } } }
@Override public Number scalar() { if(scalarValue == null && y() != null && y().isScalar()) return y().getDouble(0); return scalarValue; }
/** * Generate a linearly spaced vector * * @param lower upper bound * @param upper lower bound * @param num the step size * @return the linearly spaced vector */ @Override public INDArray linspace(int lower, int upper, int num) { double[] data = new double[num]; for (int i = 0; i < num; i++) { double t = (double) i / (num - 1); data[i] = lower * (1 - t) + t * upper; } //edge case for scalars INDArray ret = Nd4j.create(data.length); if (ret.isScalar()) return ret; for (int i = 0; i < ret.length(); i++) ret.putScalar(i, data[i]); return ret; }
/** * Inserts the element at the specified index * * @param i the index insert into * @param element a scalar ndarray * @return a scalar ndarray of the element at this index */ @Override public IComplexNDArray put(int i, INDArray element) { if (element == null) throw new IllegalArgumentException("Unable to insert null element"); assert element.isScalar() : "Unable to insert non scalar element"; if (element instanceof IComplexNDArray) { IComplexNDArray n1 = (IComplexNDArray) element; IComplexNumber n = n1.getComplex(0); put(i, n); } else putScalar(i, Nd4j.createDouble(element.getDouble(0), 0.0)); return this; }
@Override public boolean equalsWithEps(Object o, double eps) { if (o == null) return false; if (!(o instanceof INDArray)) return false; INDArray n = (INDArray) o; if (this.lengthLong() != n.lengthLong()) return false; if (isScalar() && n.isScalar()) { // TODO } else if (isVector && n.isVector()) { // TODO } if (!Arrays.equals(this.shape(), n.shape())) return false; // TODO return false; }
@Override public INDArray mmul(INDArray other) { long[] shape = {rows(), other.columns()}; INDArray result = createUninitialized(shape, 'f'); if (result.isScalar()) return Nd4j.scalar(Nd4j.getBlasWrapper().dot(this, other)); return mmuli(other, result); }
protected void assertSlice(INDArray put, long slice) { assert slice <= slices() : "Invalid slice specified " + slice; long[] sliceShape = put.shape(); if (Shape.isRowVectorShape(sliceShape)) { return; } else { long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); //no need to compare for scalar; primarily due to shapes either being [1] or length 0 if (put.isScalar()) return; if (isVector() && put.isVector() && put.length() < length()) return; //edge case for column vectors if (Shape.isColumnVectorShape(sliceShape)) return; if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape) && !Shape.isRowVectorShape(sliceShape)) throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ", Arrays.toString(sliceShape), Arrays.toString(requiredShape))); } }
public BaseNDArray(List<INDArray> slices, long[] shape, long[] stride, char ordering) { DataBuffer ret = slices.get(0).data().dataType() == (DataBuffer.Type.FLOAT) ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); this.data = ret; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering)); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); if (slices.get(0).isScalar()) { for (int i = 0; i < length(); i++) { putScalar(i, slices.get(i).getDouble(0)); } } else { for (int i = 0; i < slices(); i++) { putSlice(i, slices.get(i)); } } }
/** * Copy real numbers to arr * @param arr the arr to copy to */ protected void copyRealTo(INDArray arr) { INDArray linear = arr.linearView(); IComplexNDArray thisLinear = linearView(); if (arr.isScalar()) arr.putScalar(0, getReal(0)); else for (int i = 0; i < linear.length(); i++) { arr.putScalar(i, thisLinear.getReal(i)); } }
/** * Copy imaginary numbers to the given * ndarray * @param arr the array to copy imaginary numbers to */ protected void copyImagTo(INDArray arr) { INDArray linear = arr.linearView(); IComplexNDArray thisLinear = linearView(); if (arr.isScalar()) arr.putScalar(0, getReal(0)); else for (int i = 0; i < linear.length(); i++) { arr.putScalar(i, thisLinear.getImag(i)); } }
/** * Perform a copy matrix multiplication * * @param other the other matrix to perform matrix multiply with * @return the result of the matrix multiplication */ @Override public INDArray mmul(INDArray other) { // FIXME: for 1D case, we probably want vector output here? long[] shape = {rows(), other.rank() == 1 ? 1 : other.columns()}; INDArray result = createUninitialized(shape, 'f'); if (result.isScalar()) return Nd4j.scalar(Nd4j.getBlasWrapper().dot(this, other)); return mmuli(other, result); }
public static void checkForNaN(INDArray z) { if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) return; int match = 0; if (!z.isScalar()) { MatchCondition condition = new MatchCondition(z, Conditions.isNan()); match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0); } else { if (z.data().dataType() == DataBuffer.Type.DOUBLE) { if (Double.isNaN(z.getDouble(0))) match = 1; } else { if (Float.isNaN(z.getFloat(0))) match = 1; } } if (match > 0) throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): "); }
public static void checkForInf(INDArray z) { if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) return; int match = 0; if (!z.isScalar()) { MatchCondition condition = new MatchCondition(z, Conditions.isInfinite()); match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0); } else { if (z.data().dataType() == DataBuffer.Type.DOUBLE) { if (Double.isInfinite(z.getDouble(0))) match = 1; } else { if (Float.isInfinite(z.getFloat(0))) match = 1; } } if (match > 0) throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)"); }
/** * in place (element wise) multiplication of two matrices * * @param other the second ndarray to multiply * @param result the result ndarray * @return the result of the multiplication */ @Override public IComplexNDArray muli(INDArray other, INDArray result) { IComplexNDArray cOther = (IComplexNDArray) other; IComplexNDArray cResult = (IComplexNDArray) result; IComplexNDArray linear = linearView(); IComplexNDArray cOtherLinear = cOther.linearView(); IComplexNDArray cResultLinear = cResult.linearView(); if (other.isScalar()) return muli(cOther.getComplex(0), result); IComplexNumber c = Nd4j.createComplexNumber(0, 0); IComplexNumber d = Nd4j.createComplexNumber(0, 0); for (int i = 0; i < length(); i++) cResultLinear.putScalar(i, linear.getComplex(i, c).muli(cOtherLinear.getComplex(i, d))); return cResult; }
/** * in place (element wise) division of two matrices * * @param other the second ndarray to divide * @param result the result ndarray * @return the result of the divide */ @Override public IComplexNDArray divi(INDArray other, INDArray result) { IComplexNDArray cOther = (IComplexNDArray) other; IComplexNDArray cResult = (IComplexNDArray) result; IComplexNDArray linear = linearView(); IComplexNDArray cOtherLinear = cOther.linearView(); IComplexNDArray cResultLinear = cResult.linearView(); if (other.isScalar()) return divi(cOther.getComplex(0), result); IComplexNumber c = Nd4j.createComplexNumber(0, 0); IComplexNumber d = Nd4j.createComplexNumber(0, 0); for (int i = 0; i < length(); i++) cResultLinear.putScalar(i, linear.getComplex(i, c).divi(cOtherLinear.getComplex(i, d))); return cResult; }