From 1780991363e9eaedc1e25d244a06b9be7dbd05c9 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Fri, 21 Feb 2025 14:05:39 -0500 Subject: [PATCH] Propagate dtype for dur * primitive --- crates/polars-plan/src/plans/aexpr/schema.rs | 8 ++++++ .../operations/arithmetic/test_arithmetic.py | 26 +++++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index ce2ad4669eae..91b6c048c56f 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -595,6 +595,14 @@ fn get_arithmetic_field( polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, }, + (Duration(_), r) if r.is_primitive_numeric() => match op { + Operator::Multiply => { + return Ok(left_field); + }, + _ => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + }, #[cfg(feature = "dtype-decimal")] (Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => { let scale = match op { diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 279042ca2620..4459a428cee0 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -705,8 +705,8 @@ def test_arithmetic_duration_div_multiply() -> None: ("a", pl.Duration(time_unit="us")), ("b", pl.Duration(time_unit="us")), ("c", pl.Duration(time_unit="us")), - ("d", pl.Unknown()), - ("e", pl.Unknown()), + ("d", pl.Duration(time_unit="us")), + ("e", pl.Duration(time_unit="us")), ("f", pl.Float64()), ] ) @@ -824,3 +824,25 @@ def test_raise_invalid_shape() -> None: def test_integer_divide_scalar_zero_lhs_19142() -> None: assert_series_equal(pl.Series([0]) // pl.Series([1, 0]), pl.Series([0, None])) assert_series_equal(pl.Series([0]) % pl.Series([1, 0]), pl.Series([0, None])) + + +def test_compound_duration_21389() -> None: + # test add + lf = pl.LazyFrame( + { + "ts": datetime(2024, 1, 1, 1, 2, 3), + "duration": timedelta(days=1), + } + ) + result = lf.select(pl.col("ts") + pl.col("duration") * 2) + expected_schema = pl.Schema({"ts": pl.Datetime(time_unit="us", time_zone=None)}) + expected = pl.DataFrame({"ts": datetime(2024, 1, 3, 1, 2, 3)}) + assert result.collect_schema() == expected_schema + assert_frame_equal(result.collect(), expected) + + # test subtract + result = lf.select(pl.col("ts") - pl.col("duration") * 2) + expected_schema = pl.Schema({"ts": pl.Datetime(time_unit="us", time_zone=None)}) + expected = pl.DataFrame({"ts": datetime(2023, 12, 30, 1, 2, 3)}) + assert result.collect_schema() == expected_schema + assert_frame_equal(result.collect(), expected)