public static DoubleVertex logProbGraph(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex covariance) { final long dimensions = numberOfDimensions(mu.getShape()); final double kLog2Pi = dimensions * LOG_2_PI; final DoubleVertex logCovDet = covariance.matrixDeterminant().log(); DoubleVertex xMinusMu = x.minus(mu); DoubleVertex xMinusMuT = xMinusMu.permute(1, 0); DoubleVertex covInv = covariance.matrixInverse(); DoubleVertex scalar = isUnivariate(dimensions) ? covInv.times(xMinusMu).times(xMinusMuT).slice(0, 0): xMinusMuT.matrixMultiply(covInv.matrixMultiply(xMinusMu)).slice(0, 0); return scalar.plus(kLog2Pi).plus(logCovDet).times(-0.5).slice(0, 0); }
@Test public void canCalculateCorrectPermutedShapeWithDownstreamSum() { UniformVertex A = new UniformVertex(0, 10); A.setValue(DoubleTensor.arange(0, 6).reshape(1, 2, 3)); UniformVertex B = new UniformVertex(0, 10); B.setValue(DoubleTensor.arange(0, 6).reshape(1, 2, 3)); DoubleVertex C = A.plus(B); DoubleVertex permute = C.permute(0, 2, 1); SumVertex sum = permute.sum(0); DoubleTensor forwardWrtA = Differentiator.forwardModeAutoDiff(A, sum).of(sum); DoubleTensor backwardWrtA = Differentiator.reverseModeAutoDiff(sum, ImmutableSet.of(A)).withRespectTo(A); Assert.assertArrayEquals(new long[]{3, 2, 1, 2, 3}, forwardWrtA.getShape()); Assert.assertArrayEquals(new long[]{3, 2, 1, 2, 3}, backwardWrtA.getShape()); Assert.assertArrayEquals(forwardWrtA.asFlatDoubleArray(), backwardWrtA.asFlatDoubleArray(), 1e-6); }
@Test public void canCalculateCorrectPermutedShapeWithUpstreamSum() { UniformVertex A = new UniformVertex(0, 10); A.setValue(DoubleTensor.arange(0, 6).reshape(1, 2, 3)); UniformVertex B = new UniformVertex(0, 10); B.setValue(DoubleTensor.arange(0, 6).reshape(1, 2, 3)); DoubleVertex C = A.plus(B); DoubleVertex sum = C.sum(2); PermuteVertex permute = sum.permute(1, 0); DoubleTensor forwardWrtA = Differentiator.forwardModeAutoDiff(A, permute).of(permute); DoubleTensor backwardWrtA = Differentiator.reverseModeAutoDiff(permute, ImmutableSet.of(A)).withRespectTo(A); Assert.assertArrayEquals(new long[]{2, 1, 1, 2, 3}, forwardWrtA.getShape()); Assert.assertArrayEquals(new long[]{2, 1, 1, 2, 3}, backwardWrtA.getShape()); Assert.assertArrayEquals(forwardWrtA.asFlatDoubleArray(), backwardWrtA.asFlatDoubleArray(), 1e-6); }