private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId) { ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); for (Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { aggregations.put(map(entry.getKey()), map(entry.getValue())); } return new AggregationNode( newNodeId, source, aggregations.build(), groupingSets( mapAndDistinct(node.getGroupingKeys()), node.getGroupingSetCount(), node.getGlobalGroupingSets()), ImmutableList.of(), node.getStep(), node.getHashSymbol().map(this::map), node.getGroupIdSymbol().map(this::map)); }
private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId) { ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); for (Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { aggregations.put(map(entry.getKey()), map(entry.getValue())); } return new AggregationNode( newNodeId, source, aggregations.build(), groupingSets( mapAndDistinct(node.getGroupingKeys()), node.getGroupingSetCount(), node.getGlobalGroupingSets()), ImmutableList.of(), node.getStep(), node.getHashSymbol().map(this::map), node.getGroupIdSymbol().map(this::map)); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(node.getGroupingKeys()); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); if (context.get().contains(symbol)) { Aggregation aggregation = entry.getValue(); expectedInputs.addAll(SymbolsExtractor.extractUnique(aggregation.getCall())); aggregation.getMask().ifPresent(expectedInputs::add); aggregations.put(symbol, aggregation); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new AggregationNode(node.getId(), source, aggregations.build(), node.getGroupingSets(), ImmutableList.of(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(node.getGroupingKeys()); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); if (context.get().contains(symbol)) { Aggregation aggregation = entry.getValue(); expectedInputs.addAll(SymbolsExtractor.extractUnique(aggregation.getCall())); aggregation.getMask().ifPresent(expectedInputs::add); aggregations.put(symbol, aggregation); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new AggregationNode(node.getId(), source, aggregations.build(), node.getGroupingSets(), ImmutableList.of(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); }
@Override protected Optional<PlanNode> pushDownProjectOff( PlanNodeIdAllocator idAllocator, AggregationNode aggregationNode, Set<Symbol> referencedOutputs) { Map<Symbol, AggregationNode.Aggregation> prunedAggregations = Maps.filterKeys( aggregationNode.getAggregations(), referencedOutputs::contains); if (prunedAggregations.size() == aggregationNode.getAggregations().size()) { return Optional.empty(); } // PruneAggregationSourceColumns will subsequently project off any newly unused inputs. return Optional.of( new AggregationNode( aggregationNode.getId(), aggregationNode.getSource(), prunedAggregations, aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedSymbols(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol())); } }
@Override protected Optional<PlanNode> pushDownProjectOff( PlanNodeIdAllocator idAllocator, AggregationNode aggregationNode, Set<Symbol> referencedOutputs) { Map<Symbol, AggregationNode.Aggregation> prunedAggregations = Maps.filterKeys( aggregationNode.getAggregations(), referencedOutputs::contains); if (prunedAggregations.size() == aggregationNode.getAggregations().size()) { return Optional.empty(); } // PruneAggregationSourceColumns will subsequently project off any newly unused inputs. return Optional.of( new AggregationNode( aggregationNode.getId(), aggregationNode.getSource(), prunedAggregations, aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedSymbols(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol())); } }
private AggregationNode replaceAggregationSource( AggregationNode aggregation, PlanNode source, List<Symbol> groupingKeys) { return new AggregationNode( aggregation.getId(), source, aggregation.getAggregations(), singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); }
private AggregationNode replaceAggregationSource( AggregationNode aggregation, PlanNode source, List<Symbol> groupingKeys) { return new AggregationNode( aggregation.getId(), source, aggregation.getAggregations(), singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); }
@Override public PlanWithProperties visitAggregation(AggregationNode node, HashComputationSet parentPreference) { Optional<HashComputation> groupByHash = Optional.empty(); if (!node.isStreamable() && !canSkipHashGeneration(node.getGroupingKeys())) { groupByHash = computeHash(node.getGroupingKeys()); } // aggregation does not pass through preferred hash symbols HashComputationSet requiredHashes = new HashComputationSet(groupByHash); PlanWithProperties child = planAndEnforce(node.getSource(), requiredHashes, false, requiredHashes); Optional<Symbol> hashSymbol = groupByHash.map(child::getRequiredHashSymbol); return new PlanWithProperties( new AggregationNode( node.getId(), child.getNode(), node.getAggregations(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), hashSymbol, node.getGroupIdSymbol()), hashSymbol.isPresent() ? ImmutableMap.of(groupByHash.get(), hashSymbol.get()) : ImmutableMap.of()); }
@Override public PlanWithProperties visitAggregation(AggregationNode node, HashComputationSet parentPreference) { Optional<HashComputation> groupByHash = Optional.empty(); if (!node.isStreamable() && !canSkipHashGeneration(node.getGroupingKeys())) { groupByHash = computeHash(node.getGroupingKeys()); } // aggregation does not pass through preferred hash symbols HashComputationSet requiredHashes = new HashComputationSet(groupByHash); PlanWithProperties child = planAndEnforce(node.getSource(), requiredHashes, false, requiredHashes); Optional<Symbol> hashSymbol = groupByHash.map(child::getRequiredHashSymbol); return new PlanWithProperties( new AggregationNode( node.getId(), child.getNode(), node.getAggregations(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), hashSymbol, node.getGroupIdSymbol()), hashSymbol.isPresent() ? ImmutableMap.of(groupByHash.get(), hashSymbol.get()) : ImmutableMap.of()); }
@Override public Result apply(AggregationNode aggregationNode, Captures captures, Context context) { boolean anyRewritten = false; ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); for (Map.Entry<Symbol, Aggregation> entry : aggregationNode.getAggregations().entrySet()) { Aggregation aggregation = entry.getValue(); FunctionCall call = (FunctionCall) rewriter.rewrite(aggregation.getCall(), context); aggregations.put( entry.getKey(), new Aggregation(call, aggregation.getSignature(), aggregation.getMask())); if (!aggregation.getCall().equals(call)) { anyRewritten = true; } } if (anyRewritten) { return Result.ofPlanNode(new AggregationNode( aggregationNode.getId(), aggregationNode.getSource(), aggregations.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedSymbols(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol())); } return Result.empty(); } }
@Override public Result apply(AggregationNode aggregationNode, Captures captures, Context context) { boolean anyRewritten = false; ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); for (Map.Entry<Symbol, Aggregation> entry : aggregationNode.getAggregations().entrySet()) { Aggregation aggregation = entry.getValue(); FunctionCall call = (FunctionCall) rewriter.rewrite(aggregation.getCall(), context); aggregations.put( entry.getKey(), new Aggregation(call, aggregation.getSignature(), aggregation.getMask())); if (!aggregation.getCall().equals(call)) { anyRewritten = true; } } if (anyRewritten) { return Result.ofPlanNode(new AggregationNode( aggregationNode.getId(), aggregationNode.getSource(), aggregations.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedSymbols(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol())); } return Result.empty(); } }
@Override public Result apply(AggregationNode parent, Captures captures, Context context) { ProjectNode child = captures.get(CHILD); boolean changed = false; Map<Symbol, AggregationNode.Aggregation> aggregations = new LinkedHashMap<>(parent.getAggregations()); for (Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); AggregationNode.Aggregation aggregation = entry.getValue(); if (isCountOverConstant(aggregation, child.getAssignments())) { changed = true; aggregations.put(symbol, new AggregationNode.Aggregation( new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT)), aggregation.getMask())); } } if (!changed) { return Result.empty(); } return Result.ofPlanNode(new AggregationNode( parent.getId(), child, aggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol())); }
@Override public Result apply(AggregationNode parent, Captures captures, Context context) { ProjectNode child = captures.get(CHILD); boolean changed = false; Map<Symbol, AggregationNode.Aggregation> aggregations = new LinkedHashMap<>(parent.getAggregations()); for (Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); AggregationNode.Aggregation aggregation = entry.getValue(); if (isCountOverConstant(aggregation, child.getAssignments())) { changed = true; aggregations.put(symbol, new AggregationNode.Aggregation( new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT)), aggregation.getMask())); } } if (!changed) { return Result.empty(); } return Result.ofPlanNode(new AggregationNode( parent.getId(), child, aggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol())); }
private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) { verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation"); ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation); return new AggregationNode( idAllocator.getNextId(), gatheringExchange, outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); }
private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) { verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation"); ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation); return new AggregationNode( idAllocator.getNextId(), gatheringExchange, outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Boolean> context) { boolean distinct = isDistinctOperator(node); PlanNode rewrittenNode = context.rewrite(node.getSource(), distinct); if (context.get() && distinct) { // Assumes underlying node has same output symbols as this distinct node return rewrittenNode; } return new AggregationNode( node.getId(), rewrittenNode, node.getAggregations(), node.getGroupingSets(), ImmutableList.of(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Boolean> context) { boolean distinct = isDistinctOperator(node); PlanNode rewrittenNode = context.rewrite(node.getSource(), distinct); if (context.get() && distinct) { // Assumes underlying node has same output symbols as this distinct node return rewrittenNode; } return new AggregationNode( node.getId(), rewrittenNode, node.getAggregations(), node.getGroupingSets(), ImmutableList.of(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); }
@Override public Result apply(AggregationNode aggregation, Captures captures, Context context) { Lookup lookup = context.getLookup(); PlanNodeIdAllocator idAllocator = context.getIdAllocator(); Session session = context.getSession(); Optional<PlanNode> rewrittenSource = recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator); if (!rewrittenSource.isPresent()) { return Result.empty(); } PlanNode source = rewrittenSource.get(); if (getTaskConcurrency(session) > 1) { source = ExchangeNode.partitionedExchange( idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), source.getOutputSymbols())); source = new AggregationNode( idAllocator.getNextId(), source, inputsAsOutputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source); } return Result.ofPlanNode(aggregation.replaceChildren(ImmutableList.of(source))); }
@Override public Result apply(AggregationNode aggregation, Captures captures, Context context) { Lookup lookup = context.getLookup(); PlanNodeIdAllocator idAllocator = context.getIdAllocator(); Session session = context.getSession(); Optional<PlanNode> rewrittenSource = recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator); if (!rewrittenSource.isPresent()) { return Result.empty(); } PlanNode source = rewrittenSource.get(); if (getTaskConcurrency(session) > 1) { source = ExchangeNode.partitionedExchange( idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), source.getOutputSymbols())); source = new AggregationNode( idAllocator.getNextId(), source, inputsAsOutputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source); } return Result.ofPlanNode(aggregation.replaceChildren(ImmutableList.of(source))); }