/** * Create an instance of {@link GeneralRegressionModel } * */ public GeneralRegressionModel createGeneralRegressionModel() { return new GeneralRegressionModel(); }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getMiningSchema(), getOutput(), getModelStats(), getModelExplanation(), getTargets(), getLocalTransformations(), getParameterList(), getFactorList(), getCovariateList(), getPPMatrix(), getPCovMatrix(), getParamMatrix(), getEventValues(), getBaseCumHazardTables(), getModelVerification()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public VisitorAction visit(GeneralRegressionModel generalRegressionModel){ GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType(); switch(modelType){ case COX_REGRESSION: process(generalRegressionModel.getBaselineStrataVariable()); process(generalRegressionModel.getEndTimeVariable()); process(generalRegressionModel.getStartTimeVariable()); process(generalRegressionModel.getStatusVariable()); process(generalRegressionModel.getSubjectIDVariable()); // Falls through default: process(generalRegressionModel.getOffsetVariable()); process(generalRegressionModel.getTrialsVariable()); break; } return super.visit(generalRegressionModel); }
@Override public VisitorAction visit(GeneralRegressionModel generalRegressionModel){ BaseCumHazardTables baseCumHazardTables = generalRegressionModel.getBaseCumHazardTables(); if(baseCumHazardTables != null){ generalRegressionModel.setBaseCumHazardTables(new RichBaseCumHazardTables(baseCumHazardTables)); } return super.visit(generalRegressionModel); }
static private Double getOffset(GeneralRegressionModel generalRegressionModel, EvaluationContext context){ FieldName offsetVariable = generalRegressionModel.getOffsetVariable(); if(offsetVariable != null){ FieldValue value = getVariable(offsetVariable, context); return value.asDouble(); } return generalRegressionModel.getOffsetValue(); }
@Override public GeneralRegressionModel encodeModel(Schema schema){ GeneralizedLinearRegressionModel model = getTransformer(); String targetCategory = null; MiningFunction miningFunction = getMiningFunction(); switch(miningFunction){ case CLASSIFICATION: CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); if(categoricalLabel.size() != 2){ throw new IllegalArgumentException(); } targetCategory = categoricalLabel.getValue(1); break; default: break; } List<Feature> features = new ArrayList<>(schema.getFeatures()); List<Double> coefficients = new ArrayList<>(VectorUtil.toList(model.coefficients())); RegressionTableUtil.simplify(this, targetCategory, features, coefficients); GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null) .setDistribution(parseFamily(model.getFamily())) .setLinkFunction(parseLinkFunction(model.getLink())) .setLinkParameter(parseLinkParameter(model.getLink())); GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, coefficients, model.intercept(), targetCategory); return generalRegressionModel; }
public GeneralRegressionModelEvaluator(PMML pmml, GeneralRegressionModel generalRegressionModel){ super(pmml, generalRegressionModel); GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType(); if(modelType == null){ throw new MissingAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_MODELTYPE); } ParameterList parameterList = generalRegressionModel.getParameterList(); if(parameterList == null){ throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_PARAMETERLIST); } PPMatrix ppMatrix = generalRegressionModel.getPPMatrix(); if(ppMatrix == null){ throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_PPMATRIX); } ParamMatrix paramMatrix = generalRegressionModel.getParamMatrix(); if(paramMatrix == null){ throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_PARAMMATRIX); } }
private <V extends Number> Map<FieldName, ?> evaluateCoxRegression(ValueFactory<V> valueFactory, EvaluationContext context){ GeneralRegressionModel generalRegressionModel = getModel(); FieldName startTimeVariable = generalRegressionModel.getStartTimeVariable(); FieldName endTimeVariable = generalRegressionModel.getEndTimeVariable(); if(endTimeVariable == null){ throw new MissingAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_ENDTIMEVARIABLE); BaseCumHazardTables baseCumHazardTables = generalRegressionModel.getBaseCumHazardTables(); if(baseCumHazardTables == null){ throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_BASECUMHAZARDTABLES); FieldName baselineStrataVariable = generalRegressionModel.getBaselineStrataVariable();
@Override public Model encodeModel(RDoubleVector a0, RExp beta, int column, Schema schema){ Double intercept = a0.getValue(column); List<Double> coefficients = getCoefficients((S4Object)beta, column); GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERAL_LINEAR, MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null) .setDistribution(GeneralRegressionModel.Distribution.NORMAL); GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, schema.getFeatures(), coefficients, intercept, null); return generalRegressionModel; } }
@Override public GeneralRegressionModel encodeModel(Schema schema){ RGenericVector earth = getObject(); RDoubleVector coefficients = (RDoubleVector)earth.getValue("coefficients"); Double intercept = coefficients.getValue(0); List<? extends Feature> features = schema.getFeatures(); if(coefficients.size() != (features.size() + 1)){ throw new IllegalArgumentException(); } List<Double> featureCoefficients = (coefficients.getValues()).subList(1, features.size() + 1); GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null) .setLinkFunction(GeneralRegressionModel.LinkFunction.IDENTITY); GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, featureCoefficients, intercept, null); return generalRegressionModel; }
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), null, null, null) .setLinkFunction(GeneralRegressionModel.LinkFunction.LOGIT) .setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType(); PPMatrix ppMatrix = generalRegressionModel.getPPMatrix(); ParamMatrix paramMatrix = generalRegressionModel.getParamMatrix(); ParamMatrix paramMatrix = generalRegressionModel.getParamMatrix(); ParamMatrix paramMatrix = generalRegressionModel.getParamMatrix();
@Test public void inspectValueAnnotations(){ PMML pmml = createPMML(); FieldName name = FieldName.create("y"); Target target = new Target() .setField(name) .addTargetValues(createTargetValue("no event"), createTargetValue("event")); Targets targets = new Targets() .addTargets(target); GeneralRegressionModel model = new GeneralRegressionModel() .setTargets(targets); pmml.addModels(model); assertVersionRange(pmml, Version.PMML_3_0, Version.PMML_3_0); PPMatrix ppMatrix = new PPMatrix() .addPPCells(new PPCell(), new PPCell()); model.setPPMatrix(ppMatrix); assertVersionRange(pmml, Version.PMML_3_0, Version.PMML_4_3); target.setField(null); assertVersionRange(pmml, Version.PMML_4_3, Version.PMML_4_3); }
String targetReferenceCategory = generalRegressionModel.getTargetReferenceCategory(); GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType(); switch(modelType){ case GENERALIZED_LINEAR: ParamMatrix paramMatrix = generalRegressionModel.getParamMatrix();
private <V extends Number> Value<V> computeDotProduct(ValueFactory<V> valueFactory, EvaluationContext context){ GeneralRegressionModel generalRegressionModel = getModel(); Map<String, Map<String, Row>> ppMatrixMap = getPPMatrixMap(); Map<String, Row> parameterPredictorRows; if(ppMatrixMap.isEmpty()){ parameterPredictorRows = Collections.emptyMap(); } else { parameterPredictorRows = ppMatrixMap.get(null); if(parameterPredictorRows == null){ PPMatrix ppMatrix = generalRegressionModel.getPPMatrix(); throw new InvalidElementException(ppMatrix); } } Map<String, List<PCell>> paramMatrixMap = getParamMatrixMap(); List<PCell> parameterCells = paramMatrixMap.get(null); if(paramMatrixMap.size() != 1 || parameterCells == null){ ParamMatrix paramMatrix = generalRegressionModel.getParamMatrix(); throw new InvalidElementException(paramMatrix); } return computeDotProduct(valueFactory, parameterCells, parameterPredictorRows, context); }
@Override public String getSummary(){ GeneralRegressionModel generalRegressionModel = getModel(); GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType(); switch(modelType){ case COX_REGRESSION: return "Cox regression"; default: return "General regression"; } }
static private Map<String, List<PCell>> parseParamMatrix(GeneralRegressionModel generalRegressionModel){ ParamMatrix paramMatrix = generalRegressionModel.getParamMatrix(); ListMultimap<String, PCell> targetCategoryCells = groupByTargetCategory(paramMatrix.getPCells()); return asMap(targetCategoryCells); }
@Override public BiMap<FieldName, Predictor> load(GeneralRegressionModel generalRegressionModel){ return ImmutableBiMap.copyOf(parsePredictorRegistry(generalRegressionModel.getFactorList())); } });
@Override public BiMap<FieldName, Predictor> load(GeneralRegressionModel generalRegressionModel){ return ImmutableBiMap.copyOf(parsePredictorRegistry(generalRegressionModel.getCovariateList())); } });
PPMatrix ppMatrix = generalRegressionModel.getPPMatrix();