/** * Given a {@link ROC} instance, render the ROC chart and precision vs. recall charts to a stand-alone HTML file (returned as a String) * @param roc ROC to render */ public static String rocChartToHtml(ROC roc) { RocCurve rocCurve = roc.getRocCurve(); Component c = getRocFromPoints(ROC_TITLE, rocCurve, roc.getCountActualPositive(), roc.getCountActualNegative(), roc.calculateAUC(), roc.calculateAUCPR()); Component c2 = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, roc.getPrecisionRecallCurve()); return StaticPageUtil.renderHTML(c, c2); }
/** * Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} class * * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROC} * @return ROC evaluation on the given dataset */ public ROC evaluateROC(MultiDataSetIterator iterator, int rocThresholdSteps) { return doEvaluation(iterator, new ROC(rocThresholdSteps))[0]; }
@Override public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { if (roc.isExact()) { //For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such //that we have them once deserialized. //Due to potentially huge size, exact mode doesn't store the original predictions in JSON roc.calculateAUC(); roc.calculateAUCPR(); } jsonGenerator.writeNumberField("thresholdSteps", roc.getThresholdSteps()); jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive()); jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative()); jsonGenerator.writeObjectField("counts", roc.getCounts()); jsonGenerator.writeNumberField("auc", roc.calculateAUC()); jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR()); if (roc.isExact()) { //Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve()); jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve()); } jsonGenerator.writeBooleanField("isExact", roc.isExact()); jsonGenerator.writeNumberField("exampleCount", roc.getExampleCount()); jsonGenerator.writeBooleanField("rocRemoveRedundantPts", roc.isRocRemoveRedundantPts()); }
ROC roc = new ROC(1000); roc.eval(lables, predicted); System.out.println(roc.calculateAUC());
@Override public String stats() { return "AUC: [" + calculateAUC() + "]"; }
/** * Get the (one vs. all) ROC curve for the specified class * @param classIdx Class index to get the ROC curve for * @return ROC curve for the given class */ public RocCurve getRocCurve(int classIdx) { assertIndex(classIdx); return underlying[classIdx].getRocCurve(); }
/** * Calculate the area under the precision/recall curve - aka AUCPR * * @return */ public double calculateAUCPR() { if (auprc != null) { return auprc; } auprc = getPrecisionRecallCurve().calculateAUPRC(); return auprc; }
/** * Get the actual positive count (accounting for any masking) for the specified output/column * * @param outputNum Index of the output (0 to {@link #numLabels()}-1) */ public long getCountActualPositive(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualPositive(); }
/** * Get the actual negative count (accounting for any masking) for the specified output/column * * @param outputNum Index of the class */ public long getCountActualNegative(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualNegative(); }
private double getAuprc() { if (auprc != null) { return auprc; } auprc = calculateAUCPR(); return auprc; }
if (labels.rank() == 3 && predictions.rank() == 3) { evalTimeSeries(labels, predictions);
private double getAuc() { if (auc != null) { return auc; } auc = calculateAUC(); return auc; }
/** * Get the ROC curve for the specified output * @param outputNum Number of the output to get the ROC curve for * @return ROC curve */ public RocCurve getRocCurve(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getRocCurve(); }
/** * Get the (one vs. all) Precision-Recall curve for the specified class * @param classIdx Class to get the P-R curve for * @return Precision recall curve for the given class */ public PrecisionRecallCurve getPrecisionRecallCurve(int classIdx) { assertIndex(classIdx); return underlying[classIdx].getPrecisionRecallCurve(); }
/** * Get the actual positive count (accounting for any masking) for the specified class * * @param outputNum Index of the class */ public long getCountActualPositive(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualPositive(); }
/** * Get the actual negative count (accounting for any masking) for the specified output/column * * @param outputNum Index of the output (0 to {@link #numLabels()}-1) */ public long getCountActualNegative(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualNegative(); }
/** * Calculate the AUPRC - Area Under Curve Precision Recall <br> * Utilizes trapezoidal integration internally * * @return AUC */ public double calculateAUCPR(int classIdx) { assertIndex(classIdx); return underlying[classIdx].calculateAUCPR(); }
/** * Calculate the AUC - Area Under ROC Curve<br> * Utilizes trapezoidal integration internally * * @return AUC */ public double calculateAUC(int classIdx) { assertIndex(classIdx); return underlying[classIdx].calculateAUC(); }