Skip to content

Commit

Permalink
Allow for role based credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
danquack committed Nov 24, 2023
1 parent 9c42710 commit e0e4744
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
24 changes: 12 additions & 12 deletions messenger/pinpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,22 @@ func NewPinpoint(cfg []byte, l *onelog.Logger) (Messenger, error) {
if c.AppID == "" {
return nil, fmt.Errorf("invalid app_id")
}
if c.Region == "" {
return nil, fmt.Errorf("invalid region")

config := &aws.Config{
MaxRetries: aws.Int(3),
}
if c.AccessKey == "" {
return nil, fmt.Errorf("invalid access_key")
if c.AccessKey != "" && c.SecretKey != "" {
config.Credentials = credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")
}
if c.SecretKey == "" {
return nil, fmt.Errorf("invalid secret_key")
if c.Region != "" {
config.Region = &c.Region
}

sess := session.Must(session.NewSession())
svc := pinpoint.New(sess,
aws.NewConfig().
WithCredentials(credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")).
WithRegion(c.Region),
)
var sess = session.Must(session.NewSession(config))
if !checkCredentials(sess) {
return nil, fmt.Errorf("invalid credentials")
}
svc := pinpoint.New(sess)

return pinpointMessenger{
client: svc,
Expand Down
33 changes: 21 additions & 12 deletions messenger/ses.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ses"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/francoispqt/onelog"
"github.com/knadh/smtppool"
)
Expand Down Expand Up @@ -101,30 +102,38 @@ func (s sesMessenger) Close() error {
return nil
}

func checkCredentials(sess *session.Session) bool {
// Create a SES service client.
svc := sts.New(sess)
// Call the GetCallerIdentity API to check credentials
params := &sts.GetCallerIdentityInput{}
_, err := svc.GetCallerIdentity(params)
return err != nil
}

// NewAWSSES creates new instance of pinpoint
func NewAWSSES(cfg []byte, l *onelog.Logger) (Messenger, error) {
var c sesCfg
if err := json.Unmarshal(cfg, &c); err != nil {
return nil, err
}

if c.Region == "" {
return nil, fmt.Errorf("invalid region")
config := &aws.Config{
MaxRetries: aws.Int(3),
}
if c.AccessKey == "" {
return nil, fmt.Errorf("invalid access_key")
if c.AccessKey != "" && c.SecretKey != "" {
config.Credentials = credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")
}
if c.SecretKey == "" {
return nil, fmt.Errorf("invalid secret_key")
if c.Region != "" {
config.Region = &c.Region
}

sess := session.Must(session.NewSession())
svc := ses.New(sess,
aws.NewConfig().
WithCredentials(credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")).
WithRegion(c.Region),
)
var sess = session.Must(session.NewSession(config))
if !checkCredentials(sess) {
return nil, fmt.Errorf("invalid credentials")
}

svc := ses.New(sess)
return sesMessenger{
client: svc,
cfg: c,
Expand Down

0 comments on commit e0e4744

Please sign in to comment.