public void testLogNormalize () { FactorGraph mdl = models [0]; Iterator it = mdl.variablesIterator (); Variable v1 = (Variable) it.next(); Variable v2 = (Variable) it.next(); Random rand = new Random (3214123); for (int i = 0; i < 10; i++) { Factor ptl = randomEdgePotential (rand, v1, v2); Factor norm1 = new LogTableFactor((AbstractTableFactor) ptl); Factor norm2 = ptl.duplicate(); norm1.normalize(); norm2.normalize(); assertTrue ("LogNormalize failed! Correct: "+norm2+" Log-normed: "+norm1, norm1.almostEquals (norm2)); } }
public static double oneDistance (Factor bel1, Factor bel2) { Set vs1 = bel1.varSet (); Set vs2 = bel2.varSet (); if (!vs1.equals (vs2)) { throw new IllegalArgumentException ("Attempt to take distancebetween mismatching potentials "+bel1+" and "+bel2); } double dist = 0; for (AssignmentIterator it = bel1.assignmentIterator (); it.hasNext ();) { Assignment assn = it.assignment (); dist += Math.abs (bel1.value (assn) - bel2.value (assn)); it.advance (); } return dist; }
private Factor eliminate (Collection allPhi, Variable node) { HashSet phiSet = new HashSet(); /* collect the potentials that include this variable */ for (Iterator j = allPhi.iterator(); j.hasNext(); ) { Factor cpf = (Factor) j.next (); if (cpf.varSet().isEmpty() || cpf.containsVar (node)) { phiSet.add (cpf); j.remove (); } } return TableFactor.multiplyAll (phiSet); }
private boolean willBeNaN2 (Factor product, Factor otherMsg) { Factor p2 = product.duplicate (); p2.multiplyBy (otherMsg); return p2.isNaN (); }
private void compareMarginals (String msg, Factor[] pre, Factor[] post) { for (int i = 0; i < pre.length; i++) { Factor ptl1 = pre[i]; Factor ptl2 = post[i]; assertTrue (msg + "\n" + ptl1.dumpToString () + "\n" + ptl2.dumpToString (), ptl1.almostEquals (ptl2, 1e-3)); } }
public void testJtViterbi() { JunctionTreeInferencer jti = new JunctionTreeInferencer(); for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models[mdlIdx]; BruteForceInferencer brute = new BruteForceInferencer (); JunctionTreeInferencer maxprod = JunctionTreeInferencer.createForMaxProduct (); JunctionTree jt = maxprod.buildJunctionTree (mdl); Factor joint = brute.joint (mdl); maxprod.computeMarginals (jt); for (Iterator it = mdl.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); Factor maxPotRaw = maxprod.lookupMarginal (var); Factor trueMaxPotRaw = joint.extractMax (var); Factor maxPot = maxPotRaw.duplicate().normalize (); Factor trueMaxPot = trueMaxPotRaw.duplicate().normalize (); assertTrue ("Maximization failed on model " + mdlIdx + " ! Normalized returns:\n" + maxPot.dumpToString () + "\nTrue: " + trueMaxPot.dumpToString (), maxPot.almostEquals (trueMaxPot, 0.01)); } } logger.info("Test jtViterbi passed."); }
public void sendMessage (FactorGraph mdl, Factor from, Variable to) { // System.err.println ("...max-prod message"); int fromIdx = messages.getIndex (from); int toIdx = messages.getIndex (to); Factor product = from.duplicate (); msgProduct (product, fromIdx, toIdx); Factor msg = product.extractMax (to); msg.normalize (); assert msg.varSet ().size () == 1; assert msg.varSet ().contains (to); messages.put (fromIdx, toIdx, msg); }
/** * This sends a max-product message. */ public void sendMessage (JunctionTree jt, VarSet from, VarSet to) { // System.err.println ("Send message "+from+" --> "+to); Collection sepset = jt.getSepset (from, to); Factor fromCpf = jt.getCPF (from); Factor toCpf = jt.getCPF (to); Factor oldSepsetPot = jt.getSepsetPot (from, to); Factor lambda = fromCpf.extractMax (sepset); lambda.normalize (); jt.setSepsetPot (lambda, from, to); toCpf = toCpf.multiply (lambda); toCpf.divideBy (oldSepsetPot); toCpf.normalize (); jt.setCPF (to, toCpf); }
Region region = (Region) it.next(); Factor belief = computeBelief(region); double thisEntropy = belief.entropy(); DiscreteFactor product = new LogTableFactor(belief.varSet()); for (Iterator ptlIt = region.factors.iterator(); ptlIt.hasNext();) { Factor ptl = (Factor) ptlIt.next(); for (AssignmentIterator assnIt = belief.assignmentIterator(); assnIt.hasNext();) { Assignment assn = assnIt.assignment(); double thisBel = belief.value(assn); thisAvgEnergy += thisBel * thisEnergy; assnIt.advance();
for (Iterator i = model.factorsIterator (); i.hasNext(); ){ Factor factor = (Factor) i.next (); allPhi.add(factor.duplicate()); if(newCPF.varSet().size() == 1) { singleCPF = newCPF; } else { singleCPF = newCPF.marginalizeOut (node); assert marginal.containsVar (query); assert marginal.varSet().size() == 1;
public void plusEquals (Factor f) { if (f instanceof DiscreteFactor) { DiscreteFactor factor = (DiscreteFactor) f; expandToContain (factor); factor = ensureOperandCompatible (factor); plusEqualsInternal (factor); } else if (f instanceof ConstantFactor) { plusEquals (f.value (new Assignment ())); } else { AbstractTableFactor tbl; try { tbl = f.asTable (); } catch (UnsupportedOperationException e) { throw new UnsupportedOperationException ("Don't know how to add "+this+" by "+f); } plusEquals (tbl); } }
public void testSparseMatrixN () { Variable x1 = new Variable (2); Variable x2 = new Variable (2); Variable alpha = new Variable (Variable.CONTINUOUS); Factor potts = new PottsTableFactor (x1, x2, alpha); Assignment alphAssn = new Assignment (alpha, 1.0); Factor tbl = potts.slice (alphAssn); System.out.println (tbl.dumpToString ()); int j = 0; double[] vals = new double[] { 0, -1, -1, 0 }; for (AssignmentIterator it = tbl.assignmentIterator (); it.hasNext ();) { assertEquals (vals[j++], tbl.logValue (it), 1e-5); it.advance (); } }
public static double corr (Factor factor) { if (factor.varSet ().size() != 2) throw new IllegalArgumentException ("corr() only works on Factors of size 2, tried "+factor); Variable v0 = factor.varSet ().get (0); Variable v1 = factor.varSet ().get (1); double eXY = 0.0; for (AssignmentIterator it = factor.assignmentIterator (); it.hasNext();) { Assignment assn = (Assignment) it.next (); int val0 = assn.get (v0); int val1 = assn.get (v1); eXY += factor.value (assn) * val0 * val1; } double eX = mean (factor.marginalize (v0)); double eY = mean (factor.marginalize (v1)); return eXY - eX * eY; }
/** * This sends a sum-product message, normalized to avoid * underflow. */ public void sendMessage (JunctionTree jt, VarSet from, VarSet to) { Collection sepset = jt.getSepset (from, to); Factor fromCpf = jt.getCPF (from); Factor toCpf = jt.getCPF (to); Factor oldSepsetPot = jt.getSepsetPot (from, to); Factor lambda = fromCpf.marginalize (sepset); lambda.normalize (); jt.setSepsetPot (lambda, from, to); toCpf = toCpf.multiply (lambda); toCpf.divideBy (oldSepsetPot); toCpf.normalize (); jt.setCPF (to, toCpf); }
public void testTreeViterbi() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; BruteForceInferencer brute = new BruteForceInferencer (); Inferencer maxprod = TreeBP.createForMaxProduct (); Factor joint = brute.joint (mdl); maxprod.computeMarginals (mdl); for (Iterator it = mdl.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); Factor maxPot = maxprod.lookupMarginal (var); Factor trueMaxPot = joint.extractMax (var); maxPot.normalize (); trueMaxPot.normalize (); assertTrue ("Maximization failed! Normalized returns:\n" + maxPot + "\nTrue: " + trueMaxPot, maxPot.almostEquals (trueMaxPot)); } } logger.info("Test treeViterbi passed: " + trees.length + " models."); }
public static double localMagnetization (Inferencer inferencer, Variable var) { if (var.getNumOutcomes () != 2) throw new IllegalArgumentException (); Factor marg = inferencer.lookupMarginal (var); AssignmentIterator it = marg.assignmentIterator (); double v1 = marg.value (it); it.advance (); double v2 = marg.value (it); return v1 - v2; }
Factor msgProduct (RegionEdge edge) { Factor product = new LogTableFactor (edge.from.vars); for (Iterator it = edge.neighboringParents.iterator (); it.hasNext ();) { RegionEdge otherEdge = (RegionEdge) it.next (); Factor otherMsg = oldMessages.getMessage (otherEdge.from, otherEdge.to); product.multiplyBy (otherMsg); } for (Iterator it = edge.loopingMessages.iterator (); it.hasNext ();) { RegionEdge otherEdge = (RegionEdge) it.next (); Factor otherMsg = newMessages.getMessage (otherEdge.from, otherEdge.to); product.divideBy (otherMsg); } return product; } }