@Override public INDArray doForward(boolean training) { if (!canDoForward()) throw new IllegalStateException("Cannot do forward pass: inputs not set (L2NormalizeVertex " + vertexName + " idx " + vertexIndex + ")"); // L2 norm along all dimensions except 0, unless user-specified // x / |x|2 INDArray x = inputs[0]; int[] dimensions = getDimensions(x); INDArray xNorm2 = x.norm2(dimensions); Transforms.max(xNorm2, eps, false); if (x.rank() == 2) { return x.divColumnVector(xNorm2); } else { INDArray out = Nd4j.createUninitialized(x.shape(), x.ordering()); return Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(x, xNorm2, out, 0)); } }
if (x.rank() == 2) { dLdx = epsilon.divColumnVector(norm); INDArray xDivNorm3 = x.divColumnVector(norm3); dLdx.subi(xDivNorm3.muliColumnVector(epsilon.mul(x).sum(1))); } else {