/** * Calculate the Matthews correlation coefficient for the specified output * * @param outputNum Output number * @return Matthews correlation coefficient */ public double matthewsCorrelation(int outputNum) { assertIndex(outputNum); return EvaluationUtils.matthewsCorrelation(truePositives(outputNum), falsePositives(outputNum), falseNegatives(outputNum), trueNegatives(outputNum)); }
@Override public void eval(INDArray labels, INDArray networkPredictions) { eval(labels, networkPredictions, (INDArray) null); }
/** * Get the accuracy for the specified output */ public double accuracy(int outputNum) { assertIndex(outputNum); return (countTruePositive[outputNum] + countTrueNegative[outputNum]) / (double) totalCount(outputNum); }
int totalCount = totalCount(i); double acc = accuracy(i); double f1 = f1(i); double precision = precision(i); double recall = recall(i); truePositives(i), trueNegatives(i), falsePositives(i), falseNegatives(i)); if (rocBinary != null) { args = new ArrayList<>(args);
/** * Calculate the F-beta value for the given output * * @param beta Beta value to use * @param outputNum Output number * @return F-beta for the given output */ public double fBeta(double beta, int outputNum) { assertIndex(outputNum); double precision = precision(outputNum); double recall = recall(outputNum); return EvaluationUtils.fBeta(beta, precision, recall); }
/** * Get the false positives count for the specified output */ public int falsePositives(int outputNum) { assertIndex(outputNum); return countFalsePositive[outputNum]; }
/** * Returns the false positive rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fpr as a double */ public double falsePositiveRate(int classLabel, double edgeCase) { double fpCount = falsePositives(classLabel); double tnCount = trueNegatives(classLabel); return EvaluationUtils.falsePositiveRate((long) fpCount, (long) tnCount, edgeCase); }
/** * Returns the false negative rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fnr as a double */ public double falseNegativeRate(Integer classLabel, double edgeCase) { double fnCount = falseNegatives(classLabel); double tpCount = truePositives(classLabel); return EvaluationUtils.falseNegativeRate((long) fnCount, (long) tpCount, edgeCase); }
evalTimeSeries(labels, networkPredictions, maskArray); return; addInPlace(countTruePositive, tpCount); addInPlace(countFalsePositive, fpCount); addInPlace(countTrueNegative, tnCount); addInPlace(countFalseNegative, fnCount);
/** * Get the F1 score for the specified output */ public double f1(int outputNum) { return fBeta(1.0, outputNum); }
/** * Returns the false negative rate for a given label * * @param classLabel the label * @return fnr as a double */ public double falseNegativeRate(Integer classLabel) { return falseNegativeRate(classLabel, DEFAULT_EDGE_VALUE); }
@Override public void merge(EvaluationBinary other) { if (other.countTruePositive == null) { //Other is empty - no op return; } if (countTruePositive == null) { //This evaluation is empty -> take results from other this.countTruePositive = other.countTruePositive; this.countFalsePositive = other.countFalsePositive; this.countTrueNegative = other.countTrueNegative; this.countFalseNegative = other.countFalseNegative; this.rocBinary = other.rocBinary; } else { if (this.countTruePositive.length != other.countTruePositive.length) { throw new IllegalStateException("Cannot merge EvaluationBinary instances with different sizes. This " + "size: " + this.countTruePositive.length + ", other size: " + other.countTruePositive.length); } //Both have stats addInPlace(this.countTruePositive, other.countTruePositive); addInPlace(this.countTrueNegative, other.countTrueNegative); addInPlace(this.countFalsePositive, other.countFalsePositive); addInPlace(this.countFalseNegative, other.countFalseNegative); if (this.rocBinary != null) { this.rocBinary.merge(other.rocBinary); } } }
/** * Get the true negatives count for the specified output */ public int trueNegatives(int outputNum) { assertIndex(outputNum); return countTrueNegative[outputNum]; }
/** * Get the false negatives count for the specified output */ public int falseNegatives(int outputNum) { assertIndex(outputNum); return countFalseNegative[outputNum]; }
@Override public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) { if (labelsMask == null || labelsMask.rank() == 2) { super.evalTimeSeries(labels, predictions, labelsMask); return; } else if (labelsMask.rank() != 3) { throw new IllegalArgumentException("Labels must: must be rank 2 or 3. Got: " + labelsMask.rank()); } //Per output time series masking INDArray l2d = EvaluationUtils.reshapeTimeSeriesTo2d(labels); INDArray p2d = EvaluationUtils.reshapeTimeSeriesTo2d(predictions); INDArray m2d = EvaluationUtils.reshapeTimeSeriesTo2d(labelsMask); eval(l2d, p2d, m2d); }
/** * Get the precision (tp / (tp + fp)) for the specified output */ public double precision(int outputNum) { assertIndex(outputNum); //double precision = tp / (double) (tp + fp); return countTruePositive[outputNum] / (double) (countTruePositive[outputNum] + countFalsePositive[outputNum]); }
/** * Get the recall (tp / (tp + fn)) for the specified output */ public double recall(int outputNum) { assertIndex(outputNum); return countTruePositive[outputNum] / (double) (countTruePositive[outputNum] + countFalseNegative[outputNum]); }
/** * Get the total number of values for the specified column, accounting for any masking */ public int totalCount(int outputNum) { assertIndex(outputNum); return countTruePositive[outputNum] + countTrueNegative[outputNum] + countFalseNegative[outputNum] + countFalsePositive[outputNum]; }