/** * Calculate the area under the precision/recall curve - aka AUCPR * * @return */ public double calculateAUCPR() { if (auprc != null) { return auprc; } auprc = getPrecisionRecallCurve().calculateAUPRC(); return auprc; }
/** * Get the (one vs. all) Precision-Recall curve for the specified class * @param classIdx Class to get the P-R curve for * @return Precision recall curve for the given class */ public PrecisionRecallCurve getPrecisionRecallCurve(int classIdx) { assertIndex(classIdx); return underlying[classIdx].getPrecisionRecallCurve(); }
/** * Get the Precision-Recall curve for the specified output * @param outputNum Number of the output to get the P-R curve for * @return Precision recall curve */ public PrecisionRecallCurve getPrecisionRecallCurve(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getPrecisionRecallCurve(); }
/** * Given a {@link ROC} instance, render the ROC chart and precision vs. recall charts to a stand-alone HTML file (returned as a String) * @param roc ROC to render */ public static String rocChartToHtml(ROC roc) { RocCurve rocCurve = roc.getRocCurve(); Component c = getRocFromPoints(ROC_TITLE, rocCurve, roc.getCountActualPositive(), roc.getCountActualNegative(), roc.calculateAUC(), roc.calculateAUCPR()); Component c2 = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, roc.getPrecisionRecallCurve()); return StaticPageUtil.renderHTML(c, c2); }
@Override public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { if (roc.isExact()) { //For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such //that we have them once deserialized. //Due to potentially huge size, exact mode doesn't store the original predictions in JSON roc.calculateAUC(); roc.calculateAUCPR(); } jsonGenerator.writeNumberField("thresholdSteps", roc.getThresholdSteps()); jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive()); jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative()); jsonGenerator.writeObjectField("counts", roc.getCounts()); jsonGenerator.writeNumberField("auc", roc.calculateAUC()); jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR()); if (roc.isExact()) { //Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve()); jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve()); } jsonGenerator.writeBooleanField("isExact", roc.isExact()); jsonGenerator.writeNumberField("exampleCount", roc.getExampleCount()); jsonGenerator.writeBooleanField("rocRemoveRedundantPts", roc.isRocRemoveRedundantPts()); }