public Prediction makePrediction(String[] example) { if (example.length != inputSchema.getNumFeatures()) { throw new IllegalArgumentException("Wrong number of features"); } return forest.predict(ExampleUtils.dataToExample(example, inputSchema, encodings)); }
CategoricalValueEncodings valueEncodings = model.getEncodings(); JavaRDD<Example> examplesRDD = newData.values().map(MLFunctions.PARSE_FN). map(data -> ExampleUtils.dataToExample(data, inputSchema, valueEncodings));
@Override public double evaluate(JavaSparkContext sparkContext, PMML model, Path modelParentPath, JavaRDD<String> testData, JavaRDD<String> trainData) { RDFPMMLUtils.validatePMMLVsSchema(model, inputSchema); Pair<DecisionForest,CategoricalValueEncodings> forestAndEncoding = RDFPMMLUtils.read(model); DecisionForest forest = forestAndEncoding.getFirst(); CategoricalValueEncodings valueEncodings = forestAndEncoding.getSecond(); InputSchema inputSchema = this.inputSchema; JavaRDD<Example> examplesRDD = testData.map(MLFunctions.PARSE_FN). map(data -> ExampleUtils.dataToExample(data, inputSchema, valueEncodings)); double eval; if (inputSchema.isClassification()) { double accuracy = Evaluation.accuracy(forest, examplesRDD); log.info("Accuracy: {}", accuracy); eval = accuracy; } else { double rmse = Evaluation.rmse(forest, examplesRDD); log.info("RMSE: {}", rmse); eval = -rmse; } return eval; }
@Test public void testToExample() { Map<String,Object> overlayConfig = new HashMap<>(); overlayConfig.put("oryx.input-schema.num-features", 5); overlayConfig.put("oryx.input-schema.categorical-features", "[\"4\"]"); overlayConfig.put("oryx.input-schema.id-features", "[\"0\"]"); overlayConfig.put("oryx.input-schema.target-feature", "\"4\""); Config config = ConfigUtils.overlayOn(overlayConfig, ConfigUtils.getDefault()); InputSchema schema = new InputSchema(config); CategoricalValueEncodings encodings = new CategoricalValueEncodings(Collections.singletonMap(4, Arrays.asList("A", "B", "C"))); Example example = ExampleUtils.dataToExample(new String[] {"foo", "1", "2.5", "-3.2", "B"}, schema, encodings); assertEquals(CategoricalFeature.forEncoding(1), example.getTarget()); assertNull(example.getFeature(0)); assertEquals(NumericFeature.forValue(1.0), example.getFeature(1)); assertEquals(NumericFeature.forValue(2.5), example.getFeature(2)); assertEquals(NumericFeature.forValue(-3.2), example.getFeature(3)); assertNull(example.getFeature(4)); assertTrue(example.toString().contains("2.5")); }
CategoricalValueEncodings valueEncodings = model.getEncodings(); JavaRDD<Example> examplesRDD = newData.values().map(MLFunctions.PARSE_FN). map(data -> ExampleUtils.dataToExample(data, inputSchema, valueEncodings));