@Override public void observeValues(DoubleTensor input, OUTPUT output) { xVertex.setValue(input); yObservationVertex.observe(output); }
public SumGaussianTestCase() { A = new GaussianVertex(20.0, 1.0); B = new GaussianVertex(20.0, 1.0); A.setValue(20.0); B.setValue(20.0); DoubleVertex Cobserved = new GaussianVertex(A.plus(B), 1.0); Cobserved.observe(46.0); model = new BayesianNetwork(Arrays.asList(A, B, Cobserved)); model.probeForNonZeroProbability(100); }
static TestData generateMultiFeatureData(int featureCount, Function<long[], DoubleVertex> weightVertexFromShape) { long N = 1000; double expectedB = 20; DoubleVertex xGenerator = new UniformVertex(new long[]{N, featureCount}, 0, 100); DoubleVertex weightsGenerator = weightVertexFromShape.apply(new long[]{featureCount,1}); DoubleVertex yGenerator = new GaussianVertex(new long[]{N, 1}, xGenerator.matrixMultiply(weightsGenerator).plus(expectedB), 1.0); DoubleTensor xData = xGenerator.sample(); DoubleTensor weights = weightsGenerator.sample(); xGenerator.setValue(xData); weightsGenerator.setValue(weights); DoubleTensor yData = yGenerator.getValue(); return new TestData(weights, expectedB, xData, yData); }
@Test public void doesSumScalarCorrectly() { DoubleVertex a = new UniformVertex(0, 10); a.setValue(2); assertArrayEquals(new long[0], a.sum().getShape()); }
@Test public void canSetValueAsScalarOnNonScalarVertex() { DoubleVertex gaussianVertex = new GaussianVertex(new long[]{1, 2}, 0, 1); gaussianVertex.setValue(2); assertArrayEquals(new double[]{2}, gaussianVertex.getValue().asFlatDoubleArray(), 0.0); }
@Test public void canSetValueArrayOfValues() { DoubleVertex gaussianVertex = new GaussianVertex(0, 1); double[] values = new double[]{1, 2, 3}; gaussianVertex.setValue(values); assertArrayEquals(values, gaussianVertex.getValue().asFlatDoubleArray(), 0.0); }
@Test(expected = SingularMatrixException.class) public void differentiationFailsWhenMatrixIsSingular() { final long[] shape = new long[]{2, 2}; final DoubleVertex input = new UniformVertex(shape, 0, 10); input.setValue(DoubleTensor.create(new double[]{0, 0, 0, 0}, shape)); final DoubleVertex output = input.matrixDeterminant(); Differentiator.reverseModeAutoDiff(output, input); }
@Test public void doesSumAllDimensions() { DoubleVertex a = new UniformVertex(new long[]{1, 5}, 0, 10); a.setValue(new double[]{1, 2, 3, 4, 5}); DoubleVertex summed = a.sum(); assertEquals(1 + 2 + 3 + 4 + 5, summed.eval().scalar(), 1e-5); }
@Test public void diffWrtScalarOverMultipleMultipliesAndSummation() { DoubleVertex A = new UniformVertex(new long[]{2, 2}, 0, 1); A.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); SumVertex B = A.sum().times(ConstantVertex.of(new double[]{1, 2, 3, 4})).sum(); DoubleTensor wrtA = Differentiator.reverseModeAutoDiff(B, A).withRespectTo(A); //B = 1*(a00 + a01 + a10 + a11) + 2*(a00 + a01 + a10 + a11)+ 3*(a00 + a01 + a10 + a11)+ 4*(a00 + a01 + a10 + a11) //dBda00 = 1 + 2 + 3 + 4 = 10 DoubleTensor expectedWrt = DoubleTensor.create(new double[]{10, 10, 10, 10}).reshape(2, 2); assertThat(wrtA, equalTo(expectedWrt)); }
@Test public void canConcat() { DoubleVertex A = new UniformVertex(0, 1); A.setValue(DoubleTensor.arange(1, 5).reshape(2, 2)); DoubleVertex B = new UniformVertex(0, 1); B.setValue(DoubleTensor.arange(5, 9).reshape(2, 2)); DoubleVertex concatDimZero = DoubleVertex.concat(0, A, B); assertArrayEquals(concatDimZero.getShape(), new long[]{4, 2}); DoubleVertex concatDimOne = DoubleVertex.concat(1, A, B); assertArrayEquals(concatDimOne.getShape(), new long[]{2, 4}); }
@Test public void doesSumAllSpecifiedDimensions() { DoubleVertex a = new UniformVertex(new long[]{1, 5}, 0, 10); a.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4, 5}, 1, 5)); DoubleVertex summed = a.sum(0, 1); DoubleTensor expected = DoubleTensor.scalar(1 + 2 + 3 + 4 + 5); assertEquals(expected, summed.eval()); }
@Test public void canRepeatablySliceForAPick() { DoubleVertex m = new UniformVertex(0, 10); m.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); SliceVertex columnZero = new SliceVertex(m, 0, 0); SliceVertex elementZero = new SliceVertex(columnZero, 0, 0); Assert.assertEquals(elementZero.getValue().scalar(), 1, 1e-6); }
@Test public void doesRejectOnImpossibleProposal() { DoubleVertex A = new UniformVertex(0, 1); A.setValue(0.5); ProbabilisticModel model = new KeanuProbabilisticModel(A.getConnectedGraph()); MetropolisHastingsStep mhStep = stepFunctionWithConstantProposal(model, -1, alwaysAccept); MetropolisHastingsStep.StepResult result = mhStep.step( Collections.singleton(A), model.logProb() ); assertFalse(result.isAccepted()); assertEquals(0.5, A.getValue(0), 1e-10); }
@Test public void doesSumSpecifiedDimensions() { long[] shape = {2, 2, 2, 2}; DoubleVertex a = new UniformVertex(shape, 0, 10); a.setValue(DoubleTensor.arange(0, TensorShape.getLength(shape)).reshape(shape)); DoubleVertex summed = a.sum(0, 2); DoubleTensor expected = DoubleTensor.create(new double[]{20, 24, 36, 40}, 2, 2); assertEquals(expected, summed.eval()); }
@Test public void canConcatMatricesOfSameSize() { DoubleVertex a = new UniformVertex(0, 10); a.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); DoubleVertex b = new UniformVertex(0, 10); b.setValue(DoubleTensor.create(new double[]{10, 15, 20, 25}, 2, 2)); ConcatenationVertex concatZero = new ConcatenationVertex(0, a, b); Assert.assertArrayEquals(new long[]{4, 2}, concatZero.getShape()); Assert.assertArrayEquals(new double[]{1, 2, 3, 4, 10, 15, 20, 25}, concatZero.getValue().asFlatDoubleArray(), 0.001); ConcatenationVertex concatOne = new ConcatenationVertex(1, a, b); Assert.assertArrayEquals(new long[]{2, 4}, concatOne.getShape()); Assert.assertArrayEquals(new double[]{1, 2, 10, 15, 3, 4, 20, 25}, concatOne.getValue().asFlatDoubleArray(), 0.001); }
@Test public void reshapeVertex() { DoubleVertex a = new UniformVertex(0, 10); a.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); ReshapeVertex reshapeVertex = new ReshapeVertex(a, 4, 1); reshapeVertex.getValue(); Assert.assertArrayEquals(new long[]{4, 1}, reshapeVertex.getShape()); Assert.assertArrayEquals(new double[]{1, 2, 3, 4}, reshapeVertex.getValue().asFlatDoubleArray(), 1e-6); }
@Test public void canPermuteForTranpose() { DoubleVertex a = new UniformVertex(0, 10); a.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4, 5, 6}, 2, 3)); PermuteVertex transpose = new PermuteVertex(a, 1, 0); Assert.assertArrayEquals(new long[]{3, 2}, transpose.getShape()); Assert.assertArrayEquals(a.getValue().transpose().asFlatDoubleArray(), transpose.getValue().asFlatDoubleArray(), 1e-6); }
@Test public void diffWrtScalarOverMultipleMultiplies() { DoubleVertex A = new UniformVertex(0, 1); A.setValue(2); DoubleVertex prod = A.times(ConstantVertex.of(new double[]{1, 2, 3, 4})); DoubleVertex sum = prod.plus(ConstantVertex.of(new double[]{2, 4, 6, 8})); DoubleVertex prod2 = sum.times(ConstantVertex.of(new double[]{2, 4, 6, 8})); MultiplicationVertex output = prod2.plus(5).times(2); DoubleTensor wrtA = Differentiator.reverseModeAutoDiff(output, A).withRespectTo(A); DoubleTensor expectedWrt = DoubleTensor.create(4, 16, 36, 64); assertArrayEquals(expectedWrt.asFlatDoubleArray(), wrtA.asFlatDoubleArray(), 0.0); assertArrayEquals(expectedWrt.getShape(), wrtA.getShape()); }
@Test public void doesCalculateCorrectShape() { long[] shape = {2, 3, 4, 5, 6, 1}; DoubleVertex a = new UniformVertex(shape, 0, 10); DoubleTensor highrank = DoubleTensor.arange(0, TensorShape.getLength(shape)).reshape(shape); a.setValue(highrank); assertArrayEquals(new long[]{3, 5, 6, 1}, a.sum(0, 2).getShape()); assertArrayEquals(new long[]{5, 6}, a.sum(0, 1, 2, 5).getShape()); assertArrayEquals(new long[]{6, 1}, a.sum(0, 1, 2, 3).getShape()); assertArrayEquals(new long[]{2}, a.sum(1, 2, 3, 4, 5).getShape()); assertArrayEquals(new long[]{3}, a.sum(0, 2, 3, 4, 5).getShape()); assertArrayEquals(new long[]{}, a.sum(0, 1, 2, 3, 4, 5).getShape()); }