return new Example(target, features); } catch (NumberFormatException | ArrayIndexOutOfBoundsException e) { log.warn("Bad input: {}", Arrays.toString(data));
@Test public void testDecision() { BitSet activeCategories = new BitSet(10); activeCategories.set(2); activeCategories.set(5); Decision decision = new CategoricalDecision(0, activeCategories, true); for (int i = 0; i < 10; i++) { assertEquals(activeCategories.get(i), decision.isPositive(new Example(null, CategoricalFeature.forEncoding(i)))); } assertTrue(decision.isPositive(new Example(null, new Feature[] {null}))); }
@Test public void testDecision() { Decision decision = new NumericDecision(0, -3.1, true); assertFalse(decision.isPositive(new Example(null, NumericFeature.forValue(-3.5)))); assertTrue(decision.isPositive(new Example(null, NumericFeature.forValue(-3.1)))); assertTrue(decision.isPositive(new Example(null, NumericFeature.forValue(-3.0)))); assertTrue(decision.isPositive(new Example(null, NumericFeature.forValue(3.1)))); assertTrue(decision.isPositive(new Example(null, new Feature[] {null}))); }
for (int f3 : zeroOne) { CategoricalPrediction prediction = (CategoricalPrediction) forest.predict(new Example(null, null, NumericFeature.forValue(f1),
CategoricalFeature feature3 = CategoricalFeature.forEncoding( encoding.getValueEncodingMap(3).get(f3 == 1 ? "A" : "B")); Example toPredict = new Example(null, null, feature1, feature2, feature3); double prediction = ((NumericPrediction) forest.predict(toPredict)).getPrediction(); assertEquals("Incorrect prediction " + prediction + " for " + toPredict,
@Test public void testPredict() { DecisionForest forest = buildTestForest(); NumericPrediction prediction = (NumericPrediction) forest.predict(new Example(null, NumericFeature.forValue(0.5))); assertEquals(1.0, prediction.getPrediction()); }
@Test public void testPredict() { DecisionTree tree = buildTestTree(); NumericPrediction prediction = (NumericPrediction) tree.predict(new Example(null, NumericFeature.forValue(0.5))); assertEquals(1.0, prediction.getPrediction()); }
@Test public void testUpdate() { NumericPrediction prediction = new NumericPrediction(1.5, 1); Example example = new Example(NumericFeature.forValue(2.5)); prediction.update(example); assertEquals(2.0, prediction.getPrediction()); }
@Test public void testFindTerminal() { DecisionTree tree = buildTestTree(); TerminalNode node = tree.findTerminal(new Example(null, NumericFeature.forValue(0.5))); NumericPrediction prediction = (NumericPrediction) node.getPrediction(); assertEquals(1.0, prediction.getPrediction()); }
@Test public void testUpdate() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); Example example = new Example(CategoricalFeature.forEncoding(2)); // Yes, called twice prediction.update(example); prediction.update(example); assertEquals(2, prediction.getMostProbableCategoryEncoding()); counts[2] += 2; assertArrayEquals(toDoubles(counts), prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.0, 0.1, 0.5, 0.0, 0.4, 0.0}, prediction.getCategoryProbabilities()); }
return new Example(target, features); } catch (NumberFormatException | ArrayIndexOutOfBoundsException e) { log.warn("Bad input: {}", Arrays.toString(data));