Skip to content

Commit

Permalink
fix: Output index type instead of u32 for sum_horizontal with boole…
Browse files Browse the repository at this point in the history
…an inputs (#20531)
  • Loading branch information
mcrumiller authored Jan 3, 2025
1 parent 409f091 commit da0b589
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 12 deletions.
6 changes: 3 additions & 3 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ pub fn sum_horizontal(

// If we have any null columns and null strategy is not `Ignore`, we can return immediately.
if !ignore_nulls && non_null_cols.len() < columns.len() {
// We must first determine the correct return dtype.
// We must determine the correct return dtype.
let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
DataType::Boolean => DataType::UInt32,
DataType::Boolean => IDX_DTYPE,
dt => dt,
};
return Ok(Some(Column::full_null(
Expand All @@ -244,7 +244,7 @@ pub fn sum_horizontal(
},
1 => Ok(Some(
apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
non_null_cols[0].cast(&DataType::UInt32)?
non_null_cols[0].cast(&IDX_DTYPE)?
} else {
non_null_cols[0].clone()
})?
Expand Down
7 changes: 3 additions & 4 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,10 @@ impl FunctionExpr {
MinHorizontal => mapper.map_to_supertype(),
SumHorizontal { .. } => {
mapper.map_to_supertype().map(|mut f| {
match f.dtype {
// Booleans sum to UInt32.
DataType::Boolean => { f.dtype = DataType::UInt32; f},
_ => f,
if f.dtype == DataType::Boolean {
f.dtype = IDX_DTYPE;
}
f
})
},
MeanHorizontal { .. } => {
Expand Down
44 changes: 39 additions & 5 deletions py-polars/tests/unit/operations/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,39 @@ def test_sum_single_col() -> None:
)


@pytest.mark.parametrize("ignore_nulls", [False, True])
def test_sum_correct_supertype(ignore_nulls: bool) -> None:
values = [1, 2] if ignore_nulls else [None, None] # type: ignore[list-item]
lf = pl.LazyFrame(
{
"null": [None, None],
"int": pl.Series(values, dtype=pl.Int32),
"float": pl.Series(values, dtype=pl.Float32),
}
)

# null + int32 should produce int32
out = lf.select(pl.sum_horizontal("null", "int", ignore_nulls=ignore_nulls))
expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Int32)})
assert_frame_equal(out.collect(), expected.collect())
assert out.collect_schema() == expected.collect_schema()

# null + float32 should produce float32
out = lf.select(pl.sum_horizontal("null", "float", ignore_nulls=ignore_nulls))
expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float32)})
assert_frame_equal(out.collect(), expected.collect())
assert out.collect_schema() == expected.collect_schema()

# null + int32 + float32 should produce float64
values = [2, 4] if ignore_nulls else [None, None] # type: ignore[list-item]
out = lf.select(
pl.sum_horizontal("null", "int", "float", ignore_nulls=ignore_nulls)
)
expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float64)})
assert_frame_equal(out.collect(), expected.collect())
assert out.collect_schema() == expected.collect_schema()


def test_cum_sum_horizontal() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -541,17 +574,17 @@ def test_horizontal_sum_boolean_with_null() -> None:

expected_schema = pl.Schema(
{
"null_first": pl.UInt32,
"bool_first": pl.UInt32,
"null_first": pl.get_index_type(),
"bool_first": pl.get_index_type(),
}
)

assert out.collect_schema() == expected_schema

expected_df = pl.DataFrame(
{
"null_first": pl.Series([1, 0], dtype=pl.UInt32),
"bool_first": pl.Series([1, 0], dtype=pl.UInt32),
"null_first": pl.Series([1, 0], dtype=pl.get_index_type()),
"bool_first": pl.Series([1, 0], dtype=pl.get_index_type()),
}
)

Expand All @@ -563,7 +596,7 @@ def test_horizontal_sum_boolean_with_null() -> None:
("dtype_in", "dtype_out"),
[
(pl.Null, pl.Null),
(pl.Boolean, pl.UInt32),
(pl.Boolean, pl.get_index_type()),
(pl.UInt8, pl.UInt8),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
Expand All @@ -589,6 +622,7 @@ def test_horizontal_sum_with_null_col_ignore_strategy(
values = [None, None, None] # type: ignore[list-item]
expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out))
assert_frame_equal(result, expected)
assert result.collect_schema() == expected.collect_schema()


@pytest.mark.parametrize("ignore_nulls", [True, False])
Expand Down

0 comments on commit da0b589

Please sign in to comment.