private static JavaPairRDD<String,float[]> readFeaturesRDD(JavaSparkContext sparkContext, Path path) { log.info("Loading features RDD from {}", path); JavaRDD<String> featureLines = sparkContext.textFile(path.toString()); return featureLines.mapToPair(line -> { List<?> update = TextUtils.readJSON(line, List.class); String key = update.get(0).toString(); float[] vector = TextUtils.convertViaJSON(update.get(1), float[].class); return new Tuple2<>(key, vector); }); }
@Override public void consumeKeyMessage(String key, String message, Configuration hadoopConf) throws IOException { switch (key) { case "UP": if (model == null) { return; // No model to interpret with yet, so skip it } List<?> update = TextUtils.readJSON(message, List.class); // Update int id = Integer.parseInt(update.get(0).toString()); double[] center = TextUtils.convertViaJSON(update.get(1), double[].class); long count = Long.parseLong(update.get(2).toString()); model.update(id, center, count); break; case "MODEL": case "MODEL-REF": log.info("Loading new model"); PMML pmml = AppPMMLUtils.readPMMLFromUpdateKeyMessage(key, message, hadoopConf); if (pmml == null) { return; } KMeansPMMLUtils.validatePMMLVsSchema(pmml, inputSchema); List<ClusterInfo> clusters = KMeansPMMLUtils.read(pmml); model = new KMeansServingModel(clusters, inputSchema); log.info("New model: {}", model); break; default: throw new IllegalArgumentException("Bad key: " + key); } }
@Test public void testConvertViaJSON() { assertEquals(3, TextUtils.convertViaJSON("3", Long.class).longValue()); assertArrayEquals(new float[] { 1.0f, 2.0f }, TextUtils.convertViaJSON(new double[] { 1.0, 2.0 }, float[].class)); }
float[] vector = TextUtils.convertViaJSON(update.get(2), float[].class); switch (update.get(0).toString()) { case "X":
float[] vector = TextUtils.convertViaJSON(update.get(2), float[].class); switch (update.get(0).toString()) { case "X":
private static Collection<String> checkFeatures(Path path, Collection<String> previousIDs) throws IOException { Collection<String> seenIDs = new HashSet<>(); for (Path file : IOUtils.listFiles(path, "part-*")) { Path uncompressedFile = copyAndUncompress(file); Files.lines(uncompressedFile).forEach(line -> { List<?> update = TextUtils.readJSON(line, List.class); seenIDs.add(update.get(0).toString()); assertEquals(FEATURES, TextUtils.convertViaJSON(update.get(1), float[].class).length); }); Files.delete(uncompressedFile); } assertNotEquals(0, seenIDs.size()); assertTrue(seenIDs.containsAll(previousIDs)); return seenIDs; }
String id = update.get(1).toString(); float[] expected = (isX ? MockALSModelUpdateGenerator.X : MockALSModelUpdateGenerator.Y).get(id); assertArrayEquals(expected, TextUtils.convertViaJSON(update.get(2), float[].class)); @SuppressWarnings("unchecked") Collection<String> knownUsersItems = (Collection<String>) update.get(3); String id = update.get(1).toString(); float[] expected = (isX ? X : Y).get(id); assertArrayEquals(expected, TextUtils.convertViaJSON(update.get(2), float[].class), 1.0e-5f); String otherID = ALSUtilsTest.idToStringID(ALSUtilsTest.stringIDtoID(id) - 99); @SuppressWarnings("unchecked")
for (float f : TextUtils.convertViaJSON(update.get(2), float[].class)) { assertTrue(!Float.isNaN(f) && !Float.isInfinite(f));
List<?> fields = TextUtils.readJSON(update.getMessage(), List.class); int clusterID = (Integer) fields.get(0); double[] updatedCenter = TextUtils.convertViaJSON(fields.get(1), double[].class); int updatedClusterSize = (Integer) fields.get(2); clusterInfos.put(clusterID, new ClusterInfo(clusterID, updatedCenter, updatedClusterSize));
float[] vector = TextUtils.convertViaJSON(update.get(2), float[].class); switch (update.get(0).toString()) { case "X":