@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { return new LambdaExpression(node.getArguments(), treeRewriter.rewrite(node.getBody(), context)); } }, expression);
@Override protected String visitLambdaExpression(LambdaExpression node, Void context) { StringBuilder builder = new StringBuilder(); builder.append('('); Joiner.on(", ").appendTo(builder, node.getArguments()); builder.append(") -> "); builder.append(process(node.getBody(), context)); return builder.toString(); }
@Override public Expression rewriteTryExpression(TryExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Expression expression = treeRewriter.rewrite(node.getInnerExpression(), context); return new FunctionCall( QualifiedName.of("$internal$try"), ImmutableList.of(new LambdaExpression(ImmutableList.of(), expression))); } }
@Override protected Boolean visitLambdaExpression(LambdaExpression node, Void context) { return process(node.getBody(), context); }
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { for (LambdaArgumentDeclaration argument : node.getArguments()) { String argumentName = argument.getName().getValue(); // Symbol names are unique. As a result, a symbol should never be excluded multiple times. checkArgument(!excludedNames.contains(argumentName)); excludedNames.add(argumentName); } Expression result = treeRewriter.defaultRewrite(node, context); for (LambdaArgumentDeclaration argument : node.getArguments()) { excludedNames.remove(argument.getName().getValue()); } return result; } }
@Override protected String visitLambdaExpression(LambdaExpression node, Void context) { StringBuilder builder = new StringBuilder(); builder.append('('); Joiner.on(", ").appendTo(builder, node.getArguments()); builder.append(") -> "); builder.append(process(node.getBody(), context)); return builder.toString(); }
@Override public Expression rewriteTryExpression(TryExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Expression expression = treeRewriter.rewrite(node.getInnerExpression(), context); return new FunctionCall( QualifiedName.of("$internal$try"), ImmutableList.of(new LambdaExpression(ImmutableList.of(), expression))); } }
@Override protected Boolean visitLambdaExpression(LambdaExpression node, Void context) { return process(node.getBody(), context); }
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { for (LambdaArgumentDeclaration argument : node.getArguments()) { String argumentName = argument.getName().getValue(); // Symbol names are unique. As a result, a symbol should never be excluded multiple times. checkArgument(!excludedNames.contains(argumentName)); excludedNames.add(argumentName); } Expression result = treeRewriter.defaultRewrite(node, context); for (LambdaArgumentDeclaration argument : node.getArguments()) { excludedNames.remove(argument.getName().getValue()); } return result; } }
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { return new LambdaExpression(node.getArguments(), treeRewriter.rewrite(node.getBody(), context)); } }, expression);
@Override protected Void visitLambdaExpression(LambdaExpression node, Set<String> lambdaArgumentNames) { return process(node.getBody(), ImmutableSet.<String>builder() .addAll(lambdaArgumentNames) .addAll(node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Identifier::getValue) .collect(toImmutableSet())) .build()); } }
@Override public Node visitLambda(SqlBaseParser.LambdaContext context) { List<LambdaArgumentDeclaration> arguments = visit(context.identifier(), Identifier.class).stream() .map(LambdaArgumentDeclaration::new) .collect(toList()); Expression body = (Expression) visit(context.expression()); return new LambdaExpression(getLocation(context), arguments, body); }
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { checkState(analysis.getCoercion(node) == null, "cannot coerce a lambda expression"); ImmutableList.Builder<LambdaArgumentDeclaration> newArguments = ImmutableList.builder(); for (LambdaArgumentDeclaration argument : node.getArguments()) { Symbol symbol = lambdaDeclarationToSymbolMap.get(NodeRef.of(argument)); newArguments.add(new LambdaArgumentDeclaration(new Identifier(symbol.getName()))); } Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), null); return new LambdaExpression(newArguments.build(), rewrittenBody); }
@Override protected Void visitLambdaExpression(LambdaExpression node, Set<String> lambdaArgumentNames) { return process(node.getBody(), ImmutableSet.<String>builder() .addAll(lambdaArgumentNames) .addAll(node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Identifier::getValue) .collect(toImmutableSet())) .build()); } }
@Override public Node visitLambda(SqlBaseParser.LambdaContext context) { List<LambdaArgumentDeclaration> arguments = visit(context.identifier(), Identifier.class).stream() .map(LambdaArgumentDeclaration::new) .collect(toList()); Expression body = (Expression) visit(context.expression()); return new LambdaExpression(getLocation(context), arguments, body); }
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { checkState(analysis.getCoercion(node) == null, "cannot coerce a lambda expression"); ImmutableList.Builder<LambdaArgumentDeclaration> newArguments = ImmutableList.builder(); for (LambdaArgumentDeclaration argument : node.getArguments()) { Symbol symbol = lambdaDeclarationToSymbolMap.get(NodeRef.of(argument)); newArguments.add(new LambdaArgumentDeclaration(new Identifier(symbol.getName()))); } Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), null); return new LambdaExpression(newArguments.build(), rewrittenBody); }
@Override protected RowExpression visitLambdaExpression(LambdaExpression node, Void context) { RowExpression body = process(node.getBody(), context); Type type = getType(node); List<Type> typeParameters = type.getTypeParameters(); List<Type> argumentTypes = typeParameters.subList(0, typeParameters.size() - 1); List<String> argumentNames = node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Identifier::getValue) .collect(toImmutableList()); return new LambdaDefinitionExpression(argumentTypes, argumentNames, body); }
@Test public void testRewriteBasicLambda() { final Map<Symbol, Type> symbols = ImmutableMap.of(new Symbol("a"), BigintType.BIGINT); final SymbolAllocator allocator = new SymbolAllocator(symbols); assertEquals(rewrite(expression("x -> a + x"), allocator.getTypes(), allocator), new BindExpression( ImmutableList.of(expression("a")), new LambdaExpression( Stream.of("a_0", "x") .map(Identifier::new) .map(LambdaArgumentDeclaration::new) .collect(toList()), expression("a_0 + x")))); } }
Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), context.withReferencedSymbols(referencedSymbols)); List<Symbol> lambdaArguments = node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Identifier::getValue) newLambdaArguments.add(new LambdaArgumentDeclaration(new Identifier(extraSymbol.getName()))); newLambdaArguments.addAll(node.getArguments()); Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), inlineSymbols(symbolMapping, rewrittenBody));
@Override protected RowExpression visitLambdaExpression(LambdaExpression node, Void context) { RowExpression body = process(node.getBody(), context); Type type = getType(node); List<Type> typeParameters = type.getTypeParameters(); List<Type> argumentTypes = typeParameters.subList(0, typeParameters.size() - 1); List<String> argumentNames = node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Identifier::getValue) .collect(toImmutableList()); return new LambdaDefinitionExpression(argumentTypes, argumentNames, body); }