learningAlgorithm.setDAG(this.dag); learningAlgorithm.setDataStream(dataStream); learningAlgorithm.initLearning(); learningAlgorithm.runLearning();
public static void main(String[] args) throws Exception { //We can open the data stream using the static class DataStreamLoader DataStream<DataInstance> data = DataStreamLoader.open("datasets/simulated/WasteIncineratorSample.arff"); //We create a ParameterLearningAlgorithm object with the MaximumLikehood builder ParameterLearningAlgorithm parameterLearningAlgorithm = new ParallelMaximumLikelihood(); //We fix the DAG structure parameterLearningAlgorithm.setDAG(getNaiveBayesStructure(data,0)); //We should invoke this method before processing any data parameterLearningAlgorithm.initLearning(); //Then we show how we can perform parameter learnig by a sequential updating of data batches. for (DataOnMemory<DataInstance> batch : data.iterableOverBatches(100)){ parameterLearningAlgorithm.updateModel(batch); } //And we get the model BayesianNetwork bnModel = parameterLearningAlgorithm.getLearntBayesianNetwork(); //We print the model System.out.println(bnModel.toString()); }
@Override public void buildClusterer(Instances data) throws Exception { attributes_ = Converter.convertAttributes(data.enumerateAttributes()); Variables modelHeader = new Variables(attributes_); clusterVar_ = modelHeader.newMultinomialVariable("clusterVar", this.numberOfClusters()); inferenceAlgorithm_ = new ImportanceSampling(); inferenceAlgorithm_.setSeed(this.getSeed()); dag = new DAG(modelHeader); /* Set DAG structure. */ /* Add the hidden cluster variable as a parent of all the predictive attributes. */ dag.getParentSets().stream() .filter(w -> w.getMainVar().getVarID() != clusterVar_.getVarID()) .filter(w -> w.getMainVar().isObservable()) .forEach(w -> w.addParent(clusterVar_)); System.out.println(dag.toString()); parameterLearningAlgorithm_ = new SVB(); parameterLearningAlgorithm_.setDAG(dag); DataOnMemoryListContainer<DataInstance> batch_ = new DataOnMemoryListContainer(attributes_); data.stream().forEach(instance -> batch_.add(new DataInstanceFromDataRow(new DataRowWeka(instance, this.attributes_))) ); parameterLearningAlgorithm_.setDataStream(batch_); parameterLearningAlgorithm_.initLearning(); parameterLearningAlgorithm_.runLearning(); bnModel_ = parameterLearningAlgorithm_.getLearntBayesianNetwork(); System.out.println(bnModel_); inferenceAlgorithm_.setModel(bnModel_); }
parameterLearningAlgorithm_.initLearning(); parameterLearningAlgorithm_.runLearning();
parameterLearningAlgorithm_.initLearning(); parameterLearningAlgorithm_.runLearning();
protected void initLearning() { if(learningAlgorithm==null) { SVB svb = new SVB(); svb.setWindowsSize(100); svb.getPlateuStructure().getVMP().setTestELBO(false); svb.getPlateuStructure().getVMP().setMaxIter(100); svb.getPlateuStructure().getVMP().setThreshold(0.00001); learningAlgorithm = svb; } learningAlgorithm.setWindowsSize(windowSize); if (this.getDAG()!=null) learningAlgorithm.setDAG(this.getDAG()); else if (this.getPlateuStructure()!=null) ((BayesianParameterLearningAlgorithm)learningAlgorithm).setPlateuStructure(this.getPlateuStructure()); else throw new IllegalArgumentException("Non provided dag or PlateauStructure"); learningAlgorithm.setOutput(true); learningAlgorithm.initLearning(); initialized=true; }
/** * {@inheritDoc} */ @Override public void trainOnInstanceImpl(Instance instance) { DataInstance dataInstance = new DataInstanceFromDataRow(new DataRowWeka(instance, this.attributes_)); if(batch_.getNumberOfDataInstances() < getBatchSize_()-1) { //store batch_.add(dataInstance); }else{ //store & learn batch_.add(dataInstance); if(bnModel_==null) { //parameterLearningAlgorithm_.setParallelMode(isParallelMode_()); parameterLearningAlgorithm_.setDAG(dag); parameterLearningAlgorithm_.initLearning(); parameterLearningAlgorithm_.updateModel(batch_); }else{ parameterLearningAlgorithm_.updateModel(batch_); } bnModel_ = parameterLearningAlgorithm_.getLearntBayesianNetwork(); predictions_.setModel(bnModel_); batch_ = new DataOnMemoryListContainer(attributes_); } }
@Override protected void initLearning() { if(learningAlgorithm==null) { SVB svb = new SVB(); plateauLDA = new PlateauLDA(this.atts, "word", "count"); plateauLDA.setNTopics(ntopics); svb.setPlateuStructure(plateauLDA); svb.getPlateuStructure().getVMP().setTestELBO(false); svb.getPlateuStructure().getVMP().setMaxIter(100); svb.getPlateuStructure().getVMP().setThreshold(0.01); learningAlgorithm = svb; } learningAlgorithm.setWindowsSize(100); learningAlgorithm.setOutput(true); learningAlgorithm.initLearning(); initialized=true; }
/** * {@inheritDoc} */ @Override public void trainOnInstanceImpl(Instance instance) { DataInstance dataInstance = new DataInstanceFromDataRow(new DataRowWeka(instance, this.attributes_)); if(batch_.getNumberOfDataInstances() < getBatchSize_()-1) { //store batch_.add(dataInstance); }else{ //store & learn batch_.add(dataInstance); if(bnModel_==null) { //parameterLearningAlgorithm_.setParallelMode(isParallelMode_()); parameterLearningAlgorithm_.setDAG(dag); parameterLearningAlgorithm_.initLearning(); parameterLearningAlgorithm_.updateModel(batch_); }else{ parameterLearningAlgorithm_.updateModel(batch_); } bnModel_ = parameterLearningAlgorithm_.getLearntBayesianNetwork(); predictions_.setModel(bnModel_); batch_ = new DataOnMemoryListContainer(attributes_); } }
parameterLearningAlgorithm_.initLearning(); parameterLearningAlgorithm_.updateModel(batch_); }else {
@Override protected void initLearning() { if (this.getDAG()==null) buildDAG(); if(learningAlgorithm==null) { SVB svb = new SVB(); svb.setSeed(this.seed); svb.setPlateuStructure(new PlateuIIDReplication(hiddenVars)); GaussianHiddenTransitionMethod gaussianHiddenTransitionMethod = new GaussianHiddenTransitionMethod(hiddenVars, 0, this.transitionVariance); gaussianHiddenTransitionMethod.setFading(fading); svb.setTransitionMethod(gaussianHiddenTransitionMethod); svb.setDAG(dag); svb.setOutput(false); svb.getPlateuStructure().getVMP().setMaxIter(1000); svb.getPlateuStructure().getVMP().setThreshold(0.001); learningAlgorithm = svb; } learningAlgorithm.setWindowsSize(windowSize); if (this.getDAG()!=null) learningAlgorithm.setDAG(this.getDAG()); else throw new IllegalArgumentException("Non provided dag"); learningAlgorithm.setOutput(false); learningAlgorithm.initLearning(); initialized=true; }