Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle and track startup parameters #478

Merged
merged 22 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/admin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::pool::BanReason;
use crate::server::ServerParameters;
use crate::stats::pool::PoolStats;
use bytes::{Buf, BufMut, BytesMut};
use log::{error, info, trace};
Expand All @@ -17,16 +18,16 @@ use crate::pool::ClientServerMap;
use crate::pool::{get_all_pools, get_pool};
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};

pub fn generate_server_info_for_admin() -> BytesMut {
let mut server_info = BytesMut::new();
pub fn generate_server_parameters_for_admin() -> ServerParameters {
let mut server_parameters = ServerParameters::new();

server_info.put(server_parameter_message("application_name", ""));
server_info.put(server_parameter_message("client_encoding", "UTF8"));
server_info.put(server_parameter_message("server_encoding", "UTF8"));
server_info.put(server_parameter_message("server_version", VERSION));
server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
server_parameters.set_param("application_name".to_string(), "".to_string(), true);
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true);
server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true);
server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true);
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true);

server_info
server_parameters
}

/// Handle admin client.
Expand Down
107 changes: 54 additions & 53 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;

use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
Expand All @@ -22,7 +22,7 @@ use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::server::{Server, ServerParameters};
use crate::stats::{ClientStats, ServerStats};
use crate::tls::Tls;

Expand Down Expand Up @@ -96,8 +96,8 @@ pub struct Client<S, T> {
/// Postgres user for this client (This comes from the user in the connection string)
username: String,

/// Application name for this client (defaults to pgcat)
application_name: String,
/// Server startup and session parameters that we're going to track
server_parameters: ServerParameters,

/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
Expand Down Expand Up @@ -502,7 +502,7 @@ where
};

// Authenticate admin user.
let (transaction_mode, server_info) = if admin {
let (transaction_mode, mut server_parameters) = if admin {
let config = get_config();

// Compare server and client hashes.
Expand All @@ -521,7 +521,7 @@ where
return Err(error);
}

(false, generate_server_info_for_admin())
(false, generate_server_parameters_for_admin())
}
// Authenticate normal user.
else {
Expand Down Expand Up @@ -654,13 +654,16 @@ where
}
}

(transaction_mode, pool.server_info())
(transaction_mode, pool.server_parameters())
};

// Update the parameters to merge what the application sent and what's originally on the server
server_parameters.set_from_hashmap(&parameters, false);

debug!("Password authentication successful");

auth_ok(&mut write).await?;
write_all(&mut write, server_info).await?;
write_all(&mut write, (&server_parameters).into()).await?;
backend_key_data(&mut write, process_id, secret_key).await?;
ready_for_query(&mut write).await?;

Expand Down Expand Up @@ -690,7 +693,7 @@ where
last_server_stats: None,
pool_name: pool_name.clone(),
username: username.clone(),
application_name: application_name.to_string(),
server_parameters,
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
Expand Down Expand Up @@ -725,7 +728,7 @@ where
last_server_stats: None,
pool_name: String::from("undefined"),
username: String::from("undefined"),
application_name: String::from("undefined"),
server_parameters: ServerParameters::new(),
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
Expand Down Expand Up @@ -774,8 +777,11 @@ where
let mut prepared_statement = None;
let mut will_prepare = false;

let client_identifier =
ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name);
let client_identifier = ClientIdentifier::new(
&self.server_parameters.get_application_name(),
&self.username,
&self.pool_name,
);

// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
Expand Down Expand Up @@ -1115,10 +1121,7 @@ where
server.address()
);

// TODO: investigate other parameters and set them too.

// Set application_name.
server.set_name(&self.application_name).await?;
server.sync_parameters(&self.server_parameters).await?;

let mut initial_message = Some(message);

Expand Down Expand Up @@ -1296,7 +1299,9 @@ where
if !server.in_transaction() {
// Report transaction executed statistics.
self.stats.transaction();
server.stats().transaction(&self.application_name);
server
.stats()
.transaction(&self.server_parameters.get_application_name());

// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
Expand Down Expand Up @@ -1446,7 +1451,9 @@ where

if !server.in_transaction() {
self.stats.transaction();
server.stats().transaction(&self.application_name);
server
.stats()
.transaction(&self.server_parameters.get_application_name());

// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
Expand Down Expand Up @@ -1495,7 +1502,9 @@ where

if !server.in_transaction() {
self.stats.transaction();
server.stats().transaction(&self.application_name);
server
.stats()
.transaction(self.server_parameters.get_application_name());

// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
Expand Down Expand Up @@ -1547,7 +1556,9 @@ where

Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.pool_name, self.username, self.application_name
self.pool_name,
self.username,
self.server_parameters.get_application_name()
)))
}
}
Expand Down Expand Up @@ -1704,7 +1715,7 @@ where
client_stats.query();
server.stats().query(
Instant::now().duration_since(query_start).as_millis() as u64,
&self.application_name,
&self.server_parameters.get_application_name(),
);

Ok(())
Expand Down Expand Up @@ -1733,38 +1744,18 @@ where
pool: &ConnectionPool,
client_stats: &ClientStats,
) -> Result<BytesMut, Error> {
if pool.settings.user.statement_timeout > 0 {
match tokio::time::timeout(
tokio::time::Duration::from_millis(pool.settings.user.statement_timeout),
server.recv(),
)
.await
{
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
)
.await?;
Err(err)
}
},
Err(_) => {
error!(
"Statement timeout while talking to {:?} with user {}",
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
}
} else {
match server.recv().await {
let statement_timeout_duration = match pool.settings.user.statement_timeout {
0 => tokio::time::Duration::MAX,
timeout => tokio::time::Duration::from_millis(timeout),
};

match tokio::time::timeout(
statement_timeout_duration,
server.recv(Some(&mut self.server_parameters)),
levkk marked this conversation as resolved.
Show resolved Hide resolved
)
.await
{
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
Expand All @@ -1775,6 +1766,16 @@ where
.await?;
Err(err)
}
},
Err(_) => {
error!(
"Statement timeout while talking to {:?} with user {}",
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
}
}
Expand Down
19 changes: 19 additions & 0 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ where
bytes.put_slice(user.as_bytes());
bytes.put_u8(0);

// Application name
bytes.put(&b"application_name\0"[..]);
bytes.put_slice(&b"pgcat\0"[..]);

// Database
bytes.put(&b"database\0"[..]);
bytes.put_slice(database.as_bytes());
Expand Down Expand Up @@ -731,6 +735,21 @@ impl BytesMutReader for Cursor<&BytesMut> {
}
}

impl BytesMutReader for BytesMut {
/// Should only be used when reading strings from the message protocol.
/// Can be used to read multiple strings from the same message which are separated by the null byte
fn read_string(&mut self) -> Result<String, Error> {
let null_index = self.iter().position(|&byte| byte == b'\0');

match null_index {
Some(index) => {
let string_bytes = self.split_to(index + 1);
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
}
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
}
}
}
/// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
Expand Down
2 changes: 1 addition & 1 deletion src/mirrors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl MirroredClient {
}

// Incoming data from server (we read to clear the socket buffer and discard the data)
recv_result = server.recv() => {
recv_result = server.recv(None) => {
match recv_result {
Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()),
Err(err) => {
Expand Down
26 changes: 12 additions & 14 deletions src/pool.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use bytes::{BufMut, BytesMut};
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
Expand All @@ -25,7 +24,7 @@ use crate::errors::Error;

use crate::auth_passthrough::AuthPassthrough;
use crate::plugins::prewarmer;
use crate::server::Server;
use crate::server::{Server, ServerParameters};
use crate::sharding::ShardingFunction;
use crate::stats::{AddressStats, ClientStats, ServerStats};

Expand Down Expand Up @@ -196,10 +195,10 @@ pub struct ConnectionPool {
/// that should not be queried.
banlist: BanList,

/// The server information (K messages) have to be passed to the
/// The server information has to be passed to the
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
server_info: Arc<RwLock<BytesMut>>,
/// on pool creation and save the startup parameters here.
original_server_parameters: Arc<RwLock<ServerParameters>>,

/// Pool configuration.
pub settings: PoolSettings,
Expand Down Expand Up @@ -445,7 +444,7 @@ impl ConnectionPool {
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: match user.pool_mode {
Expand Down Expand Up @@ -528,7 +527,7 @@ impl ConnectionPool {
for server in 0..self.servers(shard) {
let databases = self.databases.clone();
let validated = Arc::clone(&validated);
let pool_server_info = Arc::clone(&self.server_info);
let pool_server_parameters = Arc::clone(&self.original_server_parameters);

let task = tokio::task::spawn(async move {
let connection = match databases[shard][server].get().await {
Expand All @@ -541,11 +540,10 @@ impl ConnectionPool {

let proxy = connection;
let server = &*proxy;
let server_info = server.server_info();
let server_parameters: ServerParameters = server.server_parameters();

let mut guard = pool_server_info.write();
guard.clear();
guard.put(server_info.clone());
let mut guard = pool_server_parameters.write();
*guard = server_parameters;
validated.store(true, Ordering::Relaxed);
});

Expand All @@ -557,7 +555,7 @@ impl ConnectionPool {

// TODO: compare server information to make sure
// all shards are running identical configurations.
if self.server_info.read().is_empty() {
if !self.validated() {
error!("Could not validate connection pool");
return Err(Error::AllServersDown);
}
Expand Down Expand Up @@ -917,8 +915,8 @@ impl ConnectionPool {
&self.addresses[shard][server]
}

pub fn server_info(&self) -> BytesMut {
self.server_info.read().clone()
pub fn server_parameters(&self) -> ServerParameters {
self.original_server_parameters.read().clone()
}

fn busy_connection_count(&self, address: &Address) -> u32 {
Expand Down
Loading