@Override public Iterator<Row> call(String line) throws Exception { List<Row> list = new ArrayList<>(); String[] tokens = line.split("\\s"); for (int i = 0; i < tokens.length; i++) { int start = (i - brodcastWindow.value() < 0) ? 0 : i - brodcastWindow.value(); int end = (i + brodcastWindow.value() >= tokens.length) ? tokens.length - 1 : i + brodcastWindow.value(); for (int j = start; j <= end; j++) { if (j != i) { list.add(RowFactory.create(tokens[i], tokens[j], 1)); } else { // do nothing continue; } } } return list.iterator(); } });
@Test public void isInCollectionCheckExceptionMessage() { List<Row> rows = Arrays.asList( RowFactory.create(1, Arrays.asList(1)), RowFactory.create(2, Arrays.asList(2)), RowFactory.create(3, Arrays.asList(3))); StructType schema = createStructType(Arrays.asList( createStructField("a", IntegerType, false), createStructField("b", createArrayType(IntegerType, false), false))); Dataset<Row> df = spark.createDataFrame(rows, schema); try { df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b")))); Assert.fail("Expected org.apache.spark.sql.AnalysisException"); } catch (Exception e) { Arrays.asList("cannot resolve", "due to data type mismatch: Arguments must be same type but were") .forEach(s -> Assert.assertTrue( e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))); } } }
@Test public void isInCollectionCheckExceptionMessage() { List<Row> rows = Arrays.asList( RowFactory.create(1, Arrays.asList(1)), RowFactory.create(2, Arrays.asList(2)), RowFactory.create(3, Arrays.asList(3))); StructType schema = createStructType(Arrays.asList( createStructField("a", IntegerType, false), createStructField("b", createArrayType(IntegerType, false), false))); Dataset<Row> df = spark.createDataFrame(rows, schema); try { df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b")))); Assert.fail("Expected org.apache.spark.sql.AnalysisException"); } catch (Exception e) { Arrays.asList("cannot resolve", "due to data type mismatch: Arguments must be same type but were") .forEach(s -> Assert.assertTrue( e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))); } } }
@Test public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); Dataset<Row> df = spark.createDataFrame(rows, schema); List<Row> result = df.collectAsList(); Assert.assertEquals(1, result.size()); }
@Test public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); Dataset<Row> df = spark.createDataFrame(rows, schema); List<Row> result = df.collectAsList(); Assert.assertEquals(1, result.size()); }
@Test public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); Dataset<Row> df = spark.createDataFrame(rows, schema); List<Row> result = df.collectAsList(); Assert.assertEquals(1, result.size()); }
@Test public void applySchema() { List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); personList.add(person1); Person person2 = new Person(); person2.setName("Yin"); person2.setAge(28); personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<Row> actual = spark.sql("SELECT * FROM people").collectAsList(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); expected.add(RowFactory.create("Yin", 28)); Assert.assertEquals(expected, actual); }
@Test public void applySchema() { List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); personList.add(person1); Person person2 = new Person(); person2.setName("Yin"); person2.setAge(28); personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<Row> actual = spark.sql("SELECT * FROM people").collectAsList(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); expected.add(RowFactory.create("Yin", 28)); Assert.assertEquals(expected, actual); }
@Test public void applySchema() { List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); personList.add(person1); Person person2 = new Person(); person2.setName("Yin"); person2.setAge(28); personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<Row> actual = spark.sql("SELECT * FROM people").collectAsList(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); expected.add(RowFactory.create("Yin", 28)); Assert.assertEquals(expected, actual); }
@Test public void isInCollectionWorksCorrectlyOnJava() { List<Row> rows = Arrays.asList( RowFactory.create(1, "x"), RowFactory.create(2, "y"), RowFactory.create(3, "z")); StructType schema = createStructType(Arrays.asList( createStructField("a", IntegerType, false), createStructField("b", StringType, false))); Dataset<Row> df = spark.createDataFrame(rows, schema); // Test with different types of collections Assert.assertTrue(Arrays.equals( (Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(), (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() )); Assert.assertTrue(Arrays.equals( (Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(), (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() )); Assert.assertTrue(Arrays.equals( (Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(), (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect() )); }
@Test public void isInCollectionWorksCorrectlyOnJava() { List<Row> rows = Arrays.asList( RowFactory.create(1, "x"), RowFactory.create(2, "y"), RowFactory.create(3, "z")); StructType schema = createStructType(Arrays.asList( createStructField("a", IntegerType, false), createStructField("b", StringType, false))); Dataset<Row> df = spark.createDataFrame(rows, schema); // Test with different types of collections Assert.assertTrue(Arrays.equals( (Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(), (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() )); Assert.assertTrue(Arrays.equals( (Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(), (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() )); Assert.assertTrue(Arrays.equals( (Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(), (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect() )); }
Row simpleStruct = RowFactory.create( doubleValue, stringValue, timestampValue, null); Row complexStruct = RowFactory.create( simpleStringArray, simpleMap, Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); Assert.assertEquals(arrayOfMaps, complexRow.get(0)); Assert.assertEquals(arrayOfRows, complexRow.get(1));
Row simpleStruct = RowFactory.create( doubleValue, stringValue, timestampValue, null); Row complexStruct = RowFactory.create( simpleStringArray, simpleMap, Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); Assert.assertEquals(arrayOfMaps, complexRow.get(0)); Assert.assertEquals(arrayOfRows, complexRow.get(1));
Row simpleStruct = RowFactory.create( doubleValue, stringValue, timestampValue, null); Row complexStruct = RowFactory.create( simpleStringArray, simpleMap, Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); Assert.assertEquals(arrayOfMaps, complexRow.get(0)); Assert.assertEquals(arrayOfRows, complexRow.get(1));
@Test public void dataFrameRDDOperations() { List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); personList.add(person1); Person person2 = new Person(); person2.setName("Yin"); person2.setAge(28); personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(row -> row.getString(0) + "_" + row.get(1)).collect(); List<String> expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); Assert.assertEquals(expected, actual); }
@Test public void dataFrameRDDOperations() { List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); personList.add(person1); Person person2 = new Person(); person2.setName("Yin"); person2.setAge(28); personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(row -> row.getString(0) + "_" + row.get(1)).collect(); List<String> expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); Assert.assertEquals(expected, actual); }
@Test public void dataFrameRDDOperations() { List<Person> personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); personList.add(person1); Person person2 = new Person(); person2.setName("Yin"); person2.setAge(28); personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(row -> row.getString(0) + "_" + row.get(1)).collect(); List<String> expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); Assert.assertEquals(expected, actual); }