/** * @param model {@link KMeansModel} to translate to PMML * @return PMML representation of a KMeans cluster model */ private PMML kMeansModelToPMML(KMeansModel model, Map<Integer,Long> clusterSizesMap) { ClusteringModel clusteringModel = pmmlClusteringModel(model, clusterSizesMap); PMML pmml = PMMLUtils.buildSkeletonPMML(); pmml.setDataDictionary(AppPMMLUtils.buildDataDictionary(inputSchema, null)); pmml.addModels(clusteringModel); return pmml; }
public static PMML readPMMLFromUpdateKeyMessage(String key, String message, Configuration hadoopConf) throws IOException { String pmmlString; switch (key) { case "MODEL": pmmlString = message; break; case "MODEL-REF": // Allowing null is mostly for integration tests if (hadoopConf == null) { hadoopConf = new Configuration(); } Path messagePath = new Path(message); FileSystem fs = FileSystem.get(messagePath.toUri(), hadoopConf); try (InputStreamReader in = new InputStreamReader(fs.open(messagePath), StandardCharsets.UTF_8)) { pmmlString = CharStreams.toString(in); } catch (FileNotFoundException fnfe) { log.warn("Unable to load model file at {}; ignoring", messagePath); return null; } break; default: throw new IllegalArgumentException("Unknown key " + key); } return PMMLUtils.fromString(pmmlString); }
/** * @param path file to read PMML from * @return {@link PMML} model file from path * @throws IOException if an error occurs while reading the model from storage */ public static PMML read(Path path) throws IOException { try (InputStream in = Files.newInputStream(path)) { return read(in); } }
@Test public void testReadPMMLFromMessage() throws Exception { PMML pmml = PMMLUtils.buildSkeletonPMML(); String pmmlString = PMMLUtils.toString(pmml); assertEquals(PMMLUtils.VERSION, AppPMMLUtils.readPMMLFromUpdateKeyMessage( "MODEL", pmmlString, null).getVersion()); Path pmmlPath = getTempDir().resolve("out.pmml"); Files.write(pmmlPath, Collections.singleton(pmmlString)); assertEquals(PMMLUtils.VERSION, AppPMMLUtils.readPMMLFromUpdateKeyMessage( "MODEL-REF", pmmlPath.toAbsolutePath().toString(), null).getVersion()); assertNull(AppPMMLUtils.readPMMLFromUpdateKeyMessage("MODEL-REF", "no-such-path", null)); }
@Test public void testExtensionValue() throws Exception { PMML model = buildDummyModel(); assertNull(AppPMMLUtils.getExtensionValue(model, "foo")); AppPMMLUtils.addExtension(model, "foo", "bar"); PMML reserializedModel = PMMLUtils.fromString(PMMLUtils.toString(model)); assertEquals("bar", AppPMMLUtils.getExtensionValue(reserializedModel, "foo")); }
@Test public void testReadWrite() throws Exception { Path tempModelFile = Files.createTempFile(getTempDir(), "model", ".pmml"); PMML model = buildDummyModel(); PMMLUtils.write(model, tempModelFile); assertTrue(Files.exists(tempModelFile)); PMML model2 = PMMLUtils.read(tempModelFile); List<Model> models = model2.getModels(); assertEquals(1, models.size()); assertInstanceOf(models.get(0), TreeModel.class); TreeModel treeModel = (TreeModel) models.get(0); assertEquals(123.0, treeModel.getNode().getRecordCount().doubleValue()); assertEquals(MiningFunction.CLASSIFICATION, treeModel.getMiningFunction()); }
@Override public Pair<String,String> generate(int id, RandomGenerator random) { PMML pmml = RDFPMMLUtilsTest.buildDummyRegressionModel(); return new Pair<>("MODEL", PMMLUtils.toString(pmml)); }
/** * @param pmml {@link PMML} model to write * @param path file to write the model to * @throws IOException if an error occurs while writing the model to storage */ public static void write(PMML pmml, Path path) throws IOException { try (OutputStream out = Files.newOutputStream(path)) { write(pmml, out); } }
@Test public void testExtensionContent() throws Exception { PMML model = buildDummyModel(); assertNull(AppPMMLUtils.getExtensionContent(model, "foo")); AppPMMLUtils.addExtensionContent(model, "foo1", Arrays.asList("bar", "baz")); AppPMMLUtils.addExtensionContent(model, "foo2", Collections.emptyList()); AppPMMLUtils.addExtensionContent(model, "foo3", Arrays.asList(" c\" d \"e ", " c\" d \"e ")); PMML reserializedModel = PMMLUtils.fromString(PMMLUtils.toString(model)); assertEquals(Arrays.asList("bar", "baz"), AppPMMLUtils.getExtensionContent(reserializedModel, "foo1")); assertNull(AppPMMLUtils.getExtensionContent(reserializedModel, "foo2")); assertEquals(Arrays.asList(" c\" d \"e ", " c\" d \"e "), AppPMMLUtils.getExtensionContent(reserializedModel, "foo3")); }
@Override public Pair<String,String> generate(int id, RandomGenerator random) { if (id % 10 == 0) { PMML pmml = PMMLUtils.buildSkeletonPMML(); AppPMMLUtils.addExtension(pmml, "features", 2); AppPMMLUtils.addExtension(pmml, "implicit", true); AppPMMLUtils.addExtensionContent(pmml, "XIDs", X.keySet()); AppPMMLUtils.addExtensionContent(pmml, "YIDs", Y.keySet()); return new Pair<>("MODEL", PMMLUtils.toString(pmml)); } else { int xOrYID = id % 10; String xOrYIDString = ALSUtilsTest.idToStringID(id); String message; boolean isX = xOrYID >= 6; if (isX) { message = TextUtils.joinJSON(Arrays.asList( "X", xOrYIDString, X.get(xOrYIDString), A.get(xOrYIDString))); } else { message = TextUtils.joinJSON(Arrays.asList( "Y", xOrYIDString, Y.get(xOrYIDString), At.get(xOrYIDString))); } return new Pair<>("UP", message); } }
@Override public Pair<String,String> generate(int id, RandomGenerator random) { if (id == 0) { PMML pmml = RDFPMMLUtilsTest.buildDummyClassificationModel(); return new Pair<>("MODEL", PMMLUtils.toString(pmml)); } else { String nodeID = "r" + ((id % 2 == 0) ? '-' : '+'); Map<Integer,Integer> counts = new HashMap<>(); counts.put(0, 1); counts.put(1, 2); return new Pair<>("UP", TextUtils.joinJSON(Arrays.asList(0, nodeID, counts))); } }
fs.mkdirs(candidatePath); try (OutputStream out = fs.create(modelPath)) { PMMLUtils.write(model, out);
@Test public void testFromString() throws Exception { PMML model = buildDummyModel(); PMML model2 = PMMLUtils.fromString(PMMLUtils.toString(model)); assertEquals(model.getHeader().getApplication().getName(), model2.getHeader().getApplication().getName()); assertEquals(model.getModels().get(0).getMiningFunction(), model2.getModels().get(0).getMiningFunction()); }
public static PMML buildDummyModel() { Node node = new Node().setRecordCount(123.0); TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, null, node); PMML pmml = PMMLUtils.buildSkeletonPMML(); pmml.addModels(treeModel); return pmml; }
assertTrue("No such model file: " + modelFile, Files.exists(modelFile)); PMML pmml = PMMLUtils.read(modelFile); Model rootModel = pmml.getModels().get(0); ClusteringModel clusteringModel = (ClusteringModel) rootModel;
@Test public void testPreviousPMMLVersion() throws Exception { String pmml42 = "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?>\n" + "<PMML xmlns=\"http://www.dmg.org/PMML-4_2\" version=\"4.2.1\">\n" + " <Header>\n" + " <Application name=\"Oryx\"/>\n" + " </Header>\n" + " <TreeModel functionName=\"classification\">\n" + " <Node recordCount=\"123.0\"/>\n" + " </TreeModel>\n" + "</PMML>\n"; PMML model = PMMLUtils.fromString(pmml42); // Actually transforms to latest version: assertEquals(PMMLUtils.VERSION, model.getVersion()); }
@Override public Pair<String,String> generate(int id, RandomGenerator random) { if (id == 0) { PMML pmml = KMeansPMMLUtilsTest.buildDummyClusteringModel(); return new Pair<>("MODEL", PMMLUtils.toString(pmml)); } else { List<?> data = Arrays.asList(id % 3, Arrays.asList(id, id), id); return new Pair<>("UP", TextUtils.joinJSON(data)); } }
/** * @param pmml {@link PMML} model to write * @param path file to write the model to * @throws IOException if an error occurs while writing the model to storage */ public static void write(PMML pmml, Path path) throws IOException { try (OutputStream out = Files.newOutputStream(path)) { write(pmml, out); } }