Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Jan 3, 2025
1 parent df74858 commit a2b1bc5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 35 deletions.
45 changes: 10 additions & 35 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,40 +190,6 @@ pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
}
}

// Given a Series sequence, determines the correct supertype and returns full-null Column.
// The `date_to_datetime` flag indicates that Date supertypes should be converted to Datetime("ms").
fn null_with_supertype(
columns: Vec<&Series>,
date_to_datetime: bool,
) -> PolarsResult<Option<Column>> {
// We must first determine the correct return dtype.
let mut return_dtype = dtypes_to_supertype(columns.iter().map(|c| c.dtype()))?;
if return_dtype == DataType::Boolean {
return_dtype = IDX_DTYPE;
} else if date_to_datetime && return_dtype == DataType::Date {
return_dtype = DataType::Datetime(TimeUnit::Milliseconds, None);
}
Ok(Some(Column::full_null(
columns[0].name().clone(),
columns[0].len(),
&return_dtype,
)))
}

// Apply `null_with_supertype` to a sequence of Columns.
fn null_with_supertype_from_columns(
columns: &[Column],
date_to_datetime: bool,
) -> PolarsResult<Option<Column>> {
null_with_supertype(
columns
.iter()
.map(|c| c.as_materialized_series())
.collect::<Vec<_>>(),
date_to_datetime,
)
}

pub fn sum_horizontal(
columns: &[Column],
null_strategy: NullStrategy,
Expand Down Expand Up @@ -255,7 +221,16 @@ pub fn sum_horizontal(

// If we have any null columns and null strategy is not `Ignore`, we can return immediately.
if !ignore_nulls && non_null_cols.len() < columns.len() {
return null_with_supertype_from_columns(columns, false);
// We must determine the correct return dtype.
let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
DataType::Boolean => IDX_DTYPE,
dt => dt,
};
return Ok(Some(Column::full_null(
columns[0].name().clone(),
columns[0].len(),
&return_dtype,
)));
}

match non_null_cols.len() {
Expand Down
33 changes: 33 additions & 0 deletions py-polars/tests/unit/operations/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,39 @@ def test_sum_single_col() -> None:
)


@pytest.mark.parametrize("ignore_nulls", [False, True])
def test_sum_correct_supertype(ignore_nulls: bool) -> None:
values = [1, 2] if ignore_nulls else [None, None]
lf = pl.LazyFrame(
{
"null": [None, None],
"int": pl.Series(values, dtype=pl.Int32),
"float": pl.Series(values, dtype=pl.Float32),
}
)

# null + int32 should produce int32
out = lf.select(pl.sum_horizontal("null", "int", ignore_nulls=ignore_nulls))
expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Int32)})
assert_frame_equal(out.collect(), expected.collect())
assert out.collect_schema() == expected.collect_schema()

# null + float32 should produce float32
out = lf.select(pl.sum_horizontal("null", "float", ignore_nulls=ignore_nulls))
expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float32)})
assert_frame_equal(out.collect(), expected.collect())
assert out.collect_schema() == expected.collect_schema()

# null + int32 + float32 should produce float64
values = [2, 4] if ignore_nulls else [None, None]
out = lf.select(
pl.sum_horizontal("null", "int", "float", ignore_nulls=ignore_nulls)
)
expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float64)})
assert_frame_equal(out.collect(), expected.collect())
assert out.collect_schema() == expected.collect_schema()


def test_cum_sum_horizontal() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit a2b1bc5

Please sign in to comment.