Skip to content

Commit

Permalink
Handle and track startup parameters (#478)
Browse files Browse the repository at this point in the history
* User server parameters struct instead of server info bytesmut

* Refactor to use hashmap for all params and add server parameters to client

* Sync parameters on client server checkout

* minor refactor

* update client side parameters when changed

* Move the SET statement logic from the C packet to the S packet.

* trigger build

* revert validation changes

* remove comment

* Try fix

* Reset cleanup state after sync

* fix server version test

* Track application name through client life for stats

* Add tests

* minor refactoring

* fmt

* fix

* fmt
  • Loading branch information
zainkabani authored Aug 10, 2023
1 parent 9ab1285 commit f94ce97
Show file tree
Hide file tree
Showing 8 changed files with 308 additions and 123 deletions.
17 changes: 9 additions & 8 deletions src/admin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::pool::BanReason;
use crate::server::ServerParameters;
use crate::stats::pool::PoolStats;
use bytes::{Buf, BufMut, BytesMut};
use log::{error, info, trace};
Expand All @@ -17,16 +18,16 @@ use crate::pool::ClientServerMap;
use crate::pool::{get_all_pools, get_pool};
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};

pub fn generate_server_info_for_admin() -> BytesMut {
let mut server_info = BytesMut::new();
pub fn generate_server_parameters_for_admin() -> ServerParameters {
let mut server_parameters = ServerParameters::new();

server_info.put(server_parameter_message("application_name", ""));
server_info.put(server_parameter_message("client_encoding", "UTF8"));
server_info.put(server_parameter_message("server_encoding", "UTF8"));
server_info.put(server_parameter_message("server_version", VERSION));
server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
server_parameters.set_param("application_name".to_string(), "".to_string(), true);
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true);
server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true);
server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true);
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true);

server_info
server_parameters
}

/// Handle admin client.
Expand Down
107 changes: 54 additions & 53 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;

use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
Expand All @@ -22,7 +22,7 @@ use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::server::{Server, ServerParameters};
use crate::stats::{ClientStats, ServerStats};
use crate::tls::Tls;

Expand Down Expand Up @@ -96,8 +96,8 @@ pub struct Client<S, T> {
/// Postgres user for this client (This comes from the user in the connection string)
username: String,

/// Application name for this client (defaults to pgcat)
application_name: String,
/// Server startup and session parameters that we're going to track
server_parameters: ServerParameters,

/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
Expand Down Expand Up @@ -502,7 +502,7 @@ where
};

// Authenticate admin user.
let (transaction_mode, server_info) = if admin {
let (transaction_mode, mut server_parameters) = if admin {
let config = get_config();

// Compare server and client hashes.
Expand All @@ -521,7 +521,7 @@ where
return Err(error);
}

(false, generate_server_info_for_admin())
(false, generate_server_parameters_for_admin())
}
// Authenticate normal user.
else {
Expand Down Expand Up @@ -654,13 +654,16 @@ where
}
}

(transaction_mode, pool.server_info())
(transaction_mode, pool.server_parameters())
};

// Update the parameters to merge what the application sent and what's originally on the server
server_parameters.set_from_hashmap(&parameters, false);

debug!("Password authentication successful");

auth_ok(&mut write).await?;
write_all(&mut write, server_info).await?;
write_all(&mut write, (&server_parameters).into()).await?;
backend_key_data(&mut write, process_id, secret_key).await?;
ready_for_query(&mut write).await?;

Expand Down Expand Up @@ -690,7 +693,7 @@ where
last_server_stats: None,
pool_name: pool_name.clone(),
username: username.clone(),
application_name: application_name.to_string(),
server_parameters,
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
Expand Down Expand Up @@ -725,7 +728,7 @@ where
last_server_stats: None,
pool_name: String::from("undefined"),
username: String::from("undefined"),
application_name: String::from("undefined"),
server_parameters: ServerParameters::new(),
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
Expand Down Expand Up @@ -774,8 +777,11 @@ where
let mut prepared_statement = None;
let mut will_prepare = false;

let client_identifier =
ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name);
let client_identifier = ClientIdentifier::new(
&self.server_parameters.get_application_name(),
&self.username,
&self.pool_name,
);

// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
Expand Down Expand Up @@ -1115,10 +1121,7 @@ where
server.address()
);

// TODO: investigate other parameters and set them too.

// Set application_name.
server.set_name(&self.application_name).await?;
server.sync_parameters(&self.server_parameters).await?;

let mut initial_message = Some(message);

Expand Down Expand Up @@ -1296,7 +1299,9 @@ where
if !server.in_transaction() {
// Report transaction executed statistics.
self.stats.transaction();
server.stats().transaction(&self.application_name);
server
.stats()
.transaction(&self.server_parameters.get_application_name());

// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
Expand Down Expand Up @@ -1446,7 +1451,9 @@ where

if !server.in_transaction() {
self.stats.transaction();
server.stats().transaction(&self.application_name);
server
.stats()
.transaction(&self.server_parameters.get_application_name());

// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
Expand Down Expand Up @@ -1495,7 +1502,9 @@ where

if !server.in_transaction() {
self.stats.transaction();
server.stats().transaction(&self.application_name);
server
.stats()
.transaction(self.server_parameters.get_application_name());

// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
Expand Down Expand Up @@ -1547,7 +1556,9 @@ where

Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.pool_name, self.username, self.application_name
self.pool_name,
self.username,
self.server_parameters.get_application_name()
)))
}
}
Expand Down Expand Up @@ -1704,7 +1715,7 @@ where
client_stats.query();
server.stats().query(
Instant::now().duration_since(query_start).as_millis() as u64,
&self.application_name,
&self.server_parameters.get_application_name(),
);

Ok(())
Expand Down Expand Up @@ -1733,38 +1744,18 @@ where
pool: &ConnectionPool,
client_stats: &ClientStats,
) -> Result<BytesMut, Error> {
if pool.settings.user.statement_timeout > 0 {
match tokio::time::timeout(
tokio::time::Duration::from_millis(pool.settings.user.statement_timeout),
server.recv(),
)
.await
{
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
)
.await?;
Err(err)
}
},
Err(_) => {
error!(
"Statement timeout while talking to {:?} with user {}",
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
}
} else {
match server.recv().await {
let statement_timeout_duration = match pool.settings.user.statement_timeout {
0 => tokio::time::Duration::MAX,
timeout => tokio::time::Duration::from_millis(timeout),
};

match tokio::time::timeout(
statement_timeout_duration,
server.recv(Some(&mut self.server_parameters)),
)
.await
{
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
Expand All @@ -1775,6 +1766,16 @@ where
.await?;
Err(err)
}
},
Err(_) => {
error!(
"Statement timeout while talking to {:?} with user {}",
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
}
}
Expand Down
19 changes: 19 additions & 0 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ where
bytes.put_slice(user.as_bytes());
bytes.put_u8(0);

// Application name
bytes.put(&b"application_name\0"[..]);
bytes.put_slice(&b"pgcat\0"[..]);

// Database
bytes.put(&b"database\0"[..]);
bytes.put_slice(database.as_bytes());
Expand Down Expand Up @@ -731,6 +735,21 @@ impl BytesMutReader for Cursor<&BytesMut> {
}
}

impl BytesMutReader for BytesMut {
/// Should only be used when reading strings from the message protocol.
/// Can be used to read multiple strings from the same message which are separated by the null byte
fn read_string(&mut self) -> Result<String, Error> {
let null_index = self.iter().position(|&byte| byte == b'\0');

match null_index {
Some(index) => {
let string_bytes = self.split_to(index + 1);
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
}
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
}
}
}
/// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
Expand Down
2 changes: 1 addition & 1 deletion src/mirrors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl MirroredClient {
}

// Incoming data from server (we read to clear the socket buffer and discard the data)
recv_result = server.recv() => {
recv_result = server.recv(None) => {
match recv_result {
Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()),
Err(err) => {
Expand Down
26 changes: 12 additions & 14 deletions src/pool.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use bytes::{BufMut, BytesMut};
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
Expand All @@ -25,7 +24,7 @@ use crate::errors::Error;

use crate::auth_passthrough::AuthPassthrough;
use crate::plugins::prewarmer;
use crate::server::Server;
use crate::server::{Server, ServerParameters};
use crate::sharding::ShardingFunction;
use crate::stats::{AddressStats, ClientStats, ServerStats};

Expand Down Expand Up @@ -196,10 +195,10 @@ pub struct ConnectionPool {
/// that should not be queried.
banlist: BanList,

/// The server information (K messages) have to be passed to the
/// The server information has to be passed to the
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
server_info: Arc<RwLock<BytesMut>>,
/// on pool creation and save the startup parameters here.
original_server_parameters: Arc<RwLock<ServerParameters>>,

/// Pool configuration.
pub settings: PoolSettings,
Expand Down Expand Up @@ -445,7 +444,7 @@ impl ConnectionPool {
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: match user.pool_mode {
Expand Down Expand Up @@ -528,7 +527,7 @@ impl ConnectionPool {
for server in 0..self.servers(shard) {
let databases = self.databases.clone();
let validated = Arc::clone(&validated);
let pool_server_info = Arc::clone(&self.server_info);
let pool_server_parameters = Arc::clone(&self.original_server_parameters);

let task = tokio::task::spawn(async move {
let connection = match databases[shard][server].get().await {
Expand All @@ -541,11 +540,10 @@ impl ConnectionPool {

let proxy = connection;
let server = &*proxy;
let server_info = server.server_info();
let server_parameters: ServerParameters = server.server_parameters();

let mut guard = pool_server_info.write();
guard.clear();
guard.put(server_info.clone());
let mut guard = pool_server_parameters.write();
*guard = server_parameters;
validated.store(true, Ordering::Relaxed);
});

Expand All @@ -557,7 +555,7 @@ impl ConnectionPool {

// TODO: compare server information to make sure
// all shards are running identical configurations.
if self.server_info.read().is_empty() {
if !self.validated() {
error!("Could not validate connection pool");
return Err(Error::AllServersDown);
}
Expand Down Expand Up @@ -917,8 +915,8 @@ impl ConnectionPool {
&self.addresses[shard][server]
}

pub fn server_info(&self) -> BytesMut {
self.server_info.read().clone()
pub fn server_parameters(&self) -> ServerParameters {
self.original_server_parameters.read().clone()
}

fn busy_connection_count(&self, address: &Address) -> u32 {
Expand Down
Loading

0 comments on commit f94ce97

Please sign in to comment.