public void setWeightsDimensionAsIn (InstanceList trainingData) { setWeightsDimensionAsIn(trainingData, false); }
public void setWeightsDimensionAsIn (InstanceList trainingData) { setWeightsDimensionAsIn(trainingData, false); }
public void setWeightsDimensionAsIn(InstanceList trainingData, boolean useSomeUnsupportedTrick) { setWeightsDimensionAsIn(trainingData, useSomeUnsupportedTrick, null); }
public void setWeightsDimensionAsIn (InstanceList trainingData) { setWeightsDimensionAsIn(trainingData, false); }
private void initializeFor(InstanceList examples) { this.crf = new CRF(examples.getPipe(), null); crf.addOrderNStates(examples, new int[]{1}, null, null, null, null, false); crf.addStartState(); crf.setWeightsDimensionAsIn(examples, false); if (crfFrom != null) { crf.initializeApplicableParametersFrom(crfFrom); } }
private TransducerTrainer trainOnce(Pipe pipe, InstanceList examples) { Stopwatch watch = Stopwatch.createStarted(); CRF crf = new CRF(pipe, null); crf.addOrderNStates(examples, new int[]{1}, null, null, null, null, false); crf.addStartState(); crf.setWeightsDimensionAsIn(examples, true); if (initFrom != null) { crf.initializeApplicableParametersFrom(initFrom); } log.info("Starting syllchain training..."); CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf, 8); trainer.setGaussianPriorVariance(2); trainer.setAddNoFactors(true); // trainer.setUseSomeUnsupportedTrick(true); trainer.train(examples); trainer.shutdown(); watch.stop(); log.info("SyllChain CRF Training took " + watch.toString()); crf.getInputAlphabet().stopGrowth(); crf.getOutputAlphabet().stopGrowth(); return trainer; }
private TransducerTrainer trainOnce(Pipe pipe, InstanceList trainData) { Stopwatch watch = Stopwatch.createStarted(); CRF crf = new CRF(pipe, null); crf.addOrderNStates(trainData, new int[]{1}, null, null, null, null, false); crf.addStartState(); crf.setWeightsDimensionAsIn(trainData, false); if (initFrom != null) { crf.initializeApplicableParametersFrom(initFrom); } log.info("Starting alignTag training..."); CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf, 8); trainer.setGaussianPriorVariance(2); trainer.setAddNoFactors(true); trainer.setUseSomeUnsupportedTrick(false); trainer.train(trainData); trainer.shutdown(); watch.stop(); log.info("Syll align Tag CRF Training took " + watch.toString()); crf.getInputAlphabet().stopGrowth(); crf.getOutputAlphabet().stopGrowth(); return trainer; }
public void testTrainStochasticGradient() { Pipe p = makeSpacePredictionPipe(); Pipe p2 = new TestCRF2String(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[] { .5, .5 }); CRF crf = new CRF(p, p2); crf.addFullyConnectedStatesForLabels(); crf.setWeightsDimensionAsIn(lists[0], false); CRFTrainerByStochasticGradient crft = new CRFTrainerByStochasticGradient( crf, 0.0001); System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training..."); // either fixed learning rate or selected on a sample crft.setLearningRateByLikelihood(lists[0]); // crft.setLearningRate(0.01); crft.train(lists[0], 100); crf.print(); System.out.println("Training Accuracy after training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy after training = " + crf.averageTokenAccuracy(lists[1])); }
public void testTrainStochasticGradient() { Pipe p = makeSpacePredictionPipe(); Pipe p2 = new TestCRF2String(); InstanceList instances = new InstanceList(p); instances.addThruPipe(new ArrayIterator(data)); InstanceList[] lists = instances.split(new double[] { .5, .5 }); CRF crf = new CRF(p, p2); crf.addFullyConnectedStatesForLabels(); crf.setWeightsDimensionAsIn(lists[0], false); CRFTrainerByStochasticGradient crft = new CRFTrainerByStochasticGradient( crf, 0.0001); System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy before training = " + crf.averageTokenAccuracy(lists[1])); System.out.println("Training..."); // either fixed learning rate or selected on a sample crft.setLearningRateByLikelihood(lists[0]); // crft.setLearningRate(0.01); crft.train(lists[0], 100); crf.print(); System.out.println("Training Accuracy after training = " + crf.averageTokenAccuracy(lists[0])); System.out.println("Testing Accuracy after training = " + crf.averageTokenAccuracy(lists[1])); }
public CRFOptimizableByLabelLikelihood getOptimizableCRF (InstanceList trainingSet) { if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) { if (!useNoWeights) { if (useSparseWeights) crf.setWeightsDimensionAsIn (trainingSet, useSomeUnsupportedTrick); else crf.setWeightsDimensionDensely (); } //reallocateSufficientStatistics(); // Not necessary here because it is done in the constructor for OptimizableCRF ocrf = null; cachedWeightsStructureStamp = crf.weightsStructureChangeStamp; } if (ocrf == null || ocrf.trainingSet != trainingSet) { //ocrf = new OptimizableCRF (crf, trainingSet); ocrf = new CRFOptimizableByLabelLikelihood(crf, trainingSet); ocrf.setGaussianPriorVariance(gaussianPriorVariance); ocrf.setHyperbolicPriorSharpness(hyperbolicPriorSharpness); ocrf.setHyperbolicPriorSlope(hyperbolicPriorSlope); ocrf.setUseHyperbolicPrior(usingHyperbolicPrior); opt = null; } return ocrf; }
public CRFOptimizableByLabelLikelihood getOptimizableCRF (InstanceList trainingSet) { if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) { if (!useNoWeights) { if (useSparseWeights) crf.setWeightsDimensionAsIn (trainingSet, useSomeUnsupportedTrick); else crf.setWeightsDimensionDensely (); } //reallocateSufficientStatistics(); // Not necessary here because it is done in the constructor for OptimizableCRF ocrf = null; cachedWeightsStructureStamp = crf.weightsStructureChangeStamp; } if (ocrf == null || ocrf.trainingSet != trainingSet) { //ocrf = new OptimizableCRF (crf, trainingSet); ocrf = new CRFOptimizableByLabelLikelihood(crf, trainingSet); ocrf.setGaussianPriorVariance(gaussianPriorVariance); ocrf.setHyperbolicPriorSharpness(hyperbolicPriorSharpness); ocrf.setHyperbolicPriorSlope(hyperbolicPriorSlope); ocrf.setUseHyperbolicPrior(usingHyperbolicPrior); opt = null; } return ocrf; }
private TransducerTrainer trainOnce(Pipe pipe, InstanceList trainData) { Stopwatch watch = Stopwatch.createStarted(); CRF crf = new CRF(pipe, null); // O,O O,N -O,C- // N,O N,N N,C // C,O ?C,N? C,C Pattern forbidden = null; if (USE_ONC_CODING) { forbidden = Pattern.compile("(O,C|<START>,C|O,<END>)", Pattern.CASE_INSENSITIVE); } crf.addOrderNStates(trainData, new int[]{1}, null, null, forbidden, null, false); crf.addStartState(); crf.setWeightsDimensionAsIn(trainData); if (this.pullFrom != null) { crf.initializeApplicableParametersFrom(pullFrom); } log.info("Starting syll phone training..."); CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf, 8); trainer.setGaussianPriorVariance(2); trainer.setAddNoFactors(false); trainer.setUseSomeUnsupportedTrick(true); trainer.train(trainData); trainer.shutdown(); watch.stop(); pipe.getAlphabet().stopGrowth(); pipe.getTargetAlphabet().stopGrowth(); log.info("Align Tag CRF Training took " + watch.toString()); return trainer; }
public CRFOptimizableByLabelLikelihood getOptimizableCRF(InstanceList trainingSet) { if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) { if (!useNoWeights) { if (useSparseWeights) { crf.setWeightsDimensionAsIn(trainingSet, useSomeUnsupportedTrick); } else { crf.setWeightsDimensionDensely(); } } // reallocateSufficientStatistics(); // Not necessary here because it is done in the constructor for // OptimizableCRF ocrf = null; cachedWeightsStructureStamp = crf.weightsStructureChangeStamp; } if (ocrf == null || ocrf.trainingSet != trainingSet) { // ocrf = new OptimizableCRF (crf, trainingSet); ocrf = new CRFOptimizableByLabelLikelihood(crf, trainingSet); ocrf.setGaussianPriorVariance(gaussianPriorVariance); ocrf.setHyperbolicPriorSharpness(hyperbolicPriorSharpness); ocrf.setHyperbolicPriorSlope(hyperbolicPriorSlope); ocrf.setUseHyperbolicPrior(usingHyperbolicPrior); opt = null; } return ocrf; }
public CRFOptimizableByBatchLabelLikelihood getOptimizableCRF (InstanceList trainingSet) { if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) { if (!useNoWeights) { if (useSparseWeights) { crf.setWeightsDimensionAsIn (trainingSet, useSomeUnsupportedTrick); } else { crf.setWeightsDimensionDensely (); } } optimizable = null; cachedWeightsStructureStamp = crf.weightsStructureChangeStamp; } if (optimizable == null || optimizable.trainingSet != trainingSet) { optimizable = new CRFOptimizableByBatchLabelLikelihood(crf, trainingSet, numThreads); optimizable.setGaussianPriorVariance(gaussianPriorVariance); threadedOptimizable = new ThreadedOptimizable(optimizable, trainingSet, crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(crf)); optimizer = null; } return optimizable; }
public CRFOptimizableByBatchLabelLikelihood getOptimizableCRF (InstanceList trainingSet) { if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) { if (!useNoWeights) { if (useSparseWeights) { crf.setWeightsDimensionAsIn (trainingSet, useSomeUnsupportedTrick); } else { crf.setWeightsDimensionDensely (); } } optimizable = null; cachedWeightsStructureStamp = crf.weightsStructureChangeStamp; } if (optimizable == null || optimizable.trainingSet != trainingSet) { optimizable = new CRFOptimizableByBatchLabelLikelihood(crf, trainingSet, numThreads); optimizable.setGaussianPriorVariance(gaussianPriorVariance); threadedOptimizable = new ThreadedOptimizable(optimizable, trainingSet, crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(crf)); optimizer = null; } return optimizable; }
crf.addFullyConnectedStatesForLabels(); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crf.setWeightsDimensionAsIn(one, false); Optimizable.ByGradientValue mcrf = crft.getOptimizableCRF(one); double[] params = new double[mcrf.getNumParameters()];
public CRFOptimizableByBatchLabelLikelihood getOptimizableCRF (InstanceList trainingSet) { if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) { if (!useNoWeights) { if (useSparseWeights) { crf.setWeightsDimensionAsIn (trainingSet, useSomeUnsupportedTrick); } else { crf.setWeightsDimensionDensely (); } } optimizable = null; cachedWeightsStructureStamp = crf.weightsStructureChangeStamp; } if (optimizable == null || optimizable.trainingSet != trainingSet) { optimizable = new CRFOptimizableByBatchLabelLikelihood(crf, trainingSet, numThreads); optimizable.setGaussianPriorVariance(gaussianPriorVariance); // must shutdown existing thread pool before making a new one if (threadedOptimizable != null) { threadedOptimizable.shutdown(); } threadedOptimizable = new ThreadedOptimizable(optimizable, trainingSet, crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(crf)); optimizer = null; } return optimizable; }
public void train(TextBlock textBlock) throws Exception { InstanceList trainingData = new InstanceList(getPipes()); for (TextSentence textSentence : textBlock) { Instance textInstance = new TextInstance(textSentence, getTargetAlphabet()); trainingData.addThruPipe(textInstance); } if (crf == null) { crf = new CRF(getPipes(), null); crf.addFullyConnectedStatesForLabels(); crf.setWeightsDimensionAsIn(trainingData, false); CRFOptimizableByLabelLikelihood optLabel = new CRFOptimizableByLabelLikelihood(crf, trainingData); Optimizable.ByGradientValue[] opts = new Optimizable.ByGradientValue[] { optLabel }; crfTrainer = new CRFTrainerByValueGradients(crf, opts); crfTrainer.setMaxResets(0); } crfTrainer.train(trainingData, Integer.MAX_VALUE); }
public void testPrint() { Pipe p = new SerialPipes(new Pipe[] { new CharSequence2TokenSequence("."), new TokenText(), new TestCRFTokenSequenceRemoveSpaces(), new TokenSequence2FeatureVectorSequence(), new PrintInputAndTarget(), }); InstanceList one = new InstanceList(p); String[] data = new String[] { "ABCDE", }; one.addThruPipe(new ArrayIterator(data)); CRF crf = new CRF(p, null); crf.addFullyConnectedStatesForThreeQuarterLabels(one); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crf.setWeightsDimensionAsIn(one, false); Optimizable mcrf = crft.getOptimizableCRF(one); double[] params = new double[mcrf.getNumParameters()]; for (int i = 0; i < params.length; i++) { params[i] = i; } mcrf.setParameters(params); crf.print(); }
public void testPrint() { Pipe p = new SerialPipes(new Pipe[] { new CharSequence2TokenSequence("."), new TokenText(), new TestCRFTokenSequenceRemoveSpaces(), new TokenSequence2FeatureVectorSequence(), new PrintInputAndTarget(), }); InstanceList one = new InstanceList(p); String[] data = new String[] { "ABCDE", }; one.addThruPipe(new ArrayIterator(data)); CRF crf = new CRF(p, null); crf.addFullyConnectedStatesForThreeQuarterLabels(one); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crf.setWeightsDimensionAsIn(one, false); Optimizable mcrf = crft.getOptimizableCRF(one); double[] params = new double[mcrf.getNumParameters()]; for (int i = 0; i < params.length; i++) { params[i] = i; } mcrf.setParameters(params); crf.print(); }