JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit) { requireNonNull(node, "node is null"); checkState(node.getType() == INNER, "join type must be INNER"); this.outputSymbols = node.getOutputSymbols(); this.lookup = requireNonNull(lookup, "lookup is null"); flattenNode(node, sourceLimit); }
.addAll(joinNode.getOutputSymbols()) .addAll( joinNode.getFilter()
@Override public Result apply(JoinNode joinNode, Captures captures, Context context) { Expression filter = joinNode.getFilter().get(); List<FunctionCall> spatialFunctions = extractSupportedSpatialFunctions(filter); for (FunctionCall spatialFunction : spatialFunctions) { Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); if (!result.isEmpty()) { return result; } } List<ComparisonExpression> spatialComparisons = extractSupportedSpatialComparisons(filter); for (ComparisonExpression spatialComparison : spatialComparisons) { Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); if (!result.isEmpty()) { return result; } } return Result.empty(); } }
.add(combineConjuncts(joinConjuncts)) .add(node.getFilter().orElse(TRUE_LITERAL)) .build()), node.getOutputSymbols()); case LEFT: return combineConjuncts(ImmutableList.<Expression>builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .build()); case RIGHT: return combineConjuncts(ImmutableList.<Expression>builder() .add(pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .build()); case FULL: return combineConjuncts(ImmutableList.<Expression>builder() .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains, node.getRight().getOutputSymbols()::contains)) .build()); default:
.filter(symbol -> node.getOutputSymbols().contains(symbol) || hashSymbolsWithParentPreferences.values().contains(symbol)) .collect(toImmutableList());
private OperatorFactory createLookupJoin( JoinNode node, PhysicalOperation probeSource, List<Symbol> probeSymbols, Optional<Symbol> probeHashSymbol, JoinBridgeManager<? extends LookupSourceFactory> lookupSourceFactoryManager, LocalExecutionPlanContext context) { List<Type> probeTypes = probeSource.getTypes(); List<Symbol> probeOutputSymbols = node.getOutputSymbols().stream() .filter(symbol -> node.getLeft().getOutputSymbols().contains(symbol)) .collect(toImmutableList()); List<Integer> probeOutputChannels = ImmutableList.copyOf(getChannelsForSymbols(probeOutputSymbols, probeSource.getLayout())); List<Integer> probeJoinChannels = ImmutableList.copyOf(getChannelsForSymbols(probeSymbols, probeSource.getLayout())); OptionalInt probeHashChannel = probeHashSymbol.map(channelGetter(probeSource)) .map(OptionalInt::of).orElse(OptionalInt.empty()); OptionalInt totalOperatorsCount = getJoinOperatorsCountForSpill(context, session); switch (node.getType()) { case INNER: return lookupJoinOperators.innerJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactoryManager, probeTypes, probeJoinChannels, probeHashChannel, Optional.of(probeOutputChannels), totalOperatorsCount, partitioningSpillerFactory); case LEFT: return lookupJoinOperators.probeOuterJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactoryManager, probeTypes, probeJoinChannels, probeHashChannel, Optional.of(probeOutputChannels), totalOperatorsCount, partitioningSpillerFactory); case RIGHT: return lookupJoinOperators.lookupOuterJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactoryManager, probeTypes, probeJoinChannels, probeHashChannel, Optional.of(probeOutputChannels), totalOperatorsCount, partitioningSpillerFactory); case FULL: return lookupJoinOperators.fullOuterJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactoryManager, probeTypes, probeJoinChannels, probeHashChannel, Optional.of(probeOutputChannels), totalOperatorsCount, partitioningSpillerFactory); default: throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); } }
outputSymbols = node.getOutputSymbols().stream() .filter(context.get()::contains) .distinct()
@Override public Void visitJoin(JoinNode node, Integer indent) { List<Expression> joinExpressions = new ArrayList<>(); for (JoinNode.EquiJoinClause clause : node.getCriteria()) { joinExpressions.add(clause.toExpression()); } node.getFilter().ifPresent(joinExpressions::add); if (node.isCrossJoin()) { checkState(joinExpressions.isEmpty()); print(indent, "- CrossJoin => [%s]", formatOutputs(node.getOutputSymbols())); } else { print(indent, "- %s[%s]%s => [%s]", node.getType().getJoinLabel(), Joiner.on(" AND ").join(joinExpressions), formatHash(node.getLeftHashSymbol(), node.getRightHashSymbol()), formatOutputs(node.getOutputSymbols())); } node.getDistributionType().ifPresent(distributionType -> print(indent + 2, "Distribution: %s", distributionType)); node.getSortExpressionContext().ifPresent(context -> print(indent + 2, "SortExpression[%s]", context.getSortExpression())); printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); node.getLeft().accept(this, indent + 1); node.getRight().accept(this, indent + 1); return null; }
if (!indexJoinNode.getOutputSymbols().equals(node.getOutputSymbols())) { indexJoinNode = new ProjectNode( idAllocator.getNextId(), indexJoinNode, Assignments.identity(node.getOutputSymbols())); return createIndexJoinWithExpectedOutputs(node.getOutputSymbols(), IndexJoinNode.Type.SOURCE_OUTER, leftRewritten, rightIndexCandidate.get(), createEquiJoinClause(leftJoinSymbols, rightJoinSymbols), idAllocator); return createIndexJoinWithExpectedOutputs(node.getOutputSymbols(), IndexJoinNode.Type.SOURCE_OUTER, rightRewritten, leftIndexCandidate.get(), createEquiJoinClause(rightJoinSymbols, leftJoinSymbols), idAllocator); return new JoinNode(node.getId(), node.getType(), leftRewritten, rightRewritten, node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType());
public JoinNode flipChildren() { return new JoinNode( getId(), flipType(type), right, left, flipJoinCriteria(criteria), flipOutputSymbols(getOutputSymbols(), left, right), filter, rightHashSymbol, leftHashSymbol, distributionType); }
case INNER: return leftProperties .translate(column -> PropertyDerivations.filterOrRewrite(node.getOutputSymbols(), node.getCriteria(), column)) .unordered(unordered); case LEFT: return leftProperties .translate(column -> PropertyDerivations.filterIfMissing(node.getOutputSymbols(), column)) .unordered(unordered); case RIGHT:
@Override public Void visitJoin(JoinNode node, Set<Symbol> boundSymbols) { node.getLeft().accept(this, boundSymbols); node.getRight().accept(this, boundSymbols); Set<Symbol> leftInputs = createInputs(node.getLeft(), boundSymbols); Set<Symbol> rightInputs = createInputs(node.getRight(), boundSymbols); Set<Symbol> allInputs = ImmutableSet.<Symbol>builder() .addAll(leftInputs) .addAll(rightInputs) .build(); for (JoinNode.EquiJoinClause clause : node.getCriteria()) { checkArgument(leftInputs.contains(clause.getLeft()), "Symbol from join clause (%s) not in left source (%s)", clause.getLeft(), node.getLeft().getOutputSymbols()); checkArgument(rightInputs.contains(clause.getRight()), "Symbol from join clause (%s) not in right source (%s)", clause.getRight(), node.getRight().getOutputSymbols()); } node.getFilter().ifPresent(predicate -> { Set<Symbol> predicateSymbols = SymbolsExtractor.extractUnique(predicate); checkArgument( allInputs.containsAll(predicateSymbols), "Symbol from filter (%s) not in sources (%s)", predicateSymbols, allInputs); }); checkLeftOutputSymbolsBeforeRight(node.getLeft().getOutputSymbols(), node.getOutputSymbols()); return null; }
@Override public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context) { PlanNode left = context.rewrite(node.getLeft()); PlanNode right = context.rewrite(node.getRight()); List<JoinNode.EquiJoinClause> canonicalCriteria = canonicalizeJoinCriteria(node.getCriteria()); Optional<Expression> canonicalFilter = node.getFilter().map(this::canonicalize); Optional<Symbol> canonicalLeftHashSymbol = canonicalize(node.getLeftHashSymbol()); Optional<Symbol> canonicalRightHashSymbol = canonicalize(node.getRightHashSymbol()); if (node.getType().equals(INNER)) { canonicalCriteria.stream() .filter(clause -> types.get(clause.getLeft()).equals(types.get(clause.getRight()))) .filter(clause -> node.getOutputSymbols().contains(clause.getLeft())) .forEach(clause -> map(clause.getRight(), clause.getLeft())); } return new JoinNode(node.getId(), node.getType(), left, right, canonicalCriteria, canonicalizeAndDistinct(node.getOutputSymbols()), canonicalFilter, canonicalLeftHashSymbol, canonicalRightHashSymbol, node.getDistributionType()); }
private PlanWithProperties buildJoin(JoinNode node, PlanWithProperties newLeft, PlanWithProperties newRight, JoinNode.DistributionType newDistributionType) { JoinNode result = new JoinNode(node.getId(), node.getType(), newLeft.getNode(), newRight.getNode(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), Optional.of(newDistributionType)); return new PlanWithProperties(result, deriveProperties(result, ImmutableList.of(newLeft.getProperties(), newRight.getProperties()))); }
@VisibleForTesting public static final class ExtractSpatialLeftJoin implements Rule<JoinNode> { private static final Pattern<JoinNode> PATTERN = join().matching(node -> node.getCriteria().isEmpty() && node.getFilter().isPresent() && node.getType() == LEFT); private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; private final SqlParser sqlParser; public ExtractSpatialLeftJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); } @Override public boolean isEnabled(Session session) { return isSpatialJoinEnabled(session); } @Override public Pattern<JoinNode> getPattern() { return PATTERN; }
@Override protected Optional<PlanNode> pushDownProjectOff(PlanNodeIdAllocator idAllocator, JoinNode joinNode, Set<Symbol> referencedOutputs) { return Optional.of( new JoinNode( joinNode.getId(), joinNode.getType(), joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), filteredCopy(joinNode.getOutputSymbols(), referencedOutputs::contains), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType())); } }
@Override public Result apply(JoinNode joinNode, Captures captures, Context context) { Optional<Expression> filter = joinNode.getFilter().map(x -> rewriter.rewrite(x, context)); if (!joinNode.getFilter().equals(filter)) { return Result.ofPlanNode(new JoinNode( joinNode.getId(), joinNode.getType(), joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputSymbols(), filter, joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType())); } return Result.empty(); } }
probeProperties = probeProperties.translate(column -> filterOrRewrite(node.getOutputSymbols(), node.getCriteria(), column)); buildProperties = buildProperties.translate(column -> filterOrRewrite(node.getOutputSymbols(), node.getCriteria(), column)); .build(); case LEFT: return ActualProperties.builderFrom(probeProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column))) .unordered(unordered) .build(); case RIGHT: buildProperties = buildProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column)); return ActualProperties.builderFrom(buildProperties.translate(column -> filterIfMissing(node.getOutputSymbols(), column))) .local(ImmutableList.of()) .unordered(unordered)
@Override public Result apply(JoinNode node, Captures captures, Context context) { JoinGraph joinGraph = JoinGraph.buildShallowFrom(node, context.getLookup()); if (joinGraph.size() < 3) { return Result.empty(); } List<Integer> joinOrder = getJoinOrder(joinGraph); if (isOriginalOrder(joinOrder)) { return Result.empty(); } PlanNode replacement = buildJoinTree(node.getOutputSymbols(), joinGraph, joinOrder, context.getIdAllocator()); return Result.ofPlanNode(replacement); }
private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheritedPredicate) { checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getType()), "Unsupported join type: %s", node.getType()); if (node.getType() == JoinNode.Type.INNER) { return node; } if (node.getType() == JoinNode.Type.FULL) { boolean canConvertToLeftJoin = canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate); boolean canConvertToRightJoin = canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate); if (!canConvertToLeftJoin && !canConvertToRightJoin) { return node; } if (canConvertToLeftJoin && canConvertToRightJoin) { return new JoinNode(node.getId(), INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); } else { return new JoinNode(node.getId(), canConvertToLeftJoin ? LEFT : RIGHT, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); } } if (node.getType() == JoinNode.Type.LEFT && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate) || node.getType() == JoinNode.Type.RIGHT && !canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate)) { return node; } return new JoinNode(node.getId(), JoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); }