public SRResultList getTop() { SRResultList scores = new SRResultList(size); for (int i = 1; i <= size; i++) { scores.set(i - 1, keys[i], values[i]); } scores.sortDescending(); return scores; }
@Override public TIntFloatMap getVector(int pageId) throws DaoException { SRResultList mostSimilar = baseMetric.mostSimilar(pageId, numConcepts, conceptIds); if (mostSimilar == null) { return null; } else { return mostSimilar.asTroveMap(); } }
private void writeSim(SparseMatrixWriter writer, Integer wpId, TIntSet colIds, int maxSimsPerDoc, AtomicInteger idCounter, AtomicLong cellCounter) throws IOException, DaoException { if (idCounter.incrementAndGet() % 10000 == 0) { LOG.info("finding matches for page " + idCounter.get()); } SRResultList scores = mostSimilar(wpId, maxSimsPerDoc, colIds); if (scores != null) { int ids[] = scores.getIds(); cellCounter.getAndIncrement(); writer.writeRow(new SparseMatrixRow(writer.getValueConf(), wpId, ids, scores.getScoresAsFloat())); } }
public static void main(String[] args) throws Exception{ // Initialize the WikiBrain environment and get the local page dao Env env = EnvBuilder.envFromArgs(args); Configurator conf = env.getConfigurator(); LocalPageDao lpDao = conf.get(LocalPageDao.class); Language simple = env.getDefaultLanguage(); // Retrieve the "milnewitten" sr metric for simple english SRMetric sr = conf.get( SRMetric.class, "prebuiltword2vec", "language", simple.getLangCode()); //Similarity between strings for (String phrase : Arrays.asList("Barack Obama", "US", "Canada", "vim")) { SRResultList similar = sr.mostSimilar(phrase, 3); List<String> pages = new ArrayList<String>(); for (int i = 0; i < similar.numDocs(); i++) { LocalPage page = lpDao.getById(simple, similar.getId(i)); pages.add((i+1) + ") " + page.getTitle()); } System.out.println("'" + phrase + "' is similar to " + StringUtils.join(pages, ", ")); } } }
for (int i = 0; i < l1.numDocs(); i++) { double s = l1.getScore(i); if (!Double.isInfinite(s) && !Double.isNaN(s)) { scores.adjustOrPutValue(l1.getId(i), 0.5 * s, 0.5 * s); inList1.add(l1.getId(i)); TIntSet inList2 = new TIntHashSet(); if (l2 != null) { for (int i = 0; i < l2.numDocs(); i++) { double s = l2.getScore(i); if (!Double.isInfinite(s) && !Double.isNaN(s)) { scores.adjustOrPutValue(l2.getId(i), 0.5 * s, 0.5 * s); inList2.add(l2.getId(i)); double missingScore1 = (l1 == null) ? 0.0 : l1.getMissingScore(); double missingScore2 = (l2 == null) ? 0.0 : l2.getMissingScore();
for (int j = 0; j < resultList.numDocs(); j++) { int rank = (int) ((j + 1) * k); SRResult result = resultList.get(j); unknownIds.remove(result.getId()); double value = c1 * result.getScore() + c2 * Math.log(rank); interpolatedRank = (int) Math.max(interpolatedRank, k * resultList.numDocs() * 5 / 4); Collections.reverse(resultList); int size = maxResults>resultList.size()? resultList.size() : maxResults; SRResultList result = new SRResultList(size); for (i=0; i<size;i++){ result.set(i,resultList.get(i));
for (LocalId id1 : candidates.keySet()) { SRResultList sr = metric.mostSimilar(id1.getId(), numCands * 2); if (sr != null && sr.numDocs() > 0) { for (int j = 0; j < numPerCand && j < sr.numDocs(); j++) { expanded.put(new LocalId(language, sr.getId(j)), (float)(sr.getScore(j) * candidates.get(id1)));
public MostSimilarGuess(KnownMostSim known, SRResultList guess) { this.known = known; length = guess.numDocs(); minScore = guess.minScore(); maxScore = guess.maxScore(); TIntDoubleMap actual = new TIntDoubleHashMap(); for (KnownSim ks : known.getMostSimilar()) { actual.put(ks.wpId2, ks.similarity); } for (int i = 0; i < guess.numDocs(); i++) { SRResult sr = guess.get(i); if (actual.containsKey(sr.getId())) { observations.add(new Observation(i+1, sr.getId(), sr.getScore(), actual.get(sr.getId()))); } } }
(int) Math.ceil(candidateSet.size() * 0.8), candidateSet); if (rl != null && rl.numDocs() > 0) { TIntFloatMap subscores = rl.asTroveMap(); double minScore = rl.getScore(rl.numDocs() - 1) * 0.99; for (int id : subscores.keys()) { double s = minScore;
/** * Returns the index of the specified ID, or -1 if not found. * @param id * @return */ public int getIndexForId(int id) { for (int i = 0; i < numDocs(); i++) { if (results[i].id == id) { return i; } } return -1; }
@Override public SRResultList predictMostSimilar(List<SRResultList> scores, int maxResults, TIntSet validIds) { int numMetrics = scores.size(); TIntDoubleHashMap scoreMap = new TIntDoubleHashMap(); for (SRResultList resultList : scores){ for (SRResult result : resultList){ double value = result.getScore()/numMetrics; scoreMap.adjustOrPutValue(result.getId(),value,value); } } List<SRResult> resultList = new ArrayList<SRResult>(); for (int id : scoreMap.keys()){ resultList.add(new SRResult(id,scoreMap.get(id))); } Collections.sort(resultList); Collections.reverse(resultList); SRResultList result = new SRResultList(maxResults); for (int i=0; i<maxResults&&i<resultList.size();i++){ result.set(i,resultList.get(i)); } return result; }
results.sortDescending(); for (SRResult hit : results) { System.out.println(hit.getScore() + ": " + lpd.getById(lang, hit.getId())); results.sortDescending(); for (SRResult hit : results) { System.out.println(hit.getScore() + ": " + lpd.getById(lang, hit.getId()));
public EnsembleSim call(KnownSim ks) throws DaoException { List<LocalString> localStrings = Arrays.asList( new LocalString(ks.language, ks.phrase1), new LocalString(ks.language, ks.phrase2) ); List<LocalId> ids = getDisambiguator().disambiguateTop(localStrings, null); if (ids.isEmpty() || ids.get(0).getId() <= 0) { return null; } int pageId = ids.get(0).getId(); EnsembleSim es = new EnsembleSim(ks); for (SRMetric metric : metrics) { double score = Double.NaN; int rank = -1; try { SRResultList dsl = metric.mostSimilar(pageId, getMaxResults(numResults), validIds); if (dsl != null && dsl.getIndexForId(ids.get(1).getId()) >= 0) { score = dsl.getScore(dsl.getIndexForId(ids.get(1).getId())); rank = dsl.getIndexForId(ids.get(1).getId()); } } catch (Exception e) { LOG.warn("Local sr metric " + metric.getName() + " failed for " + pageId, e); } finally { es.add(score, rank); } } return es; } }, 100);
@Override public void observe(SRResultList list, int index, double y) { if (index >= 0) { double score = list.getScore(index); if (!Double.isNaN(score) && !Double.isInfinite(score)) { synchronized (ranks) { ranks.add(index); scores.add(score); ys.add(y); } } } super.observe(list, index, y); }
@Override public void observe(SRResultList sims, int rank, double y) { if (rank >= 0) { observe(sims.get(rank).getScore(), y); } else { observe(Double.NaN, y); } }
public void call(KnownSim ks) throws IOException, DaoException { ks.maybeSwap(); List<LocalString> localStrings = new ArrayList<LocalString>(); localStrings.add(new LocalString(ks.language, ks.phrase1)); localStrings.add(new LocalString(ks.language, ks.phrase2)); List<LocalId> ids = disambiguator.disambiguateTop(localStrings, null); if (ids != null && ids.size() == 2 && ids.get(0) != null && ids.get(1) != null) { LocalId lid1 = ids.get(0); LocalId lid2 = ids.get(1); SRResultList dsl = metric.mostSimilar(lid1.getId(), maxResults, validIds); if (dsl != null) { trainee.observe(dsl, dsl.getIndexForId(lid2.getId()), ks.similarity); } } } }, 100);
public static void main(String[] args) throws Exception{ // Initialize the WikiBrain environment and get the local page dao Env env = EnvBuilder.envFromArgs(args); Configurator conf = env.getConfigurator(); LocalPageDao lpDao = conf.get(LocalPageDao.class); Language simple = env.getDefaultLanguage(); // Retrieve the "milnewitten" sr metric for simple english SRMetric sr = conf.get( SRMetric.class, "prebuiltword2vec", "language", simple.getLangCode()); //Similarity between strings for (String phrase : Arrays.asList("Barack Obama", "US", "Canada", "vim")) { SRResultList similar = sr.mostSimilar(phrase, 3); List<String> pages = new ArrayList<String>(); for (int i = 0; i < similar.numDocs(); i++) { LocalPage page = lpDao.getById(simple, similar.getId(i)); pages.add((i+1) + ") " + page.getTitle()); } System.out.println("'" + phrase + "' is similar to " + StringUtils.join(pages, ", ")); } } }
double c2 = mostSimilarCoefficients.get(i+1); // rank coefficient if (resultList != null) { for (int j = 0; j < resultList.numDocs(); j++) { int rank = j + 1; rank = (int) (rank * k); SRResult result = resultList.get(j); unknownIds.remove(result.getId()); double value = c1 * result.getScore() + c2 * Math.log(rank); Collections.reverse(resultList); int size = maxResults>resultList.size()? resultList.size() : maxResults; SRResultList result = new SRResultList(size); for (i=0; i<size;i++){ result.set(i,resultList.get(i));