diff --git a/cmd/node-termination-handler.go b/cmd/node-termination-handler.go index cee30a21..6a2f9118 100644 --- a/cmd/node-termination-handler.go +++ b/cmd/node-termination-handler.go @@ -200,7 +200,16 @@ func main() { // Exit interruption loop if a SIGTERM is received or the channel is closed break default: - drainOrCordonIfNecessary(interruptionEventStore, *node, nthConfig, nodeMetadata, metrics) + for event, ok := interruptionEventStore.GetActiveEvent(); ok && !event.InProgress; event, ok = interruptionEventStore.GetActiveEvent() { + select { + case interruptionEventStore.Workers <- 1: + event.InProgress = true + go drainOrCordonIfNecessary(interruptionEventStore, event, *node, nthConfig, nodeMetadata, metrics) + default: + log.Warn().Msg("all workers busy, waiting") + break + } + } } } log.Log().Msg("AWS Node Termination Handler is shutting down") @@ -254,59 +263,59 @@ func watchForCancellationEvents(cancelChan <-chan monitor.InterruptionEvent, int } } -func drainOrCordonIfNecessary(interruptionEventStore *interruptioneventstore.Store, node node.Node, nthConfig config.Config, nodeMetadata ec2metadata.NodeMetadata, metrics observability.Metrics) { - if drainEvent, ok := interruptionEventStore.GetActiveEvent(); ok { - nodeName := drainEvent.NodeName - if drainEvent.PreDrainTask != nil { - err := drainEvent.PreDrainTask(*drainEvent, node) - if err != nil { - log.Log().Err(err).Msg("There was a problem executing the pre-drain task") - } - metrics.NodeActionsInc("pre-drain", nodeName, err) +func drainOrCordonIfNecessary(interruptionEventStore *interruptioneventstore.Store, drainEvent *monitor.InterruptionEvent, node node.Node, nthConfig config.Config, nodeMetadata ec2metadata.NodeMetadata, metrics observability.Metrics) { + nodeName := drainEvent.NodeName + if drainEvent.PreDrainTask != nil { + err := drainEvent.PreDrainTask(*drainEvent, node) + if err != nil { + log.Log().Err(err).Msg("There was a problem executing the pre-drain task") } + metrics.NodeActionsInc("pre-drain", nodeName, err) + } - if nthConfig.CordonOnly || drainEvent.IsRebalanceRecommendation() { - err := node.Cordon(nodeName) - if err != nil { - if errors.IsNotFound(err) { - log.Warn().Err(err).Msgf("node '%s' not found in the cluster", nodeName) - } else { - log.Log().Err(err).Msg("There was a problem while trying to cordon the node") - os.Exit(1) - } + if nthConfig.CordonOnly || drainEvent.IsRebalanceRecommendation() { + err := node.Cordon(nodeName) + if err != nil { + if errors.IsNotFound(err) { + log.Warn().Err(err).Msgf("node '%s' not found in the cluster", nodeName) } else { - log.Log().Str("node_name", nodeName).Msg("Node successfully cordoned") - err = node.LogPods(nodeName) - if err != nil { - log.Log().Err(err).Msg("There was a problem while trying to log all pod names on the node") - } - metrics.NodeActionsInc("cordon", nodeName, err) + log.Log().Err(err).Msg("There was a problem while trying to cordon the node") + os.Exit(1) } } else { - err := node.CordonAndDrain(nodeName) + log.Log().Str("node_name", nodeName).Msg("Node successfully cordoned") + err = node.LogPods(nodeName) if err != nil { - if errors.IsNotFound(err) { - log.Warn().Err(err).Msgf("node '%s' not found in the cluster", nodeName) - } else { - log.Log().Err(err).Msg("There was a problem while trying to cordon and drain the node") - os.Exit(1) - } + log.Log().Err(err).Msg("There was a problem while trying to log all pod names on the node") + } + metrics.NodeActionsInc("cordon", nodeName, err) + } + } else { + err := node.CordonAndDrain(nodeName) + if err != nil { + if errors.IsNotFound(err) { + log.Warn().Err(err).Msgf("node '%s' not found in the cluster", nodeName) } else { - log.Log().Str("node_name", nodeName).Msg("Node successfully cordoned and drained") - metrics.NodeActionsInc("cordon-and-drain", nodeName, err) + log.Log().Err(err).Msg("There was a problem while trying to cordon and drain the node") + os.Exit(1) } + } else { + log.Log().Str("node_name", nodeName).Msg("Node successfully cordoned and drained") + metrics.NodeActionsInc("cordon-and-drain", nodeName, err) } + } - interruptionEventStore.MarkAllAsDrained(nodeName) - if nthConfig.WebhookURL != "" { - webhook.Post(nodeMetadata, drainEvent, nthConfig) - } - if drainEvent.PostDrainTask != nil { - err := drainEvent.PostDrainTask(*drainEvent, node) - if err != nil { - log.Err(err).Msg("There was a problem executing the post-drain task") - } - metrics.NodeActionsInc("post-drain", nodeName, err) + interruptionEventStore.MarkAllAsDrained(nodeName) + if nthConfig.WebhookURL != "" { + webhook.Post(nodeMetadata, drainEvent, nthConfig) + } + if drainEvent.PostDrainTask != nil { + err := drainEvent.PostDrainTask(*drainEvent, node) + if err != nil { + log.Err(err).Msg("There was a problem executing the post-drain task") } + metrics.NodeActionsInc("post-drain", nodeName, err) } + <-interruptionEventStore.Workers + } diff --git a/config/helm/aws-node-termination-handler/README.md b/config/helm/aws-node-termination-handler/README.md index e5ce0822..53b5544b 100644 --- a/config/helm/aws-node-termination-handler/README.md +++ b/config/helm/aws-node-termination-handler/README.md @@ -79,7 +79,6 @@ Parameter | Description | Default `podMonitor.sampleLimit` | Number of scraped samples accepted | `5000` `podMonitor.labels` | Additional PodMonitor metadata labels | `{}` - ### AWS Node Termination Handler - Queue-Processor Mode Configuration Parameter | Description | Default @@ -89,6 +88,7 @@ Parameter | Description | Default `awsRegion` | If specified, use the AWS region for AWS API calls, else NTH will try to find the region through AWS_REGION env var, IMDS, or the specified queue URL | `` `checkASGTagBeforeDraining` | If true, check that the instance is tagged with "aws-node-termination-handler/managed" as the key before draining the node | `true` `managedAsgTag` | The tag to ensure is on a node if checkASGTagBeforeDraining is true | `aws-node-termination-handler/managed` +`workers` | The maximum amount of parallel event processors | `10` ### AWS Node Termination Handler - IMDS Mode Configuration diff --git a/config/helm/aws-node-termination-handler/templates/deployment.yaml b/config/helm/aws-node-termination-handler/templates/deployment.yaml index d6b3e854..f183a86a 100644 --- a/config/helm/aws-node-termination-handler/templates/deployment.yaml +++ b/config/helm/aws-node-termination-handler/templates/deployment.yaml @@ -144,6 +144,8 @@ spec: value: {{ .Values.checkASGTagBeforeDraining | quote }} - name: MANAGED_ASG_TAG value: {{ .Values.managedAsgTag | quote }} + - name: WORKERS + value: {{ .Values.workers | quote }} resources: {{- toYaml .Values.resources | nindent 12 }} {{- if .Values.enablePrometheusServer }} diff --git a/config/helm/aws-node-termination-handler/values.yaml b/config/helm/aws-node-termination-handler/values.yaml index 105b8260..e387c948 100644 --- a/config/helm/aws-node-termination-handler/values.yaml +++ b/config/helm/aws-node-termination-handler/values.yaml @@ -188,3 +188,6 @@ windowsUpdateStrategy: "" # If you have disabled IMDSv1 and are relying on IMDSv2, you'll need to increase the IP hop count to 2 before switching this to false # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html useHostNetwork: true + +# The maximal amount of parallel event processors to handle concurrent events +workers: 10 diff --git a/pkg/config/config.go b/pkg/config/config.go index f5cbd18d..bd2d400f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -70,6 +70,8 @@ const ( logLevelDefault = "INFO" uptimeFromFileConfigKey = "UPTIME_FROM_FILE" uptimeFromFileDefault = "" + workersConfigKey = "WORKERS" + workersDefault = 10 // prometheus enablePrometheusDefault = false enablePrometheusConfigKey = "ENABLE_PROMETHEUS_SERVER" @@ -116,6 +118,7 @@ type Config struct { AWSRegion string AWSEndpoint string QueueURL string + Workers int AWSSession *session.Session } @@ -162,6 +165,7 @@ func ParseCliArgs() (config Config, err error) { flag.StringVar(&config.AWSRegion, "aws-region", getEnv(awsRegionConfigKey, ""), "If specified, use the AWS region for AWS API calls") flag.StringVar(&config.AWSEndpoint, "aws-endpoint", getEnv(awsEndpointConfigKey, ""), "[testing] If specified, use the AWS endpoint to make API calls") flag.StringVar(&config.QueueURL, "queue-url", getEnv(queueURLConfigKey, ""), "Listens for messages on the specified SQS queue URL") + flag.IntVar(&config.Workers, "workers", getIntEnv(workersConfigKey, workersDefault), "The amount of parallel event processors.") flag.Parse() diff --git a/pkg/interruptioneventstore/interruption-event-store.go b/pkg/interruptioneventstore/interruption-event-store.go index 4d7359bd..ffafa284 100644 --- a/pkg/interruptioneventstore/interruption-event-store.go +++ b/pkg/interruptioneventstore/interruption-event-store.go @@ -30,6 +30,7 @@ type Store struct { interruptionEventStore map[string]*monitor.InterruptionEvent ignoredEvents map[string]struct{} atLeastOneEvent bool + Workers chan int } // New Creates a new interruption event store @@ -38,6 +39,7 @@ func New(nthConfig config.Config) *Store { NthConfig: nthConfig, interruptionEventStore: make(map[string]*monitor.InterruptionEvent), ignoredEvents: make(map[string]struct{}), + Workers: make(chan int, nthConfig.Workers), } } diff --git a/pkg/monitor/sqsevent/sqs-monitor.go b/pkg/monitor/sqsevent/sqs-monitor.go index d8054e6c..5786238f 100644 --- a/pkg/monitor/sqsevent/sqs-monitor.go +++ b/pkg/monitor/sqsevent/sqs-monitor.go @@ -196,7 +196,7 @@ func (m SQSMonitor) retrieveNodeName(instanceID string) (string, error) { } // anything except running might not contain PrivateDnsName if state != ec2.InstanceStateNameRunning { - return "", ErrNodeStateNotRunning + return "", fmt.Errorf("node: '%s' in state '%s': %w", instanceID, state, ErrNodeStateNotRunning) } return "", fmt.Errorf("unable to retrieve PrivateDnsName name for '%s' in state '%s'", instanceID, state) } diff --git a/pkg/monitor/types.go b/pkg/monitor/types.go index 6a0f9755..a237088a 100644 --- a/pkg/monitor/types.go +++ b/pkg/monitor/types.go @@ -33,6 +33,7 @@ type InterruptionEvent struct { StartTime time.Time EndTime time.Time Drained bool + InProgress bool PreDrainTask DrainTask `json:"-"` PostDrainTask DrainTask `json:"-"` }