@Override public Expression visitProject(ProjectNode node, Void context) { // TODO: add simple algebraic solver for projection translation (right now only considers identity projections) Expression underlyingPredicate = node.getSource().accept(this, context); List<Expression> projectionEqualities = node.getAssignments().entrySet().stream() .filter(SYMBOL_MATCHES_EXPRESSION.negate()) .map(ENTRY_TO_EQUALITY) .collect(toImmutableList()); return pullExpressionThroughSymbols(combineConjuncts( ImmutableList.<Expression>builder() .addAll(projectionEqualities) .add(underlyingPredicate) .build()), node.getOutputSymbols()); }
private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Collection<Map.Entry<Symbol, QualifiedNameReference>>> mapping) { // Find the predicates that can be pulled up from each source List<Set<Expression>> sourceOutputConjuncts = new ArrayList<>(); for (int i = 0; i < node.getSources().size(); i++) { Expression underlyingPredicate = node.getSources().get(i).accept(this, null); List<Expression> equalities = mapping.apply(i).stream() .filter(SYMBOL_MATCHES_EXPRESSION.negate()) .map(ENTRY_TO_EQUALITY) .collect(toImmutableList()); sourceOutputConjuncts.add(ImmutableSet.copyOf(extractConjuncts(pullExpressionThroughSymbols(combineConjuncts( ImmutableList.<Expression>builder() .addAll(equalities) .add(underlyingPredicate) .build()), node.getOutputSymbols())))); } // Find the intersection of predicates across all sources // TODO: use a more precise way to determine overlapping conjuncts (e.g. commutative predicates) Iterator<Set<Expression>> iterator = sourceOutputConjuncts.iterator(); Set<Expression> potentialOutputConjuncts = iterator.next(); while (iterator.hasNext()) { potentialOutputConjuncts = Sets.intersection(potentialOutputConjuncts, iterator.next()); } return combineConjuncts(potentialOutputConjuncts); }
@Override public Expression visitAggregation(AggregationNode node, Void context) { Expression underlyingPredicate = node.getSource().accept(this, context); return pullExpressionThroughSymbols(underlyingPredicate, node.getGroupBy()); }