public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts) { return calculatePerplexity(matrix, docTopicCounts, 0); }
public void writeModel(Path outputPath) throws IOException { modelTrainer.persist(outputPath); }
public void trainDocuments(double testFraction) { long start = System.nanoTime(); modelTrainer.start(); for (int docId = 0; docId < corpusWeights.numRows(); docId++) { if (testFraction == 0 || docId % (1 / testFraction) != 0) { Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics); // docTopicCounts.getRow(docId) modelTrainer.trainSync(corpusWeights.viewRow(docId), docTopics , true, 10); } } modelTrainer.stop(); logTime("train documents", System.nanoTime() - start); }
public void train(VectorIterable matrix, VectorIterable docTopicCounts, int numDocTopicIters) { start(); Iterator<MatrixSlice> docIterator = matrix.iterator(); Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator(); batchTrain(batch, true, numDocTopicIters); long time = System.nanoTime(); log.debug("trained {} docs with {} tokens, start time {}, end time {}", train(document, topicDist, true, numDocTopicIters); if (log.isDebugEnabled()) { times[i % times.length] = stop();
@Override protected void cleanup(Context context) throws IOException, InterruptedException { log.info("Stopping model trainer"); modelTrainer.stop(); log.info("Writing model"); TopicModel readFrom = modelTrainer.getReadModel(); for (MatrixSlice topic : readFrom) { context.write(new IntWritable(topic.index()), new VectorWritable(topic.vector())); } readModel.stop(); writeModel.stop(); } }
@Override protected void setup(Context context) throws IOException, InterruptedException { log.info("Retrieving configuration"); Configuration conf = context.getConfiguration(); float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN); float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN); long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L); numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1); int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1); int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1); int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4); maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10); float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f); log.info("Initializing read model"); Path[] modelPaths = CVB0Driver.getModelPaths(conf); if (modelPaths != null && modelPaths.length > 0) { readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths); } else { log.info("No model files found"); readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null, numTrainThreads, modelWeight); } log.info("Initializing write model"); writeModel = modelWeight == 1 ? new TopicModel(numTopics, numTerms, eta, alpha, null, numUpdateThreads) : readModel; log.info("Initializing model trainer"); modelTrainer = new ModelTrainer(readModel, writeModel, numTrainThreads, numTopics, numTerms); modelTrainer.start(); }
trainDocuments(testFraction); if (verbose) { log.info("model after: {}: {}", iter, modelTrainer.getReadModel()); oldPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, testFraction); log.info("{} = perplexity", oldPerplexity); trainDocuments(); if (verbose) { log.info("model after: {}: {}", iter, modelTrainer.getReadModel()); newPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, testFraction); log.info("{} = perplexity", newPerplexity);
public void train(VectorIterable matrix, VectorIterable docTopicCounts) { train(matrix, docTopicCounts, 1); }
@Override protected void setup(Context context) throws IOException, InterruptedException { MemoryUtil.startMemoryLogger(5000); log.info("Retrieving configuration"); Configuration conf = context.getConfiguration(); float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN); float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN); long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L); random = RandomUtils.getRandom(seed); numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1); int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1); int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1); int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4); maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10); float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f); testFraction = conf.getFloat(CVB0Driver.TEST_SET_FRACTION, 0.1f); log.info("Initializing read model"); Path[] modelPaths = CVB0Driver.getModelPaths(conf); if (modelPaths != null && modelPaths.length > 0) { readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths); } else { log.info("No model files found"); readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null, numTrainThreads, modelWeight); } log.info("Initializing model trainer"); modelTrainer = new ModelTrainer(readModel, null, numTrainThreads, numTopics, numTerms); log.info("Initializing topic vector"); topicVector = new DenseVector(new double[numTopics]); }
@Override protected void cleanup(Context context) { getModelTrainer().stop(); } }
@Override public void map(IntWritable docId, VectorWritable doc, Context context) throws IOException, InterruptedException { int numTopics = getNumTopics(); Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics); Matrix docModel = new SparseRowMatrix(numTopics, doc.get().size()); int maxIters = getMaxIters(); ModelTrainer modelTrainer = getModelTrainer(); for (int i = 0; i < maxIters; i++) { modelTrainer.getReadModel().trainDocTopicModel(doc.get(), docTopics, docModel); } topics.set(docTopics); context.write(docId, topics); }
public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts, double testFraction) { Iterator<MatrixSlice> docIterator = matrix.iterator(); Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator(); double perplexity = 0; double matrixNorm = 0; while (docIterator.hasNext() && docTopicIterator.hasNext()) { MatrixSlice docSlice = docIterator.next(); MatrixSlice topicSlice = docTopicIterator.next(); int docId = docSlice.index(); Vector document = docSlice.vector(); Vector topicDist = topicSlice.vector(); if (testFraction == 0 || docId % (1 / testFraction) == 0) { trainSync(document, topicDist, false, 10); perplexity += readModel.perplexity(document, topicDist); matrixNorm += document.norm(1); } } return perplexity / matrixNorm; }
public void train(VectorIterable matrix, VectorIterable docTopicCounts, int numDocTopicIters) { start(); Iterator<MatrixSlice> docIterator = matrix.iterator(); Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator(); batchTrain(batch, true, numDocTopicIters); long time = System.nanoTime(); log.debug("trained {} docs with {} tokens, start time {}, end time {}", train(document, topicDist, true, numDocTopicIters); if (log.isDebugEnabled()) { times[i % times.length] = stop();
@Override protected void cleanup(Context context) throws IOException, InterruptedException { log.info("Stopping model trainer"); modelTrainer.stop(); log.info("Writing model"); TopicModel readFrom = modelTrainer.getReadModel(); for (MatrixSlice topic : readFrom) { context.write(new IntWritable(topic.index()), new VectorWritable(topic.vector())); } readModel.stop(); writeModel.stop(); } }
@Override protected void setup(Context context) throws IOException, InterruptedException { log.info("Retrieving configuration"); Configuration conf = context.getConfiguration(); float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN); float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN); long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L); numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1); int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1); int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1); int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4); maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10); float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f); log.info("Initializing read model"); Path[] modelPaths = CVB0Driver.getModelPaths(conf); if (modelPaths != null && modelPaths.length > 0) { readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths); } else { log.info("No model files found"); readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null, numTrainThreads, modelWeight); } log.info("Initializing write model"); writeModel = modelWeight == 1 ? new TopicModel(numTopics, numTerms, eta, alpha, null, numUpdateThreads) : readModel; log.info("Initializing model trainer"); modelTrainer = new ModelTrainer(readModel, writeModel, numTrainThreads, numTopics, numTerms); modelTrainer.start(); }
trainDocuments(testFraction); if (verbose) { log.info("model after: {}: {}", iter, modelTrainer.getReadModel()); oldPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, testFraction); log.info("{} = perplexity", oldPerplexity); trainDocuments(); if (verbose) { log.info("model after: {}: {}", iter, modelTrainer.getReadModel()); newPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, testFraction); log.info("{} = perplexity", newPerplexity);
public void train(VectorIterable matrix, VectorIterable docTopicCounts) { train(matrix, docTopicCounts, 1); }
@Override protected void setup(Context context) throws IOException, InterruptedException { MemoryUtil.startMemoryLogger(5000); log.info("Retrieving configuration"); Configuration conf = context.getConfiguration(); float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN); float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN); long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L); random = RandomUtils.getRandom(seed); numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1); int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1); int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1); int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4); maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10); float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f); testFraction = conf.getFloat(CVB0Driver.TEST_SET_FRACTION, 0.1f); log.info("Initializing read model"); Path[] modelPaths = CVB0Driver.getModelPaths(conf); if (modelPaths != null && modelPaths.length > 0) { readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths); } else { log.info("No model files found"); readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null, numTrainThreads, modelWeight); } log.info("Initializing model trainer"); modelTrainer = new ModelTrainer(readModel, null, numTrainThreads, numTopics, numTerms); log.info("Initializing topic vector"); topicVector = new DenseVector(new double[numTopics]); }
@Override protected void cleanup(Context context) { getModelTrainer().stop(); } }
@Override public void map(IntWritable docId, VectorWritable doc, Context context) throws IOException, InterruptedException { int numTopics = getNumTopics(); Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics); Matrix docModel = new SparseRowMatrix(numTopics, doc.get().size()); int maxIters = getMaxIters(); ModelTrainer modelTrainer = getModelTrainer(); for (int i = 0; i < maxIters; i++) { modelTrainer.getReadModel().trainDocTopicModel(doc.get(), docTopics, docModel); } topics.set(docTopics); context.write(docId, topics); }