Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unnamed prepared statements #635

Merged
merged 7 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .circleci/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ cd ../..
pip3 install -r tests/python/requirements.txt
python3 tests/python/tests.py || exit 1


#
# Go tests
# Starts its own pgcat server
#
pushd tests/go
/usr/local/go/bin/go test || exit 1
popd

start_pgcat "info"

# Admin tests
Expand Down
33 changes: 15 additions & 18 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1704,18 +1704,14 @@ where
/// and also the pool's statement cache. Add it to extended protocol data.
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
let client_given_name = match self.prepared_statements_enabled {
true => Parse::get_name(&message)?,
false => "".to_string(),
};

if client_given_name.is_empty() {
if !self.prepared_statements_enabled {
debug!("Anonymous parse message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_parse(message, None));
return Ok(());
}

let client_given_name = Parse::get_name(&message)?;
let parse: Parse = (&message).try_into()?;

// Compute the hash of the parse statement
Expand Down Expand Up @@ -1753,18 +1749,15 @@ where
/// saved in the client cache.
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
let client_given_name = match self.prepared_statements_enabled {
true => Bind::get_name(&message)?,
false => "".to_string(),
};

if client_given_name.is_empty() {
if !self.prepared_statements_enabled {
debug!("Anonymous bind message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_bind(message, None));
return Ok(());
}

let client_given_name = Bind::get_name(&message)?;

match self.prepared_statements.get(&client_given_name) {
Some((rewritten_parse, _)) => {
let message = Bind::rename(message, &rewritten_parse.name)?;
Expand Down Expand Up @@ -1807,19 +1800,23 @@ where
/// saved in the client cache.
async fn buffer_describe(&mut self, message: BytesMut) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
let describe: Describe = match self.prepared_statements_enabled {
true => (&message).try_into()?,
false => Describe::empty_new(),
};

if describe.anonymous() {
if !self.prepared_statements_enabled {
debug!("Anonymous describe message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_describe(message, None));

return Ok(());
}

let describe: Describe = (&message).try_into()?;
if describe.target == 'P' {
debug!("Portal describe message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_describe(message, None));

return Ok(());
}

let client_given_name = describe.statement_name.clone();

match self.prepared_statements.get(&client_given_name) {
Expand Down
2 changes: 1 addition & 1 deletion src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ pub struct Describe {

#[allow(dead_code)]
len: i32,
target: char,
pub target: char,
pub statement_name: String,
}

Expand Down
5 changes: 5 additions & 0 deletions tests/go/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module pgcat

go 1.21

require github.com/lib/pq v1.10.9
2 changes: 2 additions & 0 deletions tests/go/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
162 changes: 162 additions & 0 deletions tests/go/pgcat.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#
# PgCat config example.
#

#
# General pooler settings
[general]
# What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0"

# Port to run on, same as PgBouncer used in this example.
port = "${PORT}"

# Whether to enable prometheus exporter or not.
enable_prometheus_exporter = true

# Port at which prometheus exporter listens on.
prometheus_exporter_port = 9930

# How long to wait before aborting a server connection (ms).
connect_timeout = 1000

# How much time to give the health check query to return with a result (ms).
healthcheck_timeout = 1000

# How long to keep connection available for immediate re-use, without running a healthcheck query on it
healthcheck_delay = 30000

# How much time to give clients during shutdown before forcibly killing client connections (ms).
shutdown_timeout = 5000

# For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds

# If we should log client connections
log_client_connections = false

# If we should log client disconnections
log_client_disconnections = false

# Reload config automatically if it changes.
autoreload = 15000

server_round_robin = false

# TLS
tls_certificate = "../../.circleci/server.cert"
tls_private_key = "../../.circleci/server.key"

# Credentials to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
admin_username = "admin_user"
admin_password = "admin_pass"

# pool
# configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting
# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db"
[pools.sharded_db]
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"

# If the client doesn't specify, route traffic to
# this role by default.
#
# any: round-robin between primary and replicas,
# replica: round-robin between replicas only without touching the primary,
# primary: all queries go to the primary unless otherwise specified.
default_role = "any"

# Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary.
query_parser_enabled = true

# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true

# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
primary_reads_enabled = true

# So what if you wanted to implement a different hashing function,
# or you've already built one and you want this pooler to use it?
#
# Current options:
#
# pg_bigint_hash: PARTITION BY HASH (Postgres hashing function)
# sha1: A hashing function based on SHA1
#
sharding_function = "pg_bigint_hash"

# Prepared statements cache size.
prepared_statements_cache_size = 500

# Credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
username = "sharding_user"
password = "sharding_user"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 5
statement_timeout = 0


[pools.sharded_db.users.1]
username = "other_user"
password = "other_user"
pool_size = 21
statement_timeout = 30000

# Shard 0
[pools.sharded_db.shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
# Database name (e.g. "postgres")
database = "shard0"

[pools.sharded_db.shards.1]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard1"

[pools.sharded_db.shards.2]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard2"


[pools.simple_db]
pool_mode = "session"
default_role = "primary"
query_parser_enabled = true
query_parser_read_write_splitting = true
primary_reads_enabled = true
sharding_function = "pg_bigint_hash"

[pools.simple_db.users.0]
username = "simple_user"
password = "simple_user"
pool_size = 5
statement_timeout = 30000

[pools.simple_db.shards.0]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
database = "some_db"
52 changes: 52 additions & 0 deletions tests/go/prepared_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package pgcat

import (
"context"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"testing"
)

func Test(t *testing.T) {
t.Cleanup(setup(t))
t.Run("Named parameterized prepared statement works", namedParameterizedPreparedStatement)
t.Run("Unnamed parameterized prepared statement works", unnamedParameterizedPreparedStatement)
}

func namedParameterizedPreparedStatement(t *testing.T) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
t.Fatalf("could not open connection: %+v", err)
}

stmt, err := db.Prepare("SELECT $1")

if err != nil {
t.Fatalf("could not prepare: %+v", err)
}

for i := 0; i < 100; i++ {
rows, err := stmt.Query(1)
if err != nil {
t.Fatalf("could not query: %+v", err)
}
_ = rows.Close()
}
}

func unnamedParameterizedPreparedStatement(t *testing.T) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
t.Fatalf("could not open connection: %+v", err)
}

for i := 0; i < 100; i++ {
// Under the hood QueryContext generates an unnamed parameterized prepared statement
rows, err := db.QueryContext(context.Background(), "SELECT $1", 1)
if err != nil {
t.Fatalf("could not query: %+v", err)
}
_ = rows.Close()
}
}
Loading