@Test public void doesUpdateProgressAndFinishProgressWhenStreaming() { StatusBar progressBar = mock(StatusBar.class); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); Stream<NetworkSample> sampleStream = new NetworkSamplesGenerator(algorithm, () -> progressBar).stream(); sampleStream.limit(10).count(); sampleStream.close(); Mockito.verify(progressBar, times(10)).setMessage(anyString()); Mockito.verify(progressBar).finish(); }
@Test(expected = IllegalArgumentException.class) public void doesNotAllowZeroDownSample() { TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); unitUnderTest.downSampleInterval(0).stream(); }
public static double priorProbabilityTrue(Vertex<? extends Tensor<Boolean>> vertex, int sampleCount, KeanuRandom random) { KeanuProbabilisticModel model = new KeanuProbabilisticModel(vertex.getConnectedGraph()); long trueCount = MetropolisHastings.withDefaultConfigFor(model, random) .generatePosteriorSamples(model, Collections.singletonList(vertex)).stream() .limit(sampleCount) .filter(state -> state.get(vertex).scalar()) .count(); return trueCount / (double) sampleCount; }
@Test public void streamsExpectedNumberOfSamples() { AtomicInteger stepCount = new AtomicInteger(0); AtomicInteger sampleCount = new AtomicInteger(0); TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(stepCount, sampleCount); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); int totalCollected = 5; int dropCount = 3; int downSampleInterval = 2; unitUnderTest.dropCount(dropCount).downSampleInterval(downSampleInterval); unitUnderTest.stream() .limit(totalCollected) .collect(Collectors.toList()); //expected step + sample count differs from generate case due to different behaviour int expectedTotal = dropCount + totalCollected * downSampleInterval; assertEquals(expectedTotal, algorithm.stepCount.get() + algorithm.sampleCount.get()); assertEquals(totalCollected, algorithm.sampleCount.get()); }
private static double calculateMeanOfVertex(IntegerVertex vertex) { KeanuProbabilisticModel model = new KeanuProbabilisticModel(vertex.getConnectedGraph()); return MetropolisHastings.withDefaultConfigFor(model, KeanuRandom.getDefaultRandom()) .generatePosteriorSamples(model, Collections.singletonList(vertex)).stream() .limit(2000) .collect(Collectors.averagingInt((NetworkSample state) -> state.get(vertex).scalar())); } }
@Test public void canStreamSamples() { int sampleCount = 1000; int dropCount = 100; int downSampleInterval = 1; GaussianVertex A = new GaussianVertex(0, 1); KeanuProbabilisticModel model = new KeanuProbabilisticModel(A.getConnectedGraph()); MetropolisHastings algo = Keanu.Sampling.MetropolisHastings.withDefaultConfigFor(model); double averageA = algo.generatePosteriorSamples(model, model.getLatentVariables()) .dropCount(dropCount) .downSampleInterval(downSampleInterval) .stream() .limit(sampleCount) .mapToDouble(networkState -> networkState.get(A).scalar()) .average().getAsDouble(); assertEquals(0.0, averageA, 0.1); }
model, Arrays.asList(oRingFailure, residualFuel, alarm1FalsePositive) ).stream();