@Override public void update(Example train) { CategoricalFeature target = (CategoricalFeature) train.getTarget(); update(target.getEncoding(), 1); }
@SuppressWarnings("unchecked") Map<String,Integer> counts = (Map<String,Integer>) update.get(2); // JSON map keys are always Strings counts.forEach((encoding, count) -> predictionToUpdate.update(Integer.parseInt(encoding), count)); } else { TerminalNode nodeToUpdate = (TerminalNode) forest.getTrees()[treeID].findByID(nodeID);
@Test public void testUpdate2() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); prediction.update(0, 3); prediction.update(1, 9); assertArrayEquals(new double[] { 3, 10, 3, 0, 4, 0 }, prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.15, 0.5, 0.15, 0.0, 0.2, 0.0}, prediction.getCategoryProbabilities()); }
@Test public void testUpdate() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); Example example = new Example(CategoricalFeature.forEncoding(2)); // Yes, called twice prediction.update(example); prediction.update(example); assertEquals(2, prediction.getMostProbableCategoryEncoding()); counts[2] += 2; assertArrayEquals(toDoubles(counts), prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.0, 0.1, 0.5, 0.0, 0.4, 0.0}, prediction.getCategoryProbabilities()); }
@Override public void update(Example train) { CategoricalFeature target = (CategoricalFeature) train.getTarget(); update(target.getEncoding(), 1); }