diff --git a/joinery_macros/src/lib.rs b/joinery_macros/src/lib.rs index a65d43c..f0cc60a 100644 --- a/joinery_macros/src/lib.rs +++ b/joinery_macros/src/lib.rs @@ -2,8 +2,8 @@ use std::borrow::Cow; use darling::{util::Flag, FromField}; use proc_macro::TokenStream; -use proc_macro2::{Span, TokenStream as TokenStream2}; -use quote::quote; +use proc_macro2::{Delimiter, Span, TokenStream as TokenStream2, TokenTree}; +use quote::{quote, quote_spanned}; use syn::{spanned::Spanned, Field, Ident}; #[proc_macro_derive(Emit)] @@ -221,3 +221,92 @@ struct EmitAttr { /// Should we omit this field from our output? skip: Flag, } + +#[proc_macro] +pub fn sql_quote(input: TokenStream) -> TokenStream { + let input = TokenStream2::from(input); + + let mut sql_token_exprs = vec![]; + for token in input { + emit_sql_token_exprs(&mut sql_token_exprs, token); + } + let output = quote! { + crate::tokenizer::TokenStream::from_tokens(&[#(#sql_token_exprs),*][..]) + }; + output.into() +} + +fn emit_sql_token_exprs(sql_token_exprs: &mut Vec, token: TokenTree) { + match token { + TokenTree::Group(group) => { + // We flatten this and use `Punct::new`. + let (open, close) = delimiter_pair(group.delimiter()); + if let Some(open) = open { + sql_token_exprs.push(quote! { + crate::tokenizer::Token::Punct(crate::tokenizer::Punct::new(#open)) + }); + } + for token in group.stream() { + emit_sql_token_exprs(sql_token_exprs, token); + } + if let Some(close) = close { + sql_token_exprs.push(quote! { + crate::tokenizer::Token::Punct(crate::tokenizer::Punct::new(#close)) + }); + } + } + TokenTree::Ident(ident) => { + let ident_str = ident.to_string(); + sql_token_exprs.push(quote! { + crate::tokenizer::Token::Ident(crate::tokenizer::Ident::new(#ident_str)) + }); + } + TokenTree::Punct(punct) => { + let punct_str = punct.to_string(); + sql_token_exprs.push(quote! { + crate::tokenizer::Token::Punct(crate::tokenizer::Punct::new(#punct_str)) + }); + } + TokenTree::Literal(lit) => { + // There's probably a better way to do this. + let lit: syn::Lit = syn::parse_quote!(#lit); + match lit { + syn::Lit::Int(i) => { + sql_token_exprs.push(quote! { + crate::tokenizer::Token::Literal(crate::tokenizer::Literal::int(#i)) + }); + } + syn::Lit::Str(s) => { + sql_token_exprs.push(quote! { + crate::tokenizer::Token::Literal(crate::tokenizer::Literal::string(#s)) + }); + } + syn::Lit::Float(f) => { + sql_token_exprs.push(quote! { + crate::tokenizer::Token::Literal(crate::tokenizer::Literal::float(#f)) + }); + } + // syn::Lit::ByteStr(_) => todo!(), + // syn::Lit::Byte(_) => todo!(), + // syn::Lit::Char(_) => todo!(), + // syn::Lit::Bool(_) => todo!(), + // syn::Lit::Verbatim(_) => todo!(), + _ => { + sql_token_exprs.push(quote_spanned! { + lit.span() => + compile_error!("unsupported literal type") + }); + } + } + } + } +} + +fn delimiter_pair(d: Delimiter) -> (Option<&'static str>, Option<&'static str>) { + match d { + Delimiter::Parenthesis => (Some("("), Some(")")), + Delimiter::Brace => (Some("{"), Some("}")), + Delimiter::Bracket => (Some("["), Some("]")), + Delimiter::None => (None, None), + } +} diff --git a/src/tokenizer.rs b/src/tokenizer.rs index e0f2eb1..b3cb040 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -327,6 +327,32 @@ pub struct Literal { pub value: LiteralValue, } +impl Literal { + /// Construct a literal containing an integer. + pub fn int(i: i64) -> Self { + Self { + token: RawToken::new(&i.to_string()), + value: LiteralValue::Int64(i), + } + } + + /// Construct a literal containing a floating-point number. + pub fn float(d: f64) -> Self { + Self { + token: RawToken::new(&d.to_string()), + value: LiteralValue::Float64(d), + } + } + + /// Construct a literal containing a string. + pub fn string(s: &str) -> Self { + Self { + token: RawToken::new(&BigQueryString(s).to_string()), + value: LiteralValue::String(s.to_owned()), + } + } +} + /// A literal value. /// /// Does not include literals like `TRUE`, `FALSE` or `NULL`, which are parsed @@ -377,6 +403,14 @@ pub struct TokenStream { } impl TokenStream { + /// Create from tokens. + #[allow(dead_code)] + pub fn from_tokens>>(tokens: Tokens) -> Self { + Self { + tokens: tokens.into(), + } + } + /// Parse a literal. pub fn literal(&self, pos: usize) -> RuleResult { match self.tokens.get(pos) { @@ -853,6 +887,8 @@ peg::parser! { #[cfg(test)] mod test { + use joinery_macros::sql_quote; + use super::*; #[test] @@ -976,4 +1012,17 @@ mod test { }; assert_eq!(parsed, expected); } + + #[test] + fn sql_quote_builds_a_token_stream() { + sql_quote! { + SELECT + generate_uuid() AS id, + "hello" AS message, + 1 AS n, + 1.0 AS x, + true AS t, + false AS f, + }; + } }