Skip to content

Commit

Permalink
Update & add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Jan 3, 2025
1 parent b502cab commit de6240e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 35 deletions.
35 changes: 11 additions & 24 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,9 @@ pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
}
}

fn null_with_supertype(
columns: Vec<&Series>,
date_to_datetime: bool,
) -> PolarsResult<Option<Column>> {
// Return a full-null column with dtype determined by supertype of supplied columns.
// Name of returned column is the left-month input column.
fn null_with_supertype(columns: &[Column], date_to_datetime: bool) -> PolarsResult<Option<Column>> {
// We must first determine the correct return dtype.
let mut return_dtype = dtypes_to_supertype(columns.iter().map(|c| c.dtype()))?;
if return_dtype == DataType::Boolean {
Expand All @@ -209,19 +208,6 @@ fn null_with_supertype(
)))
}

fn null_with_supertype_from_series(
columns: &[Column],
date_to_datetime: bool,
) -> PolarsResult<Option<Column>> {
null_with_supertype(
columns
.iter()
.map(|c| c.as_materialized_series())
.collect::<Vec<_>>(),
date_to_datetime,
)
}

pub fn sum_horizontal(
columns: &[Column],
null_strategy: NullStrategy,
Expand Down Expand Up @@ -327,7 +313,8 @@ pub fn mean_horizontal(
};

if first_non_null_idx > 0 && !ignore_nulls {
return null_with_supertype_from_series(columns, true);
// We have null columns; return immediately
return null_with_supertype(columns, true);
}

// Ensure column dtypes are all valid
Expand All @@ -338,15 +325,15 @@ pub fn mean_horizontal(
for col in &columns[first_non_null_idx + 1..] {
let dtype = col.dtype();
if !ignore_nulls && dtype == &DataType::Null {
// A null column guarantees null output.
return null_with_supertype_from_series(columns, true);
// The presence of a single null column guarantees the output is all-null.
return null_with_supertype(columns, true);
} else if dtype != first_dtype && dtype != &DataType::Null {
polars_bail!(
InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})",
InvalidOperation: "'mean_horizontal' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})",
columns[first_non_null_idx].name(),
first_dtype,
dtype,
col.name(),
dtype,
);
};
}
Expand All @@ -370,7 +357,7 @@ pub fn mean_horizontal(
|| dtype.is_null())
{
polars_bail!(
InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})",
InvalidOperation: "'mean_horizontal' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})",
columns[first_non_null_idx].name(),
first_dtype,
col.name(),
Expand All @@ -381,7 +368,7 @@ pub fn mean_horizontal(
columns[first_non_null_idx..].to_vec()
} else {
polars_bail!(
InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={})",
InvalidOperation: "'mean_horizontal' expects all numeric or all temporal expressions, found {:?} (dtype={})",
columns[first_non_null_idx].name(),
first_dtype,
);
Expand Down
24 changes: 14 additions & 10 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,20 @@ impl FunctionExpr {
})
},
MeanHorizontal { .. } => {
mapper.map_to_supertype().map(|mut f| {
match f.dtype {
DataType::Boolean => { f.dtype = DataType::Float64; },
DataType::Float32 => { f.dtype = DataType::Float32; },
DataType::Date => { f.dtype = DataType::Datetime(TimeUnit::Milliseconds, None); }
dt if dt.is_temporal() => { f.dtype = dt; }
_ => { f.dtype = DataType::Float64; },
};
f
})
let out = match mapper.map_to_supertype() {
Ok(mut field) => {
match field.dtype {
DataType::Boolean => { field.dtype = DataType::Float64; },
DataType::Float32 => { field.dtype = DataType::Float32; },
DataType::Date => { field.dtype = DataType::Datetime(TimeUnit::Milliseconds, None); }
dt if dt.is_temporal() => { field.dtype = dt; }
_ => { field.dtype = DataType::Float64; },
};
field
},
Err(_) => polars_bail!(InvalidOperation: "'mean_horizontal' expects all numeric or all temporal expressions"),
};
Ok(out)
}
#[cfg(feature = "ewma")]
EwmMean { .. } => mapper.map_to_float_dtype(),
Expand Down
52 changes: 51 additions & 1 deletion py-polars/tests/unit/operations/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import polars as pl
import polars.selectors as cs
from polars.exceptions import ComputeError, PolarsError
from polars.exceptions import ComputeError, InvalidOperationError, PolarsError
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -555,6 +555,47 @@ def test_mean_horizontal_temporal(tu: TimeUnit, tz: str, ignore_nulls: bool) ->
assert out.collect_schema() == expected.collect_schema()


@pytest.mark.parametrize(
("dtype1", "dtype2"),
[
(pl.Date, pl.Datetime),
(pl.Date, pl.Time),
(pl.Date, pl.Duration),
(pl.Datetime, pl.Date),
(pl.Datetime, pl.Time),
(pl.Datetime, pl.Duration),
(pl.Time, pl.Date),
(pl.Time, pl.Datetime),
(pl.Time, pl.Duration),
(pl.Duration, pl.Date),
(pl.Duration, pl.Datetime),
(pl.Duration, pl.Time),
],
)
@pytest.mark.parametrize("with_null", [False, True])
def test_mean_horizontal_mismatched_types(
with_null: bool,
dtype1: PolarsDataType,
dtype2: PolarsDataType,
) -> None:
df = pl.DataFrame(
{
"null": [None, None],
"a": pl.Series([1, 2]).cast(dtype1),
"b": pl.Series([1, 2]).cast(dtype2),
}
)
with pytest.raises(
InvalidOperationError,
match="'mean_horizontal' expects all numeric or all temporal expressions",
):
df.select(
pl.mean_horizontal("null", "a", "b")
if with_null
else pl.mean_horizontal("a", "b")
)


@pytest.mark.parametrize(
("in_dtype", "out_dtype"),
[
Expand All @@ -569,6 +610,15 @@ def test_mean_horizontal_temporal(tu: TimeUnit, tz: str, ignore_nulls: bool) ->
(pl.Int64, pl.Float64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
(pl.Date, pl.Datetime("ms")),
(pl.Datetime("ms"), pl.Datetime("ms")),
(pl.Datetime("us"), pl.Datetime("us")),
(pl.Datetime("ns"), pl.Datetime("ns")),
(pl.Datetime("ns", "Asia/Kathmandu"), pl.Datetime("ns", "Asia/Kathmandu")),
(pl.Duration("ms"), pl.Duration("ms")),
(pl.Duration("us"), pl.Duration("us")),
(pl.Duration("ns"), pl.Duration("ns")),
(pl.Time, pl.Time),
],
)
def test_schema_mean_horizontal_single_column(
Expand Down

0 comments on commit de6240e

Please sign in to comment.