Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct instantiation of AWS session object #421

Merged
merged 1 commit into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions cmd/node-termination-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ import (
"github.com/aws/aws-node-termination-handler/pkg/node"
"github.com/aws/aws-node-termination-handler/pkg/observability"
"github.com/aws/aws-node-termination-handler/pkg/webhook"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/sqs"
Expand Down Expand Up @@ -106,10 +109,11 @@ func main() {
// Populate the aws region if available from node metadata and not already explicitly configured
if nthConfig.AWSRegion == "" && nodeMetadata.Region != "" {
nthConfig.AWSRegion = nodeMetadata.Region
if nthConfig.AWSSession != nil {
nthConfig.AWSSession.Config.Region = &nodeMetadata.Region
}
} else if nthConfig.AWSRegion == "" && nodeMetadata.Region == "" && nthConfig.EnableSQSTerminationDraining {
} else if nthConfig.AWSRegion == "" && nthConfig.QueueURL != "" {
nthConfig.AWSRegion = getRegionFromQueueURL(nthConfig.QueueURL)
log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", nthConfig.AWSRegion)
}
if nthConfig.AWSRegion == "" && nthConfig.EnableSQSTerminationDraining {
nthConfig.Print()
log.Fatal().Msgf("Unable to find the AWS region to process queue events.")
}
Expand Down Expand Up @@ -157,9 +161,14 @@ func main() {
monitoringFns[rebalanceRecommendation] = imdsRebalanceMonitor
}
if nthConfig.EnableSQSTerminationDraining {
creds, err := nthConfig.AWSSession.Config.Credentials.Get()
cfg := aws.NewConfig().WithRegion(nthConfig.AWSRegion).WithEndpoint(nthConfig.AWSEndpoint).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: *cfg,
SharedConfigState: session.SharedConfigEnable,
}))
creds, err := sess.Config.Credentials.Get()
if err != nil {
log.Err(err).Msg("Unable to get AWS credentials")
log.Fatal().Err(err).Msg("Unable to get AWS credentials")
}
log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName)

Expand All @@ -169,9 +178,9 @@ func main() {
QueueURL: nthConfig.QueueURL,
InterruptionChan: interruptionChan,
CancelChan: cancelChan,
SQS: sqs.New(nthConfig.AWSSession),
ASG: autoscaling.New(nthConfig.AWSSession),
EC2: ec2.New(nthConfig.AWSSession),
SQS: sqs.New(sess),
ASG: autoscaling.New(sess),
EC2: ec2.New(sess),
}
monitoringFns[sqsEvents] = sqsMonitor
}
Expand Down Expand Up @@ -380,3 +389,14 @@ func runPostDrainTask(node node.Node, nodeName string, drainEvent *monitor.Inter
}
metrics.NodeActionsInc("post-drain", nodeName, err)
}

func getRegionFromQueueURL(queueURL string) string {
for _, partition := range endpoints.DefaultPartitions() {
for regionID := range partition.Regions() {
if strings.Contains(queueURL, regionID) {
return regionID
}
}
}
return ""
}
33 changes: 0 additions & 33 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
"strconv"
"strings"

"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/rs/zerolog/log"
)

Expand Down Expand Up @@ -139,7 +137,6 @@ type Config struct {
AWSEndpoint string
QueueURL string
Workers int
AWSSession *session.Session
}

//ParseCliArgs parses cli arguments and uses environment variables as fallback values
Expand Down Expand Up @@ -195,25 +192,6 @@ func ParseCliArgs() (config Config, err error) {

flag.Parse()

if config.EnableSQSTerminationDraining {
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
if config.AWSRegion != "" {
sess.Config.Region = &config.AWSRegion
} else if *sess.Config.Region == "" && config.QueueURL != "" {
config.AWSRegion = getRegionFromQueueURL(config.QueueURL)
log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", config.AWSRegion)
sess.Config.Region = &config.AWSRegion
} else {
config.AWSRegion = *sess.Config.Region
}
config.AWSSession = sess
if config.AWSEndpoint != "" {
config.AWSSession.Config.Endpoint = &config.AWSEndpoint
}
}

if isConfigProvided("pod-termination-grace-period", podTerminationGracePeriodConfigKey) && isConfigProvided("grace-period", gracePeriodConfigKey) {
log.Warn().Msg("Deprecated argument \"grace-period\" and the replacement argument \"pod-termination-grace-period\" was provided. Using the newer argument \"pod-termination-grace-period\"")
} else if isConfigProvided("grace-period", gracePeriodConfigKey) {
Expand Down Expand Up @@ -413,14 +391,3 @@ func isConfigProvided(cliArgName string, envVarName string) bool {
})
return cliArgProvided
}

func getRegionFromQueueURL(queueURL string) string {
for _, partition := range endpoints.DefaultPartitions() {
for regionID := range partition.Regions() {
if strings.Contains(queueURL, regionID) {
return regionID
}
}
}
return ""
}