From 22576cd0ce07f1a13b06bc5193a7e2d801273fdf Mon Sep 17 00:00:00 2001 From: xibz Date: Sat, 9 Jun 2018 21:36:31 -0700 Subject: [PATCH] Adding more clarity to rdsutils.BuildAuthToken with an example and cleaning up the docs --- service/rds/rdsutils/connect.go | 4 +- service/rds/rdsutils/example_test.go | 91 ++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 service/rds/rdsutils/example_test.go diff --git a/service/rds/rdsutils/connect.go b/service/rds/rdsutils/connect.go index cfa923188e7..8d7b6667ca9 100644 --- a/service/rds/rdsutils/connect.go +++ b/service/rds/rdsutils/connect.go @@ -9,8 +9,8 @@ import ( "github.com/aws/aws-sdk-go/aws/signer/v4" ) -// BuildAuthToken will return a authentication token for the database's connect -// based on the RDS database endpoint, AWS region, IAM user or role, and AWS credentials. +// BuildAuthToken will return an authorization token used as the password for a DB +// connection. // // Endpoint consists of the hostname and port, IE hostname:port, of the RDS database. // Region is the AWS region the RDS database is in and where the authentication token diff --git a/service/rds/rdsutils/example_test.go b/service/rds/rdsutils/example_test.go new file mode 100644 index 00000000000..d0af718d291 --- /dev/null +++ b/service/rds/rdsutils/example_test.go @@ -0,0 +1,91 @@ +package rdsutils_test + +import ( + "database/sql" + "flag" + "fmt" + "os" + + _ "github.com/go-sql-driver/mysql" + + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/rds/rdsutils" +) + +// Example contains usage of assuming a role and using +// that to build the auth token. +// Usage: +// ./main -user "iamuser" -dbname "foo" -region "us-west-2" -rolearn "arn" -endpoint "dbendpoint" -port 3306 +func ExampleRDSUtils_ConnectViaAssumeRole() { + userPtr := flag.String("user", "", "user of the credentials") + regionPtr := flag.String("region", "us-east-1", "region to be used when grabbing sts creds") + roleArnPtr := flag.String("rolearn", "", "role arn to be used when grabbing sts creds") + endpointPtr := flag.String("endpoint", "", "DB endpoint to be connected to") + portPtr := flag.Int("port", 3306, "DB port to be connected to") + tablePtr := flag.String("table", "test_table", "DB table to query against") + dbNamePtr := flag.String("dbname", "", "DB name to query against") + flag.Parse() + + // Check required flags. Will exit with status code 1 if + // required field isn't set. + if err := requiredFlags( + userPtr, + regionPtr, + roleArnPtr, + endpointPtr, + portPtr, + dbNamePtr, + ); err != nil { + fmt.Printf("Error: %v\n\n", err) + flag.PrintDefaults() + os.Exit(1) + } + + sess := session.Must(session.NewSession()) + creds := stscreds.NewCredentials(sess, *roleArnPtr) + + endpoint := fmt.Sprintf("%s:%d", *endpointPtr, *portPtr) + token, err := rdsutils.BuildAuthToken(endpoint, *regionPtr, *userPtr, creds) + if err != nil { + panic(fmt.Errorf("failed to build authentication token: %v", err)) + } + + // builds the connection endpoint for the SQL driver to use. + dnsStr := fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=true", + *userPtr, token, *endpointPtr, *dbNamePtr, + ) + + const dbType = "mysql" + + db, err := sql.Open(dbType, dnsStr) + // if an error is encountered here, then most likely security groups are incorrect + // in the database. + if err != nil { + panic(fmt.Errorf("failed to open connection to the database")) + } + + rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s LIMIT 1", *tablePtr)) + if err != nil { + panic(fmt.Errorf("failed to select from table, %q, with %v", *tablePtr, err)) + } + + for rows.Next() { + columns, err := rows.Columns() + if err != nil { + panic(fmt.Errorf("failed to read columns from row: %v", err)) + } + + fmt.Printf("rows colums:\n%d\n", len(columns)) + } +} + +func requiredFlags(flags ...interface{}) error { + for _, f := range flags { + switch f.(type) { + case nil: + return fmt.Errorf("one or more required flags were not set") + } + } + return nil +}