IEvaluation[] evals = net.doEvaluation(test, new Evaluation(), new ROCMultiClass()); long endEval = System.currentTimeMillis(); .append(" trainMS ").append(end - start).append(" evalMS ").append(endEval - startEval) .append(" accuracy ").append(e.accuracy()).append(" f1 ").append(e.f1()) .append(" AvgAUC ").append(r.calculateAverageAUC()).append(" AvgAUPRC ").append(r.calculateAverageAUCPR()).append("\n");
int n = rocMultiClass.getNumClasses(); RocCurve roc = rocMultiClass.getRocCurve(i); String headerText = "Class " + i; if (classNames != null && classNames.size() > i) { Component c = getRocFromPoints(ROC_TITLE, roc, rocMultiClass.getCountActualPositive(i), rocMultiClass.getCountActualNegative(i), rocMultiClass.calculateAUC(i), rocMultiClass.calculateAUCPR(i)); Component c2 = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, rocMultiClass.getPrecisionRecallCurve(i)); components.add(headerDivLeft); components.add(headerDiv);
public String stats(int printPrecision) { StringBuilder sb = new StringBuilder(); int maxLabelsLength = 15; if (labels != null) { for (String s : labels) { maxLabelsLength = Math.max(s.length(), maxLabelsLength); } } String patternHeader = "%-" + (maxLabelsLength + 5) + "s%-12s%-10s%-10s"; String header = String.format(patternHeader, "Label", "AUC", "# Pos", "# Neg"); String pattern = "%-" + (maxLabelsLength + 5) + "s" //Label + "%-12." + printPrecision + "f" //AUC + "%-10d%-10d"; //Count pos, count neg sb.append(header); if (underlying != null) { for (int i = 0; i < underlying.length; i++) { double auc = calculateAUC(i); String label = (labels == null ? String.valueOf(i) : labels.get(i)); sb.append("\n").append(String.format(pattern, label, auc, getCountActualPositive(i), getCountActualNegative(i))); } sb.append("Average AUC: ").append(String.format("%-12." + printPrecision + "f", calculateAverageAUC())); } else { //Empty evaluation sb.append("\n-- No Data --\n"); } return sb.toString(); }
/** * Evaluate the network on the specified data, using the {@link ROCMultiClass} class * * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass} * @return Multi-class ROC evaluation on the given dataset */ public ROCMultiClass evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) { return doEvaluation(iterator, new ROCMultiClass(rocThresholdSteps))[0]; }
/** * Calculate the macro-average (one-vs-all) AUC for all classes */ public double calculateAverageAUC() { assertIndex(0); double sum = 0.0; for (int i = 0; i < underlying.length; i++) { sum += calculateAUC(i); } return sum / underlying.length; }
/** * Get the (one vs. all) Precision-Recall curve for the specified class * @param classIdx Class to get the P-R curve for * @return Precision recall curve for the given class */ public PrecisionRecallCurve getPrecisionRecallCurve(int classIdx) { assertIndex(classIdx); return underlying[classIdx].getPrecisionRecallCurve(); }
if (labels.rank() == 3 && predictions.rank() == 3) { evalTimeSeries(labels, predictions);
/** * Evaluate the network on the specified data, using the {@link ROCMultiClass} class * * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass} * @return Multi-class ROC evaluation on the given dataset */ public ROCMultiClass evaluateROCMultiClass(MultiDataSetIterator iterator, int rocThresholdSteps) { return doEvaluation(iterator, new ROCMultiClass(rocThresholdSteps))[0]; }
/** * Calculate the AUC - Area Under ROC Curve<br> * Utilizes trapezoidal integration internally * * @return AUC */ public double calculateAUC(int classIdx) { assertIndex(classIdx); return underlying[classIdx].calculateAUC(); }
/** * Evaluate the network on the specified data, using the {@link ROCMultiClass} class * * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass} * @return Multi-class ROC evaluation on the given dataset */ public ROCMultiClass evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) { return doEvaluation(iterator, new ROCMultiClass(rocThresholdSteps))[0]; }
/** * Calculate the AUPRC - Area Under Curve Precision Recall <br> * Utilizes trapezoidal integration internally * * @return AUC */ public double calculateAUCPR(int classIdx) { assertIndex(classIdx); return underlying[classIdx].calculateAUCPR(); }
/** * Get the actual negative count (accounting for any masking) for the specified output/column * * @param outputNum Index of the class */ public long getCountActualNegative(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualNegative(); }
/** * Get the (one vs. all) ROC curve for the specified class * @param classIdx Class index to get the ROC curve for * @return ROC curve for the given class */ public RocCurve getRocCurve(int classIdx) { assertIndex(classIdx); return underlying[classIdx].getRocCurve(); }
/** * Get the actual positive count (accounting for any masking) for the specified class * * @param outputNum Index of the class */ public long getCountActualPositive(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualPositive(); }