Refine search
@Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.builder(); List<Symbol> newAggregationArguments = node.getAggregationArguments().stream() .filter(context.get()::contains) .collect(Collectors.toList()); expectedInputs.addAll(newAggregationArguments); ImmutableList.Builder<List<Symbol>> newGroupingSets = ImmutableList.builder(); Map<Symbol, Symbol> newGroupingMapping = new HashMap<>(); for (List<Symbol> groupingSet : node.getGroupingSets()) { ImmutableList.Builder<Symbol> newGroupingSet = ImmutableList.builder(); for (Symbol output : groupingSet) { if (context.get().contains(output)) { newGroupingSet.add(output); newGroupingMapping.putIfAbsent(output, node.getGroupingColumns().get(output)); expectedInputs.add(node.getGroupingColumns().get(output)); } } newGroupingSets.add(newGroupingSet.build()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new GroupIdNode(node.getId(), source, newGroupingSets.build(), newGroupingMapping, newAggregationArguments, node.getGroupIdSymbol()); }
@Override public PlanNode visitUnnest(UnnestNode node, RewriteContext<Set<Symbol>> context) { List<Symbol> replicateSymbols = node.getReplicateSymbols().stream() .filter(context.get()::contains) .collect(toImmutableList()); Optional<Symbol> ordinalitySymbol = node.getOrdinalitySymbol(); if (ordinalitySymbol.isPresent() && !context.get().contains(ordinalitySymbol.get())) { ordinalitySymbol = Optional.empty(); } Map<Symbol, List<Symbol>> unnestSymbols = node.getUnnestSymbols(); ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(replicateSymbols) .addAll(unnestSymbols.keySet()); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new UnnestNode(node.getId(), source, replicateSymbols, unnestSymbols, ordinalitySymbol); }
private static void flattenSetOperation(SetOperationNode node, RewriteContext<Boolean> context, ImmutableList.Builder<PlanNode> flattenedSources, ImmutableListMultimap.Builder<Symbol, Symbol> flattenedSymbolMap) { for (int i = 0; i < node.getSources().size(); i++) { PlanNode subplan = node.getSources().get(i); PlanNode rewrittenSource = context.rewrite(subplan, context.get()); Class<?> setOperationClass = node.getClass(); if (setOperationClass.isInstance(rewrittenSource) && (!setOperationClass.equals(ExceptNode.class) || i == 0)) { // Absorb source's subplans if it is also a SetOperation of the same type // ExceptNodes can only flatten their first source because except is not associative SetOperationNode rewrittenSetOperation = (SetOperationNode) rewrittenSource; flattenedSources.addAll(rewrittenSetOperation.getSources()); for (Map.Entry<Symbol, Collection<Symbol>> entry : node.getSymbolMapping().asMap().entrySet()) { Symbol inputSymbol = Iterables.get(entry.getValue(), i); flattenedSymbolMap.putAll(entry.getKey(), rewrittenSetOperation.getSymbolMapping().get(inputSymbol)); } } else { flattenedSources.add(rewrittenSource); for (Map.Entry<Symbol, Collection<Symbol>> entry : node.getSymbolMapping().asMap().entrySet()) { flattenedSymbolMap.put(entry.getKey(), Iterables.get(entry.getValue(), i)); } } } }
private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext<Expression> context) Expression inheritedPredicate = context.get(); Expression deterministicInheritedPredicate = filterDeterministicConjuncts(inheritedPredicate); Expression sourceEffectivePredicate = filterDeterministicConjuncts(effectivePredicateExtractor.extract(node.getSource())); filteringSourceConjuncts.addAll(allInferenceWithoutFilteringSourceInferred.generateEqualitiesPartitionedBy(in(filteringSourceSymbols)).getScopeEqualities()); PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(sourceConjuncts)); PlanNode rewrittenFilteringSource = context.rewrite(node.getFilteringSource(), combineConjuncts(filteringSourceConjuncts));
@Override public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext<Expression> context) Expression inheritedPredicate = context.get(); verify(!newJoinPredicate.equals(BooleanLiteral.FALSE_LITERAL), "Spatial join predicate is missing"); PlanNode leftSource = context.rewrite(node.getLeft(), leftPredicate); PlanNode rightSource = context.rewrite(node.getRight(), rightPredicate);
@Override public PlanNode visitExchange(ExchangeNode node, RewriteContext<Set<Symbol>> context) Set<Symbol> expectedOutputSymbols = Sets.newHashSet(context.get()); node.getPartitioningScheme().getHashColumn().ifPresent(expectedOutputSymbols::add); node.getPartitioningScheme().getPartitioning().getColumns().stream() .addAll(inputsBySource.get(i)); rewrittenSources.add(context.rewrite( node.getSources().get(i), expectedInputs.build()));
expectedFilterInputs = ImmutableSet.<Symbol>builder() .addAll(SymbolsExtractor.extractUnique(node.getFilter().get())) .addAll(context.get()) .build(); leftInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getLeft)); if (node.getLeftHashSymbol().isPresent()) { leftInputsBuilder.add(node.getLeftHashSymbol().get()); rightInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight)); if (node.getRightHashSymbol().isPresent()) { rightInputsBuilder.add(node.getRightHashSymbol().get()); Set<Symbol> rightInputs = rightInputsBuilder.build(); PlanNode left = context.rewrite(node.getLeft(), leftInputs); PlanNode right = context.rewrite(node.getRight(), rightInputs);
.addAll(context.get()) .addAll(node.getPartitionBy()); WindowNode.Function function = entry.getValue(); if (context.get().contains(symbol)) { FunctionCall call = function.getFunctionCall(); expectedInputs.addAll(SymbolsExtractor.extractUnique(call)); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
@Override public PlanNode visitUnnest(UnnestNode node, RewriteContext<Expression> context) Expression inheritedPredicate = context.get(); postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts));
private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext<Expression> context) Expression inheritedPredicate = context.get(); List<Expression> sourceConjuncts = new ArrayList<>(); List<Expression> postJoinConjuncts = new ArrayList<>(); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(sourceConjuncts));
if (intersection(node.getSubqueryAssignments().getSymbols(), context.get()).isEmpty()) { return context.rewrite(node.getInput(), context.get()); Symbol output = entry.getKey(); Expression expression = entry.getValue(); if (context.get().contains(output)) { subqueryAssignmentsSymbolsBuilder.addAll(SymbolsExtractor.extractUnique(expression)); subqueryAssignments.put(output, expression); PlanNode subquery = context.rewrite(node.getSubquery(), subqueryAssignmentsSymbols); PlanNode input = context.rewrite(node.getInput(), inputContext); return new ApplyNode(node.getId(), input, subquery, subqueryAssignments.build(), newCorrelation, node.getOriginSubquery());
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(node.getGroupingKeys()); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); if (context.get().contains(symbol)) { Aggregation aggregation = entry.getValue(); expectedInputs.addAll(SymbolsExtractor.extractUnique(aggregation.getCall())); aggregation.getMask().ifPresent(expectedInputs::add); aggregations.put(symbol, aggregation); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new AggregationNode(node.getId(), source, aggregations.build(), node.getGroupingSets(), ImmutableList.of(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); }
@Override public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> sourceInputsBuilder = ImmutableSet.builder(); sourceInputsBuilder.addAll(context.get()).add(node.getSourceJoinSymbol()); if (node.getSourceHashSymbol().isPresent()) { sourceInputsBuilder.add(node.getSourceHashSymbol().get()); } Set<Symbol> sourceInputs = sourceInputsBuilder.build(); ImmutableSet.Builder<Symbol> filteringSourceInputBuilder = ImmutableSet.builder(); filteringSourceInputBuilder.add(node.getFilteringSourceJoinSymbol()); if (node.getFilteringSourceHashSymbol().isPresent()) { filteringSourceInputBuilder.add(node.getFilteringSourceHashSymbol().get()); } Set<Symbol> filteringSourceInputs = filteringSourceInputBuilder.build(); PlanNode source = context.rewrite(node.getSource(), sourceInputs); PlanNode filteringSource = context.rewrite(node.getFilteringSource(), filteringSourceInputs); return new SemiJoinNode(node.getId(), source, filteringSource, node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(), node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol(), node.getDistributionType()); }
@Override public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext<Set<Symbol>> context) { PlanNode subquery = context.rewrite(node.getSubquery(), context.get()); // remove unused lateral nodes if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), context.get()).isEmpty() && isScalar(subquery)) { return context.rewrite(node.getInput(), context.get()); } // prune not used correlation symbols Set<Symbol> subquerySymbols = SymbolsExtractor.extractUnique(subquery); List<Symbol> newCorrelation = node.getCorrelation().stream() .filter(subquerySymbols::contains) .collect(toImmutableList()); Set<Symbol> inputContext = ImmutableSet.<Symbol>builder() .addAll(context.get()) .addAll(newCorrelation) .build(); PlanNode input = context.rewrite(node.getInput(), inputContext); // remove unused lateral nodes if (intersection(ImmutableSet.copyOf(input.getOutputSymbols()), inputContext).isEmpty() && isScalar(input)) { return subquery; } return new LateralJoinNode(node.getId(), input, subquery, newCorrelation, node.getType(), node.getOriginSubquery()); } }
@Override public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> probeInputsBuilder = ImmutableSet.builder(); probeInputsBuilder.addAll(context.get()) .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getProbe)); if (node.getProbeHashSymbol().isPresent()) { probeInputsBuilder.add(node.getProbeHashSymbol().get()); } Set<Symbol> probeInputs = probeInputsBuilder.build(); ImmutableSet.Builder<Symbol> indexInputBuilder = ImmutableSet.builder(); indexInputBuilder.addAll(context.get()) .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getIndex)); if (node.getIndexHashSymbol().isPresent()) { indexInputBuilder.add(node.getIndexHashSymbol().get()); } Set<Symbol> indexInputs = indexInputBuilder.build(); PlanNode probeSource = context.rewrite(node.getProbeSource(), probeInputs); PlanNode indexSource = context.rewrite(node.getIndexSource(), indexInputs); return new IndexJoinNode(node.getId(), node.getType(), probeSource, indexSource, node.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol()); }
@Override public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Void> context) { PlanNode sourceRewritten = context.rewrite(node.getSource(), context.get()); PlanNode filteringSourceRewritten = context.rewrite(node.getFilteringSource(), context.get()); SemiJoinNode rewrittenNode = new SemiJoinNode( node.getId(), sourceRewritten, filteringSourceRewritten, node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(), node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol(), node.getDistributionType()); if (isDeleteQuery) { return rewrittenNode.withDistributionType(REPLICATED); } return rewrittenNode; }
@Override public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext<Set<Symbol>> context) { Set<Symbol> requiredInputs = ImmutableSet.<Symbol>builder() .addAll(SymbolsExtractor.extractUnique(node.getFilter())) .addAll(context.get()) .build(); ImmutableSet.Builder<Symbol> leftInputs = ImmutableSet.builder(); node.getLeftPartitionSymbol().map(leftInputs::add); ImmutableSet.Builder<Symbol> rightInputs = ImmutableSet.builder(); node.getRightPartitionSymbol().map(rightInputs::add); PlanNode left = context.rewrite(node.getLeft(), leftInputs.addAll(requiredInputs).build()); PlanNode right = context.rewrite(node.getRight(), rightInputs.addAll(requiredInputs).build()); List<Symbol> outputSymbols = node.getOutputSymbols().stream() .filter(context.get()::contains) .distinct() .collect(toImmutableList()); return new SpatialJoinNode(node.getId(), node.getType(), left, right, outputSymbols, node.getFilter(), node.getLeftPartitionSymbol(), node.getRightPartitionSymbol(), node.getKdbTree()); }