@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getMiningSchema(), getOutput(), getModelStats(), getModelExplanation(), getTargets(), getLocalTransformations(), getKernel(), getVectorDictionary()); } if ((status == VisitorAction.CONTINUE)&&hasSupportVectorMachines()) { status = PMMLObject.traverse(visitor, getSupportVectorMachines()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getModelVerification()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
/** * Create an instance of {@link SupportVectorMachineModel } * */ public SupportVectorMachineModel createSupportVectorMachineModel() { return new SupportVectorMachineModel(); }
public SupportVectorMachineModelEvaluator(PMML pmml, SupportVectorMachineModel supportVectorMachineModel){ super(pmml, supportVectorMachineModel); boolean maxWins = supportVectorMachineModel.isMaxWins(); if(maxWins){ throw new UnsupportedAttributeException(supportVectorMachineModel, PMMLAttributes.SUPPORTVECTORMACHINEMODEL_MAXWINS, maxWins); } SupportVectorMachineModel.Representation representation = supportVectorMachineModel.getRepresentation(); switch(representation){ case SUPPORT_VECTORS: break; default: throw new UnsupportedAttributeException(supportVectorMachineModel, representation); } VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary(); if(vectorDictionary == null){ throw new MissingElementException(supportVectorMachineModel, PMMLElements.SUPPORTVECTORMACHINEMODEL_VECTORDICTIONARY); } VectorFields vectorFields = vectorDictionary.getVectorFields(); if(vectorFields == null){ throw new MissingElementException(vectorDictionary, PMMLElements.VECTORDICTIONARY_VECTORFIELDS); } // End if if(!supportVectorMachineModel.hasSupportVectorMachines()){ throw new MissingElementException(supportVectorMachineModel, PMMLElements.SUPPORTVECTORMACHINEMODEL_SUPPORTVECTORMACHINES); } }
@Override public SupportVectorMachineModel encodeModel(Schema schema){ int[] shape = getSupportVectorsShape(); int numberOfVectors = shape[0]; int numberOfFeatures = shape[1]; List<Integer> support = getSupport(); List<? extends Number> supportVectors = getSupportVectors(); List<Integer> supportSizes = getSupportSizes(); List<? extends Number> dualCoef = getDualCoef(); List<? extends Number> intercept = getIntercept(); SupportVectorMachineModel supportVectorMachineModel = LibSVMUtil.createClassification(new CMatrix<>(ValueUtil.asDoubles(supportVectors), numberOfVectors, numberOfFeatures), supportSizes, SupportVectorMachineUtil.formatIds(support), ValueUtil.asDoubles(intercept), ValueUtil.asDoubles(dualCoef), schema) .setKernel(SupportVectorMachineUtil.createKernel(getKernel(), getDegree(), getGamma(), getCoef0())); List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines(); for(SupportVectorMachine supportVectorMachine : supportVectorMachines){ String category = supportVectorMachine.getTargetCategory(); // LibSVM: (decisionFunction > 0 ? first : second) // PMML: (decisionFunction < 0 ? first : second) supportVectorMachine.setTargetCategory(supportVectorMachine.getAlternateTargetCategory()); supportVectorMachine.setAlternateTargetCategory(category); } return supportVectorMachineModel; }
SupportVectorMachineModel supportVectorMachineModel = getModel(); List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines(); String alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory(); threshold = supportVectorMachineModel.getThreshold();
.setOutput(ModelUtil.createPredictedOutput(FieldName.create("decisionFunction"), OpType.CONTINUOUS, DataType.DOUBLE, outlier)); RDoubleVector yScaledScale = (RDoubleVector)yScale.getValue("scaled:scale"); supportVectorMachineModel.setTargets(ModelUtil.createRescaleTargets(-1d * yScaledScale.asScalar(), yScaledCenter.asScalar(), (ContinuousLabel)schema.getLabel())); supportVectorMachineModel.setKernel(svmKernel.createKernel(degree.asScalar(), gamma.asScalar(), coef0.asScalar()));
public SupportVectorMachineModel addSupportVectorMachines(SupportVectorMachine... supportVectorMachines) { getSupportVectorMachines().addAll(Arrays.asList(supportVectorMachines)); return this; }
List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines(); String alternateBinaryTargetCategory = supportVectorMachineModel.getAlternateBinaryTargetCategory(); if(alternateBinaryTargetCategory != null){
private Object createInput(EvaluationContext context){ SupportVectorMachineModel supportVectorMachineModel = getModel(); VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
@Override public SupportVectorMachineModel encodeModel(Schema schema){ Transformation outlier = new OutlierTransformation(){ @Override public Expression createExpression(FieldRef fieldRef){ return PMMLUtil.createApply("lessOrEqual", fieldRef, PMMLUtil.createConstant(0d)); } }; SupportVectorMachineModel supportVectorMachineModel = super.encodeModel(schema) .setOutput(ModelUtil.createPredictedOutput(FieldName.create("decisionFunction"), OpType.CONTINUOUS, DataType.DOUBLE, outlier)); return supportVectorMachineModel; } }
Kernel kernel = supportVectorMachineModel.getKernel(); if(kernel == null){ throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(supportVectorMachineModel.getClass()) + "/<Kernel>"), supportVectorMachine);
@Override public SupportVectorMachineModel encodeModel(Schema schema){ int[] shape = getSupportVectorsShape(); int numberOfVectors = shape[0]; int numberOfFeatures = shape[1]; List<Integer> support = getSupport(); List<? extends Number> supportVectors = getSupportVectors(); List<? extends Number> dualCoef = getDualCoef(); List<? extends Number> intercept = getIntercept(); SupportVectorMachineModel supportVectorMachineModel = LibSVMUtil.createRegression(new CMatrix<>(ValueUtil.asDoubles(supportVectors), numberOfVectors, numberOfFeatures), SupportVectorMachineUtil.formatIds(support), ValueUtil.asDouble(Iterables.getOnlyElement(intercept)), ValueUtil.asDoubles(dualCoef), schema) .setKernel(SupportVectorMachineUtil.createKernel(getKernel(), getDegree(), getGamma(), getCoef0())); return supportVectorMachineModel; }
@Override public SupportVectorMachineModel addExtensions(org.dmg.pmml.Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
public SupportVectorMachineModel addSupportVectorMachines(SupportVectorMachine... supportVectorMachines) { getSupportVectorMachines().addAll(Arrays.asList(supportVectorMachines)); return this; }
static private Map<String, Object> parseVectorDictionary(SupportVectorMachineModel supportVectorMachineModel){ VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
@Override public SupportVectorMachineModel addExtensions(org.dmg.pmml.Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getMiningSchema(), getOutput(), getModelStats(), getModelExplanation(), getTargets(), getLocalTransformations(), getKernel(), getVectorDictionary()); } if ((status == VisitorAction.CONTINUE)&&hasSupportVectorMachines()) { status = PMMLObject.traverse(visitor, getSupportVectorMachines()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getModelVerification()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override protected <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext context){ SupportVectorMachineModel supportVectorMachineModel = getModel(); List<SupportVectorMachine> supportVectorMachines = supportVectorMachineModel.getSupportVectorMachines(); if(supportVectorMachines.size() != 1){ throw new InvalidElementListException(supportVectorMachines); } SupportVectorMachine supportVectorMachine = supportVectorMachines.get(0); Object input = createInput(context); Value<V> result = evaluateSupportVectorMachine(valueFactory, supportVectorMachine, input); return TargetUtil.evaluateRegression(getTargetField(), result); }
/** * Create an instance of {@link SupportVectorMachineModel } * */ public SupportVectorMachineModel createSupportVectorMachineModel() { return new SupportVectorMachineModel(); }