/** * Validates that the encoded PMML model received matches expected schema. * * @param pmml {@link PMML} encoding of KMeans Clustering * @param schema expected schema attributes of KMeans Clustering */ public static void validatePMMLVsSchema(PMML pmml, InputSchema schema) { List<Model> models = pmml.getModels(); Preconditions.checkArgument(models.size() == 1, "Should have exactly one model, but had %s", models.size()); Model model = models.get(0); Preconditions.checkArgument(model instanceof ClusteringModel); Preconditions.checkArgument(model.getMiningFunction() == MiningFunction.CLUSTERING); DataDictionary dictionary = pmml.getDataDictionary(); Preconditions.checkArgument( schema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(dictionary)), "Feature names in schema don't match names in PMML"); MiningSchema miningSchema = model.getMiningSchema(); Preconditions.checkArgument(schema.getFeatureNames().equals( AppPMMLUtils.getFeatureNames(miningSchema))); }
/** * @return {@link PMML} with common {@link Header} fields like {@link Application}, * {@link Timestamp}, and version filled out */ public static PMML buildSkeletonPMML() { String formattedDate = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ssZZ", Locale.ENGLISH).format(new Date()); Header header = new Header() .setTimestamp(new Timestamp().addContent(formattedDate)) .setApplication(new Application("Oryx")); return new PMML(VERSION, header, null); }
/** * @param model {@link KMeansModel} to translate to PMML * @return PMML representation of a KMeans cluster model */ private PMML kMeansModelToPMML(KMeansModel model, Map<Integer,Long> clusterSizesMap) { ClusteringModel clusteringModel = pmmlClusteringModel(model, clusterSizesMap); PMML pmml = PMMLUtils.buildSkeletonPMML(); pmml.setDataDictionary(AppPMMLUtils.buildDataDictionary(inputSchema, null)); pmml.addModels(clusteringModel); return pmml; }
@Test public void testFromString() throws Exception { PMML model = buildDummyModel(); PMML model2 = PMMLUtils.fromString(PMMLUtils.toString(model)); assertEquals(model.getHeader().getApplication().getName(), model2.getHeader().getApplication().getName()); assertEquals(model.getModels().get(0).getMiningFunction(), model2.getModels().get(0).getMiningFunction()); }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if (status == VisitorAction.CONTINUE) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getHeader(), getMiningBuildTask(), getDataDictionary(), getTransformationDictionary()); } if ((status == VisitorAction.CONTINUE)&&hasModels()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getModels()); } if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getExtensions()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Test public void copyState(){ PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), new DataDictionary()); // Initialize the live list instance pmml.getModels(); CustomPMML customPmml = new CustomPMML(); ReflectionUtil.copyState(pmml, customPmml); assertSame(pmml.getVersion(), customPmml.getVersion()); assertSame(pmml.getHeader(), customPmml.getHeader()); assertSame(pmml.getDataDictionary(), customPmml.getDataDictionary()); assertFalse(pmml.hasModels()); assertFalse(customPmml.hasModels()); pmml.addModels(new RegressionModel()); assertTrue(pmml.hasModels()); assertTrue(customPmml.hasModels()); assertSame(pmml.getModels(), customPmml.getModels()); try { ReflectionUtil.copyState(customPmml, pmml); fail(); } catch(IllegalArgumentException iae){ // Ignored } }
/** * @param pmml PMML representation of Clusters * @return List of {@link ClusterInfo} */ public static List<ClusterInfo> read(PMML pmml) { Model model = pmml.getModels().get(0); Preconditions.checkArgument(model instanceof ClusteringModel); ClusteringModel clusteringModel = (ClusteringModel) model; return clusteringModel.getClusters().stream().map(cluster -> new ClusterInfo(Integer.parseInt(cluster.getId()), VectorMath.parseVector(TextUtils.parseDelimited(cluster.getArray().getValue(), ' ')), cluster.getSize()) ).collect(Collectors.toList()); }
PMML pmml = AppPMMLUtils.readPMMLFromUpdateKeyMessage(type, value, null); assertNotNull(pmml); checkHeader(pmml.getHeader()); assertEquals(3, pmml.getExtensions().size()); Map<String,Object> expected = new HashMap<>(); expected.put("maxDepth", MAX_DEPTH); checkExtensions(pmml, expected); checkDataDictionary(schema, pmml.getDataDictionary()); Model rootModel = pmml.getModels().get(0); if (rootModel instanceof TreeModel) { assertEquals(NUM_TREES, 1);
public PMML build(BasicML basicML) { PMML pmml = new PMML(); pmml.setHeader(header); header.setCopyright(" Copyright [2013-2017] PayPal Software Foundation\n" + "\n" + " Licensed under the Apache License, Version 2.0 (the \"License\");\n" pmml.setDataDictionary(dataDictionaryCreator.build(basicML)); List<Model> models = pmml.getModels(); Model miningModel = modelCreator.convert(((TreeModel) basicML).getIndependentTreeModel()); models.add(miningModel);
PMML pmml = AppPMMLUtils.readPMMLFromUpdateKeyMessage(type, value, null); assertNotNull(pmml); checkHeader(pmml.getHeader()); checkDataDictionary(schema, pmml.getDataDictionary()); Model rootModel = pmml.getModels().get(0);
@Test public void filterChainedSegmentation() throws Exception { PMML pmml = ResourceUtil.unmarshal(ChainedSegmentationTest.class, new SkipFilter("Segmentation")); assertNotNull(pmml.getDataDictionary()); assertNotNull(pmml.getTransformationDictionary()); List<Model> models = pmml.getModels(); MiningModel miningModel = (MiningModel)models.get(0); assertNotNull(miningModel.getMiningSchema()); assertNotNull(miningModel.getOutput()); assertNull(miningModel.getSegmentation()); }
@Override public VisitorAction visit(PMML pmml){ DataDictionary dataDictionary = pmml.getDataDictionary(); if(dataDictionary != null && dataDictionary.hasDataFields()){ declare(pmml, dataDictionary.getDataFields()); } TransformationDictionary transformationDictionary = pmml.getTransformationDictionary(); if(transformationDictionary != null && transformationDictionary.hasDerivedFields()){ declare(pmml, transformationDictionary.getDerivedFields()); } return super.visit(pmml); }
AppPMMLUtils.buildCategoricalValueEncodings(pmml.getDataDictionary()); log.info("{}", encodings); Map<String,Integer> fruitEncoding = encodings.getValueEncodingMap(0);
static public Model findModel(PMML pmml, Predicate<Model> predicate, String predicateXPath){ if(!pmml.hasModels()){ throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(pmml.getClass()) + "/" + predicateXPath), pmml); } List<Model> models = pmml.getModels(); Optional<Model> result = models.stream() .filter(predicate) .findAny(); if(!result.isPresent()){ throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(pmml.getClass()) + "/" + predicateXPath), pmml); } return result.get(); } }
TransformationDictionary transformationDictionary = pmml.getTransformationDictionary(); List<Model> models = pmml.getModels(); models.clear();
public static String getExtensionValue(PMML pmml, String name) { return pmml.getExtensions().stream().filter(extension -> name.equals(extension.getName())).findFirst(). map(Extension::getValue).orElse(null); }
.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), new DataDictionary()) .addModels(treeModel); pmml.accept(skipVisitor); pmml.accept(skipLeftVisitor); pmml.accept(terminateVisitor); pmml.accept(terminateLeftVisitor); pmml.accept(parentVisitor);
public static PMML buildDummyModel() { Node node = new Node().setRecordCount(123.0); TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, null, node); PMML pmml = PMMLUtils.buildSkeletonPMML(); pmml.addModels(treeModel); return pmml; }
/** * Quite manually write our fake model representation in PMML. */ private static void write(OutputStream out, ALSModelDescription model) throws JAXBException { PMML pmml = new PMML("4.2.1", null, null); for (Map.Entry<String,String> entry : model.getPathByKey().entrySet()) { Extension extension = new Extension(); extension.setName(entry.getKey()); extension.setValue(entry.getValue()); pmml.getExtensions().add(extension); } JAXBUtil.marshalPMML(pmml, new StreamResult(out)); }