public NegLogLikelihood(DataIndexer indexer) { // Get data from indexer. if (indexer instanceof OnePassRealValueDataIndexer) { this.values = indexer.getValues(); } else { this.values = null; } this.contexts = indexer.getContexts(); this.outcomeList = indexer.getOutcomeList(); this.numTimesEventsSeen = indexer.getNumTimesEventsSeen(); this.numOutcomes = indexer.getOutcomeLabels().length; this.numFeatures = indexer.getPredLabels().length; this.numContexts = this.contexts.length; this.dimension = numOutcomes * numFeatures; this.expectation = new double[numOutcomes]; this.tempSums = new double[numOutcomes]; this.gradient = new double[dimension]; }
.build(); indexer.index(eventStream); Assert.assertEquals(3, indexer.getContexts().length); Assert.assertArrayEquals(new int[]{0}, indexer.getContexts()[0]); Assert.assertArrayEquals(new int[]{0}, indexer.getContexts()[1]); Assert.assertArrayEquals(new int[]{0}, indexer.getContexts()[2]); Assert.assertEquals(3, indexer.getValues().length); Assert.assertNull(indexer.getValues()[0]); Assert.assertNull(indexer.getValues()[1]); Assert.assertNull(indexer.getValues()[2]); Assert.assertEquals(5, indexer.getNumEvents()); Assert.assertArrayEquals(new int[]{0, 1, 2}, indexer.getOutcomeList()); Assert.assertArrayEquals(new int[]{3, 1, 1}, indexer.getNumTimesEventsSeen()); Assert.assertArrayEquals(new String[]{"ppo=other"}, indexer.getPredLabels()); Assert.assertArrayEquals(new String[]{"other", "org-start", "org-cont"}, indexer.getOutcomeLabels()); Assert.assertArrayEquals(new int[]{5}, indexer.getPredCounts());
trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false); DataIndexer di = new OnePassDataIndexer(); di.init(trainingParameters,reportMap); di.index(new SequenceStreamEventStream(sequenceStream)); numSequences = 0; outcomeList = di.getOutcomeList(); predLabels = di.getPredLabels(); pmap = new HashMap<>(); numEvents = di.getNumEvents(); outcomeLabels = di.getOutcomeLabels(); omap = new HashMap<>(); for (int oli = 0; oli < outcomeLabels.length; oli++) { omap.put(outcomeLabels[oli], oli); outcomeList = di.getOutcomeList();
private boolean compareDoubleArray(double[] expected, double[] actual, DataIndexer indexer, double tolerance) { double[] alignedActual = alignDoubleArrayForTestData( actual, indexer.getPredLabels(), indexer.getOutcomeLabels()); if (expected.length != alignedActual.length) { return false; } for (int i = 0; i < alignedActual.length; i++) { if (Math.abs(alignedActual[i] - expected[i]) > tolerance) { return false; } } return true; } }
/** * Trains a GIS model on the event in the specified event stream, using the specified number * of iterations and the specified count cutoff. * * @param eventStream A stream of all events. * @param iterations The number of iterations to use for GIS. * @param cutoff The number of times a feature must occur to be included. * @return A GIS model trained with specified */ public GISModel trainModel(ObjectStream<Event> eventStream, int iterations, int cutoff) throws IOException { DataIndexer indexer = new OnePassDataIndexer(); TrainingParameters indexingParameters = new TrainingParameters(); indexingParameters.put(GISTrainer.CUTOFF_PARAM, cutoff); indexingParameters.put(GISTrainer.ITERATIONS_PARAM, iterations); Map<String, String> reportMap = new HashMap<>(); indexer.init(indexingParameters, reportMap); indexer.index(eventStream); return trainModel(iterations, indexer); }
public DataIndexer getDataIndexer(ObjectStream<Event> events) throws IOException { trainingParameters.put(AbstractDataIndexer.SORT_PARAM, isSortAndMerge()); // If the cutoff was set, don't overwrite the value. if (trainingParameters.getIntParameter(CUTOFF_PARAM, -1) == -1) { trainingParameters.put(CUTOFF_PARAM, 5); } DataIndexer indexer = DataIndexerFactory.getDataIndexer(trainingParameters, reportMap); indexer.index(events); return indexer; }
@Test public void testDomainDimensionSanity() throws IOException { // given RealValueFileEventStream rvfes1 = new RealValueFileEventStream( "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"); testDataIndexer.index(rvfes1); NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer); // when int correctDomainDimension = testDataIndexer.getPredLabels().length * testDataIndexer.getOutcomeLabels().length; // then Assert.assertEquals(correctDomainDimension, objectFunction.getDimension()); }
DataIndexer di = DataIndexerFactory.getDataIndexer(parameters, myReportMap); Assert.assertEquals("opennlp.tools.ml.model.OnePassDataIndexer", di.getClass().getName()); di.index(eventStream); Assert.assertEquals(3, di.getNumEvents()); Assert.assertEquals(2, di.getOutcomeLabels().length); Assert.assertEquals(6, di.getPredLabels().length); di = DataIndexerFactory.getDataIndexer(parameters, myReportMap); Assert.assertEquals("opennlp.tools.ml.model.TwoPassDataIndexer", di.getClass().getName()); di.index(eventStream); Assert.assertEquals(3, di.getNumEvents()); Assert.assertEquals(2, di.getOutcomeLabels().length); Assert.assertEquals(6, di.getPredLabels().length);
@Test public void testIndexWithNewline() throws IOException { String[] sentence = "He belongs to Apache \n Software Foundation .".split(" "); NameContextGenerator CG = new DefaultNameContextGenerator( (AdaptiveFeatureGenerator[]) null); NameSample nameSample = new NameSample(sentence, new Span[] { new Span(3, 7) }, false); ObjectStream<Event> eventStream = new NameFinderEventStream( ObjectStreamUtils.createObjectStream(nameSample), "org", CG, null); DataIndexer indexer = new TwoPassDataIndexer(); indexer.init(new TrainingParameters(Collections.emptyMap()), null); indexer.index(eventStream); Assert.assertEquals(5, indexer.getContexts().length); } }
indexer.init(parameters, reportMap);
Assert.assertEquals(3, di.getNumEvents()); Assert.assertEquals(2, di.getOutcomeLabels().length); Assert.assertEquals(6, di.getPredLabels().length);
@Test public void testLastLineBug() throws IOException { try (RealValueFileEventStream rvfes = new RealValueFileEventStream( "src/test/resources/data/opennlp/maxent/io/rvfes-bug-data-ok.txt")) { indexer.index(rvfes); } Assert.assertEquals(1, indexer.getOutcomeLabels().length); try (RealValueFileEventStream rvfes = new RealValueFileEventStream( "src/test/resources/data/opennlp/maxent/io/rvfes-bug-data-broken.txt")) { indexer.index(rvfes); } Assert.assertEquals(1, indexer.getOutcomeLabels().length); } }
public final MaxentModel train(DataIndexer indexer) throws IOException { validate(); if (indexer.getOutcomeLabels().length <= 1) { throw new InsufficientTrainingDataException("Training data must contain more than one outcome"); } MaxentModel model = doTrain(indexer); addToReport(AbstractTrainer.TRAINER_TYPE_PARAM, EventTrainer.EVENT_VALUE); return model; }
@Test public void testPlainTextModel() throws IOException { testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream()); NaiveBayesModel model1 = (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); StringWriter sw1 = new StringWriter(); NaiveBayesModelWriter modelWriter = new PlainTextNaiveBayesModelWriter(model1, new BufferedWriter(sw1)); modelWriter.persist(); NaiveBayesModelReader reader = new PlainTextNaiveBayesModelReader(new BufferedReader(new StringReader(sw1.toString()))); reader.checkModelType(); NaiveBayesModel model2 = (NaiveBayesModel)reader.constructModel(); StringWriter sw2 = new StringWriter(); modelWriter = new PlainTextNaiveBayesModelWriter(model2, new BufferedWriter(sw2)); modelWriter.persist(); System.out.println(sw1.toString()); Assert.assertEquals(sw1.toString(), sw2.toString()); }
@Test public void testValueAtNonInitialPoint02() throws IOException { // given RealValueFileEventStream rvfes1 = new RealValueFileEventStream( "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"); testDataIndexer.index(rvfes1); NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer); // when double[] nonInitialPoint = new double[] { 3, 2, 3, 2, 3, 2, 3, 2, 3, 2 }; double value = objectFunction.valueAt(dealignDoubleArrayForTestData(nonInitialPoint, testDataIndexer.getPredLabels(), testDataIndexer.getOutcomeLabels())); double expectedValue = 53.163219721099026; // then Assert.assertEquals(expectedValue, value, TOLERANCE02); }
@Test public void testQNOnPrepAttachData() throws IOException { DataIndexer indexer = new TwoPassDataIndexer(); TrainingParameters indexingParameters = new TrainingParameters(); indexingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1); indexingParameters.put(AbstractDataIndexer.SORT_PARAM, false); indexer.init(indexingParameters, new HashMap<>()); indexer.index(PrepAttachDataUtil.createTrainingStream()); AbstractModel model = new QNTrainer(true).trainModel(100, indexer ); PrepAttachDataUtil.testModel(model, 0.8155484030700668); }
String[] predLabels = indexer.getPredLabels(); int nPredLabels = predLabels.length; String[] outcomeNames = indexer.getOutcomeLabels(); int nOutcomes = outcomeNames.length;
@Before public void setUp() throws Exception { indexer = new OnePassRealValueDataIndexer(); indexer.init(new TrainingParameters(Collections.emptyMap()), null); }
public final MaxentModel train(DataIndexer indexer) throws IOException { validate(); if (indexer.getOutcomeLabels().length <= 1) { throw new InsufficientTrainingDataException("Training data must contain more than one outcome"); } MaxentModel model = doTrain(indexer); addToReport(AbstractTrainer.TRAINER_TYPE_PARAM, EventTrainer.EVENT_VALUE); return model; }
/** * Evaluate the current model on training data set * @return model's training accuracy */ @Override public double evaluate(double[] parameters) { int[][] contexts = indexer.getContexts(); float[][] values = indexer.getValues(); int[] nEventsSeen = indexer.getNumTimesEventsSeen(); int[] outcomeList = indexer.getOutcomeList(); int nOutcomes = indexer.getOutcomeLabels().length; int nPredLabels = indexer.getPredLabels().length; int nCorrect = 0; int nTotalEvents = 0; for (int ei = 0; ei < contexts.length; ei++) { int[] context = contexts[ei]; float[] value = values == null ? null : values[ei]; double[] probs = new double[nOutcomes]; QNModel.eval(context, value, probs, nOutcomes, nPredLabels, parameters); int outcome = ArrayMath.argmax(probs); if (outcome == outcomeList[ei]) { nCorrect += nEventsSeen[ei]; } nTotalEvents += nEventsSeen[ei]; } return (double) nCorrect / nTotalEvents; } }