Skip to content

Commit

Permalink
fix(query): disallow push down filters that use columns in window ord…
Browse files Browse the repository at this point in the history
…er by columns or function argument columns (databendlabs#17353)

fix
  • Loading branch information
forsaken628 authored Jan 22, 2025
1 parent 17226b8 commit 97e2381
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;
use std::sync::Arc;

use crate::optimizer::extract::Matcher;
Expand Down Expand Up @@ -76,54 +77,55 @@ impl Rule for RulePushDownFilterWindow {
s_expr: &SExpr,
state: &mut TransformResult,
) -> databend_common_exception::Result<()> {
let filter: Filter = s_expr.plan().clone().try_into()?;
let Filter { predicates } = s_expr.plan().clone().try_into()?;
let window_expr = s_expr.child(0)?;
let window: Window = window_expr.plan().clone().try_into()?;
let partition_by_columns = window.partition_by_columns()?;
let allowed = window.partition_by_columns()?;
let rejected = HashSet::from_iter(
window
.order_by_columns()?
.into_iter()
.chain(window.function.used_columns()),
);

let mut pushed_down_predicates = vec![];
let mut remaining_predicates = vec![];
for predicate in filter.predicates.into_iter() {
let predicate_used_columns = predicate.used_columns();
if predicate_used_columns.is_subset(&partition_by_columns) {
pushed_down_predicates.push(predicate);
} else {
remaining_predicates.push(predicate)
}
let (pushed_down, remaining): (Vec<_>, Vec<_>) =
predicates.into_iter().partition(|predicate| {
let used = predicate.used_columns();
used.is_subset(&allowed) && used.is_disjoint(&rejected)
});
if pushed_down.is_empty() {
return Ok(());
}

if !pushed_down_predicates.is_empty() {
let pushed_down_filter = Filter {
predicates: pushed_down_predicates,
let pushed_down_filter = Filter {
predicates: pushed_down,
};
let result = if remaining.is_empty() {
SExpr::create_unary(
Arc::new(window.into()),
Arc::new(SExpr::create_unary(
Arc::new(pushed_down_filter.into()),
Arc::new(window_expr.child(0)?.clone()),
)),
)
} else {
let remaining_filter = Filter {
predicates: remaining,
};
let result = if remaining_predicates.is_empty() {
SExpr::create_unary(
let mut s_expr = SExpr::create_unary(
Arc::new(remaining_filter.into()),
Arc::new(SExpr::create_unary(
Arc::new(window.into()),
Arc::new(SExpr::create_unary(
Arc::new(pushed_down_filter.into()),
Arc::new(window_expr.child(0)?.clone()),
)),
)
} else {
let remaining_filter = Filter {
predicates: remaining_predicates,
};
let mut s_expr = SExpr::create_unary(
Arc::new(remaining_filter.into()),
Arc::new(SExpr::create_unary(
Arc::new(window.into()),
Arc::new(SExpr::create_unary(
Arc::new(pushed_down_filter.into()),
Arc::new(window_expr.child(0)?.clone()),
)),
)),
);
s_expr.set_applied_rule(&self.id);
s_expr
};
state.add_result(result);
}

)),
);
s_expr.set_applied_rule(&self.id);
s_expr
};
state.add_result(result);
Ok(())
}

Expand Down
34 changes: 20 additions & 14 deletions src/query/sql/src/planner/plans/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,20 @@ impl Window {

used_columns.insert(self.index);
used_columns.extend(self.function.used_columns());
used_columns.extend(self.arguments_columns()?);
used_columns.extend(self.partition_by_columns()?);
used_columns.extend(self.order_by_columns()?);

for arg in self.arguments.iter() {
used_columns.insert(arg.index);
used_columns.extend(arg.scalar.used_columns())
}

for part in self.partition_by.iter() {
used_columns.insert(part.index);
used_columns.extend(part.scalar.used_columns())
}
Ok(used_columns)
}

for sort in self.order_by.iter() {
used_columns.insert(sort.order_by_item.index);
used_columns.extend(sort.order_by_item.scalar.used_columns())
pub fn arguments_columns(&self) -> Result<ColumnSet> {
let mut col_set = ColumnSet::new();
for arg in self.arguments.iter() {
col_set.insert(arg.index);
col_set.extend(arg.scalar.used_columns())
}

Ok(used_columns)
Ok(col_set)
}

// `Window.partition_by_columns` used in `RulePushDownFilterWindow` only consider `partition_by` field,
Expand All @@ -101,6 +98,15 @@ impl Window {
}
Ok(col_set)
}

pub fn order_by_columns(&self) -> Result<ColumnSet> {
let mut col_set = ColumnSet::new();
for sort in self.order_by.iter() {
col_set.insert(sort.order_by_item.index);
col_set.extend(sort.order_by_item.scalar.used_columns())
}
Ok(col_set)
}
}

impl Operator for Window {
Expand Down
75 changes: 75 additions & 0 deletions tests/sqllogictests/suites/mode/standalone/explain/window.test
Original file line number Diff line number Diff line change
Expand Up @@ -733,5 +733,80 @@ Window
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 50.00

query T
explain with test as ( select number % 10 as id, number as full_matched, max(number) OVER ( PARTITION BY id ) max from numbers(1000)) select * from test where full_matched = 3;
----
EvalScalar
├── output columns: [numbers.number (#0), max(number) OVER (PARTITION BY id) (#2), id (#3)]
├── expressions: [numbers.number (#0) % 10]
├── estimated rows: 0.40
└── Filter
├── output columns: [numbers.number (#0), max(number) OVER (PARTITION BY id) (#2)]
├── filters: [numbers.number (#0) = 3]
├── estimated rows: 0.40
└── Window
├── output columns: [numbers.number (#0), max_part_0 (#1), max(number) OVER (PARTITION BY id) (#2)]
├── aggregate function: [max(number)]
├── partition by: [max_part_0]
├── order by: []
├── frame: [Range: Preceding(None) ~ Following(None)]
└── WindowPartition
├── output columns: [numbers.number (#0), max_part_0 (#1)]
├── hash keys: [max_part_0]
├── estimated rows: 1000.00
└── EvalScalar
├── output columns: [numbers.number (#0), max_part_0 (#1)]
├── expressions: [numbers.number (#0) % 10]
├── estimated rows: 1000.00
└── TableScan
├── table: default.system.numbers
├── output columns: [number (#0)]
├── read rows: 1000
├── read size: 7.81 KiB
├── partitions total: 1
├── partitions scanned: 1
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 1000.00

query T
explain with test as (select number % 10 as id, number as full_matched from numbers(1000) QUALIFY row_number() OVER (PARTITION BY id ORDER BY number DESC)=1) select full_matched, count() from test group by full_matched having full_matched = 3;
----
AggregateFinal
├── output columns: [count() (#3), numbers.number (#0)]
├── group by: [number]
├── aggregate functions: [count()]
├── estimated rows: 0.40
└── AggregatePartial
├── group by: [number]
├── aggregate functions: [count()]
├── estimated rows: 0.40
└── Filter
├── output columns: [numbers.number (#0)]
├── filters: [numbers.number (#0) = 3, row_number() OVER (PARTITION BY id ORDER BY number DESC) (#2) = 1]
├── estimated rows: 0.40
└── Window
├── output columns: [numbers.number (#0), id (#1), row_number() OVER (PARTITION BY id ORDER BY number DESC) (#2)]
├── aggregate function: [row_number]
├── partition by: [id]
├── order by: [number]
├── frame: [Range: Preceding(None) ~ CurrentRow]
└── WindowPartition
├── output columns: [numbers.number (#0), id (#1)]
├── hash keys: [id]
├── estimated rows: 1000.00
└── EvalScalar
├── output columns: [numbers.number (#0), id (#1)]
├── expressions: [numbers.number (#0) % 10]
├── estimated rows: 1000.00
└── TableScan
├── table: default.system.numbers
├── output columns: [number (#0)]
├── read rows: 1000
├── read size: 7.81 KiB
├── partitions total: 1
├── partitions scanned: 1
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 1000.00

statement ok
DROP DATABASE test_explain_window;

0 comments on commit 97e2381

Please sign in to comment.