/** * Rewrite assignments so that inputs are in terms of the output symbols. * <p> * Example: * 'a' := sum('b') => 'a' := sum('a') * 'a' := count(*) => 'a' := count('a') */ private static Map<Symbol, AggregationNode.Aggregation> outputsAsInputs(Map<Symbol, AggregationNode.Aggregation> assignments) { ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> builder = ImmutableMap.builder(); for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) { Symbol output = entry.getKey(); AggregationNode.Aggregation aggregation = entry.getValue(); checkState(!aggregation.getCall().getOrderBy().isPresent(), "Intermediate aggregation does not support ORDER BY"); builder.put( output, new AggregationNode.Aggregation( new FunctionCall(QualifiedName.of(aggregation.getSignature().getName()), ImmutableList.of(output.toSymbolReference())), aggregation.getSignature(), Optional.empty())); // No mask for INTERMEDIATE } return builder.build(); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(node.getGroupingKeys()); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); if (context.get().contains(symbol)) { Aggregation aggregation = entry.getValue(); expectedInputs.addAll(SymbolsExtractor.extractUnique(aggregation.getCall())); aggregation.getMask().ifPresent(expectedInputs::add); aggregations.put(symbol, aggregation); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new AggregationNode(node.getId(), source, aggregations.build(), node.getGroupingSets(), ImmutableList.of(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); }
aggregationsBuilder.put(newSymbol, new Aggregation((FunctionCall) rewritten, analysis.getFunctionSignature(aggregate), Optional.empty()));
aggregationsBuilder.put(newSymbol, new Aggregation((FunctionCall) rewritten, analysis.getFunctionSignature(aggregate), Optional.empty()));
.getAggregateFunctionImplementation(aggregation.getSignature()); for (Expression argument : aggregation.getCall().getArguments()) { if (!(argument instanceof LambdaExpression)) { Symbol argumentSymbol = Symbol.from(argument); List<LambdaExpression> lambdaExpressions = aggregation.getCall().getArguments().stream() .filter(LambdaExpression.class::isInstance) .map(LambdaExpression.class::cast) .collect(toImmutableList()); if (!lambdaExpressions.isEmpty()) { List<FunctionType> functionTypes = aggregation.getSignature().getArgumentTypes().stream() .filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)) .map(typeSignature -> (FunctionType) (metadata.getTypeManager().getType(typeSignature))) Optional<Integer> maskChannel = aggregation.getMask().map(value -> source.getLayout().get(value)); List<SortOrder> sortOrders = ImmutableList.of(); List<Symbol> sortKeys = ImmutableList.of(); if (aggregation.getCall().getOrderBy().isPresent()) { OrderBy orderBy = aggregation.getCall().getOrderBy().get(); sortOrders, pagesIndexFactory, aggregation.getCall().isDistinct(), joinCompiler, lambdaProviders,
.getAggregateFunctionImplementation(aggregation.getSignature()); for (Expression argument : aggregation.getCall().getArguments()) { if (!(argument instanceof LambdaExpression)) { Symbol argumentSymbol = Symbol.from(argument); List<LambdaExpression> lambdaExpressions = aggregation.getCall().getArguments().stream() .filter(LambdaExpression.class::isInstance) .map(LambdaExpression.class::cast) .collect(toImmutableList()); if (!lambdaExpressions.isEmpty()) { List<FunctionType> functionTypes = aggregation.getSignature().getArgumentTypes().stream() .filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)) .map(typeSignature -> (FunctionType) (metadata.getTypeManager().getType(typeSignature))) Optional<Integer> maskChannel = aggregation.getMask().map(value -> source.getLayout().get(value)); List<SortOrder> sortOrders = ImmutableList.of(); List<Symbol> sortKeys = ImmutableList.of(); if (aggregation.getCall().getOrderBy().isPresent()) { OrderBy orderBy = aggregation.getCall().getOrderBy().get(); sortOrders, pagesIndexFactory, aggregation.getCall().isDistinct(), joinCompiler, lambdaProviders,
FunctionCall functionCall = entry.getValue().getCall(); if (entry.getValue().getMask().isPresent()) { aggregations.put(entry.getKey(), new Aggregation( new FunctionCall( functionCall.getName(), false, ImmutableList.of(aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference())), entry.getValue().getSignature(), Optional.empty())); String signatureName = entry.getValue().getSignature().getName(); Aggregation aggregation = new Aggregation( new FunctionCall(functionName, functionCall.getWindow(), false, ImmutableList.of(argument.toSymbolReference())), getFunctionSignature(functionName, argument),
FunctionCall functionCall = entry.getValue().getCall(); if (entry.getValue().getMask().isPresent()) { aggregations.put(entry.getKey(), new Aggregation( new FunctionCall( functionCall.getName(), false, ImmutableList.of(aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference())), entry.getValue().getSignature(), Optional.empty())); String signatureName = entry.getValue().getSignature().getName(); Aggregation aggregation = new Aggregation( new FunctionCall(functionName, functionCall.getWindow(), false, ImmutableList.of(argument.toSymbolReference())), getFunctionSignature(functionName, argument),
FunctionCall call = aggregation.getCall(); if (call.isDistinct() && !call.getFilter().isPresent() && !aggregation.getMask().isPresent()) { Set<Symbol> inputs = call.getArguments().stream() .map(Symbol::from) new Aggregation( new FunctionCall( call.getName(), false, call.getArguments()), aggregation.getSignature(), Optional.of(marker)));
FunctionCall call = aggregation.getCall(); if (call.isDistinct() && !call.getFilter().isPresent() && !aggregation.getMask().isPresent()) { Set<Symbol> inputs = call.getArguments().stream() .map(Symbol::from) new Aggregation( new FunctionCall( call.getName(), false, call.getArguments()), aggregation.getSignature(), Optional.of(marker)));
FunctionCall call = entry.getValue().getCall(); Optional<Symbol> mask = entry.getValue().getMask(); aggregations.put(output, new Aggregation( new FunctionCall(call.getName(), call.getWindow(), Optional.empty(), call.getOrderBy(), call.isDistinct(), call.getArguments()), entry.getValue().getSignature(), mask));
subqueryPlan, ImmutableMap.of( minValue, new Aggregation( new FunctionCall(MIN, outputColumnReferences), functionRegistry.resolveFunction(MIN, fromTypeSignatures(outputColumnTypeSignature)), Optional.empty()), maxValue, new Aggregation( new FunctionCall(MAX, outputColumnReferences), functionRegistry.resolveFunction(MAX, fromTypeSignatures(outputColumnTypeSignature)), Optional.empty()), countAllValue, new Aggregation( new FunctionCall(COUNT, emptyList()), functionRegistry.resolveFunction(COUNT, emptyList()), Optional.empty()), countNonNullValue, new Aggregation( new FunctionCall(COUNT, outputColumnReferences), functionRegistry.resolveFunction(COUNT, fromTypeSignatures(outputColumnTypeSignature)),
for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) { Aggregation aggregation = entry.getValue(); FunctionCall call = aggregation.getCall(); QualifiedName name = call.getName(); if (name.toString().equals(NAME) && call.getArguments().size() == 1) { new Aggregation( new FunctionCall(name, ImmutableList.of(envelopeSymbol.toSymbolReference(), partitionCountSymbol.toSymbolReference())), INTERNAL_SIGNATURE, aggregation.getMask()));
AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation( (FunctionCall) inlineSymbols(sourcesSymbolMapping, aggregation.getCall()), aggregation.getSignature(), aggregation.getMask().map(x -> Symbol.from(sourcesSymbolMapping.get(x)))); Symbol overNullSymbol = symbolAllocator.newSymbol(overNullAggregation.getCall(), symbolAllocator.getTypes().get(aggregationSymbol)); aggregationsOverNullBuilder.put(overNullSymbol, overNullAggregation); aggregationsSymbolMappingBuilder.put(aggregationSymbol, overNullSymbol);
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) { AggregationNode.Aggregation originalAggregation = entry.getValue(); Signature signature = originalAggregation.getSignature(); InternalAggregationFunction function = functionRegistry.getAggregateFunctionImplementation(signature); Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(signature.getName(), function.getIntermediateType()); checkState(!originalAggregation.getCall().getOrderBy().isPresent(), "Aggregate with ORDER BY does not support partial aggregation"); intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(originalAggregation.getCall(), signature, originalAggregation.getMask())); new AggregationNode.Aggregation( new FunctionCall( QualifiedName.of(signature.getName()), ImmutableList.<Expression>builder() .add(intermediateSymbol.toSymbolReference()) .addAll(originalAggregation.getCall().getArguments().stream() .filter(LambdaExpression.class::isInstance) .collect(toImmutableList()))
subqueryPlan, ImmutableMap.of( minValue, new Aggregation( new FunctionCall(MIN, outputColumnReferences), functionRegistry.resolveFunction(MIN, fromTypeSignatures(outputColumnTypeSignature)), Optional.empty()), maxValue, new Aggregation( new FunctionCall(MAX, outputColumnReferences), functionRegistry.resolveFunction(MAX, fromTypeSignatures(outputColumnTypeSignature)), Optional.empty()), countAllValue, new Aggregation( new FunctionCall(COUNT, emptyList()), functionRegistry.resolveFunction(COUNT, emptyList()), Optional.empty()), countNonNullValue, new Aggregation( new FunctionCall(COUNT, outputColumnReferences), functionRegistry.resolveFunction(COUNT, fromTypeSignatures(outputColumnTypeSignature)),
AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation( (FunctionCall) inlineSymbols(sourcesSymbolMapping, aggregation.getCall()), aggregation.getSignature(), aggregation.getMask().map(x -> Symbol.from(sourcesSymbolMapping.get(x)))); Symbol overNullSymbol = symbolAllocator.newSymbol(overNullAggregation.getCall(), symbolAllocator.getTypes().get(aggregationSymbol)); aggregationsOverNullBuilder.put(overNullSymbol, overNullAggregation); aggregationsSymbolMappingBuilder.put(aggregationSymbol, overNullSymbol);
FunctionCall call = entry.getValue().getCall(); Optional<Symbol> mask = entry.getValue().getMask(); aggregations.put(output, new Aggregation( new FunctionCall(call.getName(), call.getWindow(), Optional.empty(), call.getOrderBy(), call.isDistinct(), call.getArguments()), entry.getValue().getSignature(), mask));