@Override public void update(double[] params) { int i = 0; wrapped.lambda(params[i++]); wrapped.learningRate(params[i]); wrapped.stepOffset(1); wrapped.alpha(1); wrapped.decayExponent(0); }
@Override public void update(double[] params) { int i = 0; wrapped.lambda(params[i++]); wrapped.learningRate(params[i]); wrapped.stepOffset(1); wrapped.alpha(1); wrapped.decayExponent(0); }
@Override public void update(double[] params) { int i = 0; wrapped.lambda(params[i++]); wrapped.learningRate(params[i]); wrapped.stepOffset(1); wrapped.alpha(1); wrapped.decayExponent(0); }
/** * The CrossFoldLearner is probably the best learner to use for new applications. * * @throws IOException If test resources aren't readable. */ @Test public void crossValidation() throws IOException { Vector target = readStandardData(); CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1()) .lambda(1 * 1.0e-3) .learningRate(50); train(getInput(), target, lr); System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood()); test(getInput(), target, lr, 0.05, 0.3); }
@Test public void crossValidatedAuc() throws IOException { RandomUtils.useTestSeed(); Random gen = RandomUtils.getRandom(); Matrix data = readCsv("cancer.csv"); CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1()) .stepOffset(10) .decayExponent(0.7) .lambda(1 * 1.0e-3) .learningRate(5); int k = 0; int[] ordering = permute(gen, data.numRows()); for (int epoch = 0; epoch < 100; epoch++) { for (int row : ordering) { lr.train(row, (int) data.get(row, 9), data.viewRow(row)); System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc()); } assertEquals(1, lr.auc(), 0.2); } assertEquals(1, lr.auc(), 0.1); }