public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Set<INDArray> indices, String similarityFunction) { init(parent, data, corner, width, indices, similarityFunction); }
private void fill(int n) { if (indices.isEmpty() && parent == null) for (int i = 0; i < n; i++) { log.trace("Inserted " + i); insert(i); } else log.warn("Called fill already"); }
/** * Verifies the structure of the tree (does bounds checking on each node) * @return true if the structure of the tree * is correct. */ public boolean isCorrect() { for (int n = 0; n < size; n++) { INDArray point = data.slice(index[n]); if (!boundary.contains(point)) return false; } if (!isLeaf()) { boolean correct = true; for (int i = 0; i < numChildren; i++) correct = correct && children[i].isCorrect(); return correct; } return true; }
@Override public Gradient gradient() { MemoryWorkspace workspace = workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( workspaceConfigurationExternal, workspaceExternal); try (MemoryWorkspace ws = workspace.notifyScopeEntered()) { if (yIncs == null) yIncs = zeros(Y.shape()); if (gains == null) gains = ones(Y.shape()); AtomicDouble sumQ = new AtomicDouble(0); /* Calculate gradient based on barnes hut approximation with positive and negative forces */ INDArray posF = Nd4j.create(Y.shape()); INDArray negF = Nd4j.create(Y.shape()); if (tree == null) { tree = new SpTree(Y); tree.setWorkspaceMode(workspaceMode); } tree.computeEdgeForces(rows, cols, vals, N, posF); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ); INDArray dC = posF.subi(negF.divi(sumQ)); Gradient ret = new DefaultGradient(); ret.gradientForVariable().put(Y_GRAD, dC); return ret; } }
@Override public Gradient gradient() { if (yIncs == null) yIncs = zeros(Y.shape()); if (gains == null) gains = ones(Y.shape()); AtomicDouble sumQ = new AtomicDouble(0); /* Calculate gradient based on barnes hut approximation with positive and negative forces */ INDArray posF = Nd4j.create(Y.shape()); INDArray negF = Nd4j.create(Y.shape()); if (tree == null) tree = new SpTree(Y); tree.computeEdgeForces(rows, cols, vals, N, posF); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ); INDArray dC = posF.subi(negF.divi(sumQ)); Gradient ret = new DefaultGradient(); ret.gradientForVariable().put(Y_GRAD, dC); return ret; }
if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) return; if (isLeaf() || maxWidth / FastMath.sqrt(D) < theta) { children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
@Override public double score() { // Get estimate of normalization term INDArray buff = Nd4j.create(numDimensions); AtomicDouble sum_Q = new AtomicDouble(0.0); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, buff, sum_Q); // Loop over all edges to compute t-SNE error double C = .0; INDArray linear = Y; for (int n = 0; n < N; n++) { int begin = rows.getInt(n); int end = rows.getInt(n + 1); int ind1 = n; for (int i = begin; i < end; i++) { int ind2 = cols.getInt(i); buff.assign(linear.slice(ind1)); buff.subi(linear.slice(ind2)); double Q = pow(buff, 2).sum(Integer.MAX_VALUE).getDouble(0); Q = (1.0 / (1.0 + Q)) / sum_Q.doubleValue(); C += vals.getDouble(i) * FastMath.log(vals.getDouble(i) + Nd4j.EPS_THRESHOLD) / (Q + Nd4j.EPS_THRESHOLD); } } return C; }
private boolean insert(int index) { INDArray point = data.slice(index); if (!boundary.contains(point)) return false; cumSize++; double mult1 = (double) (cumSize - 1) / (double) cumSize; double mult2 = 1.0 / (double) cumSize; centerOfMass.muli(mult1); centerOfMass.addi(point.mul(mult2)); // If there is space in this quad tree and it is a leaf, add the object here if (isLeaf() && size < nodeCapacity) { this.index[size] = index; indices.add(point); size++; return true; } for (int i = 0; i < size; i++) { INDArray compPoint = data.slice(this.index[i]); if (compPoint.equals(point)) return true; } if (isLeaf()) subDivide(); // Find out where the point can be inserted for (int i = 0; i < numChildren; i++) { if (children[i].insert(index)) return true; } throw new IllegalStateException("Shouldn't reach this state"); }
public SpTree(INDArray data, Set<INDArray> indices, String similarityFunction) { this.indices = indices; this.N = data.rows(); this.D = data.columns(); this.similarityFunction = similarityFunction; INDArray meanY = data.mean(0); INDArray minY = data.min(0); INDArray maxY = data.max(0); INDArray width = Nd4j.create(meanY.shape()); for (int i = 0; i < width.length(); i++) { width.putScalar(i, FastMath.max(maxY.getDouble(i) - meanY.getDouble(i), meanY.getDouble(i) - minY.getDouble(i) + Nd4j.EPS_THRESHOLD)); } init(null, data, meanY, width, indices, similarityFunction); fill(N); }
AtomicDouble sum_Q = new AtomicDouble(0.0); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, buff, sum_Q);