public static CRF getCRF(InstanceList training, int[] orders, String defaultLabel, String forbidden, String allowed, boolean connected) { Pattern forbiddenPat = Pattern.compile(forbidden); Pattern allowedPat = Pattern.compile(allowed); CRF crf = new CRF(training.getPipe(), (Pipe)null); String startName = crf.addOrderNStates(training, orders, null, defaultLabel, forbiddenPat, allowedPat, connected); for (int i = 0; i < crf.numStates(); i++) crf.getState(i).setInitialWeight (Transducer.IMPOSSIBLE_WEIGHT); crf.getState(startName).setInitialWeight(0.0); crf.setWeightsDimensionDensely(); return crf; }
/** Create a CRF whose states and weights are a copy of those from another CRF. */ public CRF (CRF other) { // This assumes that "other" has non-null inputPipe and outputPipe. We'd need to add another constructor to handle this if not. this (other.getInputPipe (), other.getOutputPipe ()); copyStatesAndWeightsFrom (other); assertWeightsLength (); }
public LabelAlphabet getTargetAlphabet () { return (LabelAlphabet) crf.getOutputAlphabet (); }
private void initializeFor(InstanceList examples) { this.crf = new CRF(examples.getPipe(), null); crf.addOrderNStates(examples, new int[]{1}, null, null, null, null, false); crf.addStartState(); crf.setWeightsDimensionAsIn(examples, false); if (crfFrom != null) { crf.initializeApplicableParametersFrom(crfFrom); } }
CRF crf = new CRF(inputAlphabet, outputAlphabet); crf.addFullyConnectedStates(stateNames); crf.setWeightsDimensionDensely(); crf.getState(0).setInitialWeight(1.0); crf.getState(1).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT); crf.getState(0).setFinalWeight(0.0); crf.getState(1).setFinalWeight(0.0); crf.setParameter(0, 0, 0, Transducer.IMPOSSIBLE_WEIGHT); // state0 crf.setParameter(0, 1, 0, 1.0); // state0->state1 crf.setParameter(1, 1, 0, 1.0); // state1 self-transition crf.setParameter(1, 0, 0, Transducer.IMPOSSIBLE_WEIGHT); // state1->state0 new FeatureVector((Alphabet) crf.getInputAlphabet(), new double[] { 1 }), new FeatureVector((Alphabet) crf.getInputAlphabet(), new double[] { 1 }), new FeatureVector((Alphabet) crf.getInputAlphabet(), new double[] { 1 }), }); assertTrue(viterbiPath.get(0) == crf.getState(0)); assertTrue(viterbiPath.get(1) == crf.getState(1)); assertTrue(viterbiPath.get(2) == crf.getState(1));
private TransducerTrainer trainOnce(Pipe pipe, InstanceList examples) { Stopwatch watch = Stopwatch.createStarted(); CRF crf = new CRF(pipe, null); crf.addOrderNStates(examples, new int[]{1}, null, null, null, null, false); crf.addStartState(); crf.setWeightsDimensionAsIn(examples, true); if (initFrom != null) { crf.initializeApplicableParametersFrom(initFrom); } log.info("Starting syllchain training..."); CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf, 8); trainer.setGaussianPriorVariance(2); trainer.setAddNoFactors(true); // trainer.setUseSomeUnsupportedTrick(true); trainer.train(examples); trainer.shutdown(); watch.stop(); log.info("SyllChain CRF Training took " + watch.toString()); crf.getInputAlphabet().stopGrowth(); crf.getOutputAlphabet().stopGrowth(); return trainer; }
CRF crf = new CRF (p, null); if (labelGramOption.value == 1) crf.addStatesForLabelsConnectedAsIn (trainingData); else if (labelGramOption.value == 2) crf.addStatesForBiLabelsConnectedAsIn (trainingData); crft.setGaussianPriorVariance (gaussianVarianceOption.value); for (int i = 0; i < crf.numStates(); i++) { Transducer.State s = crf.getState (i); if (s.getName().charAt(0) == 'I') s.setInitialWeight (Double.POSITIVE_INFINITY);
private TransducerTrainer trainOnce(Pipe pipe, InstanceList trainData) { Stopwatch watch = Stopwatch.createStarted(); CRF crf = new CRF(pipe, null); crf.addOrderNStates(trainData, new int[]{1}, null, null, null, null, false); crf.addStartState(); log.info("Starting alignTag training..."); CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf, 8); trainer.setGaussianPriorVariance(2); // trainer.setUseSomeUnsupportedTrick(false); trainer.train(trainData); trainer.shutdown(); watch.stop(); log.info("Align Tag CRF Training took " + watch.toString()); crf.getInputAlphabet().stopGrowth(); crf.getOutputAlphabet().stopGrowth(); return trainer; }
CRF crf = new CRF(p, null); crf.addFullyConnectedStatesForLabels(); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crft.trainIncremental(training); CRF.State state = crf.getState("notstart"); int widx = crf.getWeightsIndex("BadBad"); int numFeatures = crf.getInputAlphabet().size(); SparseVector w = new SparseVector(new double[numFeatures]); w.setAll(Double.NEGATIVE_INFINITY); crf.setWeights(widx, w);
public void testStartState() { Pipe p = new SerialPipes(new Pipe[] { new LineGroupString2TokenSequence(), new TokenSequenceMatchDataAndTarget(Pattern .compile("^(\\S+) (.*)"), 2, 1), new TokenSequenceParseFeatureString(false), new TokenText(), new TokenSequence2FeatureVectorSequence(true, false), new Target2LabelSequence(), new PrintInputAndTarget(), }); InstanceList data = new InstanceList(p); data.addThruPipe(new LineGroupIterator(new StringReader(toy), Pattern .compile("\n"), true)); CRF crf = new CRF(p, null); crf.print(); crf.addStatesForLabelsConnectedAsIn(data); crf.addStartState(); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); Optimizable.ByGradientValue maxable = crft.getOptimizableCRF(data); assertEquals(-1.3862, maxable.getValue(), 1e-4); crf = new CRF(p, null); crf .addOrderNStates(data, new int[] { 1 }, null, "A", null, null, false); crf.print(); crft = new CRFTrainerByLabelLikelihood(crf); maxable = crft.getOptimizableCRF(data); assertEquals(-3.09104245335831, maxable.getValue(), 1e-4); }
public void testTrainStochasticGradient() { Pipe p = makeSpacePredictionPipe(); Pipe p2 = new TestCRF2String(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[] { .5, .5 }); CRF crf = new CRF(p, p2); crf.addFullyConnectedStatesForLabels(); crf.setWeightsDimensionAsIn(lists[0], false); CRFTrainerByStochasticGradient crft = new CRFTrainerByStochasticGradient( crf, 0.0001); System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training..."); // either fixed learning rate or selected on a sample crft.setLearningRateByLikelihood(lists[0]); // crft.setLearningRate(0.01); crft.train(lists[0], 100); crf.print(); System.out.println("Training Accuracy after training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy after training = " + crf.averageTokenAccuracy(lists[1])); }
private void setupClassifier(String trainingdata) { try { crf_input = new ObjectInputStream(ResourceUtils.loadResource( trainingdata, this.getClass())); crf = (CRF) crf_input.readObject(); crf_input.close(); } catch (FileNotFoundException e1) { e1.printStackTrace(); } catch (IOException e1) { e1.printStackTrace(); } catch (ClassNotFoundException e) { e.printStackTrace(); } crf.getInputAlphabet().stopGrowth(); crf.getOutputAlphabet().stopGrowth(); crf_pipe = crf.getInputPipe(); crf_pipe.setTargetProcessing(false); crf_estimator = new ViterbiConfidenceEstimator(crf); }
public void testDenseFeatureSelection() { Pipe p = makeSpacePredictionPipe(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); // Test that dense observations wights aren't added for // "default-feature" edges. CRF crf1 = new CRF(p, null); crf1.addOrderNStates(instances, new int[] { 0 }, null, "start", null, null, true); CRFTrainerByLabelLikelihood crft1 = new CRFTrainerByLabelLikelihood( crf1); crft1.setUseSparseWeights(false); crft1.train(instances, 1); // Set weights dimension int nParams1 = crft1.getOptimizableCRF(instances).getNumParameters(); CRF crf2 = new CRF(p, null); crf2.addOrderNStates(instances, new int[] { 0, 1 }, new boolean[] { false, true }, "start", null, null, true); CRFTrainerByLabelLikelihood crft2 = new CRFTrainerByLabelLikelihood( crf2); crft2.setUseSparseWeights(false); crft2.train(instances, 1); // Set weights dimension int nParams2 = crft2.getOptimizableCRF(instances).getNumParameters(); assertEquals(nParams2, nParams1 + 4); }
testingInstances.addThruPipe(new LineGroupIterator(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(testingFilename)))), Pattern.compile("^\\s*$"), true)); CRF crf = new CRF(pipe, null); crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingInstances); crf.addStartState();
/** * do the training * * @param instList * @param myPipe */ void train(final InstanceList instList, final Pipe myPipe) { final long s1 = System.currentTimeMillis(); // set up model model = new CRF(myPipe, null); model.addStatesForLabelsConnectedAsIn(instList); // get trainer final CRFTrainerByLabelLikelihood crfTrainer = new CRFTrainerByLabelLikelihood( model); // do the training with unlimited amount of iterations // --> refrained from using modified version of mallet; // it's now the original source final boolean b = crfTrainer.train(instList); LOGGER.info("Tokenizer training: model converged: " + b); final long s2 = System.currentTimeMillis(); // stop growth and set trained model.getInputPipe().getDataAlphabet().stopGrowth(); trained = true; LOGGER.debug("train() - training time: " + ((s2 - s1) / 1000) + " sec"); }
public void testPrint() { Pipe p = new SerialPipes(new Pipe[] { new CharSequence2TokenSequence("."), new TokenText(), new TestCRFTokenSequenceRemoveSpaces(), new TokenSequence2FeatureVectorSequence(), new PrintInputAndTarget(), }); InstanceList one = new InstanceList(p); String[] data = new String[] { "ABCDE", }; one.addThruPipe(new ArrayIterator(data)); CRF crf = new CRF(p, null); crf.addFullyConnectedStatesForThreeQuarterLabels(one); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crf.setWeightsDimensionAsIn(one, false); Optimizable mcrf = crft.getOptimizableCRF(one); double[] params = new double[mcrf.getNumParameters()]; for (int i = 0; i < params.length; i++) { params[i] = i; } mcrf.setParameters(params); crf.print(); }
/** * do the prediction * * @param an * instance for prediction * @return an ArrayList of Unit objects containing the predicted label */ @SuppressWarnings("unchecked") ArrayList<Unit> predict(final Instance inst) { if ((trained == false) || (model == null)) throw new IllegalStateException( "No model available. Train or load trained model first."); final ArrayList<Unit> units = (ArrayList<Unit>) inst.getName(); if (units.size() > 0) { // get sequence final Sequence<?> input = (Sequence<?>) inst.getData(); // transduce and generate output final Sequence<?> crfOutput = model.transduce(input); for (int j = 0; j < crfOutput.size(); j++) units.get(j).label = (String) crfOutput.get(j); } return units; }
InstanceList[] lists = instances.split(new Random(1), new double[] { .5, .5 }); CRF crf = new CRF(p, p2); crf.addFullyConnectedStatesForLabels(); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); if (testValueAndGradient) { + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training..."); crft.trainIncremental(lists[0]); System.out.println("Training Accuracy after training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy after training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training results:"); for (int i = 0; i < lists[0].size(); i++) { Instance inst = lists[0].get(i); Sequence input = (Sequence) inst.getData(); Sequence output = crf.transduce(input); System.out.println(output); Instance inst = lists[1].get(i); Sequence input = (Sequence) inst.getData(); Sequence output = crf.transduce(input); System.out.println(output);
public void testSpaceViewer () throws IOException { Pipe pipe = TestMEMM.makeSpacePredictionPipe (); String[] data0 = { TestCRF.data[0] }; String[] data1 = { TestCRF.data[1] }; InstanceList training = new InstanceList (pipe); training.addThruPipe (new ArrayIterator (data0)); InstanceList testing = new InstanceList (pipe); testing.addThruPipe (new ArrayIterator (data1)); CRF crf = new CRF (pipe, null); crf.addFullyConnectedStatesForLabels (); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf); crft.trainIncremental (training); CRFExtractor extor = TestLatticeViewer.hackCrfExtor (crf); Extraction extraction = extor.extract (new ArrayIterator (data1)); if (!outputDir.exists ()) outputDir.mkdir (); DocumentViewer.writeExtraction (outputDir, extraction); }
public CRFExtractor (CRF crf, Pipe tokpipe, TokenizationFilter filter, String backgroundTag) { this.crf = crf; tokenizationPipe = tokpipe; featurePipe = (Pipe) crf.getInputPipe (); this.filter = filter; this.backgroundTag = backgroundTag; }