@Override public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel) { final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo(); logisticRegressionModelInfo.setWeights(sparkLRModel.coefficients().toArray()); logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept()); logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses()); logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures()); logisticRegressionModelInfo.setThreshold(sparkLRModel.getThreshold()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(sparkLRModel.getFeaturesCol()); logisticRegressionModelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(sparkLRModel.getPredictionCol()); outputKeys.add(sparkLRModel.getProbabilityCol()); logisticRegressionModelInfo.setOutputKeys(outputKeys); return logisticRegressionModelInfo; }
@Override public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) { final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo(); logisticRegressionModelInfo.setWeights(sparkLRModel.coefficients().toArray()); logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept()); logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses()); logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures()); logisticRegressionModelInfo.setThreshold(sparkLRModel.getThreshold()); logisticRegressionModelInfo.setProbabilityKey(sparkLRModel.getProbabilityCol()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(sparkLRModel.getFeaturesCol()); logisticRegressionModelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(sparkLRModel.getPredictionCol()); outputKeys.add(sparkLRModel.getProbabilityCol()); logisticRegressionModelInfo.setOutputKeys(outputKeys); return logisticRegressionModelInfo; }
List<Double> coefficients = new ArrayList<>(VectorUtil.toList(model.coefficients()));
+ lrModel.coefficients() + " Intercept: " + lrModel.intercept());