diff --git a/src/admin.rs b/src/admin.rs index 6c83f9b2..f27b2a0d 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -1,4 +1,5 @@ use crate::pool::BanReason; +use crate::server::ServerParameters; use crate::stats::pool::PoolStats; use bytes::{Buf, BufMut, BytesMut}; use log::{error, info, trace}; @@ -17,16 +18,16 @@ use crate::pool::ClientServerMap; use crate::pool::{get_all_pools, get_pool}; use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState}; -pub fn generate_server_info_for_admin() -> BytesMut { - let mut server_info = BytesMut::new(); +pub fn generate_server_parameters_for_admin() -> ServerParameters { + let mut server_parameters = ServerParameters::new(); - server_info.put(server_parameter_message("application_name", "")); - server_info.put(server_parameter_message("client_encoding", "UTF8")); - server_info.put(server_parameter_message("server_encoding", "UTF8")); - server_info.put(server_parameter_message("server_version", VERSION)); - server_info.put(server_parameter_message("DateStyle", "ISO, MDY")); + server_parameters.set_param("application_name".to_string(), "".to_string(), true); + server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true); + server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true); + server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true); + server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true); - server_info + server_parameters } /// Handle admin client. diff --git a/src/client.rs b/src/client.rs index 4f5e6c96..6cdea987 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,7 +12,7 @@ use tokio::net::TcpStream; use tokio::sync::broadcast::Receiver; use tokio::sync::mpsc::Sender; -use crate::admin::{generate_server_info_for_admin, handle_admin}; +use crate::admin::{generate_server_parameters_for_admin, handle_admin}; use crate::auth_passthrough::refetch_auth_hash; use crate::config::{ get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode, @@ -22,7 +22,7 @@ use crate::messages::*; use crate::plugins::PluginOutput; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; -use crate::server::Server; +use crate::server::{Server, ServerParameters}; use crate::stats::{ClientStats, ServerStats}; use crate::tls::Tls; @@ -96,8 +96,8 @@ pub struct Client { /// Postgres user for this client (This comes from the user in the connection string) username: String, - /// Application name for this client (defaults to pgcat) - application_name: String, + /// Server startup and session parameters that we're going to track + server_parameters: ServerParameters, /// Used to notify clients about an impending shutdown shutdown: Receiver<()>, @@ -502,7 +502,7 @@ where }; // Authenticate admin user. - let (transaction_mode, server_info) = if admin { + let (transaction_mode, mut server_parameters) = if admin { let config = get_config(); // Compare server and client hashes. @@ -521,7 +521,7 @@ where return Err(error); } - (false, generate_server_info_for_admin()) + (false, generate_server_parameters_for_admin()) } // Authenticate normal user. else { @@ -654,13 +654,16 @@ where } } - (transaction_mode, pool.server_info()) + (transaction_mode, pool.server_parameters()) }; + // Update the parameters to merge what the application sent and what's originally on the server + server_parameters.set_from_hashmap(¶meters, false); + debug!("Password authentication successful"); auth_ok(&mut write).await?; - write_all(&mut write, server_info).await?; + write_all(&mut write, (&server_parameters).into()).await?; backend_key_data(&mut write, process_id, secret_key).await?; ready_for_query(&mut write).await?; @@ -690,7 +693,7 @@ where last_server_stats: None, pool_name: pool_name.clone(), username: username.clone(), - application_name: application_name.to_string(), + server_parameters, shutdown, connected_to_server: false, prepared_statements: HashMap::new(), @@ -725,7 +728,7 @@ where last_server_stats: None, pool_name: String::from("undefined"), username: String::from("undefined"), - application_name: String::from("undefined"), + server_parameters: ServerParameters::new(), shutdown, connected_to_server: false, prepared_statements: HashMap::new(), @@ -774,8 +777,11 @@ where let mut prepared_statement = None; let mut will_prepare = false; - let client_identifier = - ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name); + let client_identifier = ClientIdentifier::new( + &self.server_parameters.get_application_name(), + &self.username, + &self.pool_name, + ); // Our custom protocol loop. // We expect the client to either start a transaction with regular queries @@ -1115,10 +1121,7 @@ where server.address() ); - // TODO: investigate other parameters and set them too. - - // Set application_name. - server.set_name(&self.application_name).await?; + server.sync_parameters(&self.server_parameters).await?; let mut initial_message = Some(message); @@ -1296,7 +1299,9 @@ where if !server.in_transaction() { // Report transaction executed statistics. self.stats.transaction(); - server.stats().transaction(&self.application_name); + server + .stats() + .transaction(&self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. @@ -1446,7 +1451,9 @@ where if !server.in_transaction() { self.stats.transaction(); - server.stats().transaction(&self.application_name); + server + .stats() + .transaction(&self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. @@ -1495,7 +1502,9 @@ where if !server.in_transaction() { self.stats.transaction(); - server.stats().transaction(&self.application_name); + server + .stats() + .transaction(self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. @@ -1547,7 +1556,9 @@ where Err(Error::ClientError(format!( "Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}", - self.pool_name, self.username, self.application_name + self.pool_name, + self.username, + self.server_parameters.get_application_name() ))) } } @@ -1704,7 +1715,7 @@ where client_stats.query(); server.stats().query( Instant::now().duration_since(query_start).as_millis() as u64, - &self.application_name, + &self.server_parameters.get_application_name(), ); Ok(()) @@ -1733,38 +1744,18 @@ where pool: &ConnectionPool, client_stats: &ClientStats, ) -> Result { - if pool.settings.user.statement_timeout > 0 { - match tokio::time::timeout( - tokio::time::Duration::from_millis(pool.settings.user.statement_timeout), - server.recv(), - ) - .await - { - Ok(result) => match result { - Ok(message) => Ok(message), - Err(err) => { - pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats)); - error_response_terminal( - &mut self.write, - &format!("error receiving data from server: {:?}", err), - ) - .await?; - Err(err) - } - }, - Err(_) => { - error!( - "Statement timeout while talking to {:?} with user {}", - address, pool.settings.user.username - ); - server.mark_bad(); - pool.ban(address, BanReason::StatementTimeout, Some(client_stats)); - error_response_terminal(&mut self.write, "pool statement timeout").await?; - Err(Error::StatementTimeout) - } - } - } else { - match server.recv().await { + let statement_timeout_duration = match pool.settings.user.statement_timeout { + 0 => tokio::time::Duration::MAX, + timeout => tokio::time::Duration::from_millis(timeout), + }; + + match tokio::time::timeout( + statement_timeout_duration, + server.recv(Some(&mut self.server_parameters)), + ) + .await + { + Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats)); @@ -1775,6 +1766,16 @@ where .await?; Err(err) } + }, + Err(_) => { + error!( + "Statement timeout while talking to {:?} with user {}", + address, pool.settings.user.username + ); + server.mark_bad(); + pool.ban(address, BanReason::StatementTimeout, Some(client_stats)); + error_response_terminal(&mut self.write, "pool statement timeout").await?; + Err(Error::StatementTimeout) } } } diff --git a/src/messages.rs b/src/messages.rs index 1f40f1df..07fe9317 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -144,6 +144,10 @@ where bytes.put_slice(user.as_bytes()); bytes.put_u8(0); + // Application name + bytes.put(&b"application_name\0"[..]); + bytes.put_slice(&b"pgcat\0"[..]); + // Database bytes.put(&b"database\0"[..]); bytes.put_slice(database.as_bytes()); @@ -731,6 +735,21 @@ impl BytesMutReader for Cursor<&BytesMut> { } } +impl BytesMutReader for BytesMut { + /// Should only be used when reading strings from the message protocol. + /// Can be used to read multiple strings from the same message which are separated by the null byte + fn read_string(&mut self) -> Result { + let null_index = self.iter().position(|&byte| byte == b'\0'); + + match null_index { + Some(index) => { + let string_bytes = self.split_to(index + 1); + Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string()) + } + None => return Err(Error::ParseBytesError("Could not read string".to_string())), + } + } +} /// Parse (F) message. /// See: #[derive(Clone, Debug)] diff --git a/src/mirrors.rs b/src/mirrors.rs index 0f2b02c0..7922e6f8 100644 --- a/src/mirrors.rs +++ b/src/mirrors.rs @@ -78,7 +78,7 @@ impl MirroredClient { } // Incoming data from server (we read to clear the socket buffer and discard the data) - recv_result = server.recv() => { + recv_result = server.recv(None) => { match recv_result { Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()), Err(err) => { diff --git a/src/pool.rs b/src/pool.rs index dddb3ebe..b3627448 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,7 +1,6 @@ use arc_swap::ArcSwap; use async_trait::async_trait; use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy}; -use bytes::{BufMut, BytesMut}; use chrono::naive::NaiveDateTime; use log::{debug, error, info, warn}; use once_cell::sync::Lazy; @@ -25,7 +24,7 @@ use crate::errors::Error; use crate::auth_passthrough::AuthPassthrough; use crate::plugins::prewarmer; -use crate::server::Server; +use crate::server::{Server, ServerParameters}; use crate::sharding::ShardingFunction; use crate::stats::{AddressStats, ClientStats, ServerStats}; @@ -196,10 +195,10 @@ pub struct ConnectionPool { /// that should not be queried. banlist: BanList, - /// The server information (K messages) have to be passed to the + /// The server information has to be passed to the /// clients on startup. We pre-connect to all shards and replicas - /// on pool creation and save the K messages here. - server_info: Arc>, + /// on pool creation and save the startup parameters here. + original_server_parameters: Arc>, /// Pool configuration. pub settings: PoolSettings, @@ -445,7 +444,7 @@ impl ConnectionPool { addresses, banlist: Arc::new(RwLock::new(banlist)), config_hash: new_pool_hash_value, - server_info: Arc::new(RwLock::new(BytesMut::new())), + original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())), auth_hash: pool_auth_hash, settings: PoolSettings { pool_mode: match user.pool_mode { @@ -528,7 +527,7 @@ impl ConnectionPool { for server in 0..self.servers(shard) { let databases = self.databases.clone(); let validated = Arc::clone(&validated); - let pool_server_info = Arc::clone(&self.server_info); + let pool_server_parameters = Arc::clone(&self.original_server_parameters); let task = tokio::task::spawn(async move { let connection = match databases[shard][server].get().await { @@ -541,11 +540,10 @@ impl ConnectionPool { let proxy = connection; let server = &*proxy; - let server_info = server.server_info(); + let server_parameters: ServerParameters = server.server_parameters(); - let mut guard = pool_server_info.write(); - guard.clear(); - guard.put(server_info.clone()); + let mut guard = pool_server_parameters.write(); + *guard = server_parameters; validated.store(true, Ordering::Relaxed); }); @@ -557,7 +555,7 @@ impl ConnectionPool { // TODO: compare server information to make sure // all shards are running identical configurations. - if self.server_info.read().is_empty() { + if !self.validated() { error!("Could not validate connection pool"); return Err(Error::AllServersDown); } @@ -917,8 +915,8 @@ impl ConnectionPool { &self.addresses[shard][server] } - pub fn server_info(&self) -> BytesMut { - self.server_info.read().clone() + pub fn server_parameters(&self) -> ServerParameters { + self.original_server_parameters.read().clone() } fn busy_connection_count(&self, address: &Address) -> u32 { diff --git a/src/server.rs b/src/server.rs index 9d0beaac..c4d7a1af 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,12 +3,13 @@ use bytes::{Buf, BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use log::{debug, error, info, trace, warn}; +use once_cell::sync::Lazy; use parking_lot::{Mutex, RwLock}; use postgres_protocol::message; -use std::collections::{BTreeSet, HashMap}; -use std::io::Read; +use std::collections::{BTreeSet, HashMap, HashSet}; +use std::mem; use std::net::IpAddr; -use std::sync::Arc; +use std::sync::{Arc, Once}; use std::time::SystemTime; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::net::TcpStream; @@ -19,6 +20,7 @@ use crate::config::{get_config, get_prepared_statements_cache_size, Address, Use use crate::constants::*; use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::errors::{Error, ServerIdentifier}; +use crate::messages::BytesMutReader; use crate::messages::*; use crate::mirrors::MirroringManager; use crate::pool::ClientServerMap; @@ -145,6 +147,124 @@ impl std::fmt::Display for CleanupState { } } +static TRACKED_PARAMETERS: Lazy> = Lazy::new(|| { + let mut set = HashSet::new(); + set.insert("client_encoding".to_string()); + set.insert("DateStyle".to_string()); + set.insert("TimeZone".to_string()); + set.insert("standard_conforming_strings".to_string()); + set.insert("application_name".to_string()); + set +}); + +#[derive(Debug, Clone)] +pub struct ServerParameters { + parameters: HashMap, +} + +impl Default for ServerParameters { + fn default() -> Self { + Self::new() + } +} + +impl ServerParameters { + pub fn new() -> Self { + let mut server_parameters = ServerParameters { + parameters: HashMap::new(), + }; + + server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false); + server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false); + server_parameters.set_param("TimeZone".to_string(), "Etc/UTC".to_string(), false); + server_parameters.set_param( + "standard_conforming_strings".to_string(), + "on".to_string(), + false, + ); + server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false); + + server_parameters + } + + /// returns true if a tracked parameter was set, false if it was a non-tracked parameter + /// if startup is false, then then only tracked parameters will be set + pub fn set_param(&mut self, mut key: String, value: String, startup: bool) { + // The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys + if key == "timezone" { + key = "TimeZone".to_string(); + } else if key == "datestyle" { + key = "DateStyle".to_string(); + }; + + if TRACKED_PARAMETERS.contains(&key) { + self.parameters.insert(key, value); + } else { + if startup { + self.parameters.insert(key, value); + } + } + } + + pub fn set_from_hashmap(&mut self, parameters: &HashMap, startup: bool) { + // iterate through each and call set_param + for (key, value) in parameters { + self.set_param(key.to_string(), value.to_string(), startup); + } + } + + // Gets the diff of the parameters + fn compare_params(&self, incoming_parameters: &ServerParameters) -> HashMap { + let mut diff = HashMap::new(); + + // iterate through tracked parameters + for key in TRACKED_PARAMETERS.iter() { + if let Some(incoming_value) = incoming_parameters.parameters.get(key) { + if let Some(value) = self.parameters.get(key) { + if value != incoming_value { + diff.insert(key.to_string(), incoming_value.to_string()); + } + } + } + } + + diff + } + + pub fn get_application_name(&self) -> &String { + // Can unwrap because we set it in the constructor + self.parameters.get("application_name").unwrap() + } + + fn add_parameter_message(key: &str, value: &str, buffer: &mut BytesMut) { + buffer.put_u8(b'S'); + + // 4 is len of i32, the plus for the null terminator + let len = 4 + key.len() + 1 + value.len() + 1; + + buffer.put_i32(len as i32); + + buffer.put_slice(key.as_bytes()); + buffer.put_u8(0); + buffer.put_slice(value.as_bytes()); + buffer.put_u8(0); + } +} + +impl From<&ServerParameters> for BytesMut { + fn from(server_parameters: &ServerParameters) -> Self { + let mut bytes = BytesMut::new(); + + for (key, value) in &server_parameters.parameters { + ServerParameters::add_parameter_message(key, value, &mut bytes); + } + + bytes + } +} + +// pub fn compare + /// Server state. pub struct Server { /// Server host, e.g. localhost, @@ -158,7 +278,7 @@ pub struct Server { buffer: BytesMut, /// Server information the server sent us over on startup. - server_info: BytesMut, + server_parameters: ServerParameters, /// Backend id and secret key used for query cancellation. process_id: i32, @@ -347,7 +467,6 @@ impl Server { startup(&mut stream, username, database).await?; - let mut server_info = BytesMut::new(); let mut process_id: i32 = 0; let mut secret_key: i32 = 0; let server_identifier = ServerIdentifier::new(username, &database); @@ -359,6 +478,8 @@ impl Server { None => None, }; + let mut server_parameters = ServerParameters::new(); + loop { let code = match stream.read_u8().await { Ok(code) => code as char, @@ -616,9 +737,10 @@ impl Server { // ParameterStatus 'S' => { - let mut param = vec![0u8; len as usize - 4]; + let mut bytes = BytesMut::with_capacity(len as usize - 4); + bytes.resize(len as usize - mem::size_of::(), b'0'); - match stream.read_exact(&mut param).await { + match stream.read_exact(&mut bytes[..]).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -628,12 +750,13 @@ impl Server { } }; + let key = bytes.read_string().unwrap(); + let value = bytes.read_string().unwrap(); + // Save the parameter so we can pass it to the client later. // These can be server_encoding, client_encoding, server timezone, Postgres version, // and many more interesting things we should know about the Postgres server we are talking to. - server_info.put_u8(b'S'); - server_info.put_i32(len); - server_info.put_slice(¶m[..]); + server_parameters.set_param(key, value, true); } // BackendKeyData @@ -675,11 +798,11 @@ impl Server { } }; - let mut server = Server { + let server = Server { address: address.clone(), stream: BufStream::new(stream), buffer: BytesMut::with_capacity(8196), - server_info, + server_parameters, process_id, secret_key, in_transaction: false, @@ -691,7 +814,7 @@ impl Server { addr_set, connected_at: chrono::offset::Utc::now().naive_utc(), stats, - application_name: String::new(), + application_name: "pgcat".to_string(), last_activity: SystemTime::now(), mirror_manager: match address.mirrors.len() { 0 => None, @@ -705,8 +828,6 @@ impl Server { prepared_statements: BTreeSet::new(), }; - server.set_name("pgcat").await?; - return Ok(server); } @@ -776,7 +897,10 @@ impl Server { /// Receive data from the server in response to a client request. /// This method must be called multiple times while `self.is_data_available()` is true /// in order to receive all data the server has to offer. - pub async fn recv(&mut self) -> Result { + pub async fn recv( + &mut self, + mut client_server_parameters: Option<&mut ServerParameters>, + ) -> Result { loop { let mut message = match read_message(&mut self.stream).await { Ok(message) => message, @@ -848,14 +972,13 @@ impl Server { self.in_copy_mode = false; } - let mut command_tag = String::new(); - match message.reader().read_to_string(&mut command_tag) { - Ok(_) => { + match message.read_string() { + Ok(command) => { // Non-exhaustive list of commands that are likely to change session variables/resources // which can leak between clients. This is a best effort to block bad clients // from poisoning a transaction-mode pool by setting inappropriate session variables - match command_tag.as_str() { - "SET\0" => { + match command.as_str() { + "SET" => { // We don't detect set statements in transactions // No great way to differentiate between set and set local // As a result, we will miss cases when set statements are used in transactions @@ -865,7 +988,8 @@ impl Server { self.cleanup_state.needs_cleanup_set = true; } } - "PREPARE\0" => { + + "PREPARE" => { debug!("Server connection marked for clean up"); self.cleanup_state.needs_cleanup_prepare = true; } @@ -879,6 +1003,17 @@ impl Server { } } + 'S' => { + let key = message.read_string().unwrap(); + let value = message.read_string().unwrap(); + + if let Some(client_server_parameters) = client_server_parameters.as_mut() { + client_server_parameters.set_param(key.clone(), value.clone(), false); + } + + self.server_parameters.set_param(key, value, false); + } + // DataRow 'D' => { // More data is available after this message, this is not the end of the reply. @@ -1089,9 +1224,28 @@ impl Server { } /// Get server startup information to forward it to the client. - /// Not used at the moment. - pub fn server_info(&self) -> BytesMut { - self.server_info.clone() + pub fn server_parameters(&self) -> ServerParameters { + self.server_parameters.clone() + } + + pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> { + let parameter_diff = self.server_parameters.compare_params(parameters); + + if parameter_diff.is_empty() { + return Ok(()); + } + + let mut query = String::from(""); + + for (key, value) in parameter_diff { + query.push_str(&format!("SET {} TO '{}';", key, value)); + } + + let res = self.query(&query).await; + + self.cleanup_state.reset(); + + res } /// Indicate that this server connection cannot be re-used and must be discarded. @@ -1125,7 +1279,7 @@ impl Server { self.send(&query).await?; loop { - let _ = self.recv().await?; + let _ = self.recv(None).await?; if !self.data_available { break; @@ -1166,24 +1320,6 @@ impl Server { Ok(()) } - /// A shorthand for `SET application_name = $1`. - pub async fn set_name(&mut self, name: &str) -> Result<(), Error> { - if self.application_name != name { - self.application_name = name.to_string(); - // We don't want `SET application_name` to mark the server connection - // as needing cleanup - let needs_cleanup_before = self.cleanup_state; - - let result = Ok(self - .query(&format!("SET application_name = '{}'", name)) - .await?); - self.cleanup_state = needs_cleanup_before; - result - } else { - Ok(()) - } - } - /// get Server stats pub fn stats(&self) -> Arc { self.stats.clone() @@ -1241,7 +1377,7 @@ impl Server { .await?; debug!("Connected!, sending query."); server.send(&simple_query(query)).await?; - let mut message = server.recv().await?; + let mut message = server.recv(None).await?; Ok(parse_query_message(&mut message).await?) } diff --git a/tests/ruby/helpers/pgcat_process.rb b/tests/ruby/helpers/pgcat_process.rb index e1dbea8b..dd3fd052 100644 --- a/tests/ruby/helpers/pgcat_process.rb +++ b/tests/ruby/helpers/pgcat_process.rb @@ -112,10 +112,16 @@ def admin_connection_string "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat" end - def connection_string(pool_name, username, password = nil) + def connection_string(pool_name, username, password = nil, parameters: {}) cfg = current_config user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username } - "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}" + connection_string = "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}" + + # Add the additional parameters to the connection string + parameter_string = parameters.map { |key, value| "#{key}=#{value}" }.join("&") + connection_string += "?#{parameter_string}" unless parameter_string.empty? + + connection_string end def example_connection_string diff --git a/tests/ruby/misc_spec.rb b/tests/ruby/misc_spec.rb index fe216e5b..628680bd 100644 --- a/tests/ruby/misc_spec.rb +++ b/tests/ruby/misc_spec.rb @@ -294,6 +294,30 @@ expect(processes.primary.count_query("DISCARD ALL")).to eq(10) end + + it "Respects tracked parameters on startup" do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user", parameters: { "application_name" => "my_pgcat_test" })) + + expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test") + conn.close + end + + it "Respect tracked parameter on set statemet" do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + + conn.async_exec("SET application_name to 'my_pgcat_test'") + expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test") + end + + + it "Ignore untracked parameter on set statemet" do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + orignal_statement_timeout = conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"] + + conn.async_exec("SET statement_timeout to 1500") + expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq(orignal_statement_timeout) + end + end context "transaction mode with transactions" do