private List<Expression> getJoinPredicates(Set<Symbol> leftSymbols, Set<Symbol> rightSymbols) { ImmutableList.Builder<Expression> joinPredicatesBuilder = ImmutableList.builder(); // This takes all conjuncts that were part of allFilters that // could not be used for equality inference. // If they use both the left and right symbols, we add them to the list of joinPredicates stream(nonInferrableConjuncts(allFilter)) .map(conjunct -> allFilterInference.rewriteExpression(conjunct, symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol))) .filter(Objects::nonNull) // filter expressions that contain only left or right symbols .filter(conjunct -> allFilterInference.rewriteExpression(conjunct, leftSymbols::contains) == null) .filter(conjunct -> allFilterInference.rewriteExpression(conjunct, rightSymbols::contains) == null) .forEach(joinPredicatesBuilder::add); // create equality inference on available symbols // TODO: make generateEqualitiesPartitionedBy take left and right scope List<Expression> joinEqualities = allFilterInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)).getScopeEqualities(); EqualityInference joinInference = createEqualityInference(joinEqualities.toArray(new Expression[0])); joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeStraddlingEqualities()); return joinPredicatesBuilder.build(); }
private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) { EqualityInference equalityInference = createEqualityInference(expression); ImmutableList.Builder<Expression> effectiveConjuncts = ImmutableList.builder(); for (Expression conjunct : EqualityInference.nonInferrableConjuncts(expression)) { if (DeterminismEvaluator.isDeterministic(conjunct)) { Expression rewritten = equalityInference.rewriteExpression(conjunct, in(symbols)); if (rewritten != null) { effectiveConjuncts.add(rewritten); } } } effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(symbols)).getScopeEqualities()); return combineConjuncts(effectiveConjuncts.build()); } }
private JoinEnumerationResult getJoinSource(LinkedHashSet<PlanNode> nodes, List<Symbol> outputSymbols) { if (nodes.size() == 1) { PlanNode planNode = getOnlyElement(nodes); ImmutableList.Builder<Expression> predicates = ImmutableList.builder(); predicates.addAll(allFilterInference.generateEqualitiesPartitionedBy(outputSymbols::contains).getScopeEqualities()); stream(nonInferrableConjuncts(allFilter)) .map(conjunct -> allFilterInference.rewriteExpression(conjunct, outputSymbols::contains)) .filter(Objects::nonNull) .forEach(predicates::add); Expression filter = combineConjuncts(predicates.build()); if (!TRUE_LITERAL.equals(filter)) { planNode = new FilterNode(idAllocator.getNextId(), planNode, filter); } return createJoinEnumerationResult(planNode); } return chooseJoinOrder(nodes, outputSymbols); }
leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeEqualities()); rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(not(in(leftSymbols))).getScopeEqualities()); joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(in(leftSymbols)::apply).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate
EqualityInference outerInference = createEqualityInference(inheritedPredicate, outerEffectivePredicate); EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(outerSymbols)); Expression outerOnlyInheritedEqualities = combineConjuncts(equalityPartition.getScopeEqualities()); EqualityInference potentialNullSymbolInference = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate); innerPushdownConjuncts.addAll(potentialNullSymbolInferenceWithoutInnerInferred.generateEqualitiesPartitionedBy(not(in(outerSymbols))).getScopeEqualities()); EqualityInference.EqualityPartition joinEqualityPartition = createEqualityInference(joinPredicate).generateEqualitiesPartitionedBy(not(in(outerSymbols))); innerPushdownConjuncts.addAll(joinEqualityPartition.getScopeEqualities()); joinConjuncts.addAll(joinEqualityPartition.getScopeComplementEqualities())
private Set<Expression> normalizeConjuncts(Expression predicate) { // Normalize the predicate by identity so that the EqualityInference will produce stable rewrites in this test // and thereby produce comparable Sets of conjuncts from this method. predicate = expressionNormalizer.normalize(predicate); // Equality inference rewrites and equality generation will always be stable across multiple runs in the same JVM EqualityInference inference = EqualityInference.createEqualityInference(predicate); Set<Expression> rewrittenSet = new HashSet<>(); for (Expression expression : EqualityInference.nonInferrableConjuncts(predicate)) { Expression rewritten = inference.rewriteExpression(expression, Predicates.alwaysTrue()); Preconditions.checkState(rewritten != null, "Rewrite with full symbol scope should always be possible"); rewrittenSet.add(rewritten); } rewrittenSet.addAll(inference.generateEqualitiesPartitionedBy(Predicates.alwaysTrue()).getScopeEqualities()); return rewrittenSet; }
EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getReplicateSymbols())::apply); pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getGroupingKeys())::apply); pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(node.getSource() .getOutputSymbols())::apply); sourceConjuncts.addAll(equalityPartition.getScopeEqualities());
sourceConjuncts.addAll(allInferenceWithoutSourceInferred.generateEqualitiesPartitionedBy(in(sourceSymbols)).getScopeEqualities()); filteringSourceConjuncts.addAll(allInferenceWithoutFilteringSourceInferred.generateEqualitiesPartitionedBy(in(filteringSourceSymbols)).getScopeEqualities());
EqualityInference.EqualityPartition emptyScopePartition = inference.generateEqualitiesPartitionedBy(Predicates.alwaysFalse()); EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("c1")); .build(); EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(matchesSymbols("c1"));
EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b")); .build(); EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b"));
@Test public void testConstantEqualities() { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }
@Test public void testExpressionsThatMayReturnNullOnNonNullInput() { List<Expression> candidates = ImmutableList.of( new Cast(nameReference("b"), "BIGINT", true), // try_cast new FunctionCall(QualifiedName.of("try"), ImmutableList.of(nameReference("b"))), new NullIfExpression(nameReference("b"), number(1)), new IfExpression(nameReference("b"), number(1), new NullLiteral()), new DereferenceExpression(nameReference("b"), identifier("x")), new InPredicate(nameReference("b"), new InListExpression(ImmutableList.of(new NullLiteral()))), new SearchedCaseExpression(ImmutableList.of(new WhenClause(new IsNotNullPredicate(nameReference("b")), new NullLiteral())), Optional.empty()), new SimpleCaseExpression(nameReference("b"), ImmutableList.of(new WhenClause(number(1), new NullLiteral())), Optional.empty()), new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new NullLiteral())), nameReference("b"))); for (Expression candidate : candidates) { EqualityInference.Builder builder = new EqualityInference.Builder(); builder.extractInferenceCandidates(equals(nameReference("b"), nameReference("x"))); builder.extractInferenceCandidates(equals(nameReference("a"), candidate)); EqualityInference inference = builder.build(); List<Expression> equalities = inference.generateEqualitiesPartitionedBy(matchesSymbols("b")).getScopeStraddlingEqualities(); assertEquals(equalities.size(), 1); assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x")))); } }
private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) { EqualityInference equalityInference = createEqualityInference(expression); ImmutableList.Builder<Expression> effectiveConjuncts = ImmutableList.builder(); for (Expression conjunct : EqualityInference.nonInferrableConjuncts(expression)) { if (DeterminismEvaluator.isDeterministic(conjunct)) { Expression rewritten = equalityInference.rewriteExpression(conjunct, in(symbols)); if (rewritten != null) { effectiveConjuncts.add(rewritten); } } } effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(symbols)).getScopeEqualities()); return combineConjuncts(effectiveConjuncts.build()); } }
EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getReplicateSymbols())); pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
private Set<Expression> normalizeConjuncts(Expression predicate) { // Normalize the predicate by identity so that the EqualityInference will produce stable rewrites in this test // and thereby produce comparable Sets of conjuncts from this method. predicate = expressionNormalizer.normalize(predicate); // Equality inference rewrites and equality generation will always be stable across multiple runs in the same JVM EqualityInference inference = EqualityInference.createEqualityInference(predicate); Set<Expression> rewrittenSet = new HashSet<>(); for (Expression expression : EqualityInference.nonInferrableConjuncts(predicate)) { Expression rewritten = inference.rewriteExpression(expression, Predicates.<Symbol>alwaysTrue()); Preconditions.checkState(rewritten != null, "Rewrite with full symbol scope should always be possible"); rewrittenSet.add(rewritten); } rewrittenSet.addAll(inference.generateEqualitiesPartitionedBy(Predicates.<Symbol>alwaysTrue()).getScopeEqualities()); return rewrittenSet; }
EqualityInference.EqualityPartition joinInferenceEqualityPartition = joinInference.generateEqualitiesPartitionedBy(equalTo(node.getFilteringSourceJoinSymbol())); filteringSourceConjuncts.addAll(ImmutableList.copyOf(transform(joinInferenceEqualityPartition.getScopeEqualities(), expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol()))))); EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(node.getSource().getOutputSymbols())); sourceConjuncts.addAll(equalityPartition.getScopeEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b")); .build(); EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b"));
@Test public void testConstantEqualities() throws Exception { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }