@Override public int lastIndexOf(Object o) { return BooleanIndexing.lastIndex(container,new EqualsCondition((double) o)).getInt(0); }
@Override public int indexOf(Object o) { return BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); }
@Override public int indexOf(Object o) { return BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); }
@Override public int lastIndexOf(Object o) { return BooleanIndexing.lastIndex(container,new EqualsCondition((double) o)).getInt(0); }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val numSplits = (int) attributesForNode.get("num_split").getI(); this.numSplit = numSplits; addIArgument(numSplits); val splitDim = TFGraphMapper.getInstance().getArrayFrom(TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); if(splitDim != null) { this.splitDim = splitDim.getInt(0); addIArgument(splitDim.getInt(0)); } }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val numSplits = (int) attributesForNode.get("num_split").getI(); this.numSplit = numSplits; val splitDim = TFGraphMapper.getInstance().getArrayFrom(TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); if(splitDim != null) { this.splitDim = splitDim.getInt(0); addIArgument(splitDim.getInt(0)); } addIArgument(numSplits); }
@Override public List<long[]> calculateOutputShape() { int numArgs = args().length; if(numArgs < 1) return Collections.emptyList(); val shape = args()[0].getArr(); if(shape == null) return Collections.emptyList(); else { if(shape.length() == 1) { if(shape.getDouble(0) < 1) return Arrays.asList(new long[]{1,1}); else return Arrays.asList(new long[]{1,shape.getInt(0)}); } } return Arrays.asList(shape.data().asLong()); }
/** * Create an n dimensional index * based on the given interval indices. * Start and end represent the begin and * end of each interval * @param start the start indexes * @param end the end indexes * @return the interval index relative to the given * start and end indices */ public static INDArrayIndex[] createFromStartAndEnd(INDArray start, INDArray end) { if (start.length() != end.length()) throw new IllegalArgumentException("Start length must be equal to end length"); else { if (start.length() > Integer.MAX_VALUE) throw new ND4JIllegalStateException("Can't proceed with INDArray with length > Integer.MAX_VALUE"); INDArrayIndex[] indexes = new INDArrayIndex[(int) start.length()]; for (int i = 0; i < indexes.length; i++) { indexes[i] = NDArrayIndex.interval(start.getInt(i), end.getInt(i)); } return indexes; } }
/** * Choose from the inputs based on the given condition. * This returns a row vector of all elements fulfilling the * condition listed within the array for input * @param input the input to filter * @param condition the condition to filter based on * @return a row vector of the input elements that are true * ffor the given conditions */ public static INDArray chooseFrom(@NonNull INDArray[] input,@NonNull Condition condition) { Choose choose = new Choose(input,condition); Nd4j.getExecutioner().exec(choose); int secondOutput = choose.getOutputArgument(1).getInt(0); if(secondOutput < 1) { return null; } return choose.getOutputArgument(0); }
@Override public void potrf(INDArray A, boolean lower) { // FIXME: int cast if (A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); byte uplo = (byte) (lower ? 'L' : 'U'); // upper or lower part of the factor desired ? int n = (int) A.columns(); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst()); if (A.data().dataType() == DataBuffer.Type.DOUBLE) dpotrf(uplo, n, A, INFO); else if (A.data().dataType() == DataBuffer.Type.FLOAT) spotrf(uplo, n, A, INFO); else throw new UnsupportedOperationException(); if (INFO.getInt(0) < 0) { throw new Error("Parameter #" + INFO.getInt(0) + " to potrf() was not valid"); } else if (INFO.getInt(0) > 0) { throw new Error("The matrix is not positive definite! (potrf fails @ order " + INFO.getInt(0) + ")"); } return; }
@Override public void gesvd(INDArray A, INDArray S, INDArray U, INDArray VT) { // FIXME: int cast if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); int m = (int) A.rows(); int n = (int) A.columns(); byte jobu = (byte) (U == null ? 'N' : 'A'); byte jobvt = (byte) (VT == null ? 'N' : 'A'); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst()); if (A.data().dataType() == DataBuffer.Type.DOUBLE) dgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO); else if (A.data().dataType() == DataBuffer.Type.FLOAT) sgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO); else throw new UnsupportedOperationException(); if (INFO.getInt(0) < 0) { throw new Error("Parameter #" + INFO.getInt(0) + " to gesvd() was not valid"); } else if (INFO.getInt(0) > 0) { log.warn("The matrix contains singular elements. Check S matrix at row " + INFO.getInt(0)); } }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val inputOne = nodeDef.getInput(1); val varFor = initWith.getVariable(inputOne); val nodeWithIndex = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,inputOne); val var = TFGraphMapper.getInstance().getArrayFrom(nodeWithIndex,graph); if(var != null) { val idx = var.getInt(0); addIArgument(idx); } }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val idd = nodeDef.getInput(nodeDef.getInputCount() - 1); NodeDef iddNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(idd)) { iddNode = graph.getNode(i); } } val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph); if (arr != null) { int idx = arr.getInt(0); addIArgument(idx); } }
@Override public INDArray getPFactor(int M, INDArray ipiv) { // The simplest permutation is the identity matrix INDArray P = Nd4j.eye(M); // result is a square matrix with given size for (int i = 0; i < ipiv.length(); i++) { int pivot = ipiv.getInt(i) - 1; // Did we swap row #i with anything? if (pivot > i) { // don't reswap when we get lower down in the vector INDArray v1 = P.getColumn(i).dup(); // because of row vs col major order we'll ... INDArray v2 = P.getColumn(pivot); // ... make a transposed matrix immediately P.putColumn(i, v2); P.putColumn(pivot, v1); // note dup() above is required - getColumn() is a 'view' } } return P; // the permutation matrix - contains a single 1 in any row and column }
@Override public void geqrf(INDArray A, INDArray R) { // FIXME: int cast if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); int m = (int) A.rows(); int n = (int) A.columns(); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst()); if (R.rows() != A.columns() || R.columns() != A.columns()) { throw new Error("geqrf: R must be N x N (n = columns in A)"); } if (A.data().dataType() == DataBuffer.Type.DOUBLE) { dgeqrf(m, n, A, R, INFO); } else if (A.data().dataType() == DataBuffer.Type.FLOAT) { sgeqrf(m, n, A, R, INFO); } else { throw new UnsupportedOperationException(); } if (INFO.getInt(0) < 0) { throw new Error("Parameter #" + INFO.getInt(0) + " to getrf() was not valid"); } }
@Override public void resolvePropertiesFromSameDiffBeforeExecution() { val propertiesToResolve = sameDiff.propertiesToResolveForFunction(this); if(!propertiesToResolve.isEmpty()) { val varName = propertiesToResolve.get(0); val var = sameDiff.getVariable(varName); if(var == null) { throw new ND4JIllegalStateException("No variable found with name " +varName); } else if(var.getArr() == null) { throw new ND4JIllegalStateException("Array with variable name " + varName + " unset!"); } concatDimension = var.getArr().getInt(0); addIArgument(concatDimension); } //don't pass both iArg and last axis down to libnd4j if(inputArguments().length == args().length) { val inputArgs = inputArguments(); removeInputArgument(inputArgs[inputArguments().length - 1]); } }
@Override public boolean remove(Object o) { int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); if(idx < 0) return false; container.put(new INDArrayIndex[]{NDArrayIndex.interval(idx,container.length())},container.get(NDArrayIndex.interval(idx + 1,container.length()))); container = container.reshape(1,size); return true; }
@Override public boolean remove(Object o) { int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); if(idx < 0) return false; container.put(new INDArrayIndex[]{NDArrayIndex.interval(idx,container.length())},container.get(NDArrayIndex.interval(idx + 1,container.length()))); container = container.reshape(1,size); return true; }
public static void checkForInf(INDArray z) { if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) return; int match = 0; if (!z.isScalar()) { MatchCondition condition = new MatchCondition(z, Conditions.isInfinite()); match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0); } else { if (z.data().dataType() == DataBuffer.Type.DOUBLE) { if (Double.isInfinite(z.getDouble(0))) match = 1; } else { if (Float.isInfinite(z.getFloat(0))) match = 1; } } if (match > 0) throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)"); }
public static void checkForNaN(INDArray z) { if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) return; int match = 0; if (!z.isScalar()) { MatchCondition condition = new MatchCondition(z, Conditions.isNan()); match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0); } else { if (z.data().dataType() == DataBuffer.Type.DOUBLE) { if (Double.isNaN(z.getDouble(0))) match = 1; } else { if (Float.isNaN(z.getFloat(0))) match = 1; } } if (match > 0) throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): "); }