Skip to content

Commit

Permalink
fix: Fix output type for list.eval in certain cases (#18570)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Sep 6, 2024
1 parent ef295d2 commit 325dfd1
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 106 deletions.
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ fn run_on_group_by_engine(
let state = ExecutionState::new();
let mut ac = phys_expr.evaluate_on_groups(&df_context, &groups, &state)?;
let out = match ac.agg_state() {
AggState::AggregatedScalar(_) | AggState::Literal(_) => {
AggState::AggregatedScalar(_) => {
let out = ac.aggregated();
out.as_list().into_series()
},
Expand Down
132 changes: 132 additions & 0 deletions py-polars/tests/unit/operations/namespaces/list/test_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

from typing import Any

import pytest

import polars as pl
from polars.exceptions import (
StructFieldNotFoundError,
)
from polars.testing import assert_frame_equal, assert_series_equal


def test_list_eval_dtype_inference() -> None:
grades = pl.DataFrame(
{
"student": ["bas", "laura", "tim", "jenny"],
"arithmetic": [10, 5, 6, 8],
"biology": [4, 6, 2, 7],
"geography": [8, 4, 9, 7],
}
)

rank_pct = pl.col("").rank(descending=True) / pl.col("").count().cast(pl.UInt16)

# the .list.first() would fail if .list.eval did not correctly infer the output type
assert grades.with_columns(
pl.concat_list(pl.all().exclude("student")).alias("all_grades")
).select(
pl.col("all_grades")
.list.eval(rank_pct, parallel=True)
.alias("grades_rank")
.list.first()
).to_series().to_list() == [
0.3333333333333333,
0.6666666666666666,
0.6666666666666666,
0.3333333333333333,
]


def test_list_eval_categorical() -> None:
df = pl.DataFrame({"test": [["a", None]]}, schema={"test": pl.List(pl.Categorical)})
df = df.select(
pl.col("test").list.eval(pl.element().filter(pl.element().is_not_null()))
)
assert_series_equal(
df.get_column("test"), pl.Series("test", [["a"]], dtype=pl.List(pl.Categorical))
)


def test_list_eval_type_coercion() -> None:
last_non_null_value = pl.element().fill_null(3).last()
df = pl.DataFrame({"array_cols": [[1, None]]})

assert df.select(
pl.col("array_cols")
.list.eval(last_non_null_value, parallel=False)
.alias("col_last")
).to_dict(as_series=False) == {"col_last": [[3]]}


def test_list_eval_all_null() -> None:
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, None, None]}).with_columns(
pl.col("bar").cast(pl.List(pl.String))
)

assert df.select(pl.col("bar").list.eval(pl.element())).to_dict(
as_series=False
) == {"bar": [None, None, None]}


def test_empty_eval_dtype_5546() -> None:
# https://github.com/pola-rs/polars/issues/5546
df = pl.DataFrame([{"a": [{"name": 1}, {"name": 2}]}])

dtype = df.dtypes[0]

assert (
df.limit(0).with_columns(
pl.col("a")
.list.eval(pl.element().filter(pl.first().struct.field("name") == 1))
.alias("a_filtered")
)
).dtypes == [dtype, dtype]


def test_list_eval_gather_every_13410() -> None:
df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]]})
out = df.with_columns(result=pl.col("a").list.eval(pl.element().gather_every(2)))
expected = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]], "result": [[1, 3], [4, 6]]})
assert_frame_equal(out, expected)


def test_list_eval_err_raise_15653() -> None:
df = pl.DataFrame({"foo": [[]]})
with pytest.raises(StructFieldNotFoundError):
df.with_columns(bar=pl.col("foo").list.eval(pl.element().struct.field("baz")))


def test_list_eval_type_cast_11188() -> None:
df = pl.DataFrame(
[
{"a": None},
],
schema={"a": pl.List(pl.Int64)},
)
assert df.select(
pl.col("a").list.eval(pl.element().cast(pl.String)).alias("a_str")
).schema == {"a_str": pl.List(pl.String)}


@pytest.mark.parametrize(
"data",
[
{"a": [["0"], ["1"]]},
{"a": [["0", "1"], ["2", "3"]]},
{"a": [["0", "1"]]},
{"a": [["0"]]},
],
)
@pytest.mark.parametrize(
"expr",
[
pl.lit(""),
pl.format("test: {}", pl.element()),
],
)
def test_list_eval_list_output_18510(data: dict[str, Any], expr: pl.Expr) -> None:
df = pl.DataFrame(data)
result = df.select(pl.col("a").list.eval(pl.lit("")))
assert result.to_series().dtype == pl.List(pl.String)
106 changes: 1 addition & 105 deletions py-polars/tests/unit/operations/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@
import pytest

import polars as pl
from polars.exceptions import (
ComputeError,
OutOfBoundsError,
SchemaError,
StructFieldNotFoundError,
)
from polars.exceptions import ComputeError, OutOfBoundsError, SchemaError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -342,44 +337,6 @@ def test_slice() -> None:
assert s.list.slice(-5, 2).to_list() == [[1], []]


def test_list_eval_dtype_inference() -> None:
grades = pl.DataFrame(
{
"student": ["bas", "laura", "tim", "jenny"],
"arithmetic": [10, 5, 6, 8],
"biology": [4, 6, 2, 7],
"geography": [8, 4, 9, 7],
}
)

rank_pct = pl.col("").rank(descending=True) / pl.col("").count().cast(pl.UInt16)

# the .list.first() would fail if .list.eval did not correctly infer the output type
assert grades.with_columns(
pl.concat_list(pl.all().exclude("student")).alias("all_grades")
).select(
pl.col("all_grades")
.list.eval(rank_pct, parallel=True)
.alias("grades_rank")
.list.first()
).to_series().to_list() == [
0.3333333333333333,
0.6666666666666666,
0.6666666666666666,
0.3333333333333333,
]


def test_list_eval_categorical() -> None:
df = pl.DataFrame({"test": [["a", None]]}, schema={"test": pl.List(pl.Categorical)})
df = df.select(
pl.col("test").list.eval(pl.element().filter(pl.element().is_not_null()))
)
assert_series_equal(
df.get_column("test"), pl.Series("test", [["a"]], dtype=pl.List(pl.Categorical))
)


def test_list_ternary_concat() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -423,17 +380,6 @@ def test_arr_contains_categorical() -> None:
assert result.to_dict(as_series=False) == expected


def test_list_eval_type_coercion() -> None:
last_non_null_value = pl.element().fill_null(3).last()
df = pl.DataFrame({"array_cols": [[1, None]]})

assert df.select(
pl.col("array_cols")
.list.eval(last_non_null_value, parallel=False)
.alias("col_last")
).to_dict(as_series=False) == {"col_last": [[3]]}


def test_list_slice() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -476,21 +422,6 @@ def test_list_sliced_get_5186() -> None:
assert_frame_equal(out1, out2)


def test_empty_eval_dtype_5546() -> None:
# https://github.com/pola-rs/polars/issues/5546
df = pl.DataFrame([{"a": [{"name": 1}, {"name": 2}]}])

dtype = df.dtypes[0]

assert (
df.limit(0).with_columns(
pl.col("a")
.list.eval(pl.element().filter(pl.first().struct.field("name") == 1))
.alias("a_filtered")
)
).dtypes == [dtype, dtype]


def test_list_amortized_apply_explode_5812() -> None:
s = pl.Series([None, [1, 3], [0, -3], [1, 2, 2]])
assert s.list.sum().to_list() == [None, 4, -3, 5]
Expand Down Expand Up @@ -548,16 +479,6 @@ def test_list_gather() -> None:
]


def test_list_eval_all_null() -> None:
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, None, None]}).with_columns(
pl.col("bar").cast(pl.List(pl.String))
)

assert df.select(pl.col("bar").list.eval(pl.element())).to_dict(
as_series=False
) == {"bar": [None, None, None]}


def test_list_function_group_awareness() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -825,13 +746,6 @@ def test_list_get_logical_type() -> None:
assert_series_equal(out, expected)


def test_list_eval_gater_every_13410() -> None:
df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]]})
out = df.with_columns(result=pl.col("a").list.eval(pl.element().gather_every(2)))
expected = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]], "result": [[1, 3], [4, 6]]})
assert_frame_equal(out, expected)


def test_list_gather_every() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -896,24 +810,6 @@ def test_list_get_with_null() -> None:
assert_frame_equal(out, expected)


def test_list_eval_err_raise_15653() -> None:
df = pl.DataFrame({"foo": [[]]})
with pytest.raises(StructFieldNotFoundError):
df.with_columns(bar=pl.col("foo").list.eval(pl.element().struct.field("baz")))


def test_list_sum_bool_schema() -> None:
q = pl.LazyFrame({"x": [[True, True, False]]})
assert q.select(pl.col("x").list.sum()).collect_schema()["x"] == pl.UInt32


def test_list_eval_type_cast_11188() -> None:
df = pl.DataFrame(
[
{"a": None},
],
schema={"a": pl.List(pl.Int64)},
)
assert df.select(
pl.col("a").list.eval(pl.element().cast(pl.String)).alias("a_str")
).schema == {"a_str": pl.List(pl.String)}

0 comments on commit 325dfd1

Please sign in to comment.