OnlineLogisticRegression getLearner() { return new OnlineLogisticRegression(catToInt.size(), vectorSize, getPrior()); }
@Override public void readFields(DataInput in) throws IOException { record = in.readInt(); auc = PolymorphicWritable.read(in, OnlineAuc.class); logLikelihood = in.readDouble(); int n = in.readInt(); for (int i = 0; i < n; i++) { OnlineLogisticRegression olr = new OnlineLogisticRegression(); olr.readFields(in); models.add(olr); } parameters = new double[4]; for (int i = 0; i < 4; i++) { parameters[i] = in.readDouble(); } numFeatures = in.readInt(); prior = PolymorphicWritable.read(in, PriorFunction.class); percentCorrect = in.readDouble(); windowSize = in.readInt(); } }
public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) { this.numFeatures = numFeatures; this.prior = prior; for (int i = 0; i < folds; i++) { OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior); model.alpha(1).stepOffset(0).decayExponent(0); models.add(model); } }
@Override public void readFields(DataInput in) throws IOException { record = in.readInt(); auc = PolymorphicWritable.read(in, OnlineAuc.class); logLikelihood = in.readDouble(); int n = in.readInt(); for (int i = 0; i < n; i++) { OnlineLogisticRegression olr = new OnlineLogisticRegression(); olr.readFields(in); models.add(olr); } parameters = new double[4]; for (int i = 0; i < 4; i++) { parameters[i] = in.readDouble(); } numFeatures = in.readInt(); prior = PolymorphicWritable.read(in, PriorFunction.class); percentCorrect = in.readDouble(); windowSize = in.readInt(); } }
@Override public void readFields(DataInput in) throws IOException { record = in.readInt(); auc = PolymorphicWritable.read(in, OnlineAuc.class); logLikelihood = in.readDouble(); int n = in.readInt(); for (int i = 0; i < n; i++) { OnlineLogisticRegression olr = new OnlineLogisticRegression(); olr.readFields(in); models.add(olr); } parameters = new double[4]; for (int i = 0; i < 4; i++) { parameters[i] = in.readDouble(); } numFeatures = in.readInt(); prior = PolymorphicWritable.read(in, PriorFunction.class); percentCorrect = in.readDouble(); windowSize = in.readInt(); } }
public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) { this.numFeatures = numFeatures; this.prior = prior; for (int i = 0; i < folds; i++) { OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior); model.alpha(1).stepOffset(0).decayExponent(0); models.add(model); } }
public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) { this.numFeatures = numFeatures; this.prior = prior; for (int i = 0; i < folds; i++) { OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior); model.alpha(1).stepOffset(0).decayExponent(0); models.add(model); } }
public CrossFoldLearner copy() { CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(), numFeatures, prior); r.models.clear(); for (OnlineLogisticRegression model : models) { model.close(); OnlineLogisticRegression newModel = new OnlineLogisticRegression(model.numCategories(), model.numFeatures(), model.prior); newModel.copyFrom(model); r.models.add(newModel); } return r; }
throw new BadClassifierSpecException("Must have more than one target category. Remember that categories is a space separated list"); model = new OnlineLogisticRegression(categories.size(), Integer.parseInt(options.get("features")), new L1()); options.remove("categories"); options.remove("features");
public CrossFoldLearner copy() { CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(), numFeatures, prior); r.models.clear(); for (OnlineLogisticRegression model : models) { model.close(); OnlineLogisticRegression newModel = new OnlineLogisticRegression(model.numCategories(), model.numFeatures(), model.prior); newModel.copyFrom(model); r.models.add(newModel); } return r; }
public CrossFoldLearner copy() { CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(), numFeatures, prior); r.models.clear(); for (OnlineLogisticRegression model : models) { model.close(); OnlineLogisticRegression newModel = new OnlineLogisticRegression(model.numCategories(), model.numFeatures(), model.prior); newModel.copyFrom(model); r.models.add(newModel); } return r; }
public OnlineLogisticRegression copy() { close(); OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), numFeatures(), prior); r.copyFrom(this); return r; }
public OnlineLogisticRegression copy() { close(); OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), numFeatures(), prior); r.copyFrom(this); return r; }
public OnlineLogisticRegression copy() { close(); OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), numFeatures(), prior); r.copyFrom(this); return r; }
OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression(); learningAlgo = new OnlineLogisticRegression(2, 3, new L1()); learningAlgo.lambda(0.1);
new OnlineLogisticRegression( 20, FEATURES, new L1()) .alpha(1).stepOffset(1000)
OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) .lambda(1 * 1.0e-3) .stepOffset(11)
OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
@Test public void testTrain() throws Exception { Vector target = readStandardData(); // lambda here needs to be relatively small to avoid swamping the actual signal, but can be // larger than usual because the data are dense. The learning rate doesn't matter too much // for this example, but should generally be < 1 // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias // --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) .lambda(1 * 1.0e-3) .learningRate(50); train(getInput(), target, lr); test(getInput(), target, lr, 0.05, 0.3); }
@Test public void onlineLogisticRegressionRoundTrip() throws IOException { OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1()); train(olr, 100); OnlineLogisticRegression olr3 = roundTrip(olr, OnlineLogisticRegression.class); assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); train(olr, 100); train(olr3, 100); assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); olr.close(); olr3.close(); }