searcher.clear(); searcher.addAll(dataPoints); Pair<List<List<WeightedThing<Vector>>>, Long> results = getResultsAndRuntime(searcher, queries);
@Test public void testOverlapAndRuntimeSearchFirst() { searcher.clear(); searcher.addAll(dataPoints); Pair<List<WeightedThing<Vector>>, Long> results = getResultsAndRuntimeSearchFirst(searcher, queries); int numFirstMatches = 0; for (int i = 0; i < queries.numRows(); ++i) { WeightedThing<Vector> referenceVector = referenceSearchFirst.getFirst().get(i); WeightedThing<Vector> resultVector = results.getFirst().get(i); if (referenceVector.getValue().equals(resultVector.getValue())) { ++numFirstMatches; } } double bruteSearchAvgTime = reference.getSecond() / (queries.numRows() * 1.0); double searcherAvgTime = results.getSecond() / (queries.numRows() * 1.0); System.out.printf("%s: first matches %d [%d]; avg_time(1 query) %f(s) [%f]\n", searcher.getClass().getName(), numFirstMatches, queries.numRows(), searcherAvgTime, bruteSearchAvgTime); assertEquals("Closest vector returned doesn't match", queries.numRows(), numFirstMatches); assertTrue("Searcher " + searcher.getClass().getName() + " slower than brute", bruteSearchAvgTime > searcherAvgTime); } @Test
@Parameterized.Parameters public static List<Object[]> generateData() { RandomUtils.useTestSeed(); Matrix dataPoints = LumpyData.lumpyRandomData(NUM_DATA_POINTS, NUM_DIMENSIONS); Matrix queries = LumpyData.lumpyRandomData(NUM_QUERIES, NUM_DIMENSIONS); DistanceMeasure distanceMeasure = new CosineDistanceMeasure(); Searcher bruteSearcher = new BruteSearch(distanceMeasure); bruteSearcher.addAll(dataPoints); Pair<List<List<WeightedThing<Vector>>>, Long> reference = getResultsAndRuntime(bruteSearcher, queries); Pair<List<WeightedThing<Vector>>, Long> referenceSearchFirst = getResultsAndRuntimeSearchFirst(bruteSearcher, queries); double bruteSearchAvgTime = reference.getSecond() / (queries.numRows() * 1.0); System.out.printf("BruteSearch: avg_time(1 query) %f[s]\n", bruteSearchAvgTime); return Arrays.asList(new Object[][]{ // NUM_PROJECTIONS = 3 // SEARCH_SIZE = 10 {new ProjectionSearch(distanceMeasure, 3, 10), dataPoints, queries, reference, referenceSearchFirst}, {new FastProjectionSearch(distanceMeasure, 3, 10), dataPoints, queries, reference, referenceSearchFirst}, // NUM_PROJECTIONS = 5 // SEARCH_SIZE = 5 {new ProjectionSearch(distanceMeasure, 5, 5), dataPoints, queries, reference, referenceSearchFirst}, {new FastProjectionSearch(distanceMeasure, 5, 5), dataPoints, queries, reference, referenceSearchFirst}, } ); }