@Override public Prediction predict(Example test) { return WeightedPrediction.voteOnFeature( Arrays.stream(trees).map(tree -> tree.predict(test)).collect(Collectors.toList()), weights); }
/** * @param predictions {@link Prediction}s from individuals * @param weights weights that should be applied to * @param <T> type of {@link Prediction} to vote on * @return a single {@link Prediction} represented a weighted combination of the inputs */ public static <T extends Prediction> Prediction voteOnFeature(List<T> predictions, double[] weights) { Preconditions.checkArgument(!predictions.isEmpty(), "No predictions"); Preconditions.checkArgument(predictions.size() == weights.length, "%s predictions but %s weights?", predictions.size(), weights.length); switch (predictions.get(0).getFeatureType()) { case NUMERIC: @SuppressWarnings("unchecked") List<NumericPrediction> numericVotes = (List<NumericPrediction>) predictions; return voteOnNumericFeature(numericVotes, weights); case CATEGORICAL: @SuppressWarnings("unchecked") List<CategoricalPrediction> categoricalVotes = (List<CategoricalPrediction>) predictions; return voteOnCategoricalFeature(categoricalVotes, weights); default: throw new IllegalStateException(); } }
@Test public void testCategoricalVoteWeighted() { List<CategoricalPrediction> predictions = Arrays.asList( new CategoricalPrediction(new int[]{0, 1, 2}), new CategoricalPrediction(new int[]{6, 2, 0}), new CategoricalPrediction(new int[]{0, 2, 0}) ); double[] weights = {1.0, 10.0, 1.0}; CategoricalPrediction vote = (CategoricalPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.CATEGORICAL, vote.getFeatureType()); assertEquals(0, vote.getMostProbableCategoryEncoding()); }
/** * @param predictions {@link Prediction}s from individuals * @param weights weights that should be applied to * @param <T> type of {@link Prediction} to vote on * @return a single {@link Prediction} represented a weighted combination of the inputs */ public static <T extends Prediction> Prediction voteOnFeature(List<T> predictions, double[] weights) { Preconditions.checkArgument(!predictions.isEmpty(), "No predictions"); Preconditions.checkArgument(predictions.size() == weights.length, "%s predictions but %s weights?", predictions.size(), weights.length); switch (predictions.get(0).getFeatureType()) { case NUMERIC: @SuppressWarnings("unchecked") List<NumericPrediction> numericVotes = (List<NumericPrediction>) predictions; return voteOnNumericFeature(numericVotes, weights); case CATEGORICAL: @SuppressWarnings("unchecked") List<CategoricalPrediction> categoricalVotes = (List<CategoricalPrediction>) predictions; return voteOnCategoricalFeature(categoricalVotes, weights); default: throw new IllegalStateException(); } }
@Test public void testNumericVoteWeighted() { List<NumericPrediction> predictions = Arrays.asList( new NumericPrediction(1.0, 1), new NumericPrediction(3.0, 2), new NumericPrediction(6.0, 3) ); double[] weights = {3.0, 2.0, 1.0}; NumericPrediction vote = (NumericPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.NUMERIC, vote.getFeatureType()); assertEquals(15.0 / 6.0, vote.getPrediction()); }
@Test public void testNumericVote() { List<NumericPrediction> predictions = Arrays.asList( new NumericPrediction(1.0, 1), new NumericPrediction(3.0, 2), new NumericPrediction(6.0, 3) ); double[] weights = {1.0, 1.0, 1.0}; NumericPrediction vote = (NumericPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.NUMERIC, vote.getFeatureType()); assertEquals(10.0 /3.0, vote.getPrediction()); }
@Test public void testCategoricalVote() { List<CategoricalPrediction> predictions = Arrays.asList( new CategoricalPrediction(new int[]{0, 1, 2}), new CategoricalPrediction(new int[]{6, 2, 0}), new CategoricalPrediction(new int[]{0, 2, 0}) ); double[] weights = {1.0, 1.0, 1.0}; CategoricalPrediction vote = (CategoricalPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.CATEGORICAL, vote.getFeatureType()); assertEquals(1, vote.getMostProbableCategoryEncoding()); }
@Override public Prediction predict(Example test) { return WeightedPrediction.voteOnFeature( Arrays.stream(trees).map(tree -> tree.predict(test)).collect(Collectors.toList()), weights); }