diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 025779e77349..6a6960480c47 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -221,9 +221,9 @@ pub fn sum_horizontal( // If we have any null columns and null strategy is not `Ignore`, we can return immediately. if !ignore_nulls && non_null_cols.len() < columns.len() { - // We must first determine the correct return dtype. + // We must determine the correct return dtype. let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? { - DataType::Boolean => DataType::UInt32, + DataType::Boolean => IDX_DTYPE, dt => dt, }; return Ok(Some(Column::full_null( @@ -244,7 +244,7 @@ pub fn sum_horizontal( }, 1 => Ok(Some( apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean { - non_null_cols[0].cast(&DataType::UInt32)? + non_null_cols[0].cast(&IDX_DTYPE)? } else { non_null_cols[0].clone() })? diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index d45f75c01e9d..beaacac49942 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -331,11 +331,10 @@ impl FunctionExpr { MinHorizontal => mapper.map_to_supertype(), SumHorizontal { .. } => { mapper.map_to_supertype().map(|mut f| { - match f.dtype { - // Booleans sum to UInt32. - DataType::Boolean => { f.dtype = DataType::UInt32; f}, - _ => f, + if f.dtype == DataType::Boolean { + f.dtype = IDX_DTYPE; } + f }) }, MeanHorizontal { .. } => { diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py index 3959e15e22ed..bc557a231d75 100644 --- a/py-polars/tests/unit/operations/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -319,6 +319,39 @@ def test_sum_single_col() -> None: ) +@pytest.mark.parametrize("ignore_nulls", [False, True]) +def test_sum_correct_supertype(ignore_nulls: bool) -> None: + values = [1, 2] if ignore_nulls else [None, None] # type: ignore[list-item] + lf = pl.LazyFrame( + { + "null": [None, None], + "int": pl.Series(values, dtype=pl.Int32), + "float": pl.Series(values, dtype=pl.Float32), + } + ) + + # null + int32 should produce int32 + out = lf.select(pl.sum_horizontal("null", "int", ignore_nulls=ignore_nulls)) + expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Int32)}) + assert_frame_equal(out.collect(), expected.collect()) + assert out.collect_schema() == expected.collect_schema() + + # null + float32 should produce float32 + out = lf.select(pl.sum_horizontal("null", "float", ignore_nulls=ignore_nulls)) + expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float32)}) + assert_frame_equal(out.collect(), expected.collect()) + assert out.collect_schema() == expected.collect_schema() + + # null + int32 + float32 should produce float64 + values = [2, 4] if ignore_nulls else [None, None] # type: ignore[list-item] + out = lf.select( + pl.sum_horizontal("null", "int", "float", ignore_nulls=ignore_nulls) + ) + expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float64)}) + assert_frame_equal(out.collect(), expected.collect()) + assert out.collect_schema() == expected.collect_schema() + + def test_cum_sum_horizontal() -> None: df = pl.DataFrame( { @@ -541,8 +574,8 @@ def test_horizontal_sum_boolean_with_null() -> None: expected_schema = pl.Schema( { - "null_first": pl.UInt32, - "bool_first": pl.UInt32, + "null_first": pl.get_index_type(), + "bool_first": pl.get_index_type(), } ) @@ -550,8 +583,8 @@ def test_horizontal_sum_boolean_with_null() -> None: expected_df = pl.DataFrame( { - "null_first": pl.Series([1, 0], dtype=pl.UInt32), - "bool_first": pl.Series([1, 0], dtype=pl.UInt32), + "null_first": pl.Series([1, 0], dtype=pl.get_index_type()), + "bool_first": pl.Series([1, 0], dtype=pl.get_index_type()), } ) @@ -563,7 +596,7 @@ def test_horizontal_sum_boolean_with_null() -> None: ("dtype_in", "dtype_out"), [ (pl.Null, pl.Null), - (pl.Boolean, pl.UInt32), + (pl.Boolean, pl.get_index_type()), (pl.UInt8, pl.UInt8), (pl.Float32, pl.Float32), (pl.Float64, pl.Float64), @@ -589,6 +622,7 @@ def test_horizontal_sum_with_null_col_ignore_strategy( values = [None, None, None] # type: ignore[list-item] expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out)) assert_frame_equal(result, expected) + assert result.collect_schema() == expected.collect_schema() @pytest.mark.parametrize("ignore_nulls", [True, False])