@Override
public INDArray sample(int[] shape) {
int numRows = 1;
for (int i = 0; i < shape.length - 1; i++)
numRows *= shape[i];
int numCols = shape[shape.length - 1];
val flatShape = new int[]{numRows, numCols};
val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(flatShape, Nd4j.order()), 0.0, 1.0), random);
long m = flatRng.rows();
long n = flatRng.columns();
val s = Nd4j.create(m < n ? m : n);
val u = m < n ? Nd4j.create(m, n) : Nd4j.create(m, m);
val v = Nd4j.create(n, n, 'f');
Nd4j.getBlasWrapper().lapack().gesvd(flatRng, s, u, v);
if (gains == null) {
if (u.rows() == numRows && u.columns() == numCols) {
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(ArrayUtil.toLongArray(shape));
} else {
return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(ArrayUtil.toLongArray(shape));
}
} else {
throw new UnsupportedOperationException();
}
}