/** * * @param numCategories The number of categories (labels) to train on * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector) * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use * @param threadCount The number of threads to use for training * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use. */ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount, int poolSize) { this.numFeatures = numFeatures; this.threadCount = threadCount; this.poolSize = poolSize; seed = new State<Wrapper, CrossFoldLearner>(new double[2], 10); Wrapper w = new Wrapper(numCategories, numFeatures, prior); seed.setPayload(w); Wrapper.setMappings(seed); seed.setPayload(w); setPoolSize(this.poolSize); }
/** * * @param numCategories The number of categories (labels) to train on * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector) * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use * @param threadCount The number of threads to use for training * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use. */ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount, int poolSize) { this.numFeatures = numFeatures; this.threadCount = threadCount; this.poolSize = poolSize; seed = new State<Wrapper, CrossFoldLearner>(new double[2], 10); Wrapper w = new Wrapper(numCategories, numFeatures, prior); seed.setPayload(w); Wrapper.setMappings(seed); seed.setPayload(w); setPoolSize(this.poolSize); }
/** * * @param numCategories The number of categories (labels) to train on * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector) * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use * @param threadCount The number of threads to use for training * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use. */ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount, int poolSize) { this.numFeatures = numFeatures; this.threadCount = threadCount; this.poolSize = poolSize; seed = new State<>(new double[2], 10); Wrapper w = new Wrapper(numCategories, numFeatures, prior); seed.setPayload(w); Wrapper.setMappings(seed); seed.setPayload(w); setPoolSize(this.poolSize); }
@Test public void testConverges() throws Exception { State<Foo, Double> s0 = new State<Foo, Double>(new double[5], 1); s0.setPayload(new Foo()); EvolutionaryProcess<Foo, Double> ep = new EvolutionaryProcess<Foo, Double>(10, 100, s0); State<Foo, Double> best = null; for (int i = 0; i < 20; i++) { best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<Double>>() { @Override public double apply(Payload<Double> payload, double[] params) { int i = 1; double sum = 0; for (double x : params) { sum += i * (x - i) * (x - i); i++; } return -sum; } }); ep.mutatePopulation(3); System.out.printf("%10.3f %.3f\n", best.getValue(), best.getOmni()); } ep.close(); assertNotNull(best); assertEquals(0.0, best.getValue(), 0.02); }