public static RerankerModel trainRerankerModel(double C, int n_thread, StructuredProblem train) throws Exception { RerankerModel model = new RerankerModel(); model.para = new JLISParameters(); // para.total_number_features = train.label_mapping.size() * // train.n_base_feature_in_train; model.para.c_struct = C; model.para.TRAINMINI = true; // play with the following two parameters if you want to solve SSVM more // tightly model.para.DUAL_GAP = 0.01; model.para.WORKINGSETSVM_STOP = 0.01; System.out.println("Initializing Solvers..."); System.out.flush(); AbstractLossSensitiveStructureFinder[] s_finder_list = new AbstractLossSensitiveStructureFinder[n_thread]; for (int i = 0; i < s_finder_list.length; i++) { s_finder_list[i] = new RerankerBestItemFinder(); } System.out.println("Done!"); System.out.flush(); L2LossParallelJLISLearner learner = new L2LossParallelJLISLearner(); // train model model.wv = learner.parallelTrainStructuredSVM(s_finder_list, train, model.para); return model; }
final BinaryProblem bp, final JLISParameters para) throws Exception { return multiThreadGetJointWeightVector(init_wv, struct_finder_list, empty_s, bp, para);
int n_p_update = multiThreadUpdateStructuresforPositiveExamples( alpha_ins_list, new_wv, struct_finder_list, para.verbose_level); Pair<Integer, Integer> count = multiThreadUpdateStructuresforNegativeAndStructuredExamples( alpha_ins_list, new_wv, struct_finder_list, para.verbose_level); WorkingSetSVMResult svm_res = getWeightVectorWithWorkingSetCDSVM( alpha_ins_list, new_wv.isExtendable(), para.WORKINGSETSVM_STOP, WorkingSetSVMResult svm_res = getWeightVectorWithWorkingSetCDSVM( alpha_ins_list, new_wv.isExtendable(), para.WORKINGSETSVM_STOP, para.MAX_SVM_ITER, printTotalNumberofAlphas(alpha_ins_list); printTotalNumberofAlphas(alpha_ins_list); System.out .println("The real objective value : " + getPrimalObjective( alpha_ins_list, new_wv,
WeightVector init_wv = parallelTrainStructuredSVM(struct_finder_list, sp, para); WeightVector res_wv = multiThreadGetJointWeightVector(init_wv, struct_finder_list, sp, bp, para);
int total_size = struct_size + binary_size; L2LossInstanceWithAlphas[] alpha_ins_list = initArrayOfInstances(sp, empty_b, para.c_struct, para.c_binary, struct_size, total_size); + " #binary: " + binary_size); Pair<WeightVector, Double> res = multitreadTrainJLIS( struct_finder_list, para.MAX_OUT_ITER, struct_size, total_size, wv, alpha_ins_list, para); System.out .println("primal: " + getPrimalObjective( alpha_ins_list, wv, System.out .println("primal: " + getPrimalObjective( alpha_ins_list, wv,
protected WeightVector multiThreadGetJointWeightVector(WeightVector old_wv, final AbstractStructureFinder[] struct_finder_list, StructuredProblem sp, BinaryProblem bp, JLISParameters para) throws Exception { int struct_size = sp.size(); int binary_size = bp.size(); int total_size = struct_size + binary_size; System.out.println("Number of traing data: #struct: " + struct_size + " #binary: " + binary_size); WeightVector new_wv = new WeightVector(old_wv, 0); // allocate bias term // for indirect // supervision L2LossInstanceWithAlphas[] alpha_ins_list = initArrayOfInstances(sp, bp, para.c_struct, para.c_binary, struct_size, total_size); return multitreadTrainJLIS(struct_finder_list, para.MAX_OUT_ITER, struct_size, total_size, new_wv, alpha_ins_list, para) .getFirst(); }
public static MulticlassModel trainMultiClassModel(double C, int n_thread, LabeledMulticlassData train) throws Exception { MulticlassModel model = new MulticlassModel(); model.lab_mapping = train.label_mapping; // for the bias term model.n_base_feature_in_train = train.n_base_feature_in_train; model.para = new JLISParameters(); // para.total_number_features = train.label_mapping.size() * // train.n_base_feature_in_train; model.para.c_struct = C; model.para.TRAINMINI = true; // play with the following two parameters if you want to solve SSVM more // tightly model.para.DUAL_GAP = 0.01; model.para.WORKINGSETSVM_STOP = 0.01; System.out.println("Initializing Solvers..."); System.out.flush(); AbstractLossSensitiveStructureFinder[] s_finder_list = new AbstractLossSensitiveStructureFinder[n_thread]; for (int i = 0; i < s_finder_list.length; i++) { s_finder_list[i] = new MultiClassStructureFinder(); } System.out.println("Done!"); System.out.flush(); model.s_finder = s_finder_list[0]; L2LossParallelJLISLearner learner = new L2LossParallelJLISLearner(); // train model model.wv = learner.parallelTrainStructuredSVM(s_finder_list, train.sp, model.para); return model; }
public WeightVector parallelTrainLatentStructuredSVMWithInitStructures_old( final AbstractLatentLossSensitiveStructureFinder[] struct_finder_list, final StructuredProblem sp, final JLISParameters para) throws Exception { WeightVector wv = new WeightVector(para.total_number_features + 1); // +1 for (int i = 0; i < para.MAX_OUT_ITER; i++) { wv = multiThreadGetJointWeightVector(wv, struct_finder_list, sp, empty_b, para); for (int j = 0; j < sp.size(); j++) { IStructure newLatentStructureWithSameOutputStructure = struct_finder_list[0] .getBestLatentStructure(wv, sp.input_list.get(j), sp.output_list.get(j)); sp.output_list .set(j, newLatentStructureWithSameOutputStructure); } } return wv; }
L2LossParallelJLISLearner learner = new L2LossParallelJLISLearner(); model.wv = learner.parallelTrainStructuredSVM(s_finder_list, train.sp, model.para); return model;
wv = multiThreadGetJointWeightVector(wv, struct_finder_list, minisp, empty_b, para); return multiThreadGetJointWeightVector(wv, struct_finder_list, sp, empty_b, para);