Skip to content

Commit

Permalink
Remove AssumeAsgTagPropagation (#632)
Browse files Browse the repository at this point in the history
* remove AssumeAsgTagPropogation flag+config

* refactor check before invoking asg api in sqs monitor

* update sqs monitor unit tests

* clarify checkASGTagBeforeDraining docs

* removed unused test code in monitor package
  • Loading branch information
brycahta authored Apr 25, 2022
1 parent b76bb73 commit 37e9899
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 154 deletions.
17 changes: 8 additions & 9 deletions cmd/node-termination-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,14 @@ func main() {
log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName)

sqsMonitor := sqsevent.SQSMonitor{
CheckIfManaged: nthConfig.CheckASGTagBeforeDraining,
ManagedAsgTag: nthConfig.ManagedAsgTag,
AssumeAsgTagPropagation: nthConfig.AssumeAsgTagPropagation,
QueueURL: nthConfig.QueueURL,
InterruptionChan: interruptionChan,
CancelChan: cancelChan,
SQS: sqs.New(sess),
ASG: autoscaling.New(sess),
EC2: ec2.New(sess),
CheckIfManaged: nthConfig.CheckASGTagBeforeDraining,
ManagedAsgTag: nthConfig.ManagedAsgTag,
QueueURL: nthConfig.QueueURL,
InterruptionChan: interruptionChan,
CancelChan: cancelChan,
SQS: sqs.New(sess),
ASG: autoscaling.New(sess),
EC2: ec2.New(sess),
}
monitoringFns[sqsEvents] = sqsMonitor
}
Expand Down
3 changes: 1 addition & 2 deletions config/helm/aws-node-termination-handler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ The configuration in this table applies to AWS Node Termination Handler in queue
| `awsRegion` | If specified, use the AWS region for AWS API calls, else NTH will try to find the region through the `AWS_REGION` environment variable, IMDS, or the specified queue URL. | `""` |
| `queueURL` | Listens for messages on the specified SQS queue URL. | `""` |
| `workers` | The maximum amount of parallel event processors to handle concurrent events. | `10` |
| `checkASGTagBeforeDraining` | If `true`, check that the instance is tagged with the `managedAsgTag` before draining the node. | `true` |
| `checkASGTagBeforeDraining` | If `true`, check that the instance is tagged with the `managedAsgTag` before draining the node. If `false`, disables calls ASG API. | `true` |
| `managedAsgTag` | The node tag to check if `checkASGTagBeforeDraining` is `true`. | `aws-node-termination-handler/managed` |
| `assumeAsgTagPropagation` | If `true`, assume that ASG tags will be appear on the ASG's instances. | `false` |
| `useProviderId` | If `true`, fetch node name through Kubernetes node spec ProviderID instead of AWS event PrivateDnsHostname. | `false` |

### IMDS Mode Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ spec:
value: {{ .Values.checkASGTagBeforeDraining | quote }}
- name: MANAGED_ASG_TAG
value: {{ .Values.managedAsgTag | quote }}
- name: ASSUME_ASG_TAG_PROPAGATION
value: {{ .Values.assumeAsgTagPropagation | quote }}
- name: USE_PROVIDER_ID
value: {{ .Values.useProviderId | quote }}
- name: DRY_RUN
Expand Down
4 changes: 1 addition & 3 deletions config/helm/aws-node-termination-handler/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,12 @@ queueURL: ""
workers: 10

# If true, check that the instance is tagged with "aws-node-termination-handler/managed" as the key before draining the node
# If false, disables calls to ASG API.
checkASGTagBeforeDraining: true

# The tag to ensure is on a node if checkASGTagBeforeDraining is true
managedAsgTag: "aws-node-termination-handler/managed"

# If true, assume that ASG tags will be appear on the ASG's instances
assumeAsgTagPropagation: false

# If true, fetch node name through Kubernetes node spec ProviderID instead of AWS event PrivateDnsHostname.
useProviderId: false

Expand Down
9 changes: 1 addition & 8 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ const (
checkASGTagBeforeDrainingDefault = true
managedAsgTagConfigKey = "MANAGED_ASG_TAG"
managedAsgTagDefault = "aws-node-termination-handler/managed"
assumeAsgTagPropagationKey = "ASSUME_ASG_TAG_PROPAGATION"
assumeAsgTagPropagationDefault = false
useProviderIdConfigKey = "USE_PROVIDER_ID"
useProviderIdDefault = false
metadataTriesConfigKey = "METADATA_TRIES"
Expand Down Expand Up @@ -126,7 +124,6 @@ type Config struct {
EnableRebalanceDraining bool
CheckASGTagBeforeDraining bool
ManagedAsgTag string
AssumeAsgTagPropagation bool
MetadataTries int
CordonOnly bool
TaintNode bool
Expand Down Expand Up @@ -181,7 +178,7 @@ func ParseCliArgs() (config Config, err error) {
flag.BoolVar(&config.EnableSQSTerminationDraining, "enable-sqs-termination-draining", getBoolEnv(enableSQSTerminationDrainingConfigKey, enableSQSTerminationDrainingDefault), "If true, drain nodes when an SQS termination event is received")
flag.BoolVar(&config.EnableRebalanceMonitoring, "enable-rebalance-monitoring", getBoolEnv(enableRebalanceMonitoringConfigKey, enableRebalanceMonitoringDefault), "If true, cordon nodes when the rebalance recommendation notice is received. If you'd like to drain the node in addition to cordoning, then also set \"enableRebalanceDraining\".")
flag.BoolVar(&config.EnableRebalanceDraining, "enable-rebalance-draining", getBoolEnv(enableRebalanceDrainingConfigKey, enableRebalanceDrainingDefault), "If true, drain nodes when the rebalance recommendation notice is received")
flag.BoolVar(&config.CheckASGTagBeforeDraining, "check-asg-tag-before-draining", getBoolEnv(checkASGTagBeforeDrainingConfigKey, checkASGTagBeforeDrainingDefault), "If true, check that the instance is tagged with \"aws-node-termination-handler/managed\" as the key before draining the node")
flag.BoolVar(&config.CheckASGTagBeforeDraining, "check-asg-tag-before-draining", getBoolEnv(checkASGTagBeforeDrainingConfigKey, checkASGTagBeforeDrainingDefault), "If true, check that the instance is tagged with \"aws-node-termination-handler/managed\" as the key before draining the node. If false, disables calls to ASG API.")
flag.StringVar(&config.ManagedAsgTag, "managed-asg-tag", getEnv(managedAsgTagConfigKey, managedAsgTagDefault), "Sets the tag to check for on instances that is propogated from the ASG before taking action, default to aws-node-termination-handler/managed")
flag.IntVar(&config.MetadataTries, "metadata-tries", getIntEnv(metadataTriesConfigKey, metadataTriesDefault), "The number of times to try requesting metadata. If you would like 2 retries, set metadata-tries to 3.")
flag.BoolVar(&config.CordonOnly, "cordon-only", getBoolEnv(cordonOnly, false), "If true, nodes will be cordoned but not drained when an interruption event occurs.")
Expand All @@ -202,7 +199,6 @@ func ParseCliArgs() (config Config, err error) {
flag.StringVar(&config.AWSEndpoint, "aws-endpoint", getEnv(awsEndpointConfigKey, ""), "[testing] If specified, use the AWS endpoint to make API calls")
flag.StringVar(&config.QueueURL, "queue-url", getEnv(queueURLConfigKey, ""), "Listens for messages on the specified SQS queue URL")
flag.IntVar(&config.Workers, "workers", getIntEnv(workersConfigKey, workersDefault), "The amount of parallel event processors.")
flag.BoolVar(&config.AssumeAsgTagPropagation, "assume-asg-tag-propagation", getBoolEnv(assumeAsgTagPropagationKey, assumeAsgTagPropagationDefault), "If true, assume that ASG tags will be appear on the ASG's instances.")
flag.BoolVar(&config.UseProviderId, "use-provider-id", getBoolEnv(useProviderIdConfigKey, useProviderIdDefault), "If true, fetch node name through Kubernetes node spec ProviderID instead of AWS event PrivateDnsHostname.")
flag.Parse()

Expand Down Expand Up @@ -279,7 +275,6 @@ func (c Config) PrintJsonConfigArgs() {
Str("queue_url", c.QueueURL).
Bool("check_asg_tag_before_draining", c.CheckASGTagBeforeDraining).
Str("ManagedAsgTag", c.ManagedAsgTag).
Bool("assume_asg_tag_propagation", c.AssumeAsgTagPropagation).
Bool("use_provider_id", c.UseProviderId).
Msg("aws-node-termination-handler arguments")
}
Expand Down Expand Up @@ -328,7 +323,6 @@ func (c Config) PrintHumanConfigArgs() {
"\tqueue-url: %s,\n"+
"\tcheck-asg-tag-before-draining: %t,\n"+
"\tmanaged-asg-tag: %s,\n"+
"\tassume-asg-tag-propagation: %t,\n"+
"\tuse-provider-id: %t,\n"+
"\taws-endpoint: %s,\n",
c.DryRun,
Expand Down Expand Up @@ -366,7 +360,6 @@ func (c Config) PrintHumanConfigArgs() {
c.QueueURL,
c.CheckASGTagBeforeDraining,
c.ManagedAsgTag,
c.AssumeAsgTagPropagation,
c.UseProviderId,
c.AWSEndpoint,
)
Expand Down
7 changes: 0 additions & 7 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ func setEnvForTest(key string, val string) {

func TestParseCliArgsEnvSuccess(t *testing.T) {
resetFlagsForTest()
setEnvForTest("ASSUME_ASG_TAG_PROPAGATION", "true")
setEnvForTest("USE_PROVIDER_ID", "true")
setEnvForTest("DELETE_LOCAL_DATA", "false")
setEnvForTest("DRY_RUN", "true")
Expand All @@ -68,7 +67,6 @@ func TestParseCliArgsEnvSuccess(t *testing.T) {
h.Ok(t, err)

// Assert all the values were set
h.Equals(t, true, nthConfig.AssumeAsgTagPropagation)
h.Equals(t, true, nthConfig.UseProviderId)
h.Equals(t, false, nthConfig.DeleteLocalData)
h.Equals(t, true, nthConfig.DryRun)
Expand Down Expand Up @@ -104,7 +102,6 @@ func TestParseCliArgsSuccess(t *testing.T) {
resetFlagsForTest()
os.Args = []string{
"cmd",
"--assume-asg-tag-propagation=true",
"--use-provider-id=true",
"--delete-local-data=false",
"--dry-run=true",
Expand All @@ -130,7 +127,6 @@ func TestParseCliArgsSuccess(t *testing.T) {
h.Ok(t, err)

// Assert all the values were set
h.Equals(t, true, nthConfig.AssumeAsgTagPropagation)
h.Equals(t, true, nthConfig.UseProviderId)
h.Equals(t, false, nthConfig.DeleteLocalData)
h.Equals(t, true, nthConfig.DryRun)
Expand Down Expand Up @@ -161,7 +157,6 @@ func TestParseCliArgsSuccess(t *testing.T) {

func TestParseCliArgsOverrides(t *testing.T) {
resetFlagsForTest()
setEnvForTest("ASSUME_ASG_TAG_PROPAGATION", "true")
setEnvForTest("USE_PROVIDER_ID", "true")
setEnvForTest("DELETE_LOCAL_DATA", "true")
setEnvForTest("DRY_RUN", "false")
Expand All @@ -185,7 +180,6 @@ func TestParseCliArgsOverrides(t *testing.T) {
setEnvForTest("CORDON_ONLY", "true")
os.Args = []string{
"cmd",
"--assume-asg-tag-propagation=false",
"--use-provider-id=false",
"--delete-local-data=false",
"--dry-run=true",
Expand Down Expand Up @@ -213,7 +207,6 @@ func TestParseCliArgsOverrides(t *testing.T) {
h.Ok(t, err)

// Assert all the values were set
h.Equals(t, false, nthConfig.AssumeAsgTagPropagation)
h.Equals(t, false, nthConfig.UseProviderId)
h.Equals(t, false, nthConfig.DeleteLocalData)
h.Equals(t, true, nthConfig.DryRun)
Expand Down
39 changes: 18 additions & 21 deletions pkg/monitor/sqsevent/sqs-monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ const (

// SQSMonitor is a struct definition that knows how to process events from Amazon EventBridge
type SQSMonitor struct {
InterruptionChan chan<- monitor.InterruptionEvent
CancelChan chan<- monitor.InterruptionEvent
QueueURL string
SQS sqsiface.SQSAPI
ASG autoscalingiface.AutoScalingAPI
EC2 ec2iface.EC2API
CheckIfManaged bool
AssumeAsgTagPropagation bool
ManagedAsgTag string
InterruptionChan chan<- monitor.InterruptionEvent
CancelChan chan<- monitor.InterruptionEvent
QueueURL string
SQS sqsiface.SQSAPI
ASG autoscalingiface.AutoScalingAPI
EC2 ec2iface.EC2API
CheckIfManaged bool
ManagedAsgTag string
}

// InterruptionEventWrapper is a convenience wrapper for associating an interruption event with its error, if any
Expand Down Expand Up @@ -348,26 +347,24 @@ func (m SQSMonitor) getNodeInfo(instanceID string) (*NodeInfo, error) {
}
}

if nodeInfo.AsgName == "" && !m.AssumeAsgTagPropagation {
// If ASG tags are not propagated we might need to use the API
// to retrieve the ASG name
nodeInfo.AsgName, err = m.retrieveAutoScalingGroupName(nodeInfo.InstanceID)
if err != nil {
return nil, fmt.Errorf("unable to retrieve AutoScaling group: %w", err)
if m.CheckIfManaged {
if nodeInfo.AsgName == "" {

This comment has been minimized.

Copy link
@bwagner5

bwagner5 Jul 11, 2022

Contributor

This is actually wrong now. We should skip the ASG lookup like we used to when AssumeASGTagPropagation was true. We should assume that AssumeASGTagPropagation is always true and try to match the tag. We should also just try to match the ASG system tag and if it doesn't exist then assume it's not managed by an ASG.

// If ASG tags are not propagated we might need to use the API
// to retrieve the ASG name
nodeInfo.AsgName, err = m.retrieveAutoScalingGroupName(nodeInfo.InstanceID)
if err != nil {
return nil, fmt.Errorf("unable to retrieve AutoScaling group: %w", err)
}
}
}

if m.CheckIfManaged && nodeInfo.Tags[m.ManagedAsgTag] == "" {
if m.AssumeAsgTagPropagation {
nodeInfo.IsManaged = false
} else {
if nodeInfo.Tags[m.ManagedAsgTag] == "" {
// if ASG tags are not propagated we might have to check the ASG directly
nodeInfo.IsManaged, err = m.isASGManaged(nodeInfo.AsgName, nodeInfo.InstanceID)
if err != nil {
return nil, err
}
}
}

infoJSON, _ := json.MarshalIndent(nodeInfo, " ", " ")
log.Debug().Msgf("Got node info from AWS: %s", infoJSON)

Expand Down
107 changes: 5 additions & 102 deletions pkg/monitor/sqsevent/sqs-monitor_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,30 +83,6 @@ func TestGetNodeInfo_BothTags_Managed(t *testing.T) {
h.Equals(t, true, nodeInfo.IsManaged)
}

func TestGetNodeInfo_BothTags_AssumePropagation_Managed(t *testing.T) {
asgName := "test-asg"
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp(
"i-beebeebe",
"mydns.example.com",
map[string]string{
"aws-nth/managed": "true",
ASGTagName: asgName,
}),
}
monitor := SQSMonitor{
AssumeAsgTagPropagation: true,
EC2: ec2Mock,
ASG: h.MockedASG{},
CheckIfManaged: true,
ManagedAsgTag: "aws-nth/managed",
}
nodeInfo, err := monitor.getNodeInfo("i-0123456789")
h.Ok(t, err)
h.Equals(t, asgName, nodeInfo.AsgName)
h.Equals(t, true, nodeInfo.IsManaged)
}

func TestGetNodeInfo_ASGTag_ASGNotManaged(t *testing.T) {
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp(
Expand Down Expand Up @@ -161,40 +137,6 @@ func TestGetNodeInfo_ASGTag_ASGManaged(t *testing.T) {
h.Equals(t, true, nodeInfo.IsManaged)
}

func TestGetNodeInfo_ASGTag_AssumePropagation_NotManaged(t *testing.T) {
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp(
"i-beebeebe",
"mydns.example.com",
map[string]string{
ASGTagName: "test-asg",
}),
}
asgMock := h.MockedASG{
DescribeAutoScalingInstancesResp: autoscaling.DescribeAutoScalingInstancesOutput{
AutoScalingInstances: []*autoscaling.InstanceDetails{
{AutoScalingGroupName: aws.String("asg-from-api-not-used")},
},
},
DescribeTagsPagesResp: autoscaling.DescribeTagsOutput{
Tags: []*autoscaling.TagDescription{
{Key: aws.String("aws-nth/managed")},
},
},
}
monitor := SQSMonitor{
AssumeAsgTagPropagation: true,
EC2: ec2Mock,
ASG: asgMock,
CheckIfManaged: true,
ManagedAsgTag: "aws-nth/managed",
}
nodeInfo, err := monitor.getNodeInfo("i-0123456789")
h.Ok(t, err)
h.Equals(t, "test-asg", nodeInfo.AsgName)
h.Equals(t, false, nodeInfo.IsManaged)
}

func TestGetNodeInfo_NoASG(t *testing.T) {
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp("i-beebeebe", "mydns.example.com", map[string]string{}),
Expand Down Expand Up @@ -235,23 +177,6 @@ func TestGetNodeInfo_NoASG_NotManaged(t *testing.T) {
h.Equals(t, false, nodeInfo.IsManaged)
}

func TestGetNodeInfo_NoASG_AssumePropagation_NotManaged(t *testing.T) {
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp("i-beebeebe", "mydns.example.com", map[string]string{}),
}
monitor := SQSMonitor{
AssumeAsgTagPropagation: true,
EC2: ec2Mock,
ASG: getUnusableMockedASG(),
CheckIfManaged: true,
ManagedAsgTag: "aws-nth/managed",
}
nodeInfo, err := monitor.getNodeInfo("i-0123456789")
h.Ok(t, err)
h.Equals(t, "", nodeInfo.AsgName)
h.Equals(t, false, nodeInfo.IsManaged)
}

func TestGetNodeInfo_ASG(t *testing.T) {
asgName := "my-asg"
ec2Mock := h.MockedEC2{
Expand All @@ -270,25 +195,9 @@ func TestGetNodeInfo_ASG(t *testing.T) {
}
nodeInfo, err := monitor.getNodeInfo("i-0123456789")
h.Ok(t, err)
h.Equals(t, asgName, nodeInfo.AsgName)
h.Equals(t, true, nodeInfo.IsManaged)
}

func TestGetNodeInfo_ASG_AssumePropagation_NotManaged(t *testing.T) {
ec2Mock := h.MockedEC2{
DescribeInstancesResp: getDescribeInstancesResp("i-beebeebe", "mydns.example.com", map[string]string{}),
}
monitor := SQSMonitor{
AssumeAsgTagPropagation: true,
EC2: ec2Mock,
ASG: getUnusableMockedASG(),
CheckIfManaged: true,
ManagedAsgTag: "aws-nth/managed",
}
nodeInfo, err := monitor.getNodeInfo("i-0123456789")
h.Ok(t, err)
// CheckIfManaged defaults to false; therefore, do not call ASG API
h.Equals(t, "", nodeInfo.AsgName)
h.Equals(t, false, nodeInfo.IsManaged)
h.Equals(t, true, nodeInfo.IsManaged)
}

func TestGetNodeInfo_ASG_ASGManaged(t *testing.T) {
Expand Down Expand Up @@ -366,8 +275,9 @@ func TestGetNodeInfo_ASGError(t *testing.T) {
DescribeAutoScalingInstancesErr: fmt.Errorf("error"),
}
monitor := SQSMonitor{
EC2: ec2Mock,
ASG: asgMock,
EC2: ec2Mock,
ASG: asgMock,
CheckIfManaged: true, //enables calling ASG API
}
_, err := monitor.getNodeInfo("i-0123456789")
h.Nok(t, err)
Expand Down Expand Up @@ -423,10 +333,3 @@ func getDescribeInstancesResp(instanceID string, privateDNSName string, tags map
},
}
}

func getUnusableMockedASG() h.MockedASG {
return h.MockedASG{
DescribeAutoScalingInstancesErr: fmt.Errorf("not used"),
DescribeTagsPagesErr: fmt.Errorf("not used"),
}
}

0 comments on commit 37e9899

Please sign in to comment.