From b9903dcaa2fffd5b0d6df39ac6aacb4bae69f7b5 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Thu, 17 Oct 2024 16:39:19 +0200 Subject: [PATCH] fix(schema-engine): Ensure WS migrations can use shadow database (#5021) * fix(schema-engine): Ensure WS migrations can use shadow database Previousy, when creating shadow DB connection over websocket, we connected to the same DB which broke in every `migrate` case except the one that starts with clean migration history. This PR ensures it works normally. Implementation is quite cursed though: for WS we now allow to override db name via `dbname` query string parameter. If set, we ignore `dbname` that we got from migration server and use provided DB with the same username and password. Shadow DB then uses this query string parameter to specify the url. Close ORM-325 * Clippy + rustfmt --- .../connector/postgres/native/websocket.rs | 6 ++++- quaint/src/connector/postgres/url.rs | 21 ++++++++++++++-- .../src/flavour/postgres.rs | 25 ++++++++++++++++--- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs index 7899e9a22ec0..65b3c74a4fe8 100644 --- a/quaint/src/connector/postgres/native/websocket.rs +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -23,12 +23,16 @@ const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; const HOST_HEADER: &str = "Prisma-Db-Host"; pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let db_name = url.overriden_db_name().map(ToOwned::to_owned); let (ws_stream, response) = connect_async(url).await?; let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?; let db_host = require_header_value(response.headers(), HOST_HEADER)?; - let config = Config::from_str(connection_params)?; + let mut config = Config::from_str(connection_params)?; + if let Some(db_name) = db_name { + config.dbname(&db_name); + } let ws_byte_stream = WsStream::new(ws_stream); let tls = TlsConnector::new(native_tls::TlsConnector::new()?, db_host); diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 703aff33ebb5..096484cdc87a 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -81,7 +81,7 @@ impl PostgresUrl { pub fn dbname(&self) -> &str { match self { Self::Native(url) => url.dbname(), - Self::WebSocket(_) => "postgres", + Self::WebSocket(url) => url.dbname(), } } @@ -493,17 +493,34 @@ pub(crate) struct PostgresUrlQueryParams { pub struct PostgresWebSocketUrl { pub(crate) url: Url, pub(crate) api_key: String, + pub(crate) db_name: Option, } impl PostgresWebSocketUrl { pub fn new(url: Url, api_key: String) -> Self { - Self { url, api_key } + Self { + url, + api_key, + db_name: None, + } + } + + pub fn override_db_name(&mut self, name: String) { + self.db_name = Some(name) } pub fn api_key(&self) -> &str { &self.api_key } + pub fn dbname(&self) -> &str { + self.overriden_db_name().unwrap_or("postgres") + } + + pub fn overriden_db_name(&self) -> Option<&str> { + self.db_name.as_deref() + } + pub fn host(&self) -> &str { self.url.host_str().unwrap_or("localhost") } diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index ac704459b5a9..02752e491eeb 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -6,7 +6,11 @@ use crate::SqlFlavour; use enumflags2::BitFlags; use indoc::indoc; use once_cell::sync::Lazy; -use quaint::{connector::PostgresUrl, prelude::NativeConnectionInfo, Value}; +use quaint::{ + connector::{PostgresUrl, PostgresWebSocketUrl}, + prelude::NativeConnectionInfo, + Value, +}; use schema_connector::{ migrations_directory::MigrationDirectory, BoxFuture, ConnectorError, ConnectorParams, ConnectorResult, Namespaces, }; @@ -41,6 +45,7 @@ static MIGRATE_WS_BASE_URL: Lazy> = Lazy::new(|| { impl MigratePostgresUrl { const WEBSOCKET_SCHEME: &'static str = "prisma+postgres"; const API_KEY_PARAM: &'static str = "api_key"; + const DBNAME_PARAM: &'static str = "dbname"; fn new(url: Url) -> ConnectorResult { let postgres_url = if url.scheme() == Self::WEBSOCKET_SCHEME { @@ -50,7 +55,14 @@ impl MigratePostgresUrl { "Required `api_key` query string parameter was not provided in a connection URL", )); }; - PostgresUrl::new_websocket(ws_url, api_key.into_owned()) + + let dbname_override = url.query_pairs().find(|(name, _)| name == Self::DBNAME_PARAM); + let mut ws_url = PostgresWebSocketUrl::new(ws_url, api_key.into_owned()); + if let Some((_, dbname_override)) = dbname_override { + ws_url.override_db_name(dbname_override.into_owned()); + } + + Ok(PostgresUrl::WebSocket(ws_url)) } else { PostgresUrl::new_native(url) } @@ -514,7 +526,14 @@ impl SqlFlavour for PostgresFlavour { .connection_string .parse() .map_err(ConnectorError::url_parse_error)?; - shadow_database_url.set_path(&format!("/{shadow_database_name}")); + + if shadow_database_url.scheme() == MigratePostgresUrl::WEBSOCKET_SCHEME { + shadow_database_url + .query_pairs_mut() + .append_pair(MigratePostgresUrl::DBNAME_PARAM, &shadow_database_name); + } else { + shadow_database_url.set_path(&format!("/{shadow_database_name}")); + } let shadow_db_params = ConnectorParams { connection_string: shadow_database_url.to_string(), preview_features: params.connector_params.preview_features,