@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { return new LambdaExpression(node.getArguments(), treeRewriter.rewrite(node.getBody(), context)); } }, expression);
@Override protected String visitLambdaExpression(LambdaExpression node, Void context) { StringBuilder builder = new StringBuilder(); builder.append('('); Joiner.on(", ").appendTo(builder, node.getArguments()); builder.append(") -> "); builder.append(process(node.getBody(), context)); return builder.toString(); }
@Override public Expression rewriteTryExpression(TryExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Expression expression = treeRewriter.rewrite(node.getInnerExpression(), context); return new FunctionCall( QualifiedName.of("$internal$try"), ImmutableList.of(new LambdaExpression(ImmutableList.of(), expression))); } }
@Override protected Boolean visitLambdaExpression(LambdaExpression node, Void context) { return process(node.getBody(), context); }
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { for (LambdaArgumentDeclaration argument : node.getArguments()) { String argumentName = argument.getName().getValue(); // Symbol names are unique. As a result, a symbol should never be excluded multiple times. checkArgument(!excludedNames.contains(argumentName)); excludedNames.add(argumentName); } Expression result = treeRewriter.defaultRewrite(node, context); for (LambdaArgumentDeclaration argument : node.getArguments()) { excludedNames.remove(argument.getName().getValue()); } return result; } }
@Override protected Void visitLambdaExpression(LambdaExpression node, Set<String> lambdaArgumentNames) { return process(node.getBody(), ImmutableSet.<String>builder() .addAll(lambdaArgumentNames) .addAll(node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Identifier::getValue) .collect(toImmutableSet())) .build()); } }
@Override public Node visitLambda(SqlBaseParser.LambdaContext context) { List<LambdaArgumentDeclaration> arguments = visit(context.identifier(), Identifier.class).stream() .map(LambdaArgumentDeclaration::new) .collect(toList()); Expression body = (Expression) visit(context.expression()); return new LambdaExpression(getLocation(context), arguments, body); }
@Override protected RowExpression visitLambdaExpression(LambdaExpression node, Void context) { RowExpression body = process(node.getBody(), context); Type type = getType(node); List<Type> typeParameters = type.getTypeParameters(); List<Type> argumentTypes = typeParameters.subList(0, typeParameters.size() - 1); List<String> argumentNames = node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Identifier::getValue) .collect(toImmutableList()); return new LambdaDefinitionExpression(argumentTypes, argumentNames, body); }
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { checkState(analysis.getCoercion(node) == null, "cannot coerce a lambda expression"); ImmutableList.Builder<LambdaArgumentDeclaration> newArguments = ImmutableList.builder(); for (LambdaArgumentDeclaration argument : node.getArguments()) { Symbol symbol = lambdaDeclarationToSymbolMap.get(NodeRef.of(argument)); newArguments.add(new LambdaArgumentDeclaration(new Identifier(symbol.getName()))); } Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), null); return new LambdaExpression(newArguments.build(), rewrittenBody); }
@Test public void testRewriteBasicLambda() { final Map<Symbol, Type> symbols = ImmutableMap.of(new Symbol("a"), BigintType.BIGINT); final SymbolAllocator allocator = new SymbolAllocator(symbols); assertEquals(rewrite(expression("x -> a + x"), allocator.getTypes(), allocator), new BindExpression( ImmutableList.of(expression("a")), new LambdaExpression( Stream.of("a_0", "x") .map(Identifier::new) .map(LambdaArgumentDeclaration::new) .collect(toList()), expression("a_0 + x")))); } }
@Override protected Object visitLambdaExpression(LambdaExpression node, Object context) { if (optimize) { // TODO: enable optimization related to lambda expression // A mechanism to convert function type back into lambda expression need to exist to enable optimization return node; } Expression body = node.getBody(); List<String> argumentNames = node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Identifier::getValue) .collect(toImmutableList()); FunctionType functionType = (FunctionType) expressionTypes.get(NodeRef.<Expression>of(node)); checkArgument(argumentNames.size() == functionType.getArgumentTypes().size()); return generateVarArgsToMapAdapter( Primitives.wrap(functionType.getReturnType().getJavaType()), functionType.getArgumentTypes().stream() .map(Type::getJavaType) .map(Primitives::wrap) .collect(toImmutableList()), argumentNames, map -> process(body, new LambdaSymbolResolver(map))); }
Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), context.withReferencedSymbols(referencedSymbols)); List<Symbol> lambdaArguments = node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Identifier::getValue) newLambdaArguments.add(new LambdaArgumentDeclaration(new Identifier(extraSymbol.getName()))); newLambdaArguments.addAll(node.getArguments()); Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), inlineSymbols(symbolMapping, rewrittenBody));
@Test public void testTryExpressionDesugaringRewriter() { // 1 + try(2) Expression before = new ArithmeticBinaryExpression( ADD, new DecimalLiteral("1"), new TryExpression(new DecimalLiteral("2"))); // 1 + try_function(() -> 2) Expression after = new ArithmeticBinaryExpression( ADD, new DecimalLiteral("1"), new FunctionCall( QualifiedName.of("$internal$try"), ImmutableList.of(new LambdaExpression(ImmutableList.of(), new DecimalLiteral("2"))))); assertEquals(DesugarTryExpressionRewriter.rewrite(before), after); } }
verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size()); Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>(); Map<Symbol, Type> lambdaArgumentSymbolTypes = new HashMap<>(); for (int j = 0; j < lambdaExpression.getArguments().size(); j++) { LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j); Type type = functionType.getArgumentTypes().get(j); lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type); sqlParser, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody(), emptyList(), NOOP))
@Test public void testLambda() { assertExpression("() -> x", new LambdaExpression( ImmutableList.of(), new Identifier("x"))); assertExpression("x -> sin(x)", new LambdaExpression( ImmutableList.of(new LambdaArgumentDeclaration(identifier("x"))), new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new Identifier("x"))))); assertExpression("(x, y) -> mod(x, y)", new LambdaExpression( ImmutableList.of(new LambdaArgumentDeclaration(identifier("x")), new LambdaArgumentDeclaration(identifier("y"))), new FunctionCall( QualifiedName.of("mod"), ImmutableList.of(new Identifier("x"), new Identifier("y"))))); }
@Override protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorContext<Context> context) verifyNoAggregateWindowOrGroupingFunctions(functionRegistry, node.getBody(), "Lambda expression"); if (!context.getContext().isExpectingLambda()) { throw new SemanticException(STANDALONE_LAMBDA, node, "Lambda expression should always be used inside a function"); List<LambdaArgumentDeclaration> lambdaArguments = node.getArguments(); Type returnType = process(node.getBody(), new StackableAstVisitorContext<>(Context.inLambda(lambdaScope, fieldToLambdaArgumentDeclaration.build()))); FunctionType functionType = new FunctionType(types, returnType); return setExpressionType(node, functionType);
@Override public Node visitLambda(SqlBaseParser.LambdaContext context) { List<String> arguments = context.identifier().stream() .map(SqlBaseParser.IdentifierContext::getText) .collect(toList()); Expression body = (Expression) visit(context.expression()); return new LambdaExpression(arguments, body); }
@Override protected String visitLambdaExpression(LambdaExpression node, Void context) { StringBuilder builder = new StringBuilder(); builder.append('('); Joiner.on(", ").appendTo(builder, node.getArguments()); builder.append(") -> "); builder.append(process(node.getBody(), context)); return builder.toString(); }
@Override public Node visitLambda(SqlBaseParser.LambdaContext context) { List<LambdaArgumentDeclaration> arguments = visit(context.identifier(), Identifier.class).stream() .map(LambdaArgumentDeclaration::new) .collect(toList()); Expression body = (Expression) visit(context.expression()); return new LambdaExpression(getLocation(context), arguments, body); }
@Override protected String visitLambdaExpression(LambdaExpression node, Boolean unmangleNames) { StringBuilder builder = new StringBuilder(); builder.append('('); Joiner.on(", ").appendTo(builder, node.getArguments()); builder.append(") -> "); builder.append(process(node.getBody(), unmangleNames)); return builder.toString(); }