@Test public void testUpdate2() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); prediction.update(0, 3); prediction.update(1, 9); assertArrayEquals(new double[] { 3, 10, 3, 0, 4, 0 }, prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.15, 0.5, 0.15, 0.0, 0.2, 0.0}, prediction.getCategoryProbabilities()); }
@Test public void testConstruct() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); assertEquals(FeatureType.CATEGORICAL, prediction.getFeatureType()); assertEquals(4, prediction.getMostProbableCategoryEncoding()); assertArrayEquals(toDoubles(counts), prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.0, 0.125, 0.375, 0.0, 0.5, 0.0}, prediction.getCategoryProbabilities()); }
assertEquals(2, leftPrediction.getCategoryCounts()[0]); assertEquals(5, leftPrediction.getCategoryCounts()[1]); assertEquals(3, rightPrediction.getCategoryCounts()[0]); assertEquals(4, rightPrediction.getCategoryCounts()[1]);
@Test public void testUpdate() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); Example example = new Example(CategoricalFeature.forEncoding(2)); // Yes, called twice prediction.update(example); prediction.update(example); assertEquals(2, prediction.getMostProbableCategoryEncoding()); counts[2] += 2; assertArrayEquals(toDoubles(counts), prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.0, 0.1, 0.5, 0.0, 0.4, 0.0}, prediction.getCategoryProbabilities()); }