@Test public void testSelect() { List<Integer> data = Arrays.asList(2, 6); Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()); Dataset<Tuple2<Integer, String>> selected = ds.select( expr("value + 1"), col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING())); Assert.assertEquals( Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), selected.collectAsList()); }
@Test public void testJoin() { List<Integer> data = Arrays.asList(1, 2, 3); Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()).as("a"); List<Integer> data2 = Arrays.asList(2, 3, 4); Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()).as("b"); Dataset<Tuple2<Integer, Integer>> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); Assert.assertEquals( Arrays.asList(tuple2(2, 2), tuple2(3, 3)), joined.collectAsList()); }
@Test public void testSampleBy() { Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset<Row> sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); Assert.assertEquals(1, actual.get(1).getLong(0)); Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); }
@Test public void isInCollectionCheckExceptionMessage() { List<Row> rows = Arrays.asList( RowFactory.create(1, Arrays.asList(1)), RowFactory.create(2, Arrays.asList(2)), RowFactory.create(3, Arrays.asList(3))); StructType schema = createStructType(Arrays.asList( createStructField("a", IntegerType, false), createStructField("b", createArrayType(IntegerType, false), false))); Dataset<Row> df = spark.createDataFrame(rows, schema); try { df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b")))); Assert.fail("Expected org.apache.spark.sql.AnalysisException"); } catch (Exception e) { Arrays.asList("cannot resolve", "due to data type mismatch: Arguments must be same type but were") .forEach(s -> Assert.assertTrue( e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))); } } }
/** * Returns the mappings for the given URI and version. * * @param uri the uri of the concept map for which we get mappings * @param version the version of the concept map for which we get mappings * @return a dataset of mappings for the given URI and version. */ public Dataset<Mapping> getMappings(String uri, String version) { return this.mappings.where(functions.col("conceptmapuri").equalTo(lit(uri)) .and(functions.col("conceptmapversion").equalTo(lit(version)))); }
protected C withConceptMaps(Dataset<T> newMaps, Dataset<Mapping> newMappings) { Dataset<UrlAndVersion> newMembers = getUrlAndVersions(newMaps); // Instantiating a new composite ConceptMaps requires a new timestamp Timestamp timestamp = new Timestamp(System.currentTimeMillis()); Dataset<T> newMapsWithTimestamp = newMaps .withColumn("timestamp", lit(timestamp.toString()).cast("timestamp")) .as(conceptMapEncoder); return newInstance(spark, this.members.union(newMembers), this.conceptMaps.union(newMapsWithTimestamp), this.mappings.union(newMappings)); }
.as("toload"); .as("present") .join( referencesToLoad, col("present.valueSetUri").equalTo(col("toload.valueSetUri")) .and(col("present.valueSetVersion").equalTo(col("toload.valueSetVersion")))) .select("referenceName", "system", "value") .collectAsList(); .join( ancestorsToLoad, col("present.uri").equalTo(col("toload.uri")) .and(col("present.version").equalTo(col("toload.version"))) .and(col("present.ancestorSystem").equalTo(col("toload.ancestorSystem"))) .and(col("present.ancestorValue").equalTo(col("toload.ancestorValue")))) .select("referenceName", "descendantSystem", "descendantValue") .collectAsList();
@Test public void testUDAF() { Dataset<Row> df = hc.range(0, 100).union(hc.range(0, 100)).select(col("id").as("value")); UserDefinedAggregateFunction udaf = new MyDoubleSum(); UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf); // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if // we want to use distinct aggregation. Dataset<Row> aggregatedDF = df.groupBy() .agg( udaf.distinct(col("value")), udaf.apply(col("value")), registeredUDAF.apply(col("value")), callUDF("mydoublesum", col("value"))); List<Row> expectedResult = new ArrayList<>(); expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); checkAnswer( aggregatedDF, expectedResult); } }
/** * Reads a Snomed relationship file and converts it to a {@link HierarchicalElement} dataset. * * @param spark the Spark session * @param snomedRelationshipPath path to the SNOMED relationship file * @return a dataset of{@link HierarchicalElement} representing the hierarchical relationship. */ public static Dataset<HierarchicalElement> readRelationshipFile(SparkSession spark, String snomedRelationshipPath) { return spark.read() .option("header", true) .option("delimiter", "\t") .csv(snomedRelationshipPath) .where(col("typeId").equalTo(lit(SNOMED_ISA_RELATIONSHIP_ID))) .where(col("active").equalTo(lit("1"))) .select(col("destinationId"), col("sourceId")) .where(col("destinationId").isNotNull() .and(col("destinationId").notEqual(lit("")))) .where(col("sourceId").isNotNull() .and(col("sourceId").notEqual(lit("")))) .map((MapFunction<Row, HierarchicalElement>) row -> { HierarchicalElement element = new HierarchicalElement(); element.setAncestorSystem(SNOMED_CODE_SYSTEM_URI); element.setAncestorValue(row.getString(0)); element.setDescendantSystem(SNOMED_CODE_SYSTEM_URI); element.setDescendantValue(row.getString(1)); return element; }, Hierarchies.getHierarchicalElementEncoder()); }
/** * Returns the value set with the given uri and version, or null if there is no such value set. * * @param uri the uri of the value set to return * @param version the version of the value set to return * @return the specified value set. */ public T getValueSet(String uri, String version) { // Load the value sets, which may contain zero items if the value set does not exist // Typecast necessary to placate the Java compiler calling this Scala function T[] valueSets = (T[]) this.valueSets.filter( col("url").equalTo(lit(uri)) .and(col("version").equalTo(lit(version)))) .head(1); if (valueSets.length == 0) { return null; } else { T valueSet = valueSets[0]; Dataset<Value> filteredValues = getValues(uri, version); addToValueSet(valueSet, filteredValues); return valueSet; } }
/** * Returns all value sets that are disjoint with value sets stored in the given database and * adds them to our collection. The directory may be anything readable from a Spark path, * including local filesystems, HDFS, S3, or others. * * @param path a path from which disjoint value sets will be loaded * @param database the database to check value sets against * @return an instance of ValueSets that includes content from that directory that is disjoint * with content already contained in the given database. */ public C withDisjointValueSetsFromDirectory(String path, String database) { Dataset<UrlAndVersion> currentMembers = this.spark.table(database + "." + VALUE_SETS_TABLE) .select("url", "version") .distinct() .as(URL_AND_VERSION_ENCODER) .alias("current"); Dataset<T> valueSets = valueSetDatasetFromDirectory(path) .alias("new") .join(currentMembers, col("new.url").equalTo(col("current.url")) .and(col("new.version").equalTo(col("current.version"))), "leftanti") .as(valueSetEncoder); return withValueSets(valueSets); }
.option("header", "true") .load(filename); authorsDf.show(); authorsDf.printSchema(); .option("header", "true") .load(filename); booksDf.show(); booksDf.printSchema(); .join( booksDf, authorsDf.col("id").equalTo(booksDf.col("authorId")), "left") .withColumn("bookId", booksDf.col("id")) .count(); libraryDf = libraryDf.orderBy(libraryDf.col("count").desc()); libraryDf.show();
ds.printSchema(); ds.show(); ds.where(col("from").getField("x").gt(7.0)).select(col("to")).show(); .where(col("points").getItem(2).getField("y").gt(7.0)) .select(col("name"), size(col("points")).as("count")).show(); .where(size(col("points")).gt(1)) .select(col("name"), size(col("points")).as("count"), col("points").getItem("p1")).show();
roc.show(); roc.select("FPR").show(); System.out.println(binarySummary.areaUnderROC()); double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)) .select("threshold").head().getDouble(0); lrModel.setThreshold(bestThreshold);
@Test public void saveTableAndQueryIt() { checkAnswer( df.select(avg("key").over( Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), hc.sql("SELECT avg(key) " + "OVER (PARTITION BY value " + " ORDER BY key " + " ROWS BETWEEN 1 preceding and 1 following) " + "FROM window_table").collectAsList()); }
private void start() { SparkSession spark = SparkSession.builder().appName("CSV to Dataset") .master("local").getOrCreate(); spark.udf().register("x2Multiplier", new Multiplier2(), DataTypes.IntegerType); String filename = "data/tuple-data-file.csv"; Dataset<Row> df = spark.read().format("csv").option("inferSchema", "true") .option("header", "false").load(filename); df = df.withColumn("label", df.col("_c0")).drop("_c0"); df = df.withColumn("value", df.col("_c1")).drop("_c1"); df = df.withColumn("x2", callUDF("x2Multiplier", df.col("value").cast( DataTypes.IntegerType))); df.show(); } }