Skip to content

Commit

Permalink
fix: Fix filter incorrectly pushed past struct unnest when unnested c…
Browse files Browse the repository at this point in the history
…olumn name matches upper column name (pola-rs#19638)
  • Loading branch information
nameexhaustion authored and tylerriccio33 committed Nov 8, 2024
1 parent a6d6047 commit 5b8c689
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
22 changes: 22 additions & 0 deletions crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,28 @@ impl<'a> PredicatePushDown<'a> {
expr_arena,
))
},
FunctionIR::Unnest { columns } => {
let exclude = columns.iter().cloned().collect::<PlHashSet<_>>();

let local_predicates =
transfer_to_local_by_name(expr_arena, &mut acc_predicates, |x| {
exclude.contains(x)
});

let lp = self.pushdown_and_continue(
lp,
acc_predicates,
lp_arena,
expr_arena,
false,
)?;
Ok(self.optional_apply_predicate(
lp,
local_predicates,
lp_arena,
expr_arena,
))
},
_ => self.pushdown_and_continue(
lp,
acc_predicates,
Expand Down
42 changes: 42 additions & 0 deletions py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,45 @@ def test_predicate_push_down_list_gather_17492() -> None:
.filter(pl.col("val").list.get(1, null_on_oob=True) == 1)
.explain()
)


def test_predicate_pushdown_struct_unnest_19632() -> None:
lf = pl.LazyFrame({"a": [{"a": 1, "b": 2}]}).unnest("a")

q = lf.filter(pl.col("a") == 1)
plan = q.explain()

assert "FILTER" in plan
assert plan.index("FILTER") < plan.index("UNNEST")

assert_frame_equal(
q.collect(),
pl.DataFrame({"a": 1, "b": 2}),
)

# With `pl.struct()`
lf = pl.LazyFrame({"a": 1, "b": 2}).select(pl.struct(pl.all())).unnest("a")

q = lf.filter(pl.col("a") == 1)
plan = q.explain()

assert "FILTER" in plan
assert plan.index("FILTER") < plan.index("UNNEST")

assert_frame_equal(
q.collect(),
pl.DataFrame({"a": 1, "b": 2}),
)

# With `value_counts()`
lf = pl.LazyFrame({"a": [1]}).select(pl.col("a").value_counts()).unnest("a")

q = lf.filter(pl.col("a") == 1)
plan = q.explain()

assert plan.index("FILTER") < plan.index("UNNEST")

assert_frame_equal(
q.collect(),
pl.DataFrame({"a": 1, "count": 1}, schema={"a": pl.Int64, "count": pl.UInt32}),
)

0 comments on commit 5b8c689

Please sign in to comment.