public static <V extends Vertex & Differentiable> PartialsWithRespectTo forwardModeAutoDiff(V wrt, Collection<V> of) { PriorityQueue<V> priorityQueue = new PriorityQueue<>(Comparator.comparing(Vertex::getId, Comparator.naturalOrder())); priorityQueue.add(wrt); HashSet<Vertex> alreadyQueued = new HashSet<>(); alreadyQueued.add(wrt); Map<Vertex, PartialDerivative> partials = new HashMap<>(); Map<VertexId, PartialDerivative> ofWrt = new HashMap<>(); while (!priorityQueue.isEmpty()) { V visiting = priorityQueue.poll(); PartialDerivative partialOfVisiting = visiting.forwardModeAutoDifferentiation(partials); partials.put(visiting, partialOfVisiting); if (of.contains(visiting)) { ofWrt.put(visiting.getId(), partialOfVisiting); continue; } for (Vertex child : (Set<Vertex<?>>) visiting.getChildren()) { if (!child.isProbabilistic() && !alreadyQueued.contains(child) && child.isDifferentiable()) { priorityQueue.offer((V) child); alreadyQueued.add(child); } } } return new PartialsWithRespectTo(wrt, ofWrt); }
/** * Creates a Bayesian network from the graph connected to the given vertex and uses this to * create a BOBYQA {@link NonGradientOptimizer}. This provides methods for optimizing the * values of latent variables of the Bayesian network to maximise probability. * * @param vertexFromNetwork A vertex in the graph to create the Bayesian network from * @return a {@link NonGradientOptimizer} */ public NonGradientOptimizer ofConnectedGraph(Vertex<?> vertexFromNetwork) { return of(vertexFromNetwork.getConnectedGraph()); }
@Override public String toString() { StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(this.getId()); if (this.getLabel() != null) { stringBuilder.append(" (").append(this.getLabel()).append(")"); } stringBuilder.append(": "); stringBuilder.append(this.getClass().getSimpleName()); if (hasValue()) { stringBuilder.append("(" + getValue() + ")"); } return stringBuilder.toString(); }
private static boolean isUnobservedProbabilistic(Vertex vertex) { return vertex.isProbabilistic() && !vertex.isObserved(); }
private static <T> void updateVertexValue(Vertex<T> vertex) { if (vertex.isProbabilistic()) { if (!vertex.hasValue()) { vertex.setValue(vertex.sample()); } } else { if (!vertex.isObserved()) { vertex.setValue(((NonProbabilistic<T>) vertex).calculate()); } } } }
@Test public void itRestoresTheObservedStatusOfAnObservedVertex() { DoubleTensor originalValue = DoubleTensor.create(1., 2., 3.); DoubleTensor otherValue = DoubleTensor.create(4., 5., 6.); Vertex vertex = new GaussianVertex(1., 0.); vertex.observe(originalValue); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(true)); NetworkSnapshot snapshot = NetworkSnapshot.create(ImmutableSet.of(vertex)); vertex.unobserve(); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(false)); vertex.setValue(otherValue); assertThat(vertex.getValue(), equalTo(otherValue)); assertThat(vertex.isObserved(), is(false)); snapshot.apply(); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(true)); } }
public DoubleUnaryOpLambda(Vertex<IN> inputVertex, Function<IN, DoubleTensor> op, Function<Map<Vertex, PartialDerivative>, PartialDerivative> forwardModeAutoDiffLambda, Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda) { this(inputVertex.getShape(), inputVertex, op, forwardModeAutoDiffLambda, reverseModeAutoDiffLambda); }
private KeanuSavedBayesNet.Vertex buildVertex(Vertex vertex) { KeanuSavedBayesNet.Vertex.Builder vertexBuilder = KeanuSavedBayesNet.Vertex.newBuilder(); if (vertex.getLabel() != null) { vertexBuilder = vertexBuilder.setLabel(vertex.getLabel().toString()); } vertexBuilder = vertexBuilder.setId(KeanuSavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString())); vertexBuilder = vertexBuilder.setVertexType(vertex.getClass().getCanonicalName()); vertexBuilder = vertexBuilder.addAllShape(Longs.asList(vertex.getShape())); saveParams(vertexBuilder, vertex); return vertexBuilder.build(); }
@Test public void itRestoresTheValueOfAnObservedVertex() { DoubleTensor originalValue = DoubleTensor.create(1., 2., 3.); DoubleTensor otherValue = DoubleTensor.create(4., 5., 6.); Vertex vertex = new GaussianVertex(1., 0.); vertex.observe(originalValue); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(true)); NetworkSnapshot snapshot = NetworkSnapshot.create(ImmutableSet.of(vertex)); vertex.observe(otherValue); assertThat(vertex.getValue(), equalTo(otherValue)); assertThat(vertex.isObserved(), is(true)); snapshot.apply(); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(true)); }
@Test public void itRestoresTheValueOfAnUnobservedVertex() { DoubleTensor originalValue = DoubleTensor.create(1., 2., 3.); DoubleTensor otherValue = DoubleTensor.create(4., 5., 6.); Vertex vertex = new GaussianVertex(1., 0.); vertex.setValue(originalValue); NetworkSnapshot snapshot = NetworkSnapshot.create(ImmutableSet.of(vertex)); vertex.setValue(otherValue); assertThat(vertex.getValue(), equalTo(otherValue)); assertThat(vertex.isObserved(), is(false)); snapshot.apply(); assertThat(vertex.getValue(), equalTo(originalValue)); assertThat(vertex.isObserved(), is(false)); }
@BeforeClass public static void setUp() { v = Mockito.mock(Vertex.class); Set<Vertex> parents = new HashSet<>(); parent = Mockito.mock(Vertex.class); parents.add(parent); Set<Vertex> children = new HashSet<>(); child = Mockito.mock(Vertex.class); children.add(child); when(v.getParents()).thenReturn(parents); when(v.getChildren()).thenReturn(children); Set<Vertex> grandParents = new HashSet<>(); grandParent = Mockito.mock(Vertex.class); grandParents.add(grandParent); Set<Vertex> grandChildren = new HashSet<>(); grandChild = Mockito.mock(Vertex.class); grandChildren.add(grandChild); when(parent.getParents()).thenReturn(grandParents); when(child.getChildren()).thenReturn(grandChildren); when(v.getConnectedGraph()).thenReturn(ImmutableSet.of(v, parent, child, grandChild, grandParent)); network = new BayesianNetwork(v.getConnectedGraph()); }
@Override public VariableReference getReference() { return getId(); }
default PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) { if (((Vertex) this).isObserved()) { return PartialDerivative.EMPTY; } else { return withRespectToSelf(((Vertex) this).getShape()); } }
private static Set<Vertex> allParentsOf(Collection<Vertex> vertices) { Set<Vertex> allParents = new HashSet<>(); for (Vertex vertex : vertices) { allParents.addAll(vertex.getParents()); } return allParents; }
Stream<Vertex> verticesToAdd = Stream.concat(v.getParents().stream(), v.getChildren().stream()); verticesToAdd .filter(a -> !subgraphVertices.contains(a))
private static Collection<Vertex> getParentsIfVertexIsNotProbabilistic(Vertex visiting) { return visiting.isProbabilistic() ? Collections.emptySet() : visiting.getParents(); }
private List<Vertex> getFilteredVertexList(VertexFilter filter) { return vertices.stream() .filter(v -> filter.filter(v.isProbabilistic(), v.isObserved(), v.getIndentation())) .collect(Collectors.toList()); }