public static Variable[] continuousVarsOf (Factor fg) { List vars = new ArrayList (); VarSet vs = fg.varSet (); for (int vi = 0; vi < vs.size (); vi++) { Variable var = vs.get (vi); if (var.isContinuous ()) { vars.add (var); } } return (Variable[]) vars.toArray (new Variable [vars.size ()]); }
public void addFactor (Factor factor) { super.addFactor (factor); if (factor.varSet ().size() == 2) { edges.add (factor.varSet ()); } }
private boolean isAllEqual (Assignment assn) { Object val1 = assn.getObject (xs.get (0)); for (int i = 1; i < xs.size (); i++) { Object val2 = assn.getObject (xs.get (i)); if (!val1.equals (val2)) return false; } return true; }
public void unionAll (Factor factor) { VarSet varSet = factor.varSet (); for (int i = 0; i < varSet.size (); i++) { Variable var = varSet.get (i); union (var, factor); } }
private void initSizes () { sizes = new int [vertsList.size()]; for (int i = 0; i < sizes.length; i++) { Variable var = vertsList.get (i); if (var.isContinuous ()) { throw new UnsupportedOperationException ("Attempt to create AssignmentIterator over "+vertsList+", but "+var+" is continuous."); } sizes[i] = var.getNumOutcomes (); } max = vertsList.weight (); }
public static int[] computeSizes (Factor result) { int nv = result.varSet ().size(); int[] szs = new int [nv]; for (int i = 0; i < nv; i++) { Variable var = result.getVariable (i); szs[i] = var.getNumOutcomes (); } return szs; }
private boolean isAllEqual (Assignment assn) { Object val1 = assn.getObject (xs.get (0)); for (int i = 1; i < xs.size (); i++) { Object val2 = assn.getObject (xs.get (i)); if (!val1.equals (val2)) return false; } return true; }
public void unionAll (Factor factor) { VarSet varSet = factor.varSet (); for (int i = 0; i < varSet.size (); i++) { Variable var = varSet.get (i); union (var, factor); } }
private void addVarsIfNecessary (VarSet varSet) { for (int i = 0; i < varSet.size(); i++) { Variable var = varSet.get (i); if (universe == null) { universe = var.getUniverse (); } if (getIndex (var) < 0) { cacheVariable (var); } } }
private void addVarsIfNecessary (VarSet varSet) { for (int i = 0; i < varSet.size(); i++) { Variable var = varSet.get (i); if (universe == null) { universe = var.getUniverse (); } if (getIndex (var) < 0) { cacheVariable (var); } } }
public void testVarSet () { Variable var = new Variable (Variable.CONTINUOUS); Factor f = new UniNormalFactor (var, -1.0, 1.5); assertEquals (1, f.varSet ().size ()); assertTrue (f.varSet().contains (var)); }
public void testVarSet () { Variable var = new Variable (Variable.CONTINUOUS); Factor f = new UniformFactor (var, -1.0, 1.5); assertEquals (1, f.varSet ().size ()); assertTrue (f.varSet().contains (var)); }
public void testVarSet () { Variable var = new Variable (Variable.CONTINUOUS); Factor f = new BetaFactor (var, 0.5, 0.5); assertEquals (1, f.varSet ().size ()); assertTrue (f.varSet().contains (var)); }
public void sendMessage (FactorGraph mdl, Variable from, Factor to) { // System.err.println ("...max-prod message"); int fromIdx = messages.getIndex (from); int toIdx = messages.getIndex (to); Factor msg = msgProduct (null, fromIdx, toIdx); msg.normalize (); assert msg.varSet ().size () == 1; assert msg.varSet ().contains (from); messages.put (fromIdx, toIdx, msg); }
public void sendMessage (FactorGraph mdl, Variable from, Factor to) { // System.err.println ("...sum-prod message"); int fromIdx = messages.getIndex (from); int toIdx = messages.getIndex (to); Factor msg = msgProduct (null, fromIdx, toIdx); msg.normalize (); assert msg.varSet ().size () == 1; assert msg.varSet ().contains (from); messages.put (fromIdx, toIdx, msg); }
public void testRedundantDomains () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); fg.multiplyBy (ltbl1); assertEquals (3, fg.varSet ().size ()); assertEquals ("Wrong factors in FG, was "+fg.dumpToString (), 3, fg.factors ().size ()); Assignment assn = new Assignment (fg.varSet ().toVariableArray (), new int [3]); assertEquals (0.128, fg.value (assn), 1e-5); }
public void testMarginalize () { Variable[] vars = new Variable[] { new Variable (2), new Variable (2) }; TableFactor ptl = new TableFactor (vars, new double[] { 1, 2, 3, 4}); TableFactor ptl2 = (TableFactor) ptl.marginalize (vars[1]); assertEquals ("FAILURE: Potential has too many vars.\n "+ptl2, 1, ptl2.varSet ().size ()); assertTrue ("FAILURE: Potential does not contain "+vars[1]+":\n "+ptl2, ptl2.varSet ().contains (vars[1])); double[] expected = new double[] { 4, 6 }; assertTrue ("FAILURE: Potential has incorrect values. Expected "+ArrayUtils.toString (expected)+"was "+ptl2, Maths.almostEquals (ptl2.toValueArray (), expected, 1e-5)); }