private KeanuSavedBayesNet.NamedParam getParam(String paramName, Consumer<KeanuSavedBayesNet.NamedParam.Builder> valueSetter) { KeanuSavedBayesNet.NamedParam.Builder paramBuilder = KeanuSavedBayesNet.NamedParam.newBuilder(); paramBuilder.setName(paramName); valueSetter.accept(paramBuilder); return paramBuilder.build(); }
private KeanuSavedBayesNet.NamedParam getTypedParam(String paramName, Object param) { if (Vertex.class.isAssignableFrom(param.getClass())) { return getParam(paramName, (Vertex)param); } else if (DoubleTensor.class.isAssignableFrom(param.getClass())){ return getParam(paramName, builder -> builder.setDoubleTensorParam(getTensor((DoubleTensor) param))); } else if (IntegerTensor.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setIntTensorParam(getTensor((IntegerTensor) param))); } else if (BooleanTensor.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setBoolTensorParam(getTensor((BooleanTensor) param))); } else if (Double.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setDoubleParam((double) param)); } else if (Integer.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setIntParam((int) param)); } else if (Long.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setLongParam((long) param)); } else if (String.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setStringParam((String) param)); } else if (Boolean.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setBoolParam((boolean) param)); } else if (Long[].class.isAssignableFrom(param.getClass())) { return getParam(paramName, (long[]) param); } else if (Vertex[].class.isAssignableFrom(param.getClass())) { return getParam(paramName, (Vertex[]) param); } else if (Integer[].class.isAssignableFrom(param.getClass())) { return getParam(paramName, (int[]) param); } else { throw new IllegalArgumentException("Unknown Parameter Type to Save: " + param.getClass().toString()); } }
.setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .addAllShape(Longs.asList(1, 1)).addValues(1.0).build() ).build()) .build(); .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .addAllShape(Longs.asList(1, 1)).addValues(2.0).build() ).build()) .build(); .setVertexType(GaussianVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("mu") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("1").build()) .build() ).addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("sigma") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("2").build()) .build()
.setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .addAllShape(Longs.asList(1, 1)).addValues(1.0).build() ).build()) .build();
.setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .addAllShape(Longs.asList(1, 1)) .addValues(0.0).build()) .build()) .build(); .setVertexType(GaussianVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("mu") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("1").build()) .build()
@Test public void loadFailsIfWrongArgumentTypeSpecified() throws IOException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Incorrect Parameter Type specified. " + "Got: class io.improbable.keanu.tensor.intgr.ScalarIntegerTensor, " + "Expected: interface io.improbable.keanu.tensor.dbl.DoubleTensor"); KeanuSavedBayesNet.Vertex constantVertex = KeanuSavedBayesNet.Vertex.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setIntTensorParam(KeanuSavedBayesNet.IntegerTensor.newBuilder() .addAllShape(Longs.asList()).addValues(1).build() ).build()) .build(); KeanuSavedBayesNet.BayesianNetwork savedNet = KeanuSavedBayesNet.BayesianNetwork.newBuilder() .addVertices(constantVertex).build(); KeanuSavedBayesNet.Model savedModel = KeanuSavedBayesNet.Model.newBuilder() .setNetwork(savedNet) .build(); ByteArrayOutputStream writer = new ByteArrayOutputStream(); savedModel.writeTo(writer); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(new ByteArrayInputStream(writer.toByteArray())); }
private KeanuSavedBayesNet.NamedParam getParam(String paramName, int[] param) { return getParam(paramName, builder -> builder.setIntArrayParam( KeanuSavedBayesNet.IntArray.newBuilder().addAllValues(Ints.asList(param)))); }
private KeanuSavedBayesNet.NamedParam getParam(String paramName, Vertex parent) { return getParam(paramName, builder -> builder.setParentVertex( KeanuSavedBayesNet.VertexID.newBuilder().setId(parent.getId().toString()) ) ); }