From 22dca475c2e7a33be6ce983820012f5c53cdb44f Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 3 Dec 2024 15:37:26 +0000 Subject: [PATCH 1/4] rename tokio-postgres -> postgres_client --- proxy/Cargo.toml | 2 +- proxy/src/auth/backend/classic.rs | 2 +- proxy/src/auth/backend/console_redirect.rs | 2 +- proxy/src/auth/backend/mod.rs | 2 +- proxy/src/auth/flow.rs | 2 +- proxy/src/cancellation.rs | 6 ++--- proxy/src/compute.rs | 26 +++++++++++----------- proxy/src/control_plane/client/mock.rs | 16 ++++++------- proxy/src/control_plane/client/neon.rs | 2 +- proxy/src/control_plane/errors.rs | 2 +- proxy/src/error.rs | 2 +- proxy/src/postgres_rustls/mod.rs | 6 ++--- proxy/src/proxy/retry.rs | 16 ++++++------- proxy/src/proxy/tests/mitm.rs | 22 +++++++++--------- proxy/src/proxy/tests/mod.rs | 20 ++++++++--------- proxy/src/serverless/backend.rs | 18 +++++++-------- proxy/src/serverless/conn_pool.rs | 6 ++--- proxy/src/serverless/conn_pool_lib.rs | 4 ++-- proxy/src/serverless/json.rs | 6 ++--- proxy/src/serverless/local_conn_pool.rs | 10 ++++----- proxy/src/serverless/sql_over_http.rs | 18 +++++++-------- 21 files changed, 95 insertions(+), 95 deletions(-) diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index f5934c8a89dd..60bc08b2d00f 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -55,6 +55,7 @@ parquet.workspace = true parquet_derive.workspace = true pin-project-lite.workspace = true postgres_backend.workspace = true +postgres-client = { package = "tokio-postgres2", path = "../libs/proxy/tokio-postgres2" } postgres-protocol = { package = "postgres-protocol2", path = "../libs/proxy/postgres-protocol2" } pq_proto.workspace = true prometheus.workspace = true @@ -81,7 +82,6 @@ subtle.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] } -tokio-postgres = { package = "tokio-postgres2", path = "../libs/proxy/tokio-postgres2" } tokio-rustls.workspace = true tokio-util.workspace = true tokio = { workspace = true, features = ["signal"] } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 491b272ac4e8..5e494dfdd694 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -66,7 +66,7 @@ pub(super) async fn authenticate( Ok(ComputeCredentials { info: creds, - keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256( + keys: ComputeCredentialKeys::AuthKeys(postgres_client::config::AuthKeys::ScramSha256( scram_keys, )), }) diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index bf7a1cb0705f..494564de05f0 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; +use postgres_client::config::SslMode; use pq_proto::BeMessage as Be; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_postgres::config::SslMode; use tracing::{info, info_span}; use super::ComputeCredentialKeys; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 7e1b26a11a0d..84a572dcf9f1 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -11,8 +11,8 @@ pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::ConsoleRedirectError; use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; +use postgres_client::config::AuthKeys; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_postgres::config::AuthKeys; use tracing::{debug, info, warn}; use crate::auth::credentials::check_peer_addr_is_in_list; diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 9c6ce151cba9..60d1962d7f78 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -227,7 +227,7 @@ pub(crate) async fn validate_password_and_exchange( }; Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys( - tokio_postgres::config::AuthKeys::ScramSha256(keys), + postgres_client::config::AuthKeys::ScramSha256(keys), ))) } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 91e198bf88a2..bcb0ef40bd74 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -3,11 +3,11 @@ use std::sync::Arc; use dashmap::DashMap; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; +use postgres_client::{CancelToken, NoTls}; use pq_proto::CancelKeyData; use thiserror::Error; use tokio::net::TcpStream; use tokio::sync::Mutex; -use tokio_postgres::{CancelToken, NoTls}; use tracing::{debug, info}; use uuid::Uuid; @@ -44,7 +44,7 @@ pub(crate) enum CancelError { IO(#[from] std::io::Error), #[error("{0}")] - Postgres(#[from] tokio_postgres::Error), + Postgres(#[from] postgres_client::Error), #[error("rate limit exceeded")] RateLimit, @@ -70,7 +70,7 @@ impl ReportableError for CancelError { impl CancellationHandler

{ /// Run async action within an ephemeral session identified by [`CancelKeyData`]. pub(crate) fn get_session(self: Arc) -> Session

{ - // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't + // HACK: We'd rather get the real backend_pid but postgres_client doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the // actual backend_pid, but backend_pid is not used for anything diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index b689b97a2100..06bc71c55988 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -6,6 +6,8 @@ use std::time::Duration; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use once_cell::sync::OnceCell; +use postgres_client::tls::MakeTlsConnect; +use postgres_client::{CancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; use pq_proto::StartupMessageParams; use rustls::client::danger::ServerCertVerifier; @@ -13,8 +15,6 @@ use rustls::crypto::ring; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::TcpStream; -use tokio_postgres::tls::MakeTlsConnect; -use tokio_postgres::{CancelToken, RawConnection}; use tracing::{debug, error, info, warn}; use crate::auth::parse_endpoint_param; @@ -34,9 +34,9 @@ pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; #[derive(Debug, Error)] pub(crate) enum ConnectionError { /// This error doesn't seem to reveal any secrets; for instance, - /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such. + /// `postgres_client::error::Kind` doesn't contain ip addresses and such. #[error("{COULD_NOT_CONNECT}: {0}")] - Postgres(#[from] tokio_postgres::Error), + Postgres(#[from] postgres_client::Error), #[error("{COULD_NOT_CONNECT}: {0}")] CouldNotConnect(#[from] io::Error), @@ -99,13 +99,13 @@ impl ReportableError for ConnectionError { } /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. -pub(crate) type ScramKeys = tokio_postgres::config::ScramKeys<32>; +pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>; /// A config for establishing a connection to compute node. -/// Eventually, `tokio_postgres` will be replaced with something better. +/// Eventually, `postgres_client` will be replaced with something better. /// Newtype allows us to implement methods on top of it. #[derive(Clone, Default)] -pub(crate) struct ConnCfg(Box); +pub(crate) struct ConnCfg(Box); /// Creation and initialization routines. impl ConnCfg { @@ -126,7 +126,7 @@ impl ConnCfg { pub(crate) fn get_host(&self) -> Result { match self.0.get_hosts() { - [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()), + [postgres_client::config::Host::Tcp(s)] => Ok(s.into()), // we should not have multiple address or unix addresses. _ => Err(WakeComputeError::BadComputeAddress( "invalid compute address".into(), @@ -160,7 +160,7 @@ impl ConnCfg { // TODO: This is especially ugly... if let Some(replication) = params.get("replication") { - use tokio_postgres::config::ReplicationMode; + use postgres_client::config::ReplicationMode; match replication { "true" | "on" | "yes" | "1" => { self.replication_mode(ReplicationMode::Physical); @@ -182,7 +182,7 @@ impl ConnCfg { } impl std::ops::Deref for ConnCfg { - type Target = tokio_postgres::Config; + type Target = postgres_client::Config; fn deref(&self) -> &Self::Target { &self.0 @@ -199,7 +199,7 @@ impl std::ops::DerefMut for ConnCfg { impl ConnCfg { /// Establish a raw TCP connection to the compute node. async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> { - use tokio_postgres::config::Host; + use postgres_client::config::Host; // wrap TcpStream::connect with timeout let connect_with_timeout = |host, port| { @@ -224,7 +224,7 @@ impl ConnCfg { }) }; - // We can't reuse connection establishing logic from `tokio_postgres` here, + // We can't reuse connection establishing logic from `postgres_client` here, // because it has no means for extracting the underlying socket which we // require for our business. let mut connection_error = None; @@ -272,7 +272,7 @@ type RustlsStream = > pub(crate) struct PostgresConnection { /// Socket connected to a compute node. pub(crate) stream: - tokio_postgres::maybe_tls_stream::MaybeTlsStream, + postgres_client::maybe_tls_stream::MaybeTlsStream, /// PostgreSQL connection parameters. pub(crate) params: std::collections::HashMap, /// Query cancellation token. diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 9537d717a1f1..3ea4e8f2112c 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -4,9 +4,9 @@ use std::str::FromStr; use std::sync::Arc; use futures::TryFutureExt; +use postgres_client::config::SslMode; +use postgres_client::Client; use thiserror::Error; -use tokio_postgres::config::SslMode; -use tokio_postgres::Client; use tracing::{error, info, info_span, warn, Instrument}; use crate::auth::backend::jwt::AuthRule; @@ -29,7 +29,7 @@ use crate::{compute, scram}; #[derive(Debug, Error)] enum MockApiError { #[error("Failed to read password: {0}")] - PasswordNotSet(tokio_postgres::Error), + PasswordNotSet(postgres_client::Error), } impl From for ControlPlaneError { @@ -38,8 +38,8 @@ impl From for ControlPlaneError { } } -impl From for ControlPlaneError { - fn from(e: tokio_postgres::Error) -> Self { +impl From for ControlPlaneError { + fn from(e: postgres_client::Error) -> Self { io_error(e).into() } } @@ -71,7 +71,7 @@ impl MockControlPlane { // write more code for reopening it if it got closed, which doesn't // seem worth it. let (client, connection) = - tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; + postgres_client::connect(self.endpoint.as_str(), postgres_client::NoTls).await?; tokio::spawn(connection); @@ -129,7 +129,7 @@ impl MockControlPlane { endpoint: EndpointId, ) -> Result, GetEndpointJwksError> { let (client, connection) = - tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; + postgres_client::connect(self.endpoint.as_str(), postgres_client::NoTls).await?; let connection = tokio::spawn(connection); @@ -185,7 +185,7 @@ impl MockControlPlane { async fn get_execute_postgres_query( client: &Client, query: &str, - params: &[&(dyn tokio_postgres::types::ToSql + Sync)], + params: &[&(dyn postgres_client::types::ToSql + Sync)], idx: &str, ) -> Result, GetAuthInfoError> { let rows = client.query(query, params).await?; diff --git a/proxy/src/control_plane/client/neon.rs b/proxy/src/control_plane/client/neon.rs index 2cad981d01ac..5a78ec9d32aa 100644 --- a/proxy/src/control_plane/client/neon.rs +++ b/proxy/src/control_plane/client/neon.rs @@ -6,8 +6,8 @@ use std::time::Duration; use ::http::header::AUTHORIZATION; use ::http::HeaderName; use futures::TryFutureExt; +use postgres_client::config::SslMode; use tokio::time::Instant; -use tokio_postgres::config::SslMode; use tracing::{debug, info, info_span, warn, Instrument}; use super::super::messages::{ControlPlaneErrorMessage, GetRoleSecret, WakeCompute}; diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index d6f565e34a45..57cc01883fa5 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -204,7 +204,7 @@ pub enum GetEndpointJwksError { #[cfg(any(test, feature = "testing"))] #[error(transparent)] - TokioPostgres(#[from] tokio_postgres::Error), + TokioPostgres(#[from] postgres_client::Error), #[cfg(any(test, feature = "testing"))] #[error(transparent)] diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 2221aac407fc..6a379499dc62 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -84,7 +84,7 @@ pub(crate) trait ReportableError: fmt::Display + Send + 'static { fn get_error_kind(&self) -> ErrorKind; } -impl ReportableError for tokio_postgres::error::Error { +impl ReportableError for postgres_client::error::Error { fn get_error_kind(&self) -> ErrorKind { if self.as_db_error().is_some() { ErrorKind::Postgres diff --git a/proxy/src/postgres_rustls/mod.rs b/proxy/src/postgres_rustls/mod.rs index 31e7915e89fd..5ef20991c309 100644 --- a/proxy/src/postgres_rustls/mod.rs +++ b/proxy/src/postgres_rustls/mod.rs @@ -1,10 +1,10 @@ use std::convert::TryFrom; use std::sync::Arc; +use postgres_client::tls::MakeTlsConnect; use rustls::pki_types::ServerName; use rustls::ClientConfig; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_postgres::tls::MakeTlsConnect; mod private { use std::future::Future; @@ -12,9 +12,9 @@ mod private { use std::pin::Pin; use std::task::{Context, Poll}; + use postgres_client::tls::{ChannelBinding, TlsConnect}; use rustls::pki_types::ServerName; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - use tokio_postgres::tls::{ChannelBinding, TlsConnect}; use tokio_rustls::client::TlsStream; use tokio_rustls::TlsConnector; @@ -59,7 +59,7 @@ mod private { pub struct RustlsStream(TlsStream); - impl tokio_postgres::tls::TlsStream for RustlsStream + impl postgres_client::tls::TlsStream for RustlsStream where S: AsyncRead + AsyncWrite + Unpin, { diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index d3f0c3e7d471..42d1491782dd 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -31,9 +31,9 @@ impl CouldRetry for io::Error { } } -impl CouldRetry for tokio_postgres::error::DbError { +impl CouldRetry for postgres_client::error::DbError { fn could_retry(&self) -> bool { - use tokio_postgres::error::SqlState; + use postgres_client::error::SqlState; matches!( self.code(), &SqlState::CONNECTION_FAILURE @@ -43,9 +43,9 @@ impl CouldRetry for tokio_postgres::error::DbError { ) } } -impl ShouldRetryWakeCompute for tokio_postgres::error::DbError { +impl ShouldRetryWakeCompute for postgres_client::error::DbError { fn should_retry_wake_compute(&self) -> bool { - use tokio_postgres::error::SqlState; + use postgres_client::error::SqlState; // Here are errors that happens after the user successfully authenticated to the database. // TODO: there are pgbouncer errors that should be retried, but they are not listed here. !matches!( @@ -61,21 +61,21 @@ impl ShouldRetryWakeCompute for tokio_postgres::error::DbError { } } -impl CouldRetry for tokio_postgres::Error { +impl CouldRetry for postgres_client::Error { fn could_retry(&self) -> bool { if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) { io::Error::could_retry(io_err) } else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) { - tokio_postgres::error::DbError::could_retry(db_err) + postgres_client::error::DbError::could_retry(db_err) } else { false } } } -impl ShouldRetryWakeCompute for tokio_postgres::Error { +impl ShouldRetryWakeCompute for postgres_client::Error { fn should_retry_wake_compute(&self) -> bool { if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) { - tokio_postgres::error::DbError::should_retry_wake_compute(db_err) + postgres_client::error::DbError::should_retry_wake_compute(db_err) } else { // likely an IO error. Possible the compute has shutdown and the // cache is stale. diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index fe211adfeb7b..ef351f3b54b2 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -8,9 +8,9 @@ use std::fmt::Debug; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; +use postgres_client::tls::TlsConnect; use postgres_protocol::message::frontend; use tokio::io::{AsyncReadExt, DuplexStream}; -use tokio_postgres::tls::TlsConnect; use tokio_util::codec::{Decoder, Encoder}; use super::*; @@ -158,8 +158,8 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { Scram::new("password").await?, )); - let _client_err = tokio_postgres::Config::new() - .channel_binding(tokio_postgres::config::ChannelBinding::Disable) + let _client_err = postgres_client::Config::new() + .channel_binding(postgres_client::config::ChannelBinding::Disable) .user("user") .dbname("db") .password("password") @@ -175,7 +175,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> { connect_failure( Intercept::None, - tokio_postgres::config::ChannelBinding::Prefer, + postgres_client::config::ChannelBinding::Prefer, ) .await } @@ -185,7 +185,7 @@ async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> { async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> { connect_failure( Intercept::Methods, - tokio_postgres::config::ChannelBinding::Prefer, + postgres_client::config::ChannelBinding::Prefer, ) .await } @@ -195,7 +195,7 @@ async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> { async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> { connect_failure( Intercept::SASLResponse, - tokio_postgres::config::ChannelBinding::Prefer, + postgres_client::config::ChannelBinding::Prefer, ) .await } @@ -205,7 +205,7 @@ async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Resul async fn scram_auth_require_channel_binding() -> anyhow::Result<()> { connect_failure( Intercept::None, - tokio_postgres::config::ChannelBinding::Require, + postgres_client::config::ChannelBinding::Require, ) .await } @@ -215,7 +215,7 @@ async fn scram_auth_require_channel_binding() -> anyhow::Result<()> { async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> { connect_failure( Intercept::Methods, - tokio_postgres::config::ChannelBinding::Require, + postgres_client::config::ChannelBinding::Require, ) .await } @@ -225,14 +225,14 @@ async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> { async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> { connect_failure( Intercept::SASLResponse, - tokio_postgres::config::ChannelBinding::Require, + postgres_client::config::ChannelBinding::Require, ) .await } async fn connect_failure( intercept: Intercept, - channel_binding: tokio_postgres::config::ChannelBinding, + channel_binding: postgres_client::config::ChannelBinding, ) -> anyhow::Result<()> { let (server, client, client_config, server_config) = proxy_mitm(intercept).await; let proxy = tokio::spawn(dummy_proxy( @@ -241,7 +241,7 @@ async fn connect_failure( Scram::new("password").await?, )); - let _client_err = tokio_postgres::Config::new() + let _client_err = postgres_client::Config::new() .channel_binding(channel_binding) .user("user") .dbname("db") diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 15be6c9724e8..c8b742b3ff23 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -7,13 +7,13 @@ use std::time::Duration; use anyhow::{bail, Context}; use async_trait::async_trait; use http::StatusCode; +use postgres_client::config::SslMode; +use postgres_client::tls::{MakeTlsConnect, NoTls}; use retry::{retry_after, ShouldRetryWakeCompute}; use rstest::rstest; use rustls::crypto::ring; use rustls::pki_types; use tokio::io::DuplexStream; -use tokio_postgres::config::SslMode; -use tokio_postgres::tls::{MakeTlsConnect, NoTls}; use super::connect_compute::ConnectMechanism; use super::retry::CouldRetry; @@ -204,7 +204,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); - let client_err = tokio_postgres::Config::new() + let client_err = postgres_client::Config::new() .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Disable) @@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> { generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); - let _conn = tokio_postgres::Config::new() + let _conn = postgres_client::Config::new() .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Require) @@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> { let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); - let _conn = tokio_postgres::Config::new() + let _conn = postgres_client::Config::new() .user("john_doe") .dbname("earth") .options("project=generic-project-name") @@ -296,8 +296,8 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { Scram::new(password).await?, )); - let _conn = tokio_postgres::Config::new() - .channel_binding(tokio_postgres::config::ChannelBinding::Require) + let _conn = postgres_client::Config::new() + .channel_binding(postgres_client::config::ChannelBinding::Require) .user("user") .dbname("db") .password(password) @@ -320,8 +320,8 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { Scram::new("password").await?, )); - let _conn = tokio_postgres::Config::new() - .channel_binding(tokio_postgres::config::ChannelBinding::Disable) + let _conn = postgres_client::Config::new() + .channel_binding(postgres_client::config::ChannelBinding::Disable) .user("user") .dbname("db") .password("password") @@ -348,7 +348,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> { .map(char::from) .collect(); - let _client_err = tokio_postgres::Config::new() + let _client_err = postgres_client::Config::new() .user("user") .dbname("db") .password(&password) // no password will match the mocked secret diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 57846a4c2c51..8c7931907da5 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -37,9 +37,9 @@ use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; pub(crate) struct PoolingBackend { pub(crate) http_conn_pool: Arc>>, - pub(crate) local_pool: Arc>, + pub(crate) local_pool: Arc>, pub(crate) pool: - Arc>>, + Arc>>, pub(crate) config: &'static ProxyConfig, pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>, @@ -170,7 +170,7 @@ impl PoolingBackend { conn_info: ConnInfo, keys: ComputeCredentials, force_new: bool, - ) -> Result, HttpConnError> { + ) -> Result, HttpConnError> { let maybe_client = if force_new { debug!("pool: pool is disabled"); None @@ -256,7 +256,7 @@ impl PoolingBackend { &self, ctx: &RequestContext, conn_info: ConnInfo, - ) -> Result, HttpConnError> { + ) -> Result, HttpConnError> { if let Some(client) = self.local_pool.get(ctx, &conn_info)? { return Ok(client); } @@ -315,7 +315,7 @@ impl PoolingBackend { )); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (client, connection) = config.connect(tokio_postgres::NoTls).await?; + let (client, connection) = config.connect(postgres_client::NoTls).await?; drop(pause); let pid = client.get_process_id(); @@ -360,7 +360,7 @@ pub(crate) enum HttpConnError { #[error("pooled connection closed at inconsistent state")] ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), #[error("could not connection to postgres in compute")] - PostgresConnectionError(#[from] tokio_postgres::Error), + PostgresConnectionError(#[from] postgres_client::Error), #[error("could not connection to local-proxy in compute")] LocalProxyConnectionError(#[from] LocalProxyConnError), #[error("could not parse JWT payload")] @@ -479,7 +479,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError { } struct TokioMechanism { - pool: Arc>>, + pool: Arc>>, conn_info: ConnInfo, conn_id: uuid::Uuid, @@ -489,7 +489,7 @@ struct TokioMechanism { #[async_trait] impl ConnectMechanism for TokioMechanism { - type Connection = Client; + type Connection = Client; type ConnectError = HttpConnError; type Error = HttpConnError; @@ -509,7 +509,7 @@ impl ConnectMechanism for TokioMechanism { .connect_timeout(timeout); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = config.connect(tokio_postgres::NoTls).await; + let res = config.connect(postgres_client::NoTls).await; drop(pause); let (client, connection) = permit.release_result(res)?; diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index c302eac5684b..cac5a173cb16 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -5,11 +5,11 @@ use std::task::{ready, Poll}; use futures::future::poll_fn; use futures::Future; +use postgres_client::tls::NoTlsStream; +use postgres_client::AsyncMessage; use smallvec::SmallVec; use tokio::net::TcpStream; use tokio::time::Instant; -use tokio_postgres::tls::NoTlsStream; -use tokio_postgres::AsyncMessage; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; #[cfg(test)] @@ -58,7 +58,7 @@ pub(crate) fn poll_client( ctx: &RequestContext, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: postgres_client::Connection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> Client { diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index fe1d2563bca1..2a46c8f9c5cf 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -7,8 +7,8 @@ use std::time::Duration; use dashmap::DashMap; use parking_lot::RwLock; +use postgres_client::ReadyForQueryStatus; use rand::Rng; -use tokio_postgres::ReadyForQueryStatus; use tracing::{debug, info, Span}; use super::backend::HttpConnError; @@ -683,7 +683,7 @@ pub(crate) trait ClientInnerExt: Sync + Send + 'static { fn get_process_id(&self) -> i32; } -impl ClientInnerExt for tokio_postgres::Client { +impl ClientInnerExt for postgres_client::Client { fn is_closed(&self) -> bool { self.is_closed() } diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 569e2da5715a..25b25c66d3fb 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -1,6 +1,6 @@ +use postgres_client::types::{Kind, Type}; +use postgres_client::Row; use serde_json::{Map, Value}; -use tokio_postgres::types::{Kind, Type}; -use tokio_postgres::Row; // // Convert json non-string types to strings, so that they can be passed to Postgres @@ -61,7 +61,7 @@ fn json_array_to_pg_array(value: &Value) -> Option { #[derive(Debug, thiserror::Error)] pub(crate) enum JsonConversionError { #[error("internal error compute returned invalid data: {0}")] - AsTextError(tokio_postgres::Error), + AsTextError(postgres_client::Error), #[error("parse int error: {0}")] ParseIntError(#[from] std::num::ParseIntError), #[error("parse float error: {0}")] diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index db9ac49dae8f..b84cde9e252a 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -22,13 +22,13 @@ use indexmap::IndexMap; use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding}; use p256::ecdsa::{Signature, SigningKey}; use parking_lot::RwLock; +use postgres_client::tls::NoTlsStream; +use postgres_client::types::ToSql; +use postgres_client::AsyncMessage; use serde_json::value::RawValue; use signature::Signer; use tokio::net::TcpStream; use tokio::time::Instant; -use tokio_postgres::tls::NoTlsStream; -use tokio_postgres::types::ToSql; -use tokio_postgres::AsyncMessage; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, info_span, warn, Instrument}; @@ -164,7 +164,7 @@ pub(crate) fn poll_client( ctx: &RequestContext, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: postgres_client::Connection, key: SigningKey, conn_id: uuid::Uuid, aux: MetricsAuxInfo, @@ -280,7 +280,7 @@ pub(crate) fn poll_client( ) } -impl ClientInnerCommon { +impl ClientInnerCommon { pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> { if let ClientDataEnum::Local(local_data) = &mut self.data { local_data.jti += 1; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index a0ca7cc60d6a..5e85f5ec4019 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -11,12 +11,12 @@ use http_body_util::{BodyExt, Full}; use hyper::body::Incoming; use hyper::http::{HeaderName, HeaderValue}; use hyper::{header, HeaderMap, Request, Response, StatusCode}; +use postgres_client::error::{DbError, ErrorPosition, SqlState}; +use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction}; use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; use tokio::time::{self, Instant}; -use tokio_postgres::error::{DbError, ErrorPosition, SqlState}; -use tokio_postgres::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info}; use typed_json::json; @@ -361,7 +361,7 @@ pub(crate) enum SqlOverHttpError { #[error("invalid isolation level")] InvalidIsolationLevel, #[error("{0}")] - Postgres(#[from] tokio_postgres::Error), + Postgres(#[from] postgres_client::Error), #[error("{0}")] JsonConversion(#[from] JsonConversionError), #[error("{0}")] @@ -986,7 +986,7 @@ async fn query_to_json( // Manually drain the stream into a vector to leave row_stream hanging // around to get a command tag. Also check that the response is not too // big. - let mut rows: Vec = Vec::new(); + let mut rows: Vec = Vec::new(); while let Some(row) = row_stream.next().await { let row = row?; *current_size += row.body_len(); @@ -1063,13 +1063,13 @@ async fn query_to_json( } enum Client { - Remote(conn_pool_lib::Client), - Local(conn_pool_lib::Client), + Remote(conn_pool_lib::Client), + Local(conn_pool_lib::Client), } enum Discard<'a> { - Remote(conn_pool_lib::Discard<'a, tokio_postgres::Client>), - Local(conn_pool_lib::Discard<'a, tokio_postgres::Client>), + Remote(conn_pool_lib::Discard<'a, postgres_client::Client>), + Local(conn_pool_lib::Discard<'a, postgres_client::Client>), } impl Client { @@ -1080,7 +1080,7 @@ impl Client { } } - fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { + fn inner(&mut self) -> (&mut postgres_client::Client, Discard<'_>) { match self { Client::Remote(client) => { let (c, d) = client.inner(); From 4d09a5a3be518bf1f2809af653be1b954076d67b Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 3 Dec 2024 15:39:20 +0000 Subject: [PATCH 2/4] reintroduce tokio-postgres for mock provider --- Cargo.lock | 1 + proxy/Cargo.toml | 4 +++- proxy/src/control_plane/client/mock.rs | 17 ++++++++--------- proxy/src/control_plane/errors.rs | 2 +- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b2769e59f082..ff5c9b95ae6a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4612,6 +4612,7 @@ dependencies = [ "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", + "tokio-postgres", "tokio-postgres2", "tokio-rustls 0.26.0", "tokio-tungstenite", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 60bc08b2d00f..2f63ee3acc42 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -6,7 +6,7 @@ license.workspace = true [features] default = [] -testing = [] +testing = ["dep:tokio-postgres"] [dependencies] ahash.workspace = true @@ -82,6 +82,7 @@ subtle.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] } +tokio-postgres = { workspace = true, optional = true } tokio-rustls.workspace = true tokio-util.workspace = true tokio = { workspace = true, features = ["signal"] } @@ -119,3 +120,4 @@ rcgen.workspace = true rstest.workspace = true walkdir.workspace = true rand_distr = "0.4" +tokio-postgres.workspace = true diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 3ea4e8f2112c..4d55f96ca198 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -4,9 +4,8 @@ use std::str::FromStr; use std::sync::Arc; use futures::TryFutureExt; -use postgres_client::config::SslMode; -use postgres_client::Client; use thiserror::Error; +use tokio_postgres::Client; use tracing::{error, info, info_span, warn, Instrument}; use crate::auth::backend::jwt::AuthRule; @@ -29,7 +28,7 @@ use crate::{compute, scram}; #[derive(Debug, Error)] enum MockApiError { #[error("Failed to read password: {0}")] - PasswordNotSet(postgres_client::Error), + PasswordNotSet(tokio_postgres::Error), } impl From for ControlPlaneError { @@ -38,8 +37,8 @@ impl From for ControlPlaneError { } } -impl From for ControlPlaneError { - fn from(e: postgres_client::Error) -> Self { +impl From for ControlPlaneError { + fn from(e: tokio_postgres::Error) -> Self { io_error(e).into() } } @@ -71,7 +70,7 @@ impl MockControlPlane { // write more code for reopening it if it got closed, which doesn't // seem worth it. let (client, connection) = - postgres_client::connect(self.endpoint.as_str(), postgres_client::NoTls).await?; + tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; tokio::spawn(connection); @@ -129,7 +128,7 @@ impl MockControlPlane { endpoint: EndpointId, ) -> Result, GetEndpointJwksError> { let (client, connection) = - postgres_client::connect(self.endpoint.as_str(), postgres_client::NoTls).await?; + tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; let connection = tokio::spawn(connection); @@ -165,7 +164,7 @@ impl MockControlPlane { config .host(self.endpoint.host_str().unwrap_or("localhost")) .port(self.endpoint.port().unwrap_or(5432)) - .ssl_mode(SslMode::Disable); + .ssl_mode(postgres_client::config::SslMode::Disable); let node = NodeInfo { config, @@ -185,7 +184,7 @@ impl MockControlPlane { async fn get_execute_postgres_query( client: &Client, query: &str, - params: &[&(dyn postgres_client::types::ToSql + Sync)], + params: &[&(dyn tokio_postgres::types::ToSql + Sync)], idx: &str, ) -> Result, GetAuthInfoError> { let rows = client.query(query, params).await?; diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index 57cc01883fa5..d6f565e34a45 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -204,7 +204,7 @@ pub enum GetEndpointJwksError { #[cfg(any(test, feature = "testing"))] #[error(transparent)] - TokioPostgres(#[from] postgres_client::Error), + TokioPostgres(#[from] tokio_postgres::Error), #[cfg(any(test, feature = "testing"))] #[error(transparent)] From f87b34368c69545c435f62bcc12325ed25997b76 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 3 Dec 2024 15:40:53 +0000 Subject: [PATCH 3/4] delete config parse code --- libs/proxy/tokio-postgres2/src/config.rs | 465 +------------------- libs/proxy/tokio-postgres2/src/error/mod.rs | 6 - libs/proxy/tokio-postgres2/src/lib.rs | 20 - 3 files changed, 1 insertion(+), 490 deletions(-) diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 26124b38ef8f..5dad835c3bdd 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -6,11 +6,9 @@ use crate::connect_raw::RawConnection; use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; use crate::{Client, Connection, Error}; -use std::borrow::Cow; +use std::fmt; use std::str; -use std::str::FromStr; use std::time::Duration; -use std::{error, fmt, iter, mem}; use tokio::io::{AsyncRead, AsyncWrite}; pub use postgres_protocol2::authentication::sasl::ScramKeys; @@ -380,99 +378,6 @@ impl Config { self.max_backend_message_size } - fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { - match key { - "user" => { - self.user(value); - } - "password" => { - self.password(value); - } - "dbname" => { - self.dbname(value); - } - "options" => { - self.options(value); - } - "application_name" => { - self.application_name(value); - } - "sslmode" => { - let mode = match value { - "disable" => SslMode::Disable, - "prefer" => SslMode::Prefer, - "require" => SslMode::Require, - _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))), - }; - self.ssl_mode(mode); - } - "host" => { - for host in value.split(',') { - self.host(host); - } - } - "port" => { - for port in value.split(',') { - let port = if port.is_empty() { - 5432 - } else { - port.parse() - .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))? - }; - self.port(port); - } - } - "connect_timeout" => { - let timeout = value - .parse::() - .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?; - if timeout > 0 { - self.connect_timeout(Duration::from_secs(timeout as u64)); - } - } - "target_session_attrs" => { - let target_session_attrs = match value { - "any" => TargetSessionAttrs::Any, - "read-write" => TargetSessionAttrs::ReadWrite, - _ => { - return Err(Error::config_parse(Box::new(InvalidValue( - "target_session_attrs", - )))); - } - }; - self.target_session_attrs(target_session_attrs); - } - "channel_binding" => { - let channel_binding = match value { - "disable" => ChannelBinding::Disable, - "prefer" => ChannelBinding::Prefer, - "require" => ChannelBinding::Require, - _ => { - return Err(Error::config_parse(Box::new(InvalidValue( - "channel_binding", - )))) - } - }; - self.channel_binding(channel_binding); - } - "max_backend_message_size" => { - let limit = value.parse::().map_err(|_| { - Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) - })?; - if limit > 0 { - self.max_backend_message_size(limit); - } - } - key => { - return Err(Error::config_parse(Box::new(UnknownOption( - key.to_string(), - )))); - } - } - - Ok(()) - } - /// Opens a connection to a PostgreSQL database. /// /// Requires the `runtime` Cargo feature (enabled by default). @@ -499,17 +404,6 @@ impl Config { } } -impl FromStr for Config { - type Err = Error; - - fn from_str(s: &str) -> Result { - match UrlParser::parse(s)? { - Some(config) => Ok(config), - None => Parser::parse(s), - } - } -} - // Omit password from debug output impl fmt::Debug for Config { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -536,360 +430,3 @@ impl fmt::Debug for Config { .finish() } } - -#[derive(Debug)] -struct UnknownOption(String); - -impl fmt::Display for UnknownOption { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "unknown option `{}`", self.0) - } -} - -impl error::Error for UnknownOption {} - -#[derive(Debug)] -struct InvalidValue(&'static str); - -impl fmt::Display for InvalidValue { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "invalid value for option `{}`", self.0) - } -} - -impl error::Error for InvalidValue {} - -struct Parser<'a> { - s: &'a str, - it: iter::Peekable>, -} - -impl<'a> Parser<'a> { - fn parse(s: &'a str) -> Result { - let mut parser = Parser { - s, - it: s.char_indices().peekable(), - }; - - let mut config = Config::new(); - - while let Some((key, value)) = parser.parameter()? { - config.param(key, &value)?; - } - - Ok(config) - } - - fn skip_ws(&mut self) { - self.take_while(char::is_whitespace); - } - - fn take_while(&mut self, f: F) -> &'a str - where - F: Fn(char) -> bool, - { - let start = match self.it.peek() { - Some(&(i, _)) => i, - None => return "", - }; - - loop { - match self.it.peek() { - Some(&(_, c)) if f(c) => { - self.it.next(); - } - Some(&(i, _)) => return &self.s[start..i], - None => return &self.s[start..], - } - } - } - - fn eat(&mut self, target: char) -> Result<(), Error> { - match self.it.next() { - Some((_, c)) if c == target => Ok(()), - Some((i, c)) => { - let m = format!( - "unexpected character at byte {}: expected `{}` but got `{}`", - i, target, c - ); - Err(Error::config_parse(m.into())) - } - None => Err(Error::config_parse("unexpected EOF".into())), - } - } - - fn eat_if(&mut self, target: char) -> bool { - match self.it.peek() { - Some(&(_, c)) if c == target => { - self.it.next(); - true - } - _ => false, - } - } - - fn keyword(&mut self) -> Option<&'a str> { - let s = self.take_while(|c| match c { - c if c.is_whitespace() => false, - '=' => false, - _ => true, - }); - - if s.is_empty() { - None - } else { - Some(s) - } - } - - fn value(&mut self) -> Result { - let value = if self.eat_if('\'') { - let value = self.quoted_value()?; - self.eat('\'')?; - value - } else { - self.simple_value()? - }; - - Ok(value) - } - - fn simple_value(&mut self) -> Result { - let mut value = String::new(); - - while let Some(&(_, c)) = self.it.peek() { - if c.is_whitespace() { - break; - } - - self.it.next(); - if c == '\\' { - if let Some((_, c2)) = self.it.next() { - value.push(c2); - } - } else { - value.push(c); - } - } - - if value.is_empty() { - return Err(Error::config_parse("unexpected EOF".into())); - } - - Ok(value) - } - - fn quoted_value(&mut self) -> Result { - let mut value = String::new(); - - while let Some(&(_, c)) = self.it.peek() { - if c == '\'' { - return Ok(value); - } - - self.it.next(); - if c == '\\' { - if let Some((_, c2)) = self.it.next() { - value.push(c2); - } - } else { - value.push(c); - } - } - - Err(Error::config_parse( - "unterminated quoted connection parameter value".into(), - )) - } - - fn parameter(&mut self) -> Result, Error> { - self.skip_ws(); - let keyword = match self.keyword() { - Some(keyword) => keyword, - None => return Ok(None), - }; - self.skip_ws(); - self.eat('=')?; - self.skip_ws(); - let value = self.value()?; - - Ok(Some((keyword, value))) - } -} - -// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict -struct UrlParser<'a> { - s: &'a str, - config: Config, -} - -impl<'a> UrlParser<'a> { - fn parse(s: &'a str) -> Result, Error> { - let s = match Self::remove_url_prefix(s) { - Some(s) => s, - None => return Ok(None), - }; - - let mut parser = UrlParser { - s, - config: Config::new(), - }; - - parser.parse_credentials()?; - parser.parse_host()?; - parser.parse_path()?; - parser.parse_params()?; - - Ok(Some(parser.config)) - } - - fn remove_url_prefix(s: &str) -> Option<&str> { - for prefix in &["postgres://", "postgresql://"] { - if let Some(stripped) = s.strip_prefix(prefix) { - return Some(stripped); - } - } - - None - } - - fn take_until(&mut self, end: &[char]) -> Option<&'a str> { - match self.s.find(end) { - Some(pos) => { - let (head, tail) = self.s.split_at(pos); - self.s = tail; - Some(head) - } - None => None, - } - } - - fn take_all(&mut self) -> &'a str { - mem::take(&mut self.s) - } - - fn eat_byte(&mut self) { - self.s = &self.s[1..]; - } - - fn parse_credentials(&mut self) -> Result<(), Error> { - let creds = match self.take_until(&['@']) { - Some(creds) => creds, - None => return Ok(()), - }; - self.eat_byte(); - - let mut it = creds.splitn(2, ':'); - let user = self.decode(it.next().unwrap())?; - self.config.user(&user); - - if let Some(password) = it.next() { - let password = Cow::from(percent_encoding::percent_decode(password.as_bytes())); - self.config.password(password); - } - - Ok(()) - } - - fn parse_host(&mut self) -> Result<(), Error> { - let host = match self.take_until(&['/', '?']) { - Some(host) => host, - None => self.take_all(), - }; - - if host.is_empty() { - return Ok(()); - } - - for chunk in host.split(',') { - let (host, port) = if chunk.starts_with('[') { - let idx = match chunk.find(']') { - Some(idx) => idx, - None => return Err(Error::config_parse(InvalidValue("host").into())), - }; - - let host = &chunk[1..idx]; - let remaining = &chunk[idx + 1..]; - let port = if let Some(port) = remaining.strip_prefix(':') { - Some(port) - } else if remaining.is_empty() { - None - } else { - return Err(Error::config_parse(InvalidValue("host").into())); - }; - - (host, port) - } else { - let mut it = chunk.splitn(2, ':'); - (it.next().unwrap(), it.next()) - }; - - self.host_param(host)?; - let port = self.decode(port.unwrap_or("5432"))?; - self.config.param("port", &port)?; - } - - Ok(()) - } - - fn parse_path(&mut self) -> Result<(), Error> { - if !self.s.starts_with('/') { - return Ok(()); - } - self.eat_byte(); - - let dbname = match self.take_until(&['?']) { - Some(dbname) => dbname, - None => self.take_all(), - }; - - if !dbname.is_empty() { - self.config.dbname(&self.decode(dbname)?); - } - - Ok(()) - } - - fn parse_params(&mut self) -> Result<(), Error> { - if !self.s.starts_with('?') { - return Ok(()); - } - self.eat_byte(); - - while !self.s.is_empty() { - let key = match self.take_until(&['=']) { - Some(key) => self.decode(key)?, - None => return Err(Error::config_parse("unterminated parameter".into())), - }; - self.eat_byte(); - - let value = match self.take_until(&['&']) { - Some(value) => { - self.eat_byte(); - value - } - None => self.take_all(), - }; - - if key == "host" { - self.host_param(value)?; - } else { - let value = self.decode(value)?; - self.config.param(&key, &value)?; - } - } - - Ok(()) - } - - fn host_param(&mut self, s: &str) -> Result<(), Error> { - let s = self.decode(s)?; - self.config.param("host", &s) - } - - fn decode(&self, s: &'a str) -> Result, Error> { - percent_encoding::percent_decode(s.as_bytes()) - .decode_utf8() - .map_err(|e| Error::config_parse(e.into())) - } -} diff --git a/libs/proxy/tokio-postgres2/src/error/mod.rs b/libs/proxy/tokio-postgres2/src/error/mod.rs index 651432225009..922c348525c6 100644 --- a/libs/proxy/tokio-postgres2/src/error/mod.rs +++ b/libs/proxy/tokio-postgres2/src/error/mod.rs @@ -349,7 +349,6 @@ enum Kind { Parse, Encode, Authentication, - ConfigParse, Config, Connect, Timeout, @@ -386,7 +385,6 @@ impl fmt::Display for Error { Kind::Parse => fmt.write_str("error parsing response from server")?, Kind::Encode => fmt.write_str("error encoding message to server")?, Kind::Authentication => fmt.write_str("authentication error")?, - Kind::ConfigParse => fmt.write_str("invalid connection string")?, Kind::Config => fmt.write_str("invalid configuration")?, Kind::Connect => fmt.write_str("error connecting to server")?, Kind::Timeout => fmt.write_str("timeout waiting for server")?, @@ -482,10 +480,6 @@ impl Error { Error::new(Kind::Authentication, Some(e)) } - pub(crate) fn config_parse(e: Box) -> Error { - Error::new(Kind::ConfigParse, Some(e)) - } - pub(crate) fn config(e: Box) -> Error { Error::new(Kind::Config, Some(e)) } diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index 57c639a7de51..901ed0c96c68 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -13,14 +13,12 @@ pub use crate::query::RowStream; pub use crate::row::{Row, SimpleQueryRow}; pub use crate::simple_query::SimpleQueryStream; pub use crate::statement::{Column, Statement}; -use crate::tls::MakeTlsConnect; pub use crate::tls::NoTls; pub use crate::to_statement::ToStatement; pub use crate::transaction::Transaction; pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; use crate::types::ToSql; use postgres_protocol2::message::backend::ReadyForQueryBody; -use tokio::net::TcpStream; /// After executing a query, the connection will be in one of these states #[derive(Clone, Copy, Debug, PartialEq)] @@ -72,24 +70,6 @@ mod transaction; mod transaction_builder; pub mod types; -/// A convenience function which parses a connection string and connects to the database. -/// -/// See the documentation for [`Config`] for details on the connection string format. -/// -/// Requires the `runtime` Cargo feature (enabled by default). -/// -/// [`Config`]: config/struct.Config.html -pub async fn connect( - config: &str, - tls: T, -) -> Result<(Client, Connection), Error> -where - T: MakeTlsConnect, -{ - let config = config.parse::()?; - config.connect(tls).await -} - /// An asynchronous notification. #[derive(Clone, Debug)] pub struct Notification { From 4bbf95d6d1b7bb757c0923f071682b3c9e9d8754 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 3 Dec 2024 15:42:19 +0000 Subject: [PATCH 4/4] remove md5 support --- Cargo.lock | 1 - libs/proxy/postgres-protocol2/Cargo.toml | 1 - .../src/authentication/mod.rs | 35 ------------------- .../postgres-protocol2/src/message/backend.rs | 8 ++--- .../postgres-protocol2/src/password/mod.rs | 18 ---------- .../postgres-protocol2/src/password/test.rs | 8 ----- libs/proxy/tokio-postgres2/src/connect_raw.rs | 19 ++-------- 7 files changed, 4 insertions(+), 86 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ff5c9b95ae6a..5b80ec5e93d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4209,7 +4209,6 @@ dependencies = [ "bytes", "fallible-iterator", "hmac", - "md-5", "memchr", "rand 0.8.5", "sha2", diff --git a/libs/proxy/postgres-protocol2/Cargo.toml b/libs/proxy/postgres-protocol2/Cargo.toml index 284a632954fd..f71c1599c7c2 100644 --- a/libs/proxy/postgres-protocol2/Cargo.toml +++ b/libs/proxy/postgres-protocol2/Cargo.toml @@ -10,7 +10,6 @@ byteorder.workspace = true bytes.workspace = true fallible-iterator.workspace = true hmac.workspace = true -md-5 = "0.10" memchr = "2.0" rand.workspace = true sha2.workspace = true diff --git a/libs/proxy/postgres-protocol2/src/authentication/mod.rs b/libs/proxy/postgres-protocol2/src/authentication/mod.rs index 71afa4b9b60a..0bdc177143fb 100644 --- a/libs/proxy/postgres-protocol2/src/authentication/mod.rs +++ b/libs/proxy/postgres-protocol2/src/authentication/mod.rs @@ -1,37 +1,2 @@ //! Authentication protocol support. -use md5::{Digest, Md5}; - pub mod sasl; - -/// Hashes authentication information in a way suitable for use in response -/// to an `AuthenticationMd5Password` message. -/// -/// The resulting string should be sent back to the database in a -/// `PasswordMessage` message. -#[inline] -pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String { - let mut md5 = Md5::new(); - md5.update(password); - md5.update(username); - let output = md5.finalize_reset(); - md5.update(format!("{:x}", output)); - md5.update(salt); - format!("md5{:x}", md5.finalize()) -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn md5() { - let username = b"md5_user"; - let password = b"password"; - let salt = [0x2a, 0x3d, 0x8f, 0xe0]; - - assert_eq!( - md5_hash(username, password, salt), - "md562af4dd09bbb41884907a838a3233294" - ); - } -} diff --git a/libs/proxy/postgres-protocol2/src/message/backend.rs b/libs/proxy/postgres-protocol2/src/message/backend.rs index 33d77fc25261..097964f9c110 100644 --- a/libs/proxy/postgres-protocol2/src/message/backend.rs +++ b/libs/proxy/postgres-protocol2/src/message/backend.rs @@ -79,7 +79,7 @@ pub enum Message { AuthenticationCleartextPassword, AuthenticationGss, AuthenticationKerberosV5, - AuthenticationMd5Password(AuthenticationMd5PasswordBody), + AuthenticationMd5Password, AuthenticationOk, AuthenticationScmCredential, AuthenticationSspi, @@ -191,11 +191,7 @@ impl Message { 0 => Message::AuthenticationOk, 2 => Message::AuthenticationKerberosV5, 3 => Message::AuthenticationCleartextPassword, - 5 => { - let mut salt = [0; 4]; - buf.read_exact(&mut salt)?; - Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt }) - } + 5 => Message::AuthenticationMd5Password, 6 => Message::AuthenticationScmCredential, 7 => Message::AuthenticationGss, 8 => Message::AuthenticationGssContinue, diff --git a/libs/proxy/postgres-protocol2/src/password/mod.rs b/libs/proxy/postgres-protocol2/src/password/mod.rs index e669e80f3f22..38eb31dfcf99 100644 --- a/libs/proxy/postgres-protocol2/src/password/mod.rs +++ b/libs/proxy/postgres-protocol2/src/password/mod.rs @@ -8,7 +8,6 @@ use crate::authentication::sasl; use hmac::{Hmac, Mac}; -use md5::Md5; use rand::RngCore; use sha2::digest::FixedOutput; use sha2::{Digest, Sha256}; @@ -88,20 +87,3 @@ pub(crate) async fn scram_sha_256_salt( base64::encode(server_key) ) } - -/// **Not recommended, as MD5 is not considered to be secure.** -/// -/// Hash password using MD5 with the username as the salt. -/// -/// The client may assume the returned string doesn't contain any -/// special characters that would require escaping. -pub fn md5(password: &[u8], username: &str) -> String { - // salt password with username - let mut salted_password = Vec::from(password); - salted_password.extend_from_slice(username.as_bytes()); - - let mut hash = Md5::new(); - hash.update(&salted_password); - let digest = hash.finalize(); - format!("md5{:x}", digest) -} diff --git a/libs/proxy/postgres-protocol2/src/password/test.rs b/libs/proxy/postgres-protocol2/src/password/test.rs index c9d340f09d80..0692c07adbb1 100644 --- a/libs/proxy/postgres-protocol2/src/password/test.rs +++ b/libs/proxy/postgres-protocol2/src/password/test.rs @@ -9,11 +9,3 @@ async fn test_encrypt_scram_sha_256() { "SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA=" ); } - -#[test] -fn test_encrypt_md5() { - assert_eq!( - password::md5(b"secret", "foo"), - "md54ab2c5d00339c4b2a4e921d2dc4edec7" - ); -} diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index 9c6f1a255200..390f133002be 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -7,7 +7,6 @@ use crate::Error; use bytes::BytesMut; use fallible_iterator::FallibleIterator; use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt}; -use postgres_protocol2::authentication; use postgres_protocol2::authentication::sasl; use postgres_protocol2::authentication::sasl::ScramSha256; use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody}; @@ -174,25 +173,11 @@ where authenticate_password(stream, pass).await?; } - Some(Message::AuthenticationMd5Password(body)) => { - can_skip_channel_binding(config)?; - - let user = config - .user - .as_ref() - .ok_or_else(|| Error::config("user missing".into()))?; - let pass = config - .password - .as_ref() - .ok_or_else(|| Error::config("password missing".into()))?; - - let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); - authenticate_password(stream, output.as_bytes()).await?; - } Some(Message::AuthenticationSasl(body)) => { authenticate_sasl(stream, body, config).await?; } - Some(Message::AuthenticationKerberosV5) + Some(Message::AuthenticationMd5Password) + | Some(Message::AuthenticationKerberosV5) | Some(Message::AuthenticationScmCredential) | Some(Message::AuthenticationGss) | Some(Message::AuthenticationSspi) => {