/** * This method returns index of highest value along specified dimension(s) * * @param dimension * @return */ @Override public INDArray argMax(int... dimension) { return Nd4j.argMax(this, dimension); }
INDArray argMaxAlongDim0 = Nd4j.argMax(originalArray,0); //Index of the max value, along dimension 0 System.out.println("\n\nargmax along dimension 0: " + argMaxAlongDim0); INDArray argMinAlongDim0 = Nd4j.getExecutioner().exec(new IMin(originalArray),0); //Index of the min value, along dimension 0
private List<RecognisedObject> predict(INDArray predictions) { List<RecognisedObject> objects = new ArrayList<>(); int[] topNPredictions = new int[topN]; float[] topNProb = new float[topN]; String outLabels[]=new String[topN]; //brute force collect top N int i = 0; for (int batch = 0; batch < predictions.size(0); batch++) { INDArray currentBatch = predictions.getRow(batch).dup(); while (i < topN) { topNPredictions[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0); topNProb[i] = currentBatch.getFloat(batch, topNPredictions[i]); currentBatch.putScalar(0, topNPredictions[i], 0); outLabels[i]= imageNetLabels.getLabel(topNPredictions[i]); objects.add(new RecognisedObject(outLabels[i], "eng", outLabels[i], topNProb[i])); i++; } } return objects; } }
/** * This method returns index of highest value along specified dimension(s) * * @param dimension * @return */ @Override public INDArray argMax(int... dimension) { return Nd4j.argMax(this, dimension); }
private static void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb) { Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p, outMask); INDArray realOutcomes = pair.getFirst(); INDArray guesses = pair.getSecond(); // Length of real labels must be same as length of predicted labels if (realOutcomes.length() != guesses.length()) throw new IllegalArgumentException( "Unable to evaluate. Outcome matrices not same length"); INDArray guessIndex = Nd4j.argMax(guesses, 1); INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1); int nExamples = guessIndex.length(); for (int i = 0; i < nExamples; i++) { int actual = (int) realOutcomeIndex.getDouble(i); int predicted = (int) guessIndex.getDouble(i); sb.append(actual + "\t" + predicted + "\n"); } } }
int sampleIdx = 0; for (Sample sample : aData) { INDArray argMax = Nd4j.argMax(predicted, 1);
int sampleIdx = 0; for (Sample sample : aData) { INDArray argMax = Nd4j.argMax(predicted, 1);
private void eval(INDArray labels, INDArray p, INDArray outMask, StringBuilder sb) { Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, p, outMask); INDArray realOutcomes = pair.getFirst(); INDArray guesses = pair.getSecond(); // Length of real labels must be same as length of predicted labels if (realOutcomes.length() != guesses.length()) throw new IllegalArgumentException( "Unable to evaluate. Outcome matrices not same length"); INDArray guessIndex = Nd4j.argMax(guesses, 1); INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1); int nExamples = guessIndex.length(); for (int i = 0; i < nExamples; i++) { int actual = (int) realOutcomeIndex.getDouble(i); int predicted = (int) guessIndex.getDouble(i); sb.append( vectorize.getTagset()[actual] + "\t" + vectorize.getTagset()[predicted] + "\n"); } }
INDArray currentBatch = predictions.getRow(batch).dup(); while (i < 5) { top5[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0); top5Prob[i] = currentBatch.getFloat(batch, top5[i]); currentBatch.putScalar(0, top5[i], 0);
/** * Given predictions from the trained model this method will return a string * listing the top five matches and the respective probabilities * @param predictions * @return */ public String decodePredictions(INDArray predictions) { String predictionDescription = ""; int[] top5 = new int[5]; float[] top5Prob = new float[5]; //brute force collect top 5 int i = 0; for (int batch = 0; batch < predictions.size(0); batch++) { predictionDescription += "Predictions for batch "; if (predictions.size(0) > 1) { predictionDescription += String.valueOf(batch); } predictionDescription += " :"; INDArray currentBatch = predictions.getRow(batch).dup(); while (i < 5) { top5[i] = Nd4j.argMax(currentBatch, 1).getInt(0, 0); top5Prob[i] = currentBatch.getFloat(batch, top5[i]); currentBatch.putScalar(0, top5[i], 0); predictionDescription += "\n\t" + String.format("%3f", top5Prob[i] * 100) + "%, " + predictionLabels.get(top5[i]); i++; } } return predictionDescription; }
} else if (costArray != null) { guessIndex = Nd4j.argMax(guesses.mulRowVector(costArray), 1); } else { guessIndex = Nd4j.argMax(guesses, 1); INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1); int nExamples = guessIndex.length(); INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1); int nExamples = realOutcomeIndex.length(); for (int i = 0; i < nExamples; i++) {
INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength); INDArray temp = mask.mulRowVector(row); INDArray lastElementIdx = Nd4j.argMax(temp, 1); fwdPassTimeSteps = new int[fwdPassShape[0]]; for (int i = 0; i < fwdPassTimeSteps.length; i++) {
INDArray lastTimeStepIndices; if (labelsMask != null){ lastTimeStepIndices = Nd4j.argMax(labelsMask, 1); } else { lastTimeStepIndices = Nd4j.zeros(features.size(0), 1);