diff --git a/src/errors.rs b/src/errors.rs index c076a3f9..014a1340 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -12,6 +12,7 @@ pub enum Error { ProtocolSyncError(String), BadQuery(String), ServerError, + ServerMessageParserError(String), ServerStartupError(String, ServerIdentifier), ServerAuthError(String, ServerIdentifier), BadConfig, diff --git a/src/messages.rs b/src/messages.rs index 8ebc00a3..1f40f1df 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -11,10 +11,13 @@ use crate::client::PREPARED_STATEMENT_COUNTER; use crate::config::get_config; use crate::errors::Error; +use crate::constants::MESSAGE_TERMINATOR; use std::collections::HashMap; use std::ffi::CString; +use std::fmt::{Display, Formatter}; use std::io::{BufRead, Cursor}; use std::mem; +use std::str::FromStr; use std::sync::atomic::Ordering; use std::time::Duration; @@ -1098,3 +1101,298 @@ pub fn prepared_statement_name() -> String { PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst) ) } + +// from https://www.postgresql.org/docs/12/protocol-error-fields.html +#[derive(Debug, Default, PartialEq)] +pub struct PgErrorMsg { + pub severity_localized: String, // S + pub severity: String, // V + pub code: String, // C + pub message: String, // M + pub detail: Option, // D + pub hint: Option, // H + pub position: Option, // P + pub internal_position: Option, // p + pub internal_query: Option, // q + pub where_context: Option, // W + pub schema_name: Option, // s + pub table_name: Option, // t + pub column_name: Option, // c + pub data_type_name: Option, // d + pub constraint_name: Option, // n + pub file_name: Option, // F + pub line: Option, // L + pub routine: Option, // R +} + +// TODO: implement with https://docs.rs/derive_more/latest/derive_more/ +impl Display for PgErrorMsg { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "[severity: {}]", self.severity)?; + write!(f, "[code: {}]", self.code)?; + write!(f, "[message: {}]", self.message)?; + if let Some(val) = &self.detail { + write!(f, "[detail: {val}]")?; + } + if let Some(val) = &self.hint { + write!(f, "[hint: {val}]")?; + } + if let Some(val) = &self.position { + write!(f, "[position: {val}]")?; + } + if let Some(val) = &self.internal_position { + write!(f, "[internal_position: {val}]")?; + } + if let Some(val) = &self.internal_query { + write!(f, "[internal_query: {val}]")?; + } + if let Some(val) = &self.internal_query { + write!(f, "[internal_query: {val}]")?; + } + if let Some(val) = &self.where_context { + write!(f, "[where: {val}]")?; + } + if let Some(val) = &self.schema_name { + write!(f, "[schema_name: {val}]")?; + } + if let Some(val) = &self.table_name { + write!(f, "[table_name: {val}]")?; + } + if let Some(val) = &self.column_name { + write!(f, "[column_name: {val}]")?; + } + if let Some(val) = &self.data_type_name { + write!(f, "[data_type_name: {val}]")?; + } + if let Some(val) = &self.constraint_name { + write!(f, "[constraint_name: {val}]")?; + } + if let Some(val) = &self.file_name { + write!(f, "[file_name: {val}]")?; + } + if let Some(val) = &self.line { + write!(f, "[line: {val}]")?; + } + if let Some(val) = &self.routine { + write!(f, "[routine: {val}]")?; + } + + write!(f, " ")?; + + Ok(()) + } +} + +impl PgErrorMsg { + pub fn parse(error_msg: Vec) -> Result { + let mut out = PgErrorMsg { + severity_localized: "".to_string(), + severity: "".to_string(), + code: "".to_string(), + message: "".to_string(), + detail: None, + hint: None, + position: None, + internal_position: None, + internal_query: None, + where_context: None, + schema_name: None, + table_name: None, + column_name: None, + data_type_name: None, + constraint_name: None, + file_name: None, + line: None, + routine: None, + }; + for msg_part in error_msg.split(|v| *v == MESSAGE_TERMINATOR) { + if msg_part.is_empty() { + continue; + } + + let msg_content = match String::from_utf8_lossy(&msg_part[1..]).parse() { + Ok(c) => c, + Err(err) => { + return Err(Error::ServerMessageParserError(format!( + "could not parse server message field. err {:?}", + err + ))) + } + }; + + match &msg_part[0] { + b'S' => { + out.severity_localized = msg_content; + } + b'V' => { + out.severity = msg_content; + } + b'C' => { + out.code = msg_content; + } + b'M' => { + out.message = msg_content; + } + b'D' => { + out.detail = Some(msg_content); + } + b'H' => { + out.hint = Some(msg_content); + } + b'P' => out.position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)), + b'p' => { + out.internal_position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)) + } + b'q' => { + out.internal_query = Some(msg_content); + } + b'W' => { + out.where_context = Some(msg_content); + } + b's' => { + out.schema_name = Some(msg_content); + } + b't' => { + out.table_name = Some(msg_content); + } + b'c' => { + out.column_name = Some(msg_content); + } + b'd' => { + out.data_type_name = Some(msg_content); + } + b'n' => { + out.constraint_name = Some(msg_content); + } + b'F' => { + out.file_name = Some(msg_content); + } + b'L' => out.line = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)), + b'R' => { + out.routine = Some(msg_content); + } + _ => {} + } + } + + Ok(out) + } +} + +#[cfg(test)] +mod tests { + use crate::messages::PgErrorMsg; + use log::{error, info}; + + fn field(kind: char, content: &str) -> Vec { + format!("{kind}{content}\0").as_bytes().to_vec() + } + + #[test] + fn parse_fields() { + let mut complete_msg = vec![]; + let severity = "FATAL"; + complete_msg.extend(field('S', &severity)); + complete_msg.extend(field('V', &severity)); + + let error_code = "29P02"; + complete_msg.extend(field('C', &error_code)); + let message = "password authentication failed for user \"wrong_user\""; + complete_msg.extend(field('M', &message)); + let detail_msg = "super detailed message"; + complete_msg.extend(field('D', &detail_msg)); + let hint_msg = "hint detail here"; + complete_msg.extend(field('H', &hint_msg)); + complete_msg.extend(field('P', "123")); + complete_msg.extend(field('p', "234")); + let internal_query = "SELECT * from foo;"; + complete_msg.extend(field('q', &internal_query)); + let where_msg = "where goes here"; + complete_msg.extend(field('W', &where_msg)); + let schema_msg = "schema_name"; + complete_msg.extend(field('s', &schema_msg)); + let table_msg = "table_name"; + complete_msg.extend(field('t', &table_msg)); + let column_msg = "column_name"; + complete_msg.extend(field('c', &column_msg)); + let data_type_msg = "type_name"; + complete_msg.extend(field('d', &data_type_msg)); + let constraint_msg = "constraint_name"; + complete_msg.extend(field('n', &constraint_msg)); + let file_msg = "pgcat.c"; + complete_msg.extend(field('F', &file_msg)); + complete_msg.extend(field('L', "335")); + let routine_msg = "my_failing_routine"; + complete_msg.extend(field('R', &routine_msg)); + + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .with_ansi(true) + .init(); + + info!( + "full message: {}", + PgErrorMsg::parse(complete_msg.clone()).unwrap() + ); + assert_eq!( + PgErrorMsg { + severity_localized: severity.to_string(), + severity: severity.to_string(), + code: error_code.to_string(), + message: message.to_string(), + detail: Some(detail_msg.to_string()), + hint: Some(hint_msg.to_string()), + position: Some(123), + internal_position: Some(234), + internal_query: Some(internal_query.to_string()), + where_context: Some(where_msg.to_string()), + schema_name: Some(schema_msg.to_string()), + table_name: Some(table_msg.to_string()), + column_name: Some(column_msg.to_string()), + data_type_name: Some(data_type_msg.to_string()), + constraint_name: Some(constraint_msg.to_string()), + file_name: Some(file_msg.to_string()), + line: Some(335), + routine: Some(routine_msg.to_string()), + }, + PgErrorMsg::parse(complete_msg).unwrap() + ); + + let mut only_mandatory_msg = vec![]; + only_mandatory_msg.extend(field('S', &severity)); + only_mandatory_msg.extend(field('V', &severity)); + only_mandatory_msg.extend(field('C', &error_code)); + only_mandatory_msg.extend(field('M', &message)); + only_mandatory_msg.extend(field('D', &detail_msg)); + + let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap(); + info!("only mandatory fields: {}", &err_fields); + error!( + "server error: {}: {}", + err_fields.severity, err_fields.message + ); + assert_eq!( + PgErrorMsg { + severity_localized: severity.to_string(), + severity: severity.to_string(), + code: error_code.to_string(), + message: message.to_string(), + detail: Some(detail_msg.to_string()), + hint: None, + position: None, + internal_position: None, + internal_query: None, + where_context: None, + schema_name: None, + table_name: None, + column_name: None, + data_type_name: None, + constraint_name: None, + file_name: None, + line: None, + routine: None, + }, + PgErrorMsg::parse(only_mandatory_msg).unwrap() + ); + } +} diff --git a/src/server.rs b/src/server.rs index afa1c09d..9d0beaac 100644 --- a/src/server.rs +++ b/src/server.rs @@ -588,8 +588,7 @@ impl Server { // An error message will be present. _ => { - // Read the error message without the terminating null character. - let mut error = vec![0u8; len as usize - 4 - 1]; + let mut error = vec![0u8; len as usize]; match stream.read_exact(&mut error).await { Ok(_) => (), @@ -601,10 +600,14 @@ impl Server { } }; - // TODO: the error message contains multiple fields; we can decode them and - // present a prettier message to the user. - // See: https://www.postgresql.org/docs/12/protocol-error-fields.html - error!("Server error: {}", String::from_utf8_lossy(&error)); + let fields = match PgErrorMsg::parse(error) { + Ok(f) => f, + Err(err) => { + return Err(err); + } + }; + trace!("error fields: {}", &fields); + error!("server error: {}: {}", fields.severity, fields.message); } };