private static CRFTrainerByLabelLikelihood makeNewTrainerSingleThreaded(CRF crf) { CRFTrainerByLabelLikelihood trainer = new CRFTrainerByLabelLikelihood(crf); trainer.setGaussianPriorVariance(2); // trainer.setUseHyperbolicPrior(true); trainer.setAddNoFactors(true); trainer.setUseSomeUnsupportedTrick(false); return trainer; }
/** * do the training * * @param instList * @param myPipe */ void train(final InstanceList instList, final Pipe myPipe) { final long s1 = System.currentTimeMillis(); // set up model model = new CRF(myPipe, null); model.addStatesForLabelsConnectedAsIn(instList); // get trainer final CRFTrainerByLabelLikelihood crfTrainer = new CRFTrainerByLabelLikelihood( model); // do the training with unlimited amount of iterations // --> refrained from using modified version of mallet; // it's now the original source final boolean b = crfTrainer.train(instList); LOGGER.info("Tokenizer training: model converged: " + b); final long s2 = System.currentTimeMillis(); // stop growth and set trained model.getInputPipe().getDataAlphabet().stopGrowth(); trained = true; LOGGER.debug("train() - training time: " + ((s2 - s1) / 1000) + " sec"); }
double getLikelihood(CRF crf, InstanceList data) { CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); Optimizable.ByGradientValue mcrf = crft.getOptimizableCRF(data); // Do this elaborate thing so that crf.cachedValueStale is forced true double[] params = new double[mcrf.getNumParameters()]; mcrf.getParameters(params); mcrf.setParameters(params); return mcrf.getValue(); }
double getLikelihood(CRF crf, InstanceList data) { CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); Optimizable.ByGradientValue mcrf = crft.getOptimizableCRF(data); // Do this elaborate thing so that crf.cachedValueStale is forced true double[] params = new double[mcrf.getNumParameters()]; mcrf.getParameters(params); mcrf.setParameters(params); return mcrf.getValue(); }
public void testSpaceViewer () throws FileNotFoundException { Pipe pipe = TestMEMM.makeSpacePredictionPipe (); String[] data0 = { TestCRF.data[0] }; String[] data1 = { TestCRF.data[1] }; InstanceList training = new InstanceList (pipe); training.addThruPipe (new ArrayIterator (data0)); InstanceList testing = new InstanceList (pipe); testing.addThruPipe (new ArrayIterator (data1)); CRF crf = new CRF (pipe, null); crf.addFullyConnectedStatesForLabels (); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf); crft.trainIncremental (training); CRFExtractor extor = hackCrfExtor (crf); Extraction extration = extor.extract (new ArrayIterator (data1)); PrintStream out = new PrintStream (new FileOutputStream (htmlFile)); LatticeViewer.extraction2html (extration, extor, out); out.close(); out = new PrintStream (new FileOutputStream (latticeFile)); LatticeViewer.extraction2html (extration, extor, out, true); out.close(); }
public void testSpaceViewer () throws FileNotFoundException { Pipe pipe = TestMEMM.makeSpacePredictionPipe (); String[] data0 = { TestCRF.data[0] }; String[] data1 = { TestCRF.data[1] }; InstanceList training = new InstanceList (pipe); training.addThruPipe (new ArrayIterator (data0)); InstanceList testing = new InstanceList (pipe); testing.addThruPipe (new ArrayIterator (data1)); CRF crf = new CRF (pipe, null); crf.addFullyConnectedStatesForLabels (); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf); crft.trainIncremental (training); CRFExtractor extor = hackCrfExtor (crf); Extraction extration = extor.extract (new ArrayIterator (data1)); PrintStream out = new PrintStream (new FileOutputStream (htmlFile)); LatticeViewer.extraction2html (extration, extor, out); out.close(); out = new PrintStream (new FileOutputStream (latticeFile)); LatticeViewer.extraction2html (extration, extor, out, true); out.close(); }
public void testSpaceViewer () throws IOException { Pipe pipe = TestMEMM.makeSpacePredictionPipe (); String[] data0 = { TestCRF.data[0] }; String[] data1 = { TestCRF.data[1] }; InstanceList training = new InstanceList (pipe); training.addThruPipe (new ArrayIterator (data0)); InstanceList testing = new InstanceList (pipe); testing.addThruPipe (new ArrayIterator (data1)); CRF crf = new CRF (pipe, null); crf.addFullyConnectedStatesForLabels (); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf); crft.trainIncremental (training); CRFExtractor extor = TestLatticeViewer.hackCrfExtor (crf); Extraction extraction = extor.extract (new ArrayIterator (data1)); if (!outputDir.exists ()) outputDir.mkdir (); DocumentViewer.writeExtraction (outputDir, extraction); }
public void testSpaceViewer () throws IOException { Pipe pipe = TestMEMM.makeSpacePredictionPipe (); String[] data0 = { TestCRF.data[0] }; String[] data1 = { TestCRF.data[1] }; InstanceList training = new InstanceList (pipe); training.addThruPipe (new ArrayIterator (data0)); InstanceList testing = new InstanceList (pipe); testing.addThruPipe (new ArrayIterator (data1)); CRF crf = new CRF (pipe, null); crf.addFullyConnectedStatesForLabels (); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf); crft.trainIncremental (training); CRFExtractor extor = TestLatticeViewer.hackCrfExtor (crf); Extraction extraction = extor.extract (new ArrayIterator (data1)); if (!outputDir.exists ()) outputDir.mkdir (); DocumentViewer.writeExtraction (outputDir, extraction); }
public void testGetSetParameters() { int inputVocabSize = 100; int numStates = 5; Alphabet inputAlphabet = new Alphabet(); for (int i = 0; i < inputVocabSize; i++) inputAlphabet.lookupIndex("feature" + i); Alphabet outputAlphabet = new Alphabet(); CRF crf = new CRF(inputAlphabet, outputAlphabet); String[] stateNames = new String[numStates]; for (int i = 0; i < numStates; i++) stateNames[i] = "state" + i; crf.addFullyConnectedStates(stateNames); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); Optimizable.ByGradientValue mcrf = crft .getOptimizableCRF(new InstanceList(null)); TestOptimizable.testGetSetParameters(mcrf); }
public void testGetSetParameters() { int inputVocabSize = 100; int numStates = 5; Alphabet inputAlphabet = new Alphabet(); for (int i = 0; i < inputVocabSize; i++) inputAlphabet.lookupIndex("feature" + i); Alphabet outputAlphabet = new Alphabet(); CRF crf = new CRF(inputAlphabet, outputAlphabet); String[] stateNames = new String[numStates]; for (int i = 0; i < numStates; i++) stateNames[i] = "state" + i; crf.addFullyConnectedStates(stateNames); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); Optimizable.ByGradientValue mcrf = crft .getOptimizableCRF(new InstanceList(null)); TestOptimizable.testGetSetParameters(mcrf); }
public void testTokenAccuracy() { Pipe p = makeSpacePredictionPipe(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); InstanceList[] lists = instances.split(new Random(777), new double[] { .5, .5 }); CRF crf = new CRF(p.getDataAlphabet(), p.getTargetAlphabet()); crf.addFullyConnectedStatesForLabels(); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crft.setUseSparseWeights(true); crft.trainIncremental(lists[0]); TokenAccuracyEvaluator eval = new TokenAccuracyEvaluator(lists, new String[] { "Train", "Test" }); eval.evaluateInstanceList(crft, lists[1], "Test"); assertEquals(0.9409, eval.getAccuracy("Test"), 0.001); }
public void testDenseFeatureSelection() { Pipe p = makeSpacePredictionPipe(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); // Test that dense observations wights aren't added for // "default-feature" edges. CRF crf1 = new CRF(p, null); crf1.addOrderNStates(instances, new int[] { 0 }, null, "start", null, null, true); CRFTrainerByLabelLikelihood crft1 = new CRFTrainerByLabelLikelihood( crf1); crft1.setUseSparseWeights(false); crft1.train(instances, 1); // Set weights dimension int nParams1 = crft1.getOptimizableCRF(instances).getNumParameters(); CRF crf2 = new CRF(p, null); crf2.addOrderNStates(instances, new int[] { 0, 1 }, new boolean[] { false, true }, "start", null, null, true); CRFTrainerByLabelLikelihood crft2 = new CRFTrainerByLabelLikelihood( crf2); crft2.setUseSparseWeights(false); crft2.train(instances, 1); // Set weights dimension int nParams2 = crft2.getOptimizableCRF(instances).getNumParameters(); assertEquals(nParams2, nParams1 + 4); }
public void testDenseFeatureSelection() { Pipe p = makeSpacePredictionPipe(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); // Test that dense observations wights aren't added for // "default-feature" edges. CRF crf1 = new CRF(p, null); crf1.addOrderNStates(instances, new int[] { 0 }, null, "start", null, null, true); CRFTrainerByLabelLikelihood crft1 = new CRFTrainerByLabelLikelihood( crf1); crft1.setUseSparseWeights(false); crft1.train(instances, 1); // Set weights dimension int nParams1 = crft1.getOptimizableCRF(instances).getNumParameters(); CRF crf2 = new CRF(p, null); crf2.addOrderNStates(instances, new int[] { 0, 1 }, new boolean[] { false, true }, "start", null, null, true); CRFTrainerByLabelLikelihood crft2 = new CRFTrainerByLabelLikelihood( crf2); crft2.setUseSparseWeights(false); crft2.train(instances, 1); // Set weights dimension int nParams2 = crft2.getOptimizableCRF(instances).getNumParameters(); assertEquals(nParams2, nParams1 + 4); }
public void testStartState() { Pipe p = new SerialPipes(new Pipe[] { new LineGroupString2TokenSequence(), new TokenSequenceMatchDataAndTarget(Pattern .compile("^(\\S+) (.*)"), 2, 1), new TokenSequenceParseFeatureString(false), new TokenText(), new TokenSequence2FeatureVectorSequence(true, false), new Target2LabelSequence(), new PrintInputAndTarget(), }); InstanceList data = new InstanceList(p); data.addThruPipe(new LineGroupIterator(new StringReader(toy), Pattern .compile("\n"), true)); CRF crf = new CRF(p, null); crf.print(); crf.addStatesForLabelsConnectedAsIn(data); crf.addStartState(); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); Optimizable.ByGradientValue maxable = crft.getOptimizableCRF(data); assertEquals(-1.3862, maxable.getValue(), 1e-4); crf = new CRF(p, null); crf .addOrderNStates(data, new int[] { 1 }, null, "A", null, null, false); crf.print(); crft = new CRFTrainerByLabelLikelihood(crf); maxable = crft.getOptimizableCRF(data); assertEquals(-3.09104245335831, maxable.getValue(), 1e-4); }
public void testTokenAccuracy() { Pipe p = makeSpacePredictionPipe(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); InstanceList[] lists = instances.split(new Random(777), new double[] { .5, .5 }); CRF crf = new CRF(p.getDataAlphabet(), p.getTargetAlphabet()); crf.addFullyConnectedStatesForLabels(); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crft.setUseSparseWeights(true); crft.trainIncremental(lists[0]); TokenAccuracyEvaluator eval = new TokenAccuracyEvaluator(lists, new String[] { "Train", "Test" }); eval.evaluateInstanceList(crft, lists[1], "Test"); assertEquals(0.9409, eval.getAccuracy("Test"), 0.001); }
public void testStartState() { Pipe p = new SerialPipes(new Pipe[] { new LineGroupString2TokenSequence(), new TokenSequenceMatchDataAndTarget(Pattern .compile("^(\\S+) (.*)"), 2, 1), new TokenSequenceParseFeatureString(false), new TokenText(), new TokenSequence2FeatureVectorSequence(true, false), new Target2LabelSequence(), new PrintInputAndTarget(), }); InstanceList data = new InstanceList(p); data.addThruPipe(new LineGroupIterator(new StringReader(toy), Pattern .compile("\n"), true)); CRF crf = new CRF(p, null); crf.print(); crf.addStatesForLabelsConnectedAsIn(data); crf.addStartState(); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); Optimizable.ByGradientValue maxable = crft.getOptimizableCRF(data); assertEquals(-1.3862, maxable.getValue(), 1e-4); crf = new CRF(p, null); crf .addOrderNStates(data, new int[] { 1 }, null, "A", null, null, false); crf.print(); crft = new CRFTrainerByLabelLikelihood(crf); maxable = crft.getOptimizableCRF(data); assertEquals(-3.09104245335831, maxable.getValue(), 1e-4); }
public void testPrint() { Pipe p = new SerialPipes(new Pipe[] { new CharSequence2TokenSequence("."), new TokenText(), new TestCRFTokenSequenceRemoveSpaces(), new TokenSequence2FeatureVectorSequence(), new PrintInputAndTarget(), }); InstanceList one = new InstanceList(p); String[] data = new String[] { "ABCDE", }; one.addThruPipe(new ArrayIterator(data)); CRF crf = new CRF(p, null); crf.addFullyConnectedStatesForThreeQuarterLabels(one); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crf.setWeightsDimensionAsIn(one, false); Optimizable mcrf = crft.getOptimizableCRF(one); double[] params = new double[mcrf.getNumParameters()]; for (int i = 0; i < params.length; i++) { params[i] = i; } mcrf.setParameters(params); crf.print(); }
public void testPrint() { Pipe p = new SerialPipes(new Pipe[] { new CharSequence2TokenSequence("."), new TokenText(), new TestCRFTokenSequenceRemoveSpaces(), new TokenSequence2FeatureVectorSequence(), new PrintInputAndTarget(), }); InstanceList one = new InstanceList(p); String[] data = new String[] { "ABCDE", }; one.addThruPipe(new ArrayIterator(data)); CRF crf = new CRF(p, null); crf.addFullyConnectedStatesForThreeQuarterLabels(one); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crf.setWeightsDimensionAsIn(one, false); Optimizable mcrf = crft.getOptimizableCRF(one); double[] params = new double[mcrf.getNumParameters()]; for (int i = 0; i < params.length; i++) { params[i] = i; } mcrf.setParameters(params); crf.print(); }
public void testXis() { Pipe p = makeSpacePredictionPipe(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); CRF crf1 = new CRF(p, null); crf1.addFullyConnectedStatesForLabels(); CRFTrainerByLabelLikelihood crft1 = new CRFTrainerByLabelLikelihood( crf1); crft1.train(instances, 10); // Let's get some parameters Instance inst = instances.get(0); Sequence input = (Sequence) inst.getData(); SumLatticeDefault lattice = new SumLatticeDefault(crf1, input, (Sequence) inst.getTarget(), null, true); for (int ip = 0; ip < lattice.length() - 1; ip++) { for (int i = 0; i < crf1.numStates(); i++) { Transducer.State state = crf1.getState(i); Transducer.TransitionIterator it = state.transitionIterator( input, ip); double gamma = lattice.getGammaProbability(ip, state); double xiSum = 0; while (it.hasNext()) { Transducer.State dest = it.nextState(); double xi = lattice.getXiProbability(ip, state, dest); xiSum += xi; } assertEquals(gamma, xiSum, 1e-5); } } }
public void testXis() { Pipe p = makeSpacePredictionPipe(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); CRF crf1 = new CRF(p, null); crf1.addFullyConnectedStatesForLabels(); CRFTrainerByLabelLikelihood crft1 = new CRFTrainerByLabelLikelihood( crf1); crft1.train(instances, 10); // Let's get some parameters Instance inst = instances.get(0); Sequence input = (Sequence) inst.getData(); SumLatticeDefault lattice = new SumLatticeDefault(crf1, input, (Sequence) inst.getTarget(), null, true); for (int ip = 0; ip < lattice.length() - 1; ip++) { for (int i = 0; i < crf1.numStates(); i++) { Transducer.State state = crf1.getState(i); Transducer.TransitionIterator it = state.transitionIterator( input, ip); double gamma = lattice.getGammaProbability(ip, state); double xiSum = 0; while (it.hasNext()) { Transducer.State dest = it.nextState(); double xi = lattice.getXiProbability(ip, state, dest); xiSum += xi; } assertEquals(gamma, xiSum, 1e-5); } } }