private String generateJavaSource(RelNode root) throws Exception { StringWriter sw = new StringWriter(); try (PrintWriter pw = new PrintWriter(sw)) { RelNodeCompiler compiler = new RelNodeCompiler(pw, typeFactory); printPrologue(pw); compiler.traverse(root); printMain(pw, root); printEpilogue(pw); } return sw.toString(); }
@Override public Void visitAggregate(Aggregate aggregate, List<Void> inputStreams) throws Exception { beginAggregateStage(aggregate); pw.println(" if (_data != null) {"); pw.println(" List<Object> curGroupValues = getGroupValues(_data);"); pw.println(" if (!correlatedGroupedValues.containsKey(curGroupValues)) {"); pw.println(" correlatedGroupedValues.put(curGroupValues, new ArrayList<CorrelatedValues>());"); pw.println(" }"); pw.println(" correlatedGroupedValues.get(curGroupValues).add(_data);"); pw.println(" if (!state.containsKey(curGroupValues)) {"); pw.println(" state.put(curGroupValues, new HashMap<String, Object>());"); pw.println(" }"); pw.println(" Map<String, Object> accumulators = state.get(curGroupValues);"); for (AggregateCall call : aggregate.getAggCallList()) { aggregate(call); } pw.println(" }"); endStage(); return null; }
private String emitAggregateStmts(Aggregate aggregate) { List<String> res = new ArrayList<>(); StringWriter sw = new StringWriter(); for (AggregateCall call : aggregate.getAggCallList()) { res.add(aggregateResult(call, new PrintWriter(sw))); } return NEW_LINE_JOINER.join(sw.toString(), String.format(" ctx.emit(new CorrelatedValues(correlatedEvents, %s, %s));", groupValueEmitStr("groupValues", aggregate.getGroupSet().cardinality()), Joiner.on(", ").join(res))); }
private void beginAggregateStage(Aggregate n) { pw.print(String.format(AGGREGATE_STAGE_PROLOGUE, getStageName(n), getGroupByIndices(n), emitAggregateStmts(n))); }
@Override public Void visitJoin(Join join, List<Void> inputStreams) { beginJoinStage(join); pw.println(" if (source == left) {"); pw.println(" leftRows.add(_data);"); pw.println(" } else if (source == right) {"); pw.println(" rightRows.add(_data);"); pw.println(" }"); endStage(); return null; }
@Test public void testFilter() throws Exception { String sql = "SELECT ID + 1 FROM FOO WHERE ID > 3"; TestCompilerUtils.CalciteState state = TestCompilerUtils.sqlOverDummyTable(sql); JavaTypeFactory typeFactory = new JavaTypeFactoryImpl( RelDataTypeSystem.DEFAULT); LogicalProject project = (LogicalProject) state.tree(); LogicalFilter filter = (LogicalFilter) project.getInput(); try (StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw) ) { RelNodeCompiler compiler = new RelNodeCompiler(pw, typeFactory); // standalone mode doesn't use inputstreams argument compiler.visitFilter(filter, Collections.EMPTY_LIST); pw.flush(); Assert.assertThat(sw.toString(), containsString("> 3")); } try (StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw) ) { RelNodeCompiler compiler = new RelNodeCompiler(pw, typeFactory); // standalone mode doesn't use inputstreams argument compiler.visitProject(project, Collections.EMPTY_LIST); pw.flush(); Assert.assertThat(sw.toString(), containsString(" + 1")); } } }
private String aggregateResult(AggregateCall call, PrintWriter pw) { SqlAggFunction aggFunction = call.getAggregation(); String aggregationName = call.getAggregation().getName(); Type ty = typeFactory.getJavaClass(call.getType()); String result; if (aggFunction instanceof SqlUserDefinedAggFunction) { AggregateFunction aggregateFunction = ((SqlUserDefinedAggFunction) aggFunction).function; result = doAggregateResult((AggregateFunctionImpl) aggregateFunction, reserveAggVarName(call), ty, pw); } else { List<BuiltinAggregateFunctions.TypeClass> typeClasses = BuiltinAggregateFunctions.TABLE.get(aggregationName); if (typeClasses == null) { throw new UnsupportedOperationException(aggregationName + " Not implemented"); } result = doAggregateResult(AggregateFunctionImpl.create(findMatchingClass(aggregationName, typeClasses, ty)), reserveAggVarName(call), ty, pw); } return result; }
private void aggregate(AggregateCall call) { SqlAggFunction aggFunction = call.getAggregation(); String aggregationName = call.getAggregation().getName(); Type ty = typeFactory.getJavaClass(call.getType()); if (call.getArgList().size() != 1) { if (aggregationName.equals("COUNT")) { if (call.getArgList().size() != 0) { throw new UnsupportedOperationException("Count with nullable fields"); } } } if (aggFunction instanceof SqlUserDefinedAggFunction) { AggregateFunction aggregateFunction = ((SqlUserDefinedAggFunction) aggFunction).function; doAggregate((AggregateFunctionImpl) aggregateFunction, reserveAggVarName(call), ty, call.getArgList()); } else { List<BuiltinAggregateFunctions.TypeClass> typeClasses = BuiltinAggregateFunctions.TABLE.get(aggregationName); if (typeClasses == null) { throw new UnsupportedOperationException(aggregationName + " Not implemented"); } doAggregate(AggregateFunctionImpl.create(findMatchingClass(aggregationName, typeClasses, ty)), reserveAggVarName(call), ty, call.getArgList()); } }
@Override public Void visitProject(Project project, List<Void> inputStreams) throws Exception { beginStage(project); List<RexNode> childExps = project.getChildExps(); RelDataType inputRowType = project.getInput(0).getRowType(); int outputCount = project.getRowType().getFieldCount(); pw.print("Context context = new StreamlineContext(Processor.dataContext);\n"); pw.print("context.values = _data.toArray();\n"); pw.print(String.format("Object[] outputValues = new Object[%d];\n", outputCount)); pw.write(rexCompiler.compileToBlock(childExps, inputRowType).toString()); pw.print(" ctx.emit(new CorrelatedValues(_data.getCorrelated(), outputValues));\n"); endStage(); return null; }
@Override public Void visitFilter(Filter filter, List<Void> inputStreams) throws Exception { beginStage(filter); List<RexNode> childExps = filter.getChildExps(); RelDataType inputRowType = filter.getInput(0).getRowType(); pw.print("Context context = new StreamlineContext(Processor.dataContext);\n"); pw.print("context.values = _data.toArray();\n"); pw.print("Object[] outputValues = new Object[1];\n"); pw.write(rexCompiler.compileToBlock(childExps, inputRowType).toString()); String r = "((Boolean) outputValues[0])"; if (filter.getCondition().getType().isNullable()) { pw.print(String.format(" if (%s != null && %s) { ctx.emit(_data); }\n", r, r)); } else { pw.print(String.format(" if (%s) { ctx.emit(_data); }\n", r, r)); } endStage(); return null; }