CrossFoldLearner state = best.getPayload().getLearner(); averageCorrect = state.percentCorrect(); averageLL = state.logLikelihood(); if (learningAlgorithm.getBest() != null) { ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
Wrapper.freeze(state);
AdaptiveLogisticRegression.Wrapper w = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1()); for (int i = 0; i < 3000; i++) { AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta); w.train(r); if (i % 1000 == 0) { System.out.printf("%10d %.3f\n", i, w.getLearner().auc()); System.out.printf("%10d %.3f\n", 3000, w.getLearner().auc()); double auc1 = w.getLearner().auc(); AdaptiveLogisticRegression.Wrapper w2 = w.copy(); assertEquals("Should have started with no data", 0.5, w2.getLearner().auc(), 0.0001); double auc2 = w2.getLearner().auc(); assertTrue("Should have had head-start", Math.abs(auc2 - 0.5) > 0.1); assertTrue("AUC should improve quickly on copy", auc1 < auc2); System.out.printf("%10d %.3f\n", i, w2.getLearner().auc()); w2.train(r); assertEquals("Original should not change after copy is updated", auc1, w.getLearner().auc(), 1.0e-5); assertTrue("AUC should improve significantly on copy", auc1 < w2.getLearner().auc() - 0.05); assertEquals(auc1, w.getLearner().auc(), 0);
AdaptiveLogisticRegression.Wrapper cl = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1()); cl.update(new double[]{1.0e-5, 1}); cl.train(r); if (i % 1000 == 0) { System.out.printf("%10d %10.3f\n", i, cl.getLearner().auc()); assertEquals(1, cl.getLearner().auc(), 0.1);
Wrapper.freeze(state);
Wrapper.freeze(state);
/** * * @param numCategories The number of categories (labels) to train on * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector) * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use * @param threadCount The number of threads to use for training * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use. */ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount, int poolSize) { this.numFeatures = numFeatures; this.threadCount = threadCount; this.poolSize = poolSize; seed = new State<>(new double[2], 10); Wrapper w = new Wrapper(numCategories, numFeatures, prior); seed.setPayload(w); Wrapper.setMappings(seed); seed.setPayload(w); setPoolSize(this.poolSize); }
private static void dissect(Dictionary newsGroups, AdaptiveLogisticRegression learningAlgorithm, Iterable<File> files) throws IOException { CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner(); model.close(); Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap(); ModelDissector md = new ModelDissector(); encoder.setTraceDictionary(traceDictionary); bias.setTraceDictionary(traceDictionary); for (File file : permute(files, rand).subList(0, 500)) { traceDictionary.clear(); Vector v = encodeFeatureVector(file); md.update(v, traceDictionary, model); } List<String> ngNames = Lists.newArrayList(newsGroups.values()); List<ModelDissector.Weight> weights = md.summary(100); for (ModelDissector.Weight w : weights) { System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1), w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2)); } }
/** * * @param numCategories The number of categories (labels) to train on * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector) * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use * @param threadCount The number of threads to use for training * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use. */ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount, int poolSize) { this.numFeatures = numFeatures; this.threadCount = threadCount; this.poolSize = poolSize; seed = new State<Wrapper, CrossFoldLearner>(new double[2], 10); Wrapper w = new Wrapper(numCategories, numFeatures, prior); seed.setPayload(w); Wrapper.setMappings(seed); seed.setPayload(w); setPoolSize(this.poolSize); }
/** * * @param numCategories The number of categories (labels) to train on * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector) * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use * @param threadCount The number of threads to use for training * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use. */ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount, int poolSize) { this.numFeatures = numFeatures; this.threadCount = threadCount; this.poolSize = poolSize; seed = new State<Wrapper, CrossFoldLearner>(new double[2], 10); Wrapper w = new Wrapper(numCategories, numFeatures, prior); seed.setPayload(w); Wrapper.setMappings(seed); seed.setPayload(w); setPoolSize(this.poolSize); }
public PriorFunction getPrior() { return seed.getPayload().getLearner().getPrior(); }
public PriorFunction getPrior() { return seed.getPayload().getLearner().getPrior(); }
public PriorFunction getPrior() { return seed.getPayload().getLearner().getPrior(); }
public int getNumCategories() { return seed.getPayload().getLearner().numCategories(); }
/** * What is the AUC for the current best member of the population. If no member is best, usually * because we haven't done any training yet, then the result is set to NaN. * * @return The AUC of the best member of the population or NaN if we can't figure that out. */ public double auc() { if (best == null) { return Double.NaN; } else { Wrapper payload = best.getPayload(); return payload.getLearner().auc(); } }
public int getNumCategories() { return seed.getPayload().getLearner().numCategories(); }
/** * What is the AUC for the current best member of the population. If no member is best, usually * because we haven't done any training yet, then the result is set to NaN. * * @return The AUC of the best member of the population or NaN if we can't figure that out. */ public double auc() { if (best == null) { return Double.NaN; } else { Wrapper payload = best.getPayload(); return payload.getLearner().auc(); } }
/** * What is the AUC for the current best member of the population. If no member is best, usually * because we haven't done any training yet, then the result is set to NaN. * * @return The AUC of the best member of the population or NaN if we can't figure that out. */ public double auc() { if (best == null) { return Double.NaN; } else { Wrapper payload = best.getPayload(); return payload.getLearner().auc(); } }
public void setAucEvaluator(OnlineAuc auc) { seed.getPayload().setAucEvaluator(auc); setupOptimizer(poolSize); }
public void setAveragingWindow(int averagingWindow) { seed.getPayload().getLearner().setWindowSize(averagingWindow); setupOptimizer(poolSize); }