private static Set<Vertex> allParentsOf(Collection<Vertex> vertices) { Set<Vertex> allParents = new HashSet<>(); for (Vertex vertex : vertices) { allParents.addAll(vertex.getParents()); } return allParents; }
public void addParents(Collection<? extends Vertex> parents) { this.parents = ImmutableSet.<Vertex>builder().addAll(this.getParents()).addAll(parents).build(); parents.forEach(p -> p.addChild(this)); }
private static Collection<Vertex> getParentsIfValueNotKnownToBeConstant(Vertex visiting, Set<Vertex> constantValueVerticesCache) { return isValueKnownToBeConstant(visiting, constantValueVerticesCache) ? Collections.emptySet() : visiting.getParents(); }
private static Collection<Vertex> getParentsIfVertexIsNotProbabilistic(Vertex visiting) { return visiting.isProbabilistic() ? Collections.emptySet() : visiting.getParents(); }
private static void insertParentDependencies(Vertex<?> aVertex, Map<Vertex, Set<Vertex>> dependencies, Set<Vertex> verticesToCount) { dependencies.computeIfAbsent(aVertex, v -> new HashSet<>()); aVertex.getParents().forEach(parent -> { if (!dependencies.containsKey(parent)) { insertParentDependencies(parent, dependencies, verticesToCount); } final Set<Vertex> parentDependencies = dependencies.get(parent); dependencies.computeIfPresent(aVertex, (vertex, vertexDependencies) -> { vertexDependencies.addAll(parentDependencies); if (verticesToCount.contains(parent)) { vertexDependencies.add(parent); } return vertexDependencies; }); }); } }
/** * The dLogProb(x) method on Vertex returns a partial derivative of the Log Prob with respect to each * of its arguments and with respect to its value, x. This method searches these partials for any that * are parents of the vertices we are taking the derivative with respect to * * @param ofVertices the vertices that the derivative is being calculated "of" with respect to the wrtVertices * @param parentToWrtVertices a lookup * @return a map for a given vertex to a set of the wrt vertices that it is connected to */ private Map<Vertex, Set<DoubleVertex>> getVerticesWithNonzeroDiffWrt(Set<? extends Vertex<?>> ofVertices, Map<Vertex, Set<DoubleVertex>> parentToWrtVertices) { return ofVertices.stream() .collect(Collectors.toMap( v -> v, v -> { Set<DoubleVertex> parents = v.getParents().stream() .map(parent -> (DoubleVertex) parent) .filter(parentToWrtVertices::containsKey) .collect(Collectors.toSet()); if (!v.isObserved()) { parents.add((DoubleVertex) v); } return parents; } )); }
Stream<Vertex> verticesToAdd = Stream.concat(v.getParents().stream(), v.getChildren().stream()); verticesToAdd .filter(a -> !subgraphVertices.contains(a))
@Override protected boolean matchesSafely(Vertex<T> vertex, Description description) { description.appendText("vertex with parents ").appendValue(vertex.getParents()); return parentMatcher.matches(vertex.getParents()); }
/** * This method finds connections between a vertex's parents and any vertices that we are taking the derivative * wrt to * * @param ofVertices the vertices that the derivative is being calculated "of" with respect to the wrtVertices * @return a map for a given vertex to a set of vertices that are directly connected to the dLogProb result * of the ofVertices and a vertex that we are finding the gradient with respect to. */ private Map<Vertex, Set<DoubleVertex>> getParentsThatAreConnectedToWrtVertices(Set<? extends Vertex> ofVertices) { Map<Vertex, Set<DoubleVertex>> probabilisticParentLookup = new HashMap<>(); for (Vertex<?> probabilisticVertex : ofVertices) { Set<? extends Vertex> parents = probabilisticVertex.getParents(); for (Vertex parent : parents) { LambdaSection upstreamLambdaSection = LambdaSection.getUpstreamLambdaSection(parent, false); Set<Vertex> latentAndObservedVertices = upstreamLambdaSection.getLatentAndObservedVertices(); Set<DoubleVertex> latentVertices = latentAndObservedVertices.stream() .filter(this::isLatentDoubleVertexAndInWrtTo) .map(v -> (DoubleVertex) v) .collect(Collectors.toSet()); if (!latentVertices.isEmpty()) { probabilisticParentLookup.put(parent, latentVertices); } } } return probabilisticParentLookup; }
for (Vertex<?> parent : visiting.getParents()) { if (!discoveredGraph.contains(parent)) { stack.addFirst(parent);
private Set<GraphEdge> getParentEdges(Vertex vertex) { Set<GraphEdge> edges = new HashSet<>(); for (Object v : vertex.getParents()) { edges.add(new GraphEdge((Vertex) v, vertex)); } // Check if any of the edges represent a connection between the vertex and its hyperparameter and annotate it accordingly. Class vertexClass = vertex.getClass(); Method[] methods = vertexClass.getMethods(); for (Method method : methods) { SaveVertexParam annotation = method.getAnnotation(SaveVertexParam.class); if (annotation != null && Vertex.class.isAssignableFrom(method.getReturnType())) { String parentName = annotation.value(); try { Vertex parentVertex = (Vertex) method.invoke(vertex); GraphEdge parentEdge = new GraphEdge(vertex, parentVertex); GraphEdge foundEdge = edges.stream().filter(parentEdge::equals).findFirst() .orElseThrow(() -> new IllegalStateException("Did not find parent edge " + parentName)); foundEdge.appendToLabel(parentName); } catch (Exception e) { throw new IllegalArgumentException("Invalid parent retrieval function specified", e); } } } return edges; } }
public static void eval(Collection<? extends Vertex> vertices) { Deque<Vertex> stack = asDeque(vertices); Set<Vertex<?>> hasCalculated = new HashSet<>(); while (!stack.isEmpty()) { Vertex<?> head = stack.peek(); Set<Vertex<?>> parentsThatAreNotYetCalculated = parentsThatAreNotCalculated(hasCalculated, head.getParents()); if (head.isProbabilistic() || parentsThatAreNotYetCalculated.isEmpty()) { Vertex<?> top = stack.pop(); updateVertexValue(top); hasCalculated.add(top); } else { for (Vertex<?> vertex : parentsThatAreNotYetCalculated) { stack.push(vertex); } } } }
public static void lazyEval(Collection<? extends Vertex> vertices) { Deque<Vertex> stack = asDeque(vertices); while (!stack.isEmpty()) { Vertex<?> head = stack.peek(); Set<Vertex<?>> parentsThatAreNotYetCalculated = parentsThatAreNotCalculated(head.getParents()); if (head.isProbabilistic() || parentsThatAreNotYetCalculated.isEmpty()) { Vertex<?> top = stack.pop(); updateVertexValue(top); } else { for (Vertex<?> vertex : parentsThatAreNotYetCalculated) { stack.push(vertex); } } } }
collectPartials(partialDerivatives, dwrtOf); for (Vertex parent : visiting.getParents()) { if (!alreadyQueued.contains(parent) && parent.isDifferentiable()) { priorityQueue.offer(parent);
@BeforeClass public static void setUp() { v = Mockito.mock(Vertex.class); Set<Vertex> parents = new HashSet<>(); parent = Mockito.mock(Vertex.class); parents.add(parent); Set<Vertex> children = new HashSet<>(); child = Mockito.mock(Vertex.class); children.add(child); when(v.getParents()).thenReturn(parents); when(v.getChildren()).thenReturn(children); Set<Vertex> grandParents = new HashSet<>(); grandParent = Mockito.mock(Vertex.class); grandParents.add(grandParent); Set<Vertex> grandChildren = new HashSet<>(); grandChild = Mockito.mock(Vertex.class); grandChildren.add(grandChild); when(parent.getParents()).thenReturn(grandParents); when(child.getChildren()).thenReturn(grandChildren); when(v.getConnectedGraph()).thenReturn(ImmutableSet.of(v, parent, child, grandChild, grandParent)); network = new BayesianNetwork(v.getConnectedGraph()); }
Vertex<DoubleTensor> x = plate.get(xLabel); Vertex<DoubleTensor> y = plate.get(yLabel); assertThat(xPreviousProxy.getParents(), contains(previousX)); assertThat(x.getParents(), contains(xPreviousProxy)); assertThat(y.getParents(), contains(x)); previousX = x;
@Test public void idOrderingStillImpliesTopologicalOrdering() { for (Vertex v : outerNet.getVertices()) { Set<Vertex> parentSet = v.getParents(); for (Vertex parent : parentSet) { assertTrue(v.getId().compareTo(parent.getId()) > 0); } } }
@Test public void youCanCreateASetOfPlatesWithACommonParameterFromAnIterator() { GaussianVertex commonTheta = new GaussianVertex(0.5, 0.01); VertexLabel label = new VertexLabel("flip"); Plates plates = new PlateBuilder<Bean>() .fromIterator(ROWS.iterator()) .withFactory((plate, bean) -> { BooleanVertex flip = new BernoulliVertex(commonTheta); flip.observe(false); plate.add(label, flip); }) .build(); for (Plate plate : plates) { Vertex<DoubleTensor> flip = plate.get(label); assertThat(flip.getParents(), contains(commonTheta)); } }
@Test public void youCanCreateASetOfPlatesWithACommonParameterFromACount() { GaussianVertex commonTheta = new GaussianVertex(0.5, 0.01); VertexLabel label = new VertexLabel("flip"); Plates plates = new PlateBuilder<Bean>() .count(10) .withFactory((plate) -> { BooleanVertex flip = new BernoulliVertex(commonTheta); flip.observe(false); plate.add(label, flip); }) .build(); for (Plate plate : plates) { Vertex<DoubleTensor> flip = plate.get(label); assertThat(flip.getParents(), contains(commonTheta)); } }
@Test public void ifAVertexIsLabeledThatIsWhatsUsedToReferToItInThePlate() { VertexLabel label = new VertexLabel("label"); Vertex<?> startVertex = ConstantVertex.of(1.).setLabel(label); Plates plates = new PlateBuilder<Integer>() .withInitialState(startVertex) .withTransitionMapping(ImmutableMap.of(label, label)) .count(10) .withFactory((plate) -> { DoubleVertex intermediateVertex = new DoubleProxyVertex(label); plate.add(intermediateVertex); }) .build(); for (Plate plate : plates) { Vertex<?> vertex = plate.get(label); assertThat(vertex, hasLabel(hasUnqualifiedName(label.getUnqualifiedName()))); Vertex<?> parent = Iterables.getOnlyElement(vertex.getParents()); assertThat(parent, hasLabel(hasUnqualifiedName(label.getUnqualifiedName()))); } }