public String predict(Map<String, Object> features) { Vector v = encoder.getVector(features); int o = learn.classifyFull(v).maxValueIndex(); return encoder.outputIntToString(o); }
@Override public void train(long trackingKey, String groupKey, int actual, Vector instance) { record++; int k = 0; for (OnlineLogisticRegression model : models) { if (k == mod(trackingKey, models.size())) { Vector v = model.classifyFull(instance); double score = Math.max(v.get(actual), MIN_SCORE); logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize); int correct = v.maxValueIndex() == actual ? 1 : 0; percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize); if (numCategories() == 2) { auc.addSample(actual, groupKey, v.get(1)); } } else { model.train(trackingKey, groupKey, actual, instance); } k++; } }
@Override public void train(long trackingKey, String groupKey, int actual, Vector instance) { record++; int k = 0; for (OnlineLogisticRegression model : models) { if (k == mod(trackingKey, models.size())) { Vector v = model.classifyFull(instance); double score = Math.max(v.get(actual), MIN_SCORE); logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize); int correct = v.maxValueIndex() == actual ? 1 : 0; percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize); if (numCategories() == 2) { auc.addSample(actual, groupKey, v.get(1)); } } else { model.train(trackingKey, groupKey, actual, instance); } k++; } }
@Override public void train(long trackingKey, String groupKey, int actual, Vector instance) { record++; int k = 0; for (OnlineLogisticRegression model : models) { if (k == mod(trackingKey, models.size())) { Vector v = model.classifyFull(instance); double score = Math.max(v.get(actual), MIN_SCORE); logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize); int correct = v.maxValueIndex() == actual ? 1 : 0; percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize); if (numCategories() == 2) { auc.addSample(actual, groupKey, v.get(1)); } } else { model.train(trackingKey, groupKey, actual, instance); } k++; } }
v.set(2, 1); Vector r = learningAlgo.classifyFull(v); System.out.println(r);
final int target = model.getModel().classifyFull(instance).maxValueIndex(); return model.getCategories().get(target);
int[] count = new int[3]; for (Integer k : test) { int r = lr.classifyFull(data.get(k)).maxValueIndex(); count[r]++; x += r == target.get(k) ? 1 : 0;
assertEquals(1 / 3.0, v.get(1), 1.0e-8); v = lr.classifyFull(new DenseVector(new double[]{0, 0})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(1 / 3.0, v.get(0), 1.0e-8); assertEquals(1 / 3.0, v.get(1), 1.0e-3); v = lr.classifyFull(new DenseVector(new double[]{0, 1})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(1 / 3.0, v.get(0), 1.0e-3); assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8); v = lr.classifyFull(new DenseVector(new double[]{1, 0})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8); v = lr.classifyFull(new DenseVector(new double[]{1, 1})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3); v = lr.classifyFull(new DenseVector(new double[]{1, 1})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);