private ClusteringModel pmmlClusteringModel(KMeansModel model, Map<Integer,Long> clusterSizesMap) { Vector[] clusterCenters = model.clusterCenters(); List<ClusteringField> clusteringFields = new ArrayList<>(); for (int i = 0; i < inputSchema.getNumFeatures(); i++) { if (inputSchema.isActive(i)) { FieldName fieldName = FieldName.create(inputSchema.getFeatureNames().get(i)); ClusteringField clusteringField = new ClusteringField(fieldName).setCenterField(ClusteringField.CenterField.TRUE); clusteringFields.add(clusteringField); } } List<Cluster> clusters = new ArrayList<>(clusterCenters.length); for (int i = 0; i < clusterCenters.length; i++) { clusters.add(new Cluster().setId(Integer.toString(i)) .setSize(clusterSizesMap.get(i).intValue()) .setArray(AppPMMLUtils.toArray(clusterCenters[i].toArray()))); } return new ClusteringModel( MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), AppPMMLUtils.buildMiningSchema(inputSchema), new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setMeasure(new SquaredEuclidean()), clusteringFields, clusters); }
/** * Create an instance of {@link ClusteringField } * */ public ClusteringField createClusteringField() { return new ClusteringField(); }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getComparisons()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public VisitorAction visit(ClusteringField clusteringField){ process(clusteringField.getField()); return super.visit(clusteringField); }
@Override public ClusteringField addExtensions(Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
private List<ClusteringField> getCenterClusteringFields(){ ClusteringModel clusteringModel = getModel(); List<ClusteringField> clusteringFields = clusteringModel.getClusteringFields(); List<ClusteringField> result = new ArrayList<>(clusteringFields.size()); for(int i = 0, max = clusteringFields.size(); i < max; i++){ ClusteringField clusteringField = clusteringFields.get(i); ClusteringField.CenterField centerField = clusteringField.getCenterField(); switch(centerField){ case TRUE: result.add(clusteringField); break; case FALSE: break; default: throw new UnsupportedAttributeException(clusteringField, centerField); } } return result; }
/** * Create an instance of {@link ClusteringField } * */ public ClusteringField createClusteringField() { return new ClusteringField(); }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getComparisons()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public VisitorAction visit(ClusteringField clusteringField){ process(clusteringField.getField()); return super.visit(clusteringField); }
@Override public ClusteringField addExtensions(Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
clusteringFields.add(new ClusteringField( FieldName.create("x")).setCenterField(ClusteringField.CenterField.TRUE)); clusteringFields.add(new ClusteringField( FieldName.create("y")).setCenterField(ClusteringField.CenterField.TRUE));
@Override protected <V extends Number> Map<FieldName, ClusterAffinityDistribution<V>> evaluateClustering(ValueFactory<V> valueFactory, EvaluationContext context){ ClusteringModel clusteringModel = getModel(); ComparisonMeasure comparisonMeasure = clusteringModel.getComparisonMeasure(); List<ClusteringField> clusteringFields = getCenterClusteringFields(); List<FieldValue> values = new ArrayList<>(clusteringFields.size()); for(int i = 0, max = clusteringFields.size(); i < max; i++){ ClusteringField clusteringField = clusteringFields.get(i); FieldName name = clusteringField.getField(); if(name == null){ throw new MissingAttributeException(clusteringField, PMMLAttributes.CLUSTERINGFIELD_FIELD); } FieldValue value = context.evaluate(name); values.add(value); } ClusterAffinityDistribution<V> result; Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure); if(measure instanceof Similarity){ result = evaluateSimilarity(valueFactory, comparisonMeasure, clusteringFields, values); } else if(measure instanceof Distance){ result = evaluateDistance(valueFactory, comparisonMeasure, clusteringFields, values); } else { throw new UnsupportedElementException(measure); } // "For clustering models, the identifier of the winning cluster is returned as the predictedValue" result.computeResult(DataType.STRING); return Collections.singletonMap(getTargetName(), result); }