/** * Expand the array dimensions. * This is equivalent to * adding a new axis dimension * @param input the input array * @param dimension the dimension to add the * new axis at * @return the array with the new axis dimension */ public static INDArray expandDims(INDArray input,int dimension) { if(dimension < 0) dimension += input.rank(); INDArrayIndex[] indexes = new INDArrayIndex[input.rank()]; for(int i = 0; i < indexes.length; i++) indexes[i] = NDArrayIndex.all(); indexes[dimension] = NDArrayIndex.newAxis(); return input.get(indexes); }