private PhysicalOperation planGlobalAggregation(int operatorId, AggregationNode node, PhysicalOperation source) { int outputChannel = 0; ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); List<AccumulatorFactory> accumulatorFactories = new ArrayList<>(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); accumulatorFactories.add(buildAccumulatorFactory(source, node.getFunctions().get(symbol), entry.getValue(), node.getMasks().get(entry.getKey()), Optional.<Integer>empty(), node.getSampleWeight(), node.getConfidence())); outputMappings.put(symbol, outputChannel); // one aggregation per channel outputChannel++; } OperatorFactory operatorFactory = new AggregationOperatorFactory(operatorId, node.getId(), node.getStep(), accumulatorFactories); return new PhysicalOperation(operatorFactory, outputMappings.build(), source); }
@Override public Void visitAggregation(AggregationNode node, Integer indent) { String type = ""; if (node.getStep() != AggregationNode.Step.SINGLE) { type = format("(%s)", node.getStep().toString()); } String key = ""; if (!node.getGroupBy().isEmpty()) { key = node.getGroupBy().toString(); } String sampleWeight = ""; if (node.getSampleWeight().isPresent()) { sampleWeight = format("[sampleWeight = %s]", node.getSampleWeight().get()); } print(indent, "- Aggregate%s%s%s => [%s]", type, key, sampleWeight, formatOutputs(node.getOutputSymbols())); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { if (node.getMasks().containsKey(entry.getKey())) { print(indent + 2, "%s := %s (mask = %s)", entry.getKey(), entry.getValue(), node.getMasks().get(entry.getKey())); } else { print(indent + 2, "%s := %s", entry.getKey(), entry.getValue()); } } return processChildren(node, indent + 1); }
@Override public Void visitAggregation(AggregationNode node, Void context) { PlanNode source = node.getSource(); source.accept(this, context); // visit child verifyUniqueId(node); Set<Symbol> inputs = ImmutableSet.copyOf(source.getOutputSymbols()); checkDependencies(inputs, node.getGroupBy(), "Invalid node. Group by symbols (%s) not in source plan output (%s)", node.getGroupBy(), node.getSource().getOutputSymbols()); if (node.getSampleWeight().isPresent()) { checkArgument(inputs.contains(node.getSampleWeight().get()), "Invalid node. Sample weight symbol (%s) is not in source plan output (%s)", node.getSampleWeight().get(), node.getSource().getOutputSymbols()); } for (FunctionCall call : node.getAggregations().values()) { Set<Symbol> dependencies = DependencyExtractor.extractUnique(call); checkDependencies(inputs, dependencies, "Invalid node. Aggregation dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); } return null; }
if (node.getSampleWeight().isPresent()) { expectedInputs.add(node.getSampleWeight().get()); masks.build(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol());
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context) { PlanNode source = context.rewrite(node.getSource()); ImmutableMap.Builder<Symbol, Signature> functionInfos = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); Symbol canonical = canonicalize(symbol); FunctionCall canonicalCall = (FunctionCall) canonicalize(entry.getValue()); functionCalls.put(canonical, canonicalCall); functionInfos.put(canonical, node.getFunctions().get(symbol)); } for (Map.Entry<Symbol, Symbol> entry : node.getMasks().entrySet()) { masks.put(canonicalize(entry.getKey()), canonicalize(entry.getValue())); } List<Symbol> groupByKeys = canonicalizeAndDistinct(node.getGroupBy()); return new AggregationNode( node.getId(), source, groupByKeys, functionCalls.build(), functionInfos.build(), masks.build(), node.getStep(), canonicalize(node.getSampleWeight()), node.getConfidence(), canonicalize(node.getHashSymbol())); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Optional<Symbol>> context) { // optimize if and only if // all aggregation functions have a single common distinct mask symbol // AND all aggregation functions have mask Set<Symbol> masks = ImmutableSet.copyOf(node.getMasks().values()); if (masks.size() != 1 || node.getMasks().size() != node.getAggregations().size()) { return context.defaultRewrite(node, Optional.empty()); } PlanNode source = context.rewrite(node.getSource(), Optional.of(Iterables.getOnlyElement(masks))); Map<Symbol, FunctionCall> aggregations = ImmutableMap.copyOf(Maps.transformValues(node.getAggregations(), call -> new FunctionCall(call.getName(), call.getWindow(), false, call.getArguments()))); return new AggregationNode(idAllocator.getNextId(), source, node.getGroupBy(), aggregations, node.getFunctions(), Collections.emptyMap(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<LimitContext> context) { LimitContext limit = context.get(); if (limit != null && node.getAggregations().isEmpty() && node.getOutputSymbols().size() == node.getGroupBy().size() && node.getOutputSymbols().containsAll(node.getGroupBy())) { checkArgument(!node.getSampleWeight().isPresent(), "DISTINCT aggregation has sample weight symbol"); PlanNode rewrittenSource = context.rewrite(node.getSource()); return new DistinctLimitNode(idAllocator.getNextId(), rewrittenSource, limit.getCount(), Optional.empty()); } PlanNode rewrittenNode = context.defaultRewrite(node); if (limit != null) { // Drop in a LimitNode b/c limits cannot be pushed through aggregations rewrittenNode = new LimitNode(idAllocator.getNextId(), rewrittenNode, limit.getCount()); } return rewrittenNode; }
initialMask, PARTIAL, node.getSampleWeight(), node.getConfidence(), node.getHashSymbol());
@Override public PlanNode visitAggregation(AggregationNode node, List<PlanNode> newChildren) { return new AggregationNode(node.getId(), Iterables.getOnlyElement(newChildren), node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context) { Map<Symbol, FunctionCall> aggregations = new LinkedHashMap<>(node.getAggregations()); Map<Symbol, Signature> functions = new LinkedHashMap<>(node.getFunctions()); PlanNode source = context.rewrite(node.getSource()); if (source instanceof ProjectNode) { ProjectNode projectNode = (ProjectNode) source; for (Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); FunctionCall functionCall = entry.getValue(); Signature signature = node.getFunctions().get(symbol); if (isCountConstant(projectNode, functionCall, signature)) { aggregations.put(symbol, new FunctionCall(functionCall.getName(), functionCall.isDistinct(), ImmutableList.<Expression>of())); functions.put(symbol, new Signature("count", AGGREGATE, StandardTypes.BIGINT)); } } } return new AggregationNode( node.getId(), source, node.getGroupBy(), aggregations, functions, node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
intermediateMask, PARTIAL, node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()),
node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol());
node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), Optional.empty()); node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), Optional.of(hashSymbol));
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Boolean> context) { boolean distinct = isDistinctOperator(node); PlanNode rewrittenNode = context.rewrite(node.getSource(), distinct); if (context.get() && distinct) { // Assumes underlying node has same output symbols as this distinct node return rewrittenNode; } return new AggregationNode( node.getId(), rewrittenNode, node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }