Skip to content

Commit

Permalink
Support unnamed prepared statements (#635)
Browse files Browse the repository at this point in the history
* Add golang test suite to reproduce issue with unnamed parameterized prepared statements

* Allow caching of unnamed prepared statements

* Passthrough describe on portals

* Remove unneeded kill

* Update Dockerfile.ci with golang

* Move out update of Dockerfiles to separate PR
  • Loading branch information
sjuls authored Nov 9, 2023
1 parent b45c6b1 commit 7c37da2
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 19 deletions.
9 changes: 9 additions & 0 deletions .circleci/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ cd ../..
pip3 install -r tests/python/requirements.txt
python3 tests/python/tests.py || exit 1


#
# Go tests
# Starts its own pgcat server
#
pushd tests/go
/usr/local/go/bin/go test || exit 1
popd

start_pgcat "info"

# Admin tests
Expand Down
33 changes: 15 additions & 18 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1704,18 +1704,14 @@ where
/// and also the pool's statement cache. Add it to extended protocol data.
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
let client_given_name = match self.prepared_statements_enabled {
true => Parse::get_name(&message)?,
false => "".to_string(),
};

if client_given_name.is_empty() {
if !self.prepared_statements_enabled {
debug!("Anonymous parse message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_parse(message, None));
return Ok(());
}

let client_given_name = Parse::get_name(&message)?;
let parse: Parse = (&message).try_into()?;

// Compute the hash of the parse statement
Expand Down Expand Up @@ -1753,18 +1749,15 @@ where
/// saved in the client cache.
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
let client_given_name = match self.prepared_statements_enabled {
true => Bind::get_name(&message)?,
false => "".to_string(),
};

if client_given_name.is_empty() {
if !self.prepared_statements_enabled {
debug!("Anonymous bind message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_bind(message, None));
return Ok(());
}

let client_given_name = Bind::get_name(&message)?;

match self.prepared_statements.get(&client_given_name) {
Some((rewritten_parse, _)) => {
let message = Bind::rename(message, &rewritten_parse.name)?;
Expand Down Expand Up @@ -1807,19 +1800,23 @@ where
/// saved in the client cache.
async fn buffer_describe(&mut self, message: BytesMut) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
let describe: Describe = match self.prepared_statements_enabled {
true => (&message).try_into()?,
false => Describe::empty_new(),
};

if describe.anonymous() {
if !self.prepared_statements_enabled {
debug!("Anonymous describe message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_describe(message, None));

return Ok(());
}

let describe: Describe = (&message).try_into()?;
if describe.target == 'P' {
debug!("Portal describe message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_describe(message, None));

return Ok(());
}

let client_given_name = describe.statement_name.clone();

match self.prepared_statements.get(&client_given_name) {
Expand Down
2 changes: 1 addition & 1 deletion src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ pub struct Describe {

#[allow(dead_code)]
len: i32,
target: char,
pub target: char,
pub statement_name: String,
}

Expand Down
5 changes: 5 additions & 0 deletions tests/go/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module pgcat

go 1.21

require github.com/lib/pq v1.10.9
2 changes: 2 additions & 0 deletions tests/go/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
162 changes: 162 additions & 0 deletions tests/go/pgcat.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#
# PgCat config example.
#

#
# General pooler settings
[general]
# What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0"

# Port to run on, same as PgBouncer used in this example.
port = "${PORT}"

# Whether to enable prometheus exporter or not.
enable_prometheus_exporter = true

# Port at which prometheus exporter listens on.
prometheus_exporter_port = 9930

# How long to wait before aborting a server connection (ms).
connect_timeout = 1000

# How much time to give the health check query to return with a result (ms).
healthcheck_timeout = 1000

# How long to keep connection available for immediate re-use, without running a healthcheck query on it
healthcheck_delay = 30000

# How much time to give clients during shutdown before forcibly killing client connections (ms).
shutdown_timeout = 5000

# For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds

# If we should log client connections
log_client_connections = false

# If we should log client disconnections
log_client_disconnections = false

# Reload config automatically if it changes.
autoreload = 15000

server_round_robin = false

# TLS
tls_certificate = "../../.circleci/server.cert"
tls_private_key = "../../.circleci/server.key"

# Credentials to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
admin_username = "admin_user"
admin_password = "admin_pass"

# pool
# configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting
# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db"
[pools.sharded_db]
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"

# If the client doesn't specify, route traffic to
# this role by default.
#
# any: round-robin between primary and replicas,
# replica: round-robin between replicas only without touching the primary,
# primary: all queries go to the primary unless otherwise specified.
default_role = "any"

# Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary.
query_parser_enabled = true

# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true

# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
primary_reads_enabled = true

# So what if you wanted to implement a different hashing function,
# or you've already built one and you want this pooler to use it?
#
# Current options:
#
# pg_bigint_hash: PARTITION BY HASH (Postgres hashing function)
# sha1: A hashing function based on SHA1
#
sharding_function = "pg_bigint_hash"

# Prepared statements cache size.
prepared_statements_cache_size = 500

# Credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
username = "sharding_user"
password = "sharding_user"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 5
statement_timeout = 0


[pools.sharded_db.users.1]
username = "other_user"
password = "other_user"
pool_size = 21
statement_timeout = 30000

# Shard 0
[pools.sharded_db.shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
# Database name (e.g. "postgres")
database = "shard0"

[pools.sharded_db.shards.1]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard1"

[pools.sharded_db.shards.2]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard2"


[pools.simple_db]
pool_mode = "session"
default_role = "primary"
query_parser_enabled = true
query_parser_read_write_splitting = true
primary_reads_enabled = true
sharding_function = "pg_bigint_hash"

[pools.simple_db.users.0]
username = "simple_user"
password = "simple_user"
pool_size = 5
statement_timeout = 30000

[pools.simple_db.shards.0]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
database = "some_db"
52 changes: 52 additions & 0 deletions tests/go/prepared_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package pgcat

import (
"context"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"testing"
)

func Test(t *testing.T) {
t.Cleanup(setup(t))
t.Run("Named parameterized prepared statement works", namedParameterizedPreparedStatement)
t.Run("Unnamed parameterized prepared statement works", unnamedParameterizedPreparedStatement)
}

func namedParameterizedPreparedStatement(t *testing.T) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
t.Fatalf("could not open connection: %+v", err)
}

stmt, err := db.Prepare("SELECT $1")

if err != nil {
t.Fatalf("could not prepare: %+v", err)
}

for i := 0; i < 100; i++ {
rows, err := stmt.Query(1)
if err != nil {
t.Fatalf("could not query: %+v", err)
}
_ = rows.Close()
}
}

func unnamedParameterizedPreparedStatement(t *testing.T) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
t.Fatalf("could not open connection: %+v", err)
}

for i := 0; i < 100; i++ {
// Under the hood QueryContext generates an unnamed parameterized prepared statement
rows, err := db.QueryContext(context.Background(), "SELECT $1", 1)
if err != nil {
t.Fatalf("could not query: %+v", err)
}
_ = rows.Close()
}
}
Loading

0 comments on commit 7c37da2

Please sign in to comment.