@Test public void doesCalculateCorrectShape() { long[] shape = {2, 3, 4, 5, 6, 1}; DoubleVertex a = new UniformVertex(shape, 0, 10); DoubleTensor highrank = DoubleTensor.arange(0, TensorShape.getLength(shape)).reshape(shape); a.setValue(highrank); assertArrayEquals(new long[]{3, 5, 6, 1}, a.sum(0, 2).getShape()); assertArrayEquals(new long[]{5, 6}, a.sum(0, 1, 2, 5).getShape()); assertArrayEquals(new long[]{6, 1}, a.sum(0, 1, 2, 3).getShape()); assertArrayEquals(new long[]{2}, a.sum(1, 2, 3, 4, 5).getShape()); assertArrayEquals(new long[]{3}, a.sum(0, 2, 3, 4, 5).getShape()); assertArrayEquals(new long[]{}, a.sum(0, 1, 2, 3, 4, 5).getShape()); }
@Test public void doesSumScalarCorrectly() { DoubleVertex a = new UniformVertex(0, 10); a.setValue(2); assertArrayEquals(new long[0], a.sum().getShape()); }
@Override public LogProbGraph logProbGraph() { final DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape()); final DoublePlaceholderVertex muPlaceholder = new DoublePlaceholderVertex(mu.getShape()); final DoublePlaceholderVertex sigmaPlaceholder = new DoublePlaceholderVertex(sigma.getShape()); return LogProbGraph.builder() .input(this, xPlaceholder) .input(mu, muPlaceholder) .input(sigma, sigmaPlaceholder) .logProbOutput(Gaussian.logProbOutput(xPlaceholder, muPlaceholder, sigmaPlaceholder).sum()) .build(); }
@Test public void doesSumAllDimensions() { DoubleVertex a = new UniformVertex(new long[]{1, 5}, 0, 10); a.setValue(new double[]{1, 2, 3, 4, 5}); DoubleVertex summed = a.sum(); assertEquals(1 + 2 + 3 + 4 + 5, summed.eval().scalar(), 1e-5); }
@Test public void diffWrtScalarOverMultipleMultipliesAndSummation() { DoubleVertex A = new UniformVertex(new long[]{2, 2}, 0, 1); A.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); SumVertex B = A.sum().times(ConstantVertex.of(new double[]{1, 2, 3, 4})).sum(); DoubleTensor wrtA = Differentiator.reverseModeAutoDiff(B, A).withRespectTo(A); //B = 1*(a00 + a01 + a10 + a11) + 2*(a00 + a01 + a10 + a11)+ 3*(a00 + a01 + a10 + a11)+ 4*(a00 + a01 + a10 + a11) //dBda00 = 1 + 2 + 3 + 4 = 10 DoubleTensor expectedWrt = DoubleTensor.create(new double[]{10, 10, 10, 10}).reshape(2, 2); assertThat(wrtA, equalTo(expectedWrt)); }
@Test public void doesSumAllSpecifiedDimensions() { DoubleVertex a = new UniformVertex(new long[]{1, 5}, 0, 10); a.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4, 5}, 1, 5)); DoubleVertex summed = a.sum(0, 1); DoubleTensor expected = DoubleTensor.scalar(1 + 2 + 3 + 4 + 5); assertEquals(expected, summed.eval()); }
@Test public void doesSumSpecifiedDimensions() { long[] shape = {2, 2, 2, 2}; DoubleVertex a = new UniformVertex(shape, 0, 10); a.setValue(DoubleTensor.arange(0, TensorShape.getLength(shape)).reshape(shape)); DoubleVertex summed = a.sum(0, 2); DoubleTensor expected = DoubleTensor.create(new double[]{20, 24, 36, 40}, 2, 2); assertEquals(expected, summed.eval()); }
); DoubleVertex answerTotal = new GaussianVertex(answer.sum(), 1); answerTotal.observe(numberOfYesAnswers);
@Test public void canConcatenateHighRankAutoDiff() { UniformVertex a = new UniformVertex(0, 10); a.setValue(DoubleTensor.arange(0, 12).reshape(2, 2, 3)); UniformVertex b = new UniformVertex(0, 10); b.setValue(DoubleTensor.arange(0, 8).reshape(2, 2, 2)); DoubleVertex c = a.times(ConstantVertex.of(DoubleTensor.linspace(0, 1, 12).reshape(2, 2, 3))); DoubleVertex d = b.plus(ConstantVertex.of(DoubleTensor.linspace(1, 2, 8).reshape(2, 2, 2))); DoubleVertex concat = new ConcatenationVertex(2, c, d); SumVertex sum = concat.sum(1); finiteDifferenceMatchesForwardAndReverseModeGradient(Arrays.asList(a, b), sum, 10.0, 1e-10); }
@Test public void canCalculateCorrectPermutedShapeWithUpstreamSum() { UniformVertex A = new UniformVertex(0, 10); A.setValue(DoubleTensor.arange(0, 6).reshape(1, 2, 3)); UniformVertex B = new UniformVertex(0, 10); B.setValue(DoubleTensor.arange(0, 6).reshape(1, 2, 3)); DoubleVertex C = A.plus(B); DoubleVertex sum = C.sum(2); PermuteVertex permute = sum.permute(1, 0); DoubleTensor forwardWrtA = Differentiator.forwardModeAutoDiff(A, permute).of(permute); DoubleTensor backwardWrtA = Differentiator.reverseModeAutoDiff(permute, ImmutableSet.of(A)).withRespectTo(A); Assert.assertArrayEquals(new long[]{2, 1, 1, 2, 3}, forwardWrtA.getShape()); Assert.assertArrayEquals(new long[]{2, 1, 1, 2, 3}, backwardWrtA.getShape()); Assert.assertArrayEquals(forwardWrtA.asFlatDoubleArray(), backwardWrtA.asFlatDoubleArray(), 1e-6); }
@Test public void canCalculateCorrectPermutedShapeWithDownstreamSum() { UniformVertex A = new UniformVertex(0, 10); A.setValue(DoubleTensor.arange(0, 6).reshape(1, 2, 3)); UniformVertex B = new UniformVertex(0, 10); B.setValue(DoubleTensor.arange(0, 6).reshape(1, 2, 3)); DoubleVertex C = A.plus(B); DoubleVertex permute = C.permute(0, 2, 1); SumVertex sum = permute.sum(0); DoubleTensor forwardWrtA = Differentiator.forwardModeAutoDiff(A, sum).of(sum); DoubleTensor backwardWrtA = Differentiator.reverseModeAutoDiff(sum, ImmutableSet.of(A)).withRespectTo(A); Assert.assertArrayEquals(new long[]{3, 2, 1, 2, 3}, forwardWrtA.getShape()); Assert.assertArrayEquals(new long[]{3, 2, 1, 2, 3}, backwardWrtA.getShape()); Assert.assertArrayEquals(forwardWrtA.asFlatDoubleArray(), backwardWrtA.asFlatDoubleArray(), 1e-6); }