/** Matrix multiply: Implements c = alpha*op(a)*op(b) + beta*c where op(X) means transpose X (or not) * depending on setting of arguments transposeA and transposeB.<br> * Note that matrix c MUST be fortran order, have zero offset and have c.data().length == c.length(). * An exception will be thrown otherwise.<br> * Don't use this unless you know about level 3 blas and NDArray storage orders. * @param a First matrix * @param b Second matrix * @param c result matrix. Used in calculation (assuming beta != 0) and result is stored in this. f order, * zero offset and length == data.length only * @param transposeA if true: transpose matrix a before mmul * @param transposeB if true: transpose matrix b before mmul * @return result, i.e., matrix c is returned for convenience */ public static INDArray gemm(INDArray a, INDArray b, INDArray c, boolean transposeA, boolean transposeB, double alpha, double beta) { getBlasWrapper().level3().gemm(a, b, c, transposeA, transposeB, alpha, beta); return c; }
@Override public IComplexNDArray gemm(IComplexNumber alpha, IComplexNDArray a, IComplexNDArray b, IComplexNumber beta, IComplexNDArray c) { LinAlgExceptions.assertMatrix(a, b, c); level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b), BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c); return c; }
@Override public INDArray gemm(float alpha, INDArray a, INDArray b, float beta, INDArray c) { LinAlgExceptions.assertMatrix(a, b, c); if (a.data().dataType() == DataBuffer.Type.DOUBLE) { return gemm((double) alpha, a, b, (double) beta, c); } level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b), BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c); return c; }
@Override public INDArray gemm(double alpha, INDArray a, INDArray b, double beta, INDArray c) { LinAlgExceptions.assertMatrix(a, b, c); if (a.data().dataType() == DataBuffer.Type.FLOAT) { return gemm((float) alpha, a, b, (float) beta, c); } level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b), BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c); return c; }
Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(result), BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(temp), 1.0, this, other, 0.0, temp); Nd4j.getBlasWrapper().level3().gemm( BlasBufferUtil.getCharForTranspose(result), BlasBufferUtil.getCharForTranspose(this), Nd4j.getBlasWrapper().level3().gemm(ordering(), BlasBufferUtil.getCharForTranspose(other), BlasBufferUtil.getCharForTranspose(gemmResultArr),
BlasBufferUtil.getCharForTranspose(this), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, temp); } else { Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(temp), BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(other), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, temp); Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(resultArray), BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(other), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, resultArray);
Nd4j.getBlasWrapper().level3().gemm(ordering(), BlasBufferUtil.getCharForTranspose(other), BlasBufferUtil.getCharForTranspose(gemmResultArr), 1.0, this, other, 0.0, gemmResultArr);
/** Matrix multiply: Implements c = alpha*op(a)*op(b) + beta*c where op(X) means transpose X (or not) * depending on setting of arguments transposeA and transposeB.<br> * Note that matrix c MUST be fortran order, have zero offset and have c.data().length == c.length(). * An exception will be thrown otherwise.<br> * Don't use this unless you know about level 3 blas and NDArray storage orders. * @param a First matrix * @param b Second matrix * @param c result matrix. Used in calculation (assuming beta != 0) and result is stored in this. f order, * zero offset and length == data.length only * @param transposeA if true: transpose matrix a before mmul * @param transposeB if true: transpose matrix b before mmul * @return result, i.e., matrix c is returned for convenience */ public static INDArray gemm(INDArray a, INDArray b, INDArray c, boolean transposeA, boolean transposeB, double alpha, double beta) { getBlasWrapper().level3().gemm(a, b, c, transposeA, transposeB, alpha, beta); return c; }
@Override public IComplexNDArray gemm(IComplexNumber alpha, IComplexNDArray a, IComplexNDArray b, IComplexNumber beta, IComplexNDArray c) { LinAlgExceptions.assertMatrix(a, b, c); level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b), BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c); return c; }
@Override public INDArray gemm(float alpha, INDArray a, INDArray b, float beta, INDArray c) { LinAlgExceptions.assertMatrix(a, b, c); if (a.data().dataType() == DataBuffer.Type.DOUBLE) { return gemm((double) alpha, a, b, (double) beta, c); } level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b), BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c); return c; }
@Override public INDArray gemm(double alpha, INDArray a, INDArray b, double beta, INDArray c) { LinAlgExceptions.assertMatrix(a, b, c); if (a.data().dataType() == DataBuffer.Type.FLOAT) { return gemm((float) alpha, a, b, (float) beta, c); } level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b), BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c); return c; }
BlasBufferUtil.getCharForTranspose(this), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, temp); } else { Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(temp), BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(other), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, temp); Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(resultArray), BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(other), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, resultArray);
Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(result), BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(temp), 1.0, this, other, 0.0, temp); Nd4j.getBlasWrapper().level3().gemm(ordering(), BlasBufferUtil.getCharForTranspose(other), BlasBufferUtil.getCharForTranspose(gemmResultArr),