Skip to content

Commit

Permalink
fix: DATE_PART SQL syntax/parsing, improve some error messages (#16761
Browse files Browse the repository at this point in the history
)
  • Loading branch information
alexander-beedie authored Jun 6, 2024
1 parent 3b67c25 commit 2398b47
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 19 deletions.
16 changes: 8 additions & 8 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ impl PolarsSQLFunctions {
if ctx.function_registry.contains(other) {
Self::Udf(other.to_string())
} else {
polars_bail!(SQLInterface: "unsupported function: {}", other);
polars_bail!(SQLInterface: "unsupported function '{}'", other);
}
},
})
Expand Down Expand Up @@ -824,13 +824,13 @@ impl SQLFunctionVisitor<'_> {
Ok(e.round(match decimals {
Expr::Literal(LiteralValue::Int(n)) => {
if n >= 0 { n as u32 } else {
polars_bail!(SQLInterface: "ROUND does not (yet) support negative 'n_decimals' ({})", function.args[1])
polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", function.args[1])
}
},
_ => polars_bail!(SQLSyntax: "invalid 'n_decimals' for ROUND ({})", function.args[1]),
_ => polars_bail!(SQLSyntax: "invalid decimals value for ROUND ({})", function.args[1]),
}))
}),
_ => polars_bail!(SQLSyntax: "invalid number of arguments for ROUND; expected 1 or 2, found {}", function.args.len()),
_ => polars_bail!(SQLSyntax: "invalid number of arguments for ROUND (expected 1-2, found {})", function.args.len()),
},
Sign => self.visit_unary(Expr::sign),
Sqrt => self.visit_unary(Expr::sqrt),
Expand Down Expand Up @@ -887,11 +887,11 @@ impl SQLFunctionVisitor<'_> {
2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
_ => polars_bail!(SQLSyntax: "invalid number of arguments for DATE: {}", function.args.len()),
},
DatePart => self.try_visit_binary(|e, part| {
DatePart => self.try_visit_binary(|part, e| {
match part {
Expr::Literal(LiteralValue::String(p)) => parse_date_part(e, &p),
_ => {
polars_bail!(SQLSyntax: "invalid 'part' for DATE_PART: {}", function.args[1]);
polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART: {}", function.args[1]);
}
}
}),
Expand Down Expand Up @@ -1119,7 +1119,7 @@ impl SQLFunctionVisitor<'_> {
self.apply_cumulative_window(f, cumulative_f, spec)
},
Some(WindowType::NamedWindow(named_window)) => polars_bail!(
SQLInterface: "Named windows are not supported yet; found {:?}",
SQLInterface: "Named windows are not currently supported; found {:?}",
named_window
),
_ => self.visit_unary(f),
Expand Down Expand Up @@ -1294,7 +1294,7 @@ impl SQLFunctionVisitor<'_> {
}
},
Some(WindowType::NamedWindow(named_window)) => polars_bail!(
SQLInterface: "Named windows are not supported yet; found: {:?}",
SQLInterface: "Named windows are not currently supported; found {:?}",
named_window
),
None => expr,
Expand Down
11 changes: 8 additions & 3 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,12 @@ fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
DateTimeField::Year => expr.dt().year(),
DateTimeField::Quarter => expr.dt().quarter(),
DateTimeField::Month => expr.dt().month(),
DateTimeField::Week(_) => expr.dt().week(),
DateTimeField::Week(weekday) => {
if weekday.is_some() {
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
}
expr.dt().week()
},
DateTimeField::IsoWeek => expr.dt().week(),
DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
DateTimeField::DayOfWeek | DateTimeField::Dow => {
Expand Down Expand Up @@ -1217,7 +1222,7 @@ fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
+ expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
},
_ => {
polars_bail!(SQLInterface: "EXTRACT function does not support {}", field)
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
},
})
}
Expand Down Expand Up @@ -1250,7 +1255,7 @@ pub(crate) fn parse_date_part(expr: Expr, part: &str) -> PolarsResult<Expr> {
"time" => &DateTimeField::Time,
"epoch" => &DateTimeField::Epoch,
_ => {
polars_bail!(SQLInterface: "DATE_PART function does not support '{}'", part)
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", part)
},
},
)
Expand Down
6 changes: 3 additions & 3 deletions py-polars/docs/source/reference/sql/functions/temporal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ Extracts a part of a date (or datetime) such as 'year', 'month', etc.
df.sql("""
SELECT
dt,
DATE_PART(dt, 'year') AS year,
DATE_PART(dt, 'month') AS month,
DATE_PART(dt, 'day') AS day
DATE_PART('year', dt) AS year,
DATE_PART('month', dt) AS month,
DATE_PART('day', dt) AS day
FROM self
""")
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/sql/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def test_round_ndigits_errors() -> None:
df = pl.DataFrame({"n": [99.999]})
with pl.SQLContext(df=df, eager=True) as ctx:
with pytest.raises(
SQLSyntaxError, match=r"invalid 'n_decimals' for ROUND \('!!'\)"
SQLSyntaxError, match=r"invalid decimals value for ROUND \('!!'\)"
):
ctx.execute("SELECT ROUND(n,'!!') AS n FROM df")
with pytest.raises(
SQLInterfaceError, match=r"ROUND .* negative 'n_decimals' \(-1\)"
SQLInterfaceError, match=r"ROUND .* negative decimals value \(-1\)"
):
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")

Expand Down
24 changes: 21 additions & 3 deletions py-polars/tests/unit/sql/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,31 @@ def test_extract(part: str, dtype: pl.DataType, expected: list[Any]) -> None:
}
)
with pl.SQLContext(frame_data=df, eager=True) as ctx:
for func in (f"EXTRACT({part} FROM dt)", f"DATE_PART(dt,'{part}')"):
for func in (f"EXTRACT({part} FROM dt)", f"DATE_PART('{part}',dt)"):
res = ctx.execute(f"SELECT {func} AS {part} FROM frame_data").to_series()

assert res.dtype == dtype
assert res.to_list() == expected


def test_extract_errors() -> None:
df = pl.DataFrame({"dt": [datetime(2024, 1, 7, 1, 2, 3, 123456)]})

with pl.SQLContext(frame_data=df, eager=True) as ctx:
for part in ("femtosecond", "stroopwafel"):
with pytest.raises(
SQLSyntaxError,
match=f"EXTRACT/DATE_PART does not support '{part}' part",
):
ctx.execute(f"SELECT EXTRACT({part} FROM dt) FROM frame_data")

with pytest.raises(
SQLSyntaxError,
match=r"EXTRACT/DATE_PART does not support 'week\(tuesday\)' part",
):
ctx.execute("SELECT DATE_PART('week(tuesday)', dt) FROM frame_data")


@pytest.mark.parametrize(
("dt", "expected"),
[
Expand All @@ -130,9 +148,9 @@ def test_extract_century_millennium(dt: date, expected: list[int]) -> None:
"""
SELECT
EXTRACT(MILLENNIUM FROM dt) AS c1,
DATE_PART(dt,'century') AS c2,
DATE_PART('century',dt) AS c2,
EXTRACT(millennium FROM dt) AS c3,
DATE_PART(dt,'CENTURY') AS c4,
DATE_PART('CENTURY',dt) AS c4,
FROM frame_data
"""
)
Expand Down

0 comments on commit 2398b47

Please sign in to comment.