/** * copy a vector to another vector. * * @param x * @param y */ @Override public void copy(IComplexNDArray x, IComplexNDArray y) { if (x.data().dataType() == DataBuffer.Type.DOUBLE) zcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); else ccopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); }
@Override public void swap(IComplexNDArray x, IComplexNDArray y) { if (x.data().dataType() == DataBuffer.Type.DOUBLE) zswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); else cswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); }
/** * computes a vector-scalar product and adds the result to a vector. * * @param n * @param alpha * @param x * @param y */ @Override public void axpy(long n, IComplexNumber alpha, IComplexNDArray x, IComplexNDArray y) { if (x.data().dataType() == DataBuffer.Type.DOUBLE) zaxpy(n, alpha.asDouble(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); else caxpy(n, alpha.asFloat(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); }
private String shapeInfo(INDArray arr) { return Arrays.toString(arr.shape()) + " and stride " + Arrays.toString(arr.stride()) + " and offset " + arr.offset() + " and blas stride of " + BlasBufferUtil.getBlasStride(arr); }
/** * finds the element of a vector that has the largest absolute value. * * @param arr * @return */ @Override public int iamax(IComplexNDArray arr) { if (arr.data().dataType() == DataBuffer.Type.DOUBLE) return izamax(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); return icamax(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); }
/** * computes a vector by a scalar product. * * @param N * @param alpha * @param X */ @Override public void scal(long N, IComplexNumber alpha, IComplexNDArray X) { if (X.data().dataType() == DataBuffer.Type.DOUBLE) zscal(N, alpha.asDouble(), X, BlasBufferUtil.getBlasStride(X)); else cscal(N, alpha.asFloat(), X, BlasBufferUtil.getBlasStride(X)); }
/** * Return the proper stride * through a vector * relative to the ordering of the array * This is for incX/incY parameters in BLAS. * * @param arr the array to get the stride for * @return the stride wrt the ordering * for the given array */ public static int getStrideForOrdering(INDArray arr) { if (arr.ordering() == NDArrayFactory.FORTRAN) { return getBlasStride(arr); } else { if (arr instanceof IComplexNDArray) return arr.stride(1) / 2; return arr.stride(1); } }
/** * computes a vector-scalar product and adds the result to a vector. * * @param n * @param alpha * @param x * @param y */ @Override public void axpy(long n, double alpha, INDArray x, INDArray y) { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, x, y); if (x.isSparse() && !y.isSparse()) { Nd4j.getSparseBlasWrapper().level1().axpy(n, alpha, x, y); } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y); daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); } else if (x.data().dataType() == DataBuffer.Type.FLOAT) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y); saxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); } else { DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, x, y); haxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); } }
/** * performs rotation of points in the plane. * * @param N * @param X * @param Y * @param c * @param s */ @Override public void rot(long N, INDArray X, INDArray Y, double c, double s) { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, X, Y); if (X.isSparse() && !Y.isSparse()) { Nd4j.getSparseBlasWrapper().level1().rot(N, X, Y, c, s); } else if (X.data().dataType() == DataBuffer.Type.DOUBLE) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, X, Y); drot(N, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(X), c, s); } else { DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, X, Y); srot(N, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(X), (float) c, (float) s); } }
/** * swaps a vector with another vector. * * @param x * @param y */ @Override public void copy(INDArray x, INDArray y) { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, x, y); if (x.isSparse() || y.isSparse()) { Nd4j.getSparseBlasWrapper().level1().copy(x, y); return; } if (x.data().dataType() == DataBuffer.Type.DOUBLE) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y); dcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); } else { DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y); scopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); } }
/** * swaps a vector with another vector. * * @param x * @param y */ @Override public void swap(INDArray x, INDArray y) { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, x, y); if (x.isSparse() || y.isSparse()) { Nd4j.getSparseBlasWrapper().level1().swap(x, y); return; } if (x.data().dataType() == DataBuffer.Type.DOUBLE) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y); dswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); } else { DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y); sswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); } }
return ddot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y)); } else if (X.data().dataType() == DataBuffer.Type.FLOAT) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, X, Y); return sdot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y)); } else { DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, X, Y); return hdot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y));
/** * computes the sum of magnitudes of all vector elements or, for a complex vector x, the sum * * @param arr * @return */ @Override public double asum(INDArray arr) { if (arr.isSparse()) { return Nd4j.getSparseBlasWrapper().level1().asum(arr); } if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, arr); if (arr.data().dataType() == DataBuffer.Type.DOUBLE) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, arr); return dasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); } else if (arr.data().dataType() == DataBuffer.Type.FLOAT) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, arr); return sasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); } else { DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, arr); return hasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); } }
/** * finds the element of a * vector that has the largest absolute value. * * @param arr * @return */ @Override public int iamax(INDArray arr) { if (arr.isSparse()) { return Nd4j.getSparseBlasWrapper().level1().iamax(arr); } if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, arr); if (arr.data().dataType() == DataBuffer.Type.DOUBLE) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, arr); return idamax(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); } else { DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, arr); return isamax(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); } }
/** * computes a vector by a scalar product. * * @param N * @param alpha * @param X */ @Override public void scal(long N, double alpha, INDArray X) { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, X); if (X.isSparse()) { Nd4j.getSparseBlasWrapper().level1().scal(N, alpha, X); } else if (X.data().dataType() == DataBuffer.Type.DOUBLE) dscal(N, alpha, X, BlasBufferUtil.getBlasStride(X)); else if (X.data().dataType() == DataBuffer.Type.FLOAT) sscal(N, (float) alpha, X, BlasBufferUtil.getBlasStride(X)); else if (X.data().dataType() == DataBuffer.Type.HALF) Nd4j.getExecutioner().exec(new ScalarMultiplication(X, alpha)); }
/** * computes the Euclidean norm of a vector. * * @param arr * @return */ @Override public double nrm2(INDArray arr) { if (arr.isSparse()) { return Nd4j.getSparseBlasWrapper().level1().nrm2(arr); } if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, arr); if (arr.isSparse()) { return Nd4j.getSparseBlasWrapper().level1().nrm2(arr); } if (arr.data().dataType() == DataBuffer.Type.DOUBLE) { DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, arr); return dnrm2(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); } else { DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, arr); return snrm2(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); } // TODO: add nrm2 for half, as call to appropriate NativeOp<HALF> }
@Override public void swap(IComplexNDArray x, IComplexNDArray y) { if (x.data().dataType() == DataBuffer.Type.DOUBLE) zswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); else cswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); }
/** * copy a vector to another vector. * * @param x * @param y */ @Override public void copy(IComplexNDArray x, IComplexNDArray y) { if (x.data().dataType() == DataBuffer.Type.DOUBLE) zcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); else ccopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); }
private String shapeInfo(INDArray arr) { return Arrays.toString(arr.shape()) + " and stride " + Arrays.toString(arr.stride()) + " and offset " + arr.offset() + " and blas stride of " + BlasBufferUtil.getBlasStride(arr); }
/** * finds the element of a vector that has the largest absolute value. * * @param arr * @return */ @Override public int iamax(IComplexNDArray arr) { if (arr.data().dataType() == DataBuffer.Type.DOUBLE) return izamax(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); return icamax(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); }