diff --git a/examples/bench.rs b/examples/bench.rs index da684bc..624fc5f 100644 --- a/examples/bench.rs +++ b/examples/bench.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use futures::stream; use futures::StreamExt; +use pgwire::api::NoopErrorHandler; use pgwire::api::PgWireHandlerFactory; use tokio::net::TcpListener; @@ -78,6 +79,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory { type SimpleQueryHandler = DummyProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() @@ -94,6 +96,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory { fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } #[tokio::main(flavor = "multi_thread", worker_threads = 10)] diff --git a/examples/copy.rs b/examples/copy.rs index d6bd461..aae182d 100644 --- a/examples/copy.rs +++ b/examples/copy.rs @@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::CopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{CopyResponse, Response}; -use pgwire::api::{ClientInfo, PgWireConnectionState, PgWireHandlerFactory}; +use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireConnectionState, PgWireHandlerFactory}; use pgwire::error::ErrorInfo; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::copy::{CopyData, CopyDone, CopyFail}; @@ -111,6 +111,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory { type SimpleQueryHandler = DummyProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = DummyProcessor; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() @@ -127,6 +128,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory { fn copy_handler(&self) -> Arc { self.handler.clone() } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } #[tokio::main] diff --git a/examples/duckdb.rs b/examples/duckdb.rs index 3c976fb..3cebcf7 100644 --- a/examples/duckdb.rs +++ b/examples/duckdb.rs @@ -16,7 +16,7 @@ use pgwire::api::results::{ Response, Tag, }; use pgwire::api::stmt::{NoopQueryParser, StoredStatement}; -use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type}; +use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; use pgwire::tokio::process_socket; @@ -334,6 +334,7 @@ impl PgWireHandlerFactory for DuckDBBackendFactory { type SimpleQueryHandler = DuckDBBackend; type ExtendedQueryHandler = DuckDBBackend; type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() @@ -353,6 +354,10 @@ impl PgWireHandlerFactory for DuckDBBackendFactory { fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } #[tokio::main] diff --git a/examples/gluesql.rs b/examples/gluesql.rs index d052712..0aa1ff1 100644 --- a/examples/gluesql.rs +++ b/examples/gluesql.rs @@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type}; +use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::tokio::process_socket; @@ -170,6 +170,7 @@ impl PgWireHandlerFactory for GluesqlHandlerFactory { type SimpleQueryHandler = GluesqlProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.processor.clone() @@ -186,6 +187,10 @@ impl PgWireHandlerFactory for GluesqlHandlerFactory { fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } #[tokio::main] diff --git a/examples/scram.rs b/examples/scram.rs index c05d3a9..854a059 100644 --- a/examples/scram.rs +++ b/examples/scram.rs @@ -16,7 +16,7 @@ use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{Response, Tag}; -use pgwire::api::{ClientInfo, PgWireHandlerFactory}; +use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory}; use pgwire::error::PgWireResult; use pgwire::tokio::process_socket; @@ -83,6 +83,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory { type SimpleQueryHandler = DummyProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() @@ -108,6 +109,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory { fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } #[tokio::main] diff --git a/examples/secure_server.rs b/examples/secure_server.rs index 4f4956a..0630cc7 100644 --- a/examples/secure_server.rs +++ b/examples/secure_server.rs @@ -14,7 +14,7 @@ use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type}; +use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type}; use pgwire::error::PgWireResult; use pgwire::tokio::process_socket; @@ -90,6 +90,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory { type SimpleQueryHandler = DummyProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() @@ -106,6 +107,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory { fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } #[tokio::main] diff --git a/examples/server.rs b/examples/server.rs index f176fda..dd8ab11 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type}; +use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type}; use pgwire::error::ErrorInfo; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::response::NoticeResponse; @@ -80,6 +80,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory { type SimpleQueryHandler = DummyProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() @@ -96,6 +97,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory { fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } #[tokio::main] diff --git a/examples/sqlite.rs b/examples/sqlite.rs index 308c2c3..6c53386 100644 --- a/examples/sqlite.rs +++ b/examples/sqlite.rs @@ -14,6 +14,7 @@ use pgwire::api::results::{ Response, Tag, }; use pgwire::api::stmt::{NoopQueryParser, StoredStatement}; +use pgwire::api::NoopErrorHandler; use pgwire::api::PgWireHandlerFactory; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; @@ -291,6 +292,7 @@ impl PgWireHandlerFactory for SqliteBackendFactory { type SimpleQueryHandler = SqliteBackend; type ExtendedQueryHandler = SqliteBackend; type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() @@ -313,6 +315,10 @@ impl PgWireHandlerFactory for SqliteBackendFactory { fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } #[tokio::main] diff --git a/examples/transaction.rs b/examples/transaction.rs index e58f0f4..ce579f8 100644 --- a/examples/transaction.rs +++ b/examples/transaction.rs @@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type}; +use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type}; use pgwire::error::ErrorInfo; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::response::NoticeResponse; @@ -96,6 +96,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory { type SimpleQueryHandler = DummyProcessor; type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.handler.clone() @@ -112,6 +113,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory { fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } #[tokio::main] diff --git a/src/api/mod.rs b/src/api/mod.rs index 3adad68..5b9b983 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -6,6 +6,7 @@ use std::sync::Arc; pub use postgres_types::Type; +use crate::error::PgWireError; use crate::messages::response::TransactionStatus; pub mod auth; @@ -126,11 +127,25 @@ impl ClientPortalStore for DefaultClient { } } +pub trait ErrorHandler: Send + Sync { + fn on_error(&self, _client: &C, _error: &mut PgWireError) + where + C: ClientInfo, + { + } +} + +/// A noop implementation for `ErrorHandler`. +pub struct NoopErrorHandler; + +impl ErrorHandler for NoopErrorHandler {} + pub trait PgWireHandlerFactory { type StartupHandler: auth::StartupHandler; type SimpleQueryHandler: query::SimpleQueryHandler; type ExtendedQueryHandler: query::ExtendedQueryHandler; type CopyHandler: copy::CopyHandler; + type ErrorHandler: ErrorHandler; fn simple_query_handler(&self) -> Arc; @@ -139,6 +154,8 @@ pub trait PgWireHandlerFactory { fn startup_handler(&self) -> Arc; fn copy_handler(&self) -> Arc; + + fn error_handler(&self) -> Arc; } impl PgWireHandlerFactory for Arc @@ -149,6 +166,7 @@ where type SimpleQueryHandler = T::SimpleQueryHandler; type ExtendedQueryHandler = T::ExtendedQueryHandler; type CopyHandler = T::CopyHandler; + type ErrorHandler = T::ErrorHandler; fn simple_query_handler(&self) -> Arc { (**self).simple_query_handler() @@ -165,4 +183,8 @@ where fn copy_handler(&self) -> Arc { (**self).copy_handler() } + + fn error_handler(&self) -> Arc { + (**self).error_handler() + } } diff --git a/src/tokio/server.rs b/src/tokio/server.rs index 097110c..29d5fe6 100644 --- a/src/tokio/server.rs +++ b/src/tokio/server.rs @@ -14,7 +14,8 @@ use crate::api::copy::CopyHandler; use crate::api::query::SimpleQueryHandler; use crate::api::query::{send_ready_for_query, ExtendedQueryHandler}; use crate::api::{ - ClientInfo, ClientPortalStore, DefaultClient, PgWireConnectionState, PgWireHandlerFactory, + ClientInfo, ClientPortalStore, DefaultClient, ErrorHandler, PgWireConnectionState, + PgWireHandlerFactory, }; use crate::error::{ErrorInfo, PgWireError, PgWireResult}; use crate::messages::response::ReadyForQuery; @@ -318,12 +319,13 @@ async fn peek_for_sslrequest( } } -async fn do_process_socket( +async fn do_process_socket( socket: &mut Framed>, startup_handler: Arc, simple_query_handler: Arc, extended_query_handler: Arc, copy_handler: Arc, + error_handler: Arc, ) -> Result<(), IOError> where S: AsyncRead + AsyncWrite + Unpin + Send + Sync, @@ -331,13 +333,14 @@ where Q: SimpleQueryHandler, EQ: ExtendedQueryHandler, C: CopyHandler, + E: ErrorHandler, { while let Some(Ok(msg)) = socket.next().await { let is_extended_query = match socket.state() { PgWireConnectionState::CopyInProgress(is_extended_query) => is_extended_query, _ => msg.is_extended_query(), }; - if let Err(e) = process_message( + if let Err(mut e) = process_message( msg, socket, startup_handler.clone(), @@ -347,6 +350,7 @@ where ) .await { + error_handler.on_error(socket, &mut e); process_error(socket, e, is_extended_query).await?; } } @@ -398,6 +402,7 @@ where let simple_query_handler = handlers.simple_query_handler(); let extended_query_handler = handlers.extended_query_handler(); let copy_handler = handlers.copy_handler(); + let error_handler = handlers.error_handler(); if ssl == SslNegotiationType::None { // use an already configured socket. @@ -409,6 +414,7 @@ where simple_query_handler, extended_query_handler, copy_handler, + error_handler, ) .await } else { @@ -435,6 +441,7 @@ where simple_query_handler, extended_query_handler, copy_handler, + error_handler, ) .await } diff --git a/tests-integration/test-server/src/main.rs b/tests-integration/test-server/src/main.rs index f74caaf..df288ff 100644 --- a/tests-integration/test-server/src/main.rs +++ b/tests-integration/test-server/src/main.rs @@ -21,7 +21,7 @@ use pgwire::api::results::{ Response, Tag, }; use pgwire::api::stmt::{NoopQueryParser, StoredStatement}; -use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type}; +use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type}; use pgwire::error::PgWireResult; use pgwire::tokio::process_socket; use tokio::net::TcpListener; @@ -223,6 +223,7 @@ impl PgWireHandlerFactory for DummyDatabaseFactory { type SimpleQueryHandler = DummyDatabase; type ExtendedQueryHandler = DummyDatabase; type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; fn simple_query_handler(&self) -> Arc { self.0.clone() @@ -245,6 +246,10 @@ impl PgWireHandlerFactory for DummyDatabaseFactory { fn copy_handler(&self) -> Arc { Arc::new(NoopCopyHandler) } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } } fn setup_tls() -> Result {