private static PMML buildDummyModel() { Node node = new Node().setRecordCount(123.0); TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, null, node); PMML pmml = PMMLUtils.buildSkeletonPMML(); pmml.addModels(treeModel); return pmml; }
private static void checkTreeModel(TreeModel treeModel) { assertEquals(TreeModel.SplitCharacteristic.BINARY_SPLIT, treeModel.getSplitCharacteristic()); assertEquals(TreeModel.MissingValueStrategy.DEFAULT_CHILD, treeModel.getMissingValueStrategy()); checkNode(treeModel.getNode()); }
Preconditions.checkState(classificationTask == inputSchema.isClassification()); Node root = new Node(); root.setId("r"); modelNode.setPredicate(predicate); modelNode.setRecordCount((double) nodeCount); modelNode.addScoreDistributions(distribution); modelNode.setScore(Double.toString(targetEncodedValue)); Node positiveModelNode = new Node().setId(modelNode.getId() + '+'); Node negativeModelNode = new Node().setId(modelNode.getId() + '-'); modelNode.addNodes(positiveModelNode, negativeModelNode); modelNode.setDefaultChild(defaultRight ? positiveModelNode.getId() : negativeModelNode.getId()); return new TreeModel() .setNode(root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
Node rootNode = new Node().setId("r").setRecordCount(dummyCount).setPredicate(new True()); Node left = new Node() .setId("r-") .setRecordCount(halfCount) .setPredicate(new True()) .setScore("-2.0"); Node right = new Node().setId("r+").setRecordCount(halfCount) .setPredicate(new SimplePredicate(FieldName.create("foo"), SimplePredicate.Operator.GREATER_THAN).setValue("3.14")) .setScore("2.0"); rootNode.addNodes(right, left); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, miningSchema, rootNode) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD) .setMiningSchema(miningSchema);
@Override public void enterTreeModel(TreeModel treeModel){ TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy(); TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy(); TreeModel.SplitCharacteristic splitCharacteristic = treeModel.getSplitCharacteristic(); if(!(TreeModel.MissingValueStrategy.NONE).equals(missingValueStrategy) || !(TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION).equals(noTrueChildStrategy) || !(TreeModel.SplitCharacteristic.BINARY_SPLIT).equals(splitCharacteristic)){ throw new IllegalArgumentException(); } this.miningFunction = treeModel.getMiningFunction(); }
Node rootNode = new Node().setId("r").setRecordCount(dummyCount).setPredicate(new True()); Node left = new Node().setId("r-").setRecordCount(halfCount).setPredicate(new True()); left.addScoreDistributions(new ScoreDistribution("apple", halfCount)); Node right = new Node().setId("r+").setRecordCount(halfCount) .setPredicate(new SimpleSetPredicate(FieldName.create("color"), SimpleSetPredicate.BooleanOperator.IS_NOT_IN, new Array(Array.Type.STRING, "red"))); right.addScoreDistributions(new ScoreDistribution("banana", halfCount)); rootNode.addNodes(right, left); TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, miningSchema, rootNode) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
String id = root.getId(); List<Node> children = root.getNodes(); if (children.isEmpty()) { Collection<ScoreDistribution> scoreDistributions = root.getScoreDistributions(); Prediction prediction; if (scoreDistributions != null && !scoreDistributions.isEmpty()) { prediction = new NumericPrediction(Double.parseDouble(root.getScore()), (int) Math.round(root.getRecordCount())); Node negativeLeftChild; Node positiveRightChild; if (child1.getPredicate() instanceof True) { negativeLeftChild = child1; positiveRightChild = child2; } else { Preconditions.checkArgument(child2.getPredicate() instanceof True); negativeLeftChild = child2; positiveRightChild = child1; Predicate predicate = positiveRightChild.getPredicate(); boolean defaultDecision = positiveRightChild.getId().equals(root.getDefaultChild());
private static void checkNode(Node node) { assertNotNull(node.getId()); List<ScoreDistribution> scoreDists = node.getScoreDistributions(); int numDists = scoreDists.size(); if (numDists == 0) { List<Node> children = node.getNodes(); assertEquals(2, children.size()); Node rightChild = children.get(0); Node leftChild = children.get(1); assertInstanceOf(leftChild.getPredicate(), True.class); assertEquals(node.getRecordCount().doubleValue(), leftChild.getRecordCount() + rightChild.getRecordCount()); assertEquals(node.getId() + "+", rightChild.getId()); assertEquals(node.getId() + "-", leftChild.getId()); checkNode(rightChild); checkNode(leftChild);
@Test public void testReadWrite() throws Exception { Path tempModelFile = Files.createTempFile(getTempDir(), "model", ".pmml"); PMML model = buildDummyModel(); PMMLUtils.write(model, tempModelFile); assertTrue(Files.exists(tempModelFile)); PMML model2 = PMMLUtils.read(tempModelFile); List<Model> models = model2.getModels(); assertEquals(1, models.size()); assertInstanceOf(models.get(0), TreeModel.class); TreeModel treeModel = (TreeModel) models.get(0); assertEquals(123.0, treeModel.getNode().getRecordCount().doubleValue()); assertEquals(MiningFunction.CLASSIFICATION, treeModel.getMiningFunction()); }
@Override public void enterTreeModel(TreeModel treeModel){ TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy(); TreeModel.SplitCharacteristic splitCharacteristic = treeModel.getSplitCharacteristic(); if(!(TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION).equals(noTrueChildStrategy) || !(TreeModel.SplitCharacteristic.BINARY_SPLIT).equals(splitCharacteristic)){ throw new IllegalArgumentException(); } }
static private Node createLeafNode(String score, Predicate predicate){ return new LeafNode() .setScore(score) .setPredicate(predicate); }
static private Node createBranchNode(String score, Predicate predicate){ return new BranchNode() .setScore(score) .setPredicate(predicate); }
/** * Create an instance of {@link TreeModel } * */ public TreeModel createTreeModel() { return new TreeModel(); }
@Override public void exitTreeModel(TreeModel treeModel){ treeModel.setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT); this.miningFunction = null; }
@Override public void enterTreeModel(TreeModel treeModel){ TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy(); TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy(); TreeModel.SplitCharacteristic splitCharacteristic = treeModel.getSplitCharacteristic(); if(!(TreeModel.MissingValueStrategy.DEFAULT_CHILD).equals(missingValueStrategy) || !(TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION).equals(noTrueChildStrategy) || !(TreeModel.SplitCharacteristic.BINARY_SPLIT).equals(splitCharacteristic)){ throw new IllegalArgumentException(); } }
public static PMML buildDummyModel() { Node node = new Node().setRecordCount(123.0); TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, null, node); PMML pmml = PMMLUtils.buildSkeletonPMML(); pmml.addModels(treeModel); return pmml; }
@Override public void enterTreeModel(TreeModel treeModel){ TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy(); TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy(); TreeModel.SplitCharacteristic splitCharacteristic = treeModel.getSplitCharacteristic(); if(!(TreeModel.MissingValueStrategy.DEFAULT_CHILD).equals(missingValueStrategy) || !(TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION).equals(noTrueChildStrategy) || !(TreeModel.SplitCharacteristic.BINARY_SPLIT).equals(splitCharacteristic)){ throw new IllegalArgumentException(); } }
@Override public void enterTreeModel(TreeModel treeModel){ TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy(); TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy(); TreeModel.SplitCharacteristic splitCharacteristic = treeModel.getSplitCharacteristic(); if(!(TreeModel.MissingValueStrategy.DEFAULT_CHILD).equals(missingValueStrategy) || !(TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION).equals(noTrueChildStrategy) || !(TreeModel.SplitCharacteristic.BINARY_SPLIT).equals(splitCharacteristic)){ throw new IllegalArgumentException(); } }
@Override public void enterTreeModel(TreeModel treeModel){ TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy(); TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy(); TreeModel.SplitCharacteristic splitCharacteristic = treeModel.getSplitCharacteristic(); if(!(TreeModel.MissingValueStrategy.NONE).equals(missingValueStrategy) || !(TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION).equals(noTrueChildStrategy) || !(TreeModel.SplitCharacteristic.BINARY_SPLIT).equals(splitCharacteristic)){ throw new IllegalArgumentException(); } this.miningFunction = treeModel.getMiningFunction(); this.replacedPredicates.clear(); }