@Override public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<LimitContext> context) { PlanNode source = context.rewrite(node.getSource(), context.get()); if (source != node.getSource()) { return new SemiJoinNode( node.getId(), source, node.getFilteringSource(), node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(), node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol(), node.getDistributionType()); } return node; } }
@Override public Void visitSemiJoin(SemiJoinNode node, Set<Symbol> boundSymbols) { node.getSource().accept(this, boundSymbols); node.getFilteringSource().accept(this, boundSymbols); checkArgument(node.getSource().getOutputSymbols().contains(node.getSourceJoinSymbol()), "Symbol from semi join clause (%s) not in source (%s)", node.getSourceJoinSymbol(), node.getSource().getOutputSymbols()); checkArgument(node.getFilteringSource().getOutputSymbols().contains(node.getFilteringSourceJoinSymbol()), "Symbol from semi join clause (%s) not in filtering source (%s)", node.getSourceJoinSymbol(), node.getFilteringSource().getOutputSymbols()); Set<Symbol> outputs = createInputs(node, boundSymbols); checkArgument(outputs.containsAll(node.getSource().getOutputSymbols()), "Semi join output symbols (%s) must contain all of the source symbols (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols()); checkArgument(outputs.contains(node.getSemiJoinOutput()), "Semi join output symbols (%s) must contain join result (%s)", node.getOutputSymbols(), node.getSemiJoinOutput()); return null; }
@Override public Set<PlanFragmentId> visitSemiJoin(SemiJoinNode node, PlanFragmentId currentFragmentId) { return processJoin(node.getFilteringSource(), node.getSource(), currentFragmentId); }
@Override public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context) { return calculateJoinExchangeCost( node.getSource(), node.getFilteringSource(), stats, types, Objects.equals(node.getDistributionType(), Optional.of(SemiJoinNode.DistributionType.REPLICATED)), taskCountEstimator.estimateSourceDistributedTaskCount()); }
@Override protected Optional<PlanNode> pushDownProjectOff(PlanNodeIdAllocator idAllocator, SemiJoinNode semiJoinNode, Set<Symbol> referencedOutputs) { if (!referencedOutputs.contains(semiJoinNode.getSemiJoinOutput())) { return Optional.of(semiJoinNode.getSource()); } Set<Symbol> requiredSourceInputs = Streams.concat( referencedOutputs.stream() .filter(symbol -> !symbol.equals(semiJoinNode.getSemiJoinOutput())), Stream.of(semiJoinNode.getSourceJoinSymbol()), semiJoinNode.getSourceHashSymbol().map(Stream::of).orElse(Stream.empty())) .collect(toImmutableSet()); return restrictOutputs(idAllocator, semiJoinNode.getSource(), requiredSourceInputs) .map(newSource -> semiJoinNode.replaceChildren(ImmutableList.of( newSource, semiJoinNode.getFilteringSource()))); } }
@Override public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); SemiJoinNode semiJoinNode = (SemiJoinNode) node; if (!(symbolAliases.get(sourceSymbolAlias).equals(semiJoinNode.getSourceJoinSymbol().toSymbolReference()) && symbolAliases.get(filteringSymbolAlias).equals(semiJoinNode.getFilteringSourceJoinSymbol().toSymbolReference()))) { return NO_MATCH; } return match(outputAlias, semiJoinNode.getSemiJoinOutput().toSymbolReference()); }
@Override public Result apply(SemiJoinNode semiJoinNode, Captures captures, Context context) { Set<Symbol> requiredFilteringSourceInputs = Streams.concat( Stream.of(semiJoinNode.getFilteringSourceJoinSymbol()), semiJoinNode.getFilteringSourceHashSymbol().map(Stream::of).orElse(Stream.empty())) .collect(toImmutableSet()); return restrictOutputs(context.getIdAllocator(), semiJoinNode.getFilteringSource(), requiredFilteringSourceInputs) .map(newFilteringSource -> semiJoinNode.replaceChildren(ImmutableList.of(semiJoinNode.getSource(), newFilteringSource))) .map(Result::ofPlanNode) .orElse(Result.empty()); } }
@Override public Result apply(ApplyNode applyNode, Captures captures, Context context) { if (applyNode.getSubqueryAssignments().size() != 1) { return Result.empty(); } Expression expression = getOnlyElement(applyNode.getSubqueryAssignments().getExpressions()); if (!(expression instanceof InPredicate)) { return Result.empty(); } InPredicate inPredicate = (InPredicate) expression; Symbol semiJoinSymbol = getOnlyElement(applyNode.getSubqueryAssignments().getSymbols()); SemiJoinNode replacement = new SemiJoinNode(context.getIdAllocator().getNextId(), applyNode.getInput(), applyNode.getSubquery(), Symbol.from(inPredicate.getValue()), Symbol.from(inPredicate.getValueList()), semiJoinSymbol, Optional.empty(), Optional.empty(), Optional.empty()); return Result.ofPlanNode(replacement); } }
@Override public PlanNode replaceChildren(List<PlanNode> newChildren) { checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes"); return new SemiJoinNode( getId(), newChildren.get(0), newChildren.get(1), sourceJoinSymbol, filteringSourceJoinSymbol, semiJoinOutput, sourceHashSymbol, filteringSourceHashSymbol, distributionType); } }
@Override public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Expression> context) { Expression inheritedPredicate = context.get(); if (!extractConjuncts(inheritedPredicate).contains(node.getSemiJoinOutput().toSymbolReference())) { return visitNonFilteringSemiJoin(node, context); } return visitFilteringSemiJoin(node, context); }
@Override public Void visitSemiJoin(final SemiJoinNode node, Integer indent) { output(indent, "semijoin (%s):", node.getDistributionType().get()); return visitPlan(node, indent + 1); }
@Override public Optional<PlanNodeStatsEstimate> calculate(SemiJoinNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource()); // For now we just propagate statistics for source symbols. // Handling semiJoinOutput symbols requires support for correlation for boolean columns. return Optional.of(sourceStats); } }
@Override protected Optional<PlanNode> pushDownProjectOff(PlanNodeIdAllocator idAllocator, SemiJoinNode semiJoinNode, Set<Symbol> referencedOutputs) { if (!referencedOutputs.contains(semiJoinNode.getSemiJoinOutput())) { return Optional.of(semiJoinNode.getSource()); } Set<Symbol> requiredSourceInputs = Streams.concat( referencedOutputs.stream() .filter(symbol -> !symbol.equals(semiJoinNode.getSemiJoinOutput())), Stream.of(semiJoinNode.getSourceJoinSymbol()), semiJoinNode.getSourceHashSymbol().map(Stream::of).orElse(Stream.empty())) .collect(toImmutableSet()); return restrictOutputs(idAllocator, semiJoinNode.getSource(), requiredSourceInputs) .map(newSource -> semiJoinNode.replaceChildren(ImmutableList.of( newSource, semiJoinNode.getFilteringSource()))); } }
@Override public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); SemiJoinNode semiJoinNode = (SemiJoinNode) node; if (!(symbolAliases.get(sourceSymbolAlias).equals(semiJoinNode.getSourceJoinSymbol().toSymbolReference()) && symbolAliases.get(filteringSymbolAlias).equals(semiJoinNode.getFilteringSourceJoinSymbol().toSymbolReference()))) { return NO_MATCH; } return match(outputAlias, semiJoinNode.getSemiJoinOutput().toSymbolReference()); }
@Override public Result apply(SemiJoinNode semiJoinNode, Captures captures, Context context) { Set<Symbol> requiredFilteringSourceInputs = Streams.concat( Stream.of(semiJoinNode.getFilteringSourceJoinSymbol()), semiJoinNode.getFilteringSourceHashSymbol().map(Stream::of).orElse(Stream.empty())) .collect(toImmutableSet()); return restrictOutputs(context.getIdAllocator(), semiJoinNode.getFilteringSource(), requiredFilteringSourceInputs) .map(newFilteringSource -> semiJoinNode.replaceChildren(ImmutableList.of(semiJoinNode.getSource(), newFilteringSource))) .map(Result::ofPlanNode) .orElse(Result.empty()); } }
@Override public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context) { return calculateJoinCost( node, node.getSource(), node.getFilteringSource(), node.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED).equals(SemiJoinNode.DistributionType.REPLICATED)); }
@Override public Set<PlanFragmentId> visitSemiJoin(SemiJoinNode node, PlanFragmentId currentFragmentId) { return processJoin(node.getFilteringSource(), node.getSource(), currentFragmentId); }
public class AggregationBuilder { private PlanNode source; private Map<Symbol, Aggregation> assignments = new HashMap<>(); private AggregationNode.GroupingSetDescriptor groupingSets; private List<Symbol> preGroupedSymbols = new ArrayList<>(); private Step step = Step.SINGLE; private Optional<Symbol> hashSymbol = Optional.empty(); private Optional<Symbol> groupIdSymbol = Optional.empty(); public AggregationBuilder source(PlanNode source) { this.source = source; return this; } public AggregationBuilder addAggregation(Symbol output, Expression expression, List<Type> inputTypes) { return addAggregation(output, expression, inputTypes, Optional.empty()); } public AggregationBuilder addAggregation(Symbol output, Expression expression, List<Type> inputTypes, Symbol mask) { return addAggregation(output, expression, inputTypes, Optional.of(mask)); } private AggregationBuilder addAggregation(Symbol output, Expression expression, List<Type> inputTypes, Optional<Symbol> mask) { checkArgument(expression instanceof FunctionCall); FunctionCall aggregation = (FunctionCall) expression;
@Override public PlanNode replaceChildren(List<PlanNode> newChildren) { checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes"); return new SemiJoinNode( getId(), newChildren.get(0), newChildren.get(1), sourceJoinSymbol, filteringSourceJoinSymbol, semiJoinOutput, sourceHashSymbol, filteringSourceHashSymbol, distributionType); } }
@Override public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Expression> context) { Expression inheritedPredicate = context.get(); if (!extractConjuncts(inheritedPredicate).contains(node.getSemiJoinOutput().toSymbolReference())) { return visitNonFilteringSemiJoin(node, context); } return visitFilteringSemiJoin(node, context); }