/** * @param value value to represent as a feature * @return {@code NumericFeature} representing the given numeric value */ public static NumericFeature forValue(double value) { return value == 0.0f ? ZERO : new NumericFeature(value); }
@Override public boolean isPositive(Example example) { NumericFeature feature = (NumericFeature) example.getFeature(getFeatureNumber()); return feature == null ? defaultDecision : feature.getValue() >= threshold; }
@Test public void testDecision() { Decision decision = new NumericDecision(0, -3.1, true); assertFalse(decision.isPositive(new Example(null, NumericFeature.forValue(-3.5)))); assertTrue(decision.isPositive(new Example(null, NumericFeature.forValue(-3.1)))); assertTrue(decision.isPositive(new Example(null, NumericFeature.forValue(-3.0)))); assertTrue(decision.isPositive(new Example(null, NumericFeature.forValue(3.1)))); assertTrue(decision.isPositive(new Example(null, new Feature[] {null}))); }
@Test public void testFeature() { CategoricalFeature f = CategoricalFeature.forEncoding(1); assertEquals(FeatureType.CATEGORICAL, f.getFeatureType()); assertEquals(1, f.getEncoding()); assertEquals(f, CategoricalFeature.forEncoding(1)); // Not necessary for correctness to assert this, but fairly important for performance assertSame(f, CategoricalFeature.forEncoding(1)); }
@Test public void testFeature() { NumericFeature f = NumericFeature.forValue(1.5); assertEquals(FeatureType.NUMERIC, f.getFeatureType()); assertEquals(1.5, f.getValue()); assertEquals(f, NumericFeature.forValue(1.5)); assertNotEquals(f, NumericFeature.forValue(Double.NaN)); }
@Override public void update(Example train) { CategoricalFeature target = (CategoricalFeature) train.getTarget(); update(target.getEncoding(), 1); }
@Test public void testDecision() { BitSet activeCategories = new BitSet(10); activeCategories.set(2); activeCategories.set(5); Decision decision = new CategoricalDecision(0, activeCategories, true); for (int i = 0; i < 10; i++) { assertEquals(activeCategories.get(i), decision.isPositive(new Example(null, CategoricalFeature.forEncoding(i)))); } assertTrue(decision.isPositive(new Example(null, new Feature[] {null}))); }
@Override public boolean isPositive(Example example) { CategoricalFeature feature = (CategoricalFeature) example.getFeature(getFeatureNumber()); if (feature == null) { return defaultDecision; } int encoding = feature.getEncoding(); if (encoding >= activeCategoryEncodings.size()) { return defaultDecision; } return activeCategoryEncodings.get(encoding); }
@Test public void testToString() { NumericFeature f = NumericFeature.forValue(1.5); assertEquals("1.5", f.toString()); }
@Test public void testToString() { CategoricalFeature f = CategoricalFeature.forEncoding(1); assertEquals(":1", f.toString()); }
@Test public void testHashCode() { assertEquals(NumericFeature.forValue(1.5), NumericFeature.forValue(1.5)); assertEquals(NumericFeature.forValue(Double.MIN_VALUE), NumericFeature.forValue(Double.MIN_VALUE)); }
/** * @param encoding category value ID to create {@code CategoricalFeature} for * @return {@code CategoricalFeature} representing the category value specified by ID */ public static CategoricalFeature forEncoding(int encoding) { Preconditions.checkArgument(encoding >= 0); // Not important if several threads get here return FEATURE_CACHE.computeIfAbsent(encoding, k -> new CategoricalFeature(encoding)); }
static double rmse(DecisionForest forest, JavaRDD<Example> examples) { double mse = examples.mapToDouble(example -> { NumericPrediction prediction = (NumericPrediction) forest.predict(example); NumericFeature target = (NumericFeature) example.getTarget(); double diff = prediction.getPrediction() - target.getValue(); return diff * diff; }).mean(); return Math.sqrt(mse); }
static double accuracy(DecisionForest forest, JavaRDD<Example> examples) { long total = examples.count(); if (total == 0) { return 0.0; } long correct = examples.filter(example -> { CategoricalPrediction prediction = (CategoricalPrediction) forest.predict(example); CategoricalFeature target = (CategoricalFeature) example.getTarget(); return prediction.getMostProbableCategoryEncoding() == target.getEncoding(); }).count(); return (double) correct / total; }
@Test public void testUpdate() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); Example example = new Example(CategoricalFeature.forEncoding(2)); // Yes, called twice prediction.update(example); prediction.update(example); assertEquals(2, prediction.getMostProbableCategoryEncoding()); counts[2] += 2; assertArrayEquals(toDoubles(counts), prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.0, 0.1, 0.5, 0.0, 0.4, 0.0}, prediction.getCategoryProbabilities()); }