/** {@inheritDoc} */ @Override protected double bnorm() { return dataset.computeWithCtx((ctx, data) -> { ctx.setU(Arrays.copyOf(data.getLabels(), data.getLabels().length)); return BLAS.getInstance().dnrm2(data.getLabels().length, data.getLabels(), 1); }, (a, b) -> a == null ? b : b == null ? a : Math.sqrt(a * a + b * b)); }
/** {@inheritDoc} */ @Override protected double beta(double[] x, double alfa, double beta) { return dataset.computeWithCtx((ctx, data) -> { if (data.getFeatures() == null) return null; int cols = data.getFeatures().length / data.getRows(); BLAS.getInstance().dgemv("N", data.getRows(), cols, alfa, data.getFeatures(), Math.max(1, data.getRows()), x, 1, beta, ctx.getU(), 1); return BLAS.getInstance().dnrm2(ctx.getU().length, ctx.getU(), 1); }, (a, b) -> a == null ? b : b == null ? a : Math.sqrt(a * a + b * b)); }
/** {@inheritDoc} */ @Override protected double[] iter(double bnorm, double[] target) { double[] res = dataset.computeWithCtx((ctx, data) -> { if (data.getFeatures() == null) return null; int cols = data.getFeatures().length / data.getRows(); BLAS.getInstance().dscal(ctx.getU().length, 1 / bnorm, ctx.getU(), 1); double[] v = new double[cols]; BLAS.getInstance().dgemv("T", data.getRows(), cols, 1.0, data.getFeatures(), Math.max(1, data.getRows()), ctx.getU(), 1, 0, v, 1); return v; }, (a, b) -> { if (a == null) return b; else if (b == null) return a; else { BLAS.getInstance().daxpy(a.length, 1.0, a, 1, b, 1); return b; } }); BLAS.getInstance().daxpy(res.length, 1.0, res, 1, target, 1); return target; }
/** * Tests if Native BLAS is loaded * @return * true if native BLAS loaded, false if Java BLAS is loaded */ public static boolean isNativeBLAS() { return BLAS.getInstance().getClass().getName().contains("Native"); }
@Override public Matrix rank1(double alpha, Vector x, Vector y) { if (x != y) throw new IllegalArgumentException("x != y"); if (!(x instanceof DenseVector)) return super.rank1(alpha, x, y); checkRank1(x, y); double[] xd = ((DenseVector) x).getData(); BLAS.getInstance().dsyr(uplo.netlib(), numRows, alpha, xd, 1, data, Math.max(1, numRows)); return this; }
@Override public Matrix transRank1(double alpha, Matrix C) { if (!(C instanceof DenseMatrix)) return super.transRank1(alpha, C); checkTransRank1(C); double[] Cd = ((DenseMatrix) C).getData(); BLAS.getInstance().dsyrk(uplo.netlib(), Transpose.Transpose.netlib(), numRows, numRows, alpha, Cd, Math.max(1, numRows), 1, data, Math.max(1, numRows)); return this; }
@Override protected void dscal(int N, double alpha, INDArray X, int incX) { double[] data = getDoubleData(X); BLAS.getInstance().dscal(N, alpha, data, BlasBufferUtil.getBlasOffset(X), incX); setData(data,X); }
@Override public Vector multAdd(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.multAdd(alpha, x, y); checkMultAdd(x, y); double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y) .getData(); BLAS.getInstance().dgemv(Transpose.NoTranspose.netlib(), numRows, numColumns, alpha, data, Math.max(numRows, 1), xd, 1, 1, yd, 1); return y; }
@Override public Vector transMultAdd(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.transMultAdd(alpha, x, y); checkTransMultAdd(x, y); double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y) .getData(); BLAS.getInstance().dgemv(Transpose.Transpose.netlib(), numRows, numColumns, alpha, data, Math.max(numRows, 1), xd, 1, 1, yd, 1); return y; }
@Override public Matrix transRank2(double alpha, Matrix B, Matrix C) { if (!(B instanceof DenseMatrix) || !(C instanceof DenseMatrix)) return super.transRank2(alpha, B, C); checkTransRank2(B, C); double[] Bd = ((DenseMatrix) B).getData(), Cd = ((DenseMatrix) C) .getData(); BLAS.getInstance().dsyr2k(uplo.netlib(), Transpose.Transpose.netlib(), numRows, B.numRows(), alpha, Bd, Math.max(1, B.numRows()), Cd, Math.max(1, B.numRows()), 1, data, Math.max(1, numRows)); return this; }
@Override public Matrix transRank2(double alpha, Matrix B, Matrix C) { if (!(B instanceof DenseMatrix) || !(C instanceof DenseMatrix)) return super.transRank2(alpha, B, C); checkTransRank2(B, C); double[] Bd = ((DenseMatrix) B).getData(), Cd = ((DenseMatrix) C) .getData(); BLAS.getInstance().dsyr2k(uplo.netlib(), Transpose.Transpose.netlib(), numRows, B.numRows(), alpha, Bd, Math.max(1, B.numRows()), Cd, Math.max(1, B.numRows()), 1, data, Math.max(1, numRows)); return this; }
@Override protected void sgbmv(char order, char TransA, int M, int N, int KL, int KU, float alpha, INDArray A, int lda, INDArray X, int incX, float beta, INDArray Y, int incY) { float[] yData = getFloatData(Y); BLAS.getInstance().sgbmv(String.valueOf(TransA),M,N,KL,KU,alpha,getFloatData(A),getBlasOffset(A),lda,getFloatData(X),getBlasOffset(X),incX,beta,yData,getBlasOffset(Y),incY); setData(yData,Y); }
@Override protected void ssyr2(char order, char Uplo, int N, float alpha, INDArray X, int incX, INDArray Y, int incY, INDArray A, int lda) { float[] aData = getFloatData(A); BLAS.getInstance().ssyr2(String.valueOf(Uplo),N,alpha,getFloatData(X),getBlasOffset(X),incY,getFloatData(Y),getBlasOffset(Y),incY,aData,getBlasOffset(A),lda); setData(aData,A); }
@Override protected void dsbmv(char order, char Uplo, int N, int K, double alpha, INDArray A, int lda, INDArray X, int incX, double beta, INDArray Y, int incY) { double[] yData = getDoubleData(Y); BLAS.getInstance().dsbmv(String.valueOf(Uplo),N,K,alpha,getDoubleData(A),getBlasOffset(A),lda,getDoubleData(X),getBlasOffset(X),incX,beta,yData,getBlasOffset(Y),incY); setData(yData,Y); }
@Override public Vector multAdd(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.multAdd(alpha, x, y); checkMultAdd(x, y); double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y) .getData(); BLAS.getInstance().dsbmv(uplo.netlib(), numRows, kd, alpha, data, kd + 1, xd, 1, 1, yd, 1); return y; }
@Override public Vector transMultAdd(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.transMultAdd(alpha, x, y); checkTransMultAdd(x, y); double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y) .getData(); BLAS.getInstance().dgbmv(Transpose.Transpose.netlib(), numRows, numColumns, kl, ku, alpha, data, kl + ku + 1, xd, 1, 1, yd, 1); return y; }
@Override public Vector transMultAdd(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.transMultAdd(alpha, x, y); checkTransMultAdd(x, y); double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y) .getData(); BLAS.getInstance().dgbmv(Transpose.Transpose.netlib(), numRows, numColumns, kl, ku, alpha, data, kl + ku + 1, xd, 1, 1, yd, 1); return y; }
@Override public Matrix rank2(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.rank2(alpha, x, y); checkRank2(x, y); double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y) .getData(); BLAS.getInstance().dspr2(uplo.netlib(), numRows, alpha, xd, 1, yd, 1, data); return this; }
@Override protected void sswap(int N, INDArray X, int incX, INDArray Y, int incY) { float[] yData = getFloatData(Y); float[] xData = getFloatData(X); BLAS.getInstance().sswap(N, xData, getBlasOffset(X), incX, yData, getBlasOffset(Y), incY); setData(xData,X); setData(yData,Y); }
@Override protected void srotmg(float d1, float d2, float b1, float b2, INDArray P) { float[] pData = getFloatData(P); BLAS.getInstance().srotmg(new floatW(d1), new floatW(d2), new floatW(b1), b2, pData, getBlasOffset(P)); setData(pData,P); }