From e0e4744112e7833ff685de6af2a9ca09304fa1ea Mon Sep 17 00:00:00 2001 From: Daniel Quackenbush <25692880+danquack@users.noreply.github.com> Date: Thu, 23 Nov 2023 20:46:50 -0500 Subject: [PATCH] Allow for role based credentials --- messenger/pinpoint.go | 24 ++++++++++++------------ messenger/ses.go | 33 +++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/messenger/pinpoint.go b/messenger/pinpoint.go index 3e29df5..087634e 100644 --- a/messenger/pinpoint.go +++ b/messenger/pinpoint.go @@ -94,22 +94,22 @@ func NewPinpoint(cfg []byte, l *onelog.Logger) (Messenger, error) { if c.AppID == "" { return nil, fmt.Errorf("invalid app_id") } - if c.Region == "" { - return nil, fmt.Errorf("invalid region") + + config := &aws.Config{ + MaxRetries: aws.Int(3), } - if c.AccessKey == "" { - return nil, fmt.Errorf("invalid access_key") + if c.AccessKey != "" && c.SecretKey != "" { + config.Credentials = credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "") } - if c.SecretKey == "" { - return nil, fmt.Errorf("invalid secret_key") + if c.Region != "" { + config.Region = &c.Region } - sess := session.Must(session.NewSession()) - svc := pinpoint.New(sess, - aws.NewConfig(). - WithCredentials(credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")). - WithRegion(c.Region), - ) + var sess = session.Must(session.NewSession(config)) + if !checkCredentials(sess) { + return nil, fmt.Errorf("invalid credentials") + } + svc := pinpoint.New(sess) return pinpointMessenger{ client: svc, diff --git a/messenger/ses.go b/messenger/ses.go index 672fe6f..3329c79 100644 --- a/messenger/ses.go +++ b/messenger/ses.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ses" + "github.com/aws/aws-sdk-go/service/sts" "github.com/francoispqt/onelog" "github.com/knadh/smtppool" ) @@ -101,6 +102,15 @@ func (s sesMessenger) Close() error { return nil } +func checkCredentials(sess *session.Session) bool { + // Create a SES service client. + svc := sts.New(sess) + // Call the GetCallerIdentity API to check credentials + params := &sts.GetCallerIdentityInput{} + _, err := svc.GetCallerIdentity(params) + return err != nil +} + // NewAWSSES creates new instance of pinpoint func NewAWSSES(cfg []byte, l *onelog.Logger) (Messenger, error) { var c sesCfg @@ -108,23 +118,22 @@ func NewAWSSES(cfg []byte, l *onelog.Logger) (Messenger, error) { return nil, err } - if c.Region == "" { - return nil, fmt.Errorf("invalid region") + config := &aws.Config{ + MaxRetries: aws.Int(3), } - if c.AccessKey == "" { - return nil, fmt.Errorf("invalid access_key") + if c.AccessKey != "" && c.SecretKey != "" { + config.Credentials = credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "") } - if c.SecretKey == "" { - return nil, fmt.Errorf("invalid secret_key") + if c.Region != "" { + config.Region = &c.Region } - sess := session.Must(session.NewSession()) - svc := ses.New(sess, - aws.NewConfig(). - WithCredentials(credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")). - WithRegion(c.Region), - ) + var sess = session.Must(session.NewSession(config)) + if !checkCredentials(sess) { + return nil, fmt.Errorf("invalid credentials") + } + svc := ses.New(sess) return sesMessenger{ client: svc, cfg: c,