Skip to content

Commit

Permalink
fix: More Int128 testing and related fixes (#20494)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Dec 30, 2024
1 parent b430f64 commit c88911f
Show file tree
Hide file tree
Showing 17 changed files with 83 additions and 162 deletions.
9 changes: 9 additions & 0 deletions crates/polars-ops/src/chunked_array/array/sum_mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub(super) fn sum_array_numerical(ca: &ArrayChunked, inner_type: &DataType) -> S
Int16 => dispatch_sum::<i16, i64>(values, width, arr.validity()),
Int32 => dispatch_sum::<i32, i32>(values, width, arr.validity()),
Int64 => dispatch_sum::<i64, i64>(values, width, arr.validity()),
Int128 => dispatch_sum::<i128, i128>(values, width, arr.validity()),
UInt8 => dispatch_sum::<u8, i64>(values, width, arr.validity()),
UInt16 => dispatch_sum::<u16, i64>(values, width, arr.validity()),
UInt32 => dispatch_sum::<u32, u32>(values, width, arr.validity()),
Expand Down Expand Up @@ -96,6 +97,14 @@ pub(super) fn sum_with_nulls(ca: &ArrayChunked, inner_dtype: &DataType) -> Polar
.collect();
out.into_series()
},
#[cfg(feature = "dtype-i128")]
Int128 => {
let out: Int128Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().sum().ok()))
.collect();
out.into_series()
},
Float32 => {
let out: Float32Chunked = ca
.amortized_iter()
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/src/chunked_array/list/sum_mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub(super) fn sum_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Ser
Int16 => dispatch_sum::<i16, i64>(values, offsets, arr.validity()),
Int32 => dispatch_sum::<i32, i32>(values, offsets, arr.validity()),
Int64 => dispatch_sum::<i64, i64>(values, offsets, arr.validity()),
Int128 => dispatch_sum::<i128, i128>(values, offsets, arr.validity()),
UInt8 => dispatch_sum::<u8, i64>(values, offsets, arr.validity()),
UInt16 => dispatch_sum::<u16, i64>(values, offsets, arr.validity()),
UInt32 => dispatch_sum::<u32, u32>(values, offsets, arr.validity()),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-python/src/map/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub fn apply_lambda_with_primitive_out_type<'a, D>(
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D: PyPolarsNumericType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>,
{
let skip = usize::from(first_value.is_some());
Expand Down
25 changes: 13 additions & 12 deletions crates/polars-python/src/map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@ use crate::error::PyPolarsErr;
use crate::prelude::ObjectValue;
use crate::{PySeries, Wrap};

pub trait PyArrowPrimitiveType: PolarsNumericType {}
pub trait PyPolarsNumericType: PolarsNumericType {}

impl PyArrowPrimitiveType for UInt8Type {}
impl PyArrowPrimitiveType for UInt16Type {}
impl PyArrowPrimitiveType for UInt32Type {}
impl PyArrowPrimitiveType for UInt64Type {}
impl PyArrowPrimitiveType for Int8Type {}
impl PyArrowPrimitiveType for Int16Type {}
impl PyArrowPrimitiveType for Int32Type {}
impl PyArrowPrimitiveType for Int64Type {}
impl PyArrowPrimitiveType for Float32Type {}
impl PyArrowPrimitiveType for Float64Type {}
impl PyPolarsNumericType for UInt8Type {}
impl PyPolarsNumericType for UInt16Type {}
impl PyPolarsNumericType for UInt32Type {}
impl PyPolarsNumericType for UInt64Type {}
impl PyPolarsNumericType for Int8Type {}
impl PyPolarsNumericType for Int16Type {}
impl PyPolarsNumericType for Int32Type {}
impl PyPolarsNumericType for Int64Type {}
impl PyPolarsNumericType for Int128Type {}
impl PyPolarsNumericType for Float32Type {}
impl PyPolarsNumericType for Float64Type {}

fn iterator_to_struct<'a>(
py: Python,
Expand Down Expand Up @@ -141,7 +142,7 @@ fn iterator_to_primitive<T>(
capacity: usize,
) -> PyResult<ChunkedArray<T>>
where
T: PyArrowPrimitiveType,
T: PyPolarsNumericType,
{
let mut error = None;
// SAFETY: we know the iterators len.
Expand Down
18 changes: 9 additions & 9 deletions crates/polars-python/src/map/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ pub trait ApplyLambda<'a> {
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D: PyPolarsNumericType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>;

/// Apply a lambda with a boolean output type
Expand Down Expand Up @@ -343,7 +343,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked {
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D: PyPolarsNumericType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>,
{
let skip = usize::from(first_value.is_some());
Expand Down Expand Up @@ -570,7 +570,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked {

impl<'a, T> ApplyLambda<'a> for ChunkedArray<T>
where
T: PyArrowPrimitiveType + PolarsNumericType,
T: PyPolarsNumericType,
T::Native: IntoPyObject<'a> + FromPyObject<'a>,
ChunkedArray<T>: IntoSeries,
{
Expand Down Expand Up @@ -643,7 +643,7 @@ where
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D: PyPolarsNumericType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>,
{
let skip = usize::from(first_value.is_some());
Expand Down Expand Up @@ -936,7 +936,7 @@ impl<'a> ApplyLambda<'a> for StringChunked {
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D: PyPolarsNumericType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>,
{
let skip = usize::from(first_value.is_some());
Expand Down Expand Up @@ -1285,7 +1285,7 @@ impl<'a> ApplyLambda<'a> for ListChunked {
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D: PyPolarsNumericType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>,
{
let skip = usize::from(first_value.is_some());
Expand Down Expand Up @@ -1728,7 +1728,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked {
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D: PyPolarsNumericType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>,
{
let skip = usize::from(first_value.is_some());
Expand Down Expand Up @@ -2122,7 +2122,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked<ObjectValue> {
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D: PyPolarsNumericType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>,
{
let skip = usize::from(first_value.is_some());
Expand Down Expand Up @@ -2398,7 +2398,7 @@ impl<'a> ApplyLambda<'a> for StructChunked {
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D: PyPolarsNumericType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>,
{
let skip = usize::from(first_value.is_some());
Expand Down
11 changes: 11 additions & 0 deletions crates/polars-python/src/series/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ impl PySeries {
)?;
ca.into_series()
},
Some(DataType::Int128) => {
let ca: Int128Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
function,
0,
None
)?;
ca.into_series()
},
Some(DataType::UInt8) => {
let ca: UInt8Chunked = dispatch_apply!(
series,
Expand Down
1 change: 1 addition & 0 deletions crates/polars-python/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ macro_rules! apply_method_all_arrow_series2 {
DataType::Int16 => $self.i16().unwrap().$method($($args),*),
DataType::Int32 => $self.i32().unwrap().$method($($args),*),
DataType::Int64 => $self.i64().unwrap().$method($($args),*),
DataType::Int128 => $self.i128().unwrap().$method($($args),*),
DataType::Float32 => $self.f32().unwrap().$method($($args),*),
DataType::Float64 => $self.f64().unwrap().$method($($args),*),
DataType::Date => $self.date().unwrap().$method($($args),*),
Expand Down
14 changes: 3 additions & 11 deletions py-polars/tests/unit/dataframe/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import polars as pl
from polars.testing import assert_frame_equal, assert_series_equal
from polars.testing.parametric import column, dataframes
from tests.unit.conftest import INTEGER_DTYPES, SIGNED_INTEGER_DTYPES


@given(
Expand Down Expand Up @@ -309,16 +310,7 @@ def test_df_getitem() -> None:
assert_frame_equal(df[pl.Series("", ["a", "b"])], df)

# pl.Series: positive idxs or empty idxs for row selection.
for pl_dtype in (
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
):
for pl_dtype in INTEGER_DTYPES:
assert_frame_equal(
df[pl.Series("", [1, 0, 3, 2, 3, 0], dtype=pl_dtype)],
pl.DataFrame(
Expand All @@ -328,7 +320,7 @@ def test_df_getitem() -> None:
assert df[pl.Series("", [], dtype=pl_dtype)].columns == ["a", "b"]

# pl.Series: positive and negative idxs for row selection.
for pl_dtype in (pl.Int8, pl.Int16, pl.Int32, pl.Int64):
for pl_dtype in SIGNED_INTEGER_DTYPES:
assert_frame_equal(
df[pl.Series("", [-1, 0, -3, -2, 3, -4], dtype=pl_dtype)],
pl.DataFrame(
Expand Down
20 changes: 3 additions & 17 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SchemaError,
)
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import INTEGER_DTYPES

if sys.version_info >= (3, 11):
from enum import StrEnum
Expand Down Expand Up @@ -498,10 +499,7 @@ def test_enum_categories_series_zero_copy() -> None:
assert result_dtype == dtype


@pytest.mark.parametrize(
"dtype",
[pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, pl.Int8, pl.Int16, pl.Int32, pl.Int64],
)
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
def test_enum_cast_from_other_integer_dtype(dtype: pl.DataType) -> None:
enum_dtype = pl.Enum(["a", "b", "c", "d"])
series = pl.Series([1, 2, 3, 3, 2, 1], dtype=dtype)
Expand Down Expand Up @@ -585,19 +583,7 @@ def test_category_comparison_subset() -> None:
assert out["dt1"].dtype != out["dt2"].dtype


@pytest.mark.parametrize(
"dt",
[
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
],
)
@pytest.mark.parametrize("dt", INTEGER_DTYPES)
def test_integer_cast_to_enum_15738(dt: pl.DataType) -> None:
s = pl.Series([0, 1, 2], dtype=dt).cast(pl.Enum(["a", "b", "c"]))
assert s.to_list() == ["a", "b", "c"]
Expand Down
13 changes: 2 additions & 11 deletions py-polars/tests/unit/operations/map/test_map_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import polars as pl
from polars.exceptions import PolarsInefficientMapWarning
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import INTEGER_DTYPES

pytestmark = pytest.mark.filterwarnings(
"ignore::polars.exceptions.PolarsInefficientMapWarning"
Expand Down Expand Up @@ -129,18 +130,8 @@ def test_map_elements_list_any_value_fallback() -> None:


def test_map_elements_all_types() -> None:
dtypes = [
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
]
# test we don't panic
for dtype in dtypes:
for dtype in INTEGER_DTYPES:
pl.Series([1, 2, 3, 4, 5], dtype=dtype).map_elements(lambda x: x)


Expand Down
28 changes: 6 additions & 22 deletions py-polars/tests/unit/operations/test_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import polars as pl
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import INTEGER_DTYPES


@pytest.mark.parametrize("op", ["and_", "or_"])
Expand Down Expand Up @@ -80,20 +81,7 @@ def trailing_ones(v: int | None) -> int | None:
None,
],
)
@pytest.mark.parametrize(
"dtype",
[
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Boolean,
],
)
@pytest.mark.parametrize("dtype", [*INTEGER_DTYPES, pl.Boolean])
@pytest.mark.skipif(sys.version_info < (3, 10), reason="bit_count introduced in 3.10")
@typing.no_type_check
def test_bit_counts(value: int, dtype: pl.DataType) -> None:
Expand All @@ -106,6 +94,8 @@ def test_bit_counts(value: int, dtype: pl.DataType) -> None:
bitsize = 32
elif "64" in str(dtype):
bitsize = 64
elif "128" in str(dtype):
bitsize = 128

if bitsize == 1 and value is not None:
value = value & 1 != 0
Expand Down Expand Up @@ -150,10 +140,7 @@ def test_bit_counts(value: int, dtype: pl.DataType) -> None:
)


@pytest.mark.parametrize(
"dtype",
[pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64],
)
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
def test_bit_aggregations(dtype: pl.DataType) -> None:
s = pl.Series("a", [0x74, 0x1C, 0x05], dtype)

Expand All @@ -175,10 +162,7 @@ def test_bit_aggregations(dtype: pl.DataType) -> None:
)


@pytest.mark.parametrize(
"dtype",
[pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64],
)
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
def test_bit_group_by(dtype: pl.DataType) -> None:
df = pl.DataFrame(
[
Expand Down
Loading

0 comments on commit c88911f

Please sign in to comment.