/** * Whether an op call is in place or not. * * @param reallyCall * @return */ public DynamicCustomOpsBuilder callInplace(boolean reallyCall) { if (reallyCall && !inplaceAllowed) throw new ND4JIllegalStateException("Requested op can't be called inplace"); this.inplaceCall = reallyCall; return this; }
public static int rankFromShape(long[] shape){ if(shape == null){ throw new ND4JIllegalStateException("Cannot get rank from null shape array"); } return shape.length; }
public static void checkShapeValues(int[] shape) { for (int e: shape) { if (e < 1) throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape) + " contains dimension size values < 1 (all dimensions must be 1 or more)"); } }
@Override public void addOutputArgument(INDArray... arg) { for (int i = 0; i < arg.length; i++) { if (arg[i] == null) throw new ND4JIllegalStateException("Output " + i + " was null!"); } outputArguments.addAll(Arrays.asList(arg)); }
/** * Get the function by the {@link DifferentialFunction#getOwnName()} * * @param id the id of the function * @return the function for the given id if it exists */ public DifferentialFunction getFunctionById(String id) { if (!functionInstancesById.containsKey(id)) { throw new ND4JIllegalStateException("No function with id " + id + " found!"); } return functionInstancesById.get(id); }
/** * This method sums given arrays and stores them to a new target array * * @param arrays * @return */ public static INDArray accumulate(Collection<INDArray> arrays) { if (arrays == null|| arrays.size() == 0) throw new ND4JIllegalStateException("Input for accumulation is null or empty"); return accumulate(arrays.toArray(new INDArray[0])); }
public static void assertBroadcastable(@NonNull int[] x, @NonNull int[] y){ if(!areShapesBroadcastable(x, y)){ throw new ND4JIllegalStateException("Arrays are different shape and are not broadcastable." + " Array 1 shape = " + Arrays.toString(x) + ", array 2 shape = " + Arrays.toString(y)); } }
/** * Create a complex array from the given numbers * @param iComplexNumbers the numbers to use * @return the complex numbers */ public static IComplexNDArray createComplex(IComplexNumber[] iComplexNumbers) { if (iComplexNumbers == null || iComplexNumbers.length < 1) throw new ND4JIllegalStateException("Number of complex numbers can't be < 1 for new INDArray"); return createComplex(iComplexNumbers, new int[] {1, iComplexNumbers.length}); }
private static long[] targetShape(long[] shape, double eps, int targetDimension, boolean auto){ long components = targetDimension; if (auto) components = johnsonLindenStraussMinDim(shape[0], eps).get(0); // JL or user spec edge cases if (auto && (components <= 0 || components > shape[1])){ throw new ND4JIllegalStateException(String.format("Estimation led to a target dimension of %d, which is invalid", components)); } return new long[]{ shape[1], components}; }
public static void assertBroadcastable(@NonNull long[] x, @NonNull long[] y){ if(!areShapesBroadcastable(x, y)){ throw new ND4JIllegalStateException("Arrays are different shape and are not broadcastable." + " Array 1 shape = " + Arrays.toString(x) + ", array 2 shape = " + Arrays.toString(y)); } }
/** * The left argument for this function * @return */ public SDVariable larg() { val args = args(); if(args == null || args.length == 0) throw new ND4JIllegalStateException("No arguments found."); return args()[0]; }
@Override public INDArray doCreate(long[] shape, INDArray paramsView) { if(shape == null) { throw new ND4JIllegalStateException("Shape must not be null!"); } return Nd4j.createUninitialized(shape, order()).assign(0.0); }
@Override public List<long[]> calculateOutputShape() { List<long[]> ret = new ArrayList<>(1); if(arg() == null) throw new ND4JIllegalStateException("No arg found for op!"); val arr = sameDiff.getArrForVarName(arg().getVarName()); if(arr == null) return Collections.emptyList(); ret.add(arr.shape()); this.n = arr.length(); return ret; }
@Override public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) { for(OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) { if(attributeProto.getName().equals(key)) { return attributeProto.getS().toString(); } } throw new ND4JIllegalStateException("No key found for " + key); }
public SDVariable tile(SDVariable iX, int[] repeat) { if (repeat == null) { throw new ND4JIllegalStateException("Repeat must not be null!"); } return new Tile(sameDiff(), iX, repeat).outputVariables()[0]; }
@Override public List<long[]> calculateOutputShape() { if(args().length < 1) { throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found."); } if(arg().getShape() == null) return Collections.emptyList(); List<long[]> ret = new ArrayList<>(1); val reducedShape = Shape.getReducedShape(arg().getShape(),dimensions, isKeepDims(), newFormat); ret.add(reducedShape); return ret; }
@Override public Number percentileNumber(Number quantile) { if (quantile.intValue() < 0 || quantile.intValue() > 100) throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); if (isScalar()) return this.getDouble(0); INDArray sorted = Nd4j.sort(this.dup(this.ordering()), true); return getPercentile(quantile, sorted); }
/** * This method sums given arrays and stores them to a new array * * @param arrays * @return */ public static INDArray accumulate(INDArray... arrays) { if (arrays == null|| arrays.length == 0) throw new ND4JIllegalStateException("Input for accumulation is null or empty"); return accumulate(Nd4j.create(arrays[0].shape(), arrays[0].ordering()), arrays); }
public static void validateDataType(DataBuffer.Type expectedType, INDArray... operands) { if (operands == null || operands.length == 0) return; int cnt = 0; for (INDArray operand : operands) { if (operand == null) continue; if (operand.data().dataType() != expectedType) throw new ND4JIllegalStateException("INDArray [" + cnt++ + "] dataType is [" + operand.data().dataType() + "] instead of expected [" + expectedType + "]"); } }
@Override public double[] toDoubleVector() { if(!isVector()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector!"); } return dup().data().asDouble(); }