diff --git a/messenger/pinpoint.go b/messenger/pinpoint.go index 3e29df5..fd93abc 100644 --- a/messenger/pinpoint.go +++ b/messenger/pinpoint.go @@ -94,22 +94,23 @@ func NewPinpoint(cfg []byte, l *onelog.Logger) (Messenger, error) { if c.AppID == "" { return nil, fmt.Errorf("invalid app_id") } - if c.Region == "" { - return nil, fmt.Errorf("invalid region") + + config := &aws.Config{ + MaxRetries: aws.Int(3), } - if c.AccessKey == "" { - return nil, fmt.Errorf("invalid access_key") + if c.AccessKey != "" && c.SecretKey != "" { + config.Credentials = credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "") } - if c.SecretKey == "" { - return nil, fmt.Errorf("invalid secret_key") + if c.Region != "" { + config.Region = &c.Region } - sess := session.Must(session.NewSession()) - svc := pinpoint.New(sess, - aws.NewConfig(). - WithCredentials(credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")). - WithRegion(c.Region), - ) + var sess = session.Must(session.NewSession(config)) + err := checkCredentials(sess) + if err != nil { + return nil, err + } + svc := pinpoint.New(sess) return pinpointMessenger{ client: svc, diff --git a/messenger/ses.go b/messenger/ses.go index b18db7a..39ee286 100644 --- a/messenger/ses.go +++ b/messenger/ses.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ses" + "github.com/aws/aws-sdk-go/service/sts" "github.com/francoispqt/onelog" "github.com/knadh/smtppool" ) @@ -106,6 +107,15 @@ func (s sesMessenger) Close() error { return nil } +func checkCredentials(sess *session.Session) error { + // Create a SES service client. + svc := sts.New(sess) + // Call the GetCallerIdentity API to check credentials + params := &sts.GetCallerIdentityInput{} + _, err := svc.GetCallerIdentity(params) + return err +} + // NewAWSSES creates new instance of pinpoint func NewAWSSES(cfg []byte, l *onelog.Logger) (Messenger, error) { var c sesCfg @@ -113,23 +123,23 @@ func NewAWSSES(cfg []byte, l *onelog.Logger) (Messenger, error) { return nil, err } - if c.Region == "" { - return nil, fmt.Errorf("invalid region") + config := &aws.Config{ + MaxRetries: aws.Int(3), } - if c.AccessKey == "" { - return nil, fmt.Errorf("invalid access_key") + if c.AccessKey != "" && c.SecretKey != "" { + config.Credentials = credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "") } - if c.SecretKey == "" { - return nil, fmt.Errorf("invalid secret_key") + if c.Region != "" { + config.Region = &c.Region } - sess := session.Must(session.NewSession()) - svc := ses.New(sess, - aws.NewConfig(). - WithCredentials(credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")). - WithRegion(c.Region), - ) + var sess = session.Must(session.NewSession(config)) + err := checkCredentials(sess) + if err != nil { + return nil, err + } + svc := ses.New(sess) return sesMessenger{ client: svc, cfg: c,