diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index a3a6a9bc1303..739330ad0aed 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -191,10 +191,9 @@ pub fn min_horizontal(columns: &[Column]) -> PolarsResult> { } } -fn null_with_supertype( - columns: Vec<&Series>, - date_to_datetime: bool, -) -> PolarsResult> { +// Return a full-null column with dtype determined by supertype of supplied columns. +// Name of returned column is the left-month input column. +fn null_with_supertype(columns: &[Column], date_to_datetime: bool) -> PolarsResult> { // We must first determine the correct return dtype. let mut return_dtype = dtypes_to_supertype(columns.iter().map(|c| c.dtype()))?; if return_dtype == DataType::Boolean { @@ -209,19 +208,6 @@ fn null_with_supertype( ))) } -fn null_with_supertype_from_series( - columns: &[Column], - date_to_datetime: bool, -) -> PolarsResult> { - null_with_supertype( - columns - .iter() - .map(|c| c.as_materialized_series()) - .collect::>(), - date_to_datetime, - ) -} - pub fn sum_horizontal( columns: &[Column], null_strategy: NullStrategy, @@ -327,7 +313,8 @@ pub fn mean_horizontal( }; if first_non_null_idx > 0 && !ignore_nulls { - return null_with_supertype_from_series(columns, true); + // We have null columns; return immediately + return null_with_supertype(columns, true); } // Ensure column dtypes are all valid @@ -338,15 +325,15 @@ pub fn mean_horizontal( for col in &columns[first_non_null_idx + 1..] { let dtype = col.dtype(); if !ignore_nulls && dtype == &DataType::Null { - // A null column guarantees null output. - return null_with_supertype_from_series(columns, true); + // The presence of a single null column guarantees the output is all-null. + return null_with_supertype(columns, true); } else if dtype != first_dtype && dtype != &DataType::Null { polars_bail!( - InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})", + InvalidOperation: "'mean_horizontal' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})", columns[first_non_null_idx].name(), first_dtype, - dtype, col.name(), + dtype, ); }; } @@ -370,7 +357,7 @@ pub fn mean_horizontal( || dtype.is_null()) { polars_bail!( - InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})", + InvalidOperation: "'mean_horizontal' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})", columns[first_non_null_idx].name(), first_dtype, col.name(), @@ -381,7 +368,7 @@ pub fn mean_horizontal( columns[first_non_null_idx..].to_vec() } else { polars_bail!( - InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={})", + InvalidOperation: "'mean_horizontal' expects all numeric or all temporal expressions, found {:?} (dtype={})", columns[first_non_null_idx].name(), first_dtype, ); diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 9bd0f4fbaf0c..18c68c626fce 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -338,16 +338,20 @@ impl FunctionExpr { }) }, MeanHorizontal { .. } => { - mapper.map_to_supertype().map(|mut f| { - match f.dtype { - DataType::Boolean => { f.dtype = DataType::Float64; }, - DataType::Float32 => { f.dtype = DataType::Float32; }, - DataType::Date => { f.dtype = DataType::Datetime(TimeUnit::Milliseconds, None); } - dt if dt.is_temporal() => { f.dtype = dt; } - _ => { f.dtype = DataType::Float64; }, - }; - f - }) + let out = match mapper.map_to_supertype() { + Ok(mut field) => { + match field.dtype { + DataType::Boolean => { field.dtype = DataType::Float64; }, + DataType::Float32 => { field.dtype = DataType::Float32; }, + DataType::Date => { field.dtype = DataType::Datetime(TimeUnit::Milliseconds, None); } + dt if dt.is_temporal() => { field.dtype = dt; } + _ => { field.dtype = DataType::Float64; }, + }; + field + }, + Err(_) => polars_bail!(InvalidOperation: "'mean_horizontal' expects all numeric or all temporal expressions"), + }; + Ok(out) } #[cfg(feature = "ewma")] EwmMean { .. } => mapper.map_to_float_dtype(), diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py index 645f822cb968..bef1f28d4f01 100644 --- a/py-polars/tests/unit/operations/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -8,7 +8,7 @@ import polars as pl import polars.selectors as cs -from polars.exceptions import ComputeError, PolarsError +from polars.exceptions import ComputeError, InvalidOperationError, PolarsError from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: @@ -555,6 +555,47 @@ def test_mean_horizontal_temporal(tu: TimeUnit, tz: str, ignore_nulls: bool) -> assert out.collect_schema() == expected.collect_schema() +@pytest.mark.parametrize( + ("dtype1", "dtype2"), + [ + (pl.Date, pl.Datetime), + (pl.Date, pl.Time), + (pl.Date, pl.Duration), + (pl.Datetime, pl.Date), + (pl.Datetime, pl.Time), + (pl.Datetime, pl.Duration), + (pl.Time, pl.Date), + (pl.Time, pl.Datetime), + (pl.Time, pl.Duration), + (pl.Duration, pl.Date), + (pl.Duration, pl.Datetime), + (pl.Duration, pl.Time), + ], +) +@pytest.mark.parametrize("with_null", [False, True]) +def test_mean_horizontal_mismatched_types( + with_null: bool, + dtype1: PolarsDataType, + dtype2: PolarsDataType, +) -> None: + df = pl.DataFrame( + { + "null": [None, None], + "a": pl.Series([1, 2]).cast(dtype1), + "b": pl.Series([1, 2]).cast(dtype2), + } + ) + with pytest.raises( + InvalidOperationError, + match="'mean_horizontal' expects all numeric or all temporal expressions", + ): + df.select( + pl.mean_horizontal("null", "a", "b") + if with_null + else pl.mean_horizontal("a", "b") + ) + + @pytest.mark.parametrize( ("in_dtype", "out_dtype"), [ @@ -569,6 +610,15 @@ def test_mean_horizontal_temporal(tu: TimeUnit, tz: str, ignore_nulls: bool) -> (pl.Int64, pl.Float64), (pl.Float32, pl.Float32), (pl.Float64, pl.Float64), + (pl.Date, pl.Datetime("ms")), + (pl.Datetime("ms"), pl.Datetime("ms")), + (pl.Datetime("us"), pl.Datetime("us")), + (pl.Datetime("ns"), pl.Datetime("ns")), + (pl.Datetime("ns", "Asia/Kathmandu"), pl.Datetime("ns", "Asia/Kathmandu")), + (pl.Duration("ms"), pl.Duration("ms")), + (pl.Duration("us"), pl.Duration("us")), + (pl.Duration("ns"), pl.Duration("ns")), + (pl.Time, pl.Time), ], ) def test_schema_mean_horizontal_single_column(