@Test public void saveAndLoad() { Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); Dataset<Row> loadedDF = spark.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); }
@Test public void testExecution() { Dataset<Row> df = spark.table("testData").filter("key = 1"); Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); }
LOG.debug("Found {} telemetry record(s)", telemetry.cache().count()); .flatMap(messageRouterFunction(profilerProps, profiles, globals), Encoders.bean(MessageRoute.class)); LOG.debug("Generated {} message route(s)", routes.cache().count()); .groupByKey(new GroupByPeriodFunction(profilerProps), Encoders.STRING()) .mapGroups(new ProfileBuilderFunction(profilerProps, globals), Encoders.bean(ProfileMeasurementAdapter.class)); LOG.debug("Produced {} profile measurement(s)", measurements.cache().count()); .mapPartitions(new HBaseWriterFunction(profilerProps), Encoders.INT()) .agg(sum("value")) .head() .getLong(0); LOG.debug("{} profile measurement(s) written to HBase", count);
@Test public void testCollectAndTake() { Dataset<Row> df = spark.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); }
@Test public void javaCompatibilityTest() { double[] input = new double[]{1D, 2D, 3D, 4D}; Dataset<Row> dataset = spark.createDataFrame( Arrays.asList(RowFactory.create(Vectors.dense(input))), new StructType(new StructField[]{ new StructField("vec", (new VectorUDT()), false, Metadata.empty()) })); double[] expectedResult = input.clone(); (new DoubleDCT_1D(input.length)).forward(expectedResult, true); DCT dct = new DCT() .setInputCol("vec") .setOutputCol("resultVec"); List<Row> result = dct.transform(dataset).select("resultVec").collectAsList(); Vector resultVec = result.get(0).getAs("resultVec"); Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); } }
@Test public void applySchemaToJSON() { Dataset<String> jsonDS = spark.createDataset(Arrays.asList( "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + List<Row> expectedResult = new ArrayList<>(2); expectedResult.add( RowFactory.create( new BigDecimal("92233720368547758070"), true, "this is another simple string.")); Dataset<Row> df1 = spark.read().json(jsonDS); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.createOrReplaceTempView("jsonTable1"); List<Row> actual1 = spark.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); Dataset<Row> df2 = spark.read().schema(expectedSchema).json(jsonDS); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.createOrReplaceTempView("jsonTable2"); List<Row> actual2 = spark.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2);
@Test public void dataFrameRDDOperations() { List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); personList.add(person1); Person person2 = new Person(); person2.setName("Yin"); person2.setAge(28); personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(row -> row.getString(0) + "_" + row.get(1)).collect(); List<String> expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); Assert.assertEquals(expected, actual); }
.builder() .appName("SparkSQLRelativeFrequency") .config(sparkConf) .getOrCreate(); JavaSparkContext sc = new JavaSparkContext(spark.sparkContext()); int neighborWindow = Integer.parseInt(args[0]); String input = args[1]; StructType rfSchema = new StructType(new StructField[]{ new StructField("word", DataTypes.StringType, false, Metadata.empty()), new StructField("neighbour", DataTypes.StringType, false, Metadata.empty()), new StructField("frequency", DataTypes.IntegerType, false, Metadata.empty())}); Dataset<Row> rfDataset = spark.createDataFrame(rowRDD, rfSchema); rfDataset.createOrReplaceTempView("rfTable"); Dataset<Row> sqlResult = spark.sql(query); sqlResult.show(); // print first 20 records on the console sqlResult.write().parquet(output + "/parquetFormat"); // saves output in compressed Parquet format, recommended for large projects. sqlResult.rdd().saveAsTextFile(output + "/textFormat"); // to see output via cat command
void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { StructType schema = df.schema(); Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a")); Assert.assertEquals( new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), schema.apply("b")); ArrayType valueType = new ArrayType(DataTypes.IntegerType, false); MapType mapType = new MapType(DataTypes.StringType, valueType, true); Assert.assertEquals( new StructField("c", mapType, true, Metadata.empty()), schema.apply("c")); Assert.assertEquals( new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), schema.apply("e")); Row first = df.select("a", "b", "c", "d", "e").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); Seq<Integer> result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); for (int i = 0; i < result.length(); i++) { Seq<Integer> outputBuffer = (Seq<Integer>) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"),
.builder() .appName("knn2") .getOrCreate(); JavaSparkContext context = JavaSparkContext.fromSparkContext(session.sparkContext()); JavaRDD<String> R = session.read().textFile(datasetR).javaRDD(); R.saveAsTextFile(outputPath+"/R"); JavaRDD<String> S = session.read().textFile(datasetS).javaRDD(); S.saveAsTextFile(outputPath+"/S");
@Test public void testCrosstab() { Dataset<Row> df = spark.table("testData2"); Dataset<Row> crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); Assert.assertEquals("1", columnNames[1]); Assert.assertEquals("2", columnNames[2]); List<Row> rows = crosstab.collectAsList(); rows.sort(crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); Assert.assertEquals(1L, row.getLong(1)); Assert.assertEquals(1L, row.getLong(2)); count++; } }
@Test public void applySchema() { List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); personList.add(person1); Person person2 = new Person(); person2.setName("Yin"); person2.setAge(28); personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<Row> actual = spark.sql("SELECT * FROM people").collectAsList(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); expected.add(RowFactory.create("Yin", 28)); Assert.assertEquals(expected, actual); }
@Test public void bucketizerTest() { double[] splits = {-0.5, 0.0, 0.5}; StructType schema = new StructType(new StructField[]{ new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); Dataset<Row> dataset = spark.createDataFrame( Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), RowFactory.create(0.0), RowFactory.create(0.2)), schema); Bucketizer bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits); List<Row> result = bucketizer.transform(dataset).select("result").collectAsList(); for (Row r : result) { double index = r.getDouble(0); Assert.assertTrue((index >= 0) && (index <= 1)); } }
@SuppressWarnings("unchecked") @Test public void udf4Test() { spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); spark.range(10).toDF("x").createOrReplaceTempView("tmp"); // This tests when Java UDFs are required to be the semantically same (See SPARK-9435). List<Row> results = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").collectAsList(); Assert.assertEquals(10, results.size()); long sum = 0; for (Row result : results) { sum += result.getLong(0); } Assert.assertEquals(55, sum); }
@Override public Dataset<Row> check(Dataset<Row> dataset, Map<String, Dataset<Row>> stepDependencies) { if (isDependency()) { Dataset<Row> expectedDependency = stepDependencies.get(dependency); if (expectedDependency.count() == 1 && expectedDependency.schema().fields().length == 1 && expectedDependency.schema().apply(0).dataType() == DataTypes.LongType) { expected = expectedDependency.collectAsList().get(0).getLong(0); } else { throw new RuntimeException("Step dependency for count rule must have one row with a single field of long type"); } } if (expected < 0) { throw new RuntimeException("Failed to determine expected count: must be specified either as literal or step dependency"); } return dataset.groupBy().count().map(new CheckCount(expected, name), RowEncoder.apply(SCHEMA)); }
@Before public void setUp() throws IOException { spark = SparkSession.builder() .master("local[*]") .appName("testing") .getOrCreate(); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); if (path.exists()) { path.delete(); } List<String> jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } Dataset<String> ds = spark.createDataset(jsonObjects, Encoders.STRING()); df = spark.read().json(ds); df.createOrReplaceTempView("jsonTable"); }
@Test public void testSampleBy() { Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset<Row> sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); Assert.assertEquals(1, actual.get(1).getLong(0)); Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); }
@Test public void testUDF() { UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); Dataset<Row> df = spark.table("testData").select(foo.apply(col("key"), col("value"))); String[] result = df.collectAsList().stream().map(row -> row.getString(0)) .toArray(String[]::new); String[] expected = spark.table("testData").collectAsList().stream() .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assert.assertArrayEquals(expected, result); } }
@Test public void testBeanWithoutGetter() { BeanWithoutGetter bean = new BeanWithoutGetter(); List<BeanWithoutGetter> data = Arrays.asList(bean); Dataset<Row> df = spark.createDataFrame(data, BeanWithoutGetter.class); Assert.assertEquals(df.schema().length(), 0); Assert.assertEquals(df.collectAsList().size(), 1); }