public SDVariable max(SDVariable i_x, int... dimensions) { return new Max(sameDiff(), i_x, dimensions).outputVariables()[0]; }
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //TODO do we need to handle the "multiple equal maximums" case? //TODO code duplication (min/max) SDVariable out = outputVariables()[0]; int origRank = Shape.rankFromShape(arg().getShape()); SDVariable expandedOut = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, out); expandedOut = sameDiff.onesLike(arg()).mul(expandedOut); SDVariable expandedGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable eq = sameDiff.eq(arg(), expandedOut); SDVariable ret = eq.mul(expandedGrad); return Arrays.asList(ret); }