Skip to content

Commit

Permalink
fix: Fix rolling aggregations for various integer types (#20512)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Dec 31, 2024
1 parent 4c14e70 commit 25fab78
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 7 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/frame/group_by/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl_take_extremum!(i8);
impl_take_extremum!(i16);
impl_take_extremum!(i32);
impl_take_extremum!(i64);
#[cfg(feature = "dtype-decimal")]
#[cfg(any(feature = "dtype-decimal", feature = "dtype-i128"))]
impl_take_extremum!(i128);
impl_take_extremum!(float: f32);
impl_take_extremum!(float: f64);
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/hashing/vector_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ vec_hash_numeric!(UInt16Chunked);
vec_hash_numeric!(UInt8Chunked);
vec_hash_numeric!(Float64Chunked);
vec_hash_numeric!(Float32Chunked);
#[cfg(feature = "dtype-decimal")]
#[cfg(any(feature = "dtype-decimal", feature = "dtype-i128"))]
vec_hash_numeric!(Int128Chunked);

impl VecHash for StringChunked {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dtype-full = [
"dtype-decimal",
"dtype-duration",
"dtype-i16",
"dtype-i128",
"dtype-i8",
"dtype-struct",
"dtype-time",
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ dtype-full = [
"dtype-decimal",
"dtype-duration",
"dtype-i16",
"dtype-i128",
"dtype-i8",
"dtype-struct",
"dtype-time",
Expand Down Expand Up @@ -144,6 +145,7 @@ dtype-duration = [
"polars-mem-engine/dtype-duration",
]
dtype-i16 = ["polars-plan/dtype-i16", "polars-pipe?/dtype-i16", "polars-expr/dtype-i16", "polars-mem-engine/dtype-i16"]
dtype-i128 = ["polars-plan/dtype-i128", "polars-pipe?/dtype-i128", "polars-expr/dtype-i128"]
dtype-i8 = ["polars-plan/dtype-i8", "polars-pipe?/dtype-i8", "polars-expr/dtype-i8", "polars-mem-engine/dtype-i8"]
dtype-struct = [
"polars-plan/dtype-struct",
Expand Down
11 changes: 11 additions & 0 deletions crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ dtype-full = [
"dtype-array",
"dtype-i8",
"dtype-i16",
"dtype-i128",
"dtype-decimal",
"dtype-u8",
"dtype-u16",
Expand Down Expand Up @@ -318,12 +319,20 @@ dtype-i8 = [
"polars-io/dtype-i8",
"polars-lazy?/dtype-i8",
"polars-ops/dtype-i8",
"polars-time?/dtype-i8",
]
dtype-i16 = [
"polars-core/dtype-i16",
"polars-io/dtype-i16",
"polars-lazy?/dtype-i16",
"polars-ops/dtype-i16",
"polars-time?/dtype-i16",
]
dtype-i128 = [
"polars-core/dtype-i128",
"polars-lazy?/dtype-i128",
"polars-ops/dtype-i128",
"polars-time?/dtype-i128",
]
dtype-decimal = [
"polars-core/dtype-decimal",
Expand All @@ -337,12 +346,14 @@ dtype-u8 = [
"polars-io/dtype-u8",
"polars-lazy?/dtype-u8",
"polars-ops/dtype-u8",
"polars-time?/dtype-u8",
]
dtype-u16 = [
"polars-core/dtype-u16",
"polars-io/dtype-u16",
"polars-lazy?/dtype-u16",
"polars-ops/dtype-u16",
"polars-time?/dtype-u16",
]
dtype-categorical = [
"polars-core/dtype-categorical",
Expand Down
18 changes: 13 additions & 5 deletions py-polars/tests/unit/operations/rolling/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from polars.testing import assert_frame_equal, assert_series_equal
from polars.testing.parametric import column, dataframes
from polars.testing.parametric.strategies.dtype import _time_units
from tests.unit.conftest import INTEGER_DTYPES

if TYPE_CHECKING:
from hypothesis.strategies import SearchStrategy
Expand Down Expand Up @@ -739,11 +740,18 @@ def test_rolling_aggregations_with_over_11225() -> None:
assert_frame_equal(result, expected)


def test_rolling() -> None:
s = pl.Series("a", [1, 2, 3, 2, 1])
assert_series_equal(s.rolling_min(2), pl.Series("a", [None, 1, 2, 2, 1]))
assert_series_equal(s.rolling_max(2), pl.Series("a", [None, 2, 3, 3, 2]))
assert_series_equal(s.rolling_sum(2), pl.Series("a", [None, 3, 5, 5, 3]))
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
def test_rolling(dtype: PolarsDataType) -> None:
s = pl.Series("a", [1, 2, 3, 2, 1], dtype=dtype)
assert_series_equal(
s.rolling_min(2), pl.Series("a", [None, 1, 2, 2, 1], dtype=dtype)
)
assert_series_equal(
s.rolling_max(2), pl.Series("a", [None, 2, 3, 3, 2], dtype=dtype)
)
assert_series_equal(
s.rolling_sum(2), pl.Series("a", [None, 3, 5, 5, 3], dtype=dtype)
)
assert_series_equal(s.rolling_mean(2), pl.Series("a", [None, 1.5, 2.5, 2.5, 1.5]))

assert s.rolling_std(2).to_list()[1] == pytest.approx(0.7071067811865476)
Expand Down

0 comments on commit 25fab78

Please sign in to comment.