@Override public Expression rewriteTryExpression(TryExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Expression expression = treeRewriter.rewrite(node.getInnerExpression(), context); return new FunctionCall( QualifiedName.of("$internal$try"), ImmutableList.of(new LambdaExpression(ImmutableList.of(), expression))); } }
@Override public Expression rewriteTryExpression(TryExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Expression expression = treeRewriter.rewrite(node.getInnerExpression(), context); return new FunctionCall( QualifiedName.of("$internal$try"), ImmutableList.of(new LambdaExpression(ImmutableList.of(), expression))); } }
@Override public Node visitLambda(SqlBaseParser.LambdaContext context) { List<LambdaArgumentDeclaration> arguments = visit(context.identifier(), Identifier.class).stream() .map(LambdaArgumentDeclaration::new) .collect(toList()); Expression body = (Expression) visit(context.expression()); return new LambdaExpression(getLocation(context), arguments, body); }
@Override public Node visitLambda(SqlBaseParser.LambdaContext context) { List<LambdaArgumentDeclaration> arguments = visit(context.identifier(), Identifier.class).stream() .map(LambdaArgumentDeclaration::new) .collect(toList()); Expression body = (Expression) visit(context.expression()); return new LambdaExpression(getLocation(context), arguments, body); }
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { return new LambdaExpression(node.getArguments(), treeRewriter.rewrite(node.getBody(), context)); } }, expression);
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { return new LambdaExpression(node.getArguments(), treeRewriter.rewrite(node.getBody(), context)); } }, expression);
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { checkState(analysis.getCoercion(node) == null, "cannot coerce a lambda expression"); ImmutableList.Builder<LambdaArgumentDeclaration> newArguments = ImmutableList.builder(); for (LambdaArgumentDeclaration argument : node.getArguments()) { Symbol symbol = lambdaDeclarationToSymbolMap.get(NodeRef.of(argument)); newArguments.add(new LambdaArgumentDeclaration(new Identifier(symbol.getName()))); } Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), null); return new LambdaExpression(newArguments.build(), rewrittenBody); }
@Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { checkState(analysis.getCoercion(node) == null, "cannot coerce a lambda expression"); ImmutableList.Builder<LambdaArgumentDeclaration> newArguments = ImmutableList.builder(); for (LambdaArgumentDeclaration argument : node.getArguments()) { Symbol symbol = lambdaDeclarationToSymbolMap.get(NodeRef.of(argument)); newArguments.add(new LambdaArgumentDeclaration(new Identifier(symbol.getName()))); } Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), null); return new LambdaExpression(newArguments.build(), rewrittenBody); }
Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), inlineSymbols(symbolMapping, rewrittenBody));
Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), inlineSymbols(symbolMapping, rewrittenBody));
@Test public void testRewriteBasicLambda() { final Map<Symbol, Type> symbols = ImmutableMap.of(new Symbol("a"), BigintType.BIGINT); final SymbolAllocator allocator = new SymbolAllocator(symbols); assertEquals(rewrite(expression("x -> a + x"), allocator.getTypes(), allocator), new BindExpression( ImmutableList.of(expression("a")), new LambdaExpression( Stream.of("a_0", "x") .map(Identifier::new) .map(LambdaArgumentDeclaration::new) .collect(toList()), expression("a_0 + x")))); } }
@Test public void testRewriteBasicLambda() { final Map<Symbol, Type> symbols = ImmutableMap.of(new Symbol("a"), BigintType.BIGINT); final SymbolAllocator allocator = new SymbolAllocator(symbols); assertEquals(rewrite(expression("x -> a + x"), allocator.getTypes(), allocator), new BindExpression( ImmutableList.of(expression("a")), new LambdaExpression( Stream.of("a_0", "x") .map(Identifier::new) .map(LambdaArgumentDeclaration::new) .collect(toList()), expression("a_0 + x")))); } }
@Test public void testTryExpressionDesugaringRewriter() { // 1 + try(2) Expression before = new ArithmeticBinaryExpression( ADD, new DecimalLiteral("1"), new TryExpression(new DecimalLiteral("2"))); // 1 + try_function(() -> 2) Expression after = new ArithmeticBinaryExpression( ADD, new DecimalLiteral("1"), new FunctionCall( QualifiedName.of("$internal$try"), ImmutableList.of(new LambdaExpression(ImmutableList.of(), new DecimalLiteral("2"))))); assertEquals(DesugarTryExpressionRewriter.rewrite(before), after); } }
@Test public void testTryExpressionDesugaringRewriter() { // 1 + try(2) Expression before = new ArithmeticBinaryExpression( ADD, new DecimalLiteral("1"), new TryExpression(new DecimalLiteral("2"))); // 1 + try_function(() -> 2) Expression after = new ArithmeticBinaryExpression( ADD, new DecimalLiteral("1"), new FunctionCall( QualifiedName.of("$internal$try"), ImmutableList.of(new LambdaExpression(ImmutableList.of(), new DecimalLiteral("2"))))); assertEquals(DesugarTryExpressionRewriter.rewrite(before), after); } }
@Test public void testLambda() { assertExpression("() -> x", new LambdaExpression( ImmutableList.of(), new Identifier("x"))); assertExpression("x -> sin(x)", new LambdaExpression( ImmutableList.of(new LambdaArgumentDeclaration(identifier("x"))), new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new Identifier("x"))))); assertExpression("(x, y) -> mod(x, y)", new LambdaExpression( ImmutableList.of(new LambdaArgumentDeclaration(identifier("x")), new LambdaArgumentDeclaration(identifier("y"))), new FunctionCall( QualifiedName.of("mod"), ImmutableList.of(new Identifier("x"), new Identifier("y"))))); }