@Override public String predict(String[] example) { Prediction prediction = makePrediction(example); if (inputSchema.isClassification()) { int targetIndex = inputSchema.getTargetFeatureIndex(); Map<Integer,String> targetEncodingName = encodings.getEncodingValueMap(targetIndex); int mostProbable = ((CategoricalPrediction) prediction).getMostProbableCategoryEncoding(); return targetEncodingName.get(mostProbable); } else { double score = ((NumericPrediction) prediction).getPrediction(); return Double.toString(score); } }
static double accuracy(DecisionForest forest, JavaRDD<Example> examples) { long total = examples.count(); if (total == 0) { return 0.0; } long correct = examples.filter(example -> { CategoricalPrediction prediction = (CategoricalPrediction) forest.predict(example); CategoricalFeature target = (CategoricalFeature) example.getTarget(); return prediction.getMostProbableCategoryEncoding() == target.getEncoding(); }).count(); return (double) correct / total; }
boolean expectedPositive = f1 == 1 && f2 == 1 && f3 == 1; assertEquals(targetEncoding.get(Boolean.toString(expectedPositive)).intValue(), prediction.getMostProbableCategoryEncoding());
@Test public void testCategoricalVoteWeighted() { List<CategoricalPrediction> predictions = Arrays.asList( new CategoricalPrediction(new int[]{0, 1, 2}), new CategoricalPrediction(new int[]{6, 2, 0}), new CategoricalPrediction(new int[]{0, 2, 0}) ); double[] weights = {1.0, 10.0, 1.0}; CategoricalPrediction vote = (CategoricalPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.CATEGORICAL, vote.getFeatureType()); assertEquals(0, vote.getMostProbableCategoryEncoding()); }
@Test public void testCategoricalVote() { List<CategoricalPrediction> predictions = Arrays.asList( new CategoricalPrediction(new int[]{0, 1, 2}), new CategoricalPrediction(new int[]{6, 2, 0}), new CategoricalPrediction(new int[]{0, 2, 0}) ); double[] weights = {1.0, 1.0, 1.0}; CategoricalPrediction vote = (CategoricalPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.CATEGORICAL, vote.getFeatureType()); assertEquals(1, vote.getMostProbableCategoryEncoding()); }
@Test public void testConstructFromProbability() { double[] probability = {0.0, 0.125, 0.375, 0.0, 0.5, 0.0 }; CategoricalPrediction prediction = new CategoricalPrediction(probability); assertEquals(FeatureType.CATEGORICAL, prediction.getFeatureType()); assertEquals(4, prediction.getMostProbableCategoryEncoding()); assertArrayEquals(probability, prediction.getCategoryProbabilities()); }
@Test public void testConstruct() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); assertEquals(FeatureType.CATEGORICAL, prediction.getFeatureType()); assertEquals(4, prediction.getMostProbableCategoryEncoding()); assertArrayEquals(toDoubles(counts), prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.0, 0.125, 0.375, 0.0, 0.5, 0.0}, prediction.getCategoryProbabilities()); }
@Test public void testUpdate() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); Example example = new Example(CategoricalFeature.forEncoding(2)); // Yes, called twice prediction.update(example); prediction.update(example); assertEquals(2, prediction.getMostProbableCategoryEncoding()); counts[2] += 2; assertArrayEquals(toDoubles(counts), prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.0, 0.1, 0.5, 0.0, 0.4, 0.0}, prediction.getCategoryProbabilities()); }