Skip to content

Commit

Permalink
service/rds/rdsutils: adding more clarity to rdsutils.BuildAuthToken …
Browse files Browse the repository at this point in the history
…with an example and cl… (#1985)

* Adding more clarity to rdsutils.BuildAuthToken with an example and cleaning up the docs

* Adding builer to allow for generating auth tokens more easily
  • Loading branch information
xibz authored Jun 12, 2018
1 parent 43ba333 commit 44b48b9
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
### SDK Features

### SDK Enhancements
* `service/rds/rdsutils`: Clean up the rdsutils package and adds a new builder to construct connection strings [#1985](https://github.com/aws/aws-sdk-go/pull/1985)
* Rewords documentation to be more useful and provides links to prior setup needed to support authentication tokens. Introduces a builder that allows for building connection strings

### SDK Bugs
127 changes: 127 additions & 0 deletions service/rds/rdsutils/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package rdsutils

import (
"fmt"
"net/url"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
)

// ConnectionFormat is the type of connection that will be
// used to connect to the database
type ConnectionFormat string

// ConnectionFormat enums
const (
NoConnectionFormat ConnectionFormat = ""
TCPFormat ConnectionFormat = "tcp"
)

// ErrNoConnectionFormat will be returned during build if no format had been
// specified
var ErrNoConnectionFormat = awserr.New("NoConnectionFormat", "No connection format was specified", nil)

// ConnectionStringBuilder is a builder that will construct a connection
// string with the provided parameters. params field is required to have
// a tls specification and allowCleartextPasswords must be set to true.
type ConnectionStringBuilder struct {
dbName string
endpoint string
region string
user string
creds *credentials.Credentials

connectFormat ConnectionFormat
params url.Values
}

// NewConnectionStringBuilder will return an ConnectionStringBuilder
func NewConnectionStringBuilder(endpoint, region, dbUser, dbName string, creds *credentials.Credentials) ConnectionStringBuilder {
return ConnectionStringBuilder{
dbName: dbName,
endpoint: endpoint,
region: region,
user: dbUser,
creds: creds,
}
}

// WithEndpoint will return a builder with the given endpoint
func (b ConnectionStringBuilder) WithEndpoint(endpoint string) ConnectionStringBuilder {
b.endpoint = endpoint
return b
}

// WithRegion will return a builder with the given region
func (b ConnectionStringBuilder) WithRegion(region string) ConnectionStringBuilder {
b.region = region
return b
}

// WithUser will return a builder with the given user
func (b ConnectionStringBuilder) WithUser(user string) ConnectionStringBuilder {
b.user = user
return b
}

// WithDBName will return a builder with the given database name
func (b ConnectionStringBuilder) WithDBName(dbName string) ConnectionStringBuilder {
b.dbName = dbName
return b
}

// WithParams will return a builder with the given params. The parameters
// will be included in the connection query string
//
// Example:
// v := url.Values{}
// v.Add("tls", "rds")
// b := rdsutils.NewConnectionBuilder(endpoint, region, user, dbname, creds)
// connectStr, err := b.WithParams(v).WithTCPFormat().Build()
func (b ConnectionStringBuilder) WithParams(params url.Values) ConnectionStringBuilder {
b.params = params
return b
}

// WithFormat will return a builder with the given connection format
func (b ConnectionStringBuilder) WithFormat(f ConnectionFormat) ConnectionStringBuilder {
b.connectFormat = f
return b
}

// WithTCPFormat will set the format to TCP and return the modified builder
func (b ConnectionStringBuilder) WithTCPFormat() ConnectionStringBuilder {
return b.WithFormat(TCPFormat)
}

// Build will return a new connection string that can be used to open a connection
// to the desired database.
//
// Example:
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds)
// connectStr, err := b.WithTCPFormat().Build()
// if err != nil {
// panic(err)
// }
// const dbType = "mysql"
// db, err := sql.Open(dbType, connectStr)
func (b ConnectionStringBuilder) Build() (string, error) {
if b.connectFormat == NoConnectionFormat {
return "", ErrNoConnectionFormat
}

authToken, err := BuildAuthToken(b.endpoint, b.region, b.user, b.creds)
if err != nil {
return "", err
}

connectionStr := fmt.Sprintf("%s:%s@%s(%s)/%s",
b.user, authToken, string(b.connectFormat), b.endpoint, b.dbName,
)

if len(b.params) > 0 {
connectionStr = fmt.Sprintf("%s?%s", connectionStr, b.params.Encode())
}
return connectionStr, nil
}
58 changes: 58 additions & 0 deletions service/rds/rdsutils/builder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package rdsutils_test

import (
"net/url"
"regexp"
"testing"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/rds/rdsutils"
)

func TestConnectionStringBuilder(t *testing.T) {
cases := []struct {
user string
endpoint string
region string
dbName string
values url.Values
format rdsutils.ConnectionFormat
creds *credentials.Credentials

expectedErr error
expectedConnectRegex string
}{
{
user: "foo",
endpoint: "foo.bar",
region: "region",
dbName: "name",
format: rdsutils.NoConnectionFormat,
creds: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
expectedErr: rdsutils.ErrNoConnectionFormat,
expectedConnectRegex: "",
},
{
user: "foo",
endpoint: "foo.bar",
region: "region",
dbName: "name",
format: rdsutils.TCPFormat,
creds: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
expectedConnectRegex: `^foo:foo.bar\?Action=connect\&DBUser=foo.*\@tcp\(foo.bar\)/name`,
},
}

for _, c := range cases {
b := rdsutils.NewConnectionStringBuilder(c.endpoint, c.region, c.user, c.dbName, c.creds)
connectStr, err := b.WithFormat(c.format).Build()

if e, a := c.expectedErr, err; e != a {
t.Errorf("expected %v error, but received %v", e, a)
}

if re, a := regexp.MustCompile(c.expectedConnectRegex), connectStr; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
}
}
19 changes: 8 additions & 11 deletions service/rds/rdsutils/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@ import (
"github.com/aws/aws-sdk-go/aws/signer/v4"
)

// BuildAuthToken will return a authentication token for the database's connect
// based on the RDS database endpoint, AWS region, IAM user or role, and AWS credentials.
// BuildAuthToken will return an authorization token used as the password for a DB
// connection.
//
// Endpoint consists of the hostname and port, IE hostname:port, of the RDS database.
// Region is the AWS region the RDS database is in and where the authentication token
// will be generated for. DbUser is the IAM user or role the request will be authenticated
// for. The creds is the AWS credentials the authentication token is signed with.
//
// An error is returned if the authentication token is unable to be signed with
// the credentials, or the endpoint is not a valid URL.
// * endpoint - Endpoint consists of the port needed to connect to the DB. <host>:<port>
// * region - Region is the location of where the DB is
// * dbUser - User account within the database to sign in with
// * creds - Credentials to be signed with
//
// The following example shows how to use BuildAuthToken to create an authentication
// token for connecting to a MySQL database in RDS.
Expand All @@ -27,12 +24,12 @@ import (
//
// // Create the MySQL DNS string for the DB connection
// // user:password@protocol(endpoint)/dbname?<params>
// dnsStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=true",
// connectStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?allowCleartextPasswords=true&tls=rds",
// dbUser, authToken, dbEndpoint, dbName,
// )
//
// // Use db to perform SQL operations on database
// db, err := sql.Open("mysql", dnsStr)
// db, err := sql.Open("mysql", connectStr)
//
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
// for more information on using IAM database authentication with RDS.
Expand Down
18 changes: 18 additions & 0 deletions service/rds/rdsutils/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Package rdsutils is used to generate authentication tokens used to
// connect to a givent Amazon Relational Database Service (RDS) database.
//
// Before using the authentication please visit the docs here to ensure
// the database has the proper policies to allow for IAM token authentication.
// https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html#UsingWithRDS.IAMDBAuth.Availability
//
// When building the connection string, there are two required parameters that are needed to be set on the query.
// * tls
// * allowCleartextPasswords must be set to true
//
// Example creating a basic auth token with the builder:
// v := url.Values{}
// v.Add("tls", "tls_profile_name")
// v.Add("allowCleartextPasswords", "true")
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds)
// connectStr, err := b.WithTCPFormat().WithParams(v).Build()
package rdsutils
119 changes: 119 additions & 0 deletions service/rds/rdsutils/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// +build example,exclude

package rdsutils_test

import (
"crypto/tls"
"crypto/x509"
"database/sql"
"flag"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"

"github.com/go-sql-driver/mysql"

"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/rds/rdsutils"
)

// ExampleConnectionStringBuilder contains usage of assuming a role and using
// that to build the auth token.
// Usage:
// ./main -user "iamuser" -dbname "foo" -region "us-west-2" -rolearn "arn" -endpoint "dbendpoint" -port 3306
func ExampleConnectionStringBuilder() {
userPtr := flag.String("user", "", "user of the credentials")
regionPtr := flag.String("region", "us-east-1", "region to be used when grabbing sts creds")
roleArnPtr := flag.String("rolearn", "", "role arn to be used when grabbing sts creds")
endpointPtr := flag.String("endpoint", "", "DB endpoint to be connected to")
portPtr := flag.Int("port", 3306, "DB port to be connected to")
tablePtr := flag.String("table", "test_table", "DB table to query against")
dbNamePtr := flag.String("dbname", "", "DB name to query against")
flag.Parse()

// Check required flags. Will exit with status code 1 if
// required field isn't set.
if err := requiredFlags(
userPtr,
regionPtr,
roleArnPtr,
endpointPtr,
portPtr,
dbNamePtr,
); err != nil {
fmt.Printf("Error: %v\n\n", err)
flag.PrintDefaults()
os.Exit(1)
}

err := registerRDSMysqlCerts(http.DefaultClient)
if err != nil {
panic(err)
}

sess := session.Must(session.NewSession())
creds := stscreds.NewCredentials(sess, *roleArnPtr)

v := url.Values{}
// required fields for DB connection
v.Add("tls", "rds")
v.Add("allowCleartextPasswords", "true")
endpoint := fmt.Sprintf("%s:%d", *endpointPtr, *portPtr)

b := rdsutils.NewConnectionStringBuilder(endpoint, *regionPtr, *userPtr, *dbNamePtr, creds)
connectStr, err := b.WithTCPFormat().WithParams(v).Build()

const dbType = "mysql"
db, err := sql.Open(dbType, connectStr)
// if an error is encountered here, then most likely security groups are incorrect
// in the database.
if err != nil {
panic(fmt.Errorf("failed to open connection to the database"))
}

rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s LIMIT 1", *tablePtr))
if err != nil {
panic(fmt.Errorf("failed to select from table, %q, with %v", *tablePtr, err))
}

for rows.Next() {
columns, err := rows.Columns()
if err != nil {
panic(fmt.Errorf("failed to read columns from row: %v", err))
}

fmt.Printf("rows colums:\n%d\n", len(columns))
}
}

func requiredFlags(flags ...interface{}) error {
for _, f := range flags {
switch f.(type) {
case nil:
return fmt.Errorf("one or more required flags were not set")
}
}
return nil
}

func registerRDSMysqlCerts(c *http.Client) error {
resp, err := c.Get("https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem")
if err != nil {
return err
}

pem, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}

rootCertPool := x509.NewCertPool()
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return fmt.Errorf("failed to append cert to cert pool!")
}

return mysql.RegisterTLSConfig("rds", &tls.Config{RootCAs: rootCertPool, InsecureSkipVerify: true})
}

0 comments on commit 44b48b9

Please sign in to comment.