diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 6205fe4a3f2a..be78be02baff 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -786,7 +786,7 @@ impl PolarsSQLFunctions { if ctx.function_registry.contains(other) { Self::Udf(other.to_string()) } else { - polars_bail!(SQLInterface: "unsupported function: {}", other); + polars_bail!(SQLInterface: "unsupported function '{}'", other); } }, }) @@ -824,13 +824,13 @@ impl SQLFunctionVisitor<'_> { Ok(e.round(match decimals { Expr::Literal(LiteralValue::Int(n)) => { if n >= 0 { n as u32 } else { - polars_bail!(SQLInterface: "ROUND does not (yet) support negative 'n_decimals' ({})", function.args[1]) + polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", function.args[1]) } }, - _ => polars_bail!(SQLSyntax: "invalid 'n_decimals' for ROUND ({})", function.args[1]), + _ => polars_bail!(SQLSyntax: "invalid decimals value for ROUND ({})", function.args[1]), })) }), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for ROUND; expected 1 or 2, found {}", function.args.len()), + _ => polars_bail!(SQLSyntax: "invalid number of arguments for ROUND (expected 1-2, found {})", function.args.len()), }, Sign => self.visit_unary(Expr::sign), Sqrt => self.visit_unary(Expr::sqrt), @@ -887,11 +887,11 @@ impl SQLFunctionVisitor<'_> { 2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)), _ => polars_bail!(SQLSyntax: "invalid number of arguments for DATE: {}", function.args.len()), }, - DatePart => self.try_visit_binary(|e, part| { + DatePart => self.try_visit_binary(|part, e| { match part { Expr::Literal(LiteralValue::String(p)) => parse_date_part(e, &p), _ => { - polars_bail!(SQLSyntax: "invalid 'part' for DATE_PART: {}", function.args[1]); + polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART: {}", function.args[1]); } } }), @@ -1119,7 +1119,7 @@ impl SQLFunctionVisitor<'_> { self.apply_cumulative_window(f, cumulative_f, spec) }, Some(WindowType::NamedWindow(named_window)) => polars_bail!( - SQLInterface: "Named windows are not supported yet; found {:?}", + SQLInterface: "Named windows are not currently supported; found {:?}", named_window ), _ => self.visit_unary(f), @@ -1294,7 +1294,7 @@ impl SQLFunctionVisitor<'_> { } }, Some(WindowType::NamedWindow(named_window)) => polars_bail!( - SQLInterface: "Named windows are not supported yet; found: {:?}", + SQLInterface: "Named windows are not currently supported; found {:?}", named_window ), None => expr, diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index b9d38e761a82..e5349d3252bd 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -1181,7 +1181,12 @@ fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult { DateTimeField::Year => expr.dt().year(), DateTimeField::Quarter => expr.dt().quarter(), DateTimeField::Month => expr.dt().month(), - DateTimeField::Week(_) => expr.dt().week(), + DateTimeField::Week(weekday) => { + if weekday.is_some() { + polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field) + } + expr.dt().week() + }, DateTimeField::IsoWeek => expr.dt().week(), DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(), DateTimeField::DayOfWeek | DateTimeField::Dow => { @@ -1217,7 +1222,7 @@ fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult { + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64)) }, _ => { - polars_bail!(SQLInterface: "EXTRACT function does not support {}", field) + polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field) }, }) } @@ -1250,7 +1255,7 @@ pub(crate) fn parse_date_part(expr: Expr, part: &str) -> PolarsResult { "time" => &DateTimeField::Time, "epoch" => &DateTimeField::Epoch, _ => { - polars_bail!(SQLInterface: "DATE_PART function does not support '{}'", part) + polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", part) }, }, ) diff --git a/py-polars/docs/source/reference/sql/functions/temporal.rst b/py-polars/docs/source/reference/sql/functions/temporal.rst index 11a00c1e2412..d54ee6e5693f 100644 --- a/py-polars/docs/source/reference/sql/functions/temporal.rst +++ b/py-polars/docs/source/reference/sql/functions/temporal.rst @@ -82,9 +82,9 @@ Extracts a part of a date (or datetime) such as 'year', 'month', etc. df.sql(""" SELECT dt, - DATE_PART(dt, 'year') AS year, - DATE_PART(dt, 'month') AS month, - DATE_PART(dt, 'day') AS day + DATE_PART('year', dt) AS year, + DATE_PART('month', dt) AS month, + DATE_PART('day', dt) AS day FROM self """) diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index c903160a5623..f075949a3864 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -118,11 +118,11 @@ def test_round_ndigits_errors() -> None: df = pl.DataFrame({"n": [99.999]}) with pl.SQLContext(df=df, eager=True) as ctx: with pytest.raises( - SQLSyntaxError, match=r"invalid 'n_decimals' for ROUND \('!!'\)" + SQLSyntaxError, match=r"invalid decimals value for ROUND \('!!'\)" ): ctx.execute("SELECT ROUND(n,'!!') AS n FROM df") with pytest.raises( - SQLInterfaceError, match=r"ROUND .* negative 'n_decimals' \(-1\)" + SQLInterfaceError, match=r"ROUND .* negative decimals value \(-1\)" ): ctx.execute("SELECT ROUND(n,-1) AS n FROM df") diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index f6bf073b7141..cd4919a13485 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -100,13 +100,31 @@ def test_extract(part: str, dtype: pl.DataType, expected: list[Any]) -> None: } ) with pl.SQLContext(frame_data=df, eager=True) as ctx: - for func in (f"EXTRACT({part} FROM dt)", f"DATE_PART(dt,'{part}')"): + for func in (f"EXTRACT({part} FROM dt)", f"DATE_PART('{part}',dt)"): res = ctx.execute(f"SELECT {func} AS {part} FROM frame_data").to_series() assert res.dtype == dtype assert res.to_list() == expected +def test_extract_errors() -> None: + df = pl.DataFrame({"dt": [datetime(2024, 1, 7, 1, 2, 3, 123456)]}) + + with pl.SQLContext(frame_data=df, eager=True) as ctx: + for part in ("femtosecond", "stroopwafel"): + with pytest.raises( + SQLSyntaxError, + match=f"EXTRACT/DATE_PART does not support '{part}' part", + ): + ctx.execute(f"SELECT EXTRACT({part} FROM dt) FROM frame_data") + + with pytest.raises( + SQLSyntaxError, + match=r"EXTRACT/DATE_PART does not support 'week\(tuesday\)' part", + ): + ctx.execute("SELECT DATE_PART('week(tuesday)', dt) FROM frame_data") + + @pytest.mark.parametrize( ("dt", "expected"), [ @@ -130,9 +148,9 @@ def test_extract_century_millennium(dt: date, expected: list[int]) -> None: """ SELECT EXTRACT(MILLENNIUM FROM dt) AS c1, - DATE_PART(dt,'century') AS c2, + DATE_PART('century',dt) AS c2, EXTRACT(millennium FROM dt) AS c3, - DATE_PART(dt,'CENTURY') AS c4, + DATE_PART('CENTURY',dt) AS c4, FROM frame_data """ )