logProbabilityBeforeStep, temperature ).getLogProbabilityAfterStep();
/** * @param chosenVariables variables to get a proposed change for * @param logProbabilityBeforeStep The log of the previous state's probability * @param temperature Temperature for simulated annealing. This * should be constant if no annealing is wanted * @return the log probability of the network after either accepting or rejecting the sample */ public StepResult step(final Set<Variable> chosenVariables, final double logProbabilityBeforeStep, final double temperature) { Proposal proposal = proposalDistribution.getProposal(chosenVariables, random); rejectionStrategy.onProposalCreated(proposal); final double logProbabilityAfterStep = model.logProbAfter(proposal.getProposalTo(), logProbabilityBeforeStep); if (!ProbabilityCalculator.isImpossibleLogProb(logProbabilityAfterStep)) { final double logProbabilityDelta = logProbabilityAfterStep - logProbabilityBeforeStep; final double pqxOld = proposalDistribution.logProbAtFromGivenTo(proposal); final double pqxNew = proposalDistribution.logProbAtToGivenFrom(proposal); final double annealFactor = (1.0 / temperature); final double hastingsCorrection = pqxOld - pqxNew; final double logR = annealFactor * logProbabilityDelta + hastingsCorrection; final double r = Math.exp(logR); final boolean shouldAccept = r >= random.nextDouble(); if (shouldAccept) { return new StepResult(true, logProbabilityAfterStep); } } proposalDistribution.onProposalRejected(); rejectionStrategy.onProposalRejected(proposal); return new StepResult(false, logProbabilityBeforeStep); }
@Test public void doesCalculateCorrectLogProbAfterAcceptingStep() { DoubleVertex A = new GaussianVertex(0, 1); A.setValue(1.0); DoubleVertex B = A.times(2); DoubleVertex observedB = new GaussianVertex(B, 1); observedB.observe(5); BayesianNetwork bayesNet = new BayesianNetwork(A.getConnectedGraph()); KeanuProbabilisticModel model = new KeanuProbabilisticModel(bayesNet); double logProbBeforeStep = model.logProb(); MetropolisHastingsStep mhStep = new MetropolisHastingsStep( model, new PriorProposalDistribution(bayesNet.getAllVertices()), new RollBackToCachedValuesOnRejection(bayesNet.getLatentVertices()), alwaysAccept ); MetropolisHastingsStep.StepResult result = mhStep.step( Collections.singleton(A), logProbBeforeStep ); assertTrue(result.isAccepted()); assertEquals(model.logProb(), result.getLogProbabilityAfterStep(), 1e-10); }
@Test public void doesRejectWhenRejectProbabilityIsOne() { DoubleVertex A = new GaussianVertex(0, 1); A.setValue(0.5); DoubleVertex B = A.times(2); DoubleVertex C = new GaussianVertex(B, 1); C.observe(5.0); ProbabilisticModel model = new KeanuProbabilisticModel(A.getConnectedGraph()); MetropolisHastingsStep mhStep = stepFunctionWithConstantProposal(model, 10, alwaysReject); MetropolisHastingsStep.StepResult result = mhStep.step( Collections.singleton(A), model.logProb() ); assertFalse(result.isAccepted()); assertEquals(0.5, A.getValue(0), 1e-10); }
@Category(Slow.class) @Test public void doesAllowCustomProposalDistribution() { DoubleVertex A = new GaussianVertex(0, 1); A.setValue(0.0); ProbabilisticModel model = new KeanuProbabilisticModel(A.getConnectedGraph()); MetropolisHastingsStep mhStep = stepFunctionWithConstantProposal(model, 1.0, alwaysAccept); MetropolisHastingsStep.StepResult result = mhStep.step( Collections.singleton(A), model.logProb() ); assertTrue(result.isAccepted()); assertEquals(1.0, A.getValue(0), 1e-10); }
@Test public void doesRejectOnImpossibleProposal() { DoubleVertex A = new UniformVertex(0, 1); A.setValue(0.5); ProbabilisticModel model = new KeanuProbabilisticModel(A.getConnectedGraph()); MetropolisHastingsStep mhStep = stepFunctionWithConstantProposal(model, -1, alwaysAccept); MetropolisHastingsStep.StepResult result = mhStep.step( Collections.singleton(A), model.logProb() ); assertFalse(result.isAccepted()); assertEquals(0.5, A.getValue(0), 1e-10); }
@Override public void step() { Set<Variable> chosenVariables = variableSelector.select(latentVariables, sampleNum); logProbabilityBeforeStep = mhStep.step( chosenVariables, logProbabilityBeforeStep ).getLogProbabilityAfterStep(); sampleNum++; }