public static void assertDouble(INDArray... d) { for (INDArray d1 : d) assertDouble(d1); }
public static void assertDouble(INDArray... d) { for (INDArray d1 : d) assertDouble(d1); }
/** * @param x * @param y * @param a * @param alpha * @return */ public static IComplexNDArray gerc(IComplexNDArray x, IComplexNDArray y, IComplexNDArray a, IComplexDouble alpha) { DataTypeValidation.assertDouble(x, y, a); throw new UnsupportedOperationException(); }
@Override public IComplexNDArray gemv(IComplexFloat alpha, IComplexNDArray a, IComplexNDArray x, IComplexFloat beta, IComplexNDArray y) { DataTypeValidation.assertDouble(a, x, y); throw new UnsupportedOperationException(); }
@Override public IComplexNDArray gemv(IComplexDouble alpha, IComplexNDArray a, IComplexNDArray x, IComplexDouble beta, IComplexNDArray y) { DataTypeValidation.assertDouble(a, x, y); if (y.isScalar()) return y.putScalar(0, dotc(a, x)); throw new UnsupportedOperationException(); }
/** * Multiply the given ndarray * by alpha * * @param alpha * @param x * @return */ public static INDArray scal(double alpha, INDArray x) { DataTypeValidation.assertDouble(x); BLAS.getInstance().dscal( x.length(), alpha, x.data().asDouble(), x.offset(), x.majorStride()); return x; }
/** * @param da * @param A * @param B */ public static void axpy(IComplexDouble da, IComplexNDArray A, IComplexNDArray B) { DataTypeValidation.assertDouble(A, B); CublasPointer aCPointer = new CublasPointer(A); CublasPointer bCPointer = new CublasPointer(B); sync(); JCublas2.cublasZaxpy( ContextHolder.getInstance().getHandle(), A.length(), PointerUtil.getPointer(jcuda.cuDoubleComplex.cuCmplx(da.realComponent().floatValue(), da.imaginaryComponent().floatValue())), aCPointer.getDevicePointer(), A.majorStride(), bCPointer.getDevicePointer(), B.majorStride() ); sync(); }
/** * Multiply the given ndarray * by alpha * * @param alpha * @param x * @return */ public static INDArray scal(double alpha, INDArray x) { if(x.data().dataType() == DataBuffer.Type.FLOAT) { return scal((float) alpha,x); } DataTypeValidation.assertDouble(x); BLAS.getInstance().dscal( x.length(), alpha, x.data().asDouble(), x.offset(), x.majorStride()); return x; }
public static INDArray ger(INDArray A, INDArray B, INDArray C, double alpha) { DataTypeValidation.assertDouble(A, B, C); // = alpha * A * transpose(B) + C BLAS.getInstance().dger( A.rows(), // m A.columns(),// n alpha, // alpha A.data().asDouble(), // d_A or x A.rows(), // incx B.data().asDouble(), // dB or y B.rows(), // incy C.data().asDouble(), // dC or A C.rows() // lda ); return C; }
/** * Multiply the given ndarray * by alpha * * @param alpha * @param x * @return */ public static INDArray scal(double alpha, INDArray x) { DataTypeValidation.assertDouble(x); sync(); CublasPointer xCPointer = new CublasPointer(x); JCublas2.cublasDscal( ContextHolder.getInstance().getHandle(), x.length(), Pointer.to(new double[]{alpha}), xCPointer.getDevicePointer(), x.majorStride()); sync(); xCPointer.copyToHost(); releaseCublasPointers(xCPointer); return x; }
/** * * @param A * @param B * @param C * @param alpha * @return */ public static INDArray ger(INDArray A, INDArray B, INDArray C, double alpha) { DataTypeValidation.assertDouble(A, B, C); // = alpha * A * transpose(B) + C BLAS.getInstance().dger( A.rows(), // m A.columns(),// n alpha, // alpha A.data().asDouble(), // d_A or x A.rows(), // incx B.data().asDouble(), // dB or y B.rows(), // incy C.data().asDouble(), // dC or A C.rows() // lda ); return C; }
/** * Simpler version of saxpy * taking in to account the parameters of the ndarray * * @param alpha the alpha to scale by * @param x the x * @param y the y */ public static void axpy(double alpha, INDArray x, INDArray y) { DataTypeValidation.assertDouble(x, y); sync(); CublasPointer xCPointer = new CublasPointer(x); CublasPointer yCPointer = new CublasPointer(y); JCublas2.cublasDaxpy( ContextHolder.getInstance().getHandle(),x.length() , Pointer.to(new double[]{alpha}) , xCPointer.getDevicePointer() , 1 , yCPointer.getDevicePointer() , 1); sync(); yCPointer.copyToHost(); releaseCublasPointers(xCPointer, yCPointer); }
/** * Complex multiplication of an ndarray * * @param alpha * @param x * @return */ public static IComplexNDArray scal(IComplexDouble alpha, IComplexNDArray x) { DataTypeValidation.assertDouble(x); sync(); CublasPointer xCPointer = new CublasPointer(x); JCublas2.cublasZscal( ContextHolder.getInstance().getHandle(), x.length(), PointerUtil.getPointer(jcuda.cuDoubleComplex.cuCmplx(alpha.realComponent(), alpha.imaginaryComponent())), xCPointer.getDevicePointer(), 1 ); sync(); xCPointer.copyToHost(); releaseCublasPointers(xCPointer); return x; }
@Override public IComplexNDArray gemv(IComplexFloat alpha, IComplexNDArray a, IComplexNDArray x, IComplexFloat beta, IComplexNDArray y) { DataTypeValidation.assertDouble(a, x, y); NativeBlas.cgemv( 'N', a.rows(), a.columns(), (ComplexFloat) alpha, a.data().asFloat(), a.blasOffset(), a.rows(), x.data().asFloat(), x.offset(), x.secondaryStride(), (ComplexFloat) beta, y.data().asFloat(), y.blasOffset(), y.secondaryStride() ); return y; }
@Override protected void dgemm(char Order, char TransA, char TransB, int M, int N, int K, double alpha, INDArray A, int lda, INDArray B, int ldb, double beta, INDArray C, int ldc) { //A = Shape.toOffsetZero(A); //B = Shape.toOffsetZero(B); if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) logger.warn("DOUBLE gemm called"); Nd4j.getExecutioner().push(); CudaContext ctx = allocator.getFlowController().prepareAction(C, A, B); DataTypeValidation.assertDouble(A, B, C); CublasPointer cAPointer = new CublasPointer(A, ctx); CublasPointer cBPointer = new CublasPointer(B, ctx); CublasPointer cCPointer = new CublasPointer(C, ctx); cublasHandle_t handle = ctx.getHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasDgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda, (DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta), (DoublePointer) cCPointer.getDevicePointer(), ldc); } allocator.registerAction(ctx, C, A, B); OpExecutionerUtil.checkForAny(C); }
@Override protected void dgemm(char Order, char TransA, char TransB, int M, int N, int K, double alpha, INDArray A, int lda, INDArray B, int ldb, double beta, INDArray C, int ldc) { //A = Shape.toOffsetZero(A); //B = Shape.toOffsetZero(B); if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) logger.warn("DOUBLE gemm called"); Nd4j.getExecutioner().push(); CudaContext ctx = allocator.getFlowController().prepareAction(C, A, B); DataTypeValidation.assertDouble(A, B, C); CublasPointer cAPointer = new CublasPointer(A, ctx); CublasPointer cBPointer = new CublasPointer(B, ctx); CublasPointer cCPointer = new CublasPointer(C, ctx); cublasHandle_t handle = ctx.getHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasDgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda, (DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta), (DoublePointer) cCPointer.getDevicePointer(), ldc); } allocator.registerAction(ctx, C, A, B); OpExecutionerUtil.checkForAny(C); }
@Override public IComplexNDArray gemv(IComplexDouble alpha, IComplexNDArray a, IComplexNDArray x, IComplexDouble beta, IComplexNDArray y) { DataTypeValidation.assertDouble(a, x, y); if (y.isScalar()) return y.putScalar(0, dotc(a, x)); NativeBlas.zgemv( 'N', a.rows(), a.columns(), (ComplexDouble) alpha, a.data().asDouble(), a.blasOffset(), a.rows(), x.data().asDouble(), x.offset(), x.secondaryStride(), (ComplexDouble) beta, y.data().asDouble(), y.blasOffset(), y.secondaryStride() ); return y; }
public static INDArray ger(INDArray A, INDArray B, INDArray C, double alpha) { DataTypeValidation.assertDouble(A, B, C); sync(); // = alpha * A * transpose(B) + C CublasPointer aCPointer = new CublasPointer(A); CublasPointer bCPointer = new CublasPointer(B); CublasPointer cCPointer = new CublasPointer(C); JCublas2.cublasDger( ContextHolder.getInstance().getHandle(), A.rows(), // m A.columns(),// n Pointer.to(new double[]{alpha}), // alpha aCPointer.getDevicePointer(), // d_A or x A.rows(), // incx bCPointer.getDevicePointer(), // dB or y B.rows(), // incy cCPointer.getDevicePointer(), // dC or A C.rows() // lda ); cCPointer.copyToHost(); releaseCublasPointers(aCPointer,bCPointer,cCPointer); sync(); return C; }
DataTypeValidation.assertDouble(A, B, C);
IComplexDouble Alpha) { DataTypeValidation.assertDouble(A, B, C);