Skip to content

Commit

Permalink
Add support for search macro substitution
Browse files Browse the repository at this point in the history
  • Loading branch information
scnerd committed Sep 13, 2024
1 parent 28d694c commit 28432ce
Show file tree
Hide file tree
Showing 10 changed files with 537 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ classifiers = [
]
dependencies = [
"black",
"pydantic>=2",
]

[project.optional-dependencies]
test = [
"pytest"
"pytest",
"pytest-dependency"
]
cli = [
"textual[syntax]"
Expand Down
5 changes: 5 additions & 0 deletions python/spl_transpiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
ast,
)

from .macros import substitute_macros, parse_with_macros, MacroDefinition

__all__ = (
"__version__",
"parse",
"render_pyspark",
"convert_spl_to_pyspark",
"ast",
"MacroDefinition",
"substitute_macros",
"parse_with_macros",
)
66 changes: 66 additions & 0 deletions python/spl_transpiler/macros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import re
from typing import Protocol

from pydantic import BaseModel

from spl_transpiler.spl_transpiler import detect_macros, parse


class MacroDefinition(BaseModel):
arguments: list[str] | None = None
definition: str


class MacroLoader(Protocol):
def __getitem__(self, item: str) -> dict | MacroDefinition:
pass


def substitute_macros(code, macros: MacroLoader):
(chunks, suffix) = detect_macros(code)
query_parts = []
for prefix, macro_call in chunks:
query_parts.append(prefix)

macro_name = macro_call.macro_name
args = macro_call.args

macro_spec = MacroDefinition.model_validate(macros[macro_name])

macro_final_value = macro_spec.definition

if args:
if all(name is None for name, _ in args):
if macro_spec.arguments is None or len(macro_spec.arguments) != len(
args
):
raise ValueError(
f"Mismatched number of arguments in macro call: expected {len(macro_spec.arguments or [])}, got {len(args)}"
)
args = dict(zip(macro_spec.arguments, [value for _, value in args]))
elif all(name is not None for name, _ in args):
if macro_spec.arguments is None or len(macro_spec.arguments) != len(
args
):
raise ValueError(
f"Mismatched number of arguments in macro call: expected {len(macro_spec.arguments or [])}, got {len(args)}"
)
args = dict(args)
else:
raise ValueError(
"Mixture of named and positional arguments in macro call"
)

for arg_name, arg_substitute_value in args.items():
macro_final_value = re.sub(
f"\\${arg_name}\\$", arg_substitute_value, macro_final_value
)

query_parts.append(macro_final_value)

query_parts.append(suffix)
return "".join(query_parts)


def parse_with_macros(code: str, macros: MacroLoader, *args, **kwargs):
return parse(substitute_macros(code, macros), *args, **kwargs)
12 changes: 12 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ fn spl_transpiler(m: &Bound<'_, PyModule>) -> PyResult<()> {
}
}

#[pyfn(m)]
/// Parses SPL query code into a syntax tree.
fn detect_macros(spl_code: &str) -> PyResult<(Vec<(&str, spl::macros::MacroCall)>, &str)> {
match spl::macros::spl_macros(spl_code) {
Ok(("", res)) => Ok(res),
Ok(_) => Err(PyValueError::new_err("Failed to fully parse input")),
Err(e) => Err(PyValueError::new_err(format!("Error parsing SPL: {}", e))),
}
}

#[pyfn(m)]
#[pyo3(signature = (pipeline, format=true))]
/// Renders a parsed SPL syntax tree into equivalent PySpark query code, if possible.
Expand All @@ -42,6 +52,8 @@ fn spl_transpiler(m: &Bound<'_, PyModule>) -> PyResult<()> {
render_pyspark(&pipeline, format)
}

m.add_class::<spl::macros::MacroCall>()?;

let ast_m = PyModule::new_bound(m.py(), "ast")?;
spl::python::ast(&ast_m)?;
m.add_submodule(&ast_m)?;
Expand Down
33 changes: 33 additions & 0 deletions src/pyspark/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ pub enum DataFrame {
Source {
name: String,
},
Named {
source: Box<DataFrame>,
name: String,
},
Select {
source: Box<DataFrame>,
columns: Vec<ColumnLike>,
Expand Down Expand Up @@ -253,6 +257,12 @@ impl DataFrame {
name: name.to_string(),
}
}
pub fn named(&self, name: impl ToString) -> DataFrame {
DataFrame::Named {
source: Box::new(self.clone()),
name: name.to_string(),
}
}
pub fn select(&self, columns: Vec<ColumnLike>) -> DataFrame {
Self::Select {
source: Box::new(self.clone()),
Expand Down Expand Up @@ -338,6 +348,12 @@ impl TemplateNode for DataFrame {
fn to_spark_query(&self) -> Result<String> {
match self {
DataFrame::Source { name } => Ok(format!("spark.table('{}')", name)),
DataFrame::Named { source, name } => Ok(format!(
"{}.write.saveAsTable('{}')\n\n{}",
source.to_spark_query()?,
name,
DataFrame::source(name).to_spark_query()?,
)),
DataFrame::Select { source, columns } => {
let columns: Result<Vec<String>> =
columns.iter().map(|col| col.to_spark_query()).collect();
Expand Down Expand Up @@ -525,4 +541,21 @@ mod tests {
r#"spark.table("main").withColumn("final_name", F.col("orig_name"))"#,
)
}

#[test]
fn test_named() {
generates(
DataFrame::source("main")
.with_column(
"final_name",
column_like!([col("orig_name")].alias("alias_name")),
)
.named("prev"),
r#"
spark.table("main").withColumn("final_name", F.col("orig_name")).write.saveAsTable("prev")
spark.table("prev")
"#,
)
}
}
1 change: 1 addition & 0 deletions src/spl/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ use float_derive::FloatHash;
use pyo3::prelude::*;
use std::collections::HashMap;
use std::fmt::Debug;
// use crate::spl::to_spl::ToSpl;

/// Syntax tree element representing a null literal value.
#[derive(Debug, PartialEq, Clone, Hash)]
Expand Down
188 changes: 188 additions & 0 deletions src/spl/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// use crate::spl::ast;
use crate::spl::parser::{expr, token, ws};
use nom::bytes::complete::{tag, take_until};
use nom::character::complete::none_of;
use nom::combinator::{all_consuming, into, map, opt, recognize};
use nom::multi::{many0, separated_list0};
use nom::sequence::{delimited, pair, terminated};
use nom::IResult;
use pyo3::pyclass;

#[derive(Debug, PartialEq, Clone, Hash)]
#[pyclass(frozen, eq, hash)]
pub struct MacroCall {
#[pyo3(get)]
pub macro_name: String,
#[pyo3(get)]
pub args: Vec<(Option<String>, String)>,
}

type ChunkedQuery<'a> = (Vec<(&'a str, MacroCall)>, &'a str);

pub fn spl_macros(input: &str) -> IResult<&str, ChunkedQuery> {
all_consuming(pair(
many0(pair(
take_until("`"),
delimited(
tag("`"),
ws(map(
pair(
ws(token),
opt(delimited(
ws(tag("(")),
separated_list0(
ws(tag(",")),
pair(
opt(into(terminated(token, tag("=")))),
into(recognize(expr)),
),
),
ws(tag(")")),
)),
),
|(name, args)| MacroCall {
macro_name: name.to_string(),
args: args.unwrap_or(Vec::new()),
},
)),
tag("`"),
),
)),
recognize(many0(none_of("`"))),
))(input)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_spl_macros_simple_no_args() {
let input = r#"`foo`"#;
let result = spl_macros(input).unwrap();
assert_eq!(
result,
(
"",
(
vec![(
"",
MacroCall {
macro_name: "foo".to_string(),
args: Vec::new()
}
)],
""
)
)
);
}

#[test]
fn test_spl_macros_simple_with_args() {
let input = r#"`foo(bar, 1, "s")`"#;
let result = spl_macros(input).unwrap();
assert_eq!(
result,
(
"",
(
vec![(
"",
MacroCall {
macro_name: "foo".to_string(),
args: vec![
(None, "bar".to_string()),
(None, "1".to_string()),
(None, "\"s\"".to_string()),
]
}
)],
""
)
)
);
}

#[test]
fn test_spl_macros_simple_named_args() {
let input = r#"`foo(foo=bar, baz=1)`"#;
let result = spl_macros(input).unwrap();
assert_eq!(
result,
(
"",
(
vec![(
"",
MacroCall {
macro_name: "foo".to_string(),
args: vec![
(Some("foo".to_string()), "bar".to_string()),
(Some("baz".to_string()), "1".to_string()),
]
}
)],
""
)
)
);
}

#[test]
fn test_spl_macros_multiple() {
let input = r#"index=main | `foo(bar, 1, "s")` x=`f` y=3"#;
let result = spl_macros(input).unwrap();
assert_eq!(
result,
(
"",
(
vec![
(
"index=main | ",
MacroCall {
macro_name: "foo".to_string(),
args: vec![
(None, "bar".to_string()),
(None, "1".to_string()),
(None, "\"s\"".to_string()),
]
}
),
(
" x=",
MacroCall {
macro_name: "f".to_string(),
args: Vec::new(),
}
)
],
" y=3"
)
)
);
}

#[test]
fn test_spl_macros_quoted_backtick() {
let input = r#"`foo("`")`"#;
let result = spl_macros(input).unwrap();
assert_eq!(
result,
(
"",
(
vec![(
"",
MacroCall {
macro_name: "foo".to_string(),
args: vec![(None, "\"`\"".to_string()),]
}
)],
""
)
)
);
}
}
2 changes: 2 additions & 0 deletions src/spl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod ast;
pub mod macros;
pub mod operators;
pub(crate) mod parser;
pub mod python;
// pub mod to_spl;
Loading

0 comments on commit 28432ce

Please sign in to comment.