@Override public BayesianNetwork loadNetwork(InputStream input) throws IOException { String jsonInput = IOUtils.toString(input, StandardCharsets.UTF_8); KeanuSavedBayesNet.Model.Builder modelBuilder = KeanuSavedBayesNet.Model.newBuilder(); JsonFormat.parser().merge(jsonInput, modelBuilder); return protobufLoader.loadNetwork(modelBuilder.build()); }
@Test public void jsonSaverSavesMetadata() throws IOException { KeanuSavedBayesNet.Metadata.Builder metadataBuilder = KeanuSavedBayesNet.Metadata.newBuilder(); for (Map.Entry<String, String> entry : someMetadata.entrySet()) { metadataBuilder.putMetadataInfo(entry.getKey(), entry.getValue()); } KeanuSavedBayesNet.Model.Builder modelBuilder = KeanuSavedBayesNet.Model.newBuilder(); JsonFormat.parser().merge(outputStream.toString(), modelBuilder); KeanuSavedBayesNet.Model parsedModel = modelBuilder.build(); assertTrue(parsedModel.hasMetadata()); assertEquals(parsedModel.getMetadata().getMetadataInfoMap().size(), (metadataBuilder.getMetadataInfoMap().size())); assertThat(parsedModel.getMetadata().getMetadataInfoMap().entrySet(), everyItem(isIn(metadataBuilder.getMetadataInfoMap().entrySet()))); }
@Test public void loadFailsIfNoConstantSpecified() throws IOException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Failed to create vertex due to missing parent: constant"); KeanuSavedBayesNet.Vertex constantVertex = KeanuSavedBayesNet.Vertex.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .build(); KeanuSavedBayesNet.BayesianNetwork savedNet = KeanuSavedBayesNet.BayesianNetwork.newBuilder() .addVertices(constantVertex).build(); KeanuSavedBayesNet.Model savedModel = KeanuSavedBayesNet.Model.newBuilder() .setNetwork(savedNet) .build(); ByteArrayOutputStream writer = new ByteArrayOutputStream(); savedModel.writeTo(writer); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(new ByteArrayInputStream(writer.toByteArray())); }
.setNetwork(savedNet) .setNetworkState(savedNetState) .build();
.setNetwork(savedNet) .setNetworkState(savedNetState) .build();
@Test public void loadFailsIfWrongArgumentTypeSpecified() throws IOException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Incorrect Parameter Type specified. " + "Got: class io.improbable.keanu.tensor.intgr.ScalarIntegerTensor, " + "Expected: interface io.improbable.keanu.tensor.dbl.DoubleTensor"); KeanuSavedBayesNet.Vertex constantVertex = KeanuSavedBayesNet.Vertex.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setIntTensorParam(KeanuSavedBayesNet.IntegerTensor.newBuilder() .addAllShape(Longs.asList()).addValues(1).build() ).build()) .build(); KeanuSavedBayesNet.BayesianNetwork savedNet = KeanuSavedBayesNet.BayesianNetwork.newBuilder() .addVertices(constantVertex).build(); KeanuSavedBayesNet.Model savedModel = KeanuSavedBayesNet.Model.newBuilder() .setNetwork(savedNet) .build(); ByteArrayOutputStream writer = new ByteArrayOutputStream(); savedModel.writeTo(writer); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(new ByteArrayInputStream(writer.toByteArray())); }
@Test public void loadFailsIfInvalidVertexSpecified() throws IOException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Unknown Vertex Type Specified: made.up.vertex.NonExistentVertex"); KeanuSavedBayesNet.Vertex constantVertex = KeanuSavedBayesNet.Vertex.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType("made.up.vertex.NonExistentVertex") .build(); KeanuSavedBayesNet.BayesianNetwork savedNet = KeanuSavedBayesNet.BayesianNetwork.newBuilder() .addVertices(constantVertex).build(); KeanuSavedBayesNet.Model savedModel = KeanuSavedBayesNet.Model.newBuilder() .setNetwork(savedNet) .build(); ByteArrayOutputStream writer = new ByteArrayOutputStream(); savedModel.writeTo(writer); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(new ByteArrayInputStream(writer.toByteArray())); }
protected KeanuSavedBayesNet.Model getModel(boolean withSavedValues, Map<String, String> metadata) { createProtobufModel(withSavedValues, metadata); return modelBuilder.build(); }