if (j >= results.size()) break; indices.putScalar(j, results.get(j).getIndex());
/** * * @param target * @param k * @param results * @param distances */ public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances) { if (items != null) if (!target.isVector() || target.columns() != items.columns() || target.rows() > 1) throw new ND4JIllegalStateException("Target for search should have shape of [" + 1 + ", " + items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead"); k = Math.min(k, items.rows()); results.clear(); distances.clear(); PriorityQueue<HeapObject> pq = new PriorityQueue<>(items.rows(), new HeapObjectComparator()); search(root, target, k + 1, pq, Double.MAX_VALUE); if (pq.size() > k) pq.poll(); while (!pq.isEmpty()) { HeapObject ho = pq.peek(); results.add(new DataPoint(ho.getIndex(), ho.getPoint())); distances.add(ho.getDistance()); pq.poll(); } if (invert) { Collections.reverse(results); Collections.reverse(distances); } }
/** * * @param items the items to use * @param similarityFunction the similarity function to use * @param workers number of parallel workers for tree building (increases memory requirements!) * @param invert whether to invert the metric (different optimization objective) */ public VPTree(List<DataPoint> items, String similarityFunction, int workers, boolean invert) { if (this.items == null) { this.items = Nd4j.create(items.size(), items.get(0).getPoint().columns()); } this.workers = workers; for (int i = 0; i < items.size(); i++) { //itemsList.add(items.get(i).getPoint()); this.items.putRow(i, items.get(i).getPoint()); } this.invert = invert; this.similarityFunction = similarityFunction; root = buildFromPoints(this.items); }
if (j >= results.size()) break; indices.putScalar(j, results.get(j).getIndex());
public void search() { results = new ArrayList<>(); distances = new ArrayList<>(); //initial search //vpTree.search(target,k,results,distances); //fill till there is k results //by going down the list // if(results.size() < k) { INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1); vpTree.calcDistancesRelativeTo(target, distancesArr); INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert()); results.clear(); distances.clear(); if (vpTree.getItems().isVector()) { for (int i = 0; i < k; i++) { int idx = sortWithIndices[0].getInt(i); results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx)))); distances.add(sortWithIndices[1].getDouble(idx)); } } else { for (int i = 0; i < k; i++) { int idx = sortWithIndices[0].getInt(i); results.add(new DataPoint(idx, vpTree.getItems().getRow(idx))); distances.add(sortWithIndices[1].getDouble(idx)); } } }