@Test public void testGroupByEmpty() { PlanNode node = new AggregationNode( newId(), filter(baseTableScan, FALSE_LITERAL), ImmutableMap.of(), globalAggregation(), ImmutableList.of(), AggregationNode.Step.FINAL, Optional.empty(), Optional.empty()); Expression effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(effectivePredicate, TRUE_LITERAL); }
public PredicatePushDown(Metadata metadata, SqlParser sqlParser) { this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); this.effectivePredicateExtractor = new EffectivePredicateExtractor(new DomainTranslator(literalEncoder)); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); }
@Override public Expression visitProject(ProjectNode node, Void context) { // TODO: add simple algebraic solver for projection translation (right now only considers identity projections) Expression underlyingPredicate = node.getSource().accept(this, context); List<Expression> projectionEqualities = node.getAssignments().entrySet().stream() .filter(SYMBOL_MATCHES_EXPRESSION.negate()) .map(ENTRY_TO_EQUALITY) .collect(toImmutableList()); return pullExpressionThroughSymbols(combineConjuncts( ImmutableList.<Expression>builder() .addAll(projectionEqualities) .add(underlyingPredicate) .build()), node.getOutputSymbols()); }
@Override public Expression visitUnion(UnionNode node, Void context) { return deriveCommonPredicates(node, source -> node.outputSymbolMap(source).entries()); }
return combineConjuncts(ImmutableList.<Expression>builder() .add(leftPredicate) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), in(node.getRight().getOutputSymbols()))) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, in(node.getRight().getOutputSymbols()))) .build()); case RIGHT: return combineConjuncts(ImmutableList.<Expression>builder() .add(rightPredicate) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), in(node.getLeft().getOutputSymbols()))) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, in(node.getLeft().getOutputSymbols()))) .build()); case FULL: return combineConjuncts(ImmutableList.<Expression>builder() .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), in(node.getLeft().getOutputSymbols()))) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), in(node.getRight().getOutputSymbols()))) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, in(node.getLeft().getOutputSymbols()), in(node.getRight().getOutputSymbols()))) .build()); default:
@Override public Expression visitTableScan(TableScanNode node, Void context) { Map<ColumnHandle, Symbol> assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); return DomainTranslator.toPredicate(spanTupleDomain(node.getCurrentConstraint()).transform(assignments::get)); }
private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Collection<Map.Entry<Symbol, QualifiedNameReference>>> mapping) { // Find the predicates that can be pulled up from each source List<Set<Expression>> sourceOutputConjuncts = new ArrayList<>(); for (int i = 0; i < node.getSources().size(); i++) { Expression underlyingPredicate = node.getSources().get(i).accept(this, null); List<Expression> equalities = mapping.apply(i).stream() .filter(SYMBOL_MATCHES_EXPRESSION.negate()) .map(ENTRY_TO_EQUALITY) .collect(toImmutableList()); sourceOutputConjuncts.add(ImmutableSet.copyOf(extractConjuncts(pullExpressionThroughSymbols(combineConjuncts( ImmutableList.<Expression>builder() .addAll(equalities) .add(underlyingPredicate) .build()), node.getOutputSymbols())))); } // Find the intersection of predicates across all sources // TODO: use a more precise way to determine overlapping conjuncts (e.g. commutative predicates) Iterator<Set<Expression>> iterator = sourceOutputConjuncts.iterator(); Set<Expression> potentialOutputConjuncts = iterator.next(); while (iterator.hasNext()) { potentialOutputConjuncts = Sets.intersection(potentialOutputConjuncts, iterator.next()); } return combineConjuncts(potentialOutputConjuncts); }
@Override public Expression visitExchange(ExchangeNode node, Void context) { return deriveCommonPredicates(node, source -> { Map<Symbol, QualifiedNameReference> mappings = new HashMap<>(); for (int i = 0; i < node.getInputs().get(source).size(); i++) { mappings.put( node.getOutputSymbols().get(i), node.getInputs().get(source).get(i).toQualifiedNameReference()); } return mappings.entrySet(); }); }
@Test public void testInnerJoinWithFalseFilter() { Map<Symbol, ColumnHandle> leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C))); TableScanNode leftScan = tableScanNode(leftAssignments); Map<Symbol, ColumnHandle> rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); TableScanNode rightScan = tableScanNode(rightAssignments); PlanNode node = new JoinNode(newId(), JoinNode.Type.INNER, leftScan, rightScan, ImmutableList.of(new JoinNode.EquiJoinClause(A, D)), ImmutableList.<Symbol>builder() .addAll(leftScan.getOutputSymbols()) .addAll(rightScan.getOutputSymbols()) .build(), Optional.of(FALSE_LITERAL), Optional.empty(), Optional.empty(), Optional.empty()); Expression effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(effectivePredicate, FALSE_LITERAL); }
public static Expression extract(PlanNode node, Map<Symbol, Type> symbolTypes) { return node.accept(new EffectivePredicateExtractor(symbolTypes), null); }
@Override public Expression visitAggregation(AggregationNode node, Void context) { Expression underlyingPredicate = node.getSource().accept(this, context); return pullExpressionThroughSymbols(underlyingPredicate, node.getGroupBy()); }
@Test public void testInnerJoinPropagatesPredicatesViaEquiConditions() { Map<Symbol, ColumnHandle> leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C))); TableScanNode leftScan = tableScanNode(leftAssignments); Map<Symbol, ColumnHandle> rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(D, E, F))); TableScanNode rightScan = tableScanNode(rightAssignments); FilterNode left = filter(leftScan, equals(AE, bigintLiteral(10))); // predicates on "a" column should be propagated to output symbols via join equi conditions PlanNode node = new JoinNode(newId(), JoinNode.Type.INNER, left, rightScan, ImmutableList.of(new JoinNode.EquiJoinClause(A, D)), ImmutableList.<Symbol>builder() .addAll(rightScan.getOutputSymbols()) .build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); Expression effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals( normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(DE, bigintLiteral(10)))); }
ImmutableList.copyOf(assignments.keySet()), assignments); Expression effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(effectivePredicate, BooleanLiteral.TRUE_LITERAL); TupleDomain.none(), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(effectivePredicate, FALSE_LITERAL); TupleDomain.withColumnDomains(ImmutableMap.of(scanAssignments.get(A), Domain.singleValue(BIGINT, 1L))), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(1L), AE))); scanAssignments.get(B), Domain.singleValue(BIGINT, 2L))), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(2L), BE), equals(bigintLiteral(1L), AE))); TupleDomain.all(), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); assertEquals(effectivePredicate, BooleanLiteral.TRUE_LITERAL);
Expression sourceEffectivePredicate = filterDeterministicConjuncts(effectivePredicateExtractor.extract(node.getSource())); Expression filteringSourceEffectivePredicate = filterDeterministicConjuncts(effectivePredicateExtractor.extract(node.getFilteringSource())); Expression joinExpression = new ComparisonExpression( ComparisonExpression.Operator.EQUAL,
@Test public void testFilter() { PlanNode node = filter(baseTableScan, and( greaterThan(AE, new FunctionCall(QualifiedName.of("rand"), ImmutableList.of())), lessThan(BE, bigintLiteral(10)))); Expression effectivePredicate = effectivePredicateExtractor.extract(node); // Non-deterministic functions should be purged assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(lessThan(BE, bigintLiteral(10)))); }
Optional.empty()); Expression effectivePredicate = effectivePredicateExtractor.extract(node);
Optional.empty()); Expression effectivePredicate = effectivePredicateExtractor.extract(node);
@Test public void testWindow() { PlanNode node = new WindowNode(newId(), filter(baseTableScan, and( equals(AE, BE), equals(BE, CE), lessThan(CE, bigintLiteral(10)))), new WindowNode.Specification( ImmutableList.of(A), Optional.of(new OrderingScheme( ImmutableList.of(A), ImmutableMap.of(A, SortOrder.ASC_NULLS_LAST)))), ImmutableMap.of(), Optional.empty(), ImmutableSet.of(), 0); Expression effectivePredicate = effectivePredicateExtractor.extract(node); // Pass through assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts( equals(AE, BE), equals(BE, CE), lessThan(CE, bigintLiteral(10)))); }
@Test public void testSemiJoin() { PlanNode node = new SemiJoinNode(newId(), filter(baseTableScan, and(greaterThan(AE, bigintLiteral(10)), lessThan(AE, bigintLiteral(100)))), filter(baseTableScan, greaterThan(AE, bigintLiteral(5))), A, B, C, Optional.empty(), Optional.empty(), Optional.empty()); Expression effectivePredicate = effectivePredicateExtractor.extract(node); // Currently, only pull predicates through the source plan assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(and(greaterThan(AE, bigintLiteral(10)), lessThan(AE, bigintLiteral(100))))); }