private KeanuSavedBayesNet.NamedParam getParam(String paramName, Consumer<KeanuSavedBayesNet.NamedParam.Builder> valueSetter) { KeanuSavedBayesNet.NamedParam.Builder paramBuilder = KeanuSavedBayesNet.NamedParam.newBuilder(); paramBuilder.setName(paramName); valueSetter.accept(paramBuilder); return paramBuilder.build(); }
@Test public void loadFailsIfWrongArgumentTypeSpecified() throws IOException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Incorrect Parameter Type specified. " + "Got: class io.improbable.keanu.tensor.intgr.ScalarIntegerTensor, " + "Expected: interface io.improbable.keanu.tensor.dbl.DoubleTensor"); KeanuSavedBayesNet.Vertex constantVertex = KeanuSavedBayesNet.Vertex.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setIntTensorParam(KeanuSavedBayesNet.IntegerTensor.newBuilder() .addAllShape(Longs.asList()).addValues(1).build() ).build()) .build(); KeanuSavedBayesNet.BayesianNetwork savedNet = KeanuSavedBayesNet.BayesianNetwork.newBuilder() .addVertices(constantVertex).build(); KeanuSavedBayesNet.Model savedModel = KeanuSavedBayesNet.Model.newBuilder() .setNetwork(savedNet) .build(); ByteArrayOutputStream writer = new ByteArrayOutputStream(); savedModel.writeTo(writer); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(new ByteArrayInputStream(writer.toByteArray())); }
.setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .addAllShape(Longs.asList(1, 1)).addValues(1.0).build() ).build()) .build(); .setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .addAllShape(Longs.asList(1, 1)).addValues(2.0).build() ).build()) .build(); .setName("mu") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("1").build()) .build() ).addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("sigma") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("2").build()) .build()
.setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .addAllShape(Longs.asList(1, 1)).addValues(1.0).build() ).build()) .build();