Skip to content

Commit

Permalink
Fix bug in optimizing a nested count (apache#8459)
Browse files Browse the repository at this point in the history
* Fix nested count optimization

* fmt

* extend comment

* Clippy

* Update datafusion/optimizer/src/optimize_projections.rs

Co-authored-by: Liang-Chi Hsieh <viirya@gmail.com>

* Add sqllogictests

* Fmt

---------

Co-authored-by: Liang-Chi Hsieh <viirya@gmail.com>
  • Loading branch information
2 people authored and appletreeisyellow committed Dec 15, 2023
1 parent d409d07 commit 5f69fc7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
40 changes: 36 additions & 4 deletions datafusion/optimizer/src/optimize_projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ fn optimize_projections(
let new_group_bys = aggregate.group_expr.clone();

// Only use absolutely necessary aggregate expressions required by parent.
let new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs);
let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs);
let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
let necessary_indices =
indices_referred_by_exprs(&aggregate.input, all_exprs_iter)?;
Expand All @@ -213,6 +213,16 @@ fn optimize_projections(
let (aggregate_input, _is_added) =
add_projection_on_top_if_helpful(aggregate_input, necessary_exprs, true)?;

// Aggregate always needs at least one aggregate expression.
// With a nested count we don't require any column as input, but still need to create a correct aggregate
// The aggregate may be optimized out later (select count(*) from (select count(*) from [...]) always returns 1
if new_aggr_expr.is_empty()
&& new_group_bys.is_empty()
&& !aggregate.aggr_expr.is_empty()
{
new_aggr_expr = vec![aggregate.aggr_expr[0].clone()];
}

// Create new aggregate plan with updated input, and absolutely necessary fields.
return Aggregate::try_new(
Arc::new(aggregate_input),
Expand Down Expand Up @@ -857,10 +867,11 @@ fn rewrite_projection_given_requirements(
#[cfg(test)]
mod tests {
use crate::optimize_projections::OptimizeProjections;
use datafusion_common::Result;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Result, TableReference};
use datafusion_expr::{
binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan,
Operator,
binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder,
table_scan, Expr, LogicalPlan, Operator,
};
use std::sync::Arc;

Expand Down Expand Up @@ -909,4 +920,25 @@ mod tests {
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
#[test]
fn test_nested_count() -> Result<()> {
let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]);

let groups: Vec<Expr> = vec![];

let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.aggregate(groups.clone(), vec![count(lit(1))])
.unwrap()
.aggregate(groups, vec![count(lit(1))])
.unwrap()
.build()
.unwrap();

let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\
\n Projection: \
\n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\
\n TableScan: ?table? projection=[]";
assert_optimized_plan_equal(&plan, expected)
}
}
13 changes: 13 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3199,3 +3199,16 @@ FROM my_data
GROUP BY dummy
----
text1, text1, text1


# Queries with nested count(*)

query I
select count(*) from (select count(*) from (select 1));
----
1

query I
select count(*) from (select count(*) a, count(*) b from (select 1));
----
1

0 comments on commit 5f69fc7

Please sign in to comment.