Skip to content

Commit

Permalink
fix: Fix various Int128 operations (#20515)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Jan 2, 2025
1 parent 9d7a7d3 commit 11fa6de
Show file tree
Hide file tree
Showing 17 changed files with 86 additions and 23 deletions.
2 changes: 2 additions & 0 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,8 @@ impl Series {
},
Int64 => Ok(self.i64().unwrap().prod_reduce()),
UInt64 => Ok(self.u64().unwrap().prod_reduce()),
#[cfg(feature = "dtype-i128")]
Int128 => Ok(self.i128().unwrap().prod_reduce()),
Float32 => Ok(self.f32().unwrap().prod_reduce()),
Float64 => Ok(self.f64().unwrap().prod_reduce()),
dt => {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/series/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ impl Series {
.ok_or_else(|| unpack_chunked_err!(self => "Int64"))
}

/// Unpack to [`ChunkedArray`] of dtype [`DataType::Int64`]
/// Unpack to [`ChunkedArray`] of dtype [`DataType::Int128`]
#[cfg(feature = "dtype-i128")]
pub fn i128(&self) -> PolarsResult<&Int128Chunked> {
self.try_i128()
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/src/chunked_array/list/sum_mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ pub(super) fn mean_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Se
Int16 => dispatch_mean::<i16, f64>(values, offsets, arr.validity()),
Int32 => dispatch_mean::<i32, f64>(values, offsets, arr.validity()),
Int64 => dispatch_mean::<i64, f64>(values, offsets, arr.validity()),
Int128 => dispatch_mean::<i128, f64>(values, offsets, arr.validity()),
UInt8 => dispatch_mean::<u8, f64>(values, offsets, arr.validity()),
UInt16 => dispatch_mean::<u16, f64>(values, offsets, arr.validity()),
UInt32 => dispatch_mean::<u32, f64>(values, offsets, arr.validity()),
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-ops/src/frame/join/hash_join/sort_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ pub(super) fn par_sorted_merge_inner_no_nulls(
DataType::Int64 => {
par_sorted_merge_inner_impl(s_left.i64().unwrap(), s_right.i64().unwrap())
},
#[cfg(feature = "dtype-i128")]
DataType::Int128 => {
par_sorted_merge_inner_impl(s_left.i128().unwrap(), s_right.i128().unwrap())
},
DataType::Float32 => {
par_sorted_merge_inner_impl(s_left.f32().unwrap(), s_right.f32().unwrap())
},
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-ops/src/series/ops/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub fn abs(s: &Series) -> PolarsResult<Series> {
Int16 => s.i16().unwrap().wrapping_abs().into_series(),
Int32 => s.i32().unwrap().wrapping_abs().into_series(),
Int64 => s.i64().unwrap().wrapping_abs().into_series(),
#[cfg(feature = "dtype-i128")]
Int128 => s.i128().unwrap().wrapping_abs().into_series(),
Float32 => s.f32().unwrap().wrapping_abs().into_series(),
Float64 => s.f64().unwrap().wrapping_abs().into_series(),
#[cfg(feature = "dtype-decimal")]
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-ops/src/series/ops/cum_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ pub fn cum_prod(s: &Series, reverse: bool) -> PolarsResult<Series> {
},
Int64 => cum_prod_numeric(s.i64()?, reverse).into_series(),
UInt64 => cum_prod_numeric(s.u64()?, reverse).into_series(),
#[cfg(feature = "dtype-i128")]
Int128 => cum_prod_numeric(s.i128()?, reverse).into_series(),
Float32 => cum_prod_numeric(s.f32()?, reverse).into_series(),
Float64 => cum_prod_numeric(s.f64()?, reverse).into_series(),
dt => polars_bail!(opq = cum_prod, dt),
Expand All @@ -213,6 +215,8 @@ pub fn cum_sum(s: &Series, reverse: bool) -> PolarsResult<Series> {
UInt32 => cum_sum_numeric(s.u32()?, reverse).into_series(),
Int64 => cum_sum_numeric(s.i64()?, reverse).into_series(),
UInt64 => cum_sum_numeric(s.u64()?, reverse).into_series(),
#[cfg(feature = "dtype-i128")]
Int128 => cum_sum_numeric(s.i128()?, reverse).into_series(),
Float32 => cum_sum_numeric(s.f32()?, reverse).into_series(),
Float64 => cum_sum_numeric(s.f64()?, reverse).into_series(),
#[cfg(feature = "dtype-duration")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ fn interpolate_linear(s: &Series) -> Series {
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Int128
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/cum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub(super) mod dtypes {
match dt {
Boolean => UInt32,
Int32 => Int32,
Int128 => Int128,
UInt32 => UInt32,
UInt64 => UInt64,
Float32 => Float32,
Expand All @@ -56,6 +57,7 @@ pub(super) mod dtypes {
match dt {
Boolean => Int64,
UInt64 => UInt64,
Int128 => Int128,
Float32 => Float32,
Float64 => Float64,
_ => Int64,
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ impl Expr {
T::Float32 => T::Float32,
T::Float64 => T::Float64,
T::UInt64 => T::UInt64,
#[cfg(feature = "dtype-i128")]
T::Int128 => T::Int128,
_ => T::Int64,
})
}),
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-python/src/expr/rolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,18 @@ impl PyExpr {
})
}
},
Int128 => {
if is_float {
let v = obj.extract::<f64>(py).unwrap();
Ok(Int128Chunked::from_slice(PlSmallStr::EMPTY, &[v as i128])
.into_series())
} else {
obj.extract::<i128>(py).map(|v| {
Int128Chunked::from_slice(PlSmallStr::EMPTY, &[v])
.into_series()
})
}
},
Float32 => obj.extract::<f32>(py).map(|v| {
Float32Chunked::from_slice(PlSmallStr::EMPTY, &[v]).into_series()
}),
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-python/src/series/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ impl_eq_num!(eq_i8, i8);
impl_eq_num!(eq_i16, i16);
impl_eq_num!(eq_i32, i32);
impl_eq_num!(eq_i64, i64);
impl_eq_num!(eq_i128, i128);
impl_eq_num!(eq_f32, f32);
impl_eq_num!(eq_f64, f64);
impl_eq_num!(eq_str, &str);
Expand Down Expand Up @@ -98,6 +99,7 @@ impl_neq_num!(neq_i8, i8);
impl_neq_num!(neq_i16, i16);
impl_neq_num!(neq_i32, i32);
impl_neq_num!(neq_i64, i64);
impl_neq_num!(neq_i128, i128);
impl_neq_num!(neq_f32, f32);
impl_neq_num!(neq_f64, f64);
impl_neq_num!(neq_str, &str);
Expand All @@ -124,6 +126,7 @@ impl_gt_num!(gt_i8, i8);
impl_gt_num!(gt_i16, i16);
impl_gt_num!(gt_i32, i32);
impl_gt_num!(gt_i64, i64);
impl_gt_num!(gt_i128, i128);
impl_gt_num!(gt_f32, f32);
impl_gt_num!(gt_f64, f64);
impl_gt_num!(gt_str, &str);
Expand All @@ -150,6 +153,7 @@ impl_gt_eq_num!(gt_eq_i8, i8);
impl_gt_eq_num!(gt_eq_i16, i16);
impl_gt_eq_num!(gt_eq_i32, i32);
impl_gt_eq_num!(gt_eq_i64, i64);
impl_gt_eq_num!(gt_eq_i128, i128);
impl_gt_eq_num!(gt_eq_f32, f32);
impl_gt_eq_num!(gt_eq_f64, f64);
impl_gt_eq_num!(gt_eq_str, &str);
Expand Down Expand Up @@ -177,6 +181,7 @@ impl_lt_num!(lt_i8, i8);
impl_lt_num!(lt_i16, i16);
impl_lt_num!(lt_i32, i32);
impl_lt_num!(lt_i64, i64);
impl_lt_num!(lt_i128, i128);
impl_lt_num!(lt_f32, f32);
impl_lt_num!(lt_f64, f64);
impl_lt_num!(lt_str, &str);
Expand All @@ -203,6 +208,7 @@ impl_lt_eq_num!(lt_eq_i8, i8);
impl_lt_eq_num!(lt_eq_i16, i16);
impl_lt_eq_num!(lt_eq_i32, i32);
impl_lt_eq_num!(lt_eq_i64, i64);
impl_lt_eq_num!(lt_eq_i128, i128);
impl_lt_eq_num!(lt_eq_f32, f32);
impl_lt_eq_num!(lt_eq_f64, f64);
impl_lt_eq_num!(lt_eq_str, &str);
Expand Down
3 changes: 3 additions & 0 deletions py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Int16,
Int32,
Int64,
Int128,
List,
Null,
Object,
Expand Down Expand Up @@ -149,6 +150,7 @@ def DTYPE_TO_FFINAME(self) -> dict[PolarsDataType, str]:
Duration: "duration",
Float32: "f32",
Float64: "f64",
Int128: "i128",
Int16: "i16",
Int32: "i32",
Int64: "i64",
Expand Down Expand Up @@ -177,6 +179,7 @@ def DTYPE_TO_PY_TYPE(self) -> dict[PolarsDataType, PythonDataType]:
Duration: timedelta,
Float32: float,
Float64: float,
Int128: int,
Int16: int,
Int32: int,
Int64: int,
Expand Down
29 changes: 22 additions & 7 deletions py-polars/tests/unit/lazyframe/test_lazyframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
PolarsInefficientMapWarning,
)
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import FLOAT_DTYPES
from tests.unit.conftest import FLOAT_DTYPES, NUMERIC_DTYPES

if TYPE_CHECKING:
from _pytest.capture import CaptureFixture
Expand Down Expand Up @@ -488,19 +488,34 @@ def test_len() -> None:
assert cast(int, ldf.select(pl.col("nrs").len()).collect().item()) == 3


def test_cum_agg() -> None:
ldf = pl.LazyFrame({"a": [1, 2, 3, 2]})
@pytest.mark.parametrize("dtype", NUMERIC_DTYPES)
def test_cum_agg(dtype: PolarsDataType) -> None:
ldf = pl.LazyFrame({"a": [1, 2, 3, 2]}, schema={"a": dtype})
assert_series_equal(
ldf.select(pl.col("a").cum_sum()).collect()["a"], pl.Series("a", [1, 3, 6, 8])
ldf.select(pl.col("a").cum_min()).collect()["a"],
pl.Series("a", [1, 1, 1, 1], dtype=dtype),
)
assert_series_equal(
ldf.select(pl.col("a").cum_min()).collect()["a"], pl.Series("a", [1, 1, 1, 1])
ldf.select(pl.col("a").cum_max()).collect()["a"],
pl.Series("a", [1, 2, 3, 3], dtype=dtype),
)

expected_dtype = (
pl.Int64 if dtype in [pl.Int8, pl.Int16, pl.UInt8, pl.UInt16] else dtype
)
assert_series_equal(
ldf.select(pl.col("a").cum_max()).collect()["a"], pl.Series("a", [1, 2, 3, 3])
ldf.select(pl.col("a").cum_sum()).collect()["a"],
pl.Series("a", [1, 3, 6, 8], dtype=expected_dtype),
)

expected_dtype = (
pl.Int64
if dtype in [pl.Int8, pl.Int16, pl.Int32, pl.UInt8, pl.UInt16, pl.UInt32]
else dtype
)
assert_series_equal(
ldf.select(pl.col("a").cum_prod()).collect()["a"], pl.Series("a", [1, 2, 6, 12])
ldf.select(pl.col("a").cum_prod()).collect()["a"],
pl.Series("a", [1, 2, 6, 12], dtype=expected_dtype),
)


Expand Down
13 changes: 8 additions & 5 deletions py-polars/tests/unit/operations/rolling/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import polars as pl
from polars.testing import assert_series_equal
from tests.unit.conftest import INTEGER_DTYPES

if TYPE_CHECKING:
from polars._typing import PolarsDataType
Expand Down Expand Up @@ -82,17 +83,19 @@ def test_rolling_map_std_weights(dtype: PolarsDataType) -> None:
assert_series_equal(result, expected)


def test_rolling_map_sum_int() -> None:
s = pl.Series("A", [1, 2, 9, 2, 13], dtype=pl.Int32)
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
def test_rolling_map_sum_int(dtype: PolarsDataType) -> None:
s = pl.Series("A", [1, 2, 9, 2, 13], dtype=dtype)

result = s.rolling_map(function=lambda s: s.sum(), window_size=3)

expected = pl.Series("A", [None, None, 12, 13, 24], dtype=pl.Int32)
expected = pl.Series("A", [None, None, 12, 13, 24], dtype=dtype)
assert_series_equal(result, expected)


def test_rolling_map_sum_int_cast_to_float() -> None:
s = pl.Series("A", [1, 2, 9, None, 13], dtype=pl.Int32)
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
def test_rolling_map_sum_int_cast_to_float(dtype: PolarsDataType) -> None:
s = pl.Series("A", [1, 2, 9, None, 13], dtype=dtype)

result = s.rolling_map(
function=lambda s: s.sum(), window_size=3, weights=[1.0, 2.0, 3.0]
Expand Down
5 changes: 2 additions & 3 deletions py-polars/tests/unit/operations/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import polars as pl
from polars.exceptions import InvalidOperationError
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import FLOAT_DTYPES, SIGNED_INTEGER_DTYPES

if TYPE_CHECKING:
from polars._typing import PolarsDataType
Expand Down Expand Up @@ -47,9 +48,7 @@ def test_builtin_abs() -> None:
assert abs(s).to_list() == [1, 0, 1, None]


@pytest.mark.parametrize(
"dtype", [pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.Float32, pl.Float64]
)
@pytest.mark.parametrize("dtype", [*FLOAT_DTYPES, *SIGNED_INTEGER_DTYPES])
def test_abs_builtin(dtype: PolarsDataType) -> None:
lf = pl.LazyFrame({"a": [-1, 0, 1, None]}, schema={"a": dtype})
result = lf.select(abs(pl.col("a")))
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/operations/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
(pl.Int16, pl.Float64),
(pl.Int32, pl.Float64),
(pl.Int64, pl.Float64),
(pl.Int128, pl.Float64),
(pl.UInt8, pl.Float64),
(pl.UInt16, pl.Float64),
(pl.UInt32, pl.Float64),
Expand Down
20 changes: 13 additions & 7 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ShapeError,
)
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import FLOAT_DTYPES, INTEGER_DTYPES
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder

if TYPE_CHECKING:
Expand Down Expand Up @@ -1717,23 +1718,28 @@ def test_trigonometric_invalid_input() -> None:
s.cosh()


def test_product() -> None:
a = pl.Series("a", [1, 2, 3])
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
def test_product_ints(dtype: PolarsDataType) -> None:
a = pl.Series("a", [1, 2, 3], dtype=dtype)
out = a.product()
assert out == 6
a = pl.Series("a", [1, 2, None])
a = pl.Series("a", [1, 2, None], dtype=dtype)
out = a.product()
assert out == 2
a = pl.Series("a", [None, 2, 3])
a = pl.Series("a", [None, 2, 3], dtype=dtype)
out = a.product()
assert out == 6
a = pl.Series("a", [], dtype=pl.Float32)


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_product_floats(dtype: PolarsDataType) -> None:
a = pl.Series("a", [], dtype=dtype)
out = a.product()
assert out == 1
a = pl.Series("a", [None, None], dtype=pl.Float32)
a = pl.Series("a", [None, None], dtype=dtype)
out = a.product()
assert out == 1
a = pl.Series("a", [3.0, None, float("nan")])
a = pl.Series("a", [3.0, None, float("nan")], dtype=dtype)
out = a.product()
assert math.isnan(out)

Expand Down

0 comments on commit 11fa6de

Please sign in to comment.