@Override public String normalize(LocalString string) { return normalize(string.getLanguage(), string.getString()); }
public TIntFloatMap[] getPhraseVectors(String ... phrases) throws DaoException { List<LocalString> local = new ArrayList<LocalString>(); for (String p : phrases) { local.add(new LocalString(language, p)); } List<LinkedHashMap<LocalId, Float>> candidates = disambig.disambiguate(local, null); if (candidates.size() != phrases.length) throw new IllegalStateException(); TIntFloatMap results[] = new TIntFloatMap[phrases.length]; for (int i = 0; i < phrases.length; i++) { results[i] = getPhraseVector(phrases[i], candidates.get(i)); } return results; }
@Override public String normalize(LocalString text) { return text.getString(); }
private void debugSimilarityDisambiguator(List<LocalString> phrases) throws DaoException { String last = null; boolean same = true; StringBuffer b = new StringBuffer("results for " + phrases.get(0).getString() + ", " + phrases.get(1).getString() + "\n"); for (SimilarityDisambiguator.Criteria c : SimilarityDisambiguator.Criteria.values()) { if (c == SimilarityDisambiguator.Criteria.SIMILARITY) { continue; // weird, so skip for now. } List<LocalId> resolutions; synchronized (disambiguator) { ((SimilarityDisambiguator)disambiguator).setCriteria(c); resolutions = disambiguator.disambiguateTop(phrases, null); } String page1 = resolutions.get(0) == null ? "null" : localPageDao.getById(language, resolutions.get(0).getId()).toString(); String page2 = resolutions.get(1) == null ? "null" : localPageDao.getById(language, resolutions.get(1).getId()).toString(); b.append("\t" + c + ": " + page1 + ", " + page2 + "\n"); if (last == null) last = page1+page2; if (!last.equals(page1+page2)) { same = false; } } if (!same) { System.out.println(b.toString()); } }
@Override public List<LinkedHashMap<LocalId, Float>> disambiguate(List<LocalString> phrases, Set<LocalString> context) throws DaoException { List<LinkedHashMap<LocalId, Float>> results = new ArrayList<LinkedHashMap<LocalId, Float>>(); for (LocalString phrase : phrases) { LinkedHashMap<LocalId, Float> localMap = phraseAnalyzer.resolve(phrase.getLanguage(), phrase.getString(), 10); if (localMap==null){ results.add(null); } else { LinkedHashMap<LocalId, Float> phraseResult = new LinkedHashMap<LocalId, Float>(); for (LocalId id : localMap.keySet()) { phraseResult.put(id, localMap.get(id)); } results.add(phraseResult); } } return results; }
public TIntFloatMap getPhraseVector(String phrase) throws DaoException { LocalString ls = new LocalString(language, phrase); LinkedHashMap<LocalId, Float> candidates = disambig.disambiguate(ls, null); return getPhraseVector(phrase, candidates); }
public LocalId disambiguateTop(LocalString phrase, Set<LocalString> context) throws DaoException{ LinkedHashMap<LocalId, Integer> results = new LinkedHashMap<LocalId, Integer>(); for (PhraseAnalyzer phraseAnalyzer : phraseAnalyzers){ LinkedHashMap<LocalId, Float> localMap = phraseAnalyzer.resolve(phrase.getLanguage(), phrase.getString(), 1); if (localMap==null||localMap.isEmpty()){ continue; } LocalId localId = localMap.keySet().iterator().next(); if (results.containsKey(localId)){ results.put(localId,results.get(localId)+1); } else { results.put(localId,1); } } if (results.isEmpty()){ return null; } else { LocalId best=null; int score = 0; for (LocalId localId : results.keySet()){ if (results.get(localId)>score){ score = results.get(localId); best = localId; } } return best; } }
public void call(KnownSim ks) throws IOException, DaoException { ks.maybeSwap(); List<LocalString> localStrings = new ArrayList<LocalString>(); localStrings.add(new LocalString(ks.language, ks.phrase1)); localStrings.add(new LocalString(ks.language, ks.phrase2)); List<LocalId> ids = disambiguator.disambiguateTop(localStrings, null); if (ids != null && ids.size() == 2 && ids.get(0) != null && ids.get(1) != null) { LocalId lid1 = ids.get(0); LocalId lid2 = ids.get(1); SRResultList dsl = metric.mostSimilar(lid1.getId(), maxResults, validIds); if (dsl != null) { trainee.observe(dsl, dsl.getIndexForId(lid2.getId()), ks.similarity); } } } }, 100);
@Override public List<LinkedHashMap<LocalId, Float>> disambiguate(List<LocalString> phrases, Set<LocalString> context) throws DaoException { if (phrases.isEmpty()) { return new ArrayList<LinkedHashMap<LocalId, Float>>(); } Language lang = phrases.get(0).getLanguage(); List<LinkedHashMap<LocalId, Float>> results = new ArrayList<LinkedHashMap<LocalId, Float>>(); for (LocalString phrase : phrases) { Map<Integer, Double> pageSums = new HashMap<Integer, Double>(); for (PhraseAnalyzer pa : phraseAnalyzers) { LinkedHashMap<LocalId, Float> probs = pa.resolve(phrase.getLanguage(), phrase.getString(), 20); for (Map.Entry<LocalId, Float> entry : probs.entrySet()) { int id = entry.getKey().getId(); if (pageSums.containsKey(id)) { pageSums.put(id, pageSums.get(id) + entry.getValue()); } else { pageSums.put(id, (double)entry.getValue()); } } } LinkedHashMap<LocalId, Float> pageResult = new LinkedHashMap<LocalId, Float>(); for (Integer key : WpCollectionUtils.sortMapKeys(pageSums, true)) { pageResult.put(new LocalId(lang, key), pageSums.get(key).floatValue()); } results.add(pageResult); } return results; }
@Override public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException { Language language = getLanguage(); List<LocalString> phrases = Arrays.asList( new LocalString(language, phrase1), new LocalString(language, phrase2)); // debugSimilarityDisambiguator(phrases); List<LocalId> resolutions = disambiguator.disambiguateTop(phrases, null); if (resolutions.get(0) == null || resolutions.get(1) == null) { return new SRResult(); } // LocalPage lp1 = localPageDao.getById(language, resolutions.get(0).getId()); // LocalPage lp2 = localPageDao.getById(language, resolutions.get(1).getId()); // System.out.println("resolved " + phrase1 + ", " + phrase2 + " to " + lp1 + ", " + lp2); return similarity(resolutions.get(0).getId(), resolutions.get(1).getId(), explanations); }
candidates.put(s, phraseAnalyzer.resolve(s.getLanguage(), s.getString(), numCandidates));
@Override public double[][] cosimilarity(String[] phrases) throws DaoException { int ids[] = new int[phrases.length]; List<LocalString> localStringList = new ArrayList<LocalString>(); for (String phrase : phrases){ localStringList.add(new LocalString(getLanguage(), phrase)); } List<LocalId> localIds = disambiguator.disambiguateTop(localStringList, null); for (int i=0; i<phrases.length; i++){ ids[i] = localIds.get(i).getId(); } return cosimilarity(ids); }
if (!s.getLanguage().equals(language)) { throw new IllegalArgumentException("Disambiguator only supports language " + language); candidates.put(s, analyzer.resolve(s.getLanguage(), s.getString(), 100));
for (String field : contextFields) { if (row.get(field) != null) { context.add(new LocalString(language, row.get(field))); for (int i = 0; i < titles.size(); i++) { double weight = 1.0 * Math.pow(0.7, i); LinkedHashMap<LocalId, Float> result = disambig.disambiguate(new LocalString(language, titles.get(i)), context); if (result == null) { continue;
for (String field : contextFields) { if (row.get(field) != null) { context.add(new LocalString(language, row.get(field))); for (int i = 0; i < titles.size(); i++) { double weight = 1.0 * Math.pow(0.7, i); LinkedHashMap<LocalId, Float> result = disambig.disambiguate(new LocalString(language, titles.get(i)), context); if (result == null) { continue;
); if (resolvePhrases) { LocalId id1 = disambiguator.disambiguateTop(new LocalString(language, ks.phrase1), null); LocalId id2 = disambiguator.disambiguateTop(new LocalString(language, ks.phrase2), null); if (id1 != null) { ks.wpId1 = id1.getId(); } if (id2 != null) { ks.wpId2 = id2.getId(); }
public EnsembleSim call(KnownSim ks) throws DaoException { List<LocalString> localStrings = Arrays.asList( new LocalString(ks.language, ks.phrase1), new LocalString(ks.language, ks.phrase2) ); List<LocalId> ids = getDisambiguator().disambiguateTop(localStrings, null); if (ids.isEmpty() || ids.get(0).getId() <= 0) { return null; } int pageId = ids.get(0).getId(); EnsembleSim es = new EnsembleSim(ks); for (SRMetric metric : metrics) { double score = Double.NaN; int rank = -1; try { SRResultList dsl = metric.mostSimilar(pageId, getMaxResults(numResults), validIds); if (dsl != null && dsl.getIndexForId(ids.get(1).getId()) >= 0) { score = dsl.getScore(dsl.getIndexForId(ids.get(1).getId())); rank = dsl.getIndexForId(ids.get(1).getId()); } } catch (Exception e) { LOG.warn("Local sr metric " + metric.getName() + " failed for " + pageId, e); } finally { es.add(score, rank); } } return es; } }, 100);
public static TIntFloatMap getPhraseVector(SparseVectorSRMetric metric, String phrase) { // try using phrase generator directly try { return metric.getGenerator().getVector(phrase); } catch (UnsupportedOperationException e) { // try using other methods } try { Language lang = metric.getLanguage(); LocalId best = metric.getDisambiguator().disambiguateTop(new LocalString(lang, phrase), null); if (best == null) { return null; } return metric.getPageVector(best.getId()); } catch (DaoException e) { throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } }
@Override public SRResultList mostSimilar(String phrase, int maxResults, TIntSet validIds) throws DaoException{ LocalId similar = disambiguator.disambiguateTop(new LocalString(getLanguage(), phrase), null); if (similar==null){ SRResultList resultList = new SRResultList(1); resultList.set(0, new SRResult()); return resultList; } return mostSimilar(similar.getId(), maxResults,validIds); }
@Override public SRResultList mostSimilar(String phrase, int maxResults) throws DaoException { LocalId similar = disambiguator.disambiguateTop(new LocalString(getLanguage(), phrase), null); if (similar==null){ SRResultList resultList = new SRResultList(1); resultList.set(0, new SRResult()); return resultList; } return mostSimilar(similar.getId(), maxResults); }