From 7c37da2fade58d0e59f2e429827fe2a1a9b2b482 Mon Sep 17 00:00:00 2001 From: Jakob Schultz-Falk Date: Thu, 9 Nov 2023 01:36:45 +0100 Subject: [PATCH] Support unnamed prepared statements (#635) * Add golang test suite to reproduce issue with unnamed parameterized prepared statements * Allow caching of unnamed prepared statements * Passthrough describe on portals * Remove unneeded kill * Update Dockerfile.ci with golang * Move out update of Dockerfiles to separate PR --- .circleci/run_tests.sh | 9 +++ src/client.rs | 33 ++++---- src/messages.rs | 2 +- tests/go/go.mod | 5 ++ tests/go/go.sum | 2 + tests/go/pgcat.toml | 162 ++++++++++++++++++++++++++++++++++++++ tests/go/prepared_test.go | 52 ++++++++++++ tests/go/setup.go | 81 +++++++++++++++++++ 8 files changed, 327 insertions(+), 19 deletions(-) create mode 100644 tests/go/go.mod create mode 100644 tests/go/go.sum create mode 100644 tests/go/pgcat.toml create mode 100644 tests/go/prepared_test.go create mode 100644 tests/go/setup.go diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 4ba497c3..3a31240a 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -108,6 +108,15 @@ cd ../.. pip3 install -r tests/python/requirements.txt python3 tests/python/tests.py || exit 1 + +# +# Go tests +# Starts its own pgcat server +# +pushd tests/go +/usr/local/go/bin/go test || exit 1 +popd + start_pgcat "info" # Admin tests diff --git a/src/client.rs b/src/client.rs index 31dcb4bd..dd89697c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1704,18 +1704,14 @@ where /// and also the pool's statement cache. Add it to extended protocol data. fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> { // Avoid parsing if prepared statements not enabled - let client_given_name = match self.prepared_statements_enabled { - true => Parse::get_name(&message)?, - false => "".to_string(), - }; - - if client_given_name.is_empty() { + if !self.prepared_statements_enabled { debug!("Anonymous parse message"); self.extended_protocol_data_buffer .push_back(ExtendedProtocolData::create_new_parse(message, None)); return Ok(()); } + let client_given_name = Parse::get_name(&message)?; let parse: Parse = (&message).try_into()?; // Compute the hash of the parse statement @@ -1753,18 +1749,15 @@ where /// saved in the client cache. async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> { // Avoid parsing if prepared statements not enabled - let client_given_name = match self.prepared_statements_enabled { - true => Bind::get_name(&message)?, - false => "".to_string(), - }; - - if client_given_name.is_empty() { + if !self.prepared_statements_enabled { debug!("Anonymous bind message"); self.extended_protocol_data_buffer .push_back(ExtendedProtocolData::create_new_bind(message, None)); return Ok(()); } + let client_given_name = Bind::get_name(&message)?; + match self.prepared_statements.get(&client_given_name) { Some((rewritten_parse, _)) => { let message = Bind::rename(message, &rewritten_parse.name)?; @@ -1807,12 +1800,7 @@ where /// saved in the client cache. async fn buffer_describe(&mut self, message: BytesMut) -> Result<(), Error> { // Avoid parsing if prepared statements not enabled - let describe: Describe = match self.prepared_statements_enabled { - true => (&message).try_into()?, - false => Describe::empty_new(), - }; - - if describe.anonymous() { + if !self.prepared_statements_enabled { debug!("Anonymous describe message"); self.extended_protocol_data_buffer .push_back(ExtendedProtocolData::create_new_describe(message, None)); @@ -1820,6 +1808,15 @@ where return Ok(()); } + let describe: Describe = (&message).try_into()?; + if describe.target == 'P' { + debug!("Portal describe message"); + self.extended_protocol_data_buffer + .push_back(ExtendedProtocolData::create_new_describe(message, None)); + + return Ok(()); + } + let client_given_name = describe.statement_name.clone(); match self.prepared_statements.get(&client_given_name) { diff --git a/src/messages.rs b/src/messages.rs index 3a26f42a..4390d9f9 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1109,7 +1109,7 @@ pub struct Describe { #[allow(dead_code)] len: i32, - target: char, + pub target: char, pub statement_name: String, } diff --git a/tests/go/go.mod b/tests/go/go.mod new file mode 100644 index 00000000..faa2292e --- /dev/null +++ b/tests/go/go.mod @@ -0,0 +1,5 @@ +module pgcat + +go 1.21 + +require github.com/lib/pq v1.10.9 diff --git a/tests/go/go.sum b/tests/go/go.sum new file mode 100644 index 00000000..aeddeae3 --- /dev/null +++ b/tests/go/go.sum @@ -0,0 +1,2 @@ +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= diff --git a/tests/go/pgcat.toml b/tests/go/pgcat.toml new file mode 100644 index 00000000..72eba8a4 --- /dev/null +++ b/tests/go/pgcat.toml @@ -0,0 +1,162 @@ +# +# PgCat config example. +# + +# +# General pooler settings +[general] +# What IP to run on, 0.0.0.0 means accessible from everywhere. +host = "0.0.0.0" + +# Port to run on, same as PgBouncer used in this example. +port = "${PORT}" + +# Whether to enable prometheus exporter or not. +enable_prometheus_exporter = true + +# Port at which prometheus exporter listens on. +prometheus_exporter_port = 9930 + +# How long to wait before aborting a server connection (ms). +connect_timeout = 1000 + +# How much time to give the health check query to return with a result (ms). +healthcheck_timeout = 1000 + +# How long to keep connection available for immediate re-use, without running a healthcheck query on it +healthcheck_delay = 30000 + +# How much time to give clients during shutdown before forcibly killing client connections (ms). +shutdown_timeout = 5000 + +# For how long to ban a server if it fails a health check (seconds). +ban_time = 60 # Seconds + +# If we should log client connections +log_client_connections = false + +# If we should log client disconnections +log_client_disconnections = false + +# Reload config automatically if it changes. +autoreload = 15000 + +server_round_robin = false + +# TLS +tls_certificate = "../../.circleci/server.cert" +tls_private_key = "../../.circleci/server.key" + +# Credentials to access the virtual administrative database (pgbouncer or pgcat) +# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. +admin_username = "admin_user" +admin_password = "admin_pass" + +# pool +# configs are structured as pool. +# the pool_name is what clients use as database name when connecting +# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db" +[pools.sharded_db] +# Pool mode (see PgBouncer docs for more). +# session: one server connection per connected client +# transaction: one server connection per client transaction +pool_mode = "transaction" + +# If the client doesn't specify, route traffic to +# this role by default. +# +# any: round-robin between primary and replicas, +# replica: round-robin between replicas only without touching the primary, +# primary: all queries go to the primary unless otherwise specified. +default_role = "any" + +# Query parser. If enabled, we'll attempt to parse +# every incoming query to determine if it's a read or a write. +# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, +# we'll direct it to the primary. +query_parser_enabled = true + +# If the query parser is enabled and this setting is enabled, we'll attempt to +# infer the role from the query itself. +query_parser_read_write_splitting = true + +# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for +# load balancing of read queries. Otherwise, the primary will only be used for write +# queries. The primary can always be explicitely selected with our custom protocol. +primary_reads_enabled = true + +# So what if you wanted to implement a different hashing function, +# or you've already built one and you want this pooler to use it? +# +# Current options: +# +# pg_bigint_hash: PARTITION BY HASH (Postgres hashing function) +# sha1: A hashing function based on SHA1 +# +sharding_function = "pg_bigint_hash" + +# Prepared statements cache size. +prepared_statements_cache_size = 500 + +# Credentials for users that may connect to this cluster +[pools.sharded_db.users.0] +username = "sharding_user" +password = "sharding_user" +# Maximum number of server connections that can be established for this user +# The maximum number of connection from a single Pgcat process to any database in the cluster +# is the sum of pool_size across all users. +pool_size = 5 +statement_timeout = 0 + + +[pools.sharded_db.users.1] +username = "other_user" +password = "other_user" +pool_size = 21 +statement_timeout = 30000 + +# Shard 0 +[pools.sharded_db.shards.0] +# [ host, port, role ] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ] +] +# Database name (e.g. "postgres") +database = "shard0" + +[pools.sharded_db.shards.1] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], +] +database = "shard1" + +[pools.sharded_db.shards.2] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], +] +database = "shard2" + + +[pools.simple_db] +pool_mode = "session" +default_role = "primary" +query_parser_enabled = true +query_parser_read_write_splitting = true +primary_reads_enabled = true +sharding_function = "pg_bigint_hash" + +[pools.simple_db.users.0] +username = "simple_user" +password = "simple_user" +pool_size = 5 +statement_timeout = 30000 + +[pools.simple_db.shards.0] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ] +] +database = "some_db" diff --git a/tests/go/prepared_test.go b/tests/go/prepared_test.go new file mode 100644 index 00000000..0a42e721 --- /dev/null +++ b/tests/go/prepared_test.go @@ -0,0 +1,52 @@ +package pgcat + +import ( + "context" + "database/sql" + "fmt" + _ "github.com/lib/pq" + "testing" +) + +func Test(t *testing.T) { + t.Cleanup(setup(t)) + t.Run("Named parameterized prepared statement works", namedParameterizedPreparedStatement) + t.Run("Unnamed parameterized prepared statement works", unnamedParameterizedPreparedStatement) +} + +func namedParameterizedPreparedStatement(t *testing.T) { + db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port)) + if err != nil { + t.Fatalf("could not open connection: %+v", err) + } + + stmt, err := db.Prepare("SELECT $1") + + if err != nil { + t.Fatalf("could not prepare: %+v", err) + } + + for i := 0; i < 100; i++ { + rows, err := stmt.Query(1) + if err != nil { + t.Fatalf("could not query: %+v", err) + } + _ = rows.Close() + } +} + +func unnamedParameterizedPreparedStatement(t *testing.T) { + db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port)) + if err != nil { + t.Fatalf("could not open connection: %+v", err) + } + + for i := 0; i < 100; i++ { + // Under the hood QueryContext generates an unnamed parameterized prepared statement + rows, err := db.QueryContext(context.Background(), "SELECT $1", 1) + if err != nil { + t.Fatalf("could not query: %+v", err) + } + _ = rows.Close() + } +} diff --git a/tests/go/setup.go b/tests/go/setup.go new file mode 100644 index 00000000..32ffc4ba --- /dev/null +++ b/tests/go/setup.go @@ -0,0 +1,81 @@ +package pgcat + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "math/rand" + "os" + "os/exec" + "strings" + "testing" + "time" +) + +//go:embed pgcat.toml +var pgcatCfg string + +var port = rand.Intn(32760-20000) + 20000 + +func setup(t *testing.T) func() { + cfg, err := os.CreateTemp("/tmp", "pgcat_cfg_*.toml") + if err != nil { + t.Fatalf("could not create temp file: %+v", err) + } + + pgcatCfg = strings.Replace(pgcatCfg, "\"${PORT}\"", fmt.Sprintf("%d", port), 1) + + _, err = cfg.Write([]byte(pgcatCfg)) + if err != nil { + t.Fatalf("could not write temp file: %+v", err) + } + + commandPath := "../../target/debug/pgcat" + if os.Getenv("CARGO_TARGET_DIR") != "" { + commandPath = os.Getenv("CARGO_TARGET_DIR") + "/debug/pgcat" + } + + cmd := exec.Command(commandPath, cfg.Name()) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + go func() { + err = cmd.Run() + if err != nil { + t.Errorf("could not run pgcat: %+v", err) + } + }() + + deadline, cancelFunc := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + defer cancelFunc() + for { + select { + case <-deadline.Done(): + break + case <-time.After(50 * time.Millisecond): + db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=pgcat user=admin_user password=admin_pass sslmode=disable", port)) + if err != nil { + continue + } + rows, err := db.QueryContext(deadline, "SHOW STATS") + if err != nil { + continue + } + _ = rows.Close() + _ = db.Close() + break + } + break + } + + return func() { + err := cmd.Process.Signal(os.Interrupt) + if err != nil { + t.Fatalf("could not interrupt pgcat: %+v", err) + } + err = os.Remove(cfg.Name()) + if err != nil { + t.Fatalf("could not remove temp file: %+v", err) + } + } +}