diff --git a/pgrx-sql-entity-graph/src/pg_extern/returning.rs b/pgrx-sql-entity-graph/src/pg_extern/returning.rs index 72683af11..fc884da11 100644 --- a/pgrx-sql-entity-graph/src/pg_extern/returning.rs +++ b/pgrx-sql-entity-graph/src/pg_extern/returning.rs @@ -22,7 +22,7 @@ use std::convert::TryFrom; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::{GenericArgument, PathArguments, Token, Type}; +use syn::{Error, GenericArgument, PathArguments, Token, Type}; #[derive(Debug, Clone)] pub struct ReturningIteratedItem { @@ -54,297 +54,278 @@ impl Returning { )), } } -} -impl TryFrom<&syn::ReturnType> for Returning { - type Error = syn::Error; + fn match_type(ty: &Box) -> Result { + let mut ty = *ty.clone(); - fn try_from(value: &syn::ReturnType) -> Result { - match &value { - syn::ReturnType::Default => Ok(Returning::None), - syn::ReturnType::Type(_, ty) => { - let mut ty = *ty.clone(); + match ty { + syn::Type::Path(mut typepath) => { + let path = &mut typepath.path; + let mut saw_option_ident = false; + let mut saw_result_ident = false; + let mut saw_setof_iterator = false; + let mut saw_table_iterator = false; - match ty { - syn::Type::Path(mut typepath) => { - let path = &mut typepath.path; - let mut saw_option_ident = false; - let mut saw_result_ident = false; - let mut saw_setof_iterator = false; - let mut saw_table_iterator = false; - - for segment in &mut path.segments { - let ident_string = segment.ident.to_string(); - match ident_string.as_str() { - "Option" => saw_option_ident = true, - "Result" => saw_result_ident = true, - "SetOfIterator" => saw_setof_iterator = true, - "TableIterator" => saw_table_iterator = true, - _ => (), - }; - } - if saw_option_ident - || saw_result_ident - || saw_setof_iterator - || saw_table_iterator - { - let option_inner_path = if saw_option_ident || saw_result_ident { - match path.segments.last_mut().map(|s| &mut s.arguments) { - Some(syn::PathArguments::AngleBracketed(args)) => { - let args_span = args.span(); - match args.args.first_mut() { - Some(syn::GenericArgument::Type(syn::Type::Path(syn::TypePath { qself: _, path }))) => path.clone(), - Some(syn::GenericArgument::Type(_)) => { - let used_ty = UsedType::new(syn::Type::Path(typepath.clone()))?; - return Ok(Returning::Type(used_ty)) - }, - other => { - return Err(syn::Error::new( - other.as_ref().map(|s| s.span()).unwrap_or(args_span), - &format!( - "Got unexpected generic argument for Option inner: {other:?}" - ), - )) - } - } - } + for segment in &mut path.segments { + let ident_string = segment.ident.to_string(); + match ident_string.as_str() { + "Option" => saw_option_ident = true, + "Result" => saw_result_ident = true, + "SetOfIterator" => saw_setof_iterator = true, + "TableIterator" => saw_table_iterator = true, + _ => (), + }; + } + if saw_option_ident || saw_result_ident || saw_setof_iterator || saw_table_iterator + { + let option_inner_path = if saw_option_ident || saw_result_ident { + match path.segments.last_mut().map(|s| &mut s.arguments) { + Some(syn::PathArguments::AngleBracketed(args)) => { + let args_span = args.span(); + match args.args.first_mut() { + Some(syn::GenericArgument::Type(syn::Type::Path(syn::TypePath { qself: _, path }))) => path.clone(), + Some(syn::GenericArgument::Type(_)) => { + let used_ty = UsedType::new(syn::Type::Path(typepath.clone()))?; + return Ok(Returning::Type(used_ty)) + }, other => { return Err(syn::Error::new( - other.span(), + other.as_ref().map(|s| s.span()).unwrap_or(args_span), &format!( - "Got unexpected path argument for Option inner: {other:?}" + "Got unexpected generic argument for Option inner: {other:?}" ), )) } } - } else { - path.clone() - }; + } + other => { + return Err(syn::Error::new( + other.span(), + &format!( + "Got unexpected path argument for Option inner: {other:?}" + ), + )) + } + } + } else { + path.clone() + }; - let mut segments = option_inner_path.segments.clone(); - let mut found_option = false; - 'outer: loop { - for segment in &segments { - let ident_string = segment.ident.to_string(); - match ident_string.as_str() { - "Option" => match &segment.arguments { - PathArguments::AngleBracketed(bracketed) => { - match bracketed.args.first().unwrap() { - GenericArgument::Type(ty) => match ty { - Type::Path(this_path) => { - segments = - this_path.path.segments.clone(); - saw_option_ident = true; - found_option = true; - continue 'outer; - } - _ => continue, - }, - _ => continue, - }; - } + let mut segments = option_inner_path.segments.clone(); + let mut found_option = false; + 'outer: loop { + for segment in &segments { + let ident_string = segment.ident.to_string(); + match ident_string.as_str() { + "Option" => match &segment.arguments { + PathArguments::AngleBracketed(bracketed) => { + match bracketed.args.first().unwrap() { + GenericArgument::Type(ty) => match ty { + Type::Path(this_path) => { + segments = this_path.path.segments.clone(); + saw_option_ident = true; + found_option = true; + continue 'outer; + } + _ => continue, + }, _ => continue, - }, - "SetOfIterator" => saw_setof_iterator = true, - "TableIterator" => { - if found_option { - segments = Punctuated::from_iter(std::iter::once( - segment.clone(), - )); - found_option = false; - continue 'outer; - } - saw_table_iterator = true - } - _ => (), - }; + }; + } + _ => continue, + }, + "SetOfIterator" => saw_setof_iterator = true, + "TableIterator" => { + if found_option { + segments = + Punctuated::from_iter(std::iter::once(segment.clone())); + found_option = false; + continue 'outer; + } + saw_table_iterator = true } - break; - } + _ => (), + }; + } + break; + } - if saw_setof_iterator { - let last_path_segment = option_inner_path.segments.last(); - let (used_ty, optional) = match &last_path_segment.map(|ps| &ps.arguments) { - Some(syn::PathArguments::AngleBracketed(args)) => { - match args.args.last().unwrap() { - syn::GenericArgument::Type(ty) => { - match &ty { - syn::Type::Path(path) => { - (UsedType::new(syn::Type::Path(path.clone()))?, saw_option_ident) - } - syn::Type::Macro(type_macro) => { - (UsedType::new(syn::Type::Macro(type_macro.clone()),)?, saw_option_ident) - }, - reference @ syn::Type::Reference(_) => { - (UsedType::new((*reference).clone(),)?, saw_option_ident) - }, - ty => return Err(syn::Error::new( - ty.span(), - "SetOf Iterator must have an item", - )), - } - } - other => { - return Err(syn::Error::new( - other.span(), - &format!( - "Got unexpected generic argument for SetOfIterator: {other:?}" - ), - )) + if saw_setof_iterator { + let last_path_segment = option_inner_path.segments.last(); + let (used_ty, optional) = match &last_path_segment.map(|ps| &ps.arguments) { + Some(syn::PathArguments::AngleBracketed(args)) => { + match args.args.last().unwrap() { + syn::GenericArgument::Type(ty) => { + match &ty { + syn::Type::Path(path) => { + (UsedType::new(syn::Type::Path(path.clone()))?, saw_option_ident) } + syn::Type::Macro(type_macro) => { + (UsedType::new(syn::Type::Macro(type_macro.clone()), )?, saw_option_ident) + }, + reference @ syn::Type::Reference(_) => { + (UsedType::new((*reference).clone(), )?, saw_option_ident) + }, + ty => return Err(syn::Error::new( + ty.span(), + "SetOf Iterator must have an item", + )), } } other => { return Err(syn::Error::new( - other.map(|s| s.span()).unwrap_or_else(proc_macro2::Span::call_site), + other.span(), &format!( - "Got unexpected path argument for SetOfIterator: {other:?}" + "Got unexpected generic argument for SetOfIterator: {other:?}" ), )) } - }; - Ok(Returning::SetOf { - ty: used_ty, - optional, - result: saw_result_ident, - }) - } else if saw_table_iterator { - let last_path_segment = segments.last_mut().unwrap(); - let mut iterated_items = vec![]; + } + } + other => { + return Err(syn::Error::new( + other + .map(|s| s.span()) + .unwrap_or_else(proc_macro2::Span::call_site), + &format!( + "Got unexpected path argument for SetOfIterator: {other:?}" + ), + )) + } + }; + Ok(Returning::SetOf { ty: used_ty, optional, result: saw_result_ident }) + } else if saw_table_iterator { + let last_path_segment = segments.last_mut().unwrap(); + let mut iterated_items = vec![]; - match &mut last_path_segment.arguments { - syn::PathArguments::AngleBracketed(args) => { - match args.args.last_mut().unwrap() { - syn::GenericArgument::Type(syn::Type::Tuple( - type_tuple, - )) => { - for elem in &type_tuple.elems { - match &elem { - syn::Type::Path(path) => { + match &mut last_path_segment.arguments { + syn::PathArguments::AngleBracketed(args) => { + match args.args.last_mut().unwrap() { + syn::GenericArgument::Type(syn::Type::Tuple(type_tuple)) => { + for elem in &type_tuple.elems { + match &elem { + syn::Type::Path(path) => { + let iterated_item = ReturningIteratedItem { + name: None, + used_ty: UsedType::new(syn::Type::Path( + path.clone(), + ))?, + }; + iterated_items.push(iterated_item); + } + syn::Type::Macro(type_macro) => { + let mac = &type_macro.mac; + let archetype = + mac.path.segments.last().unwrap(); + match archetype.ident.to_string().as_str() { + "name" => { + let out: NameMacro = + mac.parse_body()?; let iterated_item = ReturningIteratedItem { - name: None, - used_ty: UsedType::new( - syn::Type::Path( - path.clone(), - ), - )?, + name: Some(out.ident), + used_ty: out.used_ty, }; - iterated_items.push(iterated_item); + iterated_items.push(iterated_item) } - syn::Type::Macro(type_macro) => { - let mac = &type_macro.mac; - let archetype = - mac.path.segments.last().unwrap(); - match archetype - .ident - .to_string() - .as_str() - { - "name" => { - let out: NameMacro = - mac.parse_body()?; - let iterated_item = - ReturningIteratedItem { - name: Some(out.ident), - used_ty: out.used_ty, - }; - iterated_items - .push(iterated_item) - } - _ => { - let iterated_item = - ReturningIteratedItem { - name: None, - used_ty: UsedType::new( - syn::Type::Macro( - type_macro - .clone(), - ), - )?, - }; - iterated_items - .push(iterated_item); - } - } - } - reference @ syn::Type::Reference(_) => { + _ => { let iterated_item = ReturningIteratedItem { name: None, used_ty: UsedType::new( - (*reference).clone(), + syn::Type::Macro( + type_macro.clone(), + ), )?, }; iterated_items.push(iterated_item); } - ty => { - return Err(syn::Error::new( - ty.span(), - "Table Iterator must have an item", - )); - } + } + } + reference @ syn::Type::Reference(_) => { + let iterated_item = ReturningIteratedItem { + name: None, + used_ty: UsedType::new( + (*reference).clone(), + )?, }; + iterated_items.push(iterated_item); } - } - syn::GenericArgument::Lifetime(_) => (), - other => { - return Err(syn::Error::new( - other.span(), - &format!( - "Got unexpected generic argument: {other:?}" - ), - )) - } - }; + ty => { + return Err(syn::Error::new( + ty.span(), + "Table Iterator must have an item", + )); + } + }; + } } + syn::GenericArgument::Lifetime(_) => (), other => { return Err(syn::Error::new( other.span(), - &format!("Got unexpected path argument: {other:?}"), + &format!("Got unexpected generic argument: {other:?}"), )) } }; - Ok(Returning::Iterated { - tys: iterated_items, - optional: saw_option_ident, - result: saw_result_ident, - }) - } else { - let used_ty = UsedType::new(syn::Type::Path(typepath.clone()))?; - Ok(Returning::Type(used_ty)) } - } else { - let used_ty = UsedType::new(syn::Type::Path(typepath.clone()))?; - Ok(Returning::Type(used_ty)) - } - } - syn::Type::Reference(ty_ref) => { - let used_ty = UsedType::new(syn::Type::Reference(ty_ref.clone()))?; + other => { + return Err(syn::Error::new( + other.span(), + &format!("Got unexpected path argument: {other:?}"), + )) + } + }; + Ok(Returning::Iterated { + tys: iterated_items, + optional: saw_option_ident, + result: saw_result_ident, + }) + } else { + let used_ty = UsedType::new(syn::Type::Path(typepath.clone()))?; Ok(Returning::Type(used_ty)) } - syn::Type::Macro(ref mut type_macro) => Self::parse_type_macro(type_macro), - syn::Type::Paren(ref mut type_paren) => match &mut *type_paren.elem { - syn::Type::Macro(ref mut type_macro) => Self::parse_type_macro(type_macro), - other => { - return Err(syn::Error::new( - other.span(), - &format!("Got unknown return type: {type_paren:?}"), - )) - } - }, - other => { - return Err(syn::Error::new( - other.span(), - &format!("Got unknown return type: {other:?}"), - )) - } + } else { + let used_ty = UsedType::new(syn::Type::Path(typepath.clone()))?; + Ok(Returning::Type(used_ty)) + } + } + syn::Type::Reference(ty_ref) => { + let used_ty = UsedType::new(syn::Type::Reference(ty_ref.clone()))?; + Ok(Returning::Type(used_ty)) + } + syn::Type::Macro(ref mut type_macro) => Self::parse_type_macro(type_macro), + syn::Type::Paren(ref mut type_paren) => match &mut *type_paren.elem { + syn::Type::Macro(ref mut type_macro) => Self::parse_type_macro(type_macro), + other => { + return Err(syn::Error::new( + other.span(), + &format!("Got unknown return type (type_paren): {type_paren:?}"), + )) } + }, + syn::Type::Group(tg) => return Self::match_type(&tg.elem), + other => { + return Err(syn::Error::new( + other.span(), + &format!("Got unknown return type (other): {other:?}"), + )) } } } } +impl TryFrom<&syn::ReturnType> for Returning { + type Error = syn::Error; + + fn try_from(value: &syn::ReturnType) -> Result { + match &value { + syn::ReturnType::Default => Ok(Returning::None), + syn::ReturnType::Type(_, ty) => Self::match_type(ty), + } + } +} + impl ToTokens for Returning { fn to_tokens(&self, tokens: &mut TokenStream2) { let quoted = match self { diff --git a/pgrx-tests/src/framework.rs b/pgrx-tests/src/framework.rs index 82b16c3e6..69b6d9f72 100644 --- a/pgrx-tests/src/framework.rs +++ b/pgrx-tests/src/framework.rs @@ -19,7 +19,6 @@ use pgrx_pg_config::{ }; use postgres::error::DbError; use std::collections::HashMap; -use std::fmt::Write as _; use std::io::{BufRead, BufReader, Write}; use std::path::PathBuf; use std::sync::{Arc, Mutex}; @@ -66,42 +65,46 @@ where match result { Ok(result) => Ok(result), Err(e) => { - let dberror = e.as_db_error().unwrap(); - let query = query.unwrap(); - let query_message = dberror.message(); - - let code = dberror.code().code(); - let severity = dberror.severity(); - - let mut message = format!("{} SQLSTATE[{}]", severity, code).bold().red().to_string(); - - message.push_str(format!(": {}", query_message.bold().white()).as_str()); - message.push_str(format!("\nquery: {}", query.bold().white()).as_str()); - message.push_str( - format!( - "\nparams: {}", - match query_params { - Some(params) => format!("{:?}", params), - None => "None".to_string(), + if let Some(dberror) = e.as_db_error() { + let query = query.unwrap(); + let query_message = dberror.message(); + + let code = dberror.code().code(); + let severity = dberror.severity(); + + let mut message = + format!("{} SQLSTATE[{}]", severity, code).bold().red().to_string(); + + message.push_str(format!(": {}", query_message.bold().white()).as_str()); + message.push_str(format!("\nquery: {}", query.bold().white()).as_str()); + message.push_str( + format!( + "\nparams: {}", + match query_params { + Some(params) => format!("{:?}", params), + None => "None".to_string(), + } + ) + .as_str(), + ); + + if let Ok(var) = std::env::var("RUST_BACKTRACE") { + if var.eq("1") { + let detail = dberror.detail().unwrap_or("None"); + let hint = dberror.hint().unwrap_or("None"); + let schema = dberror.hint().unwrap_or("None"); + let table = dberror.table().unwrap_or("None"); + let more_info = format!( + "\ndetail: {detail}\nhint: {hint}\nschema: {schema}\ntable: {table}" + ); + message.push_str(more_info.as_str()); } - ) - .as_str(), - ); - - if let Ok(var) = std::env::var("RUST_BACKTRACE") { - if var.eq("1") { - let detail = dberror.detail().unwrap_or("None"); - let hint = dberror.hint().unwrap_or("None"); - let schema = dberror.hint().unwrap_or("None"); - let table = dberror.table().unwrap_or("None"); - let more_info = format!( - "\ndetail: {detail}\nhint: {hint}\nschema: {schema}\ntable: {table}" - ); - message.push_str(more_info.as_str()); } - } - Err(eyre!(message)) + Err(eyre!(message)) + } else { + return Err(e).wrap_err("non-DbError"); + } } } } @@ -115,71 +118,59 @@ pub fn run_test( let (mut client, session_id) = client()?; - let schema = "tests"; // get_extension_schema(); - let result = match client.transaction() { - // run the test function in a transaction - Ok(mut tx) => { - let result = tx.simple_query(&format!("SELECT \"{schema}\".\"{sql_funcname}\"();")); - - if result.is_ok() { - // and abort the transaction when complete - tx.rollback().expect("test rollback didn't work"); - } + let result = client.transaction().map(|mut tx| { + let schema = "tests"; // get_extension_schema(); + let result = tx.simple_query(&format!("SELECT \"{schema}\".\"{sql_funcname}\"();")); - result + if result.is_ok() { + // and abort the transaction when complete + tx.rollback()?; } - Err(e) => panic!("attempt to run test tx failed:\n{e}"), + result + }); + + // flatten the above result + let result = match result { + Err(e) => Err(e), + Ok(Err(e)) => Err(e), + Ok(_) => Ok(()), }; if let Err(e) = result { - let error_as_string = format!("error in test tx: {e}"); - + let error_as_string = format!("{e}"); let cause = e.into_source(); - if let Some(e) = cause { - if let Some(dberror) = e.downcast_ref::() { - // we got an ERROR - let received_error_message: &str = dberror.message(); - - if let Some(expected_error_message) = expected_error { - // and we expected an error, so assert what we got is what we expect - assert_eq!(received_error_message, expected_error_message); - Ok(()) - } else { - // we weren't expecting an error - // wait a second for Postgres to get log messages written to stderr - std::thread::sleep(std::time::Duration::from_millis(1000)); - - let mut pg_location = String::from("Postgres location: "); - pg_location.push_str(match dberror.file() { - Some(file) => file, - None => "", - }); - if let Some(ln) = dberror.line() { - let _ = write!(pg_location, ":{ln}"); - }; - - let mut rust_location = String::from("Rust location: "); - rust_location.push_str(match dberror.where_() { - Some(place) => place, - None => "", - }); - // then we can panic with those messages plus those that belong to the system - panic!( - "\n{sys}...\n{sess}\n{e}\n{pg}\n{rs}\n\n", - sys = format_loglines(&system_session_id, &loglines), - sess = format_loglines(&session_id, &loglines), - e = received_error_message.bold().red(), - pg = pg_location.dimmed().white(), - rs = rust_location.yellow() - ); + + let (pg_location, rust_location, message) = + if let Some(Some(dberror)) = cause.map(|e| e.downcast_ref::().cloned()) { + let received_error_message = dberror.message(); + + if Some(received_error_message) == expected_error { + // the error received is the one we expected, so just return if they match + return Ok(()); } + + let pg_location = dberror.file().unwrap_or("").to_string(); + let rust_location = dberror.where_().unwrap_or("").to_string(); + + (pg_location, rust_location, received_error_message.to_string()) } else { - panic!("Failed downcast to DbError:\n{e}") - } - } else { - panic!("Error without deeper source cause:\n{e}\n", e = error_as_string.bold().red()) - } + ("".to_string(), "".to_string(), format!("{error_as_string}")) + }; + + // wait a second for Postgres to get log messages written to stderr + std::thread::sleep(std::time::Duration::from_millis(1000)); + + let system_loglines = format_loglines(&system_session_id, &loglines); + let session_loglines = format_loglines(&session_id, &loglines); + panic!( + "\n\nPostgres Messages:\n{system_loglines}\n\nTest Function Messages:\n{session_loglines}\n\nClient Error:\n{message}\npostgres location: {pg_location}\nrust location: {rust_location}\n\n", + system_loglines = system_loglines.dimmed().white(), + session_loglines = session_loglines.cyan(), + message = message.bold().red(), + pg_location = pg_location.dimmed().white(), + rust_location = rust_location.yellow() + ); } else if let Some(message) = expected_error { // we expected an ERROR, but didn't get one return Err(eyre!("Expected error: {message}")); @@ -252,7 +243,7 @@ pub fn client() -> eyre::Result<(postgres::Client, String)> { .user(&get_pg_user()) .dbname(&get_pg_dbname()) .connect(postgres::NoTls) - .unwrap(); + .wrap_err("Error connecting to Postgres")?; let sid_query_result = query_wrapper( Some("SELECT to_hex(trunc(EXTRACT(EPOCH FROM backend_start))::integer) || '.' || to_hex(pid) AS sid FROM pg_stat_activity WHERE pid = pg_backend_pid();".to_string()),