/** * Sets up a new json saver for the given network. * @param net network that will be saved */ public JsonSaver(BayesianNetwork net) { protobufSaver = new ProtobufSaver(net); }
public void saveNetToProtobuf(BayesianNetwork net, OutputStream outputStream, boolean saveValuesAndObservations) throws IOException { NetworkSaver saver = new ProtobufSaver(net); saver.save(outputStream, saveValuesAndObservations); } //%%SNIPPET_END%% SaveToProtobuf
private BayesianNetwork saveLoad(final BayesianNetwork net) throws IOException { final ByteArrayOutputStream output = new ByteArrayOutputStream(); final ProtobufSaver protobufSaver = new ProtobufSaver(net); protobufSaver.save(output, true); assertThat(output.size(), greaterThan(0)); final ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray()); final ProtobufLoader loader = new ProtobufLoader(); return loader.loadNetwork(input); }
@Test public void metadataCanBeSavedToProtobuf() throws IOException { Vertex vertex = new ConstantIntegerVertex(1); BayesianNetwork net = new BayesianNetwork(vertex.getConnectedGraph()); Map<String, String> metadata = ImmutableMap.of("Author", "Some Author", "Tag", "MyBayesNet"); ByteArrayOutputStream writer = new ByteArrayOutputStream(); ProtobufSaver protobufSaver = new ProtobufSaver(net); protobufSaver.save(writer, true, metadata); KeanuSavedBayesNet.Model parsedModel = KeanuSavedBayesNet.Model.parseFrom(writer.toByteArray()); KeanuSavedBayesNet.Metadata.Builder metadataBuilder = KeanuSavedBayesNet.Metadata.newBuilder(); String[] metadataKeys = metadata.keySet().toArray(new String[0]); Arrays.sort(metadataKeys); for (String metadataKey : metadataKeys) { metadataBuilder.putMetadataInfo(metadataKey, metadata.get(metadataKey)); } assertEquals(parsedModel.getMetadata().getMetadataInfoMap(), metadataBuilder.getMetadataInfoMap()); }
@Test(expected = IllegalArgumentException.class) public void nonSaveableVertexThrowsExceptionOnSave() { DoubleVertex testVertex = new TestNonSaveableVertex(); BayesianNetwork net = new BayesianNetwork(testVertex.getConnectedGraph()); ProtobufSaver protobufSaver = new ProtobufSaver(net); testVertex.save(protobufSaver); }
@Test public void youCanSaveAndLoadANetworkWithValues() throws IOException { final String gaussianLabel = "Gaussian"; DoubleVertex mu1 = new ConstantDoubleVertex(new double[]{3.0, 1.0}); DoubleVertex mu2 = new ConstantDoubleVertex(new double[]{5.0, 6.0}); DoubleVertex finalMu = new ConcatenationVertex(0, mu1, mu2); DoubleVertex gaussianVertex = new GaussianVertex(finalMu, 1.0); gaussianVertex.setLabel(gaussianLabel); BayesianNetwork net = new BayesianNetwork(gaussianVertex.getConnectedGraph()); ByteArrayOutputStream output = new ByteArrayOutputStream(); ProtobufSaver protobufSaver = new ProtobufSaver(net); protobufSaver.save(output, true); assertThat(output.size(), greaterThan(0)); ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray()); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(input); assertThat(readNet.getLatentVertices().size(), is(1)); assertThat(readNet.getLatentVertices().get(0), instanceOf(GaussianVertex.class)); GaussianVertex latentGaussianVertex = (GaussianVertex) readNet.getLatentVertices().get(0); GaussianVertex labelGaussianVerted = (GaussianVertex) readNet.getVertexByLabel(new VertexLabel(gaussianLabel)); assertThat(latentGaussianVertex, equalTo(labelGaussianVerted)); assertThat(latentGaussianVertex.getMu().getValue(0), closeTo(3.0, 1e-10)); assertThat(labelGaussianVerted.getMu().getValue(2), closeTo(5.0, 1e-10)); assertThat(latentGaussianVertex.getSigma().getValue().scalar(), closeTo(1.0, 1e-10)); latentGaussianVertex.sample(); }
@Test public void shapeIsCorrectlySavedAndLoaded() throws IOException { long[] shape1 = new long[]{2, 3}; long[] shape2 = new long[]{3, 2}; final VertexLabel LABEL_ONE = new VertexLabel("Vertex1"); final VertexLabel LABEL_TWO = new VertexLabel("Vertex2"); DoubleVertex gaussianVertex1 = new GaussianVertex(shape1, 0.0, 1.0); gaussianVertex1.setLabel(LABEL_ONE); DoubleVertex gaussianVertex2 = new GaussianVertex(shape2, 0.0, 1.0); gaussianVertex2.setLabel(LABEL_TWO); DoubleVertex output = gaussianVertex1.matrixMultiply(gaussianVertex2); BayesianNetwork bayesNet = new BayesianNetwork(output.getConnectedGraph()); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ProtobufSaver saver = new ProtobufSaver(bayesNet); saver.save(outputStream, false); ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray()); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(inputStream); Vertex vertexToShapeCheck = readNet.getVertexByLabel(LABEL_ONE); assertThat(vertexToShapeCheck.getShape(), is(shape1)); vertexToShapeCheck = readNet.getVertexByLabel(LABEL_TWO); assertThat(vertexToShapeCheck.getShape(), is(shape2)); }
@Test public void saveLoadGradientTest() throws IOException { BayesianNetwork complexNet = createComplexNet(); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ProtobufSaver saver = new ProtobufSaver(complexNet); saver.save(outputStream, true); DoubleIfVertex outputVertex = (DoubleIfVertex) complexNet.getVertexByLabel(new VertexLabel(OUTPUT_NAME)); GaussianVertex inputVertex = (GaussianVertex) complexNet.getVertexByLabel(new VertexLabel(INPUT_NAME)); ByteArrayInputStream input = new ByteArrayInputStream(outputStream.toByteArray()); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork loadedNet = loader.loadNetwork(input); DoubleIfVertex outputVertex2 = (DoubleIfVertex) loadedNet.getVertexByLabel(new VertexLabel(OUTPUT_NAME)); GaussianVertex inputVertex2 = (GaussianVertex) loadedNet.getVertexByLabel(new VertexLabel(INPUT_NAME)); DoubleTensor dOutputBefore = Differentiator.forwardModeAutoDiff( inputVertex, outputVertex ).of(outputVertex); DoubleTensor dOutputAfter = Differentiator.forwardModeAutoDiff( inputVertex2, outputVertex2 ).of(outputVertex2); assertEquals(dOutputBefore, dOutputAfter); dOutputBefore = Differentiator.reverseModeAutoDiff(outputVertex, inputVertex).withRespectTo(inputVertex); dOutputAfter = Differentiator.reverseModeAutoDiff(outputVertex2, inputVertex2).withRespectTo(inputVertex2); assertEquals(dOutputBefore, dOutputAfter); }