diff --git a/.circleci/pgcat.toml b/.circleci/pgcat.toml index 943eff92..eca8f673 100644 --- a/.circleci/pgcat.toml +++ b/.circleci/pgcat.toml @@ -5,21 +5,12 @@ # # General pooler settings [general] - # What IP to run on, 0.0.0.0 means accessible from everywhere. host = "0.0.0.0" # Port to run on, same as PgBouncer used in this example. port = 6432 -# How many connections to allocate per server. -pool_size = 15 - -# Pool mode (see PgBouncer docs for more). -# session: one server connection per connected client -# transaction: one server connection per client transaction -pool_mode = "transaction" - # How long to wait before aborting a server connection (ms). connect_timeout = 100 @@ -29,56 +20,27 @@ healthcheck_timeout = 100 # For how long to ban a server if it fails a health check (seconds). ban_time = 60 # Seconds -# +# Reload config automatically if it changes. autoreload = true +# TLS tls_certificate = ".circleci/server.cert" tls_private_key = ".circleci/server.key" -# -# User to use for authentication against the server. -[user] -name = "sharding_user" -password = "sharding_user" - - -# -# Shards in the cluster -[shards] - -# Shard 0 -[shards.0] - -# [ host, port, role ] -servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5433, "replica" ], - # [ "127.0.1.1", 5432, "replica" ], -] -# Database name (e.g. "postgres") -database = "shard0" - -[shards.1] -# [ host, port, role ] -servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5433, "replica" ], - # [ "127.0.1.1", 5432, "replica" ], -] -database = "shard1" - -[shards.2] -# [ host, port, role ] -servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5433, "replica" ], - # [ "127.0.1.1", 5432, "replica" ], -] -database = "shard2" - +# Credentials to access the virtual administrative database (pgbouncer or pgcat) +# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. +admin_username = "admin_user" +admin_password = "admin_pass" -# Settings for our query routing layer. -[query_router] +# pool +# configs are structured as pool. +# the pool_name is what clients use as database name when connecting +# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded" +[pools.sharded_db] +# Pool mode (see PgBouncer docs for more). +# session: one server connection per connected client +# transaction: one server connection per client transaction +pool_mode = "transaction" # If the client doesn't specify, route traffic to # this role by default. @@ -88,7 +50,6 @@ database = "shard2" # primary: all queries go to the primary unless otherwise specified. default_role = "any" - # Query parser. If enabled, we'll attempt to parse # every incoming query to determine if it's a read or a write. # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, @@ -109,3 +70,36 @@ primary_reads_enabled = true # sha1: A hashing function based on SHA1 # sharding_function = "pg_bigint_hash" + +# Credentials for users that may connect to this cluster +[pools.sharded_db.users.0] +username = "sharding_user" +password = "sharding_user" +# Maximum number of server connections that can be established for this user +# The maximum number of connection from a single Pgcat process to any database in the cluster +# is the sum of pool_size across all users. +pool_size = 9 + +# Shard 0 +[pools.sharded_db.shards.0] +# [ host, port, role ] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ] +] +# Database name (e.g. "postgres") +database = "shard0" + +[pools.sharded_db.shards.1] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], +] +database = "shard1" + +[pools.sharded_db.shards.2] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], +] +database = "shard2" diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index a0e23f0a..c932a86e 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -32,6 +32,7 @@ toxiproxy-cli create -l 127.0.0.1:5433 -u 127.0.0.1:5432 postgres_replica start_pgcat "info" export PGPASSWORD=sharding_user +export PGDATABASE=sharded_db # pgbench test pgbench -U sharding_user -i -h 127.0.0.1 -p 6432 @@ -47,7 +48,7 @@ sleep 1 killall psql -s SIGINT # Reload pool (closing unused server connections) -psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' +PGPASSWORD=admin_pass psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' (psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) & sleep 1 @@ -72,15 +73,17 @@ cd tests/ruby && \ cd ../.. # Admin tests -psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null -psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null -psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null -psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null -psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null -psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null -psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null -psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore -(! psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null) +export PGPASSWORD=admin_pass +psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null +psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null +psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null +psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null +psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null +psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null +psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null +psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore +(! psql -U admin_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null) +export PGPASSWORD=sharding_user # Start PgCat in debug to demonstrate failover better start_pgcat "trace" diff --git a/Cargo.lock b/Cargo.lock index cfaa0fbd..ddab730d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -395,7 +395,7 @@ dependencies = [ [[package]] name = "pgcat" -version = "0.4.0-beta1" +version = "0.6.0-alpha1" dependencies = [ "arc-swap", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 3f65e90c..8bdeab67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgcat" -version = "0.4.0-beta1" +version = "0.6.0-alpha1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/pgcat.toml b/pgcat.toml index e9dbf075..a1937e6c 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -5,21 +5,12 @@ # # General pooler settings [general] - # What IP to run on, 0.0.0.0 means accessible from everywhere. host = "0.0.0.0" # Port to run on, same as PgBouncer used in this example. port = 6432 -# How many connections to allocate per server. -pool_size = 15 - -# Pool mode (see PgBouncer docs for more). -# session: one server connection per connected client -# transaction: one server connection per client transaction -pool_mode = "transaction" - # How long to wait before aborting a server connection (ms). connect_timeout = 5000 @@ -27,7 +18,7 @@ connect_timeout = 5000 healthcheck_timeout = 1000 # For how long to ban a server if it fails a health check (seconds). -ban_time = 60 # Seconds +ban_time = 60 # seconds # Reload config automatically if it changes. autoreload = false @@ -36,50 +27,20 @@ autoreload = false # tls_certificate = "server.cert" # tls_private_key = "server.key" -# -# User to use for authentication against the server. -[user] -name = "sharding_user" -password = "sharding_user" - - -# -# Shards in the cluster -[shards] - -# Shard 0 -[shards.0] - -# [ host, port, role ] -servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5432, "replica" ], - # [ "127.0.1.1", 5432, "replica" ], -] -# Database name (e.g. "postgres") -database = "shard0" - -[shards.1] -# [ host, port, role ] -servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5432, "replica" ], - # [ "127.0.1.1", 5432, "replica" ], -] -database = "shard1" +# Credentials to access the virtual administrative database (pgbouncer or pgcat) +# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. +admin_username = "user" +admin_password = "pass" -[shards.2] -# [ host, port, role ] -servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5432, "replica" ], - # [ "127.0.1.1", 5432, "replica" ], -] -database = "shard2" - - -# Settings for our query routing layer. -[query_router] +# pool +# configs are structured as pool. +# the pool_name is what clients use as database name when connecting +# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded" +[pools.sharded] +# Pool mode (see PgBouncer docs for more). +# session: one server connection per connected client +# transaction: one server connection per client transaction +pool_mode = "transaction" # If the client doesn't specify, route traffic to # this role by default. @@ -89,7 +50,6 @@ database = "shard2" # primary: all queries go to the primary unless otherwise specified. default_role = "any" - # Query parser. If enabled, we'll attempt to parse # every incoming query to determine if it's a read or a write. # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, @@ -110,3 +70,61 @@ primary_reads_enabled = true # sha1: A hashing function based on SHA1 # sharding_function = "pg_bigint_hash" + +# Credentials for users that may connect to this cluster +[pools.sharded.users.0] +username = "sharding_user" +password = "sharding_user" +# Maximum number of server connections that can be established for this user +# The maximum number of connection from a single Pgcat process to any database in the cluster +# is the sum of pool_size across all users. +pool_size = 9 + +[pools.sharded.users.1] +username = "other_user" +password = "other_user" +pool_size = 21 + +# Shard 0 +[pools.sharded.shards.0] +# [ host, port, role ] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ] +] +# Database name (e.g. "postgres") +database = "shard0" + +[pools.sharded.shards.1] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], +] +database = "shard1" + +[pools.sharded.shards.2] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], +] +database = "shard2" + + +[pools.simple_db] +pool_mode = "session" +default_role = "primary" +query_parser_enabled = true +primary_reads_enabled = true +sharding_function = "pg_bigint_hash" + +[pools.simple_db.users.0] +username = "simple_user" +password = "simple_user" +pool_size = 5 + +[pools.simple_db.shards.0] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ] +] +database = "some_db" diff --git a/src/admin.rs b/src/admin.rs index 74acf151..163227db 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -3,10 +3,10 @@ use bytes::{Buf, BufMut, BytesMut}; use log::{info, trace}; use std::collections::HashMap; -use crate::config::{get_config, reload_config}; +use crate::config::{get_config, reload_config, VERSION}; use crate::errors::Error; use crate::messages::*; -use crate::pool::ConnectionPool; +use crate::pool::get_all_pools; use crate::stats::get_stats; use crate::ClientServerMap; @@ -14,7 +14,6 @@ use crate::ClientServerMap; pub async fn handle_admin( stream: &mut T, mut query: BytesMut, - pool: ConnectionPool, client_server_map: ClientServerMap, ) -> Result<(), Error> where @@ -35,7 +34,7 @@ where if query.starts_with("SHOW STATS") { trace!("SHOW STATS"); - show_stats(stream, &pool).await + show_stats(stream).await } else if query.starts_with("RELOAD") { trace!("RELOAD"); reload(stream, client_server_map).await @@ -44,13 +43,13 @@ where show_config(stream).await } else if query.starts_with("SHOW DATABASES") { trace!("SHOW DATABASES"); - show_databases(stream, &pool).await + show_databases(stream).await } else if query.starts_with("SHOW POOLS") { trace!("SHOW POOLS"); - show_pools(stream, &pool).await + show_pools(stream).await } else if query.starts_with("SHOW LISTS") { trace!("SHOW LISTS"); - show_lists(stream, &pool).await + show_lists(stream).await } else if query.starts_with("SHOW VERSION") { trace!("SHOW VERSION"); show_version(stream).await @@ -63,7 +62,7 @@ where } /// Column-oriented statistics. -async fn show_lists(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> +async fn show_lists(stream: &mut T) -> Result<(), Error> where T: tokio::io::AsyncWrite + std::marker::Unpin, { @@ -71,17 +70,20 @@ where let columns = vec![("list", DataType::Text), ("items", DataType::Int4)]; + let mut users = 1; + let mut databases = 1; + for (_, pool) in get_all_pools() { + databases += pool.databases(); + users += 1; // One user per pool + } let mut res = BytesMut::new(); res.put(row_description(&columns)); res.put(data_row(&vec![ "databases".to_string(), - (pool.databases() + 1).to_string(), // see comment below + databases.to_string(), ])); - res.put(data_row(&vec!["users".to_string(), "1".to_string()])); - res.put(data_row(&vec![ - "pools".to_string(), - (pool.databases() + 1).to_string(), // +1 for the pgbouncer admin db pool which isn't real - ])); // but admin tools that work with pgbouncer want this + res.put(data_row(&vec!["users".to_string(), users.to_string()])); + res.put(data_row(&vec!["pools".to_string(), databases.to_string()])); res.put(data_row(&vec![ "free_clients".to_string(), stats @@ -140,7 +142,7 @@ where let mut res = BytesMut::new(); res.put(row_description(&vec![("version", DataType::Text)])); - res.put(data_row(&vec!["PgCat 0.1.0".to_string()])); + res.put(data_row(&vec![format!("PgCat {}", VERSION).to_string()])); res.put(command_complete("SHOW")); res.put_u8(b'Z'); @@ -151,12 +153,11 @@ where } /// Show utilization of connection pools for each shard and replicas. -async fn show_pools(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> +async fn show_pools(stream: &mut T) -> Result<(), Error> where T: tokio::io::AsyncWrite + std::marker::Unpin, { let stats = get_stats(); - let config = get_config(); let columns = vec![ ("database", DataType::Text), @@ -176,24 +177,26 @@ where let mut res = BytesMut::new(); res.put(row_description(&columns)); - - for shard in 0..pool.shards() { - for server in 0..pool.servers(shard) { - let address = pool.address(shard, server); - let stats = match stats.get(&address.id) { - Some(stats) => stats.clone(), - None => HashMap::new(), - }; - - let mut row = vec![address.name(), config.user.name.clone()]; - - for column in &columns[2..columns.len() - 1] { - let value = stats.get(column.0).unwrap_or(&0).to_string(); - row.push(value); + for (_, pool) in get_all_pools() { + let pool_config = &pool.settings; + for shard in 0..pool.shards() { + for server in 0..pool.servers(shard) { + let address = pool.address(shard, server); + let stats = match stats.get(&address.id) { + Some(stats) => stats.clone(), + None => HashMap::new(), + }; + + let mut row = vec![address.name(), pool_config.user.username.clone()]; + + for column in &columns[2..columns.len() - 1] { + let value = stats.get(column.0).unwrap_or(&0).to_string(); + row.push(value); + } + + row.push(pool_config.pool_mode.to_string()); + res.put(data_row(&row)); } - - row.push(config.general.pool_mode.to_string()); - res.put(data_row(&row)); } } @@ -208,12 +211,10 @@ where } /// Show shards and replicas. -async fn show_databases(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> +async fn show_databases(stream: &mut T) -> Result<(), Error> where T: tokio::io::AsyncWrite + std::marker::Unpin, { - let config = get_config(); - // Columns let columns = vec![ ("name", DataType::Text), @@ -235,31 +236,33 @@ where res.put(row_description(&columns)); - for shard in 0..pool.shards() { - let database_name = &config.shards[&shard.to_string()].database; - - for server in 0..pool.servers(shard) { - let address = pool.address(shard, server); - let pool_state = pool.pool_state(shard, server); - - res.put(data_row(&vec![ - address.name(), // name - address.host.to_string(), // host - address.port.to_string(), // port - database_name.to_string(), // database - config.user.name.to_string(), // force_user - config.general.pool_size.to_string(), // pool_size - "0".to_string(), // min_pool_size - "0".to_string(), // reserve_pool - config.general.pool_mode.to_string(), // pool_mode - config.general.pool_size.to_string(), // max_connections - pool_state.connections.to_string(), // current_connections - "0".to_string(), // paused - "0".to_string(), // disabled - ])); + for (_, pool) in get_all_pools() { + let pool_config = pool.settings.clone(); + for shard in 0..pool.shards() { + let database_name = &pool_config.shards[&shard.to_string()].database; + + for server in 0..pool.servers(shard) { + let address = pool.address(shard, server); + let pool_state = pool.pool_state(shard, server); + + res.put(data_row(&vec![ + address.name(), // name + address.host.to_string(), // host + address.port.to_string(), // port + database_name.to_string(), // database + pool_config.user.username.to_string(), // force_user + pool_config.user.pool_size.to_string(), // pool_size + "0".to_string(), // min_pool_size + "0".to_string(), // reserve_pool + pool_config.pool_mode.to_string(), // pool_mode + pool_config.user.pool_size.to_string(), // max_connections + pool_state.connections.to_string(), // current_connections + "0".to_string(), // paused + "0".to_string(), // disabled + ])); + } } } - res.put(command_complete("SHOW")); // ReadyForQuery @@ -349,7 +352,7 @@ where } /// Show shard and replicas statistics. -async fn show_stats(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> +async fn show_stats(stream: &mut T) -> Result<(), Error> where T: tokio::io::AsyncWrite + std::marker::Unpin, { @@ -375,21 +378,23 @@ where let mut res = BytesMut::new(); res.put(row_description(&columns)); - for shard in 0..pool.shards() { - for server in 0..pool.servers(shard) { - let address = pool.address(shard, server); - let stats = match stats.get(&address.id) { - Some(stats) => stats.clone(), - None => HashMap::new(), - }; + for (_, pool) in get_all_pools() { + for shard in 0..pool.shards() { + for server in 0..pool.servers(shard) { + let address = pool.address(shard, server); + let stats = match stats.get(&address.id) { + Some(stats) => stats.clone(), + None => HashMap::new(), + }; - let mut row = vec![address.name()]; + let mut row = vec![address.name()]; - for column in &columns[1..] { - row.push(stats.get(column.0).unwrap_or(&0).to_string()); - } + for column in &columns[1..] { + row.push(stats.get(column.0).unwrap_or(&0).to_string()); + } - res.put(data_row(&row)); + res.put(data_row(&row)); + } } } diff --git a/src/client.rs b/src/client.rs index 05895b6d..4f32d0af 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,7 +10,7 @@ use crate::config::get_config; use crate::constants::*; use crate::errors::Error; use crate::messages::*; -use crate::pool::{get_pool, ClientServerMap}; +use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; use crate::server::Server; use crate::stats::{get_reporter, Reporter}; @@ -71,6 +71,8 @@ pub struct Client { /// Last server process id we talked to. last_server_id: Option, + + target_pool: ConnectionPool, } /// Client entrypoint. @@ -258,11 +260,25 @@ where client_server_map: ClientServerMap, ) -> Result, Error> { let config = get_config(); - let transaction_mode = config.general.pool_mode == "transaction"; let stats = get_reporter(); trace!("Got StartupMessage"); let parameters = parse_startup(bytes.clone())?; + let database = match parameters.get("database") { + Some(db) => db, + None => return Err(Error::ClientError), + }; + + let user = match parameters.get("user") { + Some(user) => user, + None => return Err(Error::ClientError), + }; + + let admin = ["pgcat", "pgbouncer"] + .iter() + .filter(|db| *db == &database) + .count() + == 1; // Generate random backend ID and secret key let process_id: i32 = rand::random(); @@ -295,33 +311,57 @@ where Err(_) => return Err(Error::SocketError), }; - // Compare server and client hashes. - let password_hash = md5_hash_password(&config.user.name, &config.user.password, &salt); + let mut target_pool: ConnectionPool = ConnectionPool::default(); + let mut transaction_mode = false; + + if admin { + let correct_user = config.general.admin_username.as_str(); + let correct_password = config.general.admin_password.as_str(); + + // Compare server and client hashes. + let password_hash = md5_hash_password(correct_user, correct_password, &salt); + if password_hash != password_response { + debug!("Password authentication failed"); + wrong_password(&mut write, user).await?; + return Err(Error::ClientError); + } + } else { + target_pool = match get_pool(database.clone(), user.clone()) { + Some(pool) => pool, + None => { + error_response( + &mut write, + &format!( + "No pool configured for database: {:?}, user: {:?}", + database, user + ), + ) + .await?; + return Err(Error::ClientError); + } + }; + transaction_mode = target_pool.settings.pool_mode == "transaction"; + + // Compare server and client hashes. + let correct_password = target_pool.settings.user.password.as_str(); + let password_hash = md5_hash_password(user, correct_password, &salt); - if password_hash != password_response { - debug!("Password authentication failed"); - wrong_password(&mut write, &config.user.name).await?; - return Err(Error::ClientError); + if password_hash != password_response { + debug!("Password authentication failed"); + wrong_password(&mut write, user).await?; + return Err(Error::ClientError); + } } debug!("Password authentication successful"); auth_ok(&mut write).await?; - write_all(&mut write, get_pool().server_info()).await?; + write_all(&mut write, target_pool.server_info()).await?; backend_key_data(&mut write, process_id, secret_key).await?; ready_for_query(&mut write).await?; trace!("Startup OK"); - let database = parameters - .get("database") - .unwrap_or(parameters.get("user").unwrap()); - let admin = ["pgcat", "pgbouncer"] - .iter() - .filter(|db| *db == &database) - .count() - == 1; - // Split the read and write streams // so we can control buffering. @@ -335,11 +375,12 @@ where process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, - parameters: parameters, + parameters: parameters.clone(), stats: stats, admin: admin, last_address_id: None, last_server_id: None, + target_pool: target_pool, }); } @@ -353,26 +394,22 @@ where ) -> Result, Error> { let process_id = bytes.get_i32(); let secret_key = bytes.get_i32(); - - let config = get_config(); - let transaction_mode = config.general.pool_mode == "transaction"; - let stats = get_reporter(); - return Ok(Client { read: BufReader::new(read), write: write, addr, buffer: BytesMut::with_capacity(8196), cancel_mode: true, - transaction_mode: transaction_mode, + transaction_mode: false, process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, parameters: HashMap::new(), - stats: stats, + stats: get_reporter(), admin: false, last_address_id: None, last_server_id: None, + target_pool: ConnectionPool::default(), }); } @@ -410,7 +447,7 @@ where // The query router determines where the query is going to go, // e.g. primary, replica, which shard. - let mut query_router = QueryRouter::new(); + let mut query_router = QueryRouter::new(self.target_pool.clone()); let mut round_robin = 0; // Our custom protocol loop. @@ -432,7 +469,7 @@ where // Get a pool instance referenced by the most up-to-date // pointer. This ensures we always read the latest config // when starting a query. - let mut pool = get_pool(); + let mut pool = self.target_pool.clone(); // Avoid taking a server if the client just wants to disconnect. if message[0] as char == 'X' { @@ -443,13 +480,7 @@ where // Handle admin database queries. if self.admin { debug!("Handling admin command"); - handle_admin( - &mut self.write, - message, - pool.clone(), - self.client_server_map.clone(), - ) - .await?; + handle_admin(&mut self.write, message, self.client_server_map.clone()).await?; continue; } diff --git a/src/config.rs b/src/config.rs index da59d2ae..d660fdcf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,6 +4,7 @@ use log::{error, info}; use once_cell::sync::Lazy; use serde_derive::Deserialize; use std::collections::{HashMap, HashSet}; +use std::hash::Hash; use std::path::Path; use std::sync::Arc; use tokio::fs::File; @@ -14,6 +15,8 @@ use crate::errors::Error; use crate::tls::{load_certs, load_keys}; use crate::{ClientServerMap, ConnectionPool}; +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); + /// Globally available configuration. static CONFIG: Lazy> = Lazy::new(|| ArcSwap::from_pointee(Config::default())); @@ -58,6 +61,7 @@ pub struct Address { pub host: String, pub port: String, pub shard: usize, + pub database: String, pub role: Role, pub replica_number: usize, } @@ -70,6 +74,7 @@ impl Default for Address { port: String::from("5432"), shard: 0, replica_number: 0, + database: String::from("database"), role: Role::Replica, } } @@ -79,9 +84,12 @@ impl Address { /// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`. pub fn name(&self) -> String { match self.role { - Role::Primary => format!("shard_{}_primary", self.shard), + Role::Primary => format!("{}_shard_{}_primary", self.database, self.shard), - Role::Replica => format!("shard_{}_replica_{}", self.shard, self.replica_number), + Role::Replica => format!( + "{}_shard_{}_replica_{}", + self.database, self.shard, self.replica_number + ), } } } @@ -89,15 +97,17 @@ impl Address { /// PostgreSQL user. #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)] pub struct User { - pub name: String, + pub username: String, pub password: String, + pub pool_size: u32, } impl Default for User { fn default() -> User { User { - name: String::from("postgres"), + username: String::from("postgres"), password: String::new(), + pool_size: 15, } } } @@ -107,14 +117,14 @@ impl Default for User { pub struct General { pub host: String, pub port: i16, - pub pool_size: u32, - pub pool_mode: String, pub connect_timeout: u64, pub healthcheck_timeout: u64, pub ban_time: i64, pub autoreload: bool, pub tls_certificate: Option, pub tls_private_key: Option, + pub admin_username: String, + pub admin_password: String, } impl Default for General { @@ -122,14 +132,37 @@ impl Default for General { General { host: String::from("localhost"), port: 5432, - pool_size: 15, - pool_mode: String::from("transaction"), connect_timeout: 5000, healthcheck_timeout: 1000, ban_time: 60, autoreload: false, tls_certificate: None, tls_private_key: None, + admin_username: String::from("admin"), + admin_password: String::from("admin"), + } + } +} +#[derive(Deserialize, Debug, Clone, PartialEq)] +pub struct Pool { + pub pool_mode: String, + pub shards: HashMap, + pub users: HashMap, + pub default_role: String, + pub query_parser_enabled: bool, + pub primary_reads_enabled: bool, + pub sharding_function: String, +} +impl Default for Pool { + fn default() -> Pool { + Pool { + pool_mode: String::from("transaction"), + shards: HashMap::from([(String::from("1"), Shard::default())]), + users: HashMap::default(), + default_role: String::from("any"), + query_parser_enabled: false, + primary_reads_enabled: true, + sharding_function: "pg_bigint_hash".to_string(), } } } @@ -137,8 +170,8 @@ impl Default for General { /// Shard configuration. #[derive(Deserialize, Debug, Clone, PartialEq)] pub struct Shard { - pub servers: Vec<(String, u16, String)>, pub database: String, + pub servers: Vec<(String, u16, String)>, } impl Default for Shard { @@ -150,26 +183,6 @@ impl Default for Shard { } } -/// Query Router configuration. -#[derive(Deserialize, Debug, Clone, PartialEq)] -pub struct QueryRouter { - pub default_role: String, - pub query_parser_enabled: bool, - pub primary_reads_enabled: bool, - pub sharding_function: String, -} - -impl Default for QueryRouter { - fn default() -> QueryRouter { - QueryRouter { - default_role: String::from("any"), - query_parser_enabled: false, - primary_reads_enabled: true, - sharding_function: "pg_bigint_hash".to_string(), - } - } -} - fn default_path() -> String { String::from("pgcat.toml") } @@ -181,9 +194,7 @@ pub struct Config { pub path: String, pub general: General, - pub user: User, - pub shards: HashMap, - pub query_router: QueryRouter, + pub pools: HashMap, } impl Default for Config { @@ -191,26 +202,58 @@ impl Default for Config { Config { path: String::from("pgcat.toml"), general: General::default(), - user: User::default(), - shards: HashMap::from([(String::from("1"), Shard::default())]), - query_router: QueryRouter::default(), + pools: HashMap::default(), } } } impl From<&Config> for std::collections::HashMap { fn from(config: &Config) -> HashMap { - HashMap::from([ + let mut r: Vec<(String, String)> = config + .pools + .iter() + .flat_map(|(pool_name, pool)| { + [ + ( + format!("pools.{}.pool_mode", pool_name), + pool.pool_mode.clone(), + ), + ( + format!("pools.{}.primary_reads_enabled", pool_name), + pool.primary_reads_enabled.to_string(), + ), + ( + format!("pools.{}.query_parser_enabled", pool_name), + pool.query_parser_enabled.to_string(), + ), + ( + format!("pools.{}.default_role", pool_name), + pool.default_role.clone(), + ), + ( + format!("pools.{}.sharding_function", pool_name), + pool.sharding_function.clone(), + ), + ( + format!("pools.{:?}.shard_count", pool_name), + pool.shards.len().to_string(), + ), + ( + format!("pools.{:?}.users", pool_name), + pool.users + .iter() + .map(|(_username, user)| &user.username) + .cloned() + .collect::>() + .join(", "), + ), + ] + }) + .collect(); + + let mut static_settings = vec![ ("host".to_string(), config.general.host.to_string()), ("port".to_string(), config.general.port.to_string()), - ( - "pool_size".to_string(), - config.general.pool_size.to_string(), - ), - ( - "pool_mode".to_string(), - config.general.pool_mode.to_string(), - ), ( "connect_timeout".to_string(), config.general.connect_timeout.to_string(), @@ -220,42 +263,22 @@ impl From<&Config> for std::collections::HashMap { config.general.healthcheck_timeout.to_string(), ), ("ban_time".to_string(), config.general.ban_time.to_string()), - ( - "default_role".to_string(), - config.query_router.default_role.to_string(), - ), - ( - "query_parser_enabled".to_string(), - config.query_router.query_parser_enabled.to_string(), - ), - ( - "primary_reads_enabled".to_string(), - config.query_router.primary_reads_enabled.to_string(), - ), - ( - "sharding_function".to_string(), - config.query_router.sharding_function.to_string(), - ), - ]) + ]; + + r.append(&mut static_settings); + return r.iter().cloned().collect(); } } impl Config { /// Print current configuration. pub fn show(&self) { - info!("Pool size: {}", self.general.pool_size); - info!("Pool mode: {}", self.general.pool_mode); info!("Ban time: {}s", self.general.ban_time); info!( "Healthcheck timeout: {}ms", self.general.healthcheck_timeout ); info!("Connection timeout: {}ms", self.general.connect_timeout); - info!("Sharding function: {}", self.query_router.sharding_function); - info!("Primary reads: {}", self.query_router.primary_reads_enabled); - info!("Query router: {}", self.query_router.query_parser_enabled); - info!("Number of shards: {}", self.shards.len()); - match self.general.tls_certificate.clone() { Some(tls_certificate) => { info!("TLS certificate: {}", tls_certificate); @@ -274,6 +297,25 @@ impl Config { info!("TLS support is disabled"); } }; + + for (pool_name, pool_config) in &self.pools { + info!("--- Settings for pool {} ---", pool_name); + info!( + "Pool size from all users: {}", + pool_config + .users + .iter() + .map(|(_, user_cfg)| user_cfg.pool_size) + .sum::() + .to_string() + ); + info!("Pool mode: {}", pool_config.pool_mode); + info!("Sharding function: {}", pool_config.sharding_function); + info!("Primary reads: {}", pool_config.primary_reads_enabled); + info!("Query router: {}", pool_config.query_parser_enabled); + info!("Number of shards: {}", pool_config.shards.len()); + info!("Number of users: {}", pool_config.users.len()); + } } } @@ -311,88 +353,6 @@ pub async fn parse(path: &str) -> Result<(), Error> { } }; - match config.query_router.sharding_function.as_ref() { - "pg_bigint_hash" => (), - "sha1" => (), - _ => { - error!( - "Supported sharding functions are: 'pg_bigint_hash', 'sha1', got: '{}'", - config.query_router.sharding_function - ); - return Err(Error::BadConfig); - } - }; - - // Quick config sanity check. - for shard in &config.shards { - // We use addresses as unique identifiers, - // let's make sure they are unique in the config as well. - let mut dup_check = HashSet::new(); - let mut primary_count = 0; - - match shard.0.parse::() { - Ok(_) => (), - Err(_) => { - error!( - "Shard '{}' is not a valid number, shards must be numbered starting at 0", - shard.0 - ); - return Err(Error::BadConfig); - } - }; - - if shard.1.servers.len() == 0 { - error!("Shard {} has no servers configured", shard.0); - return Err(Error::BadConfig); - } - - for server in &shard.1.servers { - dup_check.insert(server); - - // Check that we define only zero or one primary. - match server.2.as_ref() { - "primary" => primary_count += 1, - _ => (), - }; - - // Check role spelling. - match server.2.as_ref() { - "primary" => (), - "replica" => (), - _ => { - error!( - "Shard {} server role must be either 'primary' or 'replica', got: '{}'", - shard.0, server.2 - ); - return Err(Error::BadConfig); - } - }; - } - - if primary_count > 1 { - error!("Shard {} has more than on primary configured", &shard.0); - return Err(Error::BadConfig); - } - - if dup_check.len() != shard.1.servers.len() { - error!("Shard {} contains duplicate server configs", &shard.0); - return Err(Error::BadConfig); - } - } - - match config.query_router.default_role.as_ref() { - "any" => (), - "primary" => (), - "replica" => (), - other => { - error!( - "Query router default_role must be 'primary', 'replica', or 'any', got: '{}'", - other - ); - return Err(Error::BadConfig); - } - }; - // Validate TLS! match config.general.tls_certificate.clone() { Some(tls_certificate) => { @@ -424,6 +384,90 @@ pub async fn parse(path: &str) -> Result<(), Error> { None => (), }; + for (pool_name, pool) in &config.pools { + match pool.sharding_function.as_ref() { + "pg_bigint_hash" => (), + "sha1" => (), + _ => { + error!( + "Supported sharding functions are: 'pg_bigint_hash', 'sha1', got: '{}' in pool {} settings", + pool.sharding_function, + pool_name + ); + return Err(Error::BadConfig); + } + }; + + match pool.default_role.as_ref() { + "any" => (), + "primary" => (), + "replica" => (), + other => { + error!( + "Query router default_role must be 'primary', 'replica', or 'any', got: '{}'", + other + ); + return Err(Error::BadConfig); + } + }; + + for shard in &pool.shards { + // We use addresses as unique identifiers, + // let's make sure they are unique in the config as well. + let mut dup_check = HashSet::new(); + let mut primary_count = 0; + + match shard.0.parse::() { + Ok(_) => (), + Err(_) => { + error!( + "Shard '{}' is not a valid number, shards must be numbered starting at 0", + shard.0 + ); + return Err(Error::BadConfig); + } + }; + + if shard.1.servers.len() == 0 { + error!("Shard {} has no servers configured", shard.0); + return Err(Error::BadConfig); + } + + for server in &shard.1.servers { + dup_check.insert(server); + + // Check that we define only zero or one primary. + match server.2.as_ref() { + "primary" => primary_count += 1, + _ => (), + }; + + // Check role spelling. + match server.2.as_ref() { + "primary" => (), + "replica" => (), + _ => { + error!( + "Shard {} server role must be either 'primary' or 'replica', got: '{}'", + shard.0, server.2 + ); + return Err(Error::BadConfig); + } + }; + } + + if primary_count > 1 { + error!("Shard {} has more than on primary configured", &shard.0); + return Err(Error::BadConfig); + } + + if dup_check.len() != shard.1.servers.len() { + error!("Shard {} contains duplicate server configs", &shard.0); + return Err(Error::BadConfig); + } + } + } + config.path = path.to_string(); // Update the configuration globally. @@ -434,7 +478,6 @@ pub async fn parse(path: &str) -> Result<(), Error> { pub async fn reload_config(client_server_map: ClientServerMap) -> Result { let old_config = get_config(); - match parse(&old_config.path).await { Ok(()) => (), Err(err) => { @@ -442,11 +485,10 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result>>>; pub type ClientServerMap = Arc>>; - +pub type PoolMap = HashMap<(String, String), ConnectionPool>; /// The connection pool, globally available. /// This is atomic and safe and read-optimized. /// The pool is recreated dynamically when the config is reloaded. -pub static POOL: Lazy> = - Lazy::new(|| ArcSwap::from_pointee(ConnectionPool::default())); +pub static POOLS: Lazy> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default())); + +#[derive(Clone, Debug)] +pub struct PoolSettings { + pub pool_mode: String, + pub shards: HashMap, + pub user: User, + pub default_role: String, + pub query_parser_enabled: bool, + pub primary_reads_enabled: bool, + pub sharding_function: String, +} +impl Default for PoolSettings { + fn default() -> PoolSettings { + PoolSettings { + pool_mode: String::from("transaction"), + shards: HashMap::from([(String::from("1"), Shard::default())]), + user: User::default(), + default_role: String::from("any"), + query_parser_enabled: false, + primary_reads_enabled: true, + sharding_function: "pg_bigint_hash".to_string(), + } + } +} /// The globally accessible connection pool. #[derive(Clone, Debug, Default)] @@ -46,107 +70,124 @@ pub struct ConnectionPool { /// clients on startup. We pre-connect to all shards and replicas /// on pool creation and save the K messages here. server_info: BytesMut, + + pub settings: PoolSettings, } impl ConnectionPool { /// Construct the connection pool from the configuration. pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> { - let reporter = get_reporter(); let config = get_config(); + let mut new_pools = PoolMap::default(); - let mut shards = Vec::new(); - let mut addresses = Vec::new(); - let mut banlist = Vec::new(); let mut address_id = 0; - let mut shard_ids = config - .shards - .clone() - .into_keys() - .map(|x| x.to_string()) - .collect::>(); - - // Sort by shard number to ensure consistency. - shard_ids.sort_by_key(|k| k.parse::().unwrap()); - - for shard_idx in shard_ids { - let shard = &config.shards[&shard_idx]; - let mut pools = Vec::new(); - let mut servers = Vec::new(); - let mut replica_number = 0; - - for server in shard.servers.iter() { - let role = match server.2.as_ref() { - "primary" => Role::Primary, - "replica" => Role::Replica, - _ => { - error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2); - Role::Replica - } - }; - - let address = Address { - id: address_id, - host: server.0.clone(), - port: server.1.to_string(), - role: role, - replica_number, - shard: shard_idx.parse::().unwrap(), - }; + for (pool_name, pool_config) in &config.pools { + for (_user_index, user_info) in &pool_config.users { + let mut shards = Vec::new(); + let mut addresses = Vec::new(); + let mut banlist = Vec::new(); + let mut shard_ids = pool_config + .shards + .clone() + .into_keys() + .map(|x| x.to_string()) + .collect::>(); + + // Sort by shard number to ensure consistency. + shard_ids.sort_by_key(|k| k.parse::().unwrap()); + + for shard_idx in shard_ids { + let shard = &pool_config.shards[&shard_idx]; + let mut pools = Vec::new(); + let mut servers = Vec::new(); + let mut replica_number = 0; + + for server in shard.servers.iter() { + let role = match server.2.as_ref() { + "primary" => Role::Primary, + "replica" => Role::Replica, + _ => { + error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2); + Role::Replica + } + }; + + let address = Address { + id: address_id, + database: pool_name.clone(), + host: server.0.clone(), + port: server.1.to_string(), + role: role, + replica_number, + shard: shard_idx.parse::().unwrap(), + }; + + address_id += 1; + + if role == Role::Replica { + replica_number += 1; + } + + let manager = ServerPool::new( + address.clone(), + user_info.clone(), + &shard.database, + client_server_map.clone(), + get_reporter(), + ); - address_id += 1; + let pool = Pool::builder() + .max_size(user_info.pool_size) + .connection_timeout(std::time::Duration::from_millis( + config.general.connect_timeout, + )) + .test_on_check_out(false) + .build(manager) + .await + .unwrap(); + + pools.push(pool); + servers.push(address); + } - if role == Role::Replica { - replica_number += 1; + shards.push(pools); + addresses.push(servers); + banlist.push(HashMap::new()); } - let manager = ServerPool::new( - address.clone(), - config.user.clone(), - &shard.database, - client_server_map.clone(), - reporter.clone(), - ); - - let pool = Pool::builder() - .max_size(config.general.pool_size) - .connection_timeout(std::time::Duration::from_millis( - config.general.connect_timeout, - )) - .test_on_check_out(false) - .build(manager) - .await - .unwrap(); - - pools.push(pool); - servers.push(address); - } - - shards.push(pools); - addresses.push(servers); - banlist.push(HashMap::new()); - } - - assert_eq!(shards.len(), addresses.len()); - - let mut pool = ConnectionPool { - databases: shards, - addresses: addresses, - banlist: Arc::new(RwLock::new(banlist)), - stats: reporter, - server_info: BytesMut::new(), - }; + assert_eq!(shards.len(), addresses.len()); + + let mut pool = ConnectionPool { + databases: shards, + addresses: addresses, + banlist: Arc::new(RwLock::new(banlist)), + stats: get_reporter(), + server_info: BytesMut::new(), + settings: PoolSettings { + pool_mode: pool_config.pool_mode.clone(), + shards: pool_config.shards.clone(), + user: user_info.clone(), + default_role: pool_config.default_role.clone(), + query_parser_enabled: pool_config.query_parser_enabled.clone(), + primary_reads_enabled: pool_config.primary_reads_enabled, + sharding_function: pool_config.sharding_function.clone(), + }, + }; - // Connect to the servers to make sure pool configuration is valid - // before setting it globally. - match pool.validate().await { - Ok(_) => (), - Err(err) => { - error!("Could not validate connection pool: {:?}", err); - return Err(err); + // Connect to the servers to make sure pool configuration is valid + // before setting it globally. + match pool.validate().await { + Ok(_) => (), + Err(err) => { + error!("Could not validate connection pool: {:?}", err); + return Err(err); + } + }; + new_pools.insert((pool_name.clone(), user_info.username.clone()), pool); } - }; + } - POOL.store(Arc::new(pool.clone())); + POOLS.store(Arc::new(new_pools.clone())); Ok(()) } @@ -474,7 +515,7 @@ impl ManageConnection for ServerPool { info!( "Creating a new connection to {:?} using user {:?}", self.address.name(), - self.user.name + self.user.username ); // Put a temporary process_id into the stats @@ -517,6 +558,20 @@ impl ManageConnection for ServerPool { } /// Get the connection pool -pub fn get_pool() -> ConnectionPool { - (*(*POOL.load())).clone() +pub fn get_pool(db: String, user: String) -> Option { + match get_all_pools().get(&(db, user)) { + Some(pool) => Some(pool.clone()), + None => None, + } +} + +pub fn get_number_of_addresses() -> usize { + get_all_pools() + .iter() + .map(|(_, pool)| pool.databases()) + .sum() +} + +pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> { + return (*(*POOLS.load())).clone(); } diff --git a/src/query_router.rs b/src/query_router.rs index 98a47702..d597b81e 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -8,7 +8,8 @@ use sqlparser::ast::Statement::{Query, StartTransaction}; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; -use crate::config::{get_config, Role}; +use crate::config::Role; +use crate::pool::{ConnectionPool, PoolSettings}; use crate::sharding::{Sharder, ShardingFunction}; /// Regexes used to parse custom commands. @@ -53,6 +54,8 @@ pub struct QueryRouter { /// Include the primary into the replica pool for reads. primary_reads_enabled: bool, + + pool_settings: PoolSettings, } impl QueryRouter { @@ -88,14 +91,13 @@ impl QueryRouter { } /// Create a new instance of the query router. Each client gets its own. - pub fn new() -> QueryRouter { - let config = get_config(); - + pub fn new(target_pool: ConnectionPool) -> QueryRouter { QueryRouter { active_shard: None, active_role: None, - query_parser_enabled: config.query_router.query_parser_enabled, - primary_reads_enabled: config.query_router.primary_reads_enabled, + query_parser_enabled: target_pool.settings.query_parser_enabled, + primary_reads_enabled: target_pool.settings.primary_reads_enabled, + pool_settings: target_pool.settings, } } @@ -130,15 +132,13 @@ impl QueryRouter { return None; } - let config = get_config(); - - let sharding_function = match config.query_router.sharding_function.as_ref() { + let sharding_function = match self.pool_settings.sharding_function.as_ref() { "pg_bigint_hash" => ShardingFunction::PgBigintHash, "sha1" => ShardingFunction::Sha1, _ => unreachable!(), }; - let default_server_role = match config.query_router.default_role.as_ref() { + let default_server_role = match self.pool_settings.default_role.as_ref() { "any" => None, "primary" => Some(Role::Primary), "replica" => Some(Role::Replica), @@ -196,7 +196,7 @@ impl QueryRouter { match command { Command::SetShardingKey => { - let sharder = Sharder::new(config.shards.len(), sharding_function); + let sharder = Sharder::new(self.pool_settings.shards.len(), sharding_function); let shard = sharder.shard(value.parse::().unwrap()); self.active_shard = Some(shard); value = shard.to_string(); @@ -204,7 +204,7 @@ impl QueryRouter { Command::SetShard => { self.active_shard = match value.to_ascii_uppercase().as_ref() { - "ANY" => Some(rand::random::() % config.shards.len()), + "ANY" => Some(rand::random::() % self.pool_settings.shards.len()), _ => Some(value.parse::().unwrap()), }; } @@ -233,7 +233,7 @@ impl QueryRouter { "default" => { self.active_role = default_server_role; - self.query_parser_enabled = config.query_router.query_parser_enabled; + self.query_parser_enabled = self.query_parser_enabled; self.active_role } @@ -250,7 +250,7 @@ impl QueryRouter { self.primary_reads_enabled = false; } else if value == "default" { debug!("Setting primary reads to default"); - self.primary_reads_enabled = config.query_router.primary_reads_enabled; + self.primary_reads_enabled = self.pool_settings.primary_reads_enabled; } } @@ -370,7 +370,7 @@ mod test { #[test] fn test_defaults() { QueryRouter::setup(); - let qr = QueryRouter::new(); + let qr = QueryRouter::new(ConnectionPool::default()); assert_eq!(qr.role(), None); } @@ -378,7 +378,7 @@ mod test { #[test] fn test_infer_role_replica() { QueryRouter::setup(); - let mut qr = QueryRouter::new(); + let mut qr = QueryRouter::new(ConnectionPool::default()); assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); assert_eq!(qr.query_parser_enabled(), true); @@ -402,7 +402,7 @@ mod test { #[test] fn test_infer_role_primary() { QueryRouter::setup(); - let mut qr = QueryRouter::new(); + let mut qr = QueryRouter::new(ConnectionPool::default()); let queries = vec![ simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"), @@ -421,7 +421,7 @@ mod test { #[test] fn test_infer_role_primary_reads_enabled() { QueryRouter::setup(); - let mut qr = QueryRouter::new(); + let mut qr = QueryRouter::new(ConnectionPool::default()); let query = simple_query("SELECT * FROM items WHERE id = 5"); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); @@ -432,7 +432,7 @@ mod test { #[test] fn test_infer_role_parse_prepared() { QueryRouter::setup(); - let mut qr = QueryRouter::new(); + let mut qr = QueryRouter::new(ConnectionPool::default()); qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); @@ -523,15 +523,15 @@ mod test { #[test] fn test_try_execute_command() { QueryRouter::setup(); - let mut qr = QueryRouter::new(); + let mut qr = QueryRouter::new(ConnectionPool::default()); // SetShardingKey let query = simple_query("SET SHARDING KEY TO 13"); assert_eq!( qr.try_execute_command(query), - Some((Command::SetShardingKey, String::from("1"))) + Some((Command::SetShardingKey, String::from("0"))) ); - assert_eq!(qr.shard(), 1); + assert_eq!(qr.shard(), 0); // SetShard let query = simple_query("SET SHARD TO '1'"); @@ -600,7 +600,7 @@ mod test { #[test] fn test_enable_query_parser() { QueryRouter::setup(); - let mut qr = QueryRouter::new(); + let mut qr = QueryRouter::new(ConnectionPool::default()); let query = simple_query("SET SERVER ROLE TO 'auto'"); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); diff --git a/src/server.rs b/src/server.rs index b20d1533..d21696f0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -82,7 +82,7 @@ impl Server { trace!("Sending StartupMessage"); // StartupMessage - startup(&mut stream, &user.name, database).await?; + startup(&mut stream, &user.username, database).await?; let mut server_info = BytesMut::new(); let mut process_id: i32 = 0; @@ -127,7 +127,7 @@ impl Server { Err(_) => return Err(Error::SocketError), }; - md5_password(&mut stream, &user.name, &user.password, &salt[..]) + md5_password(&mut stream, &user.username, &user.password, &salt[..]) .await?; } diff --git a/src/stats.rs b/src/stats.rs index 59a03dc7..e0113395 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -6,7 +6,7 @@ use parking_lot::Mutex; use std::collections::HashMap; use tokio::sync::mpsc::{channel, Receiver, Sender}; -use crate::pool::get_pool; +use crate::pool::get_number_of_addresses; pub static REPORTER: Lazy> = Lazy::new(|| ArcSwap::from_pointee(Reporter::default())); @@ -331,8 +331,8 @@ impl Collector { tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD / 15)); loop { interval.tick().await; - let addresses = get_pool().databases(); - for address_id in 0..addresses { + let address_count = get_number_of_addresses(); + for address_id in 0..address_count { let _ = tx.try_send(Event { name: EventName::UpdateStats, value: 0, @@ -349,8 +349,8 @@ impl Collector { tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD)); loop { interval.tick().await; - let addresses = get_pool().databases(); - for address_id in 0..addresses { + let address_count = get_number_of_addresses(); + for address_id in 0..address_count { let _ = tx.try_send(Event { name: EventName::UpdateAverages, value: 0, diff --git a/tests/ruby/tests.rb b/tests/ruby/tests.rb index 983619f0..aaabe5e4 100644 --- a/tests/ruby/tests.rb +++ b/tests/ruby/tests.rb @@ -15,7 +15,7 @@ port: 6432, username: 'sharding_user', password: 'sharding_user', - database: 'rails_dev', + database: 'sharded_db', application_name: 'testing_pgcat', prepared_statements: false, # Transaction mode advisory_locks: false # Same @@ -117,7 +117,7 @@ def down # Test evil clients def poorly_behaved_client - conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/rails_dev?application_name=testing_pgcat") + conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat") conn.async_exec 'BEGIN' conn.async_exec 'SELECT 1'