val outShape = Shape.broadcastOutputShape(firstArgShape, secondArgShape);
/** * Calculate the output shape for this op * * @return */ public List<long[]> calculateOutputShape() { List<long[]> ret = new ArrayList<>(); if (larg().getShape() != null && rarg().getShape() != null) ret.add(Shape.broadcastOutputShape(larg().getShape(), rarg().getShape())); else if(larg().getShape() != null) ret.add(larg().getShape()); return ret; }
/** * in place addition of two matrices * * @param other the second ndarray to add * @param result the result ndarray * @return the result of the addition */ @Override public INDArray addi(INDArray other, INDArray result) { if (other.isScalar()) { return result.addi(other.getDouble(0), result); } if (isScalar()) { return other.addi(getDouble(0), result); } if(!Shape.shapeEquals(this.shape(),other.shape())) { int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape()); result = Nd4j.createUninitialized(Shape.broadcastOutputShape(this.shape(),other.shape())); Nd4j.getExecutioner().exec(new BroadcastAddOp(this,other,result,broadcastDimensions),broadcastDimensions); return result; } LinAlgExceptions.assertSameShape(other, result); Nd4j.getExecutioner().exec(new OldAddOp(this, other, result, length())); if (Nd4j.ENFORCE_NUMERICAL_STABILITY) Nd4j.clearNans(result); return result; }
public static long[] checkIsBroadcastable(long[] left, long[] right){ return Shape.broadcastOutputShape(left, right); }
private static INDArray applyScalarTensorOperationWithPreservedShape(INDArray tensor, INDArray scalarTensor, BiFunction<INDArray, INDArray, INDArray> operation) { INDArray result = operation.apply(tensor, scalarTensor.getScalar(0)); long[] resultShape = Shape.broadcastOutputShape(tensor.shape(), scalarTensor.shape()); return result.reshape(resultShape); }
private static void resultShapeMatchesBroadcastShape(Tensor result, Tensor input1, Tensor input2) { long[] broadcastShape = Shape.broadcastOutputShape(input1.getShape(), input2.getShape()); long[] resultShape = result.getShape(); assertThat(broadcastShape, equalTo(resultShape)); } }
private static INDArray performOperationWithScalarTensorPreservingShape(INDArray left, INDArray right, BiFunction<INDArray, INDArray, INDArray> operation) { if (left.length() == 1 || right.length() == 1) { long[] resultShape = Shape.broadcastOutputShape(left.shape(), right.shape()); INDArray result = (left.length() == 1) ? operation.apply(Nd4j.valueArrayOf(right.shape(), left.getDouble(0)), right) : operation.apply(left, Nd4j.valueArrayOf(left.shape(), right.getDouble(0))); return result.reshape(resultShape); } else { return operation.apply(left, right); } }
private static INDArray applyBroadcastOperation(INDArray left, INDArray right, QuadFunction<INDArray, INDArray, INDArray, List<Integer>, INDArray> baseBroadcastOp) { List<Integer> broadcastDimensions = getBroadcastDimensions(left.shape(), right.shape()); INDArray result = Nd4j.create(Shape.broadcastOutputShape(left.shape(), right.shape())); return baseBroadcastOp.apply(left, right, result, broadcastDimensions); }
private static void resultShapeMatchesBroadcastShape(Tensor result, Tensor input1, Tensor input2) { long[] broadcastShape = Shape.broadcastOutputShape(input1.getShape(), input2.getShape()); long[] resultShape = result.getShape(); assertArrayEquals(broadcastShape, resultShape); } }
private static INDArray executeNd4jTransformOpWithPreservedScalarTensorShape(INDArray mask, INDArray right, DataBuffer.Type bufferType, QuadFunction<INDArray, INDArray, INDArray, Long, BaseTransformOp> baseTransformOpConstructor) { if (mask.length() == 1 || right.length() == 1) { long[] resultShape = Shape.broadcastOutputShape(mask.shape(), right.shape()); if (mask.length() == 1) { mask = Nd4j.valueArrayOf(right.shape(), mask.getDouble(0)); Nd4j.getExecutioner().exec( baseTransformOpConstructor.apply(mask, right, mask, mask.length()) ); } else { Nd4j.getExecutioner().exec( baseTransformOpConstructor.apply(mask, valueArrayOf(mask.shape(), right.getDouble(0), bufferType), mask, mask.length() ) ); } return mask.reshape(resultShape); } else { Nd4j.getExecutioner().exec( baseTransformOpConstructor.apply(mask, right, mask, mask.length()) ); return mask; } }