/** * Calculate the F beta value from counts * * @param beta Beta of value to use * @param tp True positive count * @param fp False positive count * @param fn False negative count * @return F beta */ public static double fBeta(double beta, long tp, long fp, long fn) { double prec = tp / ((double) tp + fp); double recall = tp / ((double) tp + fn); return fBeta(beta, prec, recall); }
/** * Returns the false negative rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fnr as a double */ public double falseNegativeRate(Integer classLabel, double edgeCase) { double fnCount = falseNegatives(classLabel); double tpCount = truePositives(classLabel); return EvaluationUtils.falseNegativeRate((long) fnCount, (long) tpCount, edgeCase); }
/** * Returns the false positive rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fpr as a double */ public double falsePositiveRate(int classLabel, double edgeCase) { double fpCount = falsePositives(classLabel); double tnCount = trueNegatives(classLabel); return EvaluationUtils.falsePositiveRate((long) fpCount, (long) tnCount, edgeCase); }
fnCount += falseNegatives.getCount(i); double precision = EvaluationUtils.precision(tpCount, fpCount, DEFAULT_EDGE_VALUE); double recall = EvaluationUtils.recall(tpCount, fnCount, DEFAULT_EDGE_VALUE); return EvaluationUtils.gMeasure(precision, recall); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
@Override public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) { Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, predictions, labelsMask); INDArray labels2d = pair.getFirst(); INDArray predicted2d = pair.getSecond(); eval(labels2d, predicted2d); }
/** * Calculate the binary Mathews correlation coefficient, for the specified class.<br> * MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))<br> * * @param classIdx Class index to calculate Matthews correlation coefficient for */ public double matthewsCorrelation(int classIdx) { return EvaluationUtils.matthewsCorrelation((long) truePositives.getCount(classIdx), (long) falsePositives.getCount(classIdx), (long) falseNegatives.getCount(classIdx), (long) trueNegatives.getCount(classIdx)); }
/** * Calculate the G-measure for the given output * * @param output The specified output * @return The G-measure for the specified output */ public double gMeasure(int output) { double precision = precision(output); double recall = recall(output); return EvaluationUtils.gMeasure(precision, recall); }
/** * Returns the precision for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return the precision for the label */ public double precision(Integer classLabel, double edgeCase) { double tpCount = truePositives.getCount(classLabel); double fpCount = falsePositives.getCount(classLabel); return EvaluationUtils.precision((long) tpCount, (long) fpCount, edgeCase); }
/** * Returns the recall for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return Recall rate as a double */ public double recall(int classLabel, double edgeCase) { double tpCount = truePositives.getCount(classLabel); double fnCount = falseNegatives.getCount(classLabel); return EvaluationUtils.recall((long) tpCount, (long) fnCount, edgeCase); }
@Override public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) { if (labelsMask == null || labelsMask.rank() == 2) { super.evalTimeSeries(labels, predictions, labelsMask); return; } else if (labelsMask.rank() != 3) { throw new IllegalArgumentException("Labels must: must be rank 2 or 3. Got: " + labelsMask.rank()); } //Per output time series masking INDArray l2d = EvaluationUtils.reshapeTimeSeriesTo2d(labels); INDArray p2d = EvaluationUtils.reshapeTimeSeriesTo2d(predictions); INDArray m2d = EvaluationUtils.reshapeTimeSeriesTo2d(labelsMask); eval(l2d, p2d, m2d); }
private static void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb) { Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p, outMask); INDArray realOutcomes = pair.getFirst(); INDArray guesses = pair.getSecond(); // Length of real labels must be same as length of predicted labels if (realOutcomes.length() != guesses.length()) throw new IllegalArgumentException( "Unable to evaluate. Outcome matrices not same length"); INDArray guessIndex = Nd4j.argMax(guesses, 1); INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1); int nExamples = guessIndex.length(); for (int i = 0; i < nExamples; i++) { int actual = (int) realOutcomeIndex.getDouble(i); int predicted = (int) guessIndex.getDouble(i); sb.append(actual + "\t" + predicted + "\n"); } } }
/** * Calculate the Matthews correlation coefficient for the specified output * * @param outputNum Output number * @return Matthews correlation coefficient */ public double matthewsCorrelation(int outputNum) { assertIndex(outputNum); return EvaluationUtils.matthewsCorrelation(truePositives(outputNum), falsePositives(outputNum), falseNegatives(outputNum), trueNegatives(outputNum)); }
/** * Calculate the G-measure for the given output * * @param output The specified output * @return The G-measure for the specified output */ public double gMeasure(int output) { double precision = precision(output); double recall = recall(output); return EvaluationUtils.gMeasure(precision, recall); }
fpCount += falsePositives.getCount(i); return EvaluationUtils.precision(tpCount, fpCount, DEFAULT_EDGE_VALUE); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
fnCount += falseNegatives.getCount(i); return EvaluationUtils.recall(tpCount, fnCount, DEFAULT_EDGE_VALUE); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging);
predicted = predicted.dup('f'); INDArray labels2d = EvaluationUtils.reshapeTimeSeriesTo2d(labels); INDArray predicted2d = EvaluationUtils.reshapeTimeSeriesTo2d(predicted);
private void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb) { Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p, outMask); INDArray realOutcomes = pair.getFirst(); INDArray guesses = pair.getSecond(); // Length of real labels must be same as length of predicted labels if (realOutcomes.length() != guesses.length()) throw new IllegalArgumentException( "Unable to evaluate. Outcome matrices not same length"); INDArray guessIndex = Nd4j.argMax(guesses, 1); INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1); int nExamples = guessIndex.length(); for (int i = 0; i < nExamples; i++) { int actual = (int) realOutcomeIndex.getDouble(i); int predicted = (int) guessIndex.getDouble(i); sb.append( vectorize.getTagset()[actual] + "\t" + vectorize.getTagset()[predicted] + "\n"); } }
/** * Calculate the f_beta for a given class, where f_beta is defined as:<br> * (1+beta^2) * (precision * recall) / (beta^2 * precision + recall).<br> * F1 is a special case of f_beta, with beta=1.0 * * @param beta Beta value to use * @param classLabel Class label * @param defaultValue Default value to use when precision or recall is undefined (0/0 for prec. or recall) * @return F_beta */ public double fBeta(double beta, int classLabel, double defaultValue) { double precision = precision(classLabel, -1); double recall = recall(classLabel, -1); if (precision == -1 || recall == -1) { return defaultValue; } return EvaluationUtils.fBeta(beta, precision, recall); }
/** * Returns the false positive rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fpr as a double */ public double falsePositiveRate(int classLabel, double edgeCase) { double fpCount = falsePositives.getCount(classLabel); double tnCount = trueNegatives.getCount(classLabel); return EvaluationUtils.falsePositiveRate((long) fpCount, (long) tnCount, edgeCase); }
/** * Returns the false negative rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fnr as a double */ public double falseNegativeRate(Integer classLabel, double edgeCase) { double fnCount = falseNegatives.getCount(classLabel); double tpCount = truePositives.getCount(classLabel); return EvaluationUtils.falseNegativeRate((long) fnCount, (long) tpCount, edgeCase); }