Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix predicate pushdown for .list.(get|gather) #17511

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};

#[cfg(feature = "dtype-array")]
pub(super) use array::ArrayFunction;
pub(crate) use array::ArrayFunction;
#[cfg(feature = "cov")]
pub(crate) use correlation::CorrelationMethod;
#[cfg(feature = "fused")]
pub(crate) use fused::FusedOperator;
pub(super) use list::ListFunction;
pub(crate) use list::ListFunction;
use polars_core::prelude::*;
#[cfg(feature = "random")]
pub(crate) use random::RandomMethod;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,6 @@ pub(super) fn process_join(
let mut filter_left = false;
let mut filter_right = false;

debug_assert_aexpr_allows_predicate_pushdown(predicate.node(), expr_arena);

if !block_pushdown_left && check_input_node(predicate.node(), &schema_left, expr_arena) {
insert_and_combine_predicate(&mut pushdown_left, &predicate, expr_arena);
filter_left = true;
Expand Down
44 changes: 23 additions & 21 deletions crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<'a> PredicatePushDown<'a> {
let input = inputs[inputs.len() - 1];

let (eligibility, alias_rename_map) =
pushdown_eligibility(&exprs, &acc_predicates, expr_arena)?;
pushdown_eligibility(&exprs, &[], &acc_predicates, expr_arena)?;

let local_predicates = match eligibility {
PushdownEligibility::Full => vec![],
Expand Down Expand Up @@ -265,22 +265,28 @@ impl<'a> PredicatePushDown<'a> {
let tmp_key = Arc::<str>::from(&*temporary_unique_key(&acc_predicates));
acc_predicates.insert(tmp_key.clone(), predicate.clone());

let local_predicates =
match pushdown_eligibility(&[], &acc_predicates, expr_arena)?.0 {
PushdownEligibility::Full => vec![],
PushdownEligibility::Partial { to_local } => {
let mut out = Vec::with_capacity(to_local.len());
for key in to_local {
out.push(acc_predicates.remove(&key).unwrap());
}
out
},
PushdownEligibility::NoPushdown => {
let out = acc_predicates.drain().map(|t| t.1).collect();
acc_predicates.clear();
out
},
};
let local_predicates = match pushdown_eligibility(
&[],
&[(tmp_key.clone(), predicate.clone())],
&acc_predicates,
expr_arena,
)?
.0
{
PushdownEligibility::Full => vec![],
PushdownEligibility::Partial { to_local } => {
let mut out = Vec::with_capacity(to_local.len());
for key in to_local {
out.push(acc_predicates.remove(&key).unwrap());
}
out
},
PushdownEligibility::NoPushdown => {
let out = acc_predicates.drain().map(|t| t.1).collect();
acc_predicates.clear();
out
},
};

if let Some(predicate) = acc_predicates.remove(&tmp_key) {
insert_and_combine_predicate(&mut acc_predicates, &predicate, expr_arena);
Expand Down Expand Up @@ -327,10 +333,6 @@ impl<'a> PredicatePushDown<'a> {
file_options: options,
output_schema,
} => {
for e in acc_predicates.values() {
debug_assert_aexpr_allows_predicate_pushdown(e.node(), expr_arena);
}

let local_predicates = match &scan_type {
#[cfg(feature = "parquet")]
FileScan::Parquet { .. } => vec![],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ where
local_predicates
}

/// Extends a stack of nodes with new nodes from `ae` (with some filtering), to support traversing
/// an expression tree to check predicate PD eligibility. Generally called repeatedly with the same
/// stack until all nodes are exhausted.
fn check_and_extend_predicate_pd_nodes(
stack: &mut Vec<Node>,
ae: &AExpr,
Expand All @@ -148,6 +151,22 @@ fn check_and_extend_predicate_pd_nodes(
// rely on the height of the dataframe at this level and thus need
// to block pushdown.
AExpr::Literal(lit) => !lit.projects_as_scalar(),
// Rows that go OOB on get/gather may be filtered out in earlier operations,
// so we don't push these down.
AExpr::Function {
function: FunctionExpr::ListExpr(ListFunction::Get(false)),
..
} => true,
#[cfg(feature = "list_gather")]
AExpr::Function {
function: FunctionExpr::ListExpr(ListFunction::Gather(false)),
..
} => true,
#[cfg(feature = "dtype-array")]
AExpr::Function {
function: FunctionExpr::ArrayExpr(ArrayFunction::Get(false)),
..
} => true,
ae => ae.groups_sensitive(),
} {
false
Expand Down Expand Up @@ -185,31 +204,6 @@ fn check_and_extend_predicate_pd_nodes(
}
}

/// An expression blocks predicates from being pushed past it if its results for
/// the subset where the predicate evaluates as true becomes different compared
/// to if it was performed before the predicate was applied. This is in general
/// any expression that produces outputs based on groups of values
/// (i.e. groups-wise) rather than individual values (i.e. element-wise).
///
/// Examples of expressions whose results would change, and thus block push-down:
/// - any aggregation - sum, mean, first, last, min, max etc.
/// - sorting - as the sort keys would change between filters
pub(super) fn aexpr_blocks_predicate_pushdown(node: Node, expr_arena: &Arena<AExpr>) -> bool {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused old code

let mut stack = Vec::<Node>::with_capacity(4);
stack.push(node);

// Cannot use `has_aexpr` because we need to ignore any literals in the RHS
// of an `is_in` operation.
while let Some(node) = stack.pop() {
let ae = expr_arena.get(node);

if !check_and_extend_predicate_pd_nodes(&mut stack, ae, expr_arena) {
return true;
}
}
false
}

/// * `col(A).alias(B).alias(C) => (C, A)`
/// * `col(A) => (A, A)`
/// * `col(A).sum().alias(B) => None`
Expand Down Expand Up @@ -240,6 +234,7 @@ pub enum PushdownEligibility {
#[allow(clippy::type_complexity)]
pub fn pushdown_eligibility(
projection_nodes: &[ExprIR],
new_predicates: &[(Arc<str>, ExprIR)],
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
acc_predicates: &PlHashMap<Arc<str>, ExprIR>,
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<(PushdownEligibility, PlHashMap<Arc<str>, Arc<str>>)> {
Expand Down Expand Up @@ -376,7 +371,7 @@ pub fn pushdown_eligibility(
common_window_inputs = new;
}

for e in acc_predicates.values() {
for (_, e) in new_predicates.iter() {
debug_assert!(ae_nodes_stack.is_empty());
ae_nodes_stack.push(e.node());

Expand Down Expand Up @@ -447,13 +442,3 @@ pub fn pushdown_eligibility(
_ => Ok((PushdownEligibility::Partial { to_local }, alias_to_col_map)),
}
}

/// Used in places that previously handled blocking exprs before refactoring.
/// Can probably be eventually removed if it isn't catching anything.
#[inline(always)]
pub(super) fn debug_assert_aexpr_allows_predicate_pushdown(node: Node, expr_arena: &Arena<AExpr>) {
debug_assert!(
!aexpr_blocks_predicate_pushdown(node, expr_arena),
"Predicate pushdown: Did not expect blocking exprs at this point, please open an issue."
);
}
16 changes: 16 additions & 0 deletions py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,19 @@ def test_predicate_push_down_with_alias_15442() -> None:
.collect(predicate_pushdown=True)
)
assert output.to_dict(as_series=False) == {"a": [1]}


def test_predicate_push_down_list_gather_17492() -> None:
lf = pl.LazyFrame({"val": [[1], [1, 1]], "len": [1, 2]})

assert_frame_equal(
lf.filter(pl.col("len") == 2).filter(pl.col("val").list.get(1) == 1),
lf.slice(1, 1),
)

# null_on_oob=True can pass
assert "FILTER" not in (
lf.filter(pl.col("len") == 2)
.filter(pl.col("val").list.get(1, null_on_oob=True) == 1)
.explain()
)