diff --git a/quaint/src/connector/metrics.rs b/quaint/src/connector/metrics.rs index 37143866a67d..9f33de723978 100644 --- a/quaint/src/connector/metrics.rs +++ b/quaint/src/connector/metrics.rs @@ -106,20 +106,19 @@ struct QueryForTracing<'a>(&'a str); impl fmt::Display for QueryForTracing<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let query = self - .0 - .split_once("/* traceparent=") - .map_or(self.0, |(str, remainder)| { - if remainder - .split_once("*/") - .is_some_and(|(_, suffix)| suffix.trim_end().is_empty()) - { - str - } else { - self.0 - } - }) - .trim(); - write!(f, "{query}") + write!(f, "{}", strip_query_traceparent(self.0)) } } + +pub(super) fn strip_query_traceparent(query: &str) -> &str { + query.rsplit_once("/* traceparent=").map_or(query, |(str, remainder)| { + if remainder + .split_once("*/") + .is_some_and(|(_, suffix)| suffix.trim_end().is_empty()) + { + str.trim_end() + } else { + query + } + }) +} diff --git a/quaint/src/connector/postgres/native/cache.rs b/quaint/src/connector/postgres/native/cache.rs new file mode 100644 index 000000000000..447326608fa7 --- /dev/null +++ b/quaint/src/connector/postgres/native/cache.rs @@ -0,0 +1,428 @@ +use std::{ + hash::{BuildHasher, Hash, Hasher, RandomState}, + sync::Arc, +}; + +use async_trait::async_trait; +use lru_cache::LruCache; +use postgres_types::Type; +use tokio::sync::Mutex; +use tokio_postgres::{Client, Error, Statement}; + +use crate::connector::metrics::strip_query_traceparent; + +use super::query::{PreparedQuery, QueryMetadata, TypedQuery}; + +/// Types that can be used as a cache for prepared queries and statements. +#[async_trait] +pub trait QueryCache: From + Send + Sync { + /// The type that is returned when a prepared query is requested from the cache. + type Query<'a>: PreparedQuery; + + /// Retrieve a prepared query. + async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result, Error>; + + /// Retrieve a prepared statement. + /// + /// This is useful in scenarios that require direct access to a prepared statement, + /// e.g. describing a query. + async fn get_statement(&self, client: &Client, sql: &str, types: &[Type]) -> Result; +} + +/// A no-op cache that creates a new prepared statement for every requested query. +/// Useful when we don't need caching. +#[derive(Debug, Default)] +pub struct NoOpCache; + +#[async_trait] +impl QueryCache for NoOpCache { + type Query<'a> = Statement; + + #[inline] + async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result { + self.get_statement(client, sql, types).await + } + + #[inline] + async fn get_statement(&self, client: &Client, sql: &str, types: &[Type]) -> Result { + client.prepare_typed(sql, types).await + } +} + +impl From for NoOpCache { + fn from(_: CacheSettings) -> Self { + Self + } +} + +/// An LRU cache that creates a prepared statement for every query that is not in the cache. +#[derive(Debug)] +pub struct PreparedStatementLruCache { + cache: InnerLruCache, +} + +impl PreparedStatementLruCache { + pub fn with_capacity(capacity: usize) -> Self { + Self { + cache: InnerLruCache::with_capacity(capacity), + } + } +} + +#[async_trait] +impl QueryCache for PreparedStatementLruCache { + type Query<'a> = Statement; + + #[inline] + async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result { + self.get_statement(client, sql, types).await + } + + async fn get_statement(&self, client: &Client, sql: &str, types: &[Type]) -> Result { + match self.cache.get(sql, types).await { + Some(statement) => Ok(statement), + None => { + let stmt = client.prepare_typed(sql, types).await?; + self.cache.insert(sql, types, stmt.clone()).await; + Ok(stmt) + } + } + } +} + +impl From for PreparedStatementLruCache { + fn from(settings: CacheSettings) -> Self { + Self::with_capacity(settings.capacity) + } +} + +/// An LRU cache that creates and stores query type information rather than prepared statements. +/// Queries are identified by their content with tracing information removed (which makes it +/// possible to cache traced queries at all) and returned as instances of [`TypedQuery`]. The +/// caching behavior is implemented in [`get_query`](Self::get_query), while statements returned +/// from [`get_statement`](Self::get_statement) are always freshly prepared, because statements +/// cannot be re-used when tracing information is present. +#[derive(Debug)] +pub struct TracingLruCache { + cache: InnerLruCache>, +} + +impl TracingLruCache { + pub fn with_capacity(capacity: usize) -> Self { + Self { + cache: InnerLruCache::with_capacity(capacity), + } + } +} + +#[async_trait] +impl QueryCache for TracingLruCache { + type Query<'a> = TypedQuery<'a>; + + async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result, Error> { + let sql_without_traceparent = strip_query_traceparent(sql); + + let metadata = match self.cache.get(sql_without_traceparent, types).await { + Some(metadata) => metadata, + None => { + let stmt = client.prepare_typed(sql_without_traceparent, types).await?; + let metadata = Arc::new(QueryMetadata::from(&stmt)); + self.cache + .insert(sql_without_traceparent, types, metadata.clone()) + .await; + metadata + } + }; + Ok(TypedQuery::from_sql_and_metadata(sql, metadata)) + } + + async fn get_statement(&self, client: &Client, sql: &str, types: &[Type]) -> Result { + client.prepare_typed(sql, types).await + } +} + +impl From for TracingLruCache { + fn from(settings: CacheSettings) -> Self { + Self::with_capacity(settings.capacity) + } +} + +/// Settings related to query caching. +#[derive(Debug)] +pub struct CacheSettings { + pub capacity: usize, +} + +/// Key uniquely representing an SQL statement in the prepared statements cache. +#[derive(Debug, PartialEq, Eq, Hash)] +struct QueryKey(u64); + +impl QueryKey { + fn new(st: &S, sql: &str, params: &[Type]) -> Self { + Self(st.hash_one((sql, params))) + } +} + +#[derive(Debug)] +struct InnerLruCache { + cache: Mutex>, + state: RandomState, +} + +impl InnerLruCache { + fn with_capacity(capacity: usize) -> Self { + Self { + cache: Mutex::new(LruCache::with_hasher(capacity, NoOpHasherBuilder)), + state: RandomState::new(), + } + } + + async fn get(&self, sql: &str, types: &[Type]) -> Option + where + V: Clone, + { + let mut cache = self.cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + let key = QueryKey::new(&self.state, sql, types); + // we call `get_mut` because LRU requires mutable access for lookups + match cache.get_mut(&key) { + Some(value) => { + tracing::trace!( + message = "query cache hit", + query = sql, + capacity = capacity, + stored = stored, + ); + Some(value.clone()) + } + None => { + tracing::trace!( + message = "query cache miss", + query = sql, + capacity = capacity, + stored = stored, + ); + None + } + } + } + + pub async fn insert(&self, sql: &str, types: &[Type], value: V) { + let key = QueryKey::new(&self.state, sql, types); + self.cache.lock().await.insert(key, value); + } +} + +struct NoOpHasherBuilder; + +impl BuildHasher for NoOpHasherBuilder { + type Hasher = NoOpHasher; + + fn build_hasher(&self) -> Self::Hasher { + NoOpHasher(None) + } +} + +/// A hasher that expects to be called with a single u64 and returns it as the hash. +struct NoOpHasher(Option); + +impl Hasher for NoOpHasher { + fn finish(&self) -> u64 { + self.0.expect("NoopHasher should have been called with a single u64") + } + + fn write(&mut self, _bytes: &[u8]) { + panic!("NoopHasher should only be called with u64") + } + + fn write_u64(&mut self, i: u64) { + assert!(self.0.is_none(), "NoopHasher should only be called once"); + self.0 = Some(i); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::future::Future; + + pub(crate) use crate::connector::postgres::url::PostgresNativeUrl; + use crate::{ + connector::{MakeTlsConnectorManager, PostgresFlavour}, + tests::test_api::postgres::CONN_STR, + }; + use url::Url; + + #[tokio::test] + async fn noop_cache_returns_new_queries_every_time() { + run_with_client(|client| async move { + let cache = NoOpCache; + let sql = "SELECT $1"; + let types = [Type::INT4]; + + let stmt1 = cache.get_query(&client, sql, &types).await.unwrap(); + let stmt2 = cache.get_query(&client, sql, &types).await.unwrap(); + assert_ne!(stmt1.name(), stmt2.name()); + }) + .await; + } + + #[tokio::test] + async fn noop_cache_returns_new_statements_every_time() { + run_with_client(|client| async move { + let cache = NoOpCache; + let sql = "SELECT $1"; + let types = [Type::INT4]; + + let stmt1 = cache.get_statement(&client, sql, &types).await.unwrap(); + let stmt2 = cache.get_statement(&client, sql, &types).await.unwrap(); + assert_ne!(stmt1.name(), stmt2.name()); + }) + .await; + } + + #[tokio::test] + async fn prepared_statement_lru_cache_reuses_queries_within_capacity() { + run_with_client(|client| async move { + let cache = PreparedStatementLruCache::with_capacity(3); + let sql = "SELECT $1"; + let types = [Type::INT4]; + + let stmt1 = cache.get_query(&client, sql, &types).await.unwrap(); + let stmt2 = cache.get_query(&client, sql, &types).await.unwrap(); + assert_eq!(stmt1.name(), stmt2.name()); + + // fill the cache with different types, causing the first query to be evicted + for typ in [Type::INT8, Type::INT4_ARRAY, Type::INT8_ARRAY] { + cache.get_query(&client, sql, &[typ]).await.unwrap(); + } + + // the old statement should be re-created + let stmt3 = cache.get_query(&client, sql, &types).await.unwrap(); + assert_ne!(stmt1.name(), stmt3.name()); + }) + .await; + } + + #[tokio::test] + async fn prepared_statement_lru_cache_reuses_statements_within_capacity() { + run_with_client(|client| async move { + let cache = PreparedStatementLruCache::with_capacity(3); + let sql = "SELECT $1"; + let types = [Type::INT4]; + + let stmt1 = cache.get_statement(&client, sql, &types).await.unwrap(); + let stmt2 = cache.get_statement(&client, sql, &types).await.unwrap(); + assert_eq!(stmt1.name(), stmt2.name()); + + // fill the cache with different types, causing the first query to be evicted + for typ in [Type::INT8, Type::INT4_ARRAY, Type::INT8_ARRAY] { + cache.get_query(&client, sql, &[typ]).await.unwrap(); + } + + // the old statement should be re-created + let stmt3 = cache.get_statement(&client, sql, &types).await.unwrap(); + assert_ne!(stmt1.name(), stmt3.name()); + }) + .await; + } + + #[tokio::test] + async fn tracing_lru_cache_reuses_queries_within_capacity() { + run_with_client(|client| async move { + let cache = TracingLruCache::with_capacity(3); + let sql = "SELECT $1"; + let types = [Type::INT4]; + + let q1 = cache.get_query(&client, sql, &types).await.unwrap(); + let q2 = cache.get_query(&client, sql, &types).await.unwrap(); + assert!( + Arc::ptr_eq(&q1.metadata, &q2.metadata), + "q1 and q2 should re-use the same metadata" + ); + + // fill the cache with different types, causing the first query to be evicted + for typ in [Type::INT8, Type::INT4_ARRAY, Type::INT8_ARRAY] { + cache.get_query(&client, sql, &[typ]).await.unwrap(); + } + + // the old query should be re-created + let q3 = cache.get_query(&client, sql, &types).await.unwrap(); + assert!( + !Arc::ptr_eq(&q1.metadata, &q3.metadata), + "q1 and q3 should not re-use the same metadata" + ); + }) + .await; + } + + #[tokio::test] + async fn tracing_lru_cache_reuses_queries_with_different_traceparent() { + run_with_client(|client| async move { + let cache = TracingLruCache::with_capacity(1); + let sql1 = "SELECT $1 /* traceparent=00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01 */"; + let sql2 = "SELECT $1 /* traceparent=00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-02 */"; + let types = [Type::INT4]; + + let q1 = cache.get_query(&client, sql1, &types).await.unwrap(); + assert_eq!(q1.sql, sql1); + let q2 = cache.get_query(&client, sql2, &types).await.unwrap(); + // the requested query traceparent should be preserved + assert_eq!(q2.sql, sql2); + + assert!( + Arc::ptr_eq(&q1.metadata, &q2.metadata), + "q1 and q2 should re-use the same metadata" + ); + }) + .await; + } + + #[tokio::test] + async fn tracing_lru_cache_returns_new_statements_every_time() { + run_with_client(|client| async move { + let cache = TracingLruCache::with_capacity(1); + let sql = "SELECT $1"; + let types = [Type::INT4]; + + let q1 = cache.get_statement(&client, sql, &types).await.unwrap(); + let q2 = cache.get_statement(&client, sql, &types).await.unwrap(); + assert_ne!(q1.name(), q2.name()); + }) + .await; + } + + #[test] + fn noop_hasher_returns_the_same_hash_the_input() { + assert_eq!(NoOpHasherBuilder.hash_one(0xdeadc0deu64), 0xdeadc0de); + assert_eq!(NoOpHasherBuilder.hash_one(0xcafeu64), 0xcafe); + } + + #[test] + #[should_panic(expected = "NoopHasher should only be called with u64")] + fn noop_hasher_doesnt_accept_non_u64_input() { + NoOpHasherBuilder.hash_one("hello"); + } + + async fn run_with_client(test: Func) + where + Func: FnOnce(Client) -> Fut, + Fut: Future, + { + let url = Url::parse(&CONN_STR).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + let tls = tls_manager.get_connector().await.unwrap(); + + let (client, conn) = pg_url.to_config().connect(tls).await.unwrap(); + + let set = tokio::task::LocalSet::new(); + set.spawn_local(conn); + set.run_until(test(client)).await + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index eb6618ce9dc7..2e8f62b424d3 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -1,19 +1,23 @@ //! Definitions for the Postgres connector. //! This module is not compatible with wasm32-* targets. //! This module is only available with the `postgresql-native` feature. +mod cache; pub(crate) mod column_type; mod conversion; mod error; mod explain; +mod query; mod websocket; pub(crate) use crate::connector::postgres::url::PostgresNativeUrl; use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; use crate::connector::{ timeout, ColumnType, DescribedColumn, DescribedParameter, DescribedQuery, IsolationLevel, Transaction, + TransactionOptions, }; use crate::error::NativeErrorKind; +use crate::prelude::DefaultTransaction; use crate::ValueType; use crate::{ ast::{Query, Value}, @@ -22,14 +26,15 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; +use cache::{CacheSettings, NoOpCache, PreparedStatementLruCache, QueryCache, TracingLruCache}; use column_type::PGColumnType; -use futures::{future::FutureExt, lock::Mutex}; -use lru_cache::LruCache; +use futures::future::FutureExt; +use futures::StreamExt; use native_tls::{Certificate, Identity, TlsConnector}; use postgres_native_tls::MakeTlsConnector; use postgres_types::{Kind as PostgresKind, Type as PostgresType}; use prisma_metrics::WithMetricsInstrumentation; -use std::hash::{DefaultHasher, Hash, Hasher}; +use query::PreparedQuery; use std::{ fmt::{Debug, Display}, fs, @@ -49,7 +54,7 @@ pub use tokio_postgres; use super::PostgresWebSocketUrl; -struct PostgresClient(Client); +pub(super) struct PostgresClient(Client); impl Debug for PostgresClient { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -61,47 +66,33 @@ const DB_SYSTEM_NAME_POSTGRESQL: &str = "postgresql"; const DB_SYSTEM_NAME_COCKROACHDB: &str = "cockroachdb"; /// A connector interface for the PostgreSQL database. +/// +/// # Type parameters +/// - `Cache`: The cache used for prepared queries. #[derive(Debug)] -pub struct PostgreSql { +pub struct PostgreSql { client: PostgresClient, pg_bouncer: bool, socket_timeout: Option, - statement_cache: Mutex, + cache: Cache, is_healthy: AtomicBool, is_cockroachdb: bool, is_materialize: bool, db_system_name: &'static str, } -/// Key uniquely representing an SQL statement in the prepared statements cache. -#[derive(PartialEq, Eq, Hash)] -pub(crate) struct StatementKey { - /// Hash of a string with SQL query. - sql: u64, - /// Combined hash of types for all parameters from the query. - types_hash: u64, -} +/// A [`PostgreSql`] interface with the default caching strategy, which involves storing all +/// queries as prepared statements in an LRU cache. +pub type PostgreSqlWithDefaultCache = PostgreSql; -pub(crate) type StatementCache = LruCache; +/// A [`PostgreSql`] interface which executes all queries as prepared statements without caching +/// them. +pub type PostgreSqlWithNoCache = PostgreSql; -impl StatementKey { - fn new(sql: &str, params: &[Value<'_>]) -> Self { - Self { - sql: { - let mut hasher = DefaultHasher::new(); - sql.hash(&mut hasher); - hasher.finish() - }, - types_hash: { - let mut hasher = DefaultHasher::new(); - for param in params { - std::mem::discriminant(¶m.typed).hash(&mut hasher); - } - hasher.finish() - }, - } - } -} +/// A [`PostgreSql`] interface with the tracing caching strategy, which involves storing query +/// type information in a dedicated LRU cache for applicable queries and not re-using any prepared +/// statements. +pub type PostgreSqlWithTracingCache = PostgreSql; #[derive(Debug)] struct SslAuth { @@ -171,11 +162,13 @@ impl SslParams { } impl PostgresNativeUrl { - pub(crate) fn cache(&self) -> StatementCache { + pub(crate) fn cache_settings(&self) -> CacheSettings { if self.query_params.pg_bouncer { - StatementCache::new(0) + CacheSettings { capacity: 0 } } else { - StatementCache::new(self.query_params.statement_cache_size) + CacheSettings { + capacity: self.query_params.statement_cache_size, + } } } @@ -236,7 +229,25 @@ impl PostgresNativeUrl { } } -impl PostgreSql { +impl PostgreSqlWithNoCache { + /// Create a new websocket connection to managed database + pub async fn new_with_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let client = connect_via_websocket(url).await?; + + Ok(Self { + client: PostgresClient(client), + socket_timeout: None, + pg_bouncer: false, + cache: NoOpCache, + is_healthy: AtomicBool::new(true), + is_cockroachdb: false, + is_materialize: false, + db_system_name: DB_SYSTEM_NAME_POSTGRESQL, + }) + } +} + +impl PostgreSql { /// Create a new connection to the database. pub async fn new(url: PostgresNativeUrl, tls_manager: &MakeTlsConnectorManager) -> crate::Result { let config = url.to_config(); @@ -287,7 +298,7 @@ impl PostgreSql { client: PostgresClient(client), socket_timeout: url.query_params.socket_timeout, pg_bouncer: url.query_params.pg_bouncer, - statement_cache: Mutex::new(url.cache()), + cache: url.cache_settings().into(), is_healthy: AtomicBool::new(true), is_cockroachdb, is_materialize, @@ -295,22 +306,6 @@ impl PostgreSql { }) } - /// Create a new websocket connection to managed database - pub async fn new_with_websocket(url: PostgresWebSocketUrl) -> crate::Result { - let client = connect_via_websocket(url).await?; - - Ok(Self { - client: PostgresClient(client), - socket_timeout: None, - pg_bouncer: false, - statement_cache: Mutex::new(StatementCache::new(0)), - is_healthy: AtomicBool::new(true), - is_cockroachdb: false, - is_materialize: false, - db_system_name: DB_SYSTEM_NAME_POSTGRESQL, - }) - } - /// The underlying tokio_postgres::Client. Only available with the /// `expose-drivers` Cargo feature. This is a lower level API when you need /// to get into database specific features. @@ -319,41 +314,6 @@ impl PostgreSql { &self.client.0 } - async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - let statement_key = StatementKey::new(sql, params); - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(&statement_key) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let param_types = conversion::params_to_types(params); - let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; - - cache.insert(statement_key, stmt.clone()); - - Ok(stmt) - } - } - } - async fn perform_io(&self, fut: F) -> crate::Result where F: Future>, @@ -487,6 +447,53 @@ impl PostgreSql { Ok(nullables) } + + async fn query_raw_impl( + &self, + sql: &str, + params: &[Value<'_>], + types: &[PostgresType], + ) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query( + "postgres.query_raw", + self.db_system_name, + sql, + params, + move || async move { + let query = self.cache.get_query(&self.client.0, sql, types).await?; + + if query.param_types().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: query.param_types().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let mut rows = Box::pin( + self.perform_io(query.dispatch(&self.client.0, conversion::conv_params(params))) + .await?, + ); + + let types = query + .column_types() + .map(PGColumnType::from_pg_type) + .map(ColumnType::from) + .collect::>(); + let names = query.column_names().map(|name| name.to_string()).collect::>(); + let mut result = ResultSet::new(names, types, Vec::new()); + + while let Some(row) = rows.next().await { + result.rows.push(row?.get_result_row()?); + } + Ok(result) + }, + ) + .await + } } // A SearchPath connection parameter (Display-impl) for connection initialization. @@ -526,10 +533,22 @@ impl Display for SetSearchPath<'_> { } } -impl_default_TransactionCapable!(PostgreSql); +#[async_trait] +impl TransactionCapable for PostgreSql { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> crate::Result> { + let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + + Ok(Box::new( + DefaultTransaction::new(self, self.begin_statement(), opts).await?, + )) + } +} #[async_trait] -impl Queryable for PostgreSql { +impl Queryable for PostgreSql { async fn query(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Postgres::build(q)?; @@ -537,91 +556,16 @@ impl Queryable for PostgreSql { } async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query( - "postgres.query_raw", - self.db_system_name, - sql, - params, - move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let col_types = stmt - .columns() - .iter() - .map(|c| PGColumnType::from_pg_type(c.type_())) - .map(ColumnType::from) - .collect::>(); - let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }, - ) - .await + self.query_raw_impl(sql, params, &[]).await } async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query( - "postgres.query_raw", - self.db_system_name, - sql, - params, - move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let col_types = stmt - .columns() - .iter() - .map(|c| PGColumnType::from_pg_type(c.type_())) - .map(ColumnType::from) - .collect::>(); - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }, - ) - .await + self.query_raw_impl(sql, params, &conversion::params_to_types(params)) + .await } async fn describe_query(&self, sql: &str) -> crate::Result { - let stmt = self.fetch_cached(sql, &[]).await?; + let stmt = self.cache.get_statement(&self.client.0, sql, &[]).await?; let mut columns: Vec = Vec::with_capacity(stmt.columns().len()); let mut parameters: Vec = Vec::with_capacity(stmt.params().len()); @@ -710,7 +654,7 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; + let stmt = self.cache.get_statement(&self.client.0, sql, &[]).await?; if stmt.params().len() != params.len() { let kind = ErrorKind::IncorrectNumberOfParameters { @@ -740,7 +684,8 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; + let types = conversion::params_to_types(params); + let stmt = self.cache.get_statement(&self.client.0, sql, &types).await?; if stmt.params().len() != params.len() { let kind = ErrorKind::IncorrectNumberOfParameters { @@ -1012,7 +957,7 @@ mod tests { let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); - let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); + let client = PostgreSqlWithDefaultCache::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1066,7 +1011,7 @@ mod tests { let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); - let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); + let client = PostgreSqlWithDefaultCache::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1119,7 +1064,7 @@ mod tests { let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); - let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); + let client = PostgreSqlWithDefaultCache::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1172,7 +1117,7 @@ mod tests { let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); - let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); + let client = PostgreSqlWithDefaultCache::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1225,7 +1170,7 @@ mod tests { let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); - let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); + let client = PostgreSqlWithDefaultCache::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); diff --git a/quaint/src/connector/postgres/native/query.rs b/quaint/src/connector/postgres/native/query.rs new file mode 100644 index 000000000000..3efbae04873d --- /dev/null +++ b/quaint/src/connector/postgres/native/query.rs @@ -0,0 +1,181 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use postgres_types::{BorrowToSql, Type}; +use tokio_postgres::{Client, Error, RowStream, Statement}; + +/// Types that can be dispatched to the database as a query and carry the necessary type +/// information about its parameters and columns to interpret the results. +#[async_trait] +pub trait PreparedQuery: Send { + fn param_types(&self) -> impl ExactSizeIterator; + fn column_names(&self) -> impl ExactSizeIterator; + fn column_types(&self) -> impl ExactSizeIterator; + + async fn dispatch(&self, client: &Client, args: Args) -> Result + where + Args: IntoIterator + Send, + Args::Item: BorrowToSql, + Args::IntoIter: ExactSizeIterator + Send; +} + +#[async_trait] +impl PreparedQuery for Statement { + fn param_types(&self) -> impl ExactSizeIterator { + self.params().iter() + } + + fn column_names(&self) -> impl ExactSizeIterator { + self.columns().iter().map(|c| c.name()) + } + + fn column_types(&self) -> impl ExactSizeIterator { + self.columns().iter().map(|c| c.type_()) + } + + async fn dispatch(&self, client: &Client, args: Args) -> Result + where + Args: IntoIterator + Send, + Args::Item: BorrowToSql, + Args::IntoIter: ExactSizeIterator + Send, + { + client.query_raw(self, args).await + } +} + +/// A query combined with the relevant type information about its parameters and columns. +#[derive(Debug)] +pub struct TypedQuery<'a> { + pub(super) sql: &'a str, + pub(super) metadata: Arc, +} + +impl<'a> TypedQuery<'a> { + /// Create a new typed query from an SQL string and metadata. + pub fn from_sql_and_metadata(sql: &'a str, metadata: impl Into>) -> Self { + Self { + sql, + metadata: metadata.into(), + } + } +} + +#[async_trait] +impl<'a> PreparedQuery for TypedQuery<'a> { + fn param_types(&self) -> impl ExactSizeIterator { + self.metadata.param_types.iter() + } + + fn column_names(&self) -> impl ExactSizeIterator { + self.metadata.column_names.iter().map(|s| s.as_str()) + } + + fn column_types(&self) -> impl ExactSizeIterator { + self.metadata.column_types.iter() + } + + async fn dispatch(&self, client: &Client, args: Args) -> Result + where + Args: IntoIterator + Send, + Args::Item: BorrowToSql, + Args::IntoIter: ExactSizeIterator + Send, + { + let typed_args = args.into_iter().zip(self.metadata.param_types.iter().cloned()); + client.query_typed_raw(self.sql, typed_args).await + } +} + +#[derive(Debug)] +pub struct QueryMetadata { + param_types: Vec, + column_names: Vec, + column_types: Vec, +} + +impl From<&Statement> for QueryMetadata { + fn from(statement: &Statement) -> Self { + Self { + param_types: statement.params().to_vec(), + column_names: statement.columns().iter().map(|c| c.name().to_owned()).collect(), + column_types: statement.columns().iter().map(|c| c.type_().clone()).collect(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::future::Future; + + pub(crate) use crate::connector::postgres::url::PostgresNativeUrl; + use crate::{ + connector::{MakeTlsConnectorManager, PostgresFlavour}, + tests::test_api::postgres::CONN_STR, + }; + use url::Url; + + #[tokio::test] + async fn typed_query_matches_statement_and_dispatches() { + run_with_client(|client| async move { + let query = "SELECT $1"; + let stmt = client.prepare_typed(query, &[Type::INT4]).await.unwrap(); + let typed = TypedQuery::from_sql_and_metadata(query, QueryMetadata::from(&stmt)); + + assert_eq!(typed.param_types().cloned().collect::>(), stmt.params()); + assert_eq!( + typed.column_names().collect::>(), + stmt.columns().iter().map(|c| c.name()).collect::>() + ); + assert_eq!( + typed.column_types().collect::>(), + stmt.columns().iter().map(|c| c.type_()).collect::>() + ); + + let result = typed.dispatch(&client, &[&1i32]).await; + assert!(result.is_ok(), "{:?}", result.err()); + }) + .await; + } + + #[tokio::test] + async fn statement_trait_methods_match_statement_and_dispatch() { + run_with_client(|client| async move { + let query = "SELECT $1"; + let stmt = client.prepare_typed(query, &[Type::INT4]).await.unwrap(); + + assert_eq!(stmt.param_types().cloned().collect::>(), stmt.params()); + assert_eq!( + stmt.column_names().collect::>(), + stmt.columns().iter().map(|c| c.name()).collect::>() + ); + assert_eq!( + stmt.column_types().collect::>(), + stmt.columns().iter().map(|c| c.type_()).collect::>() + ); + + let result = stmt.dispatch(&client, &[&1i32]).await; + assert!(result.is_ok(), "{:?}", result.err()); + }) + .await; + } + + async fn run_with_client(test: Func) + where + Func: FnOnce(Client) -> Fut, + Fut: Future, + { + let url = Url::parse(&CONN_STR).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + let tls = tls_manager.get_connector().await.unwrap(); + + let (client, conn) = pg_url.to_config().connect(tls).await.unwrap(); + + let set = tokio::task::LocalSet::new(); + set.spawn_local(conn); + set.run_until(test(client)).await + } +} diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 096484cdc87a..c35c123b4395 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -559,13 +559,13 @@ mod tests { let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()) .unwrap(); - assert_eq!(420, url.cache().capacity()); + assert_eq!(420, url.cache_settings().capacity); } #[test] fn should_have_default_cache_size() { let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); + assert_eq!(100, url.cache_settings().capacity); } #[test] @@ -598,7 +598,7 @@ mod tests { fn should_not_enable_caching_with_pgbouncer() { let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); - assert_eq!(0, url.cache().capacity()); + assert_eq!(0, url.cache_settings().capacity); } #[test] diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 389005ab7bd3..3da2c6998371 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -356,10 +356,12 @@ impl Builder { impl Quaint { /// Creates a new builder for a Quaint connection pool with the given - /// connection string. See the [module level documentation] for details. - /// - /// [module level documentation]: index.html - pub fn builder(url_str: &str) -> crate::Result { + /// connection string and a tracing flag. + /// See the [module level documentation] for details. + pub fn builder_with_tracing( + url_str: &str, + #[allow(unused_variables)] is_tracing_enabled: bool, + ) -> crate::Result { match url_str { #[cfg(feature = "sqlite")] s if s.starts_with("file") => { @@ -424,7 +426,11 @@ impl Quaint { let max_idle_connection_lifetime = url.max_idle_connection_lifetime(); let tls_manager = crate::connector::MakeTlsConnectorManager::new(url.clone()); - let manager = QuaintManager::Postgres { url, tls_manager }; + let manager = QuaintManager::Postgres { + url, + tls_manager, + is_tracing_enabled, + }; let mut builder = Builder::new(s, manager)?; if let Some(limit) = connection_limit { @@ -478,6 +484,14 @@ impl Quaint { } } + /// Creates a new builder for a Quaint connection pool with the given + /// connection string. See the [module level documentation] for details. + /// + /// [module level documentation]: index.html + pub fn builder(url_str: &str) -> crate::Result { + Self::builder_with_tracing(url_str, false) + } + /// The number of connections in the pool. pub async fn capacity(&self) -> u32 { self.inner.state().await.max_open as u32 diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 8a9715640579..4b13ca26bd24 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -101,6 +101,7 @@ pub enum QuaintManager { Postgres { url: PostgresNativeUrl, tls_manager: MakeTlsConnectorManager, + is_tracing_enabled: bool, }, #[cfg(feature = "sqlite")] @@ -133,9 +134,23 @@ impl Manager for QuaintManager { } #[cfg(feature = "postgresql-native")] - QuaintManager::Postgres { url, tls_manager } => { - use crate::connector::PostgreSql; - Ok(Box::new(PostgreSql::new(url.clone(), tls_manager).await?) as Self::Connection) + QuaintManager::Postgres { + url, + tls_manager, + is_tracing_enabled: false, + } => { + use crate::connector::PostgreSqlWithDefaultCache; + Ok(Box::new(PostgreSqlWithDefaultCache::new(url.clone(), tls_manager).await?) as Self::Connection) + } + + #[cfg(feature = "postgresql-native")] + QuaintManager::Postgres { + url, + tls_manager, + is_tracing_enabled: true, + } => { + use crate::connector::PostgreSqlWithTracingCache; + Ok(Box::new(PostgreSqlWithTracingCache::new(url.clone(), tls_manager).await?) as Self::Connection) } #[cfg(feature = "mssql-native")] diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 004b84e0da54..d4d48eff73dc 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -158,7 +158,7 @@ impl Quaint { s if s.starts_with("postgres") || s.starts_with("postgresql") => { let url = connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; let tls_manager = connector::MakeTlsConnectorManager::new(url.clone()); - let psql = connector::PostgreSql::new(url, &tls_manager).await?; + let psql = connector::PostgreSqlWithDefaultCache::new(url, &tls_manager).await?; Arc::new(psql) as Arc } #[cfg(feature = "mssql-native")] diff --git a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs index 9e59e4f232c7..171181297bfa 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs @@ -6,6 +6,7 @@ use connector_interface::{ Connection, Connector, }; use psl::builtin_connectors::COCKROACH; +use psl::PreviewFeature; use quaint::{connector::PostgresFlavour, pooled::Quaint, prelude::ConnectionInfo}; use std::time::Duration; @@ -40,7 +41,7 @@ impl FromSource for PostgreSql { }) })?; - let mut builder = Quaint::builder(url) + let mut builder = Quaint::builder_with_tracing(url, features.contains(PreviewFeature::Tracing)) .map_err(SqlError::from) .map_err(|sql_error| sql_error.into_connector_error(&err_conn_info))?; diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs index 3a8f9fb6517a..59325574d794 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs @@ -15,7 +15,7 @@ use crate::sql_renderer::IteratorJoin; use super::MigratePostgresUrl; -pub(super) struct Connection(connector::PostgreSql); +pub(super) struct Connection(connector::PostgreSqlWithNoCache); impl Connection { pub(super) async fn new(url: url::Url) -> ConnectorResult { @@ -24,7 +24,7 @@ impl Connection { let quaint = match url.0 { PostgresUrl::Native(ref native_url) => { let tls_manager = MakeTlsConnectorManager::new(native_url.as_ref().clone()); - connector::PostgreSql::new(native_url.as_ref().clone(), &tls_manager).await + connector::PostgreSqlWithNoCache::new(native_url.as_ref().clone(), &tls_manager).await } PostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()).await, }