private List<Expression> getJoinPredicates(Set<Symbol> leftSymbols, Set<Symbol> rightSymbols) { ImmutableList.Builder<Expression> joinPredicatesBuilder = ImmutableList.builder(); // This takes all conjuncts that were part of allFilters that // could not be used for equality inference. // If they use both the left and right symbols, we add them to the list of joinPredicates stream(nonInferrableConjuncts(allFilter)) .map(conjunct -> allFilterInference.rewriteExpression(conjunct, symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol))) .filter(Objects::nonNull) // filter expressions that contain only left or right symbols .filter(conjunct -> allFilterInference.rewriteExpression(conjunct, leftSymbols::contains) == null) .filter(conjunct -> allFilterInference.rewriteExpression(conjunct, rightSymbols::contains) == null) .forEach(joinPredicatesBuilder::add); // create equality inference on available symbols // TODO: make generateEqualitiesPartitionedBy take left and right scope List<Expression> joinEqualities = allFilterInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)).getScopeEqualities(); EqualityInference joinInference = createEqualityInference(joinEqualities.toArray(new Expression[0])); joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeStraddlingEqualities()); return joinPredicatesBuilder.build(); }
private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) { EqualityInference equalityInference = createEqualityInference(expression); ImmutableList.Builder<Expression> effectiveConjuncts = ImmutableList.builder(); for (Expression conjunct : EqualityInference.nonInferrableConjuncts(expression)) { if (DeterminismEvaluator.isDeterministic(conjunct)) { Expression rewritten = equalityInference.rewriteExpression(conjunct, in(symbols)); if (rewritten != null) { effectiveConjuncts.add(rewritten); } } } effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(symbols)).getScopeEqualities()); return combineConjuncts(effectiveConjuncts.build()); } }
private JoinEnumerationResult getJoinSource(LinkedHashSet<PlanNode> nodes, List<Symbol> outputSymbols) { if (nodes.size() == 1) { PlanNode planNode = getOnlyElement(nodes); ImmutableList.Builder<Expression> predicates = ImmutableList.builder(); predicates.addAll(allFilterInference.generateEqualitiesPartitionedBy(outputSymbols::contains).getScopeEqualities()); stream(nonInferrableConjuncts(allFilter)) .map(conjunct -> allFilterInference.rewriteExpression(conjunct, outputSymbols::contains)) .filter(Objects::nonNull) .forEach(predicates::add); Expression filter = combineConjuncts(predicates.build()); if (!TRUE_LITERAL.equals(filter)) { planNode = new FilterNode(idAllocator.getNextId(), planNode, filter); } return createJoinEnumerationResult(planNode); } return chooseJoinOrder(nodes, outputSymbols); }
assertTrue(emptyScopePartition.getScopeEqualities().isEmpty()); assertFalse(equalityPartition.getScopeEqualities().isEmpty()); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(matchesSymbols("c1")))); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); .addAllEqualities(equalityPartition.getScopeEqualities()) .addAllEqualities(equalityPartition.getScopeComplementEqualities()) .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities())); assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities())); assertEquals(setCopy(equalityPartition.getScopeStraddlingEqualities()), setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
assertFalse(equalityPartition.getScopeEqualities().isEmpty()); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(symbolBeginsWith("a", "b")))); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); .addAllEqualities(equalityPartition.getScopeEqualities()) .addAllEqualities(equalityPartition.getScopeComplementEqualities()) .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities())); assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities())); assertEquals(setCopy(equalityPartition.getScopeStraddlingEqualities()), setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
filteringSourceConjuncts.addAll(ImmutableList.copyOf(transform(joinInferenceEqualityPartition.getScopeEqualities(), expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol()))))); sourceConjuncts.addAll(equalityPartition.getScopeEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
@Test public void testConstantEqualities() throws Exception { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }
private Set<Expression> normalizeConjuncts(Expression predicate) { // Normalize the predicate by identity so that the EqualityInference will produce stable rewrites in this test // and thereby produce comparable Sets of conjuncts from this method. predicate = expressionNormalizer.normalize(predicate); // Equality inference rewrites and equality generation will always be stable across multiple runs in the same JVM EqualityInference inference = EqualityInference.createEqualityInference(predicate); Set<Expression> rewrittenSet = new HashSet<>(); for (Expression expression : EqualityInference.nonInferrableConjuncts(predicate)) { Expression rewritten = inference.rewriteExpression(expression, Predicates.<Symbol>alwaysTrue()); Preconditions.checkState(rewritten != null, "Rewrite with full symbol scope should always be possible"); rewrittenSet.add(rewritten); } rewrittenSet.addAll(inference.generateEqualitiesPartitionedBy(Predicates.<Symbol>alwaysTrue()).getScopeEqualities()); return rewrittenSet; }
private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) { EqualityInference equalityInference = createEqualityInference(expression); ImmutableList.Builder<Expression> effectiveConjuncts = ImmutableList.builder(); for (Expression conjunct : EqualityInference.nonInferrableConjuncts(expression)) { if (DeterminismEvaluator.isDeterministic(conjunct)) { Expression rewritten = equalityInference.rewriteExpression(conjunct, in(symbols)); if (rewritten != null) { effectiveConjuncts.add(rewritten); } } } effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(symbols)).getScopeEqualities()); return combineConjuncts(effectiveConjuncts.build()); } }
sourceConjuncts.addAll(allInferenceWithoutSourceInferred.generateEqualitiesPartitionedBy(in(sourceSymbols)).getScopeEqualities()); filteringSourceConjuncts.addAll(allInferenceWithoutFilteringSourceInferred.generateEqualitiesPartitionedBy(in(filteringSourceSymbols)).getScopeEqualities());
Expression outerOnlyInheritedEqualities = combineConjuncts(equalityPartition.getScopeEqualities()); EqualityInference potentialNullSymbolInference = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate); outerPushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); innerPushdownConjuncts.addAll(potentialNullSymbolInferenceWithoutInnerInferred.generateEqualitiesPartitionedBy(not(in(outerSymbols))).getScopeEqualities()); innerPushdownConjuncts.addAll(joinEqualityPartition.getScopeEqualities()); joinConjuncts.addAll(joinEqualityPartition.getScopeComplementEqualities()) .addAll(joinEqualityPartition.getScopeStraddlingEqualities());
leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeEqualities()); rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(not(in(leftSymbols))).getScopeEqualities());
pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
assertTrue(emptyScopePartition.getScopeEqualities().isEmpty()); assertFalse(equalityPartition.getScopeEqualities().isEmpty()); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(matchesSymbols("c1")))); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); .addAllEqualities(equalityPartition.getScopeEqualities()) .addAllEqualities(equalityPartition.getScopeComplementEqualities()) .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities())); assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities())); assertEquals(setCopy(equalityPartition.getScopeStraddlingEqualities()), setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
assertFalse(equalityPartition.getScopeEqualities().isEmpty()); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(symbolBeginsWith("a", "b")))); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); .addAllEqualities(equalityPartition.getScopeEqualities()) .addAllEqualities(equalityPartition.getScopeComplementEqualities()) .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities())); assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities())); assertEquals(setCopy(equalityPartition.getScopeStraddlingEqualities()), setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
sourceConjuncts.addAll(equalityPartition.getScopeEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
@Test public void testConstantEqualities() { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }
private Set<Expression> normalizeConjuncts(Expression predicate) { // Normalize the predicate by identity so that the EqualityInference will produce stable rewrites in this test // and thereby produce comparable Sets of conjuncts from this method. predicate = expressionNormalizer.normalize(predicate); // Equality inference rewrites and equality generation will always be stable across multiple runs in the same JVM EqualityInference inference = EqualityInference.createEqualityInference(predicate); Set<Expression> rewrittenSet = new HashSet<>(); for (Expression expression : EqualityInference.nonInferrableConjuncts(predicate)) { Expression rewritten = inference.rewriteExpression(expression, Predicates.alwaysTrue()); Preconditions.checkState(rewritten != null, "Rewrite with full symbol scope should always be possible"); rewrittenSet.add(rewritten); } rewrittenSet.addAll(inference.generateEqualitiesPartitionedBy(Predicates.alwaysTrue()).getScopeEqualities()); return rewrittenSet; }