public static void main(String[] args) throws Exception { if (args.length < 1) { System.out.println("Usage: " + MLRunner.class.getName() + " <ml-conf-dir>"); System.exit(-1); } String confDir = args[0]; LensMLClient client = new LensMLClient(new LensClient()); MLRunner runner = new MLRunner(); runner.init(client, confDir); runner.train(); System.out.println("Created the Model successfully. Output Table: " + runner.outputTable); } }
@Test public void testGetAlgoParams() throws Exception { Map<String, String> params = mlClient.getAlgoParamDescription(MLUtils .getAlgoName(DecisionTreeAlgo.class)); Assert.assertNotNull(params); Assert.assertFalse(params.isEmpty()); for (String key : params.keySet()) { log.info("## Param " + key + " help = " + params.get(key)); } }
/** * Online predict call given a model ID, algorithm name and sample feature values. * * @param algorithm the algorithm * @param modelID the model id * @param features the features * @return prediction result * @throws LensException the lens exception */ @Override public Object predict(String algorithm, String modelID, Object[] features) throws LensException { return getModel(algorithm, modelID).predict(features); }
@AfterTest public void tearDown() throws Exception { super.tearDown(); Hive hive = Hive.get(new HiveConf()); try { hive.dropDatabase(TEST_DB); } catch (Exception exc) { // Ignore drop db exception log.error("Exception while dropping database.", exc); } mlClient.close(); }
/** * Evaluate model by running it against test data contained in the given table. * * @param session the session * @param table the table * @param algorithm the algorithm * @param modelID the model id * @return Test report object containing test output table, and various evaluation metrics * @throws LensException the lens exception */ @Override public MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID, String outputTable) throws LensException { String reportID = client.testModel(table, algorithm, modelID, outputTable); return getTestReport(algorithm, reportID); }
/** * Train an ML model, with specified algorithm and input data. Do model evaluation using the evaluation data and print * evaluation result * * @throws Exception */ private void runTask() throws Exception { if (mlClient != null) { // Connect to a remote Lens server ml = mlClient; log.info("Working in client mode. Lens session handle {}", mlClient.getSessionHandle().getPublicId()); } else { // In server mode session handle has to be passed by the user as a request parameter ml = MLUtils.getMLService(); log.info("Working in Lens server"); } String[] algoArgs = buildTrainingArgs(); log.info("Starting task {} algo args: {} ", taskID, Arrays.toString(algoArgs)); modelID = ml.train(trainingTable, algorithm, algoArgs); printModelMetadata(taskID, modelID); log.info("Starting test {}", taskID); testTable = (testTable != null) ? testTable : trainingTable; MLTestReport testReport = ml.testModel(mlClient.getSessionHandle(), testTable, algorithm, modelID, outputTable); reportID = testReport.getReportID(); printTestReport(taskID, testReport); saveTask(); }
@Test public void testGetAlgos() throws Exception { List<String> algoNames = mlClient.getAlgorithms(); Assert.assertNotNull(algoNames); Assert.assertTrue( algoNames.contains(MLUtils.getAlgoName(NaiveBayesAlgo.class)), MLUtils.getAlgoName(NaiveBayesAlgo.class)); Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(SVMAlgo.class)), MLUtils.getAlgoName(SVMAlgo.class)); Assert.assertTrue( algoNames.contains(MLUtils.getAlgoName(LogisticRegressionAlgo.class)), MLUtils.getAlgoName(LogisticRegressionAlgo.class)); Assert.assertTrue( algoNames.contains(MLUtils.getAlgoName(DecisionTreeAlgo.class)), MLUtils.getAlgoName(DecisionTreeAlgo.class)); }
@AfterTest public void tearDown() throws Exception { super.tearDown(); Hive hive = Hive.get(new HiveConf()); hive.dropDatabase(TEST_DB); mlClient.close(); }
@BeforeTest public void setUp() throws Exception { super.setUp(); Hive hive = Hive.get(new HiveConf()); Database db = new Database(); db.setName(TEST_DB); hive.createDatabase(db, true); LensClientConfig lensClientConfig = new LensClientConfig(); lensClientConfig.setLensDatabase(TEST_DB); lensClientConfig.set(LensConfConstants.SERVER_BASE_URL, "http://localhost:" + getTestPort() + "/lensapi"); LensClient client = new LensClient(lensClientConfig); mlClient = new LensMLClient(client); }
@BeforeTest public void setUp() throws Exception { super.setUp(); Hive hive = Hive.get(new HiveConf()); Database db = new Database(); db.setName(TEST_DB); hive.createDatabase(db, true); LensClientConfig lensClientConfig = new LensClientConfig(); lensClientConfig.setLensDatabase(TEST_DB); lensClientConfig.set(LensConfConstants.SERVER_BASE_URL, "http://localhost:" + getTestPort() + "/lensapi"); LensClient client = new LensClient(lensClientConfig); mlClient = new LensMLClient(client); }