private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, SemiJoinNode semiJoinNode, StatsProvider statsProvider, Session session, TypeProvider types) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(semiJoinNode.getSource()); PlanNodeStatsEstimate filteringSourceStats = statsProvider.getStats(semiJoinNode.getFilteringSource()); Symbol filteringSourceJoinSymbol = semiJoinNode.getFilteringSourceJoinSymbol(); Symbol sourceJoinSymbol = semiJoinNode.getSourceJoinSymbol(); Optional<SemiJoinOutputFilter> semiJoinOutputFilter = extractSemiJoinOutputFilter(filterNode.getPredicate(), semiJoinNode.getSemiJoinOutput()); if (!semiJoinOutputFilter.isPresent()) { return Optional.empty(); } PlanNodeStatsEstimate semiJoinStats; if (semiJoinOutputFilter.get().isNegated()) { semiJoinStats = computeAntiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol); } else { semiJoinStats = computeSemiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol); } if (semiJoinStats.isOutputRowCountUnknown()) { return Optional.of(PlanNodeStatsEstimate.unknown()); } // apply remaining predicate PlanNodeStatsEstimate filteredStats = filterStatsCalculator.filterStats(semiJoinStats, semiJoinOutputFilter.get().getRemainingPredicate(), session, types); if (filteredStats.isOutputRowCountUnknown()) { return Optional.of(semiJoinStats.mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT)); } return Optional.of(filteredStats); }
public static PlanNodeStatsEstimate computeSemiJoin(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate filteringSourceStats, Symbol sourceJoinSymbol, Symbol filteringSourceJoinSymbol) { return compute(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol, (sourceJoinSymbolStats, filteringSourceJoinSymbolStats) -> min(filteringSourceJoinSymbolStats.getDistinctValuesCount(), sourceJoinSymbolStats.getDistinctValuesCount())); }
assertThat(computeSemiJoin(inputStatistics, inputStatistics, x, w)) .symbolStats(x, stats -> stats .lowValue(xStats.getLowValue()) assertThat(computeSemiJoin(inputStatistics, inputStatistics, x, u)) .symbolStats(x, stats -> stats .lowValue(xStats.getLowValue()) assertThat(computeSemiJoin(inputStatistics, inputStatistics, unknown, u)) .symbolStats(unknown, stats -> stats .nullsFraction(0) assertThat(computeSemiJoin(inputStatistics, inputStatistics, x, unknown)) .symbolStats(x, stats -> stats .nullsFraction(0) assertThat(computeSemiJoin(inputStatistics, inputStatistics, emptyRange, emptyRange)) .outputRowsCount(0); assertThat(computeSemiJoin(inputStatistics, inputStatistics, fractionalNdv, fractionalNdv)) .outputRowsCount(1000) .symbolStats(fractionalNdv, stats -> stats
assertThat(computeAntiJoin(inputStatistics, inputStatistics, u, x)) .symbolStats(u, stats -> stats .lowValue(uStats.getLowValue()) assertThat(computeAntiJoin(inputStatistics, inputStatistics, x, u)) .symbolStats(x, stats -> stats .lowValue(xStats.getLowValue()) assertThat(computeAntiJoin(inputStatistics, inputStatistics, unknown, u)) .symbolStats(unknown, stats -> stats .nullsFraction(0) assertThat(computeAntiJoin(inputStatistics, inputStatistics, x, unknown)) .symbolStats(x, stats -> stats .nullsFraction(0) assertThat(computeAntiJoin(inputStatistics, inputStatistics, emptyRange, emptyRange)) .outputRowsCount(0); assertThat(computeAntiJoin(inputStatistics, inputStatistics, fractionalNdv, fractionalNdv)) .outputRowsCount(500) .symbolStats(fractionalNdv, stats -> stats
public static PlanNodeStatsEstimate computeAntiJoin(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate filteringSourceStats, Symbol sourceJoinSymbol, Symbol filteringSourceJoinSymbol) { return compute(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol, (sourceJoinSymbolStats, filteringSourceJoinSymbolStats) -> max(sourceJoinSymbolStats.getDistinctValuesCount() * MIN_ANTI_JOIN_FILTER_COEFFICIENT, sourceJoinSymbolStats.getDistinctValuesCount() - filteringSourceJoinSymbolStats.getDistinctValuesCount())); }