public SDVariable max(SDVariable i_x, int... dimensions) { return new Max(sameDiff(), i_x, dimensions).outputVariables()[0]; }
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //TODO do we need to handle the "multiple equal maximums" case? //TODO code duplication (min/max) SDVariable out = outputVariables()[0]; int origRank = Shape.rankFromShape(arg().getShape()); SDVariable expandedOut = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, out); expandedOut = sameDiff.onesLike(arg()).mul(expandedOut); SDVariable expandedGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable eq = sameDiff.eq(arg(), expandedOut); SDVariable ret = eq.mul(expandedGrad); return Arrays.asList(ret); }
@Override public Op opForDimension(int index, int dimension) { INDArray xAlongDimension = x.vectorAlongDimension(index, dimension); if (y() != null) return new Max(xAlongDimension, y.vectorAlongDimension(index, dimension), xAlongDimension.length()); else return new Max(x.vectorAlongDimension(index, dimension)); }
/** * Returns the overall max of this ndarray * * @param dimension the dimension to getScalar the mean along * @return the mean along the specified dimension of this ndarray */ @Override public INDArray max(int... dimension) { return Nd4j.getExecutioner().exec(new Max(this), dimension); }
@Override public Op opForDimension(int index, int... dimension) { INDArray xAlongDimension = x.tensorAlongDimension(index, dimension); if (y() != null) return new Max(xAlongDimension, y.tensorAlongDimension(index, dimension), xAlongDimension.length()); else return new Max(x.tensorAlongDimension(index, dimension)); } }
/** * Returns the overall max of this ndarray * * @param dimension the dimension to getScalar the mean along * @return the mean along the specified dimension of this ndarray */ @Override public INDArray max(int... dimension) { return Nd4j.getExecutioner().exec(new Max(this), dimension); }
private static double getAverageFloatMaxArray(INDArray array) { INDArray max = Nd4j.getExecutioner().exec(new Max(array), 0); int maxValidIndex = max.gt(0).sumNumber().intValue(); INDArray truncatedMax = maxValidIndex > 0 ? max.get(NDArrayIndex.all(), NDArrayIndex.interval(0, maxValidIndex)) : max; return Nd4j.getExecutioner().execAndReturn(new Mean(truncatedMax)).getFinalResult().doubleValue(); }
break; case "max": ret = new Max(x, y, z,x.length()); break; case "min":