/** * @param model a probabilistic model containing latent variables * @param variablesToSampleFrom the variables to include in the returned samples * @param sampleCount number of samples to take using the algorithm * @return Samples for each variable ordered by MCMC iteration */ @Override public NetworkSamples getPosteriorSamples(ProbabilisticModel model, List<? extends Variable> variablesToSampleFrom, int sampleCount) { return generatePosteriorSamples(model, variablesToSampleFrom) .generate(sampleCount); }
/** * Sample from the posterior of a probabilistic model using the No-U-Turn-Sampling algorithm * * @param model the probabilistic model to sample from * @param variablesToSampleFrom the variables inside the probabilistic model to sample from * @return Samples taken with NUTS */ @Override public NetworkSamples getPosteriorSamples(final ProbabilisticModel model, final List<? extends Variable> variablesToSampleFrom, final int sampleCount) { return generatePosteriorSamples(model, variablesToSampleFrom).generate(sampleCount); }
@Test(expected = IllegalArgumentException.class) public void doesNotAllowNegativeDropCount() { TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); unitUnderTest.dropCount(-10).generate(100); }
@Test public void doesUpdateStatusAndFinishStatusOnGeneration() { AtomicInteger stepCount = new AtomicInteger(0); AtomicInteger sampleCount = new AtomicInteger(0); StatusBar statusBar = mock(StatusBar.class); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(stepCount, sampleCount); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, () -> statusBar); unitUnderTest.generate(10); Mockito.verify(statusBar, times(1)).setMessage(anyString()); Mockito.verify(statusBar).finish(); }
@Test(expected = IllegalArgumentException.class) public void doesNotAllowDroppingMoreThanRequesting() { TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); unitUnderTest.dropCount(200).generate(100); }
@Test public void doesCreateNewStatusBarOnGenerationFinish() { AtomicInteger stepCount = new AtomicInteger(0); AtomicInteger sampleCount = new AtomicInteger(0); StatusBar statusBar1 = mock(StatusBar.class); StatusBar statusBar2 = mock(StatusBar.class); AtomicInteger statusBarCreationCount = new AtomicInteger(0); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(stepCount, sampleCount); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, () -> { int callNumber = statusBarCreationCount.getAndIncrement(); if (callNumber == 0) { return statusBar1; } else { return statusBar2; } }); unitUnderTest.generate(10); Mockito.verify(statusBar1, times(1)).setMessage(anyString()); Mockito.verify(statusBar1).finish(); unitUnderTest.generate(8); Mockito.verify(statusBar2, times(1)).setMessage(anyString()); Mockito.verify(statusBar2).finish(); }
/** * Runs the MetropolisHastings algorithm and saves the resulting samples to results */ public void run() { KeanuProbabilisticModel model = new KeanuProbabilisticModel(buildBayesianNetwork()); Integer numSamples = 500; results = Keanu.Sampling.MetropolisHastings.withDefaultConfigFor(model).generatePosteriorSamples( model, model.getLatentVariables() ).dropCount(numSamples/5).downSampleInterval(3).generate(numSamples); }
@Test public void dropsAndSamplesExpectedNumberOfStepsOnGeneration() { AtomicInteger stepCount = new AtomicInteger(0); AtomicInteger sampleCount = new AtomicInteger(0); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(stepCount, sampleCount); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); int totalGenerated = 12; int dropCount = 3; int downSampleInterval = 2; unitUnderTest.dropCount(dropCount).downSampleInterval(downSampleInterval); NetworkSamples samples = unitUnderTest.generate(totalGenerated); int expectedCollected = (int) Math.ceil((totalGenerated - dropCount) / (double) downSampleInterval); assertEquals(totalGenerated, algorithm.stepCount.get() + algorithm.sampleCount.get()); assertEquals(expectedCollected, samples.size()); }
public static double runUsingBinomial(int numberOfStudents, int numberOfYesAnswers) { int numberOfSamples = 100; UniformVertex probabilityOfCheating = new UniformVertex(0.0, 1.0); DoubleVertex pYesAnswer = probabilityOfCheating.times(0.5).plus(0.25); BinomialVertex answerTotal = new BinomialVertex(pYesAnswer, numberOfStudents); answerTotal.observe(numberOfYesAnswers); KeanuProbabilisticModel model = new KeanuProbabilisticModel(answerTotal.getConnectedGraph()); NetworkSamplesGenerator samplesGenerator = Keanu.Sampling.MetropolisHastings.withDefaultConfigFor(model) .generatePosteriorSamples(model, singletonList(probabilityOfCheating)); NetworkSamples networkSamples = samplesGenerator.dropCount(numberOfSamples / 10) .downSampleInterval(model.getLatentVariables().size()) .generate(numberOfSamples); double approximateProbabilityOfCheating = networkSamples .getDoubleTensorSamples(probabilityOfCheating) .getAverages() .scalar(); return approximateProbabilityOfCheating; }
.dropCount(numberOfSamples / 2) .downSampleInterval(model.getLatentVariables().size()) .generate(numberOfSamples);
@Test public void samplingWithAssertionWorks() { thrown.expect(GraphAssertionException.class); GaussianVertex gaussian = new GaussianVertex(5, 1); gaussian.greaterThan(new ConstantDoubleVertex(1000)).assertTrue(); KeanuProbabilisticModel model = new KeanuProbabilisticModel(gaussian.getConnectedGraph()); MetropolisHastings.withDefaultConfigFor(model).generatePosteriorSamples(model, model.getLatentVariables()).generate(10); }
.dropCount(sampleCount / 2) .downSampleInterval(model.getLatentVariables().size()) .generate(sampleCount);
.dropCount(numSamples / 2) .downSampleInterval(model.getLatentVariables().size()) .generate(numSamples);
@Test public void doesNotStoreSamplesThatWillBeDropped() { int sampleCount = 1000; int dropCount = 100; int downSampleInterval = 2; GaussianVertex A = new GaussianVertex(0, 1); KeanuProbabilisticModel model = new KeanuProbabilisticModel(A.getConnectedGraph()); NetworkSamples samples = Keanu.Sampling.MetropolisHastings.withDefaultConfigFor(model) .generatePosteriorSamples(model, model.getLatentVariables()) .dropCount(dropCount) .downSampleInterval(downSampleInterval) .generate(sampleCount); assertEquals((sampleCount - dropCount) / downSampleInterval, samples.size()); assertEquals(0.0, samples.getDoubleTensorSamples(A).getAverages().scalar(), 0.1); }