Skip to content

Commit

Permalink
Prevent clients from sticking to old pools after config update (#113)
Browse files Browse the repository at this point in the history
* Re-acquire pool at the beginning of Protocol loop

* Fix query router + add tests for recycling behavior
  • Loading branch information
drdrsh authored Aug 9, 2022
1 parent 3719c22 commit 7592339
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 38 deletions.
68 changes: 43 additions & 25 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::config::get_config;
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::pool::{get_pool, ClientServerMap};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::stats::{get_reporter, Reporter};
Expand Down Expand Up @@ -73,8 +73,13 @@ pub struct Client<S, T> {
/// Last server process id we talked to.
last_server_id: Option<i32>,

target_pool: ConnectionPool,
/// Name of the server pool for this client (This comes from the database name in the connection string)
target_pool_name: String,

/// Postgres user for this client (This comes from the user in the connection string)
target_user_name: String,

/// Used to notify clients about an impending shutdown
shutdown_event_receiver: Receiver<()>,
}

Expand Down Expand Up @@ -305,19 +310,19 @@ where

trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
let database = match parameters.get("database") {
let target_pool_name = match parameters.get("database") {
Some(db) => db,
None => return Err(Error::ClientError),
};

let user = match parameters.get("user") {
let target_user_name = match parameters.get("user") {
Some(user) => user,
None => return Err(Error::ClientError),
};

let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.filter(|db| *db == &target_pool_name)
.count()
== 1;

Expand Down Expand Up @@ -352,31 +357,28 @@ where
Err(_) => return Err(Error::SocketError),
};

let (target_pool, transaction_mode, server_info) = if admin {
let (transaction_mode, server_info) = if admin {
let correct_user = config.general.admin_username.as_str();
let correct_password = config.general.admin_password.as_str();

// Compare server and client hashes.
let password_hash = md5_hash_password(correct_user, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
wrong_password(&mut write, target_user_name).await?;
return Err(Error::ClientError);
}
(
ConnectionPool::default(),
false,
generate_server_info_for_admin(),
)

(false, generate_server_info_for_admin())
} else {
let target_pool = match get_pool(database.clone(), user.clone()) {
let target_pool = match get_pool(target_pool_name.clone(), target_user_name.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
database, user
target_pool_name, target_user_name
),
)
.await?;
Expand All @@ -387,14 +389,14 @@ where
let server_info = target_pool.server_info();
// Compare server and client hashes.
let correct_password = target_pool.settings.user.password.as_str();
let password_hash = md5_hash_password(user, correct_password, &salt);
let password_hash = md5_hash_password(&target_user_name, correct_password, &salt);

if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
wrong_password(&mut write, &target_user_name).await?;
return Err(Error::ClientError);
}
(target_pool, transaction_mode, server_info)
(transaction_mode, server_info)
};

debug!("Password authentication successful");
Expand Down Expand Up @@ -424,7 +426,8 @@ where
admin: admin,
last_address_id: None,
last_server_id: None,
target_pool: target_pool,
target_pool_name: target_pool_name.clone(),
target_user_name: target_user_name.clone(),
shutdown_event_receiver: shutdown_event_receiver,
});
}
Expand Down Expand Up @@ -455,7 +458,8 @@ where
admin: false,
last_address_id: None,
last_server_id: None,
target_pool: ConnectionPool::default(),
target_pool_name: String::from("undefined"),
target_user_name: String::from("undefined"),
shutdown_event_receiver: shutdown_event_receiver,
});
}
Expand Down Expand Up @@ -494,7 +498,7 @@ where

// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new(self.target_pool.clone());
let mut query_router = QueryRouter::new();
let mut round_robin = 0;

// Our custom protocol loop.
Expand All @@ -520,11 +524,6 @@ where
message_result = read_message(&mut self.read) => message_result?
};

// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = self.target_pool.clone();

// Avoid taking a server if the client just wants to disconnect.
if message[0] as char == 'X' {
debug!("Client disconnecting");
Expand All @@ -538,6 +537,25 @@ where
continue;
}

// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool =
match get_pool(self.target_pool_name.clone(), self.target_user_name.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut self.write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
self.target_pool_name, self.target_user_name
),
)
.await?;
return Err(Error::ClientError);
}
};
query_router.update_pool_settings(pool.settings.clone());
let current_shard = query_router.shard();

// Handle all custom protocol commands, if any.
Expand Down
69 changes: 57 additions & 12 deletions src/query_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;

use crate::config::Role;
use crate::pool::{ConnectionPool, PoolSettings};
use crate::pool::PoolSettings;
use crate::sharding::{Sharder, ShardingFunction};

/// Regexes used to parse custom commands.
Expand Down Expand Up @@ -91,16 +91,20 @@ impl QueryRouter {
}

/// Create a new instance of the query router. Each client gets its own.
pub fn new(target_pool: ConnectionPool) -> QueryRouter {
pub fn new() -> QueryRouter {
QueryRouter {
active_shard: None,
active_role: None,
query_parser_enabled: target_pool.settings.query_parser_enabled,
primary_reads_enabled: target_pool.settings.primary_reads_enabled,
pool_settings: target_pool.settings,
query_parser_enabled: false,
primary_reads_enabled: false,
pool_settings: PoolSettings::default(),
}
}

pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
self.pool_settings = pool_settings;
}

/// Try to parse a command and execute it.
pub fn try_execute_command(&mut self, mut buf: BytesMut) -> Option<(Command, String)> {
let code = buf.get_u8() as char;
Expand Down Expand Up @@ -363,22 +367,24 @@ impl QueryRouter {

#[cfg(test)]
mod test {
use std::collections::HashMap;

use super::*;
use crate::messages::simple_query;
use bytes::BufMut;

#[test]
fn test_defaults() {
QueryRouter::setup();
let qr = QueryRouter::new(ConnectionPool::default());
let qr = QueryRouter::new();

assert_eq!(qr.role(), None);
}

#[test]
fn test_infer_role_replica() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert_eq!(qr.query_parser_enabled(), true);

Expand All @@ -402,7 +408,7 @@ mod test {
#[test]
fn test_infer_role_primary() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();

let queries = vec![
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
Expand All @@ -421,7 +427,7 @@ mod test {
#[test]
fn test_infer_role_primary_reads_enabled() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None);

Expand All @@ -432,7 +438,7 @@ mod test {
#[test]
fn test_infer_role_parse_prepared() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();
qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);

Expand Down Expand Up @@ -523,7 +529,7 @@ mod test {
#[test]
fn test_try_execute_command() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();

// SetShardingKey
let query = simple_query("SET SHARDING KEY TO 13");
Expand Down Expand Up @@ -600,7 +606,7 @@ mod test {
#[test]
fn test_enable_query_parser() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new();
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);

Expand All @@ -621,4 +627,43 @@ mod test {
assert!(qr.try_execute_command(query) != None);
assert!(qr.query_parser_enabled());
}

#[test]
fn test_update_from_pool_settings() {
QueryRouter::setup();

let pool_settings = PoolSettings {
pool_mode: "transaction".to_string(),
shards: HashMap::default(),
user: crate::config::User::default(),
default_role: Role::Replica.to_string(),
query_parser_enabled: true,
primary_reads_enabled: false,
sharding_function: "pg_bigint_hash".to_string(),
};
let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None);
assert_eq!(qr.query_parser_enabled, false);
assert_eq!(qr.primary_reads_enabled, false);

// Internal state must not be changed due to this, only defaults
qr.update_pool_settings(pool_settings.clone());

assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None);
assert_eq!(qr.query_parser_enabled, false);
assert_eq!(qr.primary_reads_enabled, false);

let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(q1) != None);
assert_eq!(qr.active_role.unwrap(), Role::Primary);

let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(q2) != None);
assert_eq!(
qr.active_role.unwrap().to_string(),
pool_settings.clone().default_role
);
}
}
3 changes: 2 additions & 1 deletion tests/ruby/.ruby-version
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
2.7.1
3.0.0

1 change: 1 addition & 0 deletions tests/ruby/Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ source "https://rubygems.org"
gem "pg"
gem "activerecord"
gem "rubocop"
gem "toml", "~> 0.3.0"
5 changes: 5 additions & 0 deletions tests/ruby/Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ GEM
parallel (1.22.1)
parser (3.1.2.0)
ast (~> 2.4.1)
parslet (2.0.0)
pg (1.3.2)
rainbow (3.1.1)
regexp_parser (2.3.1)
Expand All @@ -35,17 +36,21 @@ GEM
rubocop-ast (1.17.0)
parser (>= 3.1.1.0)
ruby-progressbar (1.11.0)
toml (0.3.0)
parslet (>= 1.8.0, < 3.0.0)
tzinfo (2.0.4)
concurrent-ruby (~> 1.0)
unicode-display_width (2.1.0)

PLATFORMS
arm64-darwin-21
x86_64-linux

DEPENDENCIES
activerecord
pg
rubocop
toml (~> 0.3.0)

BUNDLED WITH
2.3.7
Loading

0 comments on commit 7592339

Please sign in to comment.