diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 645ff94bf..4202eb5ed 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -19,8 +19,8 @@ PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i # Install Toxiproxy to simulate a downed/slow database -wget -O toxiproxy-2.1.4.deb https://github.com/Shopify/toxiproxy/releases/download/v2.1.4/toxiproxy_2.1.4_amd64.deb -sudo dpkg -i toxiproxy-2.1.4.deb +wget -O toxiproxy-2.4.0.deb https://github.com/Shopify/toxiproxy/releases/download/v2.4.0/toxiproxy_2.4.0_linux_$(dpkg --print-architecture).deb +sudo dpkg -i toxiproxy-2.4.0.deb # Start Toxiproxy toxiproxy-server & @@ -129,11 +129,14 @@ toxiproxy-cli toxic remove --toxicName latency_downstream postgres_replica start_pgcat "info" # Test session mode (and config reload) -sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' .circleci/pgcat.toml +sed -i '0,/simple_db/s/pool_mode = "transaction"/pool_mode = "session"/' .circleci/pgcat.toml # Reload config test kill -SIGHUP $(pgrep pgcat) +# Revert settings after reload. Makes test runs idempotent +sed -i '0,/simple_db/s/pool_mode = "session"/pool_mode = "transaction"/' .circleci/pgcat.toml + sleep 1 # Prepared statements that will only work in session mode diff --git a/src/admin.rs b/src/admin.rs index 6a79e49ee..4576d1681 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -265,11 +265,11 @@ where for (_, pool) in get_all_pools() { let pool_config = pool.settings.clone(); for shard in 0..pool.shards() { - let database_name = &pool_config.shards[&shard.to_string()].database; + let database_name = &pool.address(shard, 0).database; for server in 0..pool.servers(shard) { let address = pool.address(shard, server); let pool_state = pool.pool_state(shard, server); - let banned = pool.is_banned(address, shard, Some(address.role)); + let banned = pool.is_banned(address, Some(address.role)); res.put(data_row(&vec![ address.name(), // name diff --git a/src/client.rs b/src/client.rs index c4866a089..419448fb8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,13 +5,14 @@ use std::collections::HashMap; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio::sync::broadcast::Receiver; +use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::config::{get_config, Address}; use crate::constants::*; use crate::errors::Error; use crate::messages::*; -use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; +use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolMode}; use crate::query_router::{Command, QueryRouter}; use crate::server::Server; use crate::stats::{get_reporter, Reporter}; @@ -58,7 +59,6 @@ pub struct Client { client_server_map: ClientServerMap, /// Client parameters, e.g. user, client_encoding, etc. - #[allow(dead_code)] parameters: HashMap, /// Statistics @@ -77,20 +77,22 @@ pub struct Client { connected_to_server: bool, /// Name of the server pool for this client (This comes from the database name in the connection string) - target_pool_name: String, + pool_name: String, /// Postgres user for this client (This comes from the user in the connection string) - target_user_name: String, + username: String, /// Used to notify clients about an impending shutdown - shutdown_event_receiver: Receiver<()>, + shutdown: Receiver<()>, } /// Client entrypoint. pub async fn client_entrypoint( mut stream: TcpStream, client_server_map: ClientServerMap, - shutdown_event_receiver: Receiver<()>, + shutdown: Receiver<()>, + drain: Sender, + admin_only: bool, ) -> Result<(), Error> { // Figure out if the client wants TLS or not. let addr = stream.peer_addr().unwrap(); @@ -109,11 +111,21 @@ pub async fn client_entrypoint( write_all(&mut stream, yes).await?; // Negotiate TLS. - match startup_tls(stream, client_server_map, shutdown_event_receiver).await { + match startup_tls(stream, client_server_map, shutdown, admin_only).await { Ok(mut client) => { info!("Client {:?} connected (TLS)", addr); - client.handle().await + if !client.is_admin() { + let _ = drain.send(1).await; + } + + let result = client.handle().await; + + if !client.is_admin() { + let _ = drain.send(-1).await; + } + + result } Err(err) => Err(err), } @@ -139,14 +151,25 @@ pub async fn client_entrypoint( addr, bytes, client_server_map, - shutdown_event_receiver, + shutdown, + admin_only, ) .await { Ok(mut client) => { info!("Client {:?} connected (plain)", addr); - client.handle().await + if !client.is_admin() { + let _ = drain.send(1).await; + } + + let result = client.handle().await; + + if !client.is_admin() { + let _ = drain.send(-1).await; + } + + result } Err(err) => Err(err), } @@ -169,14 +192,25 @@ pub async fn client_entrypoint( addr, bytes, client_server_map, - shutdown_event_receiver, + shutdown, + admin_only, ) .await { Ok(mut client) => { info!("Client {:?} connected (plain)", addr); - client.handle().await + if client.is_admin() { + let _ = drain.send(1).await; + } + + let result = client.handle().await; + + if !client.is_admin() { + let _ = drain.send(-1).await; + } + + result } Err(err) => Err(err), } @@ -187,20 +221,21 @@ pub async fn client_entrypoint( let (read, write) = split(stream); // Continue with cancel query request. - match Client::cancel( - read, - write, - addr, - bytes, - client_server_map, - shutdown_event_receiver, - ) - .await - { + match Client::cancel(read, write, addr, bytes, client_server_map, shutdown).await { Ok(mut client) => { info!("Client {:?} issued a cancel query request", addr); - client.handle().await + if client.is_admin() { + let _ = drain.send(1).await; + } + + let result = client.handle().await; + + if !client.is_admin() { + let _ = drain.send(-1).await; + } + + result } Err(err) => Err(err), @@ -253,7 +288,8 @@ where pub async fn startup_tls( stream: TcpStream, client_server_map: ClientServerMap, - shutdown_event_receiver: Receiver<()>, + shutdown: Receiver<()>, + admin_only: bool, ) -> Result>, WriteHalf>>, Error> { // Negotiate TLS. let tls = Tls::new()?; @@ -283,7 +319,8 @@ pub async fn startup_tls( addr, bytes, client_server_map, - shutdown_event_receiver, + shutdown, + admin_only, ) .await } @@ -298,6 +335,10 @@ where S: tokio::io::AsyncRead + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin, { + pub fn is_admin(&self) -> bool { + self.admin + } + /// Handle Postgres client startup after TLS negotiation is complete /// or over plain text. pub async fn startup( @@ -306,29 +347,44 @@ where addr: std::net::SocketAddr, bytes: BytesMut, // The rest of the startup message. client_server_map: ClientServerMap, - shutdown_event_receiver: Receiver<()>, + shutdown: Receiver<()>, + admin_only: bool, ) -> Result, Error> { let config = get_config(); let stats = get_reporter(); - - trace!("Got StartupMessage"); let parameters = parse_startup(bytes.clone())?; - let target_pool_name = match parameters.get("database") { + + // These two parameters are mandatory by the protocol. + let pool_name = match parameters.get("database") { Some(db) => db, None => return Err(Error::ClientError), }; - let target_user_name = match parameters.get("user") { + let username = match parameters.get("user") { Some(user) => user, None => return Err(Error::ClientError), }; let admin = ["pgcat", "pgbouncer"] .iter() - .filter(|db| *db == &target_pool_name) + .filter(|db| *db == &pool_name) .count() == 1; + // Kick any client that's not admin while we're in admin-only mode. + if !admin && admin_only { + debug!( + "Rejecting non-admin connection to {} when in admin only mode", + pool_name + ); + error_response_terminal( + &mut write, + &format!("terminating connection due to administrator command"), + ) + .await?; + return Err(Error::ShuttingDown); + } + // Generate random backend ID and secret key let process_id: i32 = rand::random(); let secret_key: i32 = rand::random(); @@ -360,46 +416,55 @@ where Err(_) => return Err(Error::SocketError), }; + // Authenticate admin user. let (transaction_mode, server_info) = if admin { - let correct_user = config.general.admin_username.as_str(); - let correct_password = config.general.admin_password.as_str(); - // Compare server and client hashes. - let password_hash = md5_hash_password(correct_user, correct_password, &salt); + let password_hash = md5_hash_password( + &config.general.admin_username, + &config.general.admin_password, + &salt, + ); + if password_hash != password_response { debug!("Password authentication failed"); - wrong_password(&mut write, target_user_name).await?; + wrong_password(&mut write, username).await?; + return Err(Error::ClientError); } (false, generate_server_info_for_admin()) - } else { - let target_pool = match get_pool(target_pool_name.clone(), target_user_name.clone()) { + } + // Authenticate normal user. + else { + let pool = match get_pool(pool_name.clone(), username.clone()) { Some(pool) => pool, None => { error_response( &mut write, &format!( "No pool configured for database: {:?}, user: {:?}", - target_pool_name, target_user_name + pool_name, username ), ) .await?; + return Err(Error::ClientError); } }; - let transaction_mode = target_pool.settings.pool_mode == "transaction"; - let server_info = target_pool.server_info(); + // Compare server and client hashes. - let correct_password = target_pool.settings.user.password.as_str(); - let password_hash = md5_hash_password(&target_user_name, correct_password, &salt); + let password_hash = md5_hash_password(&username, &pool.settings.user.password, &salt); if password_hash != password_response { debug!("Password authentication failed"); - wrong_password(&mut write, &target_user_name).await?; + wrong_password(&mut write, username).await?; + return Err(Error::ClientError); } - (transaction_mode, server_info) + + let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; + + (transaction_mode, pool.server_info()) }; debug!("Password authentication successful"); @@ -411,27 +476,24 @@ where trace!("Startup OK"); - // Split the read and write streams - // so we can control buffering. - return Ok(Client { read: BufReader::new(read), write: write, addr, buffer: BytesMut::with_capacity(8196), cancel_mode: false, - transaction_mode: transaction_mode, - process_id: process_id, - secret_key: secret_key, - client_server_map: client_server_map, + transaction_mode, + process_id, + secret_key, + client_server_map, parameters: parameters.clone(), stats: stats, admin: admin, last_address_id: None, last_server_id: None, - target_pool_name: target_pool_name.clone(), - target_user_name: target_user_name.clone(), - shutdown_event_receiver: shutdown_event_receiver, + pool_name: pool_name.clone(), + username: username.clone(), + shutdown, connected_to_server: false, }); } @@ -443,7 +505,7 @@ where addr: std::net::SocketAddr, mut bytes: BytesMut, // The rest of the startup message. client_server_map: ClientServerMap, - shutdown_event_receiver: Receiver<()>, + shutdown: Receiver<()>, ) -> Result, Error> { let process_id = bytes.get_i32(); let secret_key = bytes.get_i32(); @@ -454,17 +516,17 @@ where buffer: BytesMut::with_capacity(8196), cancel_mode: true, transaction_mode: false, - process_id: process_id, - secret_key: secret_key, - client_server_map: client_server_map, + process_id, + secret_key, + client_server_map, parameters: HashMap::new(), stats: get_reporter(), admin: false, last_address_id: None, last_server_id: None, - target_pool_name: String::from("undefined"), - target_user_name: String::from("undefined"), - shutdown_event_receiver: shutdown_event_receiver, + pool_name: String::from("undefined"), + username: String::from("undefined"), + shutdown, connected_to_server: false, }); } @@ -486,7 +548,7 @@ where process_id.clone(), secret_key.clone(), address.clone(), - port.clone(), + *port, ), // The client doesn't know / got the wrong server, @@ -498,7 +560,7 @@ where // Opens a new separate connection to the server, sends the backend_id // and secret_key and then closes it for security reasons. No other interactions // take place. - return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); + return Ok(Server::cancel(&address, port, process_id, secret_key).await?); } // The query router determines where the query is going to go, @@ -521,9 +583,19 @@ where // SET SHARDING KEY TO 'bigint'; let mut message = tokio::select! { - _ = self.shutdown_event_receiver.recv() => { - error_response_terminal(&mut self.write, &format!("terminating connection due to administrator command")).await?; - return Ok(()) + _ = self.shutdown.recv() => { + if !self.admin { + error_response_terminal( + &mut self.write, + &format!("terminating connection due to administrator command") + ).await?; + return Ok(()) + } + + // Admin clients ignore shutdown. + else { + read_message(&mut self.read).await? + } }, message_result = read_message(&mut self.read) => message_result? }; @@ -544,15 +616,14 @@ where // Get a pool instance referenced by the most up-to-date // pointer. This ensures we always read the latest config // when starting a query. - let pool = match get_pool(self.target_pool_name.clone(), self.target_user_name.clone()) - { + let pool = match get_pool(self.pool_name.clone(), self.username.clone()) { Some(pool) => pool, None => { error_response( &mut self.write, &format!( "No pool configured for database: {:?}, user: {:?}", - self.target_pool_name, self.target_user_name + self.pool_name, self.username ), ) .await?; @@ -643,9 +714,22 @@ where conn } Err(err) => { + // Clients do not expect to get SystemError followed by ReadyForQuery in the middle + // of extended protocol submission. So we will hold off on sending the actual error + // message to the client until we get 'S' message + match message[0] as char { + 'P' | 'B' | 'E' | 'D' => (), + _ => { + error_response( + &mut self.write, + "could not get connection from the pool", + ) + .await?; + } + }; + error!("Could not get connection from pool: {:?}", err); - error_response(&mut self.write, "could not get connection from the pool") - .await?; + continue; } }; @@ -728,15 +812,8 @@ where 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop( - code, - original, - server, - &address, - query_router.shard(), - &pool, - ) - .await?; + self.send_and_receive_loop(code, original, server, &address, &pool) + .await?; if !server.in_transaction() { // Report transaction executed statistics. @@ -803,7 +880,6 @@ where self.buffer.clone(), server, &address, - query_router.shard(), &pool, ) .await?; @@ -825,32 +901,18 @@ where 'd' => { // Forward the data to the server, // don't buffer it since it can be rather large. - self.send_server_message( - server, - original, - &address, - query_router.shard(), - &pool, - ) - .await?; + self.send_server_message(server, original, &address, &pool) + .await?; } // CopyDone or CopyFail // Copy is done, successfully or not. 'c' | 'f' => { - self.send_server_message( - server, - original, - &address, - query_router.shard(), - &pool, - ) - .await?; - - let response = self - .receive_server_message(server, &address, query_router.shard(), &pool) + self.send_server_message(server, original, &address, &pool) .await?; + let response = self.receive_server_message(server, &address, &pool).await?; + match write_all_half(&mut self.write, response).await { Ok(_) => (), Err(err) => { @@ -899,20 +961,17 @@ where message: BytesMut, server: &mut Server, address: &Address, - shard: usize, pool: &ConnectionPool, ) -> Result<(), Error> { debug!("Sending {} to server", code); - self.send_server_message(server, message, &address, shard, &pool) + self.send_server_message(server, message, &address, &pool) .await?; // Read all data the server has to offer, which can be multiple messages // buffered in 8196 bytes chunks. loop { - let response = self - .receive_server_message(server, &address, shard, &pool) - .await?; + let response = self.receive_server_message(server, &address, &pool).await?; match write_all_half(&mut self.write, response).await { Ok(_) => (), @@ -938,13 +997,12 @@ where server: &mut Server, message: BytesMut, address: &Address, - shard: usize, pool: &ConnectionPool, ) -> Result<(), Error> { match server.send(message).await { Ok(_) => Ok(()), Err(err) => { - pool.ban(address, shard, self.process_id); + pool.ban(address, self.process_id); Err(err) } } @@ -954,7 +1012,6 @@ where &mut self, server: &mut Server, address: &Address, - shard: usize, pool: &ConnectionPool, ) -> Result { if pool.settings.user.statement_timeout > 0 { @@ -967,7 +1024,7 @@ where Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, shard, self.process_id); + pool.ban(address, self.process_id); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), @@ -982,7 +1039,7 @@ where address, pool.settings.user.username ); server.mark_bad(); - pool.ban(address, shard, self.process_id); + pool.ban(address, self.process_id); error_response_terminal(&mut self.write, "pool statement timeout").await?; Err(Error::StatementTimeout) } @@ -991,7 +1048,7 @@ where match server.recv().await { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, shard, self.process_id); + pool.ban(address, self.process_id); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), diff --git a/src/config.rs b/src/config.rs index b75169373..5c1226117 100644 --- a/src/config.rs +++ b/src/config.rs @@ -57,15 +57,34 @@ impl PartialEq for Option { /// Address identifying a PostgreSQL server uniquely. #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)] pub struct Address { + /// Unique ID per addressable Postgres server. pub id: usize, + + /// Server host. pub host: String, - pub port: String, + + /// Server port. + pub port: u16, + + /// Shard number of this Postgres server. pub shard: usize, + + /// The name of the Postgres database. pub database: String, + + /// Server role: replica, primary. pub role: Role, + + /// If it's a replica, number it for reference and failover. pub replica_number: usize, + + /// Position of the server in the pool for failover. pub address_index: usize, + + /// The name of the user configured to use this pool. pub username: String, + + /// The name of this pool (i.e. database name visible to the client). pub pool_name: String, } @@ -74,7 +93,7 @@ impl Default for Address { Address { id: 0, host: String::from("127.0.0.1"), - port: String::from("5432"), + port: 5432, shard: 0, address_index: 0, replica_number: 0, @@ -341,9 +360,9 @@ impl Config { for (pool_name, pool_config) in &self.pools { // TODO: Make this output prettier (maybe a table?) - info!("--- Settings for pool {} ---", pool_name); info!( - "Pool size from all users: {}", + "[pool: {}] Maximum user connections: {}", + pool_name, pool_config .users .iter() @@ -351,20 +370,39 @@ impl Config { .sum::() .to_string() ); - info!("Pool mode: {}", pool_config.pool_mode); - info!("Sharding function: {}", pool_config.sharding_function); - info!("Primary reads: {}", pool_config.primary_reads_enabled); - info!("Query router: {}", pool_config.query_parser_enabled); - - // TODO: Make this prettier. - info!("Number of shards: {}", pool_config.shards.len()); - info!("Number of users: {}", pool_config.users.len()); + info!("[pool: {}] Pool mode: {}", pool_name, pool_config.pool_mode); + info!( + "[pool: {}] Sharding function: {}", + pool_name, pool_config.sharding_function + ); + info!( + "[pool: {}] Primary reads: {}", + pool_name, pool_config.primary_reads_enabled + ); + info!( + "[pool: {}] Query router: {}", + pool_name, pool_config.query_parser_enabled + ); + info!( + "[pool: {}] Number of shards: {}", + pool_name, + pool_config.shards.len() + ); + info!( + "[pool: {}] Number of users: {}", + pool_name, + pool_config.users.len() + ); for user in &pool_config.users { info!( - "{} pool size: {}, statement timeout: {}", - user.1.username, user.1.pool_size, user.1.statement_timeout + "[pool: {}][user: {}] Pool size: {}", + pool_name, user.1.username, user.1.pool_size, ); + info!( + "[pool: {}][user: {}] Statement timeout: {}", + pool_name, user.1.username, user.1.statement_timeout + ) } } } @@ -462,6 +500,18 @@ pub async fn parse(path: &str) -> Result<(), Error> { } }; + match pool.pool_mode.as_ref() { + "transaction" => (), + "session" => (), + other => { + error!( + "pool_mode can be 'session' or 'transaction', got: '{}'", + other + ); + return Err(Error::BadConfig); + } + }; + for shard in &pool.shards { // We use addresses as unique identifiers, // let's make sure they are unique in the config as well. diff --git a/src/errors.rs b/src/errors.rs index 06371fd19..50301f366 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -12,4 +12,5 @@ pub enum Error { ClientError, TlsError, StatementTimeout, + ShuttingDown, } diff --git a/src/main.rs b/src/main.rs index 0b2e1d59f..9aad61a13 100644 --- a/src/main.rs +++ b/src/main.rs @@ -66,6 +66,7 @@ mod stats; mod tls; use crate::config::{get_config, reload_config, VERSION}; +use crate::errors::Error; use crate::pool::{ClientServerMap, ConnectionPool}; use crate::prometheus::start_metric_server; use crate::stats::{Collector, Reporter, REPORTER}; @@ -133,8 +134,8 @@ async fn main() { let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); // Statistics reporting. - let (tx, rx) = mpsc::channel(100_000); - REPORTER.store(Arc::new(Reporter::new(tx.clone()))); + let (stats_tx, stats_rx) = mpsc::channel(100_000); + REPORTER.store(Arc::new(Reporter::new(stats_tx.clone()))); // Connection pool that allows to query all shards and replicas. match ConnectionPool::from_config(client_server_map.clone()).await { @@ -145,159 +146,148 @@ async fn main() { } }; - // Statistics collector task. - let collector_tx = tx.clone(); - - // Save these for reloading - let reload_client_server_map = client_server_map.clone(); - let autoreload_client_server_map = client_server_map.clone(); - tokio::task::spawn(async move { - let mut stats_collector = Collector::new(rx, collector_tx); + let mut stats_collector = Collector::new(stats_rx, stats_tx.clone()); stats_collector.collect().await; }); + info!("Config autoreloader: {}", config.general.autoreload); + + let mut term_signal = unix_signal(SignalKind::terminate()).unwrap(); + let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap(); + let mut sighup_signal = unix_signal(SignalKind::hangup()).unwrap(); + let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000)); + let (shutdown_tx, _) = broadcast::channel::<()>(1); + let (drain_tx, mut drain_rx) = mpsc::channel::(2048); + let (exit_tx, mut exit_rx) = mpsc::channel::<()>(1); + info!("Waiting for clients"); - let (shutdown_event_tx, mut shutdown_event_rx) = broadcast::channel::<()>(1); + let mut admin_only = false; + let mut total_clients = 0; - let shutdown_event_tx_clone = shutdown_event_tx.clone(); + loop { + tokio::select! { + // Reload config: + // kill -SIGHUP $(pgrep pgcat) + _ = sighup_signal.recv() => { + info!("Reloading config"); - // Client connection loop. - tokio::task::spawn(async move { - // Creates event subscriber for shutdown event, this is dropped when shutdown event is broadcast - let mut listener_shutdown_event_rx = shutdown_event_tx_clone.subscribe(); - loop { - let client_server_map = client_server_map.clone(); - - // Listen for shutdown event and client connection at the same time - let (socket, addr) = tokio::select! { - _ = listener_shutdown_event_rx.recv() => { - // Exits client connection loop which drops listener, listener_shutdown_event_rx and shutdown_event_tx_clone - break; - } + match reload_config(client_server_map.clone()).await { + Ok(_) => (), + Err(_) => (), + }; - listener_response = listener.accept() => { - match listener_response { - Ok((socket, addr)) => (socket, addr), - Err(err) => { - error!("{:?}", err); - continue; + get_config().show(); + }, + + _ = autoreload_interval.tick() => { + if config.general.autoreload { + info!("Automatically reloading config"); + + match reload_config(client_server_map.clone()).await { + Ok(changed) => { + if changed { + get_config().show() + } } - } + Err(_) => (), + }; } - }; - - // Used to signal shutdown - let client_shutdown_handler_rx = shutdown_event_tx_clone.subscribe(); - - // Used to signal that the task has completed - let dummy_tx = shutdown_event_tx_clone.clone(); - - // Handle client. - tokio::task::spawn(async move { - let start = chrono::offset::Utc::now().naive_utc(); - - match client::client_entrypoint( - socket, - client_server_map, - client_shutdown_handler_rx, - ) - .await - { - Ok(_) => { - let duration = chrono::offset::Utc::now().naive_utc() - start; - - info!( - "Client {:?} disconnected, session duration: {}", - addr, - format_duration(&duration) - ); - } + }, - Err(err) => { - debug!("Client disconnected with error {:?}", err); - } - }; - // Drop this transmitter so receiver knows that the task is completed - drop(dummy_tx); - }); - } - }); + // Initiate graceful shutdown sequence on sig int + _ = interrupt_signal.recv() => { + info!("Got SIGINT, waiting for client connection drain now"); + admin_only = true; - // Reload config: - // kill -SIGHUP $(pgrep pgcat) - tokio::task::spawn(async move { - let mut stream = unix_signal(SignalKind::hangup()).unwrap(); + // Broadcast that client tasks need to finish + let _ = shutdown_tx.send(()); + let exit_tx = exit_tx.clone(); + let _ = drain_tx.send(0).await; - loop { - stream.recv().await; + tokio::task::spawn(async move { + let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(config.general.shutdown_timeout)); - info!("Reloading config"); + // First tick fires immediately. + interval.tick().await; - match reload_config(reload_client_server_map.clone()).await { - Ok(_) => (), - Err(_) => continue, - }; + // Second one in the interval time. + interval.tick().await; - get_config().show(); - } - }); + // We're done waiting. + error!("Timed out waiting for clients"); - if config.general.autoreload { - let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000)); + let _ = exit_tx.send(()).await; + }); + }, - tokio::task::spawn(async move { - info!("Config autoreloader started"); - - loop { - interval.tick().await; - match reload_config(autoreload_client_server_map.clone()).await { - Ok(changed) => { - if changed { - get_config().show() - } + _ = term_signal.recv() => break, + + new_client = listener.accept() => { + let (socket, addr) = match new_client { + Ok((socket, addr)) => (socket, addr), + Err(err) => { + error!("{:?}", err); + continue; } - Err(_) => (), }; + + let shutdown_rx = shutdown_tx.subscribe(); + let drain_tx = drain_tx.clone(); + let client_server_map = client_server_map.clone(); + + tokio::task::spawn(async move { + let start = chrono::offset::Utc::now().naive_utc(); + + match client::client_entrypoint( + socket, + client_server_map, + shutdown_rx, + drain_tx, + admin_only, + ) + .await + { + Ok(()) => { + + let duration = chrono::offset::Utc::now().naive_utc() - start; + + info!( + "Client {:?} disconnected, session duration: {}", + addr, + format_duration(&duration) + ); + } + + Err(err) => { + match err { + // Don't count the clients we rejected. + Error::ShuttingDown => (), + _ => { + // drain_tx.send(-1).await.unwrap(); + } + } + + debug!("Client disconnected with error {:?}", err); + } + }; + }); } - }); - } - let mut term_signal = unix_signal(SignalKind::terminate()).unwrap(); - let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap(); + _ = exit_rx.recv() => { + break; + } - tokio::select! { - // Initiate graceful shutdown sequence on sig int - _ = interrupt_signal.recv() => { - info!("Got SIGINT, waiting for client connection drain now"); - - // Broadcast that client tasks need to finish - shutdown_event_tx.send(()).unwrap(); - // Closes transmitter - drop(shutdown_event_tx); - - // This is in a loop because the first event that the receiver receives will be the shutdown event - // This is not what we are waiting for instead, we want the receiver to send an error once all senders are closed which is reached after the shutdown event is received - loop { - match tokio::time::timeout( - tokio::time::Duration::from_millis(config.general.shutdown_timeout), - shutdown_event_rx.recv(), - ) - .await - { - Ok(res) => match res { - Ok(_) => {} - Err(_) => break, - }, - Err(_) => { - info!("Timed out while waiting for clients to shutdown"); - break; - } + client_ping = drain_rx.recv() => { + let client_ping = client_ping.unwrap(); + total_clients += client_ping; + + if total_clients == 0 && admin_only { + let _ = exit_tx.send(()).await; } } - }, - _ = term_signal.recv() => (), + } } info!("Shutting down..."); diff --git a/src/pool.rs b/src/pool.rs index ac5bc9110..99cccaf10 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -12,40 +12,74 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; -use crate::config::{get_config, Address, Role, Shard, User}; +use crate::config::{get_config, Address, Role, User}; use crate::errors::Error; use crate::server::Server; +use crate::sharding::ShardingFunction; use crate::stats::{get_reporter, Reporter}; pub type BanList = Arc>>>; -pub type ClientServerMap = Arc>>; +pub type ClientServerMap = Arc>>; pub type PoolMap = HashMap<(String, String), ConnectionPool>; /// The connection pool, globally available. /// This is atomic and safe and read-optimized. /// The pool is recreated dynamically when the config is reloaded. pub static POOLS: Lazy> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default())); +/// Pool mode: +/// - transaction: server serves one transaction, +/// - session: server is attached to the client. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum PoolMode { + Session, + Transaction, +} + +impl std::fmt::Display for PoolMode { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match *self { + PoolMode::Session => write!(f, "session"), + PoolMode::Transaction => write!(f, "transaction"), + } + } +} + +/// Pool settings. #[derive(Clone, Debug)] pub struct PoolSettings { - pub pool_mode: String, - pub shards: HashMap, + /// Transaction or Session. + pub pool_mode: PoolMode, + + // Number of shards. + pub shards: usize, + + // Connecting user. pub user: User, - pub default_role: String, + + // Default server role to connect to. + pub default_role: Option, + + // Enable/disable query parser. pub query_parser_enabled: bool, + + // Read from the primary as well or not. pub primary_reads_enabled: bool, - pub sharding_function: String, + + // Sharding function. + pub sharding_function: ShardingFunction, } + impl Default for PoolSettings { fn default() -> PoolSettings { PoolSettings { - pool_mode: String::from("transaction"), - shards: HashMap::from([(String::from("1"), Shard::default())]), + pool_mode: PoolMode::Transaction, + shards: 1, user: User::default(), - default_role: String::from("any"), + default_role: None, query_parser_enabled: false, primary_reads_enabled: true, - sharding_function: "pg_bigint_hash".to_string(), + sharding_function: ShardingFunction::PgBigintHash, } } } @@ -73,6 +107,7 @@ pub struct ConnectionPool { /// on pool creation and save the K messages here. server_info: BytesMut, + /// Pool configuration. pub settings: PoolSettings, } @@ -80,11 +115,13 @@ impl ConnectionPool { /// Construct the connection pool from the configuration. pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> { let config = get_config(); - let mut new_pools = PoolMap::default(); + let mut new_pools = HashMap::new(); let mut address_id = 0; + for (pool_name, pool_config) in &config.pools { - for (_user_index, user_info) in &pool_config.users { + // There is one pool per database/user pair. + for (_, user) in &pool_config.users { let mut shards = Vec::new(); let mut addresses = Vec::new(); let mut banlist = Vec::new(); @@ -98,8 +135,8 @@ impl ConnectionPool { // Sort by shard number to ensure consistency. shard_ids.sort_by_key(|k| k.parse::().unwrap()); - for shard_idx in shard_ids { - let shard = &pool_config.shards[&shard_idx]; + for shard_idx in &shard_ids { + let shard = &pool_config.shards[shard_idx]; let mut pools = Vec::new(); let mut servers = Vec::new(); let mut address_index = 0; @@ -119,12 +156,12 @@ impl ConnectionPool { id: address_id, database: shard.database.clone(), host: server.0.clone(), - port: server.1.to_string(), + port: server.1 as u16, role: role, address_index, replica_number, shard: shard_idx.parse::().unwrap(), - username: user_info.username.clone(), + username: user.username.clone(), pool_name: pool_name.clone(), }; @@ -137,14 +174,14 @@ impl ConnectionPool { let manager = ServerPool::new( address.clone(), - user_info.clone(), + user.clone(), &shard.database, client_server_map.clone(), get_reporter(), ); let pool = Pool::builder() - .max_size(user_info.pool_size) + .max_size(user.pool_size) .connection_timeout(std::time::Duration::from_millis( config.general.connect_timeout, )) @@ -171,13 +208,27 @@ impl ConnectionPool { stats: get_reporter(), server_info: BytesMut::new(), settings: PoolSettings { - pool_mode: pool_config.pool_mode.clone(), - shards: pool_config.shards.clone(), - user: user_info.clone(), - default_role: pool_config.default_role.clone(), + pool_mode: match pool_config.pool_mode.as_str() { + "transaction" => PoolMode::Transaction, + "session" => PoolMode::Session, + _ => unreachable!(), + }, + // shards: pool_config.shards.clone(), + shards: shard_ids.len(), + user: user.clone(), + default_role: match pool_config.default_role.as_str() { + "any" => None, + "replica" => Some(Role::Replica), + "primary" => Some(Role::Primary), + _ => unreachable!(), + }, query_parser_enabled: pool_config.query_parser_enabled.clone(), primary_reads_enabled: pool_config.primary_reads_enabled, - sharding_function: pool_config.sharding_function.clone(), + sharding_function: match pool_config.sharding_function.as_str() { + "pg_bigint_hash" => ShardingFunction::PgBigintHash, + "sha1" => ShardingFunction::Sha1, + _ => unreachable!(), + }, }, }; @@ -190,7 +241,9 @@ impl ConnectionPool { return Err(err); } }; - new_pools.insert((pool_name.clone(), user_info.username.clone()), pool); + + // There is one pool per database/user pair. + new_pools.insert((pool_name.clone(), user.username.clone()), pool); } } @@ -207,8 +260,8 @@ impl ConnectionPool { async fn validate(&mut self) -> Result<(), Error> { let mut server_infos = Vec::new(); for shard in 0..self.shards() { - for index in 0..self.servers(shard) { - let connection = match self.databases[shard][index].get().await { + for server in 0..self.servers(shard) { + let connection = match self.databases[shard][server].get().await { Ok(conn) => conn, Err(err) => { error!("Shard {} down or misconfigured: {:?}", shard, err); @@ -229,6 +282,7 @@ impl ConnectionPool { ); } } + server_infos.push(server_info); } } @@ -239,6 +293,8 @@ impl ConnectionPool { return Err(Error::AllServersDown); } + // We're assuming all servers are identical. + // TODO: not true. self.server_info = server_infos[0].clone(); Ok(()) @@ -252,9 +308,8 @@ impl ConnectionPool { process_id: i32, // client id ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { let now = Instant::now(); - let mut candidates: Vec
= self.addresses[shard] - .clone() - .into_iter() + let mut candidates: Vec<&Address> = self.addresses[shard] + .iter() .filter(|address| address.role == role) .collect(); @@ -271,7 +326,8 @@ impl ConnectionPool { None => break, }; - if self.is_banned(&address, address.shard, role) { + if self.is_banned(&address, role) { + debug!("Address {:?} is banned", address); continue; } @@ -286,8 +342,7 @@ impl ConnectionPool { Ok(conn) => conn, Err(err) => { error!("Banning instance {:?}, error: {:?}", address, err); - self.ban(&address, address.shard, process_id); - self.stats.client_disconnecting(process_id, address.id); + self.ban(&address, process_id); self.stats .checkout_time(now.elapsed().as_micros(), process_id, address.id); continue; @@ -301,6 +356,9 @@ impl ConnectionPool { let require_healthcheck = server.last_activity().elapsed().unwrap().as_millis() > healthcheck_delay; + // Do not issue a health check unless it's been a little while + // since we last checked the server is ok. + // Health checks are pretty expensive. if !require_healthcheck { self.stats .checkout_time(now.elapsed().as_micros(), process_id, address.id); @@ -314,7 +372,7 @@ impl ConnectionPool { match tokio::time::timeout( tokio::time::Duration::from_millis(healthcheck_timeout), - server.query(";"), + server.query(";"), // Cheap query (query parser not used in PG) ) .await { @@ -337,7 +395,7 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, address.shard, process_id); + self.ban(&address, process_id); continue; } }, @@ -351,44 +409,44 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, address.shard, process_id); + self.ban(&address, process_id); continue; } } } - return Err(Error::AllServersDown); + + Err(Error::AllServersDown) } /// Ban an address (i.e. replica). It no longer will serve /// traffic for any new transactions. Existing transactions on that replica /// will finish successfully or error out to the clients. - pub fn ban(&self, address: &Address, shard: usize, process_id: i32) { + pub fn ban(&self, address: &Address, process_id: i32) { self.stats.client_disconnecting(process_id, address.id); - self.stats - .checkout_time(Instant::now().elapsed().as_micros(), process_id, address.id); error!("Banning {:?}", address); + let now = chrono::offset::Utc::now().naive_utc(); let mut guard = self.banlist.write(); - guard[shard].insert(address.clone(), now); + guard[address.shard].insert(address.clone(), now); } /// Clear the replica to receive traffic again. Takes effect immediately /// for all new transactions. - pub fn _unban(&self, address: &Address, shard: usize) { + pub fn _unban(&self, address: &Address) { let mut guard = self.banlist.write(); - guard[shard].remove(address); + guard[address.shard].remove(address); } /// Check if a replica can serve traffic. If all replicas are banned, /// we unban all of them. Better to try then not to. - pub fn is_banned(&self, address: &Address, shard: usize, role: Option) -> bool { + pub fn is_banned(&self, address: &Address, role: Option) -> bool { let replicas_available = match role { - Some(Role::Replica) => self.addresses[shard] + Some(Role::Replica) => self.addresses[address.shard] .iter() .filter(|addr| addr.role == Role::Replica) .count(), - None => self.addresses[shard].len(), + None => self.addresses[address.shard].len(), Some(Role::Primary) => return false, // Primary cannot be banned. }; @@ -397,17 +455,17 @@ impl ConnectionPool { let guard = self.banlist.read(); // Everything is banned = nothing is banned. - if guard[shard].len() == replicas_available { + if guard[address.shard].len() == replicas_available { drop(guard); let mut guard = self.banlist.write(); - guard[shard].clear(); + guard[address.shard].clear(); drop(guard); warn!("Unbanning all replicas."); return false; } // I expect this to miss 99.9999% of the time. - match guard[shard].get(address) { + match guard[address.shard].get(address) { Some(timestamp) => { let now = chrono::offset::Utc::now().naive_utc(); let config = get_config(); @@ -417,7 +475,7 @@ impl ConnectionPool { drop(guard); warn!("Unbanning {:?}", address); let mut guard = self.banlist.write(); - guard[shard].remove(address); + guard[address.shard].remove(address); false } else { debug!("{:?} is banned", address); @@ -554,6 +612,7 @@ pub fn get_pool(db: String, user: String) -> Option { } } +/// How many total servers we have in the config. pub fn get_number_of_addresses() -> usize { get_all_pools() .iter() @@ -561,6 +620,7 @@ pub fn get_number_of_addresses() -> usize { .sum() } +/// Get a pointer to all configured pools. pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> { return (*(*POOLS.load())).clone(); } diff --git a/src/query_router.rs b/src/query_router.rs index 6b377684f..f9d5f0b37 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -10,7 +10,7 @@ use sqlparser::parser::Parser; use crate::config::Role; use crate::pool::PoolSettings; -use crate::sharding::{Sharder, ShardingFunction}; +use crate::sharding::Sharder; /// Regexes used to parse custom commands. const CUSTOM_SQL_REGEXES: [&str; 7] = [ @@ -55,11 +55,13 @@ pub struct QueryRouter { /// Include the primary into the replica pool for reads. primary_reads_enabled: bool, + /// Pool configuration. pool_settings: PoolSettings, } impl QueryRouter { - /// One-time initialization of regexes. + /// One-time initialization of regexes + /// that parse our custom SQL protocol. pub fn setup() -> bool { let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) { Ok(rgx) => rgx, @@ -74,10 +76,7 @@ impl QueryRouter { .map(|rgx| Regex::new(rgx).unwrap()) .collect(); - // Impossible - if list.len() != set.len() { - return false; - } + assert_eq!(list.len(), set.len()); match CUSTOM_SQL_REGEX_LIST.set(list) { Ok(_) => true, @@ -90,7 +89,8 @@ impl QueryRouter { } } - /// Create a new instance of the query router. Each client gets its own. + /// Create a new instance of the query router. + /// Each client gets its own. pub fn new() -> QueryRouter { QueryRouter { active_shard: None, @@ -101,6 +101,7 @@ impl QueryRouter { } } + /// Pool settings can change because of a config reload. pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) { self.pool_settings = pool_settings; } @@ -136,19 +137,6 @@ impl QueryRouter { return None; } - let sharding_function = match self.pool_settings.sharding_function.as_ref() { - "pg_bigint_hash" => ShardingFunction::PgBigintHash, - "sha1" => ShardingFunction::Sha1, - _ => unreachable!(), - }; - - let default_server_role = match self.pool_settings.default_role.as_ref() { - "any" => None, - "primary" => Some(Role::Primary), - "replica" => Some(Role::Replica), - _ => unreachable!(), - }; - let command = match matches[0] { 0 => Command::SetShardingKey, 1 => Command::SetShard, @@ -200,7 +188,10 @@ impl QueryRouter { match command { Command::SetShardingKey => { - let sharder = Sharder::new(self.pool_settings.shards.len(), sharding_function); + let sharder = Sharder::new( + self.pool_settings.shards, + self.pool_settings.sharding_function, + ); let shard = sharder.shard(value.parse::().unwrap()); self.active_shard = Some(shard); value = shard.to_string(); @@ -208,7 +199,7 @@ impl QueryRouter { Command::SetShard => { self.active_shard = match value.to_ascii_uppercase().as_ref() { - "ANY" => Some(rand::random::() % self.pool_settings.shards.len()), + "ANY" => Some(rand::random::() % self.pool_settings.shards), _ => Some(value.parse::().unwrap()), }; } @@ -236,7 +227,7 @@ impl QueryRouter { } "default" => { - self.active_role = default_server_role; + self.active_role = self.pool_settings.default_role; self.query_parser_enabled = self.query_parser_enabled; self.active_role } @@ -367,10 +358,10 @@ impl QueryRouter { #[cfg(test)] mod test { - use std::collections::HashMap; - use super::*; use crate::messages::simple_query; + use crate::pool::PoolMode; + use crate::sharding::ShardingFunction; use bytes::BufMut; #[test] @@ -633,13 +624,13 @@ mod test { QueryRouter::setup(); let pool_settings = PoolSettings { - pool_mode: "transaction".to_string(), - shards: HashMap::default(), + pool_mode: PoolMode::Transaction, + shards: 0, user: crate::config::User::default(), - default_role: Role::Replica.to_string(), + default_role: Some(Role::Replica), query_parser_enabled: true, primary_reads_enabled: false, - sharding_function: "pg_bigint_hash".to_string(), + sharding_function: ShardingFunction::PgBigintHash, }; let mut qr = QueryRouter::new(); assert_eq!(qr.active_role, None); @@ -661,9 +652,6 @@ mod test { let q2 = simple_query("SET SERVER ROLE TO 'default'"); assert!(qr.try_execute_command(q2) != None); - assert_eq!( - qr.active_role.unwrap().to_string(), - pool_settings.clone().default_role - ); + assert_eq!(qr.active_role.unwrap(), pool_settings.clone().default_role); } } diff --git a/src/server.rs b/src/server.rs index ddf95ce62..3134a65df 100644 --- a/src/server.rs +++ b/src/server.rs @@ -75,7 +75,7 @@ impl Server { stats: Reporter, ) -> Result { let mut stream = - match TcpStream::connect(&format!("{}:{}", &address.host, &address.port)).await { + match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await { Ok(stream) => stream, Err(err) => { error!("Could not connect to server: {}", err); @@ -342,7 +342,7 @@ impl Server { /// Uses a separate connection that's not part of the connection pool. pub async fn cancel( host: &str, - port: &str, + port: u16, process_id: i32, secret_key: i32, ) -> Result<(), Error> { @@ -529,7 +529,7 @@ impl Server { self.process_id, self.secret_key, self.address.host.clone(), - self.address.port.clone(), + self.address.port, ), ); } diff --git a/tests/python/tests.py b/tests/python/tests.py index a674cee67..092fc8cc9 100644 --- a/tests/python/tests.py +++ b/tests/python/tests.py @@ -14,6 +14,7 @@ def pgcat_start(): pg_cat_send_signal(signal.SIGTERM) os.system("./target/debug/pgcat .circleci/pgcat.toml &") + time.sleep(2) def pg_cat_send_signal(signal: signal.Signals): @@ -27,11 +28,23 @@ def pg_cat_send_signal(signal: signal.Signals): raise Exception("pgcat not closed after SIGTERM") -def connect_normal_db( - autocommit: bool = False, +def connect_db( + autocommit: bool = True, + admin: bool = False, ) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]: + + if admin: + user = "admin_user" + password = "admin_pass" + db = "pgcat" + else: + user = "sharding_user" + password = "sharding_user" + db = "sharded_db" + conn = psycopg2.connect( - f"postgres://sharding_user:sharding_user@{PGCAT_HOST}:{PGCAT_PORT}/sharded_db?application_name=testing_pgcat" + f"postgres://{user}:{password}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat", + connect_timeout=2, ) conn.autocommit = autocommit cur = conn.cursor() @@ -45,7 +58,7 @@ def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions. def test_normal_db_access(): - conn, cur = connect_normal_db() + conn, cur = connect_db(autocommit=False) cur.execute("SELECT 1") res = cur.fetchall() print(res) @@ -53,11 +66,7 @@ def test_normal_db_access(): def test_admin_db_access(): - conn = psycopg2.connect( - f"postgres://admin_user:admin_pass@{PGCAT_HOST}:{PGCAT_PORT}/pgcat" - ) - conn.autocommit = True # BEGIN/COMMIT is not supported by admin db - cur = conn.cursor() + conn, cur = connect_db(admin=True) cur.execute("SHOW POOLS") res = cur.fetchall() @@ -67,15 +76,14 @@ def test_admin_db_access(): def test_shutdown_logic(): - ##### NO ACTIVE QUERIES SIGINT HANDLING ##### + # - - - - - - - - - - - - - - - - - - + # NO ACTIVE QUERIES SIGINT HANDLING + # Start pgcat pgcat_start() - # Wait for server to fully start up - time.sleep(2) - # Create client connection and send query (not in transaction) - conn, cur = connect_normal_db(True) + conn, cur = connect_db() cur.execute("BEGIN;") cur.execute("SELECT 1;") @@ -97,17 +105,14 @@ def test_shutdown_logic(): cleanup_conn(conn, cur) pg_cat_send_signal(signal.SIGTERM) - ##### END ##### + # - - - - - - - - - - - - - - - - - - + # HANDLE TRANSACTION WITH SIGINT - ##### HANDLE TRANSACTION WITH SIGINT ##### # Start pgcat pgcat_start() - # Wait for server to fully start up - time.sleep(2) - # Create client connection and begin transaction - conn, cur = connect_normal_db(True) + conn, cur = connect_db() cur.execute("BEGIN;") cur.execute("SELECT 1;") @@ -126,17 +131,97 @@ def test_shutdown_logic(): cleanup_conn(conn, cur) pg_cat_send_signal(signal.SIGTERM) - ##### END ##### + # - - - - - - - - - - - - - - - - - - + # NO NEW NON-ADMIN CONNECTIONS DURING SHUTDOWN + # Start pgcat + pgcat_start() + + # Create client connection and begin transaction + transaction_conn, transaction_cur = connect_db() + + transaction_cur.execute("BEGIN;") + transaction_cur.execute("SELECT 1;") + + # Send sigint to pgcat while still in transaction + pg_cat_send_signal(signal.SIGINT) + time.sleep(1) - ##### HANDLE SHUTDOWN TIMEOUT WITH SIGINT ##### + start = time.perf_counter() + try: + conn, cur = connect_db() + cur.execute("SELECT 1;") + cleanup_conn(conn, cur) + except psycopg2.OperationalError as e: + time_taken = time.perf_counter() - start + if time_taken > 0.1: + raise Exception( + "Failed to reject connection within 0.1 seconds, got", time_taken, "seconds") + pass + else: + raise Exception("Able connect to database during shutdown") + + cleanup_conn(transaction_conn, transaction_cur) + pg_cat_send_signal(signal.SIGTERM) + + # - - - - - - - - - - - - - - - - - - + # ALLOW NEW ADMIN CONNECTIONS DURING SHUTDOWN # Start pgcat pgcat_start() - # Wait for server to fully start up - time.sleep(3) + # Create client connection and begin transaction + transaction_conn, transaction_cur = connect_db() + + transaction_cur.execute("BEGIN;") + transaction_cur.execute("SELECT 1;") + + # Send sigint to pgcat while still in transaction + pg_cat_send_signal(signal.SIGINT) + time.sleep(1) + + try: + conn, cur = connect_db(admin=True) + cur.execute("SHOW DATABASES;") + cleanup_conn(conn, cur) + except psycopg2.OperationalError as e: + raise Exception(e) + + cleanup_conn(transaction_conn, transaction_cur) + pg_cat_send_signal(signal.SIGTERM) + + # - - - - - - - - - - - - - - - - - - + # ADMIN CONNECTIONS CONTINUING TO WORK AFTER SHUTDOWN + # Start pgcat + pgcat_start() + + # Create client connection and begin transaction + transaction_conn, transaction_cur = connect_db() + transaction_cur.execute("BEGIN;") + transaction_cur.execute("SELECT 1;") + + admin_conn, admin_cur = connect_db(admin=True) + admin_cur.execute("SHOW DATABASES;") + + # Send sigint to pgcat while still in transaction + pg_cat_send_signal(signal.SIGINT) + time.sleep(1) + + try: + admin_cur.execute("SHOW DATABASES;") + except psycopg2.OperationalError as e: + raise Exception("Could not execute admin command:", e) + + cleanup_conn(transaction_conn, transaction_cur) + cleanup_conn(admin_conn, admin_cur) + pg_cat_send_signal(signal.SIGTERM) + + # - - - - - - - - - - - - - - - - - - + # HANDLE SHUTDOWN TIMEOUT WITH SIGINT + + # Start pgcat + pgcat_start() # Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached - conn, cur = connect_normal_db(True) + conn, cur = connect_db() cur.execute("BEGIN;") cur.execute("SELECT 1;") @@ -159,7 +244,7 @@ def test_shutdown_logic(): cleanup_conn(conn, cur) pg_cat_send_signal(signal.SIGTERM) - ##### END ##### + # - - - - - - - - - - - - - - - - - - test_normal_db_access() diff --git a/tests/ruby/tests.rb b/tests/ruby/tests.rb index ba9476f43..b665dcb1f 100644 --- a/tests/ruby/tests.rb +++ b/tests/ruby/tests.rb @@ -5,6 +5,89 @@ require 'toml' $stdout.sync = true +$stderr.sync = true + +class ConfigEditor + def initialize + @original_config_text = File.read('../../.circleci/pgcat.toml') + text_to_load = @original_config_text.gsub("5432", "\"5432\"") + + @original_configs = TOML.load(text_to_load) + end + + def original_configs + TOML.load(TOML::Generator.new(@original_configs).body) + end + + def with_modified_configs(new_configs) + text_to_write = TOML::Generator.new(new_configs).body + text_to_write = text_to_write.gsub("\"5432\"", "5432") + File.write('../../.circleci/pgcat.toml', text_to_write) + yield + ensure + File.write('../../.circleci/pgcat.toml', @original_config_text) + end +end + +def with_captured_stdout_stderr + sout = STDOUT.clone + serr = STDERR.clone + STDOUT.reopen("/tmp/out.txt", "w+") + STDERR.reopen("/tmp/err.txt", "w+") + STDOUT.sync = true + STDERR.sync = true + yield + return File.read('/tmp/out.txt'), File.read('/tmp/err.txt') +ensure + STDOUT.reopen(sout) + STDERR.reopen(serr) +end + + +def test_extended_protocol_pooler_errors + admin_conn = PG::connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat") + + conf_editor = ConfigEditor.new + new_configs = conf_editor.original_configs + + # shorter timeouts + new_configs["general"]["connect_timeout"] = 500 + new_configs["general"]["ban_time"] = 1 + new_configs["general"]["shutdown_timeout"] = 1 + new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1 + new_configs["pools"]["sharded_db"]["users"]["1"]["pool_size"] = 1 + + conf_editor.with_modified_configs(new_configs) { admin_conn.async_exec("RELOAD") } + + conn_str = "postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db" + 10.times do + Thread.new do + conn = PG::connect(conn_str) + conn.async_exec("SELECT pg_sleep(5)") rescue PG::SystemError + ensure + conn&.close + end + end + + sleep(0.5) + conn_under_test = PG::connect(conn_str) + stdout, stderr = with_captured_stdout_stderr do + 5.times do |i| + conn_under_test.async_exec("SELECT 1") rescue PG::SystemError + conn_under_test.exec_params("SELECT #{i} + $1", [i]) rescue PG::SystemError + sleep 1 + end + end + + raise StandardError, "Libpq got unexpected messages while idle" if stderr.include?("arrived from server while idle") + puts "Pool checkout errors not breaking clients passed" +ensure + sleep 1 + admin_conn.async_exec("RELOAD") # Reset state + conn_under_test&.close +end + +test_extended_protocol_pooler_errors # Uncomment these two to see all queries. # ActiveRecord.verbose_query_logs = true @@ -144,30 +227,6 @@ def test_server_parameters end -class ConfigEditor - def initialize - @original_config_text = File.read('../../.circleci/pgcat.toml') - text_to_load = @original_config_text.gsub("5432", "\"5432\"") - - @original_configs = TOML.load(text_to_load) - end - - def original_configs - TOML.load(TOML::Generator.new(@original_configs).body) - end - - def with_modified_configs(new_configs) - text_to_write = TOML::Generator.new(new_configs).body - text_to_write = text_to_write.gsub("\"5432\"", "5432") - File.write('../../.circleci/pgcat.toml', text_to_write) - yield - ensure - File.write('../../.circleci/pgcat.toml', @original_config_text) - end - -end - - def test_reload_pool_recycling admin_conn = PG::connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat") server_conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat") @@ -201,3 +260,5 @@ def test_reload_pool_recycling end test_reload_pool_recycling + +