public Classifier emptyCopy() { NaiveBayesClassifier nb = new NaiveBayesClassifier(); nb.setInputLabel(getInputLabel()); nb.setTargetLabel(getTargetLabel()); return nb; } }
public Map<String, Object> calculateObjects(Map<String, Object> input) { Map<String, Object> result = new HashMap<String, Object>(); result.put(TARGET, new NaiveBayesClassifier()); return result; }
public static void main(String[] args) { // load example data set ListDataSet dataSet = DataSet.Factory.IRIS(); // create a classifier NaiveBayesClassifier classifier = new NaiveBayesClassifier(); // train the classifier using all data classifier.trainAll(dataSet); // use the classifier to make predictions classifier.predictAll(dataSet); // get the results double accurary = dataSet.getAccuracy(); System.out.println("accuracy: " + accurary); } }
public void trainAll(ListDataSet dataSet) { System.out.println("training started"); int featureCount = (int) dataSet.get(0).getAsMatrix(getInputLabel()).getValueCount(); boolean discrete = isDiscrete(dataSet); classCount = getClassCount(dataSet); for (Sample s : dataSet) { final Matrix sampleInput = s.getAsMatrix(getInputLabel()).toColumnVector(Ret.LINK); final Matrix sampleTarget = s.getAsMatrix(getTargetLabel()).toColumnVector(Ret.LINK); final double weight = s.getWeight();
public Object call() { try { Classifier lr = new NaiveBayesClassifier(); lr.trainAll((ListDataSet) getCoreObject()); lr.predictAll((ListDataSet) getCoreObject()); } catch (Exception e) { e.printStackTrace(); } return null; } }
@Test public void testIrisClassification() throws Exception { ListDataSet iris = ListDataSet.Factory.IRIS(); Classifier c = new NaiveBayesClassifier(); ListMatrix<Double> results = CrossValidation.run(c, iris, 10, 10, 0); assertEquals(0.959, results.getMeanValue(), 0.04); }