private static boolean isDoubleOrObserved(Vertex v) { return (v instanceof DoubleVertex || v.isObserved()); }
private static boolean isValueKnownToBeConstant(Vertex vertex, Set<Vertex> constantValueVerticesCache) { return vertex instanceof ConstantVertex || vertex.isObserved() || constantValueVerticesCache.contains(vertex); } }
private boolean isLatentDoubleVertexAndInWrtTo(Vertex v) { return !v.isObserved() && wrtVertices.contains(v) && v instanceof DoubleVertex; }
private static Set<Vertex> getLatentDependencies(Set<Vertex> dependencies) { return dependencies.stream() .filter(v -> v.isProbabilistic() && !v.isObserved()) .collect(Collectors.toSet()); }
private static boolean isUnobservedProbabilistic(Vertex vertex) { return vertex.isProbabilistic() && !vertex.isObserved(); }
private static Map<Vertex, Set<Vertex>> getObservedVertexLatentDependencies(Collection<? extends Vertex> vertices) { Map<Vertex, Set<Vertex>> dependencies = TopologicalSort.mapDependencies(vertices); Map<Vertex, Set<Vertex>> observedVertexLatentDependencies = new HashMap<>(); for (Map.Entry<Vertex, Set<Vertex>> entry : dependencies.entrySet()) { Vertex<?> vertex = entry.getKey(); if (vertex.isObserved()) { Set<Vertex> vertexDependencies = entry.getValue(); Set<Vertex> latentDependencies = getLatentDependencies(vertexDependencies); observedVertexLatentDependencies.put(vertex, latentDependencies); } } return observedVertexLatentDependencies; }
private List<Vertex> getFilteredVertexList(VertexFilter filter) { return vertices.stream() .filter(v -> filter.filter(v.isProbabilistic(), v.isObserved(), v.getIndentation())) .collect(Collectors.toList()); }
/** * The dLogProb(x) method on Vertex returns a partial derivative of the Log Prob with respect to each * of its arguments and with respect to its value, x. This method searches these partials for any that * are parents of the vertices we are taking the derivative with respect to * * @param ofVertices the vertices that the derivative is being calculated "of" with respect to the wrtVertices * @param parentToWrtVertices a lookup * @return a map for a given vertex to a set of the wrt vertices that it is connected to */ private Map<Vertex, Set<DoubleVertex>> getVerticesWithNonzeroDiffWrt(Set<? extends Vertex<?>> ofVertices, Map<Vertex, Set<DoubleVertex>> parentToWrtVertices) { return ofVertices.stream() .collect(Collectors.toMap( v -> v, v -> { Set<DoubleVertex> parents = v.getParents().stream() .map(parent -> (DoubleVertex) parent) .filter(parentToWrtVertices::containsKey) .collect(Collectors.toSet()); if (!v.isObserved()) { parents.add((DoubleVertex) v); } return parents; } )); }
default PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) { if (((Vertex) this).isObserved()) { return PartialDerivative.EMPTY; } else { return withRespectToSelf(((Vertex) this).getShape()); } }
if (visiting.isObserved() || visiting.isProbabilistic()) { continue;
public static PartialsOf reverseModeAutoDiff(Vertex ofVertex, Set<DoubleVertex> wrt) { if (ofVertex.isObserved()) { return new PartialsOf(ofVertex, Collections.emptyMap()); } else { return reverseModeAutoDiff(ofVertex, Differentiable.withRespectToSelf(ofVertex.getShape()), wrt); } }
@Test public void itRestoresTheValueOfAnObservedVertex() { DoubleTensor originalValue = DoubleTensor.create(1., 2., 3.); DoubleTensor otherValue = DoubleTensor.create(4., 5., 6.); Vertex vertex = new GaussianVertex(1., 0.); vertex.observe(originalValue); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(true)); NetworkSnapshot snapshot = NetworkSnapshot.create(ImmutableSet.of(vertex)); vertex.observe(otherValue); assertThat(vertex.getValue(), equalTo(otherValue)); assertThat(vertex.isObserved(), is(true)); snapshot.apply(); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(true)); }
@Test public void itRestoresTheObservedStatusOfAnObservedVertex() { DoubleTensor originalValue = DoubleTensor.create(1., 2., 3.); DoubleTensor otherValue = DoubleTensor.create(4., 5., 6.); Vertex vertex = new GaussianVertex(1., 0.); vertex.observe(originalValue); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(true)); NetworkSnapshot snapshot = NetworkSnapshot.create(ImmutableSet.of(vertex)); vertex.unobserve(); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(false)); vertex.setValue(otherValue); assertThat(vertex.getValue(), equalTo(otherValue)); assertThat(vertex.isObserved(), is(false)); snapshot.apply(); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(true)); } }
private static <T> void updateVertexValue(Vertex<T> vertex) { if (vertex.isProbabilistic()) { if (!vertex.hasValue()) { vertex.setValue(vertex.sample()); } } else { if (!vertex.isObserved()) { vertex.setValue(((NonProbabilistic<T>) vertex).calculate()); } } } }
@Test public void itRestoresTheValueOfAnUnobservedVertex() { DoubleTensor originalValue = DoubleTensor.create(1., 2., 3.); DoubleTensor otherValue = DoubleTensor.create(4., 5., 6.); Vertex vertex = new GaussianVertex(1., 0.); vertex.setValue(originalValue); NetworkSnapshot snapshot = NetworkSnapshot.create(ImmutableSet.of(vertex)); vertex.setValue(otherValue); assertThat(vertex.getValue(), equalTo(otherValue)); assertThat(vertex.isObserved(), is(false)); snapshot.apply(); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(false)); }
private KeanuSavedBayesNet.StoredValue getStoredValue(Vertex vertex, KeanuSavedBayesNet.VertexValue value) { return KeanuSavedBayesNet.StoredValue.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString()).build()) .setValue(value) .setIsObserved(vertex.isObserved()) .build(); } }