@Override public List<Double> classify(String text) throws TException { Vector features = new RandomAccessSparseVector(FEATURES); enc.addText(text.toLowerCase()); enc.flush(1, features); bias.addToVector((byte[]) null, 1, features); Vector r = model.classifyFull(features); List<Double> rx = Lists.newArrayList(); for (int i = 0; i < r.size(); i++) { rx.add(r.get(i)); } return rx; }
return classifyFull(new DenseVector(numCategories()), instance);
return classifyFull(new DenseVector(numCategories()), instance);
return classifyFull(new DenseVector(numCategories()), instance);
/** * Returns a matrix where the rows of the matrix each contain {@code n} probabilities, one for each category. * * @param data The matrix whose rows are the input vectors to classify * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category. */ public Matrix classifyFull(Matrix data) { Matrix r = new DenseMatrix(data.numRows(), numCategories()); for (int row = 0; row < data.numRows(); row++) { classifyFull(r.viewRow(row), data.viewRow(row)); } return r; }
/** * Returns a matrix where the rows of the matrix each contain {@code n} probabilities, one for each category. * * @param data The matrix whose rows are the input vectors to classify * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category. */ public Matrix classifyFull(Matrix data) { Matrix r = new DenseMatrix(data.numRows(), numCategories()); for (int row = 0; row < data.numRows(); row++) { classifyFull(r.viewRow(row), data.viewRow(row)); } return r; }
/** * Returns a matrix where the rows of the matrix each contain {@code n} probabilities, one for each category. * * @param data The matrix whose rows are the input vectors to classify * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category. */ public Matrix classifyFull(Matrix data) { Matrix r = new DenseMatrix(data.numRows(), numCategories()); for (int row = 0; row < data.numRows(); row++) { classifyFull(r.viewRow(row), data.viewRow(row)); } return r; }
@Test public void toyData() throws Exception { TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob(); trainNaiveBayes.setConf(conf); trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), "-el", "--tempDir", tempDir.getAbsolutePath() }); NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf); AbstractVectorClassifier classifier = new StandardNaiveBayesClassifier(naiveBayesModel); assertEquals(2, classifier.numCategories()); Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get()); // should be classified as not stolen assertTrue(prediction.get(0) < prediction.get(1)); }
@Test public void toyDataComplementary() throws Exception { TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob(); trainNaiveBayes.setConf(conf); trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), "-el", "--trainComplementary", "--tempDir", tempDir.getAbsolutePath() }); NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf); AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel); assertEquals(2, classifier.numCategories()); Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get()); // should be classified as not stolen assertTrue(prediction.get(0) < prediction.get(1)); }
static void test(Matrix input, Vector target, AbstractVectorClassifier lr, double expected_mean_error, double expected_absolute_error) { // now test the accuracy Matrix tmp = lr.classify(input); // mean(abs(tmp - target)) double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60; // max(abs(tmp - target) double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS); System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError); assertEquals(0, meanAbsoluteError , expected_mean_error); assertEquals(0, maxAbsoluteError, expected_absolute_error); // convenience methods should give the same results Vector v = lr.classifyScalar(input); assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-5); v = lr.classifyFull(input).viewColumn(1); assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-4); }