Skip to content

Commit

Permalink
feat: Support SQL Struct/JSON field access operators (#17226)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Jun 28, 2024
1 parent cdfeb4f commit 8b72169
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 7 deletions.
7 changes: 7 additions & 0 deletions crates/polars-sql/src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub fn all_keywords() -> Vec<&'static str> {
let mut keywords = vec![];
keywords.extend_from_slice(PolarsTableFunctions::keywords());
keywords.extend_from_slice(PolarsSQLFunctions::keywords());

use sqlparser::keywords;
let sql_keywords = &[
keywords::AND,
Expand All @@ -30,6 +31,7 @@ pub fn all_keywords() -> Vec<&'static str> {
keywords::DISTINCT,
keywords::DOUBLE,
keywords::DROP,
keywords::EXCEPT,
keywords::EXCLUDE,
keywords::FLOAT,
keywords::FROM,
Expand All @@ -39,6 +41,7 @@ pub fn all_keywords() -> Vec<&'static str> {
keywords::IN,
keywords::INNER,
keywords::INT,
keywords::INTERSECT,
keywords::INTERVAL,
keywords::JOIN,
keywords::LEFT,
Expand All @@ -51,6 +54,8 @@ pub fn all_keywords() -> Vec<&'static str> {
keywords::ORDER,
keywords::OUTER,
keywords::REGEXP,
keywords::RENAME,
keywords::REPLACE,
keywords::RIGHT,
keywords::RLIKE,
keywords::SELECT,
Expand All @@ -60,6 +65,8 @@ pub fn all_keywords() -> Vec<&'static str> {
keywords::TABLES,
keywords::THEN,
keywords::TIME,
keywords::TRUNCATE,
keywords::UNION,
keywords::USING,
keywords::VARCHAR,
keywords::WHEN,
Expand Down
61 changes: 61 additions & 0 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,34 @@ impl SQLExprVisitor<'_> {
}
}

fn struct_field_access_expr(
&mut self,
expr: &Expr,
path: &str,
infer_index: bool,
) -> PolarsResult<Expr> {
let path_elems = if path.starts_with('{') && path.ends_with('}') {
path.trim_matches(|c| c == '{' || c == '}')
} else {
path
}
.split(',');

let mut expr = expr.clone();
for p in path_elems {
let p = p.trim();
expr = if infer_index {
match p.parse::<i64>() {
Ok(idx) => expr.list().get(lit(idx), true),
Err(_) => expr.struct_().field_by_name(p),
}
} else {
expr.struct_().field_by_name(p)
}
}
Ok(expr)
}

/// Visit a SQL binary operator.
///
/// e.g. "column + 1", "column1 <= column2"
Expand Down Expand Up @@ -672,6 +700,39 @@ impl SQLExprVisitor<'_> {
};
self.visit_expr(&expr)?
},
// ----
// JSON/Struct field access operators
// ----
SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs {
Expr::Literal(LiteralValue::String(path)) => {
let mut expr = self.struct_field_access_expr(&lhs, &path, false)?;
if let SQLBinaryOperator::LongArrow = op {
expr = expr.cast(DataType::String);
}
expr
},
Expr::Literal(LiteralValue::Int(idx)) => {
let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
if let SQLBinaryOperator::LongArrow = op {
expr = expr.cast(DataType::String);
}
expr
},
_ => {
polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
},
},
SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => {
if let Expr::Literal(LiteralValue::String(path)) = rhs {
let mut expr = self.struct_field_access_expr(&lhs, &path, true)?;
if let SQLBinaryOperator::HashLongArrow = op {
expr = expr.cast(DataType::String);
}
expr
} else {
polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
}
},
other => {
polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
},
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/tests/simple_exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ fn test_struct_field_selection() {
let sql = r#"
SELECT
json_msg.str AS id,
SUM(json_msg.num) AS sum_n
SUM(json_msg -> 'num') AS sum_n
FROM df
GROUP BY json_msg.str
ORDER BY 1
Expand Down
45 changes: 39 additions & 6 deletions py-polars/tests/unit/sql/test_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,30 @@ def test_struct_field_group_by_errors(df_struct: pl.DataFrame) -> None:
).collect()


@pytest.mark.parametrize(
("expr", "expected"),
[
("nested #> '{c,1}'", 2),
("nested #> '{c,-1}'", 1),
("nested #>> '{c,0}'", "3"),
("nested -> '0' -> 0", "baz"),
("nested -> 'c' -> -1", 1),
("nested -> 'c' ->> 2", "1"),
],
)
def test_struct_field_operator_access(expr: str, expected: int | str) -> None:
df = pl.DataFrame(
{
"nested": {
"0": ["baz"],
"b": ["foo", "bar"],
"c": [3, 2, 1],
},
},
)
assert df.sql(f"SELECT {expr} FROM self").item() == expected


@pytest.mark.parametrize(
("fields", "excluding", "rename"),
[
Expand Down Expand Up @@ -127,15 +151,24 @@ def test_struct_field_selection_wildcards(


@pytest.mark.parametrize(
"invalid_column",
("invalid_column", "error_type"),
[
"json_msg.invalid_column",
"json_msg.other.invalid_column",
"self.json_msg.other.invalid_column",
("json_msg.invalid_column", StructFieldNotFoundError),
("json_msg.other.invalid_column", StructFieldNotFoundError),
("self.json_msg.other.invalid_column", StructFieldNotFoundError),
("json_msg.other -> invalid_column", SQLSyntaxError),
("json_msg -> DATE '2020-09-11'", SQLSyntaxError),
],
)
def test_struct_field_selection_errors(
invalid_column: str, df_struct: pl.DataFrame
invalid_column: str,
error_type: type[Exception],
df_struct: pl.DataFrame,
) -> None:
with pytest.raises(StructFieldNotFoundError, match="invalid_column"):
error_msg = (
"invalid json/struct path-extract"
if ("->" in invalid_column)
else "invalid_column"
)
with pytest.raises(error_type, match=error_msg):
df_struct.sql(f"SELECT {invalid_column} FROM self")

0 comments on commit 8b72169

Please sign in to comment.