@Override protected Optional<PlanNode> pushDownProjectOff( PlanNodeIdAllocator idAllocator, AggregationNode aggregationNode, Set<Symbol> referencedOutputs) { Map<Symbol, AggregationNode.Aggregation> prunedAggregations = Maps.filterKeys( aggregationNode.getAggregations(), referencedOutputs::contains); if (prunedAggregations.size() == aggregationNode.getAggregations().size()) { return Optional.empty(); } // PruneAggregationSourceColumns will subsequently project off any newly unused inputs. return Optional.of( new AggregationNode( aggregationNode.getId(), aggregationNode.getSource(), prunedAggregations, aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedSymbols(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol())); } }
@Override public PlanWithProperties visitAggregation(AggregationNode node, HashComputationSet parentPreference) { Optional<HashComputation> groupByHash = Optional.empty(); if (!node.isStreamable() && !canSkipHashGeneration(node.getGroupingKeys())) { groupByHash = computeHash(node.getGroupingKeys()); } // aggregation does not pass through preferred hash symbols HashComputationSet requiredHashes = new HashComputationSet(groupByHash); PlanWithProperties child = planAndEnforce(node.getSource(), requiredHashes, false, requiredHashes); Optional<Symbol> hashSymbol = groupByHash.map(child::getRequiredHashSymbol); return new PlanWithProperties( new AggregationNode( node.getId(), child.getNode(), node.getAggregations(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), hashSymbol, node.getGroupIdSymbol()), hashSymbol.isPresent() ? ImmutableMap.of(groupByHash.get(), hashSymbol.get()) : ImmutableMap.of()); }
private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId) { ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); for (Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { aggregations.put(map(entry.getKey()), map(entry.getValue())); } return new AggregationNode( newNodeId, source, aggregations.build(), groupingSets( mapAndDistinct(node.getGroupingKeys()), node.getGroupingSetCount(), node.getGlobalGroupingSets()), ImmutableList.of(), node.getStep(), node.getHashSymbol().map(this::map), node.getGroupIdSymbol().map(this::map)); }
private static boolean isSupportedAggregationNode(AggregationNode aggregationNode) { // Don't split streaming aggregations if (aggregationNode.isStreamable()) { return false; } if (aggregationNode.getHashSymbol().isPresent()) { // TODO: add support for hash symbol in aggregation node return false; } return aggregationNode.getStep() == PARTIAL && aggregationNode.getGroupingSetCount() == 1; }
private AggregationNode replaceAggregationSource( AggregationNode aggregation, PlanNode source, List<Symbol> groupingKeys) { return new AggregationNode( aggregation.getId(), source, aggregation.getAggregations(), singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); }
@Override public PlanNode replaceChildren(List<PlanNode> newChildren) { return new AggregationNode(getId(), Iterables.getOnlyElement(newChildren), aggregations, groupingSets, preGroupedSymbols, step, hashSymbol, groupIdSymbol); }
@Override public PlanNode visitAggregation(AggregationNode node, List<PlanNode> newChildren) { return new AggregationNode(node.getId(), Iterables.getOnlyElement(newChildren), node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
checkState(node.getStep() == AggregationNode.Step.SINGLE, "step of aggregation is expected to be SINGLE, but it is %s", node.getStep()); if (node.hasSingleNodeExecutionPreference(metadata.getFunctionRegistry())) { return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); List<Symbol> groupingKeys = node.getGroupingKeys(); if (node.hasDefaultOutput()) { checkState(node.isDecomposable(metadata.getFunctionRegistry())); PlanWithProperties child = planAndEnforce(node.getSource(), any(), defaultParallelism(session)); PlanWithProperties exchange = deriveProperties( partitionedExchange( .constrainTo(node.getSource().getOutputSymbols()) .withDefaultParallelism(session) .withPartitioning(groupingKeys); PlanWithProperties child = planAndEnforce(node.getSource(), childRequirements, childRequirements); AggregationNode result = new AggregationNode( node.getId(), child.getNode(), node.getAggregations(), node.getGroupingSets(), preGroupedSymbols, node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol());
if (node.hasEmptyGroupingSet()) { return Optional.empty(); Optional<DecorrelationResult> childDecorrelationResultOptional = lookup.resolve(node.getSource()).accept(this, null); if (!childDecorrelationResultOptional.isPresent()) { return Optional.empty(); .map(node, childDecorrelationResult.node); Set<Symbol> groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys()); List<Symbol> symbolsToAdd = childDecorrelationResult.symbolsToPropagate.stream() .filter(symbol -> !groupingKeys.contains(symbol)) decorrelatedAggregation.getId(), decorrelatedAggregation.getSource(), decorrelatedAggregation.getAggregations(), AggregationNode.singleGroupingSet(ImmutableList.<Symbol>builder() .addAll(node.getGroupingKeys()) .addAll(symbolsToAdd) .build()), ImmutableList.of(), decorrelatedAggregation.getStep(), decorrelatedAggregation.getHashSymbol(), decorrelatedAggregation.getGroupIdSymbol()); boolean atMostSingleRow = newAggregation.getGroupingSetCount() == 1 && constantSymbols.containsAll(newAggregation.getGroupingKeys());
List<Symbol> masks = node.getAggregations().values().stream() .map(Aggregation::getMask).filter(Optional::isPresent).map(Optional::get).collect(toImmutableList()); Set<Symbol> uniqueMasks = ImmutableSet.copyOf(masks); if (uniqueMasks.size() != 1 || masks.size() == node.getAggregations().size()) { return context.defaultRewrite(node, Optional.empty()); if (node.getAggregations().values().stream().map(Aggregation::getCall).map(FunctionCall::getFilter).anyMatch(Optional::isPresent)) { if (node.hasOrderings()) { node.getGroupingKeys(), Iterables.getOnlyElement(uniqueMasks), node.getAggregations()); PlanNode source = context.rewrite(node.getSource(), Optional.of(aggregateInfo)); for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { FunctionCall functionCall = entry.getValue().getCall(); if (entry.getValue().getMask().isPresent()) { AggregationNode aggregationNode = new AggregationNode( idAllocator.getNextId(), source, aggregations.build(), node.getGroupingSets(), ImmutableList.of(), node.getStep(), Optional.empty(), node.getGroupIdSymbol());
private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) { verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation"); ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation); return new AggregationNode( idAllocator.getNextId(), gatheringExchange, outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); }
@Override public Void visitAggregation(AggregationNode node, Integer indent) { String type = ""; if (node.getStep() != AggregationNode.Step.SINGLE) { type = format("(%s)", node.getStep().toString()); } if (node.isStreamable()) { type = format("%s(STREAMING)", type); } String key = ""; if (!node.getGroupingKeys().isEmpty()) { key = node.getGroupingKeys().toString(); } print(indent, "- Aggregate%s%s%s => [%s]", type, key, formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); printPlanNodesStatsAndCost(indent + 2, node); printStats(indent + 2, node.getId()); for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { if (entry.getValue().getMask().isPresent()) { print(indent + 2, "%s := %s (mask = %s)", entry.getKey(), entry.getValue().getCall(), entry.getValue().getMask().get()); } else { print(indent + 2, "%s := %s", entry.getKey(), entry.getValue().getCall()); } } return processChildren(node, indent + 1); }
boolean aggregateWithoutFilterPresent = false; for (Map.Entry<Symbol, Aggregation> entry : aggregation.getAggregations().entrySet()) { Symbol output = entry.getKey(); if (!aggregation.hasNonEmptyGroupingSet() && !aggregateWithoutFilterPresent) { predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL); newAssignments.putIdentities(aggregation.getSource().getOutputSymbols()); new AggregationNode( context.getIdAllocator().getNextId(), new FilterNode( new ProjectNode( context.getIdAllocator().getNextId(), aggregation.getSource(), newAssignments.build()), predicate), aggregations.build(), aggregation.getGroupingSets(), ImmutableList.of(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()));
AggregationNode aggregationNode = (AggregationNode) node; if (groupId.isPresent() != aggregationNode.getGroupIdSymbol().isPresent()) { return NO_MATCH; if (!matches(groupingSets.getGroupingKeys(), aggregationNode.getGroupingKeys(), symbolAliases)) { return NO_MATCH; if (groupingSets.getGroupingSetCount() != aggregationNode.getGroupingSetCount()) { return NO_MATCH; if (!groupingSets.getGlobalGroupingSets().equals(aggregationNode.getGlobalGroupingSets())) { return NO_MATCH; List<Symbol> aggregationsWithMask = aggregationNode.getAggregations() .entrySet() .stream() if (step != aggregationNode.getStep()) { return NO_MATCH; if (!matches(preGroupedSymbols, aggregationNode.getPreGroupedSymbols(), symbolAliases)) { return NO_MATCH;
@Override public Result apply(AggregationNode aggregation, Captures captures, Context context) { Lookup lookup = context.getLookup(); PlanNodeIdAllocator idAllocator = context.getIdAllocator(); Session session = context.getSession(); Optional<PlanNode> rewrittenSource = recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator); if (!rewrittenSource.isPresent()) { return Result.empty(); } PlanNode source = rewrittenSource.get(); if (getTaskConcurrency(session) > 1) { source = ExchangeNode.partitionedExchange( idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), source.getOutputSymbols())); source = new AggregationNode( idAllocator.getNextId(), source, inputsAsOutputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()); source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source); } return Result.ofPlanNode(aggregation.replaceChildren(ImmutableList.of(source))); }
@Override protected Optional<PlanNodeStatsEstimate> doCalculate(AggregationNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) { if (node.getGroupingSetCount() != 1) { return Optional.empty(); } if (node.getStep() != SINGLE) { return Optional.empty(); } return Optional.of(groupBy( statsProvider.getStats(node.getSource()), node.getGroupingKeys(), node.getAggregations())); }
for (Map.Entry<Symbol, Aggregation> entry : scalarAggregation.getAggregations().entrySet()) { FunctionCall call = entry.getValue().getCall(); Symbol symbol = entry.getKey(); return Optional.of(new AggregationNode( idAllocator.getNextId(), leftOuterJoin, aggregations.build(), singleGroupingSet(leftOuterJoin.getLeft().getOutputSymbols()), ImmutableList.of(), scalarAggregation.getStep(), scalarAggregation.getHashSymbol(), Optional.empty()));
private static boolean isDistinctOperator(AggregationNode node) { return node.getAggregations().isEmpty(); } }
private PlanNode distinct(PlanNode node) { return new AggregationNode(idAllocator.getNextId(), node, ImmutableMap.of(), singleGroupingSet(node.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()); }
private PhysicalOperation planGlobalAggregation(int operatorId, AggregationNode node, PhysicalOperation source) { int outputChannel = 0; ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); List<AccumulatorFactory> accumulatorFactories = new ArrayList<>(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); accumulatorFactories.add(buildAccumulatorFactory(source, node.getFunctions().get(symbol), entry.getValue(), node.getMasks().get(entry.getKey()), Optional.<Integer>empty(), node.getSampleWeight(), node.getConfidence())); outputMappings.put(symbol, outputChannel); // one aggregation per channel outputChannel++; } OperatorFactory operatorFactory = new AggregationOperatorFactory(operatorId, node.getId(), node.getStep(), accumulatorFactories); return new PhysicalOperation(operatorFactory, outputMappings.build(), source); }