/** * @param cascadeFrom A collection that contains the vertices that have been updated. */ public static void cascadeUpdate(Collection<? extends Vertex> cascadeFrom) { PriorityQueue<Vertex> priorityQueue = new PriorityQueue<>(Comparator.comparing(Vertex::getId, Comparator.naturalOrder())); priorityQueue.addAll(cascadeFrom); HashSet<Vertex> alreadyQueued = new HashSet<>(cascadeFrom); while (!priorityQueue.isEmpty()) { Vertex<?> visiting = priorityQueue.poll(); updateVertexValue(visiting); for (Vertex<?> child : visiting.getChildren()) { if (!child.isProbabilistic() && !alreadyQueued.contains(child)) { priorityQueue.offer(child); alreadyQueued.add(child); } } } }
Stream<Vertex> verticesToAdd = Stream.concat(v.getParents().stream(), v.getChildren().stream()); verticesToAdd .filter(a -> !subgraphVertices.contains(a))
for (Vertex<?> child : visiting.getChildren()) { if (!discoveredGraph.contains(child)) { stack.addFirst(child);
public static <V extends Vertex & Differentiable> PartialsWithRespectTo forwardModeAutoDiff(V wrt, Collection<V> of) { PriorityQueue<V> priorityQueue = new PriorityQueue<>(Comparator.comparing(Vertex::getId, Comparator.naturalOrder())); priorityQueue.add(wrt); HashSet<Vertex> alreadyQueued = new HashSet<>(); alreadyQueued.add(wrt); Map<Vertex, PartialDerivative> partials = new HashMap<>(); Map<VertexId, PartialDerivative> ofWrt = new HashMap<>(); while (!priorityQueue.isEmpty()) { V visiting = priorityQueue.poll(); PartialDerivative partialOfVisiting = visiting.forwardModeAutoDifferentiation(partials); partials.put(visiting, partialOfVisiting); if (of.contains(visiting)) { ofWrt.put(visiting.getId(), partialOfVisiting); continue; } for (Vertex child : (Set<Vertex<?>>) visiting.getChildren()) { if (!child.isProbabilistic() && !alreadyQueued.contains(child) && child.isDifferentiable()) { priorityQueue.offer((V) child); alreadyQueued.add(child); } } } return new PartialsWithRespectTo(wrt, ofWrt); }
@BeforeClass public static void setUp() { v = Mockito.mock(Vertex.class); Set<Vertex> parents = new HashSet<>(); parent = Mockito.mock(Vertex.class); parents.add(parent); Set<Vertex> children = new HashSet<>(); child = Mockito.mock(Vertex.class); children.add(child); when(v.getParents()).thenReturn(parents); when(v.getChildren()).thenReturn(children); Set<Vertex> grandParents = new HashSet<>(); grandParent = Mockito.mock(Vertex.class); grandParents.add(grandParent); Set<Vertex> grandChildren = new HashSet<>(); grandChild = Mockito.mock(Vertex.class); grandChildren.add(grandChild); when(parent.getParents()).thenReturn(grandParents); when(child.getChildren()).thenReturn(grandChildren); when(v.getConnectedGraph()).thenReturn(ImmutableSet.of(v, parent, child, grandChild, grandParent)); network = new BayesianNetwork(v.getConnectedGraph()); }
@Test public void doesNotVisitVerticesMoreThanOnce() { GaussianVertex A = new GaussianVertex(0, 1); SinVertex sinA = A.sin(); AdditionVertex APlusSinA = A.plus(sinA); GaussianVertex C = new GaussianVertex(APlusSinA, 1); MutableInt callsToNext = new MutableInt(0); MutableInt callsToPredicate = new MutableInt(0); Set<Vertex> verticesDepthFirst = LambdaSection.getVertices(A, (v) -> { callsToNext.increment(); return v.getChildren(); }, (v) -> { callsToPredicate.increment(); return true; }); assertEquals(4, verticesDepthFirst.size()); assertThat(verticesDepthFirst, containsInAnyOrder(A, sinA, APlusSinA, C)); assertEquals(3, callsToNext.intValue()); assertEquals(3, callsToPredicate.intValue()); } }