Refine search
private void moveBackward(int index) { int numMoved = size - index - 1; INDArrayIndex[] first = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index ,index + numMoved)}; INDArrayIndex[] getRange = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index + 1 ,index + 1 + numMoved)}; INDArray get = container.get(getRange); container.put(first,get); }
@Override public boolean remove(Object o) { int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); if(idx < 0) return false; container.put(new INDArrayIndex[]{NDArrayIndex.interval(idx,container.length())},container.get(NDArrayIndex.interval(idx + 1,container.length()))); container = container.reshape(1,size); return true; }
print("Assigned value to array (1x3 => 10)", fourByFiveRandomZeroToOne); INDArray threeValuesArray = fourByFiveRandomZeroToOne.get(NDArrayIndex.interval(0, 3)); print("Get interval from array ([0:3])", threeValuesArray); INDArray threeValuesArrayColumnFour = fourByFiveRandomZeroToOne.get(NDArrayIndex.interval(0, 3), NDArrayIndex.point(4)); print("Get interval from array ([0:3,4])", threeValuesArrayColumnFour); INDArray threeValuesArrayAgain = fourByFiveRandomZeroToOne.get(NDArrayIndex.interval(0, 2)); print("Get interval from array ([:2])", threeValuesArrayAgain); INDArray allRowsIndexOne = fourByFiveRandomZeroToOne.get(NDArrayIndex.all(), NDArrayIndex.point(1)); print("Get interval from array ([:,1])", allRowsIndexOne);
INDArray first3Columns = originalArray.get(NDArrayIndex.all(), NDArrayIndex.interval(0,3)); System.out.println("first 3 columns:\n" + first3Columns); originalArray.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(2)}, zerosColumn); //All rows, column index 2 System.out.println("\n\n\nOriginal array, after put operation:\n" + originalArray);
public TimeSeriesPrediction interpret(INDArray trueMasks, INDArray trueLabels, INDArray output, int predictionIndex) { assert trueLabels.shape().length == 3 : "True labels should be a 3D array"; assert trueMasks.shape().length == 2 : "True masks should be a 2D array"; assert output.shape().length == 3 : "True labels should be a 3D array"; assert predictionIndex < trueMasks.shape()[0] : "prediction index is out of bounds for true masks"; assert predictionIndex < trueLabels.shape()[0] : "prediction index is out of bounds for true labels"; assert predictionIndex < output.shape()[0] : "prediction index is out of bounds for output"; int maxTrueLabels = trueMasks.getRow(predictionIndex).gt(0).sumNumber().intValue(); INDArray trimmedTrueLabels = trueLabels.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, maxTrueLabels)); INDArray trimmedOutput = output.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, maxTrueLabels)); return interpret(trimmedTrueLabels, trimmedOutput, predictionIndex); } }
private void moveBackward(int index) { int numMoved = size - index - 1; INDArrayIndex[] first = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index ,index + numMoved)}; INDArrayIndex[] getRange = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index + 1 ,index + 1 + numMoved)}; container.put(first,container.get(getRange)); }
INDArrayIndex[] indices = new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.point(0), null, null}; if(this.sentencesAlongHeight) { indices[2] = NDArrayIndex.point(sentenceLength); indices[3] = NDArrayIndex.all(); } else { featuresMask.getRow(i).assign(Double.valueOf(1.0D)); } else { featuresMask.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.interval(0, sentenceLength)}).assign(Double.valueOf(1.0D));
@Override public boolean remove(Object o) { int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); if(idx < 0) return false; container.put(new INDArrayIndex[]{NDArrayIndex.interval(idx,container.length())},container.get(NDArrayIndex.interval(idx + 1,container.length()))); container = container.reshape(1,size); return true; }
private void moveForward(int index) { int numMoved = size - index - 1; INDArrayIndex[] getRange = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index,index + numMoved)}; INDArray get = container.get(getRange); INDArrayIndex[] first = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index + 1,index + 1 + get.length())}; container.put(first,get); }
@Override public org.nd4j.linalg.dataset.api.DataSet getRange(int from, int to) { if (hasMaskArrays()) { INDArray featureMaskHere = featuresMask != null ? featuresMask.get(interval(from, to)) : null; INDArray labelMaskHere = labelsMask != null ? labelsMask.get(interval(from, to)) : null; return new DataSet(features.get(interval(from, to)), labels.get(interval(from, to)), featureMaskHere, labelMaskHere); } return new DataSet(features.get(interval(from, to)), labels.get(interval(from, to))); }
/** * Choose from the inputs based on the given condition. * This returns a row vector of all elements fulfilling the * condition listed within the array for input. * The double and integer arguments are only relevant * for scalar operations (like when you have a scalar * you are trying to compare each element in your input against) * * @param input the input to filter * @param tArgs the double args * @param iArgs the integer args * @param condition the condition to filter based on * @return a row vector of the input elements that are true * ffor the given conditions */ public static INDArray chooseFrom(@NonNull INDArray[] input, @NonNull List<Double> tArgs, @NonNull List<Integer> iArgs, @NonNull Condition condition) { Choose choose = new Choose(input,iArgs,tArgs,condition); Nd4j.getExecutioner().exec(choose); int secondOutput = choose.getOutputArgument(1).getInt(0); if(secondOutput < 1) { return null; } INDArray ret = choose.getOutputArgument(0).get(NDArrayIndex.interval(0,secondOutput)); ret = ret.reshape(ret.length()); return ret; }
private void moveForward(int index) { int numMoved = size - index - 1; INDArrayIndex[] getRange = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index,index + numMoved)}; INDArray get = container.get(getRange).dup(); INDArrayIndex[] first = new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.interval(index + 1,index + 1 + get.length())}; container.put(first,get); }
/** * Get a view of the underlying array * relative to the size of the actual array. * (Sometimes there are overflows in the internals * but you want to use the internal INDArray for computing something * directly, this gives you the relevant subset that reflects the content of the list) * @return the view of the underlying ndarray relative to the collection's real size */ public INDArray array() { return container.get(NDArrayIndex.interval(0,size)).reshape(1,size); }
@Override public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { if (!viewArray.isRowVector()) throw new IllegalArgumentException("Invalid input: expect row vector input"); if (initialize) viewArray.assign(0); long length = viewArray.length(); this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); //Reshape to match the expected shape of the input gradient arrays this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f'); this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f'); if (msg == null || msdx == null) throw new IllegalStateException("Could not correctly reshape gradient view arrays"); }
@Override public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { if (!viewArray.isRowVector()) throw new IllegalArgumentException("Invalid input: expect row vector input"); if (initialize) viewArray.assign(0); long length = viewArray.length(); this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); this.u = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); //Reshape to match the expected shape of the input gradient arrays this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f'); this.u = Shape.newShapeNoCopy(this.u, gradientShape, gradientOrder == 'f'); if (m == null || u == null) throw new IllegalStateException("Could not correctly reshape gradient view arrays"); this.gradientReshapeOrder = gradientOrder; }
/** * Get a view of the underlying array * relative to the size of the actual array. * (Sometimes there are overflows in the internals * but you want to use the internal INDArray for computing something * directly, this gives you the relevant subset that reflects the content of the list) * @return the view of the underlying ndarray relative to the collection's real size */ public INDArray array() { if(isEmpty()) { throw new ND4JIllegalStateException("Array is empty!"); } return container.get(NDArrayIndex.interval(0,size)); }
@Override public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { if (!viewArray.isRowVector()) throw new IllegalArgumentException("Invalid input: expect row vector input"); if (initialize) viewArray.assign(0); val n = viewArray.length() / 3; this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, n)); this.v = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(n, 2*n)); this.vHat = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(2*n, 3*n)); //Reshape to match the expected shape of the input gradient arrays this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f'); this.v = Shape.newShapeNoCopy(this.v, gradientShape, gradientOrder == 'f'); this.vHat = Shape.newShapeNoCopy(this.vHat, gradientShape, gradientOrder == 'f'); if (m == null || v == null || vHat == null) throw new IllegalStateException("Could not correctly reshape gradient view arrays"); this.gradientReshapeOrder = gradientOrder; }