/** * Calculate the macro-average (one-vs-all) AUC for all classes */ public double calculateAverageAUC() { assertIndex(0); double sum = 0.0; for (int i = 0; i < underlying.length; i++) { sum += calculateAUC(i); } return sum / underlying.length; }
Component headerDiv = new ComponentDiv(HEADER_DIV_STYLE, new ComponentText(headerText, HEADER_TEXT_STYLE)); Component c = getRocFromPoints(ROC_TITLE, roc, rocMultiClass.getCountActualPositive(i), rocMultiClass.getCountActualNegative(i), rocMultiClass.calculateAUC(i), rocMultiClass.calculateAUCPR(i)); Component c2 = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, rocMultiClass.getPrecisionRecallCurve(i));
public String stats(int printPrecision) { StringBuilder sb = new StringBuilder(); int maxLabelsLength = 15; if (labels != null) { for (String s : labels) { maxLabelsLength = Math.max(s.length(), maxLabelsLength); } } String patternHeader = "%-" + (maxLabelsLength + 5) + "s%-12s%-10s%-10s"; String header = String.format(patternHeader, "Label", "AUC", "# Pos", "# Neg"); String pattern = "%-" + (maxLabelsLength + 5) + "s" //Label + "%-12." + printPrecision + "f" //AUC + "%-10d%-10d"; //Count pos, count neg sb.append(header); if (underlying != null) { for (int i = 0; i < underlying.length; i++) { double auc = calculateAUC(i); String label = (labels == null ? String.valueOf(i) : labels.get(i)); sb.append("\n").append(String.format(pattern, label, auc, getCountActualPositive(i), getCountActualNegative(i))); } sb.append("Average AUC: ").append(String.format("%-12." + printPrecision + "f", calculateAverageAUC())); } else { //Empty evaluation sb.append("\n-- No Data --\n"); } return sb.toString(); }