@Override public double[][] cosimilarity(String[] phrases) throws DaoException { int ids[] = new int[phrases.length]; List<LocalString> localStringList = new ArrayList<LocalString>(); for (String phrase : phrases){ localStringList.add(new LocalString(getLanguage(), phrase)); } List<LocalId> localIds = disambiguator.disambiguateTop(localStringList, null); for (int i=0; i<phrases.length; i++){ ids[i] = localIds.get(i).getId(); } return cosimilarity(ids); }
protected File getMostSimilarMatrixPath() { return new File(getDataDir(), "mostSimilar.matrix"); }
public void clearMostSimilarCache() { IOUtils.closeQuietly(mostSimilarCache); FileUtils.deleteQuietly(getMostSimilarMatrixPath()); mostSimilarCache = null; }
@Override public synchronized void trainMostSimilar(Dataset dataset, int numResults, TIntSet validIds){ if (!dataset.getLanguage().equals(getLanguage())) { throw new IllegalArgumentException("SR metric has language " + getLanguage() + " but dataset has language " + dataset.getLanguage()); } normalizers.trainMostSimilar(this, disambiguator, dataset, validIds, numResults); try { if (buildMostSimilarCache) { writeMostSimilarCache(numResults, mostSimilarCacheRowIds, validIds); } } catch (Exception e) { LOG.error("writing most similar cache failed:", e); } }
protected static void configureBase(Configurator configurator, BaseSRMetric sr, Config config) throws ConfigurationException { Config rootConfig = configurator.getConf().get(); File path = new File(rootConfig.getString("sr.metric.path")); sr.setDataDir(FileUtils.getFile(path, sr.getName(), sr.getLanguage().getLangCode())); // initialize normalizers sr.setSimilarityNormalizer(configurator.get(Normalizer.class, config.getString("similaritynormalizer"))); sr.setMostSimilarNormalizer(configurator.get(Normalizer.class, config.getString("mostsimilarnormalizer"))); boolean isTraining = rootConfig.getBoolean("sr.metric.training"); if (isTraining) { sr.setReadNormalizers(false); } if (config.hasPath("buildMostSimilarCache")) { sr.setBuildMostSimilarCache(config.getBoolean("buildMostSimilarCache")); } try { sr.read(); } catch (IOException e){ throw new ConfigurationException(e); } LOG.info("finished base configuration of metric " + sr.getName()); } }
.setLanguages(getLanguage()) .setNameSpaces(NameSpace.ARTICLE) .setDisambig(false) if (colIds == null) colIds = allPageIds; getDataDir().mkdirs(); IOUtils.closeQuietly(mostSimilarCache); SRConfig config = getConfig(); final AtomicInteger idCounter = new AtomicInteger(); final AtomicLong cellCounter = new AtomicLong(); ValueConf vconf = new ValueConf(config.minScore, config.maxScore); final SparseMatrixWriter writer = new SparseMatrixWriter(getMostSimilarMatrixPath(), vconf); final TIntSet colIdSet = colIds == null ? null : new TIntHashSet(colIds); Normalizer simNormalizer = getSimilarityNormalizer(); Normalizer mostSimNormalizer = getMostSimilarNormalizer(); setMostSimilarNormalizer(new IdentityNormalizer()); setSimilarityNormalizer(new IdentityNormalizer()); try { ParallelForEach.loop( setSimilarityNormalizer(simNormalizer); setMostSimilarNormalizer(mostSimNormalizer); mostSimilarCache = new SparseMatrix(getMostSimilarMatrixPath());
@Override public SRResultList mostSimilar(String phrase, int maxResults, TIntSet validIds) throws DaoException{ LocalId similar = disambiguator.disambiguateTop(new LocalString(getLanguage(), phrase), null); if (similar==null){ SRResultList resultList = new SRResultList(1); resultList.set(0, new SRResult()); return resultList; } return mostSimilar(similar.getId(), maxResults,validIds); }
@Override public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException { Language language = getLanguage(); List<LocalString> phrases = Arrays.asList( new LocalString(language, phrase1), new LocalString(language, phrase2)); // debugSimilarityDisambiguator(phrases); List<LocalId> resolutions = disambiguator.disambiguateTop(phrases, null); if (resolutions.get(0) == null || resolutions.get(1) == null) { return new SRResult(); } // LocalPage lp1 = localPageDao.getById(language, resolutions.get(0).getId()); // LocalPage lp2 = localPageDao.getById(language, resolutions.get(1).getId()); // System.out.println("resolved " + phrase1 + ", " + phrase2 + " to " + lp1 + ", " + lp2); return similarity(resolutions.get(0).getId(), resolutions.get(1).getId(), explanations); }
return super.cosimilarity(rowPhrases, colPhrases); } else { return cosimilarity(rowVectors, colVectors);
@Override public synchronized void trainSimilarity(Dataset dataset) throws DaoException { if (!dataset.getLanguage().equals(getLanguage())) { throw new IllegalArgumentException("SR metric has language " + getLanguage() + " but dataset has language " + dataset.getLanguage()); } normalizers.trainSimilarity(this, dataset); }
/** * Use the language-specific most similar normalizer to normalize a similarity if it exists. * Otherwise use the default most similar normalizer if it's available. * @param srl * @return */ protected SRResultList normalize(SRResultList srl) { ensureMostSimilarTrained(); return normalizers.getMostSimilarNormalizer().normalize(srl); }
protected double normalize (double score){ ensureSimilarityTrained(); return normalizers.getSimilarityNormalizer().normalize(score); }
@Override public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (!config.getString("type").equals("directlink")) { return null; } if (runtimeParams == null || !runtimeParams.containsKey("language")) { throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter"); } Language language = Language.getByLangCode(runtimeParams.get("language")); LocalLinkDao linkDao = getConfigurator().get(LocalLinkDao.class); Disambiguator dab = getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", language.getLangCode()); DirectLinkMetric mw = new DirectLinkMetric( name, language, getConfigurator().get(LocalPageDao.class), linkDao, dab ); configureBase(getConfigurator(), mw, config); return mw; } }
@Override public SRResultList mostSimilar(String phrase, int maxResults) throws DaoException { LocalId similar = disambiguator.disambiguateTop(new LocalString(getLanguage(), phrase), null); if (similar==null){ SRResultList resultList = new SRResultList(1); resultList.set(0, new SRResult()); return resultList; } return mostSimilar(similar.getId(), maxResults); }
return super.cosimilarity(rowPhrases, colPhrases);
@Override public void read() throws IOException { if (!dataDir.isDirectory()) { LOG.warn("directory " + dataDir + " does not exist; cannot read files"); return; } if (shouldReadNormalizers && normalizers.hasReadableNormalizers(dataDir)) { normalizers.read(dataDir); } IOUtils.closeQuietly(mostSimilarCache); if (getMostSimilarMatrixPath().isFile()) { mostSimilarCache = new SparseMatrix(getMostSimilarMatrixPath()); } }