public static void assertDouble(INDArray... d) { for (INDArray d1 : d) assertDouble(d1); }
public static void assertFloat(INDArray... d2) { for (INDArray d3 : d2) assertFloat(d3); }
/** * @param da * @param A * @param B */ public static void axpy(IComplexNumber da, IComplexNDArray A, IComplexNDArray B) { DataTypeValidation.assertSameDataType(A, B); throw new UnsupportedOperationException(); }
/** * Dot product between two complex ndarrays * * @param x * @param y * @return */ public static IComplexDouble dot(IComplexNDArray x, IComplexNDArray y) { DataTypeValidation.assertSameDataType(x, y); throw new UnsupportedOperationException(); }
public static void assertDouble(INDArray... d) { for (INDArray d1 : d) assertDouble(d1); }
/** * @param alpha * @param x * @param y * @param a * @return */ public static IComplexNDArray geru(IComplexNumber alpha, IComplexNDArray x, IComplexNDArray y, IComplexNDArray a) { DataTypeValidation.assertSameDataType(x, y, a); throw new UnsupportedOperationException(); }
public static void assertFloat(INDArray... d2) { for (INDArray d3 : d2) assertFloat(d3); }
/** * @param x * @param y * @param a * @param alpha * @return */ public static IComplexNDArray gerc(IComplexNDArray x, IComplexNDArray y, IComplexNDArray a, IComplexDouble alpha) { DataTypeValidation.assertDouble(x, y, a); throw new UnsupportedOperationException(); }
/** * Copy x to y * * @param x the origin * @param y the destination */ public static void copy(IComplexNDArray x, IComplexNDArray y) { DataTypeValidation.assertSameDataType(x, y); Nd4j.getExecutioner().exec(new CopyOp(x, y, y, x.length())); }
/** * Scale a complex ndarray * * @param alpha * @param x * @return */ public static IComplexNDArray sscal(IComplexFloat alpha, IComplexNDArray x) { DataTypeValidation.assertFloat(x); throw new UnsupportedOperationException(); }
@Override public IComplexNDArray gemv(IComplexFloat alpha, IComplexNDArray a, IComplexNDArray x, IComplexFloat beta, IComplexNDArray y) { DataTypeValidation.assertDouble(a, x, y); throw new UnsupportedOperationException(); }
@Override protected float hdot(long N, INDArray X, int incX, INDArray Y, int incY) { DataTypeValidation.assertSameDataType(X, Y); // CudaContext ctx = allocator.getFlowController().prepareAction(null, X, Y); float ret = 1f; // CublasPointer xCPointer = new CublasPointer(X, ctx); // CublasPointer yCPointer = new CublasPointer(Y, ctx); Dot dot = new Dot(X, Y); Nd4j.getExecutioner().exec(dot); ret = dot.getFinalResult().floatValue(); /* cublasHandle_t handle = ctx.getHandle(); synchronized (handle) { long result = cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); if (result != 0) throw new IllegalStateException("cublasSetStream failed"); FloatPointer resultPointer = new FloatPointer(0.0f); cuBlasSdot_v2(new cublasContext(handle), N, xCPointer.getDevicePointer(), incX, yCPointer.getDevicePointer(), incY, resultPointer); ret = resultPointer.get(); } */ // allocator.registerAction(ctx, null, X, Y); return ret; }
/** * Scale a complex ndarray * * @param alpha * @param x * @return */ public static IComplexNDArray sscal(IComplexFloat alpha, IComplexNDArray x) { DataTypeValidation.assertFloat(x); NativeBlas.cscal(x.length(), (org.jblas.ComplexFloat) alpha, x.data().asFloat(), x.offset(), x.majorStride()); return x; }
@Override public IComplexNDArray gemv(IComplexDouble alpha, IComplexNDArray a, IComplexNDArray x, IComplexDouble beta, IComplexNDArray y) { DataTypeValidation.assertDouble(a, x, y); if (y.isScalar()) return y.putScalar(0, dotc(a, x)); throw new UnsupportedOperationException(); }
@Override protected float hdot(int N, INDArray X, int incX, INDArray Y, int incY) { DataTypeValidation.assertSameDataType(X, Y); // CudaContext ctx = allocator.getFlowController().prepareAction(null, X, Y); float ret = 1f; // CublasPointer xCPointer = new CublasPointer(X, ctx); // CublasPointer yCPointer = new CublasPointer(Y, ctx); Dot dot = new Dot(X, Y); Nd4j.getExecutioner().exec(dot); ret = dot.getFinalResult().floatValue(); /* cublasHandle_t handle = ctx.getHandle(); synchronized (handle) { long result = cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); if (result != 0) throw new IllegalStateException("cublasSetStream failed"); FloatPointer resultPointer = new FloatPointer(0.0f); cuBlasSdot_v2(new cublasContext(handle), N, xCPointer.getDevicePointer(), incX, yCPointer.getDevicePointer(), incY, resultPointer); ret = resultPointer.get(); } */ // allocator.registerAction(ctx, null, X, Y); return ret; }
/** * @param da * @param A * @param B */ public static void axpy(IComplexFloat da, IComplexNDArray A, IComplexNDArray B) { DataTypeValidation.assertFloat(A, B); CublasPointer aCPointer = new CublasPointer(A); CublasPointer bCPointer = new CublasPointer(B); sync(); JCublas2.cublasCaxpy( ContextHolder.getInstance().getHandle(), A.length(), PointerUtil.getPointer(jcuda.cuComplex.cuCmplx(da.realComponent().floatValue(), da.imaginaryComponent().floatValue())), aCPointer.getDevicePointer(), A.majorStride() / 2, bCPointer.getDevicePointer(), B.majorStride() / 2 ); sync(); }
/** * Multiply the given ndarray * by alpha * * @param alpha * @param x * @return */ public static INDArray scal(double alpha, INDArray x) { DataTypeValidation.assertDouble(x); BLAS.getInstance().dscal( x.length(), alpha, x.data().asDouble(), x.offset(), x.majorStride()); return x; }
@Override protected float sdot(int N, INDArray X, int incX, INDArray Y, int incY) { if (Nd4j.dataType() != DataBuffer.Type.FLOAT) logger.warn("FLOAT dot called"); DataTypeValidation.assertSameDataType(X, Y); Nd4j.getExecutioner().push(); CudaContext ctx = allocator.getFlowController().prepareAction(null, X, Y); float ret = 1f; CublasPointer xCPointer = new CublasPointer(X, ctx); CublasPointer yCPointer = new CublasPointer(Y, ctx); cublasHandle_t handle = ctx.getHandle(); synchronized (handle) { long result = cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); if (result != 0) throw new IllegalStateException("cublasSetStream failed"); FloatPointer resultPointer = new FloatPointer(0.0f); result = cublasSdot_v2(new cublasContext(handle), N, (FloatPointer) xCPointer.getDevicePointer(), incX, (FloatPointer) yCPointer.getDevicePointer(), incY, resultPointer); ret = resultPointer.get(); } allocator.registerAction(ctx, null, X, Y); return ret; }
/** * Multiply the given ndarray * by alpha * * @param alpha * @param x * @return */ public static INDArray scal(float alpha, INDArray x) { if(x.data().dataType() == DataBuffer.Type.DOUBLE) return scal((double) alpha,x); DataTypeValidation.assertFloat(x); BLAS.getInstance().sscal( x.length(), alpha, x.data().asFloat(), x.offset(), x.majorStride()); return x; }
/** * @param da * @param A * @param B */ public static void axpy(IComplexDouble da, IComplexNDArray A, IComplexNDArray B) { DataTypeValidation.assertDouble(A, B); CublasPointer aCPointer = new CublasPointer(A); CublasPointer bCPointer = new CublasPointer(B); sync(); JCublas2.cublasZaxpy( ContextHolder.getInstance().getHandle(), A.length(), PointerUtil.getPointer(jcuda.cuDoubleComplex.cuCmplx(da.realComponent().floatValue(), da.imaginaryComponent().floatValue())), aCPointer.getDevicePointer(), A.majorStride(), bCPointer.getDevicePointer(), B.majorStride() ); sync(); }