public static DoubleVertex logProbOutput(DoublePlaceholderVertex t, IntegerPlaceHolderVertex v) { DoubleVertex vAsDouble = v.toDouble(); DoubleVertex halfVPlusOne = vAsDouble.plus(1.).div(2.); DoubleVertex logGammaHalfVPlusOne = halfVPlusOne.logGamma(); DoubleVertex logGammaHalfV = vAsDouble.div(2.).logGamma(); DoubleVertex halfLogV = vAsDouble.log().div(2.); return logGammaHalfVPlusOne .minus(halfLogV) .minus(HALF_LOG_PI) .minus(logGammaHalfV) .minus( halfVPlusOne.times( t.pow(2.).div(vAsDouble).plus(1.).log() ) ); }
public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex location, DoublePlaceholderVertex scale) { final DoubleVertex negLnScaleMinusLnPi = scale.log().unaryMinus().plus(NEG_LOG_PI); final DoubleVertex xMinusLocationOverScalePow2Plus1 = x.minus(location).div(scale).pow(2.).plus(1.); final DoubleVertex lnXMinusLocationOverScalePow2Plus1 = xMinusLocationOverScalePow2Plus1.log(); return negLnScaleMinusLnPi.minus(lnXMinusLocationOverScalePow2Plus1); }
@Test public void diffOverPlusDivideMultiplyLogCombination() { DoubleVertex vC = vA.plus(vB); DoubleVertex vD = vA.divideBy(vB); DoubleVertex vE = vC.multiply(vD); assertDiffIsCorrect(vA, vB, vE.log()); }
@Test public void canReverseAutoDiffOfMultiplicationAndLogWithSingleOutputWithRespectToMany() { DoubleVertex A = new GaussianVertex(0, 1); A.setValue(3.0); DoubleVertex B = new GaussianVertex(0, 1); B.setValue(5.0); DoubleVertex C = A.times(B); DoubleVertex E = C.times(2); DoubleVertex Y = E.log(); PartialsOf dY = Differentiator.reverseModeAutoDiff(Y, ImmutableSet.of(A, B)); DoubleTensor dYdA = dY.withRespectTo(A); DoubleTensor dYdB = dY.withRespectTo(B); assertEquals(A.getValue().reciprocal().scalar(), dYdA.scalar(), 1e-5); assertEquals(B.getValue().reciprocal().scalar(), dYdB.scalar(), 1e-5); }
@Test public void doesMatchReverseAutoDiffWithManyOps() { long[] shape = new long[]{2, 2}; DoubleVertex A = new GaussianVertex(shape, 0, 1); A.setValue(DoubleTensor.linspace(0.1, 2, 4).reshape(shape)); DoubleVertex B = new GaussianVertex(shape, 0, 1); B.setValue(DoubleTensor.linspace(0.2, 1, 4).reshape(shape)); DoubleVertex D = A.atan2(B).sigmoid().times(B); DoubleVertex C = A.sin().cos().div(D); DoubleVertex E = C.times(D).pow(A).acos(); DoubleVertex G = E.log().tan().asin().atan(); DoubleVertex F = D.plus(B).exp(); SumVertex H = G.plus(F).sum(); GaussianVertex J = new GaussianVertex(H, 1); J.observe(0.5); LogProbGradientCalculator calculator = new LogProbGradientCalculator(ImmutableList.of(J), ImmutableList.of(A, B)); Map<VertexId, DoubleTensor> gradient = calculator.getJointLogProbGradientWrtLatents(); DoubleTensor dJLogProbWrtAValue = gradient.get(A.getId()); DoubleTensor dJLogProbWrtBValue = gradient.get(B.getId()); PartialsOf dHForward = Differentiator.reverseModeAutoDiff(H, A, B); DoubleTensor dHdA = dHForward.withRespectTo(A); DoubleTensor dHdB = dHForward.withRespectTo(B); DoubleTensor dJLogProbWrtH = J.dLogProbAtValue(H).get(H); DoubleTensor expectedDJLogProbWrtAValue = dJLogProbWrtH.times(dHdA); DoubleTensor expectedDJLogProbWrtBValue = dJLogProbWrtH.times(dHdB); assertEquals(expectedDJLogProbWrtAValue, dJLogProbWrtAValue); assertEquals(expectedDJLogProbWrtBValue, dJLogProbWrtBValue); }
private BayesianNetwork createComplexNet() { DoubleVertex A = new GaussianVertex(new long[]{2, 2}, 0, 1).setLabel(INPUT_NAME); A.setValue(DoubleTensor.create(3.0, new long[]{2, 2})); DoubleVertex B = new GaussianVertex(new long[]{2, 2}, 0, 1); B.setValue(DoubleTensor.create(5.0, new long[]{2, 2})); DoubleVertex D = A.times(B); DoubleVertex C = A.sin(); DoubleVertex E = C.times(D); DoubleVertex G = E.log(); DoubleVertex F = D.plus(B); BooleanVertex predicate = ConstantVertex.of(BooleanTensor.create(new boolean[]{true, false, true, false}, new long[]{2, 2})); DoubleVertex H = If.isTrue(predicate).then(G).orElse(F).setLabel(OUTPUT_NAME); return new BayesianNetwork(H.getConnectedGraph()); }
@Test public void reverseAutoDiffOfRank3MatchesForwardWithSingleOutputWithRespectToMany() { long[] shape = new long[]{2, 2, 2}; GaussianVertex A = new GaussianVertex(shape, 0, 1); A.setValue(DoubleTensor.linspace(0.1, 2, (int) TensorShape.getLength(shape)).reshape(shape)); GaussianVertex B = new GaussianVertex(shape, 0, 1); B.setValue(DoubleTensor.linspace(0.2, 1, (int) TensorShape.getLength(shape)).reshape(shape)); GaussianVertex C = new GaussianVertex(shape, 0, 1); C.setValue(DoubleTensor.linspace(0.2, 0.8, (int) TensorShape.getLength(shape)).reshape(shape)); DoubleVertex D = A.atan2(B).sigmoid().times(B); DoubleVertex J = A.sin().cos().div(D); DoubleVertex E = J.times(D).pow(A).acos(); DoubleVertex G = E.log().tan().atan(); DoubleVertex F = D.plus(B).exp(); MultiplicationVertex H = G.plus(F).sum().times(A).sum().times(C); finiteDifferenceMatchesForwardAndReverseModeGradient(ImmutableList.of(A, B, C), H, 0.001, 1e-3); }
@Test public void reverseAutoDiffMatchesForwardWithSingleOutputWithRespectToMany() { long[] shape = new long[]{2, 2}; GaussianVertex A = new GaussianVertex(shape, 0, 1); A.setValue(DoubleTensor.linspace(0.1, 2, 4).reshape(shape)); GaussianVertex B = new GaussianVertex(shape, 0, 1); B.setValue(DoubleTensor.linspace(0.2, 1, 4).reshape(shape)); DoubleVertex D = A.atan2(B).sigmoid().times(B); DoubleVertex C = A.sin().cos().div(D); DoubleVertex E = C.times(D).pow(A).acos(); DoubleVertex G = E.log().tan().asin().atan(); DoubleVertex F = D.plus(B).exp(); SumVertex H = G.plus(F).sum(); PartialsOf dHReverse = Differentiator.reverseModeAutoDiff(H, ImmutableSet.of(A, B)); DoubleTensor dHdAReverse = dHReverse.withRespectTo(A); DoubleTensor dHdBReverse = dHReverse.withRespectTo(B); DoubleTensor dHdAForward = Differentiator.forwardModeAutoDiff(A, H).of(H); DoubleTensor dHdBForward = Differentiator.forwardModeAutoDiff(B, H).of(H); assertEquals(dHdAReverse, dHdAForward); assertEquals(dHdBReverse, dHdBForward); }
@Test public void canReverseAutoDiffOfMultiplicationLogSinAndSumWithSingleConditionalOutputWithRespectToMany() { DoubleVertex A = new GaussianVertex(new long[]{2, 2}, 0, 1); A.setValue(DoubleTensor.create(3.0, new long[]{2, 2})); DoubleVertex B = new GaussianVertex(new long[]{2, 2}, 0, 1); B.setValue(DoubleTensor.create(5.0, new long[]{2, 2})); DoubleVertex D = A.times(B); DoubleVertex C = A.sin(); DoubleVertex E = C.times(D); DoubleVertex G = E.log(); DoubleVertex F = D.plus(B); BooleanVertex predicate = ConstantVertex.of(BooleanTensor.create(new boolean[]{true, false, true, false}, new long[]{2, 2})); DoubleVertex H = If.isTrue(predicate).then(G).orElse(F); PartialsOf dH = Differentiator.reverseModeAutoDiff(H, ImmutableSet.of(A, B)); DoubleTensor dHdA = dH.withRespectTo(A); DoubleTensor dHdB = dH.withRespectTo(B); DoubleTensor predicateTrueMask = predicate.getValue().toDoubleMask(); DoubleTensor predicateFalseMask = predicate.getValue().not().toDoubleMask(); DoubleTensor AValue = A.getValue(); DoubleTensor BValue = B.getValue(); DoubleTensor expecteddHdA = AValue.reciprocal().plus(AValue.cos().div(AValue.sin())).times(predicateTrueMask) .plus(BValue.times(predicateFalseMask)).reshape(1, 4).diag().reshape(2, 2, 2, 2); DoubleTensor expecteddHdB = BValue.reciprocal().times(predicateTrueMask).plus(AValue.plus(1).times(predicateFalseMask)).reshape(1, 4).diag().reshape(2, 2, 2, 2); assertEquals(expecteddHdA, dHdA); assertEquals(expecteddHdB, dHdB); }