/** * Validates that the encoded PMML model received matches expected schema. * * @param pmml {@link PMML} encoding of KMeans Clustering * @param schema expected schema attributes of KMeans Clustering */ public static void validatePMMLVsSchema(PMML pmml, InputSchema schema) { List<Model> models = pmml.getModels(); Preconditions.checkArgument(models.size() == 1, "Should have exactly one model, but had %s", models.size()); Model model = models.get(0); Preconditions.checkArgument(model instanceof ClusteringModel); Preconditions.checkArgument(model.getMiningFunction() == MiningFunction.CLUSTERING); DataDictionary dictionary = pmml.getDataDictionary(); Preconditions.checkArgument( schema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(dictionary)), "Feature names in schema don't match names in PMML"); MiningSchema miningSchema = model.getMiningSchema(); Preconditions.checkArgument(schema.getFeatureNames().equals( AppPMMLUtils.getFeatureNames(miningSchema))); }
public JavaModel(Model model){ setModelName(model.getModelName()); setMiningFunction(model.getMiningFunction()); setAlgorithmName(model.getAlgorithmName()); setScorable(model.isScorable()); setMathContext(model.getMathContext()); setMiningSchema(model.getMiningSchema()); setOutput(model.getOutput()); setModelStats(model.getModelStats()); setModelExplanation(model.getModelExplanation()); setTargets(model.getTargets()); setLocalTransformations(model.getLocalTransformations()); setModelVerification(model.getModelVerification()); }
MiningSchema miningSchema = model.getMiningSchema(); int targetIndex = Objects.requireNonNull(AppPMMLUtils.findTargetIndex(miningSchema));
MiningFunction miningFunction = model.getMiningFunction(); if(miningFunction == null){ throw new MissingAttributeException(MissingAttributeException.formatMessage(XPathUtil.formatElement(model.getClass()) + "@functionName"), model); MiningSchema miningSchema = model.getMiningSchema(); if(miningSchema == null){ throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(model.getClass()) + "/" + XPathUtil.formatElement(MiningSchema.class)), model); LocalTransformations localTransformations = model.getLocalTransformations(); if(localTransformations != null && localTransformations.hasDerivedFields()){ this.localDerivedFields = CacheUtil.getValue(localTransformations, ModelEvaluator.localDerivedFieldCache); Targets targets = model.getTargets(); if(targets != null && targets.hasTargets()){ this.targets = CacheUtil.getValue(targets, ModelEvaluator.targetCache); Output output = model.getOutput(); if(output != null && output.hasOutputFields()){ this.outputFields = CacheUtil.getValue(output, ModelEvaluator.outputFieldCache);
if((transformer instanceof GeneralizedLinearRegressionModel) && (MiningFunction.CLASSIFICATION).equals(model.getMiningFunction())){ break hasPredictionCol; MiningSchema miningSchema = model.getMiningSchema(); model.setModelVerification(ModelUtil.createModelVerification(data));
@Test public void testFromString() throws Exception { PMML model = buildDummyModel(); PMML model2 = PMMLUtils.fromString(PMMLUtils.toString(model)); assertEquals(model.getHeader().getApplication().getName(), model2.getHeader().getApplication().getName()); assertEquals(model.getModels().get(0).getMiningFunction(), model2.getModels().get(0).getMiningFunction()); }
static public MiningModel createModelChain(List<? extends Model> models, Schema schema){ if(models.size() < 1){ throw new IllegalArgumentException(); } Segmentation segmentation = createSegmentation(Segmentation.MultipleModelMethod.MODEL_CHAIN, models); Model lastModel = Iterables.getLast(models); MiningModel miningModel = new MiningModel(lastModel.getMiningFunction(), ModelUtil.createMiningSchema(schema.getLabel())) .setMathContext(ModelUtil.simplifyMathContext(lastModel.getMathContext())) .setSegmentation(segmentation); return miningModel; }
public Map<FieldName, ?> evaluate(ModelEvaluationContext context){ M model = getModel(); if(!model.isScorable()){ throw new EvaluationException("Model is not scorable", model); MathContext mathContext = model.getMathContext(); switch(mathContext){ case FLOAT: MiningFunction miningFunction = model.getMiningFunction(); switch(miningFunction){ case REGRESSION:
MiningFunction miningFunction = segmentModel.getMiningFunction(); switch(miningFunction){ case REGRESSION: .addOutputFields(outputField); segmentModel.setOutput(output);
static private List<Output> getEarlierOutputs(Segmentation segmentation, Segment targetSegment){ List<Output> result = new ArrayList<>(); Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod(); switch(multipleModelMethod){ case MODEL_CHAIN: break; default: return Collections.emptyList(); } List<Segment> segments = segmentation.getSegments(); for(Segment segment : segments){ Model model = segment.getModel(); if(targetSegment != null && (targetSegment).equals(segment)){ break; } Output output = model.getOutput(); if(output != null){ result.add(output); } } return result; } }
ModelVerification modelVerification = model.getModelVerification(); if(modelVerification != null){ throw new IllegalArgumentException("Model verification data is already defined"); model.setModelVerification(modelVerification);
public static Model getModelByName(PMML pmml, String name) { for(Model model: pmml.getModels()) { if(model.getModelName().equals(name)) { return model; } } throw new RuntimeException("No such model: " + name); }
@Override public PMMLObject popParent(){ PMMLObject parent = super.popParent(); if(parent instanceof Model){ Model model = (Model)parent; LocalTransformations localTransformations = model.getLocalTransformations(); if(localTransformations != null){ processLocalTransformations(localTransformations); if(!localTransformations.hasDerivedFields()){ model.setLocalTransformations(null); } } } else if(parent instanceof PMML){ PMML pmml = (PMML)parent; TransformationDictionary transformationDictionary = pmml.getTransformationDictionary(); if(transformationDictionary != null){ processTransformationDictionary(transformationDictionary); if(!transformationDictionary.hasDefineFunctions() && !transformationDictionary.hasDerivedFields()){ pmml.setTransformationDictionary(null); } } } return parent; }
model.setModelVerification(ModelUtil.createModelVerification(data));
static public Model findModel(PMML pmml, String modelName){ Model model; if(modelName != null){ model = PMMLUtil.findModel(pmml, (Model object) -> Objects.equals(object.getModelName(), modelName), "<Model>@modelName=" + modelName); } else { model = PMMLUtil.findModel(pmml, (Model object) -> object.isScorable(), "<Model>@isScorable=true"); } return model; }
@Override public VisitorAction visit(Model model){ LocalTransformations localTransformations = model.getLocalTransformations(); if(localTransformations != null && localTransformations.hasDerivedFields()){ declare(model, localTransformations.getDerivedFields()); } return super.visit(model); }
public MathContext getMathContext(){ M model = getModel(); return model.getMathContext(); }
protected EvaluationException createMiningSchemaException(String message){ M model = getModel(); MiningSchema miningSchema = model.getMiningSchema(); return new EvaluationException(message, miningSchema); }
@Override public MiningFunction getMiningFunction(){ M model = getModel(); return model.getMiningFunction(); }
static private List<Output> getEarlierOutputs(Segmentation segmentation, Segment targetSegment){ List<Output> result = new ArrayList<>(); Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod(); switch(multipleModelMethod){ case MODEL_CHAIN: break; default: return Collections.emptyList(); } List<Segment> segments = segmentation.getSegments(); for(Segment segment : segments){ Model model = segment.getModel(); if(targetSegment != null && (targetSegment).equals(segment)){ break; } Output output = model.getOutput(); if(output != null){ result.add(output); } } return result; } }