public static void checkForInf(INDArray z) { if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) return; int match = 0; if (!z.isScalar()) { MatchCondition condition = new MatchCondition(z, Conditions.isInfinite()); match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0); } else { if (z.data().dataType() == DataBuffer.Type.DOUBLE) { if (Double.isInfinite(z.getDouble(0))) match = 1; } else { if (Float.isInfinite(z.getFloat(0))) match = 1; } } if (match > 0) throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)"); }
private INDArray calcLogProbArray(INDArray x, INDArray preOutDistributionParams) { INDArray output = preOutDistributionParams.dup(); activationFn.getActivation(output, false); INDArray logOutput = Transforms.log(output, true); INDArray log1SubOut = Transforms.log(output.rsubi(1.0), false); //For numerical stability: if output = 0, then log(output) == -infinity //then x * log(output) = NaN, but lim(x->0, output->0)[ x * log(output) ] == 0 // therefore: want 0*log(0) = 0, NOT 0*log(0) = NaN by default BooleanIndexing.replaceWhere(logOutput, 0.0, Conditions.isInfinite()); //log(out)= +/- inf -> x == 0.0 -> 0 * log(0) = 0 BooleanIndexing.replaceWhere(log1SubOut, 0.0, Conditions.isInfinite()); //log(out)= +/- inf -> x == 0.0 -> 0 * log(0) = 0 return logOutput.muli(x).addi(x.rsub(1.0).muli(log1SubOut)); }
private static INDArray logZ(INDArray z) { INDArray log = log(z, true); // log approaches -Infinity as z approaches zero. Replace -Infinity with the least possible value. // Caveat: does not handle +Infinity since z is assumed to be 0 <= z <= 1. switch (log.data().dataType()) { case FLOAT: BooleanIndexing.applyWhere(log, new Or(Conditions.isNan(), Conditions.isInfinite()), new StableNumber(StableNumber.Type.FLOAT)); break; case DOUBLE: BooleanIndexing.applyWhere(log, new Or(Conditions.isNan(), Conditions.isInfinite()), new StableNumber(StableNumber.Type.DOUBLE)); break; case INT: BooleanIndexing.applyWhere(log, new Or(Conditions.isNan(), Conditions.isInfinite()), new Value(-Integer.MAX_VALUE)); break; default: throw new RuntimeException("unsupported data type: " + log.data().dataType()); } return log; }
public static void checkForInf(INDArray z) { if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) return; int match = 0; if (!z.isScalar()) { MatchCondition condition = new MatchCondition(z, Conditions.isInfinite()); match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0); } else { if (z.data().dataType() == DataBuffer.Type.DOUBLE) { if (Double.isInfinite(z.getDouble(0))) match = 1; } else { if (Float.isInfinite(z.getFloat(0))) match = 1; } } if (match > 0) throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)"); }