/** * This constructor allows for ROC to be calculated in addition to the standard evaluation metrics, when the * rocBinarySteps arg is non-null. See {@link ROCBinary} for more details * * @param size Number of outputs * @param rocBinarySteps Constructor arg for {@link ROCBinary#ROCBinary(int)} */ public EvaluationBinary(int size, Integer rocBinarySteps) { countTruePositive = new int[size]; countFalsePositive = new int[size]; countTrueNegative = new int[size]; countFalseNegative = new int[size]; if (rocBinarySteps != null) { rocBinary = new ROCBinary(rocBinarySteps); } }
/** * Macro-average AUC for all outcomes * @return the (macro-)average AUC for all outcomes. */ public double calculateAverageAuc() { double ret = 0.0; for (int i = 0; i < numLabels(); i++) { ret += calculateAUC(i); } return ret / (double) numLabels(); }
@Override public void eval(INDArray labels, INDArray networkPredictions) { eval(labels, networkPredictions, (INDArray) null); }
double auc = calculateAUC(i); sb.append("\n").append(String.format(pattern, label, auc, getCountActualPositive(i), getCountActualNegative(i)));
/** * Get the actual positive count (accounting for any masking) for the specified output/column * * @param outputNum Index of the output (0 to {@link #numLabels()}-1) */ public long getCountActualPositive(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualPositive(); }
@Override public String stats() { return stats(DEFAULT_STATS_PRECISION); }
if (rocBinary != null) { args = new ArrayList<>(args); args.add(rocBinary.calculateAUC(i));
@Override public void merge(EvaluationBinary other) { if (other.countTruePositive == null) { //Other is empty - no op return; } if (countTruePositive == null) { //This evaluation is empty -> take results from other this.countTruePositive = other.countTruePositive; this.countFalsePositive = other.countFalsePositive; this.countTrueNegative = other.countTrueNegative; this.countFalseNegative = other.countFalseNegative; this.rocBinary = other.rocBinary; } else { if (this.countTruePositive.length != other.countTruePositive.length) { throw new IllegalStateException("Cannot merge EvaluationBinary instances with different sizes. This " + "size: " + this.countTruePositive.length + ", other size: " + other.countTruePositive.length); } //Both have stats addInPlace(this.countTruePositive, other.countTruePositive); addInPlace(this.countTrueNegative, other.countTrueNegative); addInPlace(this.countFalsePositive, other.countFalsePositive); addInPlace(this.countFalseNegative, other.countFalseNegative); if (this.rocBinary != null) { this.rocBinary.merge(other.rocBinary); } } }
evalTimeSeries(labels, networkPredictions, maskArray); return;
/** * Get the ROC curve for the specified output * @param outputNum Number of the output to get the ROC curve for * @return ROC curve */ public RocCurve getRocCurve(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getRocCurve(); }
/** * Get the Precision-Recall curve for the specified output * @param outputNum Number of the output to get the P-R curve for * @return Precision recall curve */ public PrecisionRecallCurve getPrecisionRecallCurve(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getPrecisionRecallCurve(); }
rocBinary.eval(labels, networkPredictions, maskArray);
/** * Calculate the AUC - Area Under (ROC) Curve<br> * Utilizes trapezoidal integration internally * * @param outputNum Output number to calculate AUC for * @return AUC */ public double calculateAUC(int outputNum) { assertIndex(outputNum); return underlying[outputNum].calculateAUC(); }
/** * Get the actual negative count (accounting for any masking) for the specified output/column * * @param outputNum Index of the output (0 to {@link #numLabels()}-1) */ public long getCountActualNegative(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualNegative(); }