Skip to content

Commit

Permalink
chore(query): optimize window sort (#16355)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* Update column_binding.rs

* update

* update

* update

* update

* update

---------

Co-authored-by: Bohu <overred.shuttler@gmail.com>
  • Loading branch information
sundy-li and BohuTANG authored Sep 1, 2024
1 parent 7cf6152 commit ad366d5
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 48 deletions.
2 changes: 2 additions & 0 deletions src/query/expression/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ impl FunctionRegistry {
candidates
}

// note that if additional_cast_rules is not empty, default cast rules will not be used.
pub fn get_auto_cast_rules(&self, func_name: &str) -> &[(DataType, DataType)] {
self.additional_cast_rules
.get(func_name)
Expand Down Expand Up @@ -455,6 +456,7 @@ impl FunctionRegistry {
self.default_cast_rules.extend(default_cast_rules);
}

// Note that, if additional_cast_rules is not empty, the default cast rules will not be used
pub fn register_additional_cast_rules(
&mut self,
fn_name: &str,
Expand Down
6 changes: 6 additions & 0 deletions src/query/functions/src/cast_rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ pub fn register(registry: &mut FunctionRegistry) {
registry.register_additional_cast_rules(func_name, CAST_INT_TO_UINT64.iter().cloned());
}

for func_name in ["slice", "get"] {
registry.register_additional_cast_rules(func_name, GENERAL_CAST_RULES.iter().cloned());
registry.register_additional_cast_rules(func_name, CAST_FROM_STRING_RULES.iter().cloned());
registry.register_additional_cast_rules(func_name, CAST_INT_TO_UINT64.iter().cloned());
}

for func_name in ALL_COMP_FUNC_NAMES {
// Disable auto cast from strings, e.g., `1 < '1'`.
registry.register_additional_cast_rules(func_name, GENERAL_CAST_RULES.iter().cloned());
Expand Down
20 changes: 18 additions & 2 deletions src/query/sql/src/planner/binder/column_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ use crate::IndexType;
use crate::Visibility;

// Please use `ColumnBindingBuilder` to construct a new `ColumnBinding`
#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, Eq, PartialEq, Hash)]
pub struct ColumnBinding {
/// Database name of this `ColumnBinding` in current context
pub database_name: Option<String>,
/// Table name of this `ColumnBinding` in current context
pub table_name: Option<String>,
/// Column Position of this `ColumnBinding` in current context
pub column_position: Option<usize>,
Expand All @@ -41,6 +40,23 @@ pub struct ColumnBinding {
pub virtual_computed_expr: Option<String>,
}

const DUMMY_INDEX: usize = usize::MAX;
impl ColumnBinding {
pub fn new_dummy_column(name: String, data_type: Box<DataType>) -> Self {
ColumnBinding {
database_name: None,
table_name: None,
column_position: None,
table_index: None,
column_name: name,
index: DUMMY_INDEX,
data_type,
visibility: Visibility::Visible,
virtual_computed_expr: None,
}
}
}

impl ColumnIndex for ColumnBinding {}

pub struct ColumnBindingBuilder {
Expand Down
19 changes: 19 additions & 0 deletions src/query/sql/src/planner/binder/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,25 @@ pub struct WindowInfo {
pub window_functions_map: HashMap<String, usize>,
}

impl WindowInfo {
pub fn reorder(&mut self) {
self.window_functions
.sort_by(|a, b| b.order_by_items.len().cmp(&a.order_by_items.len()));

self.window_functions_map.clear();
for (i, window) in self.window_functions.iter().enumerate() {
self.window_functions_map
.insert(window.display_name.clone(), i);
}
}
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct WindowFunctionInfo {
pub span: Span,
pub index: IndexType,
pub func: WindowFuncType,
pub display_name: String,
pub arguments: Vec<ScalarItem>,
pub partition_by_items: Vec<ScalarItem>,
pub order_by_items: Vec<WindowOrderByInfo>,
Expand Down Expand Up @@ -375,6 +389,7 @@ impl<'a> WindowRewriter<'a> {

// resolve order by
let mut order_by_items = vec![];

for (i, order) in window.order_by.iter().enumerate() {
let mut order_expr = order.expr.clone();
let mut aggregate_rewriter = self.as_window_aggregate_rewriter();
Expand Down Expand Up @@ -406,6 +421,7 @@ impl<'a> WindowRewriter<'a> {
let window_info = WindowFunctionInfo {
span: window.span,
index,
display_name: window.display_name.clone(),
func: func.clone(),
arguments: window_args,
partition_by_items,
Expand All @@ -421,6 +437,9 @@ impl<'a> WindowRewriter<'a> {
window_infos.window_functions.len() - 1,
);

// we want the window with more order by items resolve firstly
// thus we can eliminate some useless order by items
window_infos.reorder();
let replaced_window = WindowFunc {
span: window.span,
display_name: window.display_name.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ impl Rule for RuleEliminateSort {

if !sort.window_partition.is_empty() {
if let Some((partition, ordering)) = &prop.partition_orderings {
if partition == &sort.window_partition && ordering == &sort.items {
// must has same partition
// if the ordering of the current node is empty, we can eliminate the sort
// eg: explain select number, sum(number - 1) over (partition by number % 3 order by number + 1),
// avg(number) over (partition by number % 3 order by number + 1)
// from numbers(50);
if partition == &sort.window_partition
&& (ordering == &sort.items || sort.sort_items_exclude_partition().is_empty())
{
state.add_result(input.clone());
return Ok(());
}
Expand Down
15 changes: 15 additions & 0 deletions src/query/sql/src/planner/plans/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ pub struct Sort {
pub window_partition: Vec<ScalarItem>,
}

impl Sort {
pub fn sort_items_exclude_partition(&self) -> Vec<SortItem> {
self.items
.iter()
.filter(|item| {
!self
.window_partition
.iter()
.any(|partition| partition.index == item.index)
})
.cloned()
.collect()
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SortItem {
pub index: IndexType,
Expand Down
45 changes: 12 additions & 33 deletions src/query/sql/src/planner/semantic/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ use databend_common_expression::Expr;
use databend_common_expression::RawExpr;
use databend_common_functions::BUILTIN_FUNCTIONS;

use crate::binder::ColumnBindingBuilder;
use crate::plans::ScalarExpr;
use crate::ColumnBinding;
use crate::ColumnEntry;
use crate::IndexType;
use crate::Metadata;
use crate::Visibility;

pub trait TypeProvider<ColumnID: ColumnIndex> {
fn get_type(&self, column_id: &ColumnID) -> Result<DataType>;
Expand Down Expand Up @@ -197,25 +195,19 @@ impl ScalarExpr {
},
ScalarExpr::WindowFunction(win) => RawExpr::ColumnRef {
span: None,
id: ColumnBindingBuilder::new(
id: ColumnBinding::new_dummy_column(
win.display_name.clone(),
usize::MAX,
Box::new(win.func.return_type()),
Visibility::Visible,
)
.build(),
),
data_type: win.func.return_type(),
display_name: win.display_name.clone(),
},
ScalarExpr::AggregateFunction(agg) => RawExpr::ColumnRef {
span: None,
id: ColumnBindingBuilder::new(
id: ColumnBinding::new_dummy_column(
agg.display_name.clone(),
usize::MAX,
Box::new((*agg.return_type).clone()),
Visibility::Visible,
)
.build(),
),
data_type: (*agg.return_type).clone(),
display_name: agg.display_name.clone(),
},
Expand All @@ -241,19 +233,19 @@ impl ScalarExpr {
},
ScalarExpr::SubqueryExpr(subquery) => RawExpr::ColumnRef {
span: subquery.span,
id: new_dummy_column(subquery.data_type()),
id: ColumnBinding::new_dummy_column(
"DUMMY".to_string(),
Box::new(subquery.data_type()),
),
data_type: subquery.data_type(),
display_name: "DUMMY".to_string(),
},
ScalarExpr::UDFCall(udf) => RawExpr::ColumnRef {
span: None,
id: ColumnBindingBuilder::new(
id: ColumnBinding::new_dummy_column(
udf.display_name.clone(),
usize::MAX,
Box::new((*udf.return_type).clone()),
Visibility::Visible,
)
.build(),
),
data_type: (*udf.return_type).clone(),
display_name: udf.display_name.clone(),
},
Expand All @@ -265,13 +257,10 @@ impl ScalarExpr {

ScalarExpr::AsyncFunctionCall(async_func) => RawExpr::ColumnRef {
span: None,
id: ColumnBindingBuilder::new(
id: ColumnBinding::new_dummy_column(
async_func.display_name.clone(),
usize::MAX,
Box::new(async_func.return_type.as_ref().clone()),
Visibility::Visible,
)
.build(),
),
data_type: async_func.return_type.as_ref().clone(),
display_name: async_func.display_name.clone(),
},
Expand All @@ -286,13 +275,3 @@ impl ScalarExpr {
matches!(self, ScalarExpr::BoundColumnRef(_))
}
}

fn new_dummy_column(data_type: DataType) -> ColumnBinding {
ColumnBindingBuilder::new(
"DUMMY".to_string(),
usize::MAX,
Box::new(data_type),
Visibility::Visible,
)
.build()
}
88 changes: 88 additions & 0 deletions tests/sqllogictests/suites/mode/standalone/explain/window.test
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,94 @@ CompoundBlockOperator(Project) × 1
CompoundBlockOperator(Map) × 1
NumbersSourceTransform × 1

# window partition with empty sort items from same window partition with sort items
## explain select a, sum(number - 1) over (partition by number % 3) from (select number, rank()over (partition by number % 3 order by number + 1) a
## from numbers(50)) t(number, a);
query T
explain select a, sum(number - 1) over (partition by number % 3) from (select number, rank()over (partition by number % 3 order by number + 1) a
from numbers(50));
----
Window
├── output columns: [numbers.number (#0), rank_part_0 (#1), rank() OVER ( PARTITION BY number % 3 ORDER BY number + 1 ) (#3), sum_arg_0 (#4), rank_part_0 (#1), sum(number - 1) OVER ( PARTITION BY number % 3 ) (#5)]
├── aggregate function: [sum(sum_arg_0)]
├── partition by: [rank_part_0]
├── order by: []
├── frame: [Range: Preceding(None) ~ Following(None)]
└── EvalScalar
├── output columns: [numbers.number (#0), rank_part_0 (#1), rank() OVER ( PARTITION BY number % 3 ORDER BY number + 1 ) (#3), sum_arg_0 (#4), rank_part_0 (#1)]
├── expressions: [numbers.number (#0) - 1, numbers.number (#0) % 3]
├── estimated rows: 50.00
└── Window
├── output columns: [numbers.number (#0), rank_part_0 (#1), rank_order_0 (#2), rank() OVER ( PARTITION BY number % 3 ORDER BY number + 1 ) (#3)]
├── aggregate function: [rank]
├── partition by: [rank_part_0]
├── order by: [rank_order_0]
├── frame: [Range: Preceding(None) ~ CurrentRow]
└── WindowPartition
├── output columns: [numbers.number (#0), rank_part_0 (#1), rank_order_0 (#2)]
├── hash keys: [rank_part_0]
├── estimated rows: 50.00
└── EvalScalar
├── output columns: [numbers.number (#0), rank_part_0 (#1), rank_order_0 (#2)]
├── expressions: [numbers.number (#0) % 3, numbers.number (#0) + 1]
├── estimated rows: 50.00
└── TableScan
├── table: default.system.numbers
├── output columns: [number (#0)]
├── read rows: 50
├── read size: < 1 KiB
├── partitions total: 1
├── partitions scanned: 1
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 50.00

query T
explain select number, avg(number) over (partition by number % 3), rank() over (partition by number % 3 order by number + 1), sum(number) over (partition by number % 3)
from numbers(50);
----
Window
├── output columns: [numbers.number (#0), avg_part_0 (#1), rank() OVER ( PARTITION BY number % 3 ORDER BY number + 1 ) (#4), avg(number) OVER ( PARTITION BY number % 3 ) (#2), avg_part_0 (#1), sum(number) OVER ( PARTITION BY number % 3 ) (#5)]
├── aggregate function: [sum(number)]
├── partition by: [avg_part_0]
├── order by: []
├── frame: [Range: Preceding(None) ~ Following(None)]
└── EvalScalar
├── output columns: [numbers.number (#0), avg_part_0 (#1), rank() OVER ( PARTITION BY number % 3 ORDER BY number + 1 ) (#4), avg(number) OVER ( PARTITION BY number % 3 ) (#2), avg_part_0 (#1)]
├── expressions: [numbers.number (#0) % 3]
├── estimated rows: 50.00
└── Window
├── output columns: [numbers.number (#0), avg_part_0 (#1), rank() OVER ( PARTITION BY number % 3 ORDER BY number + 1 ) (#4), avg_part_0 (#1), avg(number) OVER ( PARTITION BY number % 3 ) (#2)]
├── aggregate function: [avg(number)]
├── partition by: [avg_part_0]
├── order by: []
├── frame: [Range: Preceding(None) ~ Following(None)]
└── EvalScalar
├── output columns: [numbers.number (#0), avg_part_0 (#1), rank() OVER ( PARTITION BY number % 3 ORDER BY number + 1 ) (#4), avg_part_0 (#1)]
├── expressions: [numbers.number (#0) % 3]
├── estimated rows: 50.00
└── Window
├── output columns: [numbers.number (#0), avg_part_0 (#1), rank_order_0 (#3), rank() OVER ( PARTITION BY number % 3 ORDER BY number + 1 ) (#4)]
├── aggregate function: [rank]
├── partition by: [avg_part_0]
├── order by: [rank_order_0]
├── frame: [Range: Preceding(None) ~ CurrentRow]
└── WindowPartition
├── output columns: [numbers.number (#0), avg_part_0 (#1), rank_order_0 (#3)]
├── hash keys: [avg_part_0]
├── estimated rows: 50.00
└── EvalScalar
├── output columns: [numbers.number (#0), avg_part_0 (#1), rank_order_0 (#3)]
├── expressions: [numbers.number (#0) % 3, numbers.number (#0) + 1]
├── estimated rows: 50.00
└── TableScan
├── table: default.system.numbers
├── output columns: [number (#0)]
├── read rows: 50
├── read size: < 1 KiB
├── partitions total: 1
├── partitions scanned: 1
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 50.00

statement ok
DROP DATABASE test_explain_window;
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ statement ok
DROP TABLE IF EXISTS t

statement ok
create table t(col1 Array(Int Null) not null, col2 Array(String) not null, col3 Array(Date) not null, col4 Array(Timestamp) not null, col5 Array(Array(Int null)) not null)
create table t(index int default 10, col1 Array(Int Null) not null, col2 Array(String) not null, col3 Array(Date) not null, col4 Array(Timestamp) not null, col5 Array(Array(Int null)) not null)

statement ok
insert into t values([1,2,3,3],['x','x','y','z'], ['2022-02-02'], ['2023-01-01 02:00:01'], [[1,2],[],[null]])
insert into t(col1, col2, col3, col4, col5) values([1,2,3,3],['x','x','y','z'], ['2022-02-02'], ['2023-01-01 02:00:01'], [[1,2],[],[null]])

query IIII
select length(col1), length(col2), length(col3), length(col4) from t
----
4 4 1 1

query ITT
select get(col1, 3), get(col2, 2), get(col3, 1) from t
select get(col1, index - 7), get(col2, index - 8), get(col3, index - 9) from t
----
3 x 2022-02-02

query TTTT
select slice(col1, 1), slice(col1, 2, 3), slice(col2, 2), slice(col2, 3, 3) from t
select slice(col1, index - 9), slice(col1, 2, 3), slice(col2, 2), slice(col2, 3, 3) from t
----
[1,2,3,3] [2,3] ['x','y','z'] ['y']

Expand Down
Loading

0 comments on commit ad366d5

Please sign in to comment.