@Test public void canUseGradientBasedSamplingWithAssertVertex() { DoubleVertex A = new GaussianVertex(20.0, 1.0); DoubleVertex B = new GaussianVertex(20.0, 1.0); A.setValue(21.5); B.setAndCascade(21.5); A.greaterThan(new ConstantDoubleVertex(20)).assertTrue(); B.greaterThan(new ConstantDoubleVertex(20)).assertTrue(); DoubleVertex Cobserved = new GaussianVertex(A.plus(B), 1.0); Cobserved.observe(46.0); KeanuProbabilisticModel bayesNet = new KeanuProbabilisticModel(Arrays.asList(A, B, Cobserved)); PosteriorSamplingAlgorithm samplingAlgorithm = Keanu.Sampling.MCMC.withDefaultConfigFor(bayesNet); assertThat(samplingAlgorithm, instanceOf(NUTS.class)); }
@Test public void checkMHIsRunForNonDifferentiableNetwork() { GaussianVertex gaussianA = new GaussianVertex(5., 1.); FloorVertex nonDiffable = new FloorVertex(gaussianA); GaussianVertex postNonDiffLatent = new GaussianVertex(nonDiffable, 1.); KeanuProbabilisticModel model = new KeanuProbabilisticModel(postNonDiffLatent.getConnectedGraph()); PosteriorSamplingAlgorithm samplingAlgorithm = Keanu.Sampling.MCMC.withDefaultConfigFor(model); assertTrue(samplingAlgorithm instanceof MetropolisHastings); }
/** * @param model network for which to choose sampling algorithm. * @return recommended sampling algorithm for this network. */ public PosteriorSamplingAlgorithm withDefaultConfigFor(KeanuProbabilisticModel model) { return withDefaultConfigFor(model, KeanuRandom.getDefaultRandom()); }
@Test public void checkNUTSIsRunForDifferentiableNetwork() { GaussianVertex gaussianA = new GaussianVertex(5., 1.); GaussianVertex gaussianB = new GaussianVertex(gaussianA, 1.); KeanuProbabilisticModel model = new KeanuProbabilisticModel(gaussianB.getConnectedGraph()); PosteriorSamplingAlgorithm samplingAlgorithm = Keanu.Sampling.MCMC.withDefaultConfigFor(model); assertTrue(samplingAlgorithm instanceof NUTS); }