Skip to content

Commit

Permalink
feat(rust,python,cli): add SQL engine support for EXTRACT and `DATE…
Browse files Browse the repository at this point in the history
…_PART` (#13603)
  • Loading branch information
alexander-beedie authored Jan 10, 2024
1 parent 21bb7b8 commit a8bdc76
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 14 deletions.
31 changes: 25 additions & 6 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use sqlparser::ast::{
WindowSpec, WindowType,
};

use crate::sql_expr::parse_sql_expr;
use crate::sql_expr::{parse_date_part, parse_sql_expr};
use crate::SQLContext;

pub(crate) struct SQLFunctionVisitor<'a> {
Expand Down Expand Up @@ -247,6 +247,12 @@ pub(crate) enum PolarsSQLFunctions {
/// SELECT DATE('2021-03', '%Y-%m') from df;
/// ```
Date,
/// SQL 'date_part' function.
/// Extracts a part of a date (or datetime) such as 'year', 'month', etc.
/// ```sql
/// SELECT DATE_PART('year', column_1) from df;
/// SELECT DATE_PART('day', column_1) from df;
DatePart,

// ----
// String functions
Expand Down Expand Up @@ -664,6 +670,7 @@ impl PolarsSQLFunctions {
// Date functions
// ----
"date" => Self::Date,
"date_part" => Self::DatePart,

// ----
// String functions
Expand Down Expand Up @@ -808,6 +815,23 @@ impl SQLFunctionVisitor<'_> {
},
NullIf => self.visit_binary(|l: Expr, r: Expr| when(l.clone().eq(r)).then(lit(LiteralValue::Null)).otherwise(l)),

// ----
// Date functions
// ----
Date => match function.args.len() {
1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for Date: {}", function.args.len()),
},
DatePart => self.try_visit_binary(|e, part| {
match part {
Expr::Literal(LiteralValue::String(p)) => parse_date_part(e, &p),
_ => {
polars_bail!(InvalidOperation: "Invalid 'part' for DatePart: {}", function.args[1]);
}
}
}),

// ----
// String functions
// ----
Expand All @@ -827,11 +851,6 @@ impl SQLFunctionVisitor<'_> {
}
})
},
Date => match function.args.len() {
1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for Date: {}", function.args.len()),
},
EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Expand Down
92 changes: 85 additions & 7 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Div;

use polars_core::export::regex;
use polars_core::prelude::*;
use polars_error::to_compute_err;
Expand All @@ -9,8 +11,9 @@ use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use sqlparser::ast::{
ArrayAgg, ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat,
DataType as SQLDataType, Expr as SQLExpr, Function as SQLFunction, Ident, JoinConstraint,
OrderByExpr, Query as Subquery, SelectItem, TrimWhereField, UnaryOperator, Value as SQLValue,
DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident,
JoinConstraint, OrderByExpr, Query as Subquery, SelectItem, TrimWhereField, UnaryOperator,
Value as SQLValue,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};
Expand Down Expand Up @@ -101,6 +104,7 @@ impl SQLExprVisitor<'_> {
} => self.visit_cast(expr, data_type, format),
SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
SQLExpr::Extract { field, expr } => parse_extract(self.visit_expr(expr)?, field),
SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
SQLExpr::Function(function) => self.visit_function(function),
SQLExpr::Identifier(ident) => self.visit_identifier(ident),
Expand Down Expand Up @@ -759,11 +763,6 @@ impl SQLExprVisitor<'_> {
}
}

pub(crate) fn parse_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Expr> {
let mut visitor = SQLExprVisitor { ctx };
visitor.visit_expr(expr)
}

pub(super) fn process_join(
left_tbl: LazyFrame,
right_tbl: LazyFrame,
Expand Down Expand Up @@ -912,3 +911,82 @@ pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
_ => polars_bail!(InvalidOperation: "Unable to parse '{}' as Expr", s.as_ref()),
})
}

pub(crate) fn parse_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Expr> {
let mut visitor = SQLExprVisitor { ctx };
visitor.visit_expr(expr)
}

fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
Ok(match field {
DateTimeField::Decade => expr.dt().year() / lit(10),
DateTimeField::Isoyear => expr.dt().iso_year(),
DateTimeField::Year => expr.dt().year(),
DateTimeField::Quarter => expr.dt().quarter(),
DateTimeField::Month => expr.dt().month(),
DateTimeField::Week => expr.dt().week(),
DateTimeField::IsoWeek => expr.dt().week(),
DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
DateTimeField::DayOfWeek | DateTimeField::Dow => {
let w = expr.dt().weekday();
when(w.clone().eq(lit(7i8))).then(lit(0i8)).otherwise(w)
},
DateTimeField::Isodow => expr.dt().weekday(),
DateTimeField::Day => expr.dt().day(),
DateTimeField::Hour => expr.dt().hour(),
DateTimeField::Minute => expr.dt().minute(),
DateTimeField::Second => expr.dt().second(),
DateTimeField::Millisecond | DateTimeField::Milliseconds => {
(expr.clone().dt().second() * lit(1_000))
+ expr.dt().nanosecond().div(lit(1_000_000f64))
},
DateTimeField::Microsecond | DateTimeField::Microseconds => {
(expr.clone().dt().second() * lit(1_000_000))
+ expr.dt().nanosecond().div(lit(1_000f64))
},
DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
(expr.clone().dt().second() * lit(1_000_000_000f64)) + expr.dt().nanosecond()
},
DateTimeField::Time => expr.dt().time(),
DateTimeField::Epoch => {
expr.clone()
.dt()
.timestamp(TimeUnit::Nanoseconds)
.div(lit(1_000_000_000i64))
+ expr.dt().nanosecond().div(lit(1_000_000_000f64))
},
_ => {
polars_bail!(ComputeError: "Extract function does not yet support {}", field)
},
})
}

pub(crate) fn parse_date_part(expr: Expr, part: &str) -> PolarsResult<Expr> {
let part = part.to_ascii_lowercase();
parse_extract(
expr,
match part.as_str() {
"decade" => &DateTimeField::Decade,
"isoyear" => &DateTimeField::Isoyear,
"year" => &DateTimeField::Year,
"quarter" => &DateTimeField::Quarter,
"month" => &DateTimeField::Month,
"dayofyear" | "doy" => &DateTimeField::DayOfYear,
"dayofweek" | "dow" => &DateTimeField::DayOfWeek,
"isoweek" | "week" => &DateTimeField::IsoWeek,
"isodow" => &DateTimeField::Isodow,
"day" => &DateTimeField::Day,
"hour" => &DateTimeField::Hour,
"minute" => &DateTimeField::Minute,
"second" => &DateTimeField::Second,
"millisecond" | "milliseconds" => &DateTimeField::Millisecond,
"microsecond" | "microseconds" => &DateTimeField::Microsecond,
"nanosecond" | "nanoseconds" => &DateTimeField::Nanosecond,
"time" => &DateTimeField::Time,
"epoch" => &DateTimeField::Epoch,
_ => {
polars_bail!(ComputeError: "Date part '{}' not supported", part)
},
},
)
}
57 changes: 56 additions & 1 deletion py-polars/tests/unit/sql/test_temporal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from datetime import date
from datetime import date, datetime, time
from typing import Any

import pytest

import polars as pl
from polars.testing import assert_frame_equal
Expand All @@ -26,3 +29,55 @@ def test_date() -> None:
result = pl.select(pl.sql_expr("""CAST(DATE('2023-03', '%Y-%m') as STRING)"""))
expected = pl.DataFrame({"literal": ["2023-03-01"]})
assert_frame_equal(result, expected)


@pytest.mark.parametrize(
("part", "dtype", "expected"),
[
("decade", pl.Int32, [202, 202, 200]),
("isoyear", pl.Int32, [2024, 2020, 2005]),
("year", pl.Int32, [2024, 2020, 2006]),
("quarter", pl.Int8, [1, 4, 1]),
("month", pl.Int8, [1, 12, 1]),
("week", pl.Int8, [1, 53, 52]),
("doy", pl.Int16, [7, 365, 1]),
("isodow", pl.Int8, [7, 3, 7]),
("dow", pl.Int8, [0, 3, 0]),
("day", pl.Int8, [7, 30, 1]),
("hour", pl.Int8, [1, 10, 23]),
("minute", pl.Int8, [2, 30, 59]),
("second", pl.Int8, [3, 45, 59]),
("millisecond", pl.Float64, [3123.456, 45987.654, 59555.555]),
("microsecond", pl.Float64, [3123456.0, 45987654.0, 59555555.0]),
("nanosecond", pl.Float64, [3123456000.0, 45987654000.0, 59555555000.0]),
(
"time",
pl.Time,
[time(1, 2, 3, 123456), time(10, 30, 45, 987654), time(23, 59, 59, 555555)],
),
(
"epoch",
pl.Float64,
[1704589323.123456, 1609324245.987654, 1136159999.555555],
),
],
)
def test_extract_datepart(part: str, dtype: pl.DataType, expected: list[Any]) -> None:
df = pl.DataFrame(
{
"dt": [
# note: these values test several edge-cases, such as isoyear,
# the mon/sun wrapping of dow vs isodow, epoch rounding, etc,
# and the results have been validated against postgresql.
datetime(2024, 1, 7, 1, 2, 3, 123456),
datetime(2020, 12, 30, 10, 30, 45, 987654),
datetime(2006, 1, 1, 23, 59, 59, 555555),
],
}
)
with pl.SQLContext(frame_data=df, eager_execution=True) as ctx:
for func in (f"EXTRACT({part} FROM dt)", f"DATE_PART(dt,'{part}')"):
res = ctx.execute(f"SELECT {func} AS {part} FROM frame_data").to_series()

assert res.dtype == dtype
assert res.to_list() == expected

0 comments on commit a8bdc76

Please sign in to comment.