final RelBuilder relBuilder = call.builder(); final Aggregate aggRel = (Aggregate) call.rel(0); final RexBuilder rexBuilder = aggRel.getCluster().getRexBuilder(); final Map<AggregateCall, Integer> mapping = new HashMap<>(); final List<Integer> indexes = new ArrayList<>(); final List<AggregateCall> aggCalls = aggRel.getAggCallList(); final List<AggregateCall> newAggCalls = new ArrayList<>(aggCalls.size()); int nextIdx = aggRel.getGroupCount() + aggRel.getIndicatorCount(); for (int i = 0; i < aggCalls.size(); i++) { AggregateCall aggCall = aggCalls.get(i); if (aggCall.getAggregation().getKind() == SqlKind.COUNT && !aggCall.isDistinct()) { final List<Integer> args = aggCall.getArgList(); final List<Integer> nullableArgs = new ArrayList<>(args.size()); for (int arg : args) { if (aggRel.getInput().getRowType().getFieldList().get(arg).getType().isNullable()) { nullableArgs.add(arg); final Aggregate newAggregate = aggRel.copy(aggRel.getTraitSet(), aggRel.getInput(), aggRel.indicator, aggRel.getGroupSet(), aggRel.getGroupSets(), newAggCalls); if (identity) { call.transformTo(newAggregate); } else { final int offset = aggRel.getGroupCount() + aggRel.getIndicatorCount(); final List<RexNode> projList = Lists.newArrayList(); for (int i = 0; i < offset; ++i) { projList.add( rexBuilder.makeInputRef(
public RelNode align(Aggregate rel, List<RelFieldCollation> collations) { // 1) We extract the group by positions that are part of the collations and // sort them so they respect it LinkedHashSet<Integer> aggregateColumnsOrder = new LinkedHashSet<>(); ImmutableList.Builder<RelFieldCollation> propagateCollations = ImmutableList.builder(); if (rel.getGroupType() == Group.SIMPLE && !collations.isEmpty()) { for (RelFieldCollation c : collations) { if (c.getFieldIndex() < rel.getGroupCount()) { // Group column found if (aggregateColumnsOrder.add(c.getFieldIndex())) { propagateCollations.add(c.copy(rel.getGroupSet().nth(c.getFieldIndex()))); } } } } for (int i = 0; i < rel.getGroupCount(); i++) { if (!aggregateColumnsOrder.contains(i)) { // Not included in the input collations, but can be propagated as this Aggregate // will enforce it propagateCollations.add(new RelFieldCollation(rel.getGroupSet().nth(i))); } } // 2) We propagate final RelNode child = dispatchAlign(rel.getInput(), propagateCollations.build()); // 3) We annotate the Aggregate operator with this info final HiveAggregate newAggregate = (HiveAggregate) rel.copy(rel.getTraitSet(), ImmutableList.of(child)); newAggregate.setAggregateColumnsOrder(aggregateColumnsOrder); return newAggregate; }
private static boolean isEmptyGrpAggr(RelNode gbNode) { // Verify if both groupset and aggrfunction are empty) Aggregate aggrnode = (Aggregate) gbNode; if (aggrnode.getGroupSet().isEmpty() && aggrnode.getAggCallList().isEmpty()) { return true; } return false; }
private ImmutableBitSet generateNewGroupset(Aggregate aggregate, ImmutableBitSet fieldsUsed) { ImmutableBitSet originalGroupSet = aggregate.getGroupSet(); if (aggregate.getGroupSets().size() > 1 || aggregate.getIndicatorCount() > 0 || fieldsUsed.contains(originalGroupSet)) { final RelNode input = aggregate.getInput(); RelMetadataQuery mq = aggregate.getCluster().getMetadataQuery(); if (aggregate.getGroupSet().contains(key)) { groupByUniqueKey = key; break; ImmutableBitSet nonKeyColumns = aggregate.getGroupSet().except(groupByUniqueKey); ImmutableBitSet columnsToRemove = nonKeyColumns.except(fieldsUsed); ImmutableBitSet newGroupSet = aggregate.getGroupSet().except(columnsToRemove);
private boolean isAggWithConstantGbyKeys(final Aggregate aggregate, RelOptRuleCall call) { final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); final RelMetadataQuery mq = call.getMetadataQuery(); final RelOptPredicateList predicates = mq.getPulledUpPredicates(aggregate.getInput()); if (predicates == null) { return false; } final NavigableMap<Integer, RexNode> map = new TreeMap<>(); for (int key : aggregate.getGroupSet()) { final RexInputRef ref = rexBuilder.makeInputRef(aggregate.getInput(), key); if (predicates.constantMap.containsKey(ref)) { map.put(key, predicates.constantMap.get(ref)); } } // None of the group expressions are constant. Nothing to do. if (map.isEmpty()) { return false; } final int groupCount = aggregate.getGroupCount(); if (groupCount == map.size()) { return true; } return false; }
final Aggregate aggregate = call.rel(0); final Join join = call.rel(1); final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); final RelBuilder relBuilder = call.builder(); for (AggregateCall aggregateCall : aggregate.getAggCallList()) { if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) { return; if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) { return; final ImmutableBitSet aggregateColumns = aggregate.getGroupSet(); final RelMetadataQuery mq = call.getMetadataQuery(); final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, final boolean unique; if (!allowFunctions) { assert aggregate.getAggCallList().isEmpty(); : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount); for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) { final SqlAggFunction aggregation = aggCall.e.getAggregation(); final SqlSplittableAggFunction splitter = Preconditions.checkNotNull( aggregation.unwrap(SqlSplittableAggFunction.class)); final AggregateCall call1;
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); final List<AggregateCall> newCalls = new ArrayList<>(aggregate.getAggCallList().size()); final List<RexNode> newProjects = new ArrayList<>(project.getChildExps()); final List<RexNode> newCasts = new ArrayList<>(aggregate.getGroupCount() + aggregate.getAggCallList().size()); final RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory(); for (int fieldNumber : aggregate.getGroupSet()) { newCasts.add(rexBuilder.makeInputRef(project.getChildExps().get(fieldNumber).getType(), fieldNumber)); for (AggregateCall aggregateCall : aggregate.getAggCallList()) { AggregateCall newCall = null; final RexNode rexNode = project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())); if (aggregateCall.isDistinct()) { final RelDataType oldType = aggregate.getRowType().getFieldList().get(i).getType(); if (!newCalls.equals(aggregate.getAggCallList())) { final RelBuilder relBuilder = call .builder() aggregate.getGroupSet(), aggregate.getGroupSets() );
if (!joinInfo.rightSet().equals( ImmutableBitSet.range(aggregate.getGroupCount()))) { call.transformTo(topOperator.copy(topOperator.getTraitSet(), ImmutableList.of(left))); return; final List<Integer> aggregateKeys = aggregate.getGroupSet().asList(); for (int key : joinInfo.rightKeys) { newRightKeyBuilder.add(aggregateKeys.get(key)); final RelNode newRight = aggregate.getInput(); final RexNode newCondition = RelOptUtil.createEquiJoinCondition(left, joinInfo.leftKeys, newRight, if(aggregate.getInput() instanceof HepRelVertex && ((HepRelVertex)aggregate.getInput()).getCurrentRel() instanceof Join) { Join rightJoin = (Join)(((HepRelVertex)aggregate.getInput()).getCurrentRel()); List<RexNode> projects = new ArrayList<>(); for(int i=0; i<rightJoin.getRowType().getFieldCount(); i++){ projects.add(rexBuilder.makeInputRef(rightJoin, i)); RelNode topProject = call.builder().push(rightJoin).project(projects, rightJoin.getRowType().getFieldNames(), semi = call.builder().push(left).push(topProject).semiJoin(newCondition).build(); } else { semi = call.builder().push(left).push(aggregate.getInput()).semiJoin(newCondition).build(); call.transformTo(topOperator.copy(topOperator.getTraitSet(), ImmutableList.of(semi)));
final RelNode input = agg.getInput(); final RelOptPredicateList inputInfo = mq.getPulledUpPredicates(input); final List<RexNode> aggPullUpPredicates = new ArrayList<>(); final RexBuilder rexBuilder = agg.getCluster().getRexBuilder(); ImmutableBitSet groupKeys = agg.getGroupSet(); Mapping m = Mappings.create(MappingType.PARTIAL_FUNCTION, input.getRowType().getFieldCount(), agg.getRowType().getFieldCount()); if (!rCols.isEmpty() && groupKeys.contains(rCols)) { r = r.accept(new RexPermuteInputsShuttle(m, input)); aggPullUpPredicates.add(r);
final int fieldCount = aggregate.getGroupCount() + aggregate.getAggCallList().size(); if (fieldCount != aggregate.getRowType().getFieldCount()) { throw new ISE( "WTF, expected[%s] to have[%s] fields but it had[%s]", aggregate, fieldCount, aggregate.getRowType().getFieldCount() ); final ImmutableBitSet callsToKeep = projectBits.intersect( ImmutableBitSet.range(aggregate.getGroupCount(), fieldCount) ); if (callsToKeep.cardinality() < aggregate.getAggCallList().size()) { newAggregateCalls.add(aggregate.getAggCallList().get(i - aggregate.getGroupCount())); final Aggregate newAggregate = aggregate.copy( aggregate.getTraitSet(), aggregate.getInput(), aggregate.indicator, aggregate.getGroupSet(), aggregate.getGroupSets(), newAggregateCalls ); final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); for (int i = 0; i < aggregate.getGroupCount(); i++) {
boolean groupingSetsExpression = false; if (groupBy.indicator) { Group aggregateType = Aggregate.Group.induce(groupBy.getGroupSet(), groupBy.getGroupSets()); if (aggregateType == Group.ROLLUP) { b = ASTBuilder.construct(HiveParser.TOK_ROLLUP_GROUPBY, "TOK_ROLLUP_GROUPBY"); RexInputRef iRef = new RexInputRef(groupBy.getGroupSet().nth(pos), groupBy.getCluster().getTypeFactory().createSqlType(SqlTypeName.ANY)); b.add(iRef.accept(new RexVisitor(schema))); for (int pos = 0; pos < groupBy.getGroupCount(); pos++) { if (!hiveAgg.getAggregateColumnsOrder().contains(pos)) { RexInputRef iRef = new RexInputRef(groupBy.getGroupSet().nth(pos), groupBy.getCluster().getTypeFactory().createSqlType(SqlTypeName.ANY)); b.add(iRef.accept(new RexVisitor(schema))); for(ImmutableBitSet groupSet: groupBy.getGroupSets()) { ASTBuilder expression = ASTBuilder.construct( HiveParser.TOK_GROUPING_SETS_EXPRESSION, "TOK_GROUPING_SETS_EXPRESSION"); for (int i : groupSet) { RexInputRef iRef = new RexInputRef(i, groupBy.getCluster().getTypeFactory() .createSqlType(SqlTypeName.ANY)); expression.add(iRef.accept(new RexVisitor(schema))); if (!groupBy.getGroupSet().isEmpty()) { hiveAST.groupBy = b.node();
RexNode joinCond = rexBuilder.makeLiteral(true); if ((joinType != JoinRelType.LEFT) || (joinCond != rexBuilder.makeLiteral(true))) { return; if (!aggregate.getGroupSet().isEmpty()) { return; final List<AggregateCall> aggCalls = aggregate.getAggCallList(); final Set<Integer> isCountStar = Sets.newHashSet(); int nFields = left.getRowType().getFieldCount(); ImmutableBitSet allCols = ImmutableBitSet.range(nFields); int nullIndicatorPos = join.getRowType().getFieldCount() - 1; cluster.getTypeFactory().createTypeWithNullability( argList = Lists.newArrayList(); for (int aggArg : aggCall.getArgList()) { argList.add(aggArg + groupCount); aggregate.getGroupCount(), groupCount)); ImmutableBitSet.range(groupCount); Aggregate newAggregate = (Aggregate) relBuilder.push(joinOutputProject)
Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) { final int nGroups = oldAggRel.getGroupCount(); final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); final int iAvgInput = oldCall.getArgList().get(0); final RelDataType sum0InputType = typeFactory.createTypeWithNullability( getFieldType(oldAggRel.getInput(), iAvgInput), true); final RelDataType sumReturnType = getSumReturnType( rexBuilder.getTypeFactory(), sum0InputType, oldCall.getType()); final AggregateCall sumCall = AggregateCall.create( new HiveSqlSumAggFunction( oldCall.isDistinct(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null); rexBuilder.addAggCall(sumCall, nGroups, oldAggRel.indicator,
if ((aggregate.getIndicatorCount() > 0) || (aggregate.getGroupSet().isEmpty()) || fieldsUsed.contains(aggregate.getGroupSet())) { return aggregate; final RelNode input = aggregate.getInput(); final RelDataType rowType = input.getRowType(); RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); final List<RexNode> newProjects = new ArrayList<>(); final List<RexNode> inputExprs = input.getChildExps(); if (inputExprs == null || inputExprs.isEmpty()) { return aggregate; for (int key : aggregate.getGroupSet()) { for (int i = 0; i < rowType.getFieldCount(); i++) { if (aggregate.getGroupSet().get(i)) { newProjects.add(rexBuilder.makeLiteral(true)); } else { newProjects.add(rexBuilder.makeInputRef(input, i)); Aggregate newAggregate = new HiveAggregate(aggregate.getCluster(), aggregate.getTraitSet(), relBuilder.build(), aggregate.getGroupSet(), null, aggregate.getAggCallList()); return newAggregate;
ASTNode cond = where.getCondition().accept(new RexVisitor(schema, false, root.getCluster().getRexBuilder())); hiveAST.where = ASTBuilder.where(cond); ASTBuilder b; boolean groupingSetsExpression = false; Group aggregateType = groupBy.getGroupType(); switch (aggregateType) { case SIMPLE: RexInputRef iRef = new RexInputRef(groupBy.getGroupSet().nth(pos), groupBy.getCluster().getTypeFactory().createSqlType(SqlTypeName.ANY)); b.add(iRef.accept(new RexVisitor(schema, false, root.getCluster().getRexBuilder()))); for (int pos = 0; pos < groupBy.getGroupCount(); pos++) { if (!hiveAgg.getAggregateColumnsOrder().contains(pos)) { RexInputRef iRef = new RexInputRef(groupBy.getGroupSet().nth(pos), groupBy.getCluster().getTypeFactory().createSqlType(SqlTypeName.ANY)); b.add(iRef.accept(new RexVisitor(schema, false, root.getCluster().getRexBuilder()))); for(ImmutableBitSet groupSet: groupBy.getGroupSets()) { ASTBuilder expression = ASTBuilder.construct( HiveParser.TOK_GROUPING_SETS_EXPRESSION, "TOK_GROUPING_SETS_EXPRESSION"); for (int i : groupSet) { RexInputRef iRef = new RexInputRef(i, groupBy.getCluster().getTypeFactory() if (!groupBy.getGroupSet().isEmpty()) { hiveAST.groupBy = b.node();
List<List<Integer>> cleanArgList, Map<Integer, Integer> map, List<Integer> sourceOfForCountDistinct) throws CalciteSemanticException { List<RexNode> originalInputRefs = Lists.transform(aggr.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() { @Override RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, originalInputRefs .get(originalInputRefs.size() - 1), rexBuilder.makeExactLiteral(new BigDecimal( getGroupingIdValue(list, sourceOfForCountDistinct, aggr.getGroupCount())))); if (list.size() == 1) { int pos = list.get(0); RexNode notNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, originalInputRefs.get(pos)); condition = rexBuilder.makeCall(SqlStdOperatorTable.AND, condition, notNull); aggregateCalls.add(aggregateCall); Aggregate aggregate = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), gbInputRel, ImmutableBitSet.of(), null, aggregateCalls); return aggregate; } else { List<RexNode> originalAggrRefs = Lists.transform(aggregate.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() { @Override
final RexBuilder rexBuilder = filterRel.getCluster().getRexBuilder(); final List<RelDataTypeField> origFields = aggRel.getRowType().getFieldList(); final int[] adjustments = new int[origFields.size()]; int j = 0; for (int i : aggRel.getGroupSet()) { adjustments[j] = i - j; j++; if (canPush(aggRel, rCols)) { pushedConditions.add( condition.accept( new RelOptUtil.RexInputConverter(rexBuilder, origFields, aggRel.getInput(0).getRowType().getFieldList(), adjustments))); } else { builder.push(aggRel.getInput()).filter(pushedConditions).build(); if (rel == aggRel.getInput(0)) { return; rel = aggRel.copy(aggRel.getTraitSet(), ImmutableList.of(rel)); rel = builder.push(rel).filter(remainingConditions).build(); call.transformTo(rel);
final RelDataType rowType = aggregate.getRowType(); aggregate.getGroupSet().rebuild(); for (AggregateCall aggCall : aggregate.getAggCallList()) { for (int i : aggCall.getArgList()) { inputFieldsUsed.set(i); final RelNode input = aggregate.getInput(); final Set<RelDataTypeField> inputExtraFields = Collections.emptySet(); final TrimResult trimResult = ImmutableBitSet originalGroupSet = aggregate.getGroupSet(); ImmutableBitSet updatedGroupSet = generateNewGroupset(aggregate, fieldsUsed); ImmutableBitSet gbKeysDeleted = originalGroupSet.except(updatedGroupSet); ImmutableBitSet updatedGroupFields = ImmutableBitSet.range(originalGroupSet.cardinality()); int originalGroupCount = aggregate.getGroupSet().cardinality(); int j = originalGroupCount; int usedAggCallCount = 0; for (int i = 0; i < aggregate.getAggCallList().size(); i++) { if(!updatedGroupSet.equals(aggregate.getGroupSet())) { newGroupSets = ImmutableList.of(newGroupSet); } else { newGroupSets = ImmutableList.copyOf( Iterables.transform(aggregate.getGroupSets(), input1 -> Mappings.apply(inputMapping, input1)));
relBuilder.push(aggregate.getInput()); final List<AggregateCall> originalAggCalls = aggregate.getAggCallList(); final ImmutableBitSet originalGroupSet = aggregate.getGroupSet(); bottomGroupSet.addAll(aggregate.getGroupSet().asList()); for (AggregateCall aggCall : originalAggCalls) { if (aggCall.isDistinct()) { bottomGroupSet.addAll(aggCall.getArgList()); break; // since we only have single distinct call if (!aggCall.isDistinct()) { final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.getArgList(), -1, ImmutableBitSet.of(bottomGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name); bottomAggregateCalls.add(newCall); aggregate.copy( aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(bottomGroupSet), null, bottomAggregateCalls)); aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
RelOptRuleCall ruleCall, Aggregate oldAggRel) { RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); List<AggregateCall> oldCalls = oldAggRel.getAggCallList(); final int groupCount = oldAggRel.getGroupCount(); final int indicatorCount = oldAggRel.getIndicatorCount(); rexBuilder.makeInputRef( getFieldType(oldAggRel, i), i)); relBuilder.push(oldAggRel.getInput()); final List<RexNode> inputExprs = new ArrayList<>(relBuilder.fields()); inputExprs.size() - relBuilder.peek().getRowType().getFieldCount(); if (extraArgCount > 0) { relBuilder.project(inputExprs, CompositeList.of( relBuilder.peek().getRowType().getFieldNames(), Collections.<String>nCopies(extraArgCount, null))); relBuilder.project(projList, oldAggRel.getRowType().getFieldNames()) .convert(oldAggRel.getRowType(), false); ruleCall.transformTo(relBuilder.build());