diff --git a/crates/polars-plan/src/logical_plan/aexpr/schema.rs b/crates/polars-plan/src/logical_plan/aexpr/schema.rs index 965d3ad334ad..d09e6c2329b0 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/schema.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/schema.rs @@ -312,20 +312,62 @@ fn get_arithmetic_field( let mut left_field = left_ae.to_field_impl(schema, arena, nested)?; let super_type = match op { - Operator::Minus if left_field.dtype.is_temporal() => { + Operator::Minus => { let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype; - match (&left_field.dtype, right_type) { + match (&left_field.dtype, &right_type) { + #[cfg(feature = "dtype-struct")] + (Struct(_), Struct(_)) => { + return Ok(left_field); + }, + (Duration(_), Datetime(_, _)) + | (Datetime(_, _), Duration(_)) + | (Duration(_), Date) + | (Date, Duration(_)) + | (Duration(_), Time) + | (Time, Duration(_)) => try_get_supertype(left_field.data_type(), &right_type)?, // T - T != T if T is a datetime / date - (Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, &tur)), + (Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, tur)), + (_, Datetime(_, _)) | (Datetime(_, _), _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, (Date, Date) => Duration(TimeUnit::Milliseconds), - (left, right) => try_get_supertype(left, &right)?, + (_, Date) | (Date, _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)), + (_, Duration(_)) | (Duration(_), _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (_, Time) | (Time, _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (left, right) => try_get_supertype(left, right)?, } }, - Operator::Plus - if left_field.dtype == Boolean - && right_ae.get_type(schema, Context::Default, arena)? == Boolean => - { - IDX_DTYPE + Operator::Plus => { + let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype; + match (&left_field.dtype, &right_type) { + (Duration(_), Datetime(_, _)) + | (Datetime(_, _), Duration(_)) + | (Duration(_), Date) + | (Date, Duration(_)) + | (Duration(_), Time) + | (Time, Duration(_)) => try_get_supertype(left_field.data_type(), &right_type)?, + (_, Datetime(_, _)) + | (Datetime(_, _), _) + | (_, Date) + | (Date, _) + | (Time, _) + | (_, Time) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)), + (_, Duration(_)) | (Duration(_), _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Boolean, Boolean) => IDX_DTYPE, + (left, right) => try_get_supertype(left, right)?, + } }, _ => { let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype; @@ -333,9 +375,15 @@ fn get_arithmetic_field( match (&left_field.dtype, &right_type) { #[cfg(feature = "dtype-struct")] (Struct(_), Struct(_)) => { - if op.is_arithmetic() { - return Ok(left_field); - } + return Ok(left_field); + }, + (Datetime(_, _), _) + | (_, Datetime(_, _)) + | (Time, _) + | (_, Time) + | (Date, _) + | (_, Date) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, _ => { // Avoid needlessly type casting numeric columns during arithmetic @@ -381,6 +429,14 @@ fn get_truediv_field( dt if dt.is_numeric() => Float64, #[cfg(feature = "dtype-duration")] Duration(_) => Float64, + #[cfg(feature = "dtype-datetime")] + Datetime(_, _) => { + polars_bail!(InvalidOperation: "division of 'Datetime' datatype is not allowed") + }, + #[cfg(feature = "dtype-time")] + Time => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"), + #[cfg(feature = "dtype-date")] + Date => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"), // we don't know what to do here, best return the dtype dt => dt.clone(), }; diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index e505881c6542..d311dc868931 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -632,3 +632,29 @@ def test_duration_division_schema() -> None: assert q.schema == {"a": pl.Float64} assert q.collect().to_dict(as_series=False) == {"a": [1.0]} + + +@pytest.mark.parametrize( + ("a", "b", "op"), + [ + (pl.Duration, pl.Int32, "+"), + (pl.Int32, pl.Duration, "+"), + (pl.Time, pl.Int32, "+"), + (pl.Int32, pl.Time, "+"), + (pl.Date, pl.Int32, "+"), + (pl.Int32, pl.Date, "+"), + (pl.Datetime, pl.Duration, "*"), + (pl.Duration, pl.Datetime, "*"), + (pl.Date, pl.Duration, "*"), + (pl.Duration, pl.Date, "*"), + (pl.Time, pl.Duration, "*"), + (pl.Duration, pl.Time, "*"), + ], +) +def test_raise_invalid_temporal(a: pl.DataType, b: pl.DataType, op: str) -> None: + a = pl.Series("a", [], dtype=a) # type: ignore[assignment] + b = pl.Series("b", [], dtype=b) # type: ignore[assignment] + _df = pl.DataFrame([a, b]) + + with pytest.raises(pl.InvalidOperationError): + eval(f"_df.select(pl.col('a') {op} pl.col('b'))")