/** * @param dictionary {@link DataDictionary} from model * @return names of features in order */ public static List<String> getFeatureNames(DataDictionary dictionary) { List<DataField> dataFields = dictionary.getDataFields(); Preconditions.checkArgument(dataFields != null && !dataFields.isEmpty(), "No fields in DataDictionary"); return dataFields.stream().map(field -> field.getName().getValue()).collect(Collectors.toList()); }
return new DataDictionary(dataFields).setNumberOfFields(dataFields.size());
@Test public void testBuildDataDictionary() { Map<Integer,Collection<String>> distinctValues = new HashMap<>(); distinctValues.put(1, Arrays.asList("one", "two", "three", "four", "five")); CategoricalValueEncodings categoricalValueEncodings = new CategoricalValueEncodings(distinctValues); DataDictionary dictionary = AppPMMLUtils.buildDataDictionary(buildTestSchema(), categoricalValueEncodings); assertEquals(4, dictionary.getNumberOfFields().intValue()); checkDataField(dictionary.getDataFields().get(0), "foo", null); checkDataField(dictionary.getDataFields().get(1), "bar", true); checkDataField(dictionary.getDataFields().get(2), "baz", null); checkDataField(dictionary.getDataFields().get(3), "bing", false); List<Value> dfValues = dictionary.getDataFields().get(1).getValues(); assertEquals(5, dfValues.size()); String[] categoricalValues = { "one", "two", "three", "four", "five" }; for (int i = 0; i < categoricalValues.length; i++) { assertEquals(categoricalValues[i], dfValues.get(i).getValue()); } }
@Override public VisitorAction visit(DataDictionary dataDictionary){ if(dataDictionary.hasDataFields()){ this.dataFields.addAll(dataDictionary.getDataFields()); } return super.visit(dataDictionary); }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getExtensions()); } if ((status == VisitorAction.CONTINUE)&&hasDataFields()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getDataFields()); } if ((status == VisitorAction.CONTINUE)&&hasTaxonomies()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getTaxonomies()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
.setDescription("A very small binary tree model to show structure."); DataDictionary dataDictionary = new DataDictionary() .addDataFields( new DataField(temperature, OpType.CONTINUOUS, DataType.DOUBLE), new DataField(humidity, OpType.CONTINUOUS, DataType.DOUBLE), ); dataDictionary.setNumberOfFields((dataDictionary.getDataFields()).size());
/** * Create an instance of {@link DataDictionary } * */ public DataDictionary createDataDictionary() { return new DataDictionary(); }
@Override public DataDictionary build(BasicML basicML) { DataDictionary dict = new DataDictionary(); List<DataField> fields = new ArrayList<DataField>(); dict.withDataFields(fields); dict.withNumberOfFields(fields.size()); return dict;
@Override public Integer getSize(){ return dataDictionary.getNumberOfFields(); }
if(dataDictionary.hasDataFields()){ this.dataFields = CacheUtil.getValue(dataDictionary, ModelEvaluator.dataFieldCache);
private void processDataDictionary(DataDictionary dataDictionary){ if(dataDictionary.hasDataFields()){ List<DataField> dataFields = dataDictionary.getDataFields(); Set<DataField> usedDataFields = getUsedDataFields(); dataFields.retainAll(usedDataFields); } }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getExtensions()); } if ((status == VisitorAction.CONTINUE)&&hasDataFields()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getDataFields()); } if ((status == VisitorAction.CONTINUE)&&hasTaxonomies()) { status = org.dmg.pmml.PMMLObject.traverse(visitor, getTaxonomies()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
/** * Create an instance of {@link DataDictionary } * */ public DataDictionary createDataDictionary() { return new DataDictionary(); }
@Override public VisitorAction visit(DataDictionary dataDictionary){ if(dataDictionary.hasDataFields()){ this.dataFields.addAll(dataDictionary.getDataFields()); } return super.visit(dataDictionary); }
public static CategoricalValueEncodings buildCategoricalValueEncodings( DataDictionary dictionary) { Map<Integer,Collection<String>> indexToValues = new HashMap<>(); List<DataField> dataFields = dictionary.getDataFields(); for (int featureIndex = 0; featureIndex < dataFields.size(); featureIndex++) { TypeDefinitionField field = dataFields.get(featureIndex); Collection<Value> values = field.getValues(); if (values != null && !values.isEmpty()) { Collection<String> categoricalValues = values.stream().map(Value::getValue).collect(Collectors.toList()); indexToValues.put(featureIndex, categoricalValues); } } return new CategoricalValueEncodings(indexToValues); }
protected static void checkDataDictionary(InputSchema schema, DataDictionary dataDictionary) { assertNotNull(dataDictionary); assertEquals("Wrong number of features", schema.getNumFeatures(), dataDictionary.getNumberOfFields().intValue()); List<DataField> dataFields = dataDictionary.getDataFields(); assertEquals(schema.getNumFeatures(), dataFields.size()); for (DataField dataField : dataFields) { String featureName = dataField.getName().getValue(); if (schema.isNumeric(featureName)) { assertEquals("Wrong op type for feature " + featureName, OpType.CONTINUOUS, dataField.getOpType()); assertEquals("Wrong data type for feature " + featureName, DataType.DOUBLE, dataField.getDataType()); } else if (schema.isCategorical(featureName)) { assertEquals("Wrong op type for feature " + featureName, OpType.CATEGORICAL, dataField.getOpType()); assertEquals("Wrong data type for feature " + featureName, DataType.STRING, dataField.getDataType()); } else { assertNull(dataField.getOpType()); assertNull(dataField.getDataType()); } } }
@Test public void testBuildCategoricalEncoding() { List<DataField> dataFields = new ArrayList<>(); dataFields.add(new DataField(FieldName.create("foo"), OpType.CONTINUOUS, DataType.DOUBLE)); DataField barField = new DataField(FieldName.create("bar"), OpType.CATEGORICAL, DataType.STRING); barField.addValues(new Value("b"), new Value("a")); dataFields.add(barField); DataDictionary dictionary = new DataDictionary(dataFields).setNumberOfFields(dataFields.size()); CategoricalValueEncodings encodings = AppPMMLUtils.buildCategoricalValueEncodings(dictionary); assertEquals(2, encodings.getValueCount(1)); assertEquals(0, encodings.getValueEncodingMap(1).get("b").intValue()); assertEquals(1, encodings.getValueEncodingMap(1).get("a").intValue()); assertEquals("b", encodings.getEncodingValueMap(1).get(0)); assertEquals("a", encodings.getEncodingValueMap(1).get(1)); assertEquals(Collections.singletonMap(1, 2), encodings.getCategoryCounts()); }
static private PMML createPMML(){ Header header = new Header() .setCopyright("ACME Corporation"); DataDictionary dataDictionary = new DataDictionary(); PMML pmml = new PMML(Version.PMML_4_3.getVersion(), header, dataDictionary); return pmml; }
private void processDataDictionary(DataDictionary dataDictionary){ if(dataDictionary.hasDataFields()){ List<DataField> dataFields = dataDictionary.getDataFields(); Set<DataField> usedDataFields = getUsedDataFields(); dataFields.retainAll(usedDataFields); } }
@Override public Collection<?> getCollection(){ return dataDictionary.getDataFields(); } });