public OnlineLogisticRegression copy() { close(); OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), numFeatures(), prior); r.copyFrom(this); return r; }
public CrossFoldLearner lambda(double v) { for (OnlineLogisticRegression model : models) { model.lambda(v); } return this; }
public CrossFoldLearner learningRate(double x) { for (OnlineLogisticRegression model : models) { model.learningRate(x); } return this; }
public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) { this.numFeatures = numFeatures; this.prior = prior; for (int i = 0; i < folds; i++) { OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior); model.alpha(1).stepOffset(0).decayExponent(0); models.add(model); } }
OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression(); learningAlgo = new OnlineLogisticRegression(2, 3, new L1()); learningAlgo.lambda(0.1); learningAlgo.learningRate(10); learningAlgo.train(points.get(point), v); learningAlgo.close(); v.set(2, 1); Vector r = learningAlgo.classifyFull(v); System.out.println(r); System.out.println("no of categories = " + learningAlgo.numCategories()); System.out.println("no of features = " + learningAlgo.numFeatures()); System.out.println("Probability of cluster 0 = " + r.get(0)); System.out.println("Probability of cluster 1 = " + r.get(1));
OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) .lambda(1 * 1.0e-3) .stepOffset(11) .alpha(0.01) .learningRate(50) .decayExponent(-0.02); lr.close(); Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7);
throw new BadClassifierSpecException("Must have more than one target category. Remember that categories is a space separated list"); model = new OnlineLogisticRegression(categories.size(), Integer.parseInt(options.get("features")), new L1()); options.remove("categories"); options.remove("features"); model.decayExponent(Double.parseDouble(options.get("decayExponent"))); options.remove("decayExponent"); model.lambda(Double.parseDouble(options.get("lambda"))); options.remove("lambda"); model.stepOffset(Integer.parseInt(options.get("stepOffset"))); options.remove("stepOffset"); model.learningRate(Double.parseDouble(options.get("learningRate"))); options.remove("learningRate");
Vector v = getVector(point); System.out.println(point + " belongs to " + points.get(point)); learningAlgo.train(points.get(point), v); learningAlgo.close(); v.set(2, 1); Vector r = learningAlgo.classify(v); System.out.println(r); System.out.println("no of categories = " + learningAlgo.numCategories()); System.out.println("no of features = " + learningAlgo.numFeatures()); System.out.println("Probability of cluster 0 = " + (1.0d - r.get(0))); System.out.println("Probability of cluster 1 = " + r.get(0));
@Test public void testTrain() throws Exception { Vector target = readStandardData(); // lambda here needs to be relatively small to avoid swamping the actual signal, but can be // larger than usual because the data are dense. The learning rate doesn't matter too much // for this example, but should generally be < 1 // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias // --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) .lambda(1 * 1.0e-3) .learningRate(50); train(getInput(), target, lr); test(getInput(), target, lr, 0.05, 0.3); }
@Override public void train(long trackingKey, String groupKey, int actual, Vector instance) { record++; int k = 0; for (OnlineLogisticRegression model : models) { if (k == mod(trackingKey, models.size())) { Vector v = model.classifyFull(instance); double score = Math.max(v.get(actual), MIN_SCORE); logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize); int correct = v.maxValueIndex() == actual ? 1 : 0; percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize); if (numCategories() == 2) { auc.addSample(actual, groupKey, v.get(1)); } } else { model.train(trackingKey, groupKey, actual, instance); } k++; } }
OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); lr.train(target.get(k), data.get(k)); int[] count = new int[3]; for (Integer k : test) { int r = lr.classifyFull(data.get(k)).maxValueIndex(); count[r]++; x += r == target.get(k) ? 1 : 0;
OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1)); lr.setBeta(0, 0, -1); lr.setBeta(1, 0, -2); Vector v = lr.classify(new DenseVector(new double[]{0, 0})); assertEquals(1 / 3.0, v.get(0), 1.0e-8); assertEquals(1 / 3.0, v.get(1), 1.0e-8); v = lr.classifyFull(new DenseVector(new double[]{0, 0})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(1 / 3.0, v.get(0), 1.0e-8); v = lr.classify(new DenseVector(new double[]{0, 1})); assertEquals(1 / 3.0, v.get(0), 1.0e-3); assertEquals(1 / 3.0, v.get(1), 1.0e-3); v = lr.classifyFull(new DenseVector(new double[]{0, 1})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(1 / 3.0, v.get(0), 1.0e-3); v = lr.classify(new DenseVector(new double[]{1, 0})); assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8); assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8); v = lr.classifyFull(new DenseVector(new double[]{1, 0})); assertEquals(1.0, v.zSum(), 1.0e-8); assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
@Test public void onlineLogisticRegressionRoundTrip() throws IOException { OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1()); train(olr, 100); OnlineLogisticRegression olr3 = roundTrip(olr, OnlineLogisticRegression.class); assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); train(olr, 100); train(olr3, 100); assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); olr.close(); olr3.close(); }
@Override public void readFields(DataInput in) throws IOException { record = in.readInt(); auc = PolymorphicWritable.read(in, OnlineAuc.class); logLikelihood = in.readDouble(); int n = in.readInt(); for (int i = 0; i < n; i++) { OnlineLogisticRegression olr = new OnlineLogisticRegression(); olr.readFields(in); models.add(olr); } parameters = new double[4]; for (int i = 0; i < 4; i++) { parameters[i] = in.readDouble(); } numFeatures = in.readInt(); prior = PolymorphicWritable.read(in, PriorFunction.class); percentCorrect = in.readDouble(); windowSize = in.readInt(); } }
@Override public void close() { for (OnlineLogisticRegression m : models) { m.close(); } }
OnlineLogisticRegression getLearner() { return new OnlineLogisticRegression(catToInt.size(), vectorSize, getPrior()); }
public CrossFoldLearner decayExponent(double x) { for (OnlineLogisticRegression model : models) { model.decayExponent(x); } return this; }
public CrossFoldLearner stepOffset(int x) { for (OnlineLogisticRegression model : models) { model.stepOffset(x); } return this; }
public void add(String output, Map<String, Object> features) { Vector v = encoder.getVector(features); int o = encoder.outputStringToInt(output); learn.train(o, v); }