@Test public void testNumericVoteWeighted() { List<NumericPrediction> predictions = Arrays.asList( new NumericPrediction(1.0, 1), new NumericPrediction(3.0, 2), new NumericPrediction(6.0, 3) ); double[] weights = {3.0, 2.0, 1.0}; NumericPrediction vote = (NumericPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.NUMERIC, vote.getFeatureType()); assertEquals(15.0 / 6.0, vote.getPrediction()); }
@Test public void testUpdate2() { NumericPrediction prediction = new NumericPrediction(1.5, 1); prediction.update(3.5, 3); assertEquals(3.0, prediction.getPrediction()); }
private static Prediction voteOnNumericFeature(List<NumericPrediction> predictions, double[] weights) { DoubleWeightedMean mean = new DoubleWeightedMean(); for (int i = 0; i < predictions.size(); i++) { mean.increment(predictions.get(i).getPrediction(), weights[i]); } return new NumericPrediction(mean.getResult(), (int) mean.getN()); }
@Override public String predict(String[] example) { Prediction prediction = makePrediction(example); if (inputSchema.isClassification()) { int targetIndex = inputSchema.getTargetFeatureIndex(); Map<Integer,String> targetEncodingName = encodings.getEncodingValueMap(targetIndex); int mostProbable = ((CategoricalPrediction) prediction).getMostProbableCategoryEncoding(); return targetEncodingName.get(mostProbable); } else { double score = ((NumericPrediction) prediction).getPrediction(); return Double.toString(score); } }
@Test public void testHashCode() { NumericPrediction prediction = new NumericPrediction(1.5, 1); assertEquals(1073217536, prediction.hashCode()); prediction.update(2.0, 2); assertEquals(1789394944, prediction.hashCode()); }
@Test public void testEquals() { NumericPrediction prediction = new NumericPrediction(1.5, 1); NumericPrediction prediction1 = new NumericPrediction(1.5, 2); assertEquals(prediction, prediction1); prediction1.update(2.0, 2); assertNotEquals(prediction, prediction1); prediction1.update(1.5, 4); assertNotEquals(prediction, prediction1); }
static DecisionTree buildTestTree() { TerminalNode rnn = new TerminalNode("r--", new NumericPrediction(0.0, 1)); TerminalNode rnp = new TerminalNode("r-+", new NumericPrediction(1.0, 1)); DecisionNode rn = new DecisionNode("r-", new NumericDecision(0, -1.0, false), rnn, rnp); TerminalNode rp = new TerminalNode("r+", new NumericPrediction(2.0, 1)); DecisionNode root = new DecisionNode("r", new NumericDecision(0, 1.0, false), rn, rp); return new DecisionTree(root); }
@Override public void update(Example train) { NumericFeature target = (NumericFeature) train.getTarget(); update(target.getValue(), 1); }
static double rmse(DecisionForest forest, JavaRDD<Example> examples) { double mse = examples.mapToDouble(example -> { NumericPrediction prediction = (NumericPrediction) forest.predict(example); NumericFeature target = (NumericFeature) example.getTarget(); double diff = prediction.getPrediction() - target.getValue(); return diff * diff; }).mean(); return Math.sqrt(mse); }
private static Prediction voteOnNumericFeature(List<NumericPrediction> predictions, double[] weights) { DoubleWeightedMean mean = new DoubleWeightedMean(); for (int i = 0; i < predictions.size(); i++) { mean.increment(predictions.get(i).getPrediction(), weights[i]); } return new NumericPrediction(mean.getResult(), (int) mean.getN()); }
public static RDFServingModel buildTestModel() { Map<Integer,Collection<String>> distinctValues = new HashMap<>(); distinctValues.put(0, Arrays.asList("A", "B", "C")); CategoricalValueEncodings encodings = new CategoricalValueEncodings(distinctValues); TerminalNode left1 = new TerminalNode("r-", new NumericPrediction(1.0, 1)); TerminalNode right1 = new TerminalNode("r+", new NumericPrediction(10.0, 1)); BitSet activeCategories = new BitSet(2); activeCategories.set(1); Decision decision1 = new CategoricalDecision(0, activeCategories, true); TreeNode root1 = new DecisionNode("r", decision1, left1, right1); TerminalNode left2 = new TerminalNode("r-", new NumericPrediction(100.0, 1)); TerminalNode right2 = new TerminalNode("r+", new NumericPrediction(1000.0, 1)); Decision decision2 = new NumericDecision(1, -3.0, false); TreeNode root2 = new DecisionNode("r", decision2, left2, right2); DecisionTree tree1 = new DecisionTree(root1); DecisionTree tree2 = new DecisionTree(root2); DecisionTree[] trees = { tree1, tree2 }; double[] weights = { 1.0, 2.0 }; double[] featureImportances = { 0.1, 0.3 }; DecisionForest forest = new DecisionForest(trees, weights, featureImportances); Map<String,Object> overlayConfig = new HashMap<>(); overlayConfig.put("oryx.input-schema.num-features", 3); overlayConfig.put("oryx.input-schema.categorical-features", "[\"0\"]"); overlayConfig.put("oryx.input-schema.target-feature", "\"2\""); Config config = ConfigUtils.overlayOn(overlayConfig, ConfigUtils.getDefault()); InputSchema inputSchema = new InputSchema(config); return new RDFServingModel(forest, encodings, inputSchema); }
double mean = Double.parseDouble(update.get(2).toString()); int count = Integer.parseInt(update.get(3).toString()); predictionToUpdate.update(mean, count);
@Test public void testNumericVote() { List<NumericPrediction> predictions = Arrays.asList( new NumericPrediction(1.0, 1), new NumericPrediction(3.0, 2), new NumericPrediction(6.0, 3) ); double[] weights = {1.0, 1.0, 1.0}; NumericPrediction vote = (NumericPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.NUMERIC, vote.getFeatureType()); assertEquals(10.0 /3.0, vote.getPrediction()); }
@Test public void testUpdate() { NumericPrediction prediction = new NumericPrediction(1.5, 1); Example example = new Example(NumericFeature.forValue(2.5)); prediction.update(example); assertEquals(2.0, prediction.getPrediction()); }
@Test public void testFindByID() { DecisionTree tree = buildTestTree(); TerminalNode node = (TerminalNode) tree.findByID("r-+"); assertEquals(1.0, ((NumericPrediction) node.getPrediction()).getPrediction()); }
prediction = new NumericPrediction(Double.parseDouble(root.getScore()), (int) Math.round(root.getRecordCount()));
@Override public void update(Example train) { NumericFeature target = (NumericFeature) train.getTarget(); update(target.getValue(), 1); }
@Test public void testConstruct() { NumericPrediction prediction = new NumericPrediction(1.5, 1); assertEquals(FeatureType.NUMERIC, prediction.getFeatureType()); assertEquals(1.5, prediction.getPrediction()); }