/** * Convert a map with a set of entries of type K for key * and V for value in to a list of {@link Pair} * @param map the map to collapse * @param <K> the key type * @param <V> the value type * @return the collapsed map as a {@link List} */ public static <K,V> List<Pair<K,V>> mapToPair(Map<K,V> map) { List<Pair<K,V>> ret = new ArrayList<>(map.size()); for(Map.Entry<K,V> entry : map.entrySet()) { ret.add(Pair.of(entry.getKey(),entry.getValue())); } return ret; }
List<V> leftListPair = new ArrayList<>(); List<V> rightListPair = new ArrayList<>(); Pair<List<V>,List<V>> p = Pair.of(leftListPair,rightListPair); ret.put(key,p); List<V> leftListPair = new ArrayList<>(); List<V> rightListPair = new ArrayList<>(); Pair<List<V>,List<V>> p = Pair.of(leftListPair,rightListPair); ret.put(key,p);
/** * * @param point * @return */ public Pair<Cluster, Double> nearestCluster(Point point) { Cluster nearestCluster = null; double minDistance = isInverse() ? Float.MIN_VALUE : Float.MAX_VALUE; double currentDistance; for (Cluster cluster : getClusters()) { currentDistance = cluster.getDistanceToCenter(point); if (isInverse()) { if (currentDistance > minDistance) { minDistance = currentDistance; nearestCluster = cluster; } } else { if (currentDistance < minDistance) { minDistance = currentDistance; nearestCluster = cluster; } } } return Pair.of(nearestCluster, minDistance); }
private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List<Pair<Double, INDArray>> best, int _disc) { if (node == null || rect.minDistance(point) > dist) return; int _discNext = (_disc + 1) % dims; double distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point)).getFinalResult() .doubleValue(); if (distance <= dist) { best.add(Pair.of(distance, node.getPoint())); } HyperRect lower = rect.getLower(point, _disc); HyperRect upper = rect.getUpper(point, _disc); knn(node.getLeft(), point, lower, dist, best, _discNext); knn(node.getRight(), point, upper, dist, best, _discNext); }
private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, int _disc) { if (node == null || rect.minDistance(point) > dist) return Pair.of(Double.POSITIVE_INFINITY, null); int _discNext = (_disc + 1) % dims; double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point)).getFinalResult().doubleValue(); if (dist2 < dist) { best = node.getPoint(); dist = dist2; } HyperRect lower = rect.getLower(node.point, _disc); HyperRect upper = rect.getUpper(node.point, _disc); if (point.getDouble(_disc) < node.point.getDouble(_disc)) { Pair<Double, INDArray> left = nn(node.getLeft(), point, lower, dist, best, _discNext); Pair<Double, INDArray> right = nn(node.getRight(), point, upper, dist, best, _discNext); if (left.getKey() < dist) return left; else if (right.getKey() < dist) return right; } else { Pair<Double, INDArray> left = nn(node.getRight(), point, upper, dist, best, _discNext); Pair<Double, INDArray> right = nn(node.getLeft(), point, lower, dist, best, _discNext); if (left.getKey() < dist) return left; else if (right.getKey() < dist) return right; } return Pair.of(dist, best); }
/** * Read a datavec schema and record set * from the given arrow file. * @param input the input to read * @return the associated datavec schema and record */ public static Pair<Schema,ArrowWritableRecordBatch> readFromFile(FileInputStream input) throws IOException { BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); Schema retSchema = null; ArrowWritableRecordBatch ret = null; SeekableReadChannel channel = new SeekableReadChannel(input.getChannel()); ArrowFileReader reader = new ArrowFileReader(channel, allocator); reader.loadNextBatch(); retSchema = toDatavecSchema(reader.getVectorSchemaRoot().getSchema()); //load the batch VectorUnloader unloader = new VectorUnloader(reader.getVectorSchemaRoot()); VectorLoader vectorLoader = new VectorLoader(reader.getVectorSchemaRoot()); ArrowRecordBatch recordBatch = unloader.getRecordBatch(); vectorLoader.load(recordBatch); ret = asDataVecBatch(recordBatch,retSchema,reader.getVectorSchemaRoot()); ret.setUnloader(unloader); return Pair.of(retSchema,ret); }
private Pair<KDNode, Integer> min(KDNode node, int disc, int _disc) { int discNext = (_disc + 1) % dims; if (_disc == disc) { KDNode child = node.getLeft(); if (child != null) { return min(child, disc, discNext); } } else if (node.getLeft() != null || node.getRight() != null) { Pair<KDNode, Integer> left = null, right = null; if (node.getLeft() != null) left = min(node.getLeft(), disc, discNext); if (node.getRight() != null) right = min(node.getRight(), disc, discNext); if (left != null && right != null) { double pointLeft = left.getKey().getPoint().getDouble(disc); double pointRight = right.getKey().getPoint().getDouble(disc); if (pointLeft < pointRight) return left; else return right; } else if (left != null) return left; else return right; } return Pair.of(node, _disc); }
/** * Read a datavec schema and record set * from the given bytes (usually expected to be an arrow format file) * @param input the input to read * @return the associated datavec schema and record */ public static Pair<Schema,ArrowWritableRecordBatch> readFromBytes(byte[] input) throws IOException { BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); Schema retSchema = null; ArrowWritableRecordBatch ret = null; SeekableReadChannel channel = new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(input)); ArrowFileReader reader = new ArrowFileReader(channel, allocator); reader.loadNextBatch(); retSchema = toDatavecSchema(reader.getVectorSchemaRoot().getSchema()); //load the batch VectorUnloader unloader = new VectorUnloader(reader.getVectorSchemaRoot()); VectorLoader vectorLoader = new VectorLoader(reader.getVectorSchemaRoot()); ArrowRecordBatch recordBatch = unloader.getRecordBatch(); vectorLoader.load(recordBatch); ret = asDataVecBatch(recordBatch,retSchema,reader.getVectorSchemaRoot()); ret.setUnloader(unloader); return Pair.of(retSchema,ret); }
private Pair<KDNode, Integer> max(KDNode node, int disc, int _disc) { int discNext = (_disc + 1) % dims; if (_disc == disc) { KDNode child = node.getLeft(); if (child != null) { return max(child, disc, discNext); } } else if (node.getLeft() != null || node.getRight() != null) { Pair<KDNode, Integer> left = null, right = null; if (node.getLeft() != null) left = max(node.getLeft(), disc, discNext); if (node.getRight() != null) right = max(node.getRight(), disc, discNext); if (left != null && right != null) { double pointLeft = left.getKey().getPoint().getDouble(disc); double pointRight = right.getKey().getPoint().getDouble(disc); if (pointLeft > pointRight) return left; else return right; } else if (left != null) return left; else return right; } return Pair.of(node, _disc); }