.setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) .addAggregation(p.symbol("b"), expression("count(*)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); }))
public AggregationNode aggregation(Consumer<AggregationBuilder> aggregationBuilderConsumer) { AggregationBuilder aggregationBuilder = new AggregationBuilder(); aggregationBuilderConsumer.accept(aggregationBuilder); return aggregationBuilder.build(); }
@Test public void testDoesNotFireOnNestedCountAggregateWithNonEmptyGroupBy() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source( p.aggregation(aggregationBuilder -> { aggregationBuilder .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())).groupingSets(singleGroupingSet(ImmutableList.of(p.symbol("orderkey")))); aggregationBuilder .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())); })))) .doesNotFire(); }
@Test public void test() { assertRuleApplication() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(geometry)"), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("geometry"))))) .matches( aggregation( ImmutableMap.of("sp", functionCall("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), project( ImmutableMap.of("partition_count", expression("100"), "envelope", expression("ST_Envelope(geometry)")), values("geometry")))); assertRuleApplication() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(ST_Envelope(geometry))"), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("geometry"))))) .matches( aggregation( ImmutableMap.of("sp", functionCall("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), project( ImmutableMap.of("partition_count", expression("100"), "envelope", expression("ST_Envelope(geometry)")), values("geometry")))); }
@Test public void testDoesNotFire() { assertRuleApplication() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(geometry, 10)"), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("geometry"))))) .doesNotFire(); }
.setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.project( Assignments.identity(p.symbol("b")), p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a"))))))); }))
.addAggregation(pb.symbol("sum", BIGINT), expression("sum(x)"), ImmutableList.of(BIGINT)) .addAggregation(pb.symbol("count", BIGINT), expression("count()"), ImmutableList.of()) .addAggregation(pb.symbol("count_on_x", BIGINT), expression("count(x)"), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() .setOutputRowCount(100)
.setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a"))))))); }))
.setSystemProperty(TASK_CONCURRENCY, "1") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); }))
.setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); }))
.source( p.join( JoinNode.Type.LEFT, Optional.empty(), Optional.empty())) .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("COL1")))) .matches( project(ImmutableMap.of(
.source(p.join( JoinNode.Type.LEFT, p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"), expressions("11"))), Optional.empty(), Optional.empty())) .addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol("COL1")))) .doesNotFire(); .source( p.join( JoinNode.Type.LEFT, .build(), p.aggregation(builder -> builder.singleGroupingSet(p.symbol("COL1"), p.symbol("unused")) .source( p.values( ImmutableList.of(p.symbol("COL1"), p.symbol("unused")), Optional.empty(), Optional.empty())) .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("COL1")))) .doesNotFire();
.source(p.join( JoinNode.Type.RIGHT, p.values(p.symbol("COL2")), Optional.empty(), Optional.empty())) .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("COL1")))) .matches( project(ImmutableMap.of(
@Test public void testMultipleInputs() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("corr(DISTINCT x, y)"), ImmutableList.of(REAL, REAL)) .addAggregation(p.symbol("output2"), expression("corr(DISTINCT y, x)"), ImmutableList.of(REAL, REAL)) .source( p.values(p.symbol("x"), p.symbol("y"))))) .matches( aggregation( globalAggregation(), ImmutableMap.<Optional<String>, ExpectedValueProvider<FunctionCall>>builder() .put(Optional.of("output1"), functionCall("corr", ImmutableList.of("x", "y"))) .put(Optional.of("output2"), functionCall("corr", ImmutableList.of("y", "x"))) .build(), ImmutableMap.of(), Optional.empty(), SINGLE, aggregation( singleGroupingSet("x", "y"), ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), SINGLE, values("x", "y")))); } }
@Test public void testMultipleAggregations() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("sum(DISTINCT input)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"))))) .matches( aggregation( globalAggregation(), ImmutableMap.<Optional<String>, ExpectedValueProvider<FunctionCall>>builder() .put(Optional.of("output1"), functionCall("count", ImmutableList.of("input"))) .put(Optional.of("output2"), functionCall("sum", ImmutableList.of("input"))) .build(), ImmutableMap.of(), Optional.empty(), SINGLE, aggregation( singleGroupingSet("input"), ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), SINGLE, values("input")))); }
.setSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, "true") .on(p -> p.aggregation(ab -> ab .source( p.join( INNER, Optional.of(p.symbol("LEFT_HASH")), Optional.of(p.symbol("RIGHT_HASH")))) .addAggregation(p.symbol("AVG", DOUBLE), expression("AVG(LEFT_AGGR)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("RIGHT_GROUP_BY")) .step(PARTIAL))) .matches(project(ImmutableMap.of( "LEFT_GROUP_BY", PlanMatchPattern.expression("LEFT_GROUP_BY"),
@Test public void testValidateSuccessful() { validatePlan( p -> p.aggregation( a -> a.step(SINGLE) .singleGroupingSet(p.symbol("nationkey")) .source( p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle))))); validatePlan( p -> p.aggregation( a -> a.step(SINGLE) .singleGroupingSet(p.symbol("unique"), p.symbol("nationkey")) .preGroupedSymbols(p.symbol("unique"), p.symbol("nationkey")) .source( p.assignUniqueId(p.symbol("unique"), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle)))))); }
@Test public void test() { assertRuleApplication() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(geometry)"), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("geometry"))))) .matches( aggregation( ImmutableMap.of("sp", functionCall("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), project( ImmutableMap.of("partition_count", expression("100"), "envelope", expression("ST_Envelope(geometry)")), values("geometry")))); assertRuleApplication() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(ST_Envelope(geometry))"), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("geometry"))))) .matches( aggregation( ImmutableMap.of("sp", functionCall("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), project( ImmutableMap.of("partition_count", expression("100"), "envelope", expression("ST_Envelope(geometry)")), values("geometry")))); }
Symbol totalPrice = p.symbol("total_price", DOUBLE); AggregationNode inner = p.aggregation((a) -> a .addAggregation(totalPrice, new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(new SymbolReference("totalprice"))), ImmutableList.of(DOUBLE)) .globalGrouping() .source( p.project( Assignments.of(totalPrice, totalPrice.toSymbolReference()), .addAggregation( p.symbol("sum_outer", DOUBLE), new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(new SymbolReference("sum_inner"))), ImmutableList.of(DOUBLE)) .globalGrouping() .source(inner)); }).doesNotFire();
@Test public void rewritesOnSubqueryWithProjection() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionRegistry())) .on(p -> p.lateral( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project(Assignments.of(p.symbol("expr"), p.expression("sum + 1")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .globalGrouping())))) .matches( project(ImmutableMap.of("corr", expression("corr"), "expr", expression("(\"sum_1\" + 1)")), aggregation(ImmutableMap.of("sum_1", functionCall("sum", ImmutableList.of("a"))), join(JoinNode.Type.LEFT, ImmutableList.of(), assignUniqueId("unique", values(ImmutableMap.of("corr", 0))), project(ImmutableMap.of("non_null", expression("true")), values(ImmutableMap.of("a", 0, "b", 1))))))); } }