Refine search
public String call(Row row) { return row.getString(0); }}); System.out.println(topTweetText.collect());
@Test public void testJavaWord2Vec() { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); Dataset<Row> documentDF = spark.createDataFrame( Arrays.asList( RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))), schema); Word2Vec word2Vec = new Word2Vec() .setInputCol("text") .setOutputCol("result") .setVectorSize(3) .setMinCount(0); Word2VecModel model = word2Vec.fit(documentDF); Dataset<Row> result = model.transform(documentDF); for (Row r : result.select("result").collectAsList()) { double[] polyFeatures = ((Vector) r.get(0)).toArray(); Assert.assertEquals(polyFeatures.length, 3); } } }
@Test public void javaCompatibilityTest() { double[] input = new double[]{1D, 2D, 3D, 4D}; Dataset<Row> dataset = spark.createDataFrame( Arrays.asList(RowFactory.create(Vectors.dense(input))), new StructType(new StructField[]{ new StructField("vec", (new VectorUDT()), false, Metadata.empty()) })); double[] expectedResult = input.clone(); (new DoubleDCT_1D(input.length)).forward(expectedResult, true); DCT dct = new DCT() .setInputCol("vec") .setOutputCol("resultVec"); List<Row> result = dct.transform(dataset).select("resultVec").collectAsList(); Vector resultVec = result.get(0).getAs("resultVec"); Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); } }
Assert.assertEquals(byteValue, simpleRow.getByte(0)); Assert.assertEquals(byteValue, simpleRow.get(0)); Assert.assertEquals(byteValue, simpleRow.getByte(1)); Assert.assertEquals(byteValue, simpleRow.get(1)); Assert.assertEquals(shortValue, simpleRow.getShort(2)); Assert.assertEquals(shortValue, simpleRow.get(2)); Assert.assertEquals(shortValue, simpleRow.getShort(3)); Assert.assertEquals(shortValue, simpleRow.get(3)); Assert.assertEquals(intValue, simpleRow.getInt(4)); Assert.assertEquals(intValue, simpleRow.get(4)); Assert.assertEquals(intValue, simpleRow.getInt(5)); Assert.assertEquals(intValue, simpleRow.get(5)); Assert.assertEquals(longValue, simpleRow.getLong(6)); Assert.assertEquals(longValue, simpleRow.get(6)); Assert.assertEquals(longValue, simpleRow.getLong(7)); Assert.assertEquals(longValue, simpleRow.get(7)); Assert.assertEquals(floatValue, simpleRow.getFloat(8), 0); Assert.assertEquals(floatValue, simpleRow.get(8)); Assert.assertEquals(floatValue, simpleRow.getFloat(9), 0); Assert.assertEquals(floatValue, simpleRow.get(9)); Assert.assertEquals(doubleValue, simpleRow.getDouble(10), 0); Assert.assertEquals(doubleValue, simpleRow.get(10)); Assert.assertEquals(doubleValue, simpleRow.getDouble(11), 0); Assert.assertEquals(doubleValue, simpleRow.get(11)); Assert.assertEquals(decimalValue, simpleRow.get(12)); Assert.assertEquals(booleanValue, simpleRow.getBoolean(13)); Assert.assertEquals(booleanValue, simpleRow.get(13)); Assert.assertEquals(booleanValue, simpleRow.getBoolean(14));
@Test public void bucketizerTest() { double[] splits = {-0.5, 0.0, 0.5}; StructType schema = new StructType(new StructField[]{ new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); Dataset<Row> dataset = spark.createDataFrame( Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), RowFactory.create(0.0), RowFactory.create(0.2)), schema); Bucketizer bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits); List<Row> result = bucketizer.transform(dataset).select("result").collectAsList(); for (Row r : result) { double index = r.getDouble(0); Assert.assertTrue((index >= 0) && (index <= 1)); } }
@Test public void dataFrameRDDOperations() { List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); personList.add(person1); Person person2 = new Person(); person2.setName("Yin"); person2.setAge(28); personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(row -> row.getString(0) + "_" + row.get(1)).collect(); List<String> expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); Assert.assertEquals(expected, actual); }
@Test public void vectorSlice() { Attribute[] attrs = new Attribute[]{ NumericAttribute.defaultAttr().withName("f1"), NumericAttribute.defaultAttr().withName("f2"), NumericAttribute.defaultAttr().withName("f3") }; AttributeGroup group = new AttributeGroup("userFeatures", attrs); List<Row> data = Arrays.asList( RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) ); Dataset<Row> dataset = spark.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); Dataset<Row> output = vectorSlicer.transform(dataset); for (Row r : output.select("userFeatures", "features").takeAsList(2)) { Vector features = r.getAs(1); Assert.assertEquals(features.size(), 2); } } }
private static void appendValue(ColumnVector dst, DataType t, Row src, int fieldIdx) { if (t instanceof ArrayType) { ArrayType at = (ArrayType)t; if (src.isNullAt(fieldIdx)) { dst.appendNull(); } else { List<Object> values = src.getList(fieldIdx); dst.appendArray(values.size()); for (Object o : values) { appendValue(dst.arrayData(), at.elementType(), o); } } } else if (t instanceof StructType) { StructType st = (StructType)t; if (src.isNullAt(fieldIdx)) { dst.appendStruct(true); } else { dst.appendStruct(false); Row c = src.getStruct(fieldIdx); for (int i = 0; i < st.fields().length; i++) { appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i); } } } else { appendValue(dst, t, src.get(fieldIdx)); } }
Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), schema.apply("e")); Row first = df.select("a", "b", "c", "d", "e").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); Seq<Integer> result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); for (int i = 0; i < result.length(); i++) { Seq<Integer> outputBuffer = (Seq<Integer>) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), Ints.toArray(JavaConverters.seqAsJavaListConverter(outputBuffer).asJava())); Seq<String> d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
@Test public void testCrosstab() { Dataset<Row> df = spark.table("testData2"); Dataset<Row> crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); Assert.assertEquals("1", columnNames[1]); Assert.assertEquals("2", columnNames[2]); List<Row> rows = crosstab.collectAsList(); rows.sort(crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); Assert.assertEquals(1L, row.getLong(1)); Assert.assertEquals(1L, row.getLong(2)); count++; } }
@SuppressWarnings("unchecked") @Test public void udf3Test() { spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(), DataTypes.IntegerType); Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); // returnType is not provided spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null); result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); }
@Test public void testExecution() { Dataset<Row> df = spark.table("testData").filter("key = 1"); Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); }
@Test public void testSampleBy() { Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset<Row> sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); Assert.assertEquals(1, actual.get(1).getLong(0)); Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); }
@SuppressWarnings("unchecked") @Test public void udf4Test() { spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); spark.range(10).toDF("x").createOrReplaceTempView("tmp"); // This tests when Java UDFs are required to be the semantically same (See SPARK-9435). List<Row> results = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").collectAsList(); Assert.assertEquals(10, results.size()); long sum = 0; for (Row result : results) { sum += result.getLong(0); } Assert.assertEquals(55, sum); }
@Test public void testUDF() { UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); Dataset<Row> df = spark.table("testData").select(foo.apply(col("key"), col("value"))); String[] result = df.collectAsList().stream().map(row -> row.getString(0)) .toArray(String[]::new); String[] expected = spark.table("testData").collectAsList().stream() .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assert.assertArrayEquals(expected, result); } }
@Test public void pivot() { Dataset<Row> df = spark.table("courseSales"); List<Row> actual = df.groupBy("year") .pivot("course", Arrays.asList("dotNET", "Java")) .agg(sum("earnings")).orderBy("year").collectAsList(); Assert.assertEquals(2012, actual.get(0).getInt(0)); Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01); Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01); Assert.assertEquals(2013, actual.get(1).getInt(0)); Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01); Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); }
@Test public void verifyLibSVMDF() { Dataset<Row> dataset = spark.read().format("libsvm").option("vectorType", "dense") .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); Row r = dataset.first(); Assert.assertEquals(1.0, r.getDouble(0), 1e-15); DenseVector v = r.getAs(1); Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v); } }
Schema fieldSchema = field.getSchema(); if (row.isNullAt(idx) && !fieldSchema.isNullable()) { throw new NullPointerException("Null value is not allowed in record field at " + fieldPath); if (row.isNullAt(idx)) { idx++; continue; builder.set(field.getName(), fromRowValue(row.getList(idx), fieldSchema, fieldPath)); } else if (fieldSchema.getType() == Schema.Type.MAP) { builder.set(field.getName(), fromRowValue(row.getJavaMap(idx), fieldSchema, fieldPath)); } else { Object fieldValue = row.get(idx);