From 396c093533fb5754e5b1317a194d1d1082aaf654 Mon Sep 17 00:00:00 2001 From: Gari Singh Date: Fri, 5 Jul 2019 11:16:58 -0400 Subject: [PATCH] FABC-848 Fix TLS issue with PostgreSQL CreateTables was failing with TLS enabled on the PostgreSQL server but Connect() handles TLS properly. Modified the code to set the Postgres.datasource property when setting TLS parameters rather than using using a function-scoped variable. Change-Id: I936ba48aeed3f1d62a623f9e08d3ec3f6e5f61bc Signed-off-by: Gari Singh --- lib/server/db/postgres/internal_test.go | 11 +++++++++++ lib/server/db/postgres/postgres.go | 14 ++++++-------- lib/server/db/postgres/postgres_test.go | 23 +++++++++++++++++++++++ 3 files changed, 40 insertions(+), 8 deletions(-) create mode 100644 lib/server/db/postgres/internal_test.go diff --git a/lib/server/db/postgres/internal_test.go b/lib/server/db/postgres/internal_test.go new file mode 100644 index 000000000..a09156c1b --- /dev/null +++ b/lib/server/db/postgres/internal_test.go @@ -0,0 +1,11 @@ +/* +Copyright IBM Corp. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package postgres + +func (p *Postgres) Datasource() string { + return p.datasource +} diff --git a/lib/server/db/postgres/postgres.go b/lib/server/db/postgres/postgres.go index 7804bbf61..dffbe2f4e 100644 --- a/lib/server/db/postgres/postgres.go +++ b/lib/server/db/postgres/postgres.go @@ -50,10 +50,9 @@ func NewDB( // Connect connects to a PostgreSQL server func (p *Postgres) Connect() error { - datasource := p.datasource clientTLSConfig := p.TLS - p.dbName = util.GetDBName(datasource) + p.dbName = util.GetDBName(p.datasource) dbName := p.dbName log.Debugf("Database Name: %s", dbName) @@ -67,11 +66,11 @@ func (p *Postgres) Connect() error { } root := clientTLSConfig.CertFiles[0] - datasource = fmt.Sprintf("%s sslrootcert=%s", datasource, root) + p.datasource = fmt.Sprintf("%s sslrootcert=%s", p.datasource, root) cert := clientTLSConfig.Client.CertFile key := clientTLSConfig.Client.KeyFile - datasource = fmt.Sprintf("%s sslcert=%s sslkey=%s", datasource, cert, key) + p.datasource = fmt.Sprintf("%s sslcert=%s sslkey=%s", p.datasource, cert, key) } dbNames := []string{dbName, "postgres", "template1"} @@ -79,7 +78,7 @@ func (p *Postgres) Connect() error { var err error for _, dbName := range dbNames { - connStr := getConnStr(datasource, dbName) + connStr := getConnStr(p.datasource, dbName) log.Debugf("Connecting to PostgreSQL server, using connection string: %s", util.MaskDBCred(connStr)) sqlxdb, err = sqlx.Connect("postgres", connStr) @@ -122,14 +121,13 @@ func (p *Postgres) Create() (*db.DB, error) { // CreateDatabase creates database func (p *Postgres) CreateDatabase() (*db.DB, error) { dbName := p.dbName - datasource := p.datasource err := p.createDatabase() if err != nil { return nil, errors.Wrap(err, "Failed to create Postgres database") } - log.Debugf("Connecting to database '%s', using connection string: '%s'", dbName, util.MaskDBCred(datasource)) - sqlxdb, err := sqlx.Open("postgres", datasource) + log.Debugf("Connecting to database '%s', using connection string: '%s'", dbName, util.MaskDBCred(p.datasource)) + sqlxdb, err := sqlx.Open("postgres", p.datasource) if err != nil { return nil, errors.Wrapf(err, "Failed to open database '%s' in Postgres server", dbName) } diff --git a/lib/server/db/postgres/postgres_test.go b/lib/server/db/postgres/postgres_test.go index df6f5b878..eddf22222 100644 --- a/lib/server/db/postgres/postgres_test.go +++ b/lib/server/db/postgres/postgres_test.go @@ -73,6 +73,29 @@ var _ = Describe("Postgres", func() { Expect(db.SqlxDB).To(BeNil()) }) + It("has datasource with TLS connection parameters when TLS is enabled", func() { + db.TLS = &tls.ClientTLSConfig{ + Enabled: true, + CertFiles: []string{"root.pem"}, + Client: tls.KeyCertFiles{ + KeyFile: "key.pem", + CertFile: "cert.pem", + }, + } + db.Connect() + Expect(db.Datasource()).To( + ContainSubstring("sslrootcert=root.pem sslcert=cert.pem sslkey=key.pem"), + ) + }) + + It("does not have has datasource with TLS connection parameters when TLS is enabled", func() { + db.TLS = &tls.ClientTLSConfig{ + Enabled: false, + } + db.Connect() + Expect(db.Datasource()).ToNot(ContainSubstring("sslrootcert")) + }) + It("fail to open database connection if unable to ping database", func() { err := db.Connect() Expect(err).To(HaveOccurred())