-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
service/rds/rdsutils: adding more clarity to rdsutils.BuildAuthToken …
…with an example and cl… (#1985) * Adding more clarity to rdsutils.BuildAuthToken with an example and cleaning up the docs * Adding builer to allow for generating auth tokens more easily
- Loading branch information
Showing
6 changed files
with
332 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
### SDK Features | ||
|
||
### SDK Enhancements | ||
* `service/rds/rdsutils`: Clean up the rdsutils package and adds a new builder to construct connection strings [#1985](https://github.com/aws/aws-sdk-go/pull/1985) | ||
* Rewords documentation to be more useful and provides links to prior setup needed to support authentication tokens. Introduces a builder that allows for building connection strings | ||
|
||
### SDK Bugs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
package rdsutils | ||
|
||
import ( | ||
"fmt" | ||
"net/url" | ||
|
||
"github.com/aws/aws-sdk-go/aws/awserr" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
) | ||
|
||
// ConnectionFormat is the type of connection that will be | ||
// used to connect to the database | ||
type ConnectionFormat string | ||
|
||
// ConnectionFormat enums | ||
const ( | ||
NoConnectionFormat ConnectionFormat = "" | ||
TCPFormat ConnectionFormat = "tcp" | ||
) | ||
|
||
// ErrNoConnectionFormat will be returned during build if no format had been | ||
// specified | ||
var ErrNoConnectionFormat = awserr.New("NoConnectionFormat", "No connection format was specified", nil) | ||
|
||
// ConnectionStringBuilder is a builder that will construct a connection | ||
// string with the provided parameters. params field is required to have | ||
// a tls specification and allowCleartextPasswords must be set to true. | ||
type ConnectionStringBuilder struct { | ||
dbName string | ||
endpoint string | ||
region string | ||
user string | ||
creds *credentials.Credentials | ||
|
||
connectFormat ConnectionFormat | ||
params url.Values | ||
} | ||
|
||
// NewConnectionStringBuilder will return an ConnectionStringBuilder | ||
func NewConnectionStringBuilder(endpoint, region, dbUser, dbName string, creds *credentials.Credentials) ConnectionStringBuilder { | ||
return ConnectionStringBuilder{ | ||
dbName: dbName, | ||
endpoint: endpoint, | ||
region: region, | ||
user: dbUser, | ||
creds: creds, | ||
} | ||
} | ||
|
||
// WithEndpoint will return a builder with the given endpoint | ||
func (b ConnectionStringBuilder) WithEndpoint(endpoint string) ConnectionStringBuilder { | ||
b.endpoint = endpoint | ||
return b | ||
} | ||
|
||
// WithRegion will return a builder with the given region | ||
func (b ConnectionStringBuilder) WithRegion(region string) ConnectionStringBuilder { | ||
b.region = region | ||
return b | ||
} | ||
|
||
// WithUser will return a builder with the given user | ||
func (b ConnectionStringBuilder) WithUser(user string) ConnectionStringBuilder { | ||
b.user = user | ||
return b | ||
} | ||
|
||
// WithDBName will return a builder with the given database name | ||
func (b ConnectionStringBuilder) WithDBName(dbName string) ConnectionStringBuilder { | ||
b.dbName = dbName | ||
return b | ||
} | ||
|
||
// WithParams will return a builder with the given params. The parameters | ||
// will be included in the connection query string | ||
// | ||
// Example: | ||
// v := url.Values{} | ||
// v.Add("tls", "rds") | ||
// b := rdsutils.NewConnectionBuilder(endpoint, region, user, dbname, creds) | ||
// connectStr, err := b.WithParams(v).WithTCPFormat().Build() | ||
func (b ConnectionStringBuilder) WithParams(params url.Values) ConnectionStringBuilder { | ||
b.params = params | ||
return b | ||
} | ||
|
||
// WithFormat will return a builder with the given connection format | ||
func (b ConnectionStringBuilder) WithFormat(f ConnectionFormat) ConnectionStringBuilder { | ||
b.connectFormat = f | ||
return b | ||
} | ||
|
||
// WithTCPFormat will set the format to TCP and return the modified builder | ||
func (b ConnectionStringBuilder) WithTCPFormat() ConnectionStringBuilder { | ||
return b.WithFormat(TCPFormat) | ||
} | ||
|
||
// Build will return a new connection string that can be used to open a connection | ||
// to the desired database. | ||
// | ||
// Example: | ||
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds) | ||
// connectStr, err := b.WithTCPFormat().Build() | ||
// if err != nil { | ||
// panic(err) | ||
// } | ||
// const dbType = "mysql" | ||
// db, err := sql.Open(dbType, connectStr) | ||
func (b ConnectionStringBuilder) Build() (string, error) { | ||
if b.connectFormat == NoConnectionFormat { | ||
return "", ErrNoConnectionFormat | ||
} | ||
|
||
authToken, err := BuildAuthToken(b.endpoint, b.region, b.user, b.creds) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
connectionStr := fmt.Sprintf("%s:%s@%s(%s)/%s", | ||
b.user, authToken, string(b.connectFormat), b.endpoint, b.dbName, | ||
) | ||
|
||
if len(b.params) > 0 { | ||
connectionStr = fmt.Sprintf("%s?%s", connectionStr, b.params.Encode()) | ||
} | ||
return connectionStr, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package rdsutils_test | ||
|
||
import ( | ||
"net/url" | ||
"regexp" | ||
"testing" | ||
|
||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
"github.com/aws/aws-sdk-go/service/rds/rdsutils" | ||
) | ||
|
||
func TestConnectionStringBuilder(t *testing.T) { | ||
cases := []struct { | ||
user string | ||
endpoint string | ||
region string | ||
dbName string | ||
values url.Values | ||
format rdsutils.ConnectionFormat | ||
creds *credentials.Credentials | ||
|
||
expectedErr error | ||
expectedConnectRegex string | ||
}{ | ||
{ | ||
user: "foo", | ||
endpoint: "foo.bar", | ||
region: "region", | ||
dbName: "name", | ||
format: rdsutils.NoConnectionFormat, | ||
creds: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"), | ||
expectedErr: rdsutils.ErrNoConnectionFormat, | ||
expectedConnectRegex: "", | ||
}, | ||
{ | ||
user: "foo", | ||
endpoint: "foo.bar", | ||
region: "region", | ||
dbName: "name", | ||
format: rdsutils.TCPFormat, | ||
creds: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"), | ||
expectedConnectRegex: `^foo:foo.bar\?Action=connect\&DBUser=foo.*\@tcp\(foo.bar\)/name`, | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
b := rdsutils.NewConnectionStringBuilder(c.endpoint, c.region, c.user, c.dbName, c.creds) | ||
connectStr, err := b.WithFormat(c.format).Build() | ||
|
||
if e, a := c.expectedErr, err; e != a { | ||
t.Errorf("expected %v error, but received %v", e, a) | ||
} | ||
|
||
if re, a := regexp.MustCompile(c.expectedConnectRegex), connectStr; !re.MatchString(a) { | ||
t.Errorf("expect %s to match %s", re, a) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// Package rdsutils is used to generate authentication tokens used to | ||
// connect to a givent Amazon Relational Database Service (RDS) database. | ||
// | ||
// Before using the authentication please visit the docs here to ensure | ||
// the database has the proper policies to allow for IAM token authentication. | ||
// https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html#UsingWithRDS.IAMDBAuth.Availability | ||
// | ||
// When building the connection string, there are two required parameters that are needed to be set on the query. | ||
// * tls | ||
// * allowCleartextPasswords must be set to true | ||
// | ||
// Example creating a basic auth token with the builder: | ||
// v := url.Values{} | ||
// v.Add("tls", "tls_profile_name") | ||
// v.Add("allowCleartextPasswords", "true") | ||
// b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds) | ||
// connectStr, err := b.WithTCPFormat().WithParams(v).Build() | ||
package rdsutils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
// +build example,exclude | ||
|
||
package rdsutils_test | ||
|
||
import ( | ||
"crypto/tls" | ||
"crypto/x509" | ||
"database/sql" | ||
"flag" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
"net/url" | ||
"os" | ||
|
||
"github.com/go-sql-driver/mysql" | ||
|
||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds" | ||
"github.com/aws/aws-sdk-go/aws/session" | ||
"github.com/aws/aws-sdk-go/service/rds/rdsutils" | ||
) | ||
|
||
// ExampleConnectionStringBuilder contains usage of assuming a role and using | ||
// that to build the auth token. | ||
// Usage: | ||
// ./main -user "iamuser" -dbname "foo" -region "us-west-2" -rolearn "arn" -endpoint "dbendpoint" -port 3306 | ||
func ExampleConnectionStringBuilder() { | ||
userPtr := flag.String("user", "", "user of the credentials") | ||
regionPtr := flag.String("region", "us-east-1", "region to be used when grabbing sts creds") | ||
roleArnPtr := flag.String("rolearn", "", "role arn to be used when grabbing sts creds") | ||
endpointPtr := flag.String("endpoint", "", "DB endpoint to be connected to") | ||
portPtr := flag.Int("port", 3306, "DB port to be connected to") | ||
tablePtr := flag.String("table", "test_table", "DB table to query against") | ||
dbNamePtr := flag.String("dbname", "", "DB name to query against") | ||
flag.Parse() | ||
|
||
// Check required flags. Will exit with status code 1 if | ||
// required field isn't set. | ||
if err := requiredFlags( | ||
userPtr, | ||
regionPtr, | ||
roleArnPtr, | ||
endpointPtr, | ||
portPtr, | ||
dbNamePtr, | ||
); err != nil { | ||
fmt.Printf("Error: %v\n\n", err) | ||
flag.PrintDefaults() | ||
os.Exit(1) | ||
} | ||
|
||
err := registerRDSMysqlCerts(http.DefaultClient) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
sess := session.Must(session.NewSession()) | ||
creds := stscreds.NewCredentials(sess, *roleArnPtr) | ||
|
||
v := url.Values{} | ||
// required fields for DB connection | ||
v.Add("tls", "rds") | ||
v.Add("allowCleartextPasswords", "true") | ||
endpoint := fmt.Sprintf("%s:%d", *endpointPtr, *portPtr) | ||
|
||
b := rdsutils.NewConnectionStringBuilder(endpoint, *regionPtr, *userPtr, *dbNamePtr, creds) | ||
connectStr, err := b.WithTCPFormat().WithParams(v).Build() | ||
|
||
const dbType = "mysql" | ||
db, err := sql.Open(dbType, connectStr) | ||
// if an error is encountered here, then most likely security groups are incorrect | ||
// in the database. | ||
if err != nil { | ||
panic(fmt.Errorf("failed to open connection to the database")) | ||
} | ||
|
||
rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s LIMIT 1", *tablePtr)) | ||
if err != nil { | ||
panic(fmt.Errorf("failed to select from table, %q, with %v", *tablePtr, err)) | ||
} | ||
|
||
for rows.Next() { | ||
columns, err := rows.Columns() | ||
if err != nil { | ||
panic(fmt.Errorf("failed to read columns from row: %v", err)) | ||
} | ||
|
||
fmt.Printf("rows colums:\n%d\n", len(columns)) | ||
} | ||
} | ||
|
||
func requiredFlags(flags ...interface{}) error { | ||
for _, f := range flags { | ||
switch f.(type) { | ||
case nil: | ||
return fmt.Errorf("one or more required flags were not set") | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func registerRDSMysqlCerts(c *http.Client) error { | ||
resp, err := c.Get("https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem") | ||
if err != nil { | ||
return err | ||
} | ||
|
||
pem, err := ioutil.ReadAll(resp.Body) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
rootCertPool := x509.NewCertPool() | ||
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { | ||
return fmt.Errorf("failed to append cert to cert pool!") | ||
} | ||
|
||
return mysql.RegisterTLSConfig("rds", &tls.Config{RootCAs: rootCertPool, InsecureSkipVerify: true}) | ||
} |