private static Prediction voteOnNumericFeature(List<NumericPrediction> predictions, double[] weights) { DoubleWeightedMean mean = new DoubleWeightedMean(); for (int i = 0; i < predictions.size(); i++) { mean.increment(predictions.get(i).getPrediction(), weights[i]); } return new NumericPrediction(mean.getResult(), (int) mean.getN()); }
static DecisionTree buildTestTree() { TerminalNode rnn = new TerminalNode("r--", new NumericPrediction(0.0, 1)); TerminalNode rnp = new TerminalNode("r-+", new NumericPrediction(1.0, 1)); DecisionNode rn = new DecisionNode("r-", new NumericDecision(0, -1.0, false), rnn, rnp); TerminalNode rp = new TerminalNode("r+", new NumericPrediction(2.0, 1)); DecisionNode root = new DecisionNode("r", new NumericDecision(0, 1.0, false), rn, rp); return new DecisionTree(root); }
@Test public void testNumericVote() { List<NumericPrediction> predictions = Arrays.asList( new NumericPrediction(1.0, 1), new NumericPrediction(3.0, 2), new NumericPrediction(6.0, 3) ); double[] weights = {1.0, 1.0, 1.0}; NumericPrediction vote = (NumericPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.NUMERIC, vote.getFeatureType()); assertEquals(10.0 /3.0, vote.getPrediction()); }
@Test public void testNumericVoteWeighted() { List<NumericPrediction> predictions = Arrays.asList( new NumericPrediction(1.0, 1), new NumericPrediction(3.0, 2), new NumericPrediction(6.0, 3) ); double[] weights = {3.0, 2.0, 1.0}; NumericPrediction vote = (NumericPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.NUMERIC, vote.getFeatureType()); assertEquals(15.0 / 6.0, vote.getPrediction()); }
public static RDFServingModel buildTestModel() { Map<Integer,Collection<String>> distinctValues = new HashMap<>(); distinctValues.put(0, Arrays.asList("A", "B", "C")); CategoricalValueEncodings encodings = new CategoricalValueEncodings(distinctValues); TerminalNode left1 = new TerminalNode("r-", new NumericPrediction(1.0, 1)); TerminalNode right1 = new TerminalNode("r+", new NumericPrediction(10.0, 1)); BitSet activeCategories = new BitSet(2); activeCategories.set(1); Decision decision1 = new CategoricalDecision(0, activeCategories, true); TreeNode root1 = new DecisionNode("r", decision1, left1, right1); TerminalNode left2 = new TerminalNode("r-", new NumericPrediction(100.0, 1)); TerminalNode right2 = new TerminalNode("r+", new NumericPrediction(1000.0, 1)); Decision decision2 = new NumericDecision(1, -3.0, false); TreeNode root2 = new DecisionNode("r", decision2, left2, right2); DecisionTree tree1 = new DecisionTree(root1); DecisionTree tree2 = new DecisionTree(root2); DecisionTree[] trees = { tree1, tree2 }; double[] weights = { 1.0, 2.0 }; double[] featureImportances = { 0.1, 0.3 }; DecisionForest forest = new DecisionForest(trees, weights, featureImportances); Map<String,Object> overlayConfig = new HashMap<>(); overlayConfig.put("oryx.input-schema.num-features", 3); overlayConfig.put("oryx.input-schema.categorical-features", "[\"0\"]"); overlayConfig.put("oryx.input-schema.target-feature", "\"2\""); Config config = ConfigUtils.overlayOn(overlayConfig, ConfigUtils.getDefault()); InputSchema inputSchema = new InputSchema(config); return new RDFServingModel(forest, encodings, inputSchema); }
@Test public void testEquals() { NumericPrediction prediction = new NumericPrediction(1.5, 1); NumericPrediction prediction1 = new NumericPrediction(1.5, 2); assertEquals(prediction, prediction1); prediction1.update(2.0, 2); assertNotEquals(prediction, prediction1); prediction1.update(1.5, 4); assertNotEquals(prediction, prediction1); }
prediction = new NumericPrediction(Double.parseDouble(root.getScore()), (int) Math.round(root.getRecordCount()));
@Test public void testEquals() { Prediction a = new NumericPrediction(1.5, 10); Prediction b = new NumericPrediction(1.5, 10); TerminalNode ta = new TerminalNode("a", a); TerminalNode tb = new TerminalNode("b", b); assertEquals(ta.hashCode(), tb.hashCode()); assertEquals(ta, tb); }
@Test public void testUpdate2() { NumericPrediction prediction = new NumericPrediction(1.5, 1); prediction.update(3.5, 3); assertEquals(3.0, prediction.getPrediction()); }
@Test public void testConstruct() { NumericPrediction prediction = new NumericPrediction(1.5, 1); assertEquals(FeatureType.NUMERIC, prediction.getFeatureType()); assertEquals(1.5, prediction.getPrediction()); }
@Test public void testUpdate() { NumericPrediction prediction = new NumericPrediction(1.5, 1); Example example = new Example(NumericFeature.forValue(2.5)); prediction.update(example); assertEquals(2.0, prediction.getPrediction()); }
@Test public void testHashCode() { NumericPrediction prediction = new NumericPrediction(1.5, 1); assertEquals(1073217536, prediction.hashCode()); prediction.update(2.0, 2); assertEquals(1789394944, prediction.hashCode()); }
@Test public void testNode() { Prediction prediction = new NumericPrediction(1.2, 3); TerminalNode node = new TerminalNode("1", prediction); assertTrue(node.isTerminal()); assertSame(prediction, node.getPrediction()); assertEquals(3, node.getCount()); }
@Test public void testEquals() { Decision a = new NumericDecision(1, 2.0, true); Decision b = new NumericDecision(1, 2.0, true); Prediction p = new NumericPrediction(-1.0, 1); TreeNode left = new TerminalNode("2", p); TreeNode right = new TerminalNode("3", p); DecisionNode da = new DecisionNode("a", a, left, right); DecisionNode db = new DecisionNode("b", b, left, right); assertEquals(da, db); assertEquals(da.hashCode(), db.hashCode()); }
private static Prediction voteOnNumericFeature(List<NumericPrediction> predictions, double[] weights) { DoubleWeightedMean mean = new DoubleWeightedMean(); for (int i = 0; i < predictions.size(); i++) { mean.increment(predictions.get(i).getPrediction(), weights[i]); } return new NumericPrediction(mean.getResult(), (int) mean.getN()); }
prediction = new NumericPrediction(Double.parseDouble(root.getScore()), (int) Math.round(root.getRecordCount()));