Skip to content

Commit

Permalink
feat: add ErrorHandler for error processing (#222)
Browse files Browse the repository at this point in the history
* feat: add ErrorHandler for error processing

* test: fix integration test
  • Loading branch information
sunng87 authored Dec 6, 2024
1 parent 4064832 commit b60d23b
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 11 deletions.
6 changes: 6 additions & 0 deletions examples/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use async_trait::async_trait;
use futures::stream;
use futures::StreamExt;
use pgwire::api::NoopErrorHandler;
use pgwire::api::PgWireHandlerFactory;
use tokio::net::TcpListener;

Expand Down Expand Up @@ -78,6 +79,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
type SimpleQueryHandler = DummyProcessor;
type ExtendedQueryHandler = PlaceholderExtendedQueryHandler;
type CopyHandler = NoopCopyHandler;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.handler.clone()
Expand All @@ -94,6 +96,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
Expand Down
7 changes: 6 additions & 1 deletion examples/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::copy::CopyHandler;
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{CopyResponse, Response};
use pgwire::api::{ClientInfo, PgWireConnectionState, PgWireHandlerFactory};
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireConnectionState, PgWireHandlerFactory};
use pgwire::error::ErrorInfo;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::copy::{CopyData, CopyDone, CopyFail};
Expand Down Expand Up @@ -111,6 +111,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
type SimpleQueryHandler = DummyProcessor;
type ExtendedQueryHandler = PlaceholderExtendedQueryHandler;
type CopyHandler = DummyProcessor;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.handler.clone()
Expand All @@ -127,6 +128,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
self.handler.clone()
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

#[tokio::main]
Expand Down
7 changes: 6 additions & 1 deletion examples/duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use pgwire::api::results::{
Response, Tag,
};
use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type};
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::data::DataRow;
use pgwire::tokio::process_socket;
Expand Down Expand Up @@ -334,6 +334,7 @@ impl PgWireHandlerFactory for DuckDBBackendFactory {
type SimpleQueryHandler = DuckDBBackend;
type ExtendedQueryHandler = DuckDBBackend;
type CopyHandler = NoopCopyHandler;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.handler.clone()
Expand All @@ -353,6 +354,10 @@ impl PgWireHandlerFactory for DuckDBBackendFactory {
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

#[tokio::main]
Expand Down
7 changes: 6 additions & 1 deletion examples/gluesql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::copy::NoopCopyHandler;
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type};
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type};
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::tokio::process_socket;

Expand Down Expand Up @@ -170,6 +170,7 @@ impl PgWireHandlerFactory for GluesqlHandlerFactory {
type SimpleQueryHandler = GluesqlProcessor;
type ExtendedQueryHandler = PlaceholderExtendedQueryHandler;
type CopyHandler = NoopCopyHandler;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.processor.clone()
Expand All @@ -186,6 +187,10 @@ impl PgWireHandlerFactory for GluesqlHandlerFactory {
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

#[tokio::main]
Expand Down
7 changes: 6 additions & 1 deletion examples/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use pgwire::api::copy::NoopCopyHandler;
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{Response, Tag};

use pgwire::api::{ClientInfo, PgWireHandlerFactory};
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory};
use pgwire::error::PgWireResult;
use pgwire::tokio::process_socket;

Expand Down Expand Up @@ -83,6 +83,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
type SimpleQueryHandler = DummyProcessor;
type ExtendedQueryHandler = PlaceholderExtendedQueryHandler;
type CopyHandler = NoopCopyHandler;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.handler.clone()
Expand All @@ -108,6 +109,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

#[tokio::main]
Expand Down
7 changes: 6 additions & 1 deletion examples/secure_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::copy::NoopCopyHandler;
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type};
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type};
use pgwire::error::PgWireResult;
use pgwire::tokio::process_socket;

Expand Down Expand Up @@ -90,6 +90,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
type SimpleQueryHandler = DummyProcessor;
type ExtendedQueryHandler = PlaceholderExtendedQueryHandler;
type CopyHandler = NoopCopyHandler;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.handler.clone()
Expand All @@ -106,6 +107,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

#[tokio::main]
Expand Down
7 changes: 6 additions & 1 deletion examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::copy::NoopCopyHandler;
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type};
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type};
use pgwire::error::ErrorInfo;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::response::NoticeResponse;
Expand Down Expand Up @@ -80,6 +80,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
type SimpleQueryHandler = DummyProcessor;
type ExtendedQueryHandler = PlaceholderExtendedQueryHandler;
type CopyHandler = NoopCopyHandler;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.handler.clone()
Expand All @@ -96,6 +97,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

#[tokio::main]
Expand Down
6 changes: 6 additions & 0 deletions examples/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use pgwire::api::results::{
Response, Tag,
};
use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
use pgwire::api::NoopErrorHandler;
use pgwire::api::PgWireHandlerFactory;
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
Expand Down Expand Up @@ -291,6 +292,7 @@ impl PgWireHandlerFactory for SqliteBackendFactory {
type SimpleQueryHandler = SqliteBackend;
type ExtendedQueryHandler = SqliteBackend;
type CopyHandler = NoopCopyHandler;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.handler.clone()
Expand All @@ -313,6 +315,10 @@ impl PgWireHandlerFactory for SqliteBackendFactory {
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

#[tokio::main]
Expand Down
7 changes: 6 additions & 1 deletion examples/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::copy::NoopCopyHandler;
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type};
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireHandlerFactory, Type};
use pgwire::error::ErrorInfo;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::response::NoticeResponse;
Expand Down Expand Up @@ -96,6 +96,7 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
type SimpleQueryHandler = DummyProcessor;
type ExtendedQueryHandler = PlaceholderExtendedQueryHandler;
type CopyHandler = NoopCopyHandler;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.handler.clone()
Expand All @@ -112,6 +113,10 @@ impl PgWireHandlerFactory for DummyProcessorFactory {
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

#[tokio::main]
Expand Down
22 changes: 22 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::sync::Arc;

pub use postgres_types::Type;

use crate::error::PgWireError;
use crate::messages::response::TransactionStatus;

pub mod auth;
Expand Down Expand Up @@ -126,11 +127,25 @@ impl<S> ClientPortalStore for DefaultClient<S> {
}
}

pub trait ErrorHandler: Send + Sync {
fn on_error<C>(&self, _client: &C, _error: &mut PgWireError)
where
C: ClientInfo,
{
}
}

/// A noop implementation for `ErrorHandler`.
pub struct NoopErrorHandler;

impl ErrorHandler for NoopErrorHandler {}

pub trait PgWireHandlerFactory {
type StartupHandler: auth::StartupHandler;
type SimpleQueryHandler: query::SimpleQueryHandler;
type ExtendedQueryHandler: query::ExtendedQueryHandler;
type CopyHandler: copy::CopyHandler;
type ErrorHandler: ErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler>;

Expand All @@ -139,6 +154,8 @@ pub trait PgWireHandlerFactory {
fn startup_handler(&self) -> Arc<Self::StartupHandler>;

fn copy_handler(&self) -> Arc<Self::CopyHandler>;

fn error_handler(&self) -> Arc<Self::ErrorHandler>;
}

impl<T> PgWireHandlerFactory for Arc<T>
Expand All @@ -149,6 +166,7 @@ where
type SimpleQueryHandler = T::SimpleQueryHandler;
type ExtendedQueryHandler = T::ExtendedQueryHandler;
type CopyHandler = T::CopyHandler;
type ErrorHandler = T::ErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
(**self).simple_query_handler()
Expand All @@ -165,4 +183,8 @@ where
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
(**self).copy_handler()
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
(**self).error_handler()
}
}
13 changes: 10 additions & 3 deletions src/tokio/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use crate::api::copy::CopyHandler;
use crate::api::query::SimpleQueryHandler;
use crate::api::query::{send_ready_for_query, ExtendedQueryHandler};
use crate::api::{
ClientInfo, ClientPortalStore, DefaultClient, PgWireConnectionState, PgWireHandlerFactory,
ClientInfo, ClientPortalStore, DefaultClient, ErrorHandler, PgWireConnectionState,
PgWireHandlerFactory,
};
use crate::error::{ErrorInfo, PgWireError, PgWireResult};
use crate::messages::response::ReadyForQuery;
Expand Down Expand Up @@ -318,26 +319,28 @@ async fn peek_for_sslrequest<ST>(
}
}

async fn do_process_socket<S, A, Q, EQ, C>(
async fn do_process_socket<S, A, Q, EQ, C, E>(
socket: &mut Framed<S, PgWireMessageServerCodec<EQ::Statement>>,
startup_handler: Arc<A>,
simple_query_handler: Arc<Q>,
extended_query_handler: Arc<EQ>,
copy_handler: Arc<C>,
error_handler: Arc<E>,
) -> Result<(), IOError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
A: StartupHandler,
Q: SimpleQueryHandler,
EQ: ExtendedQueryHandler,
C: CopyHandler,
E: ErrorHandler,
{
while let Some(Ok(msg)) = socket.next().await {
let is_extended_query = match socket.state() {
PgWireConnectionState::CopyInProgress(is_extended_query) => is_extended_query,
_ => msg.is_extended_query(),
};
if let Err(e) = process_message(
if let Err(mut e) = process_message(
msg,
socket,
startup_handler.clone(),
Expand All @@ -347,6 +350,7 @@ where
)
.await
{
error_handler.on_error(socket, &mut e);
process_error(socket, e, is_extended_query).await?;
}
}
Expand Down Expand Up @@ -398,6 +402,7 @@ where
let simple_query_handler = handlers.simple_query_handler();
let extended_query_handler = handlers.extended_query_handler();
let copy_handler = handlers.copy_handler();
let error_handler = handlers.error_handler();

if ssl == SslNegotiationType::None {
// use an already configured socket.
Expand All @@ -409,6 +414,7 @@ where
simple_query_handler,
extended_query_handler,
copy_handler,
error_handler,
)
.await
} else {
Expand All @@ -435,6 +441,7 @@ where
simple_query_handler,
extended_query_handler,
copy_handler,
error_handler,
)
.await
}
Expand Down
Loading

0 comments on commit b60d23b

Please sign in to comment.