feature = null; } else if (inputSchema.isNumeric(featureIndex)) { feature = NumericFeature.forValue(Double.parseDouble(dataAtIndex)); } else if (inputSchema.isCategorical(featureIndex)) { int encoding = valueEncodings.getValueEncodingMap(featureIndex).get(dataAtIndex);
@Test public void testToString() { NumericFeature f = NumericFeature.forValue(1.5); assertEquals("1.5", f.toString()); }
@Test public void testDecision() { Decision decision = new NumericDecision(0, -3.1, true); assertFalse(decision.isPositive(new Example(null, NumericFeature.forValue(-3.5)))); assertTrue(decision.isPositive(new Example(null, NumericFeature.forValue(-3.1)))); assertTrue(decision.isPositive(new Example(null, NumericFeature.forValue(-3.0)))); assertTrue(decision.isPositive(new Example(null, NumericFeature.forValue(3.1)))); assertTrue(decision.isPositive(new Example(null, new Feature[] {null}))); }
@Test public void testToExample() { Map<String,Object> overlayConfig = new HashMap<>(); overlayConfig.put("oryx.input-schema.num-features", 5); overlayConfig.put("oryx.input-schema.categorical-features", "[\"4\"]"); overlayConfig.put("oryx.input-schema.id-features", "[\"0\"]"); overlayConfig.put("oryx.input-schema.target-feature", "\"4\""); Config config = ConfigUtils.overlayOn(overlayConfig, ConfigUtils.getDefault()); InputSchema schema = new InputSchema(config); CategoricalValueEncodings encodings = new CategoricalValueEncodings(Collections.singletonMap(4, Arrays.asList("A", "B", "C"))); Example example = ExampleUtils.dataToExample(new String[] {"foo", "1", "2.5", "-3.2", "B"}, schema, encodings); assertEquals(CategoricalFeature.forEncoding(1), example.getTarget()); assertNull(example.getFeature(0)); assertEquals(NumericFeature.forValue(1.0), example.getFeature(1)); assertEquals(NumericFeature.forValue(2.5), example.getFeature(2)); assertEquals(NumericFeature.forValue(-3.2), example.getFeature(3)); assertNull(example.getFeature(4)); assertTrue(example.toString().contains("2.5")); }
@Test public void testPredict() { DecisionForest forest = buildTestForest(); NumericPrediction prediction = (NumericPrediction) forest.predict(new Example(null, NumericFeature.forValue(0.5))); assertEquals(1.0, prediction.getPrediction()); }
@Test public void testPredict() { DecisionTree tree = buildTestTree(); NumericPrediction prediction = (NumericPrediction) tree.predict(new Example(null, NumericFeature.forValue(0.5))); assertEquals(1.0, prediction.getPrediction()); }
@Test public void testUpdate() { NumericPrediction prediction = new NumericPrediction(1.5, 1); Example example = new Example(NumericFeature.forValue(2.5)); prediction.update(example); assertEquals(2.0, prediction.getPrediction()); }
@Test public void testFindTerminal() { DecisionTree tree = buildTestTree(); TerminalNode node = tree.findTerminal(new Example(null, NumericFeature.forValue(0.5))); NumericPrediction prediction = (NumericPrediction) node.getPrediction(); assertEquals(1.0, prediction.getPrediction()); }
feature = null; } else if (inputSchema.isNumeric(featureIndex)) { feature = NumericFeature.forValue(Double.parseDouble(dataAtIndex)); } else if (inputSchema.isCategorical(featureIndex)) { int encoding = valueEncodings.getValueEncodingMap(featureIndex).get(dataAtIndex);