/** * Runs the MetropolisHastings algorithm and saves the resulting samples to results */ public void run() { KeanuProbabilisticModel model = new KeanuProbabilisticModel(buildBayesianNetwork()); Integer numSamples = 500; results = Keanu.Sampling.MetropolisHastings.withDefaultConfigFor(model).generatePosteriorSamples( model, model.getLatentVariables() ).dropCount(numSamples/5).downSampleInterval(3).generate(numSamples); }
@Test(expected = IllegalArgumentException.class) public void doesNotAllowNegativeDropCount() { TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); unitUnderTest.dropCount(-10).generate(100); }
/** * @return A stream of samples starting after dropping. Down-sampling is handled outside of the stream (i.e. the * stream will be the final result after dropping and down-sampling) */ public Stream<NetworkSample> stream() { StatusBar statusBar = statusBarSupplier.get(); dropSamples(dropCount, statusBar); final AtomicInteger sampleNumber = new AtomicInteger(0); return Stream.generate(() -> { sampleNumber.getAndIncrement(); for (int i = 0; i < downSampleInterval - 1; i++) { algorithm.step(); } NetworkSample sample = algorithm.sample(); statusBar.setMessage(String.format("Sample #%,d completed", sampleNumber.get())); return sample; }).onClose(statusBar::finish); }
@Test(expected = IllegalArgumentException.class) public void doesNotAllowZeroDownSample() { TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); unitUnderTest.downSampleInterval(0).stream(); }
@Test public void dropsAndSamplesExpectedNumberOfStepsOnGeneration() { AtomicInteger stepCount = new AtomicInteger(0); AtomicInteger sampleCount = new AtomicInteger(0); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(stepCount, sampleCount); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); int totalGenerated = 12; int dropCount = 3; int downSampleInterval = 2; unitUnderTest.dropCount(dropCount).downSampleInterval(downSampleInterval); NetworkSamples samples = unitUnderTest.generate(totalGenerated); int expectedCollected = (int) Math.ceil((totalGenerated - dropCount) / (double) downSampleInterval); assertEquals(totalGenerated, algorithm.stepCount.get() + algorithm.sampleCount.get()); assertEquals(expectedCollected, samples.size()); }
@Test public void streamsExpectedNumberOfSamples() { AtomicInteger stepCount = new AtomicInteger(0); AtomicInteger sampleCount = new AtomicInteger(0); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(stepCount, sampleCount); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); int totalCollected = 5; int dropCount = 3; int downSampleInterval = 2; unitUnderTest.dropCount(dropCount).downSampleInterval(downSampleInterval); unitUnderTest.stream() .limit(totalCollected) .collect(Collectors.toList()); //expected step + sample count differs from generate case due to different behaviour int expectedTotal = dropCount + totalCollected * downSampleInterval; assertEquals(expectedTotal, algorithm.stepCount.get() + algorithm.sampleCount.get()); assertEquals(totalCollected, algorithm.sampleCount.get()); }
@Test public void doesUpdateStatusAndFinishStatusOnGeneration() { AtomicInteger stepCount = new AtomicInteger(0); AtomicInteger sampleCount = new AtomicInteger(0); StatusBar statusBar = mock(StatusBar.class); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(stepCount, sampleCount); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, () -> statusBar); unitUnderTest.generate(10); Mockito.verify(statusBar, times(1)).setMessage(anyString()); Mockito.verify(statusBar).finish(); }
@Test public void canStreamSamples() { int sampleCount = 1000; int dropCount = 100; int downSampleInterval = 1; GaussianVertex A = new GaussianVertex(0, 1); KeanuProbabilisticModel model = new KeanuProbabilisticModel(A.getConnectedGraph()); MetropolisHastings algo = Keanu.Sampling.MetropolisHastings.withDefaultConfigFor(model); double averageA = algo.generatePosteriorSamples(model, model.getLatentVariables()) .dropCount(dropCount) .downSampleInterval(downSampleInterval) .stream() .limit(sampleCount) .mapToDouble(networkState -> networkState.get(A).scalar()) .average().getAsDouble(); assertEquals(0.0, averageA, 0.1); }
/** * @param model a probabilistic model containing latent variables * @param variablesToSampleFrom the variables to include in the returned samples * @param sampleCount number of samples to take using the algorithm * @return Samples for each variable ordered by MCMC iteration */ @Override public NetworkSamples getPosteriorSamples(ProbabilisticModel model, List<? extends Variable> variablesToSampleFrom, int sampleCount) { return generatePosteriorSamples(model, variablesToSampleFrom) .generate(sampleCount); }
@Test public void doesUpdateProgressAndFinishProgressWhenStreaming() { StatusBar progressBar = mock(StatusBar.class); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); Stream<NetworkSample> sampleStream = new NetworkSamplesGenerator(algorithm, () -> progressBar).stream(); sampleStream.limit(10).count(); sampleStream.close(); Mockito.verify(progressBar, times(10)).setMessage(anyString()); Mockito.verify(progressBar).finish(); }
@Override public NetworkSamplesGenerator generatePosteriorSamples(final ProbabilisticModel model, final List<? extends Variable> fromVariables) { Preconditions.checkArgument(model instanceof ProbabilisticModelWithGradient, "NUTS requires a model on which gradients can be calculated."); return new NetworkSamplesGenerator(setupSampler((ProbabilisticModelWithGradient) model, fromVariables), StatusBar::new); }
List<Double> logOfMasterPForEachSample = new ArrayList<>(); dropSamples(dropCount, statusBar); PercentageComponent statusPercentage = newPercentageComponentAndAddToStatusBar(statusBar); RemainingTimeComponent remainingTimeComponent = new RemainingTimeComponent(totalSampleCount); statusBar.addComponent(remainingTimeComponent);
public static double priorProbabilityTrue(Vertex<? extends Tensor<Boolean>> vertex, int sampleCount, KeanuRandom random) { KeanuProbabilisticModel model = new KeanuProbabilisticModel(vertex.getConnectedGraph()); long trueCount = MetropolisHastings.withDefaultConfigFor(model, random) .generatePosteriorSamples(model, Collections.singletonList(vertex)).stream() .limit(sampleCount) .filter(state -> state.get(vertex).scalar()) .count(); return trueCount / (double) sampleCount; }
private void dropSamples(int dropCount, StatusBar statusBar) { if (dropCount == 0) { return; } statusBar.setMessage("Dropping samples..."); PercentageComponent statusPercent = newPercentageComponentAndAddToStatusBar(statusBar); for (int i = 0; i < dropCount; i++) { algorithm.step(); statusPercent.progress((i + 1) / (double) dropCount); } statusBar.removeComponent(statusPercent); }
@Test public void doesCreateNewStatusBarOnGenerationFinish() { AtomicInteger stepCount = new AtomicInteger(0); AtomicInteger sampleCount = new AtomicInteger(0); StatusBar statusBar1 = mock(StatusBar.class); StatusBar statusBar2 = mock(StatusBar.class); AtomicInteger statusBarCreationCount = new AtomicInteger(0); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(stepCount, sampleCount); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, () -> { int callNumber = statusBarCreationCount.getAndIncrement(); if (callNumber == 0) { return statusBar1; } else { return statusBar2; } }); unitUnderTest.generate(10); Mockito.verify(statusBar1, times(1)).setMessage(anyString()); Mockito.verify(statusBar1).finish(); unitUnderTest.generate(8); Mockito.verify(statusBar2, times(1)).setMessage(anyString()); Mockito.verify(statusBar2).finish(); }
/** * Sample from the posterior of a probabilistic model using the No-U-Turn-Sampling algorithm * * @param model the probabilistic model to sample from * @param variablesToSampleFrom the variables inside the probabilistic model to sample from * @return Samples taken with NUTS */ @Override public NetworkSamples getPosteriorSamples(final ProbabilisticModel model, final List<? extends Variable> variablesToSampleFrom, final int sampleCount) { return generatePosteriorSamples(model, variablesToSampleFrom).generate(sampleCount); }
@Override public NetworkSamplesGenerator generatePosteriorSamples(final ProbabilisticModel model, final List<? extends Variable> variablesToSampleFrom) { return new NetworkSamplesGenerator(setupSampler(model, variablesToSampleFrom), StatusBar::new); }
private static double calculateMeanOfVertex(IntegerVertex vertex) { KeanuProbabilisticModel model = new KeanuProbabilisticModel(vertex.getConnectedGraph()); return MetropolisHastings.withDefaultConfigFor(model, KeanuRandom.getDefaultRandom()) .generatePosteriorSamples(model, Collections.singletonList(vertex)).stream() .limit(2000) .collect(Collectors.averagingInt((NetworkSample state) -> state.get(vertex).scalar())); } }
public static double runUsingBinomial(int numberOfStudents, int numberOfYesAnswers) { int numberOfSamples = 100; UniformVertex probabilityOfCheating = new UniformVertex(0.0, 1.0); DoubleVertex pYesAnswer = probabilityOfCheating.times(0.5).plus(0.25); BinomialVertex answerTotal = new BinomialVertex(pYesAnswer, numberOfStudents); answerTotal.observe(numberOfYesAnswers); KeanuProbabilisticModel model = new KeanuProbabilisticModel(answerTotal.getConnectedGraph()); NetworkSamplesGenerator samplesGenerator = Keanu.Sampling.MetropolisHastings.withDefaultConfigFor(model) .generatePosteriorSamples(model, singletonList(probabilityOfCheating)); NetworkSamples networkSamples = samplesGenerator.dropCount(numberOfSamples / 10) .downSampleInterval(model.getLatentVariables().size()) .generate(numberOfSamples); double approximateProbabilityOfCheating = networkSamples .getDoubleTensorSamples(probabilityOfCheating) .getAverages() .scalar(); return approximateProbabilityOfCheating; }
@Test(expected = IllegalArgumentException.class) public void doesNotAllowDroppingMoreThanRequesting() { TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); unitUnderTest.dropCount(200).generate(100); }