public ResultAnalyzer(Collection<String> labelSet, String defaultLabel) { confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel); summarizer = new OnlineSummarizer(); }
public double getReliability() { int count = 0; double accuracy = 0; for (String label: labelMap.keySet()) { if (!label.equals(defaultLabel)) { accuracy += getAccuracy(label); } count++; } return accuracy / count; }
public ConfusionMatrix merge(ConfusionMatrix b) { Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match"); for (String correctLabel : this.labelMap.keySet()) { for (String classifiedLabel : this.labelMap.keySet()) { incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel)); } } return this; }
returnString.append("-------------------------------------------------------\n"); RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats(); returnString.append(StringUtils.rightPad("Kappa", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n'); returnString.append(StringUtils.rightPad("Accuracy", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n"); returnString.append(StringUtils.rightPad("Reliability", 40)).append( StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n"); StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n'); returnString.append(StringUtils.rightPad("Weighted precision", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10)).append('\n'); returnString.append(StringUtils.rightPad("Weighted recall", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10)).append('\n'); returnString.append(StringUtils.rightPad("Weighted F1 score", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10)).append('\n');
/** * Example taken from * http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html */ @Test public void testPrecisionRecallAndF1ScoreAsScikitLearn() { Collection<String> labelList = Arrays.asList("0", "1", "2"); ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, "DEFAULT"); confusionMatrix.putCount("0", "0", 2); confusionMatrix.putCount("1", "0", 1); confusionMatrix.putCount("1", "2", 1); confusionMatrix.putCount("2", "1", 2); double delta = 0.001; assertEquals(0.222, confusionMatrix.getWeightedPrecision(), delta); assertEquals(0.333, confusionMatrix.getWeightedRecall(), delta); assertEquals(0.266, confusionMatrix.getWeightedF1score(), delta); }
returnString.append("-------------------------------------------------------").append('\n'); int unclassified = getTotal(defaultLabel); for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { if (entry.getKey().equals(defaultLabel) && unclassified == 0) { returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t'); StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t'); labelTotal += getCount(correctLabel, classifiedLabel); .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)) .append(" = ").append(correctLabel).append('\n');
public static void main(String[] args) throws IOException { String inputFile = args[1]; BufferedReader in = new BufferedReader(new FileReader(inputFile)); List<String> symbols = new ArrayList<String>(); String line = in.readLine(); while (line != null) { String[] pieces = line.split(","); if (!symbols.contains(pieces[0])) { symbols.add(pieces[0]); } line = in.readLine(); } ConfusionMatrix x2 = new ConfusionMatrix(symbols, "unknown"); in = new BufferedReader(new FileReader(inputFile)); line = in.readLine(); while (line != null) { String[] pieces = line.split(","); String trueValue = pieces[0]; String estimatedValue = pieces[1]; x2.addInstance(trueValue, estimatedValue); line = in.readLine(); } System.out.printf("%s\n\n", x2.toString()); } }
private static ConfusionMatrix fillConfusionMatrix(int[][] values, String[] labels, String defaultLabel) { Collection<String> labelList = Lists.newArrayList(); labelList.add(labels[0]); labelList.add(labels[1]); ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, defaultLabel); confusionMatrix.putCount("Label1", "Label1", values[0][0]); confusionMatrix.putCount("Label1", "Label2", values[0][1]); confusionMatrix.putCount("Label2", "Label1", values[1][0]); confusionMatrix.putCount("Label2", "Label2", values[1][1]); confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER[0]); confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]); return confusionMatrix; }
returnString.append("-------------------------------------------------------\n"); RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats(); returnString.append(StringUtils.rightPad("Kappa", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n'); returnString.append(StringUtils.rightPad("Accuracy", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n"); returnString.append(StringUtils.rightPad("Reliability", 40)).append( StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n");
private static void checkAccuracy(ConfusionMatrix cm) { Collection<String> labelstrs = cm.getLabels(); assertEquals(3, labelstrs.size()); assertEquals(25.0, cm.getAccuracy("Label1"), EPSILON); assertEquals(55.5555555, cm.getAccuracy("Label2"), EPSILON); assertTrue(Double.isNaN(cm.getAccuracy("other"))); }
private static void checkValues(ConfusionMatrix cm) { int[][] counts = cm.getConfusionMatrix(); cm.toString(); assertEquals(counts.length, counts[0].length); assertEquals(3, counts.length); assertEquals(VALUES[0][0], counts[0][0]); assertEquals(VALUES[0][1], counts[0][1]); assertEquals(VALUES[1][0], counts[1][0]); assertEquals(VALUES[1][1], counts[1][1]); assertTrue(Arrays.equals(new int[3], counts[2])); // zeros assertEquals(OTHER[0], counts[0][2]); assertEquals(OTHER[1], counts[1][2]); assertEquals(3, cm.getLabels().size()); assertTrue(cm.getLabels().contains(LABELS[0])); assertTrue(cm.getLabels().contains(LABELS[1])); assertTrue(cm.getLabels().contains(DEFAULT_LABEL)); }
@Test public void testGetMatrix() { ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL); Matrix m = confusionMatrix.getMatrix(); Map<String, Integer> rowLabels = m.getRowLabelBindings(); assertEquals(confusionMatrix.getLabels().size(), m.numCols()); assertTrue(rowLabels.keySet().contains(LABELS[0])); assertTrue(rowLabels.keySet().contains(LABELS[1])); assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL)); assertEquals(2, confusionMatrix.getCorrect(LABELS[0])); assertEquals(20, confusionMatrix.getCorrect(LABELS[1])); assertEquals(0, confusionMatrix.getCorrect(DEFAULT_LABEL)); }
/** * * @param correctLabel * The correct label * @param classifiedResult * The classified result * @return whether the instance was correct or not */ public boolean addInstance(String correctLabel, ClassifierResult classifiedResult) { boolean result = correctLabel.equals(classifiedResult.getLabel()); if (result) { correctlyClassified++; } else { incorrectlyClassified++; } confusionMatrix.addInstance(correctLabel, classifiedResult); if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE) { summarizer.add(classifiedResult.getLogLikelihood()); hasLL = true; } return result; }
returnString.append("-------------------------------------------------------\n"); RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats(); returnString.append(StringUtils.rightPad("Kappa", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n'); returnString.append(StringUtils.rightPad("Accuracy", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n"); returnString.append(StringUtils.rightPad("Reliability", 40)).append( StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n"); StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n'); returnString.append(StringUtils.rightPad("Weighted precision", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10)).append('\n'); returnString.append(StringUtils.rightPad("Weighted recall", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10)).append('\n'); returnString.append(StringUtils.rightPad("Weighted F1 score", 40)).append( StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10)).append('\n');
returnString.append("-------------------------------------------------------").append('\n'); int unclassified = getTotal(defaultLabel); for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { if (entry.getKey().equals(defaultLabel) && unclassified == 0) { returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t'); StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t'); labelTotal += getCount(correctLabel, classifiedLabel); .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)) .append(" = ").append(correctLabel).append('\n');
/** * * @param correctLabel * The correct label * @param classifiedResult * The classified result * @return whether the instance was correct or not */ public boolean addInstance(String correctLabel, ClassifierResult classifiedResult) { boolean result = correctLabel.equals(classifiedResult.getLabel()); if (result) { correctlyClassified++; } else { incorrectlyClassified++; } confusionMatrix.addInstance(correctLabel, classifiedResult); if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE) { summarizer.add(classifiedResult.getLogLikelihood()); hasLL = true; } return result; }
public ConfusionMatrix merge(ConfusionMatrix b) { Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match"); for (String correctLabel : this.labelMap.keySet()) { for (String classifiedLabel : this.labelMap.keySet()) { incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel)); } } return this; }
returnString.append("-------------------------------------------------------").append('\n'); int unclassified = getTotal(defaultLabel); for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { if (entry.getKey().equals(defaultLabel) && unclassified == 0) { returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t'); StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t'); labelTotal += getCount(correctLabel, classifiedLabel); .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)) .append(" = ").append(correctLabel).append('\n');