@Override public TranslationConfidence getConfidence() { return new LinearBaselineEstimation(lmFactory, sourceFactory); }
public LinearBaselineEstimation(LanguageModelFactory lmFactory, TranslationSourceFactory sourceFactory) { this.baselineFeatures = new BaselineFeatures(lmFactory, sourceFactory); } // These weights are a linear regression on the WMT-12 training set
public double[] allFeatures(Translation translation) { final double[] features = new double[17]; features[0] = countTksInSrc(translation); features[1] = countTksInTrg(translation); features[2] = aveSrcTkLen(translation); features[3] = sourceLMProb(translation); features[4] = targetLMProb(translation); features[5] = aveOccurencesInTarget(translation); features[6] = aveTranslationCount(translation, 0.2); features[7] = aveTranslationCount(translation, 0.01); final double[] oneGramScores = percentNGramsInTopBotQuartile(translation, 1); features[8] = oneGramScores[0]; features[9] = oneGramScores[1]; final double[] twoGramScores = percentNGramsInTopBotQuartile(translation, 2); features[10] = twoGramScores[0]; features[11] = twoGramScores[1]; final double[] threeGramScores = percentNGramsInTopBotQuartile(translation, 3); features[12] = threeGramScores[0]; features[13] = threeGramScores[1]; features[14] = percentUnigramsInLM(translation); features[15] = countPunctuationInSource(translation); features[16] = countPunctuationInTarget(translation); return features; }
public double aveTranslationCount(Translation translation, double minProb) { final TranslationSource source = getSource(translation.getSourceLabel().getLanguage(), translation.getTargetLabel().getLanguage()); if (source == null) { return 0.0; } final double p = Math.log(minProb); final List<String> tokens = getTokens(translation.getSourceLabel()); int transCt = 0; for (String token : tokens) { final PhraseTable candidates = source.candidates(new ChunkImpl(token)); for (PhraseTableEntry entry : candidates) { if (entry.getFeatures()[2].score >= p) { transCt++; } } } return (double) transCt / tokens.size(); }
public double percentUnigramsInLM(Translation translation) { final LanguageModel nGramSource = getModel(translation.getSourceLabel().getLanguage()); if (nGramSource == null) { return 1.0; } final List<String> tokens = getTokens(translation.getSourceLabel()); int count = 0; for (String s : tokens) { if (Double.isInfinite(nGramSource.score(Arrays.asList(s)))) { count++; } } return (double) (tokens.size() - count) / tokens.size(); }
public static double countPunctuationInTarget(Translation translation) { return countPunctuation(translation.getTargetLabel()); } private static final String puncClass = "[!-/:-@\\[-`\\{-~\u2000-\u206f]+";
@Override public double confidence(Translation translation) { final double[] scores = baselineFeatures.allFeatures(translation); double sum = offset; assert (scores.length == wts.length); for (int i = 0; i < scores.length; i++) { if (!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i])) { sum += wts[i] * scores[i]; } } if (sum < 0) { return 0; } else if (sum > 1) { return 1; } else { return sum; } } }
final BufferedReader annotations = new BufferedReader(new FileReader(args[4])); final PrintWriter out = new PrintWriter(args[5]); final BaselineFeatures features = new BaselineFeatures(Services.get(LanguageModelFactory.class), Services.get(TranslationSourceFactory.class)); out.println( "TKS_S,TKS_T,TKLEN_S,LM_S,LM_T,OCC_T,TC_20,TC_1,TOP1G,BOT1G,TOP2G,BOT2G,TOP3G,BOT3G,UNK,PUNC_S,PUNC_T,C"); out.print(countTksInSrc(translation)); out.print(","); out.print(countTksInTrg(translation)); out.print(","); out.print(aveSrcTkLen(translation)); out.print(","); out.print(features.sourceLMProb(translation)); out.print(","); out.print(features.targetLMProb(translation)); out.print(","); out.print(aveOccurencesInTarget(translation)); out.print(","); out.print(features.aveTranslationCount(translation, 0.2)); out.print(","); out.print(features.aveTranslationCount(translation, 0.01)); out.print(","); final double[] oneGramScores = features.percentNGramsInTopBotQuartile(translation, 1); out.print(oneGramScores[0]); out.print(","); out.print(oneGramScores[1]); out.print(","); final double[] twoGramScores = features.percentNGramsInTopBotQuartile(translation, 2); out.print(twoGramScores[0]);
public double[] percentNGramsInTopBotQuartile(Translation translation, int n) { final LanguageModel nGramSource = getModel(translation.getSourceLabel().getLanguage()); if (nGramSource == null) { return new double[]{0.0, 0.0}; final List<String> tokens = getTokens(translation.getSourceLabel()); int botCount = 0; int topCount = 0;
public static double countPunctuationInSource(Translation translation) { return countPunctuation(translation.getSourceLabel()); }