/** * Validates that the encoded PMML model received matches expected schema. * * @param pmml {@link PMML} encoding of KMeans Clustering * @param schema expected schema attributes of KMeans Clustering */ public static void validatePMMLVsSchema(PMML pmml, InputSchema schema) { List<Model> models = pmml.getModels(); Preconditions.checkArgument(models.size() == 1, "Should have exactly one model, but had %s", models.size()); Model model = models.get(0); Preconditions.checkArgument(model instanceof ClusteringModel); Preconditions.checkArgument(model.getMiningFunction() == MiningFunction.CLUSTERING); DataDictionary dictionary = pmml.getDataDictionary(); Preconditions.checkArgument( schema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(dictionary)), "Feature names in schema don't match names in PMML"); MiningSchema miningSchema = model.getMiningSchema(); Preconditions.checkArgument(schema.getFeatureNames().equals( AppPMMLUtils.getFeatureNames(miningSchema))); }
DataDictionary dictionary = pmml.getDataDictionary(); Preconditions.checkArgument( schema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(dictionary)),
DataDictionary dictionary = pmml.getDataDictionary(); List<String> featureNames = AppPMMLUtils.getFeatureNames(dictionary); CategoricalValueEncodings categoricalValueEncodings =
AppPMMLUtils.buildCategoricalValueEncodings(pmml.getDataDictionary()); log.info("{}", encodings); Map<String,Integer> fruitEncoding = encodings.getValueEncodingMap(0);
checkHeader(pmml.getHeader()); checkDataDictionary(schema, pmml.getDataDictionary());
checkExtensions(pmml, expected); checkDataDictionary(schema, pmml.getDataDictionary());
/** * get the header names from the PMML data dictionary * * @param pmml * the pmml model * @return headers */ public static String[] getDataDicHeaders(final PMML pmml) { DataDictionary dictionary = pmml.getDataDictionary(); List<DataField> fields = dictionary.getDataFields(); int len = fields.size(); String[] headers = new String[len]; for(int i = 0; i < len; i++) { headers[i] = fields.get(i).getName().getValue(); } return headers; }
/** * Validates that the encoded PMML model received matches expected schema. * * @param pmml {@link PMML} encoding of KMeans Clustering * @param schema expected schema attributes of KMeans Clustering */ public static void validatePMMLVsSchema(PMML pmml, InputSchema schema) { List<Model> models = pmml.getModels(); Preconditions.checkArgument(models.size() == 1, "Should have exactly one model, but had %s", models.size()); Model model = models.get(0); Preconditions.checkArgument(model instanceof ClusteringModel); Preconditions.checkArgument(model.getMiningFunction() == MiningFunction.CLUSTERING); DataDictionary dictionary = pmml.getDataDictionary(); Preconditions.checkArgument( schema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(dictionary)), "Feature names in schema don't match names in PMML"); MiningSchema miningSchema = model.getMiningSchema(); Preconditions.checkArgument(schema.getFeatureNames().equals( AppPMMLUtils.getFeatureNames(miningSchema))); }
@Override public PMMLObject popParent(){ PMMLObject parent = super.popParent(); if(parent instanceof Model){ Model model = (Model)parent; processModel(model); } else if(parent instanceof PMML){ PMML pmml = (PMML)parent; DataDictionary dataDictionary = pmml.getDataDictionary(); if(dataDictionary != null){ processDataDictionary(dataDictionary); } } return parent; }
@Override public PMMLObject popParent(){ PMMLObject parent = super.popParent(); if(parent instanceof Model){ Model model = (Model)parent; processModel(model); } else if(parent instanceof PMML){ PMML pmml = (PMML)parent; DataDictionary dataDictionary = pmml.getDataDictionary(); if(dataDictionary != null){ processDataDictionary(dataDictionary); } } return parent; }
/** * Based on the usage type, get the column indexes for corresponding fields * in the input data set * * @param pmml * the pmml model * @param type * the type * @return dic fields */ public static int[] getDicFieldIDViaType(PMML pmml, FieldUsageType type) { List<Integer> activeFields = new ArrayList<Integer>(); HashMap<String, Integer> dMap = new HashMap<String, Integer>(); int index = 0; for(DataField dField: pmml.getDataDictionary().getDataFields()) dMap.put(dField.getName().getValue(), index++); for(MiningField mField: pmml.getModels().get(0).getMiningSchema().getMiningFields()) { if(mField.getUsageType() == type) activeFields.add(dMap.get(mField.getName().getValue())); } return Ints.toArray(activeFields); }
@Override public VisitorAction visit(PMML pmml){ DataDictionary dataDictionary = pmml.getDataDictionary(); if(dataDictionary != null && dataDictionary.hasDataFields()){ declare(pmml, dataDictionary.getDataFields()); } TransformationDictionary transformationDictionary = pmml.getTransformationDictionary(); if(transformationDictionary != null && transformationDictionary.hasDerivedFields()){ declare(pmml, transformationDictionary.getDerivedFields()); } return super.visit(pmml); }
@Override public VisitorAction visit(PMML pmml){ DataDictionary dataDictionary = pmml.getDataDictionary(); if(dataDictionary != null && dataDictionary.hasDataFields()){ declare(pmml, dataDictionary.getDataFields()); } TransformationDictionary transformationDictionary = pmml.getTransformationDictionary(); if(transformationDictionary != null && transformationDictionary.hasDerivedFields()){ declare(pmml, transformationDictionary.getDerivedFields()); } return super.visit(pmml); }
@Test public void cleanChained() throws Exception { PMML pmml = ResourceUtil.unmarshal(ChainedSegmentationTest.class); DataDictionary dataDictionary = pmml.getDataDictionary(); checkFields(FieldNameUtil.create("y", "x1", "x2", "x3", "x4"), dataDictionary.getDataFields()); DataDictionaryCleaner cleaner = new DataDictionaryCleaner(); cleaner.applyTo(pmml); checkFields(FieldNameUtil.create("y", "x1", "x2", "x3"), dataDictionary.getDataFields()); List<Model> models = pmml.getModels(); models.clear(); cleaner.applyTo(pmml); checkFields(Collections.emptySet(), dataDictionary.getDataFields()); }
@Test public void cleanNested() throws Exception { PMML pmml = ResourceUtil.unmarshal(NestedSegmentationTest.class); DataDictionary dataDictionary = pmml.getDataDictionary(); checkFields(FieldNameUtil.create("y", "x1", "x2", "x3", "x4", "x5"), dataDictionary.getDataFields()); DataDictionaryCleaner cleaner = new DataDictionaryCleaner(); cleaner.applyTo(pmml); checkFields(FieldNameUtil.create("x1", "x2", "x3", "x4", "x5"), dataDictionary.getDataFields()); List<Model> models = pmml.getModels(); models.clear(); cleaner.applyTo(pmml); checkFields(Collections.emptySet(), dataDictionary.getDataFields()); }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if (status == VisitorAction.CONTINUE) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getHeader(), getMiningBuildTask(), getDataDictionary(), getTransformationDictionary()); } if ((status == VisitorAction.CONTINUE)&&hasModels()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getModels()); } if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getExtensions()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if (status == VisitorAction.CONTINUE) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getHeader(), getMiningBuildTask(), getDataDictionary(), getTransformationDictionary()); } if ((status == VisitorAction.CONTINUE)&&hasModels()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getModels()); } if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getExtensions()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Test public void filterChainedSegmentation() throws Exception { PMML pmml = ResourceUtil.unmarshal(ChainedSegmentationTest.class, new SkipFilter("Segmentation")); assertNotNull(pmml.getDataDictionary()); assertNotNull(pmml.getTransformationDictionary()); List<Model> models = pmml.getModels(); MiningModel miningModel = (MiningModel)models.get(0); assertNotNull(miningModel.getMiningSchema()); assertNotNull(miningModel.getOutput()); assertNull(miningModel.getSegmentation()); }
@Test public void filterNestedSegmentation() throws Exception { PMML pmml = ResourceUtil.unmarshal(NestedSegmentationTest.class, new SkipFilter(Segmentation.class)); assertNotNull(pmml.getDataDictionary()); List<Model> models = pmml.getModels(); MiningModel miningModel = (MiningModel)models.get(0); assertNotNull(miningModel.getMiningSchema()); assertNotNull(miningModel.getLocalTransformations()); assertNotNull(miningModel.getOutput()); assertNull(miningModel.getSegmentation()); }
@Test public void copyState(){ PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), new DataDictionary()); // Initialize the live list instance pmml.getModels(); CustomPMML customPmml = new CustomPMML(); ReflectionUtil.copyState(pmml, customPmml); assertSame(pmml.getVersion(), customPmml.getVersion()); assertSame(pmml.getHeader(), customPmml.getHeader()); assertSame(pmml.getDataDictionary(), customPmml.getDataDictionary()); assertFalse(pmml.hasModels()); assertFalse(customPmml.hasModels()); pmml.addModels(new RegressionModel()); assertTrue(pmml.hasModels()); assertTrue(customPmml.hasModels()); assertSame(pmml.getModels(), customPmml.getModels()); try { ReflectionUtil.copyState(customPmml, pmml); fail(); } catch(IllegalArgumentException iae){ // Ignored } }