@Override public AdditionVertex plus(double that) { return plus(new ConstantDoubleVertex(that)); }
private static DoubleVertex getCubeCoefficientVertex(DoubleVertex Sw, DoubleVertex Bw) { return Sw.pow(3).times(Sw.plus(Bw)).reverseDiv(-2.); }
private static DoubleVertex getSquareCoefficientVertex(DoubleVertex Sw, DoubleVertex Bw) { return Sw.pow(2).times(Sw.plus(Bw)).reverseDiv(3.); }
@Test public void diffOverAddition() { assertDiffIsCorrect(vA, vB, vA.plus(vB)); }
public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex beta) { final DoubleVertex muMinusXAbsNegDivBeta = mu.minus(x).abs().div(beta); final DoubleVertex logTwoBeta = beta.times(2.).log(); return muMinusXAbsNegDivBeta.plus(logTwoBeta).unaryMinus(); }
public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex alpha, DoublePlaceholderVertex beta) { final DoubleVertex lnGammaAlpha = alpha.logGamma(); final DoubleVertex lnGammaBeta = beta.logGamma(); final DoubleVertex alphaPlusBetaLnGamma = (alpha.plus(beta)).logGamma(); final DoubleVertex alphaMinusOneTimesLnX = x.log().times(alpha.minus(1.)); final DoubleVertex betaMinusOneTimesOneMinusXLn = x.unaryMinus().plus(1.).log().times(beta.minus(1.)); final DoubleVertex betaFunction = lnGammaAlpha.plus(lnGammaBeta).minus(alphaPlusBetaLnGamma); return alphaMinusOneTimesLnX.plus(betaMinusOneTimesOneMinusXLn).minus(betaFunction); }
public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex sigma) { final DoubleVertex lnSigma = sigma.log(); final DoubleVertex xMinusMuSquared = x.minus(mu).pow(2.); final DoubleVertex xMinusMuSquaredOver2Variance = xMinusMuSquared.div(sigma.pow(2.).times(2.)); return xMinusMuSquaredOver2Variance.plus(lnSigma).plus(LN_SQRT_2PI).unaryMinus(); }
public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex s) { final DoubleVertex xMinusAOverB = x.minus(mu).div(s); final DoubleVertex ln1OverB = s.reverseDiv(1.).log(); return xMinusAOverB.plus(ln1OverB).minus( xMinusAOverB.exp().plus(1.).log().times(2.)); }
public static void main(String[] args) { //%%SNIPPET_START%% Overview DoubleVertex x = new UniformVertex(1, 2); DoubleVertex y = x.times(2); DoubleVertex observedY = new UniformVertex(new long[]{1, 2}, y, y.plus(0.5)); observedY.observe(new double[]{4.0, 4.49}); //%%SNIPPET_END%% Overview } }
@Test public void diffOverPlusMinusMultiplyCombination() { DoubleVertex vC = vA.plus(vB); DoubleVertex vD = vA.minus(vB); MultiplicationVertex vE = vC.multiply(vD); assertDiffIsCorrect(vA, vB, vE); }
public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex sigma) { final DoubleVertex lnSigmaX = sigma.times(x).log(); final DoubleVertex lnXMinusMuSquared = x.log().minus(mu).pow(2.); final DoubleVertex lnXMinusMuSquaredOver2Variance = lnXMinusMuSquared.div(sigma.pow(2.).times(2.)); return lnXMinusMuSquaredOver2Variance.plus(lnSigmaX).plus(LN_SQRT_2PI).unaryMinus(); }
public static DoubleVertex logProbOutput(DoublePlaceholderVertex x, DoublePlaceholderVertex alpha, DoublePlaceholderVertex beta) { final DoubleVertex aTimesLnB = alpha.times(beta.log()); final DoubleVertex negAMinus1TimesLnX = x.log().times(alpha.unaryMinus().minus(1.)); final DoubleVertex lnGammaA = alpha.logGamma(); return aTimesLnB.plus(negAMinus1TimesLnX).minus(lnGammaA).minus(beta.div(x)); }
@Before public void setup() { A = ConstantVertex.of(2.0); B = ConstantVertex.of(2.0); C = A.log(); D = A.multiply(B); E = C.plus(D); F = D.minus(E); G = F.log(); allVertices = Arrays.asList(A, B, C, D, E, F, G); }
@Test public void diffOverPlusDivideMultiplyLogCombination() { DoubleVertex vC = vA.plus(vB); DoubleVertex vD = vA.divideBy(vB); DoubleVertex vE = vC.multiply(vD); assertDiffIsCorrect(vA, vB, vE.log()); }
public SumGaussianTestCase() { A = new GaussianVertex(20.0, 1.0); B = new GaussianVertex(20.0, 1.0); A.setValue(20.0); B.setValue(20.0); DoubleVertex Cobserved = new GaussianVertex(A.plus(B), 1.0); Cobserved.observe(46.0); model = new BayesianNetwork(Arrays.asList(A, B, Cobserved)); model.probeForNonZeroProbability(100); }
@Before public void setup() { random = new KeanuRandom(1); A = new GaussianVertex(5, 1); A.setValue(5.0); B = new GaussianVertex(2, 1); B.setValue(2.0); C = A.plus(B); D = new GaussianVertex(C, 1); D.observe(7.5); }
public static DoubleVertex logProbGraph(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex covariance) { final long dimensions = numberOfDimensions(mu.getShape()); final double kLog2Pi = dimensions * LOG_2_PI; final DoubleVertex logCovDet = covariance.matrixDeterminant().log(); DoubleVertex xMinusMu = x.minus(mu); DoubleVertex xMinusMuT = xMinusMu.permute(1, 0); DoubleVertex covInv = covariance.matrixInverse(); DoubleVertex scalar = isUnivariate(dimensions) ? covInv.times(xMinusMu).times(xMinusMuT).slice(0, 0): xMinusMuT.matrixMultiply(covInv.matrixMultiply(xMinusMu)).slice(0, 0); return scalar.plus(kLog2Pi).plus(logCovDet).times(-0.5).slice(0, 0); }
@Test public void canUseGradientBasedSamplingWithAssertVertex() { DoubleVertex A = new GaussianVertex(20.0, 1.0); DoubleVertex B = new GaussianVertex(20.0, 1.0); A.setValue(21.5); B.setAndCascade(21.5); A.greaterThan(new ConstantDoubleVertex(20)).assertTrue(); B.greaterThan(new ConstantDoubleVertex(20)).assertTrue(); DoubleVertex Cobserved = new GaussianVertex(A.plus(B), 1.0); Cobserved.observe(46.0); KeanuProbabilisticModel bayesNet = new KeanuProbabilisticModel(Arrays.asList(A, B, Cobserved)); PosteriorSamplingAlgorithm samplingAlgorithm = Keanu.Sampling.MCMC.withDefaultConfigFor(bayesNet); assertThat(samplingAlgorithm, instanceOf(NUTS.class)); }
@Test public void diffWrtScalarOverMultipleMultiplies() { DoubleVertex A = new UniformVertex(0, 1); A.setValue(2); DoubleVertex prod = A.times(ConstantVertex.of(new double[]{1, 2, 3, 4})); DoubleVertex sum = prod.plus(ConstantVertex.of(new double[]{2, 4, 6, 8})); DoubleVertex prod2 = sum.times(ConstantVertex.of(new double[]{2, 4, 6, 8})); MultiplicationVertex output = prod2.plus(5).times(2); DoubleTensor wrtA = Differentiator.reverseModeAutoDiff(output, A).withRespectTo(A); DoubleTensor expectedWrt = DoubleTensor.create(4, 16, 36, 64); assertArrayEquals(expectedWrt.asFlatDoubleArray(), wrtA.asFlatDoubleArray(), 0.0); assertArrayEquals(expectedWrt.getShape(), wrtA.getShape()); }
@Test public void valuesAreBeingWrittenOut() throws IOException { DoubleVertex unobservedGaussianVertex = new GaussianVertex(0, 1); DoubleVertex observedGammaVertex = new GammaVertex(2, 3); observedGammaVertex.observe(2.5); DoubleVertex gammaMultipliedVertex = observedGammaVertex.times(new ConstantDoubleVertex(4)); Vertex resultVertex = gammaMultipliedVertex.plus(unobservedGaussianVertex); gammaMultipliedVertex.setLabel("Gamma Multiplied"); DotSaver dotSaver = new DotSaver(new BayesianNetwork(resultVertex.getConnectedGraph())); dotSaver.save(outputWriter, true); String expectedOutputWithValues = readFileToString(OUTPUT_WITH_VALUES_FILENAME); checkDotFilesMatch(outputWriter.toString(), expectedOutputWithValues); }