@Override public NetworkSamplesGenerator generatePosteriorSamples(final ProbabilisticModel model, final List<? extends Variable> variablesToSampleFrom) { return new NetworkSamplesGenerator(setupSampler(model, variablesToSampleFrom), StatusBar::new); }
/** * @param model a probabilistic model containing latent variables * @param variablesToSampleFrom the variables to include in the returned samples * @param sampleCount number of samples to take using the algorithm * @return Samples for each variable ordered by MCMC iteration */ @Override public NetworkSamples getPosteriorSamples(ProbabilisticModel model, List<? extends Variable> variablesToSampleFrom, int sampleCount) { return generatePosteriorSamples(model, variablesToSampleFrom) .generate(sampleCount); }
public NetworkState getMaxAPosteriori(ProbabilisticModel model, int sampleCount) { AnnealingSchedule schedule = exponentialSchedule(sampleCount, 2, 0.01); return getMaxAPosteriori(model, sampleCount, schedule); }
@Test(expected = IllegalArgumentException.class) public void doesNotAllowNegativeDropCount() { TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); unitUnderTest.dropCount(-10).generate(100); }
public static io.improbable.keanu.algorithms.mcmc.SimulatedAnnealing withDefaultConfigFor(KeanuProbabilisticModel model, KeanuRandom random) { return builder() .proposalDistribution(new PriorProposalDistribution(model.getLatentVertices())) .rejectionStrategy(new RollbackAndCascadeOnRejection(model.getLatentVertices())) .random(random) .build(); }
@Test(expected = IllegalArgumentException.class) public void doesNotAllowZeroDownSample() { TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); unitUnderTest.downSampleInterval(0).stream(); }
@Override public void step() { Set<Variable> chosenVariables = variableSelector.select(latentVariables, sampleNum); logProbabilityBeforeStep = mhStep.step( chosenVariables, logProbabilityBeforeStep ).getLogProbabilityAfterStep(); sampleNum++; }
private MetropolisHastingsStep stepFunctionWithConstantProposal(ProbabilisticModel model, double constant, KeanuRandom random) { List<Vertex> latentVertices = (List<Vertex>) model.getLatentVariables(); return new MetropolisHastingsStep( model, constantProposal(constant), new RollBackToCachedValuesOnRejection(latentVertices), random ); }
private SamplingAlgorithm setupSampler(final ProbabilisticModel model, final List<? extends Variable> variablesToSampleFrom) { MetropolisHastingsStep mhStep = new MetropolisHastingsStep( model, proposalDistribution, rejectionStrategy, random ); return new MetropolisHastingsSampler(model.getLatentVariables(), variablesToSampleFrom, mhStep, variableSelector, model.logProb()); }
public StepResult step(final Set<Variable> chosenVariables, final double logProbabilityBeforeStep) { return step(chosenVariables, logProbabilityBeforeStep, DEFAULT_TEMPERATURE); }
public static io.improbable.keanu.algorithms.mcmc.MetropolisHastings.MetropolisHastingsBuilder builder() { return io.improbable.keanu.algorithms.mcmc.MetropolisHastings.builder(); } }
@Override public NetworkSample sample() { step(); return new NetworkSample(SamplingAlgorithm.takeSample((List<? extends Variable<Object, ?>>) variablesToSampleFrom), logProbabilityBeforeStep); } private static void takeSamples(Map<VariableReference, List<?>> samples, List<? extends Variable> fromVariables) {
private void dropSamples(int dropCount, StatusBar statusBar) { if (dropCount == 0) { return; } statusBar.setMessage("Dropping samples..."); PercentageComponent statusPercent = newPercentageComponentAndAddToStatusBar(statusBar); for (int i = 0; i < dropCount; i++) { algorithm.step(); statusPercent.progress((i + 1) / (double) dropCount); } statusBar.removeComponent(statusPercent); }
public static io.improbable.keanu.algorithms.mcmc.SimulatedAnnealing.SimulatedAnnealingBuilder builder() { return io.improbable.keanu.algorithms.mcmc.SimulatedAnnealing.builder(); } }
public static MetropolisHastingsBuilder builder() { return new MetropolisHastingsBuilder(); }
public MetropolisHastings build() { return new MetropolisHastings(random, proposalDistribution, variableSelector, rejectionStrategy); }
public static SimulatedAnnealingBuilder builder() { return new SimulatedAnnealingBuilder(); }
public SimulatedAnnealing build() { return new SimulatedAnnealing(random, proposalDistribution, variableSelector, rejectionStrategy); }
private ProposalDistribution constantProposal(double constant) { return new ConstantProposalDistribution(constant); }
@Test(expected = IllegalArgumentException.class) public void doesNotAllowDroppingMoreThanRequesting() { TestSamplingAlgorithm algorithm = new TestSamplingAlgorithm(new AtomicInteger(0), new AtomicInteger(0)); NetworkSamplesGenerator unitUnderTest = new NetworkSamplesGenerator(algorithm, StatusBar::new); unitUnderTest.dropCount(200).generate(100); }