@Override public List<RecognisedObject> recognise(InputStream stream, ContentHandler handler, Metadata metadata, ParseContext context) throws IOException, SAXException, TikaException { INDArray image = imageLoader.asMatrix(stream); preProcessor.transform(image); INDArray[] output = model.output(false, image); return predict(output[0]); } private List<RecognisedObject> predict(INDArray predictions)
public RecordReader getRecordReader(int batchSize, int numExamples, PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, Random rng) { return getRecordReader(numExamples, batchSize, new int[] {height, width, channels}, useSubset ? SUB_NUM_LABELS : NUM_LABELS, labelGenerator, train, splitTrainTest, rng); }
@Override public void reset() { exampleCount = 0; overshot = false; batchNum = 0; loader.reset(); }
@Override public void initialize(Map<String, Param> params) throws TikaConfigException { //STEP 1: resolve weights file, download if necessary modelWeightsPath = mayBeDownloadFile(modelWeightsPath); //STEP 2: Load labels map try (InputStream stream = retrieveResource(mayBeDownloadFile(labelFile))) { this.labelMap = loadClassIndex(stream); } catch (IOException | ParseException e) { LOG.error("Could not load labels map", e); return; } //STEP 3: initialize the graph try { this.imageLoader = new NativeImageLoader(imgHeight, imgWidth, imgChannels); LOG.info("Going to load Inception network..."); long st = System.currentTimeMillis(); KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelWeightsPath) .enforceTrainingConfig(false); builder.inputShape(new int[]{imgHeight, imgWidth, 3}); KerasModel model = builder.buildModel(); this.graph = model.getComputationGraph(); long time = System.currentTimeMillis() - st; LOG.info("Loaded the Inception model. Time taken={}ms", time); } catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) { throw new TikaConfigException(e.getMessage(), e); } }
public DataSet next(int batchSize) { return next(batchSize, 0); }
@Override public List<String> getLabels() { return loader.getLabels(); }
/** * Returns {@code asMatrix(image, false)}. */ public INDArray asMatrix(BufferedImage image) throws IOException { return asMatrix(image, false); }
@Override public INDArray asMatrix(Object image) throws IOException { return image instanceof Bitmap ? asMatrix((Bitmap) image) : null; }
@Override public INDArray asRowVector(Object image) throws IOException { return image instanceof BufferedImage ? asRowVector((BufferedImage) image) : null; }
public void reset() { numExamples = 0; fileNum = 0; load(); }
@Override public INDArray asRowVector(Object image) throws IOException { return image instanceof Bitmap ? asRowVector((Bitmap) image) : null; }
public void asMatrixView(Mat image, INDArray view) throws IOException { transformImage(image, view); }
protected BufferedImage scalingIfNeed(BufferedImage image, boolean needAlpha) { return scalingIfNeed(image, height, width, needAlpha); }
@Override public List<RecognisedObject> recognise( InputStream stream, ContentHandler handler, Metadata metadata, ParseContext context) throws IOException, SAXException, TikaException { INDArray image = preProcessImage(imageLoader.asMatrix(stream)); INDArray scores = graph.outputSingle(image); List<RecognisedObject> result = new ArrayList<>(); for (int i = 0; i < scores.length(); i++) { if (scores.getDouble(i) > minConfidence) { String label = labelMap.get(i); String id = i + ""; result.add(new RecognisedObject(label, labelLang, id, scores.getDouble(i))); LOG.debug("Found Object {}", label); } } return result; } }
public RecordReader getRecordReader(int batchSize, int numExamples, int[] imgDim, PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, Random rng) { return getRecordReader(numExamples, batchSize, imgDim, useSubset ? SUB_NUM_LABELS : NUM_LABELS, labelGenerator, train, splitTrainTest, rng); }
@Override public void reset() { exampleCount = 0; overshot = false; batchNum = 0; loader.reset(); }
@Override public List<String> getLabels() { return loader.getLabels(); }
@Override public INDArray asMatrix(Object image) throws IOException { return image instanceof BufferedImage ? asMatrix((BufferedImage) image) : null; }
public RecordReader getRecordReader(int batchSize, int numExamples, int numLabels, Random rng) { return getRecordReader(numExamples, batchSize, new int[] {height, width, channels}, numLabels, LABEL_PATTERN, true, 1, rng); }
public RecordReader getRecordReader(int batchSize, int numExamples, int[] imgDim, boolean train, double splitTrainTest, Random rng) { return getRecordReader(numExamples, batchSize, imgDim, useSubset ? SUB_NUM_LABELS : NUM_LABELS, LABEL_PATTERN, train, splitTrainTest, rng); }