From 68687d9736b5a1b85ecc328738f57192290168ff Mon Sep 17 00:00:00 2001 From: ritchie Date: Thu, 13 Jun 2024 15:30:16 +0200 Subject: [PATCH 1/3] feat: Raise on invalid temporal arithmetic --- .../src/logical_plan/aexpr/schema.rs | 72 +++++++++++++++---- 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/aexpr/schema.rs b/crates/polars-plan/src/logical_plan/aexpr/schema.rs index 965d3ad334ad..845806679cc1 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/schema.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/schema.rs @@ -312,20 +312,54 @@ fn get_arithmetic_field( let mut left_field = left_ae.to_field_impl(schema, arena, nested)?; let super_type = match op { - Operator::Minus if left_field.dtype.is_temporal() => { + Operator::Minus => { let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype; - match (&left_field.dtype, right_type) { + match (&left_field.dtype, &right_type) { + #[cfg(feature = "dtype-struct")] + (Struct(_), Struct(_)) => { + return Ok(left_field); + }, + (Duration(_), Datetime(_, _)) + | (Datetime(_, _), Duration(_)) + | (Duration(_), Date) + | (Date, Duration(_)) + | (Duration(_), Time) + | (Time, Duration(_)) => try_get_supertype(left_field.data_type(), &right_type)?, // T - T != T if T is a datetime / date - (Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, &tur)), + (Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, tur)), + (_, Datetime(_, _)) | (Datetime(_, _), _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, (Date, Date) => Duration(TimeUnit::Milliseconds), - (left, right) => try_get_supertype(left, &right)?, + (_, Date) | (Date, _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)), + (_, Duration(_)) | (Duration(_), _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (left, right) => try_get_supertype(left, right)?, } }, - Operator::Plus - if left_field.dtype == Boolean - && right_ae.get_type(schema, Context::Default, arena)? == Boolean => - { - IDX_DTYPE + Operator::Plus => { + let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype; + match (&left_field.dtype, &right_type) { + (Duration(_), Datetime(_, _)) + | (Datetime(_, _), Duration(_)) + | (Duration(_), Date) + | (Date, Duration(_)) + | (Duration(_), Time) + | (Time, Duration(_)) => try_get_supertype(left_field.data_type(), &right_type)?, + (_, Datetime(_, _)) | (Datetime(_, _), _) | (_, Date) | (Date, _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)), + (_, Duration(_)) | (Duration(_), _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Boolean, Boolean) => IDX_DTYPE, + (left, right) => try_get_supertype(left, right)?, + } }, _ => { let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype; @@ -333,9 +367,15 @@ fn get_arithmetic_field( match (&left_field.dtype, &right_type) { #[cfg(feature = "dtype-struct")] (Struct(_), Struct(_)) => { - if op.is_arithmetic() { - return Ok(left_field); - } + return Ok(left_field); + }, + (Datetime(_, _), _) + | (_, Datetime(_, _)) + | (Time, _) + | (_, Time) + | (Date, _) + | (_, Date) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, _ => { // Avoid needlessly type casting numeric columns during arithmetic @@ -381,6 +421,14 @@ fn get_truediv_field( dt if dt.is_numeric() => Float64, #[cfg(feature = "dtype-duration")] Duration(_) => Float64, + #[cfg(feature = "dtype-datetime")] + Datetime(_, _) => { + polars_bail!(InvalidOperation: "division of 'Datetime' datatype is not allowed") + }, + #[cfg(feature = "dtype-time")] + Time => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"), + #[cfg(feature = "dtype-date")] + Date => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"), // we don't know what to do here, best return the dtype dt => dt.clone(), }; From 8e12d4f0e0998874c82edc52566774da5115f32e Mon Sep 17 00:00:00 2001 From: ritchie Date: Thu, 13 Jun 2024 16:03:57 +0200 Subject: [PATCH 2/3] test --- .../src/logical_plan/aexpr/schema.rs | 5 +++- .../operations/arithmetic/test_arithmetic.py | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/crates/polars-plan/src/logical_plan/aexpr/schema.rs b/crates/polars-plan/src/logical_plan/aexpr/schema.rs index 845806679cc1..2aad112a4514 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/schema.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/schema.rs @@ -338,6 +338,9 @@ fn get_arithmetic_field( (_, Duration(_)) | (Duration(_), _) => { polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, + (_, Time) | (Time, _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, (left, right) => try_get_supertype(left, right)?, } }, @@ -350,7 +353,7 @@ fn get_arithmetic_field( | (Date, Duration(_)) | (Duration(_), Time) | (Time, Duration(_)) => try_get_supertype(left_field.data_type(), &right_type)?, - (_, Datetime(_, _)) | (Datetime(_, _), _) | (_, Date) | (Date, _) => { + (_, Datetime(_, _)) | (Datetime(_, _), _) | (_, Date) | (Date, _) | (Time, _) | (_, Time) => { polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)), diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index e505881c6542..d134d574f4a7 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -632,3 +632,29 @@ def test_duration_division_schema() -> None: assert q.schema == {"a": pl.Float64} assert q.collect().to_dict(as_series=False) == {"a": [1.0]} + + +@pytest.mark.parametrize( + ("a", "b", "op"), + [ + (pl.Duration, pl.Int32, "+"), + (pl.Int32, pl.Duration, "+"), + (pl.Time, pl.Int32, "+"), + (pl.Int32, pl.Time, "+"), + (pl.Date, pl.Int32, "+"), + (pl.Int32, pl.Date, "+"), + (pl.Datetime, pl.Duration, "*"), + (pl.Duration, pl.Datetime, "*"), + (pl.Date, pl.Duration, "*"), + (pl.Duration, pl.Date, "*"), + (pl.Time, pl.Duration, "*"), + (pl.Duration, pl.Time, "*"), + ] +) +def test_raise_invalid_temporal(a: pl.DataType, b: pl.DataType, op: str) -> None: + a = pl.Series("a", [], dtype=a) + b = pl.Series("b", [], dtype=b) + df = pl.DataFrame([a, b]) + + with pytest.raises(pl.InvalidOperationError): + eval(f"df.select(pl.col('a') {op} pl.col('b'))") From 119c42fffcccc1030dd573ca7163fffb464473f6 Mon Sep 17 00:00:00 2001 From: ritchie Date: Thu, 13 Jun 2024 16:46:27 +0200 Subject: [PATCH 3/3] lint --- crates/polars-plan/src/logical_plan/aexpr/schema.rs | 7 ++++++- .../unit/operations/arithmetic/test_arithmetic.py | 10 +++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/aexpr/schema.rs b/crates/polars-plan/src/logical_plan/aexpr/schema.rs index 2aad112a4514..d09e6c2329b0 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/schema.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/schema.rs @@ -353,7 +353,12 @@ fn get_arithmetic_field( | (Date, Duration(_)) | (Duration(_), Time) | (Time, Duration(_)) => try_get_supertype(left_field.data_type(), &right_type)?, - (_, Datetime(_, _)) | (Datetime(_, _), _) | (_, Date) | (Date, _) | (Time, _) | (_, Time) => { + (_, Datetime(_, _)) + | (Datetime(_, _), _) + | (_, Date) + | (Date, _) + | (Time, _) + | (_, Time) => { polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)), diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index d134d574f4a7..d311dc868931 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -649,12 +649,12 @@ def test_duration_division_schema() -> None: (pl.Duration, pl.Date, "*"), (pl.Time, pl.Duration, "*"), (pl.Duration, pl.Time, "*"), - ] + ], ) def test_raise_invalid_temporal(a: pl.DataType, b: pl.DataType, op: str) -> None: - a = pl.Series("a", [], dtype=a) - b = pl.Series("b", [], dtype=b) - df = pl.DataFrame([a, b]) + a = pl.Series("a", [], dtype=a) # type: ignore[assignment] + b = pl.Series("b", [], dtype=b) # type: ignore[assignment] + _df = pl.DataFrame([a, b]) with pytest.raises(pl.InvalidOperationError): - eval(f"df.select(pl.col('a') {op} pl.col('b'))") + eval(f"_df.select(pl.col('a') {op} pl.col('b'))")