diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 498b2c9b1..682e9d50b 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -128,6 +128,17 @@ class AzureInputVars(schema.Base): workload_identity_enabled: bool = False +class AWSAmiTypes(enum.Enum): + AL2_x86_64 = "AL2_x86_64" + AL2_x86_64_GPU = "AL2_x86_64_GPU" + CUSTOM = "CUSTOM" + + +class AWSNodeLaunchTemplate(schema.Base): + pre_bootstrap_command: Optional[str] = None + ami_id: Optional[str] = None + + class AWSNodeGroupInputVars(schema.Base): name: str instance_type: str @@ -137,6 +148,28 @@ class AWSNodeGroupInputVars(schema.Base): max_size: int single_subnet: bool permissions_boundary: Optional[str] = None + ami_type: Optional[AWSAmiTypes] = None + launch_template: Optional[AWSNodeLaunchTemplate] = None + + @field_validator("ami_type", mode="before") + @classmethod + def _infer_and_validate_ami_type(cls, value, values) -> str: + gpu_enabled = values.get("gpu", False) + + # Auto-set ami_type if not provided + if not value: + if values.get("launch_template") and values["launch_template"].ami_id: + return "CUSTOM" + if gpu_enabled: + return "AL2_x86_64_GPU" + return "AL2_x86_64" + + # Explicit validation + if value == "AL2_x86_64" and gpu_enabled: + raise ValueError( + "ami_type 'AL2_x86_64' cannot be used with GPU enabled (gpu=True)." + ) + return value class AWSInputVars(schema.Base): @@ -449,6 +482,7 @@ class AWSNodeGroup(schema.Base): gpu: bool = False single_subnet: bool = False permissions_boundary: Optional[str] = None + launch_template: Optional[AWSNodeLaunchTemplate] = None DEFAULT_AWS_NODE_GROUPS = { @@ -525,6 +559,7 @@ def _check_input(cls, data: Any) -> Any: raise ValueError( f"Amazon Web Services instance {node_group.instance} not one of available instance types={available_instances}" ) + return data @@ -828,6 +863,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): max_size=node_group.max_nodes, single_subnet=node_group.single_subnet, permissions_boundary=node_group.permissions_boundary, + launch_template=node_group.launch_template, ) for name, node_group in self.config.amazon_web_services.node_groups.items() ], diff --git a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/files/user_data.tftpl b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/files/user_data.tftpl new file mode 100644 index 000000000..278e9a627 --- /dev/null +++ b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/files/user_data.tftpl @@ -0,0 +1,20 @@ +MIME-Version: 1.0 +Content-Type: multipart/mixed; boundary="//" + +%{ if node_pre_bootstrap_command != null } +--// +Content-Type: text/x-shellscript; charset="us-ascii" + +${node_pre_bootstrap_command} +%{ endif } + +%{ if include_bootstrap_cmd } +--// +Content-Type: text/x-shellscript; charset="us-ascii" +#!/bin/bash +set -ex + +/etc/eks/bootstrap.sh ${cluster_name} --b64-cluster-ca ${cluster_cert_authority} --apiserver-endpoint ${cluster_endpoint} +%{ endif } + + --// diff --git a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf index 6ca547ab3..5b66201f8 100644 --- a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf +++ b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf @@ -21,6 +21,52 @@ resource "aws_eks_cluster" "main" { tags = merge({ Name = var.name }, var.tags) } +## aws_launch_template user_data invocation +## If using a Custom AMI, then the /etc/eks/bootstrap cmds and args must be included/modified, +## otherwise, on default AWS EKS Node AMI, the bootstrap cmd is appended automatically +resource "aws_launch_template" "main" { + for_each = { + for node_group in var.node_groups : + node_group.name => node_group + if node_group.launch_template != null + } + + name_prefix = "eks-${var.name}-${each.value.name}-" + image_id = each.value.launch_template.ami_id + + vpc_security_group_ids = var.cluster_security_groups + + + metadata_options { + http_tokens = "required" + http_endpoint = "enabled" + instance_metadata_tags = "enabled" + } + + block_device_mappings { + device_name = "/dev/xvda" + ebs { + volume_size = 50 + volume_type = "gp2" + } + } + + # https://docs.aws.amazon.com/eks/latest/userguide/launch-templates.html#launch-template-basics + user_data = base64encode( + templatefile( + "${path.module}/files/user_data.tftpl", + { + node_pre_bootstrap_command = each.value.launch_template.pre_bootstrap_command + # This will ensure the bootstrap user data is used to join the node + include_bootstrap_cmd = each.value.launch_template.ami_id != null ? true : false + cluster_name = aws_eks_cluster.main.name + cluster_cert_authority = aws_eks_cluster.main.certificate_authority[0].data + cluster_endpoint = aws_eks_cluster.main.endpoint + } + ) + ) +} + resource "aws_eks_node_group" "main" { count = length(var.node_groups) @@ -31,8 +77,8 @@ resource "aws_eks_node_group" "main" { subnet_ids = var.node_groups[count.index].single_subnet ? [element(var.cluster_subnets, 0)] : var.cluster_subnets instance_types = [var.node_groups[count.index].instance_type] - ami_type = var.node_groups[count.index].gpu == true ? "AL2_x86_64_GPU" : "AL2_x86_64" - disk_size = 50 + ami_type = var.node_groups[count.index].ami_type + disk_size = var.node_groups[count.index].launch_template == null ? 50 : null scaling_config { min_size = var.node_groups[count.index].min_size @@ -40,6 +86,15 @@ resource "aws_eks_node_group" "main" { max_size = var.node_groups[count.index].max_size } + # Only set launch_template if its node_group counterpart parameter is not null + dynamic "launch_template" { + for_each = var.node_groups[count.index].launch_template != null ? [0] : [] + content { + id = aws_launch_template.main[var.node_groups[count.index].name].id + version = aws_launch_template.main[var.node_groups[count.index].name].latest_version + } + } + labels = { "dedicated" = var.node_groups[count.index].name } diff --git a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf index 87f5e7c95..4d38d10a1 100644 --- a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf +++ b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf @@ -44,13 +44,15 @@ variable "node_group_additional_policies" { variable "node_groups" { description = "Node groups to add to EKS Cluster" type = list(object({ - name = string - instance_type = string - gpu = bool - min_size = number - desired_size = number - max_size = number - single_subnet = bool + name = string + instance_type = string + gpu = bool + min_size = number + desired_size = number + max_size = number + single_subnet = bool + launch_template = map(any) + ami_type = string })) } diff --git a/src/_nebari/stages/infrastructure/template/aws/variables.tf b/src/_nebari/stages/infrastructure/template/aws/variables.tf index 372109bab..a3f37b9eb 100644 --- a/src/_nebari/stages/infrastructure/template/aws/variables.tf +++ b/src/_nebari/stages/infrastructure/template/aws/variables.tf @@ -31,13 +31,15 @@ variable "kubernetes_version" { variable "node_groups" { description = "AWS node groups" type = list(object({ - name = string - instance_type = string - gpu = bool - min_size = number - desired_size = number - max_size = number - single_subnet = bool + name = string + instance_type = string + gpu = bool + min_size = number + desired_size = number + max_size = number + single_subnet = bool + launch_template = map(any) + ami_type = string })) } diff --git a/src/_nebari/stages/terraform_state/__init__.py b/src/_nebari/stages/terraform_state/__init__.py index 4d93775fa..dd481dad8 100644 --- a/src/_nebari/stages/terraform_state/__init__.py +++ b/src/_nebari/stages/terraform_state/__init__.py @@ -260,7 +260,6 @@ def check_immutable_fields(self): nebari_config_diff = utils.JsonDiff( nebari_config_state.model_dump(), self.config.model_dump() ) - # check if any changed fields are immutable for keys, old, new in nebari_config_diff.modified(): bottom_level_schema = self.config diff --git a/tests/tests_unit/test_cli_validate.py b/tests/tests_unit/test_cli_validate.py index faf2efa8a..07a931acd 100644 --- a/tests/tests_unit/test_cli_validate.py +++ b/tests/tests_unit/test_cli_validate.py @@ -114,20 +114,26 @@ def test_cli_validate_from_env(): ["validate", "--config", tmp_file.resolve()], env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.20"}, ) - - assert 0 == valid_result.exit_code - assert not valid_result.exception - assert "Successfully validated configuration" in valid_result.stdout + try: + assert 0 == valid_result.exit_code + assert not valid_result.exception + assert "Successfully validated configuration" in valid_result.stdout + except AssertionError: + print(valid_result.stdout) + raise invalid_result = runner.invoke( app, ["validate", "--config", tmp_file.resolve()], env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.0"}, ) - - assert 1 == invalid_result.exit_code - assert invalid_result.exception - assert "Invalid `kubernetes-version`" in invalid_result.stdout + try: + assert 1 == invalid_result.exit_code + assert invalid_result.exception + assert "Invalid `kubernetes-version`" in invalid_result.stdout + except AssertionError: + print(invalid_result.stdout) + raise @pytest.mark.parametrize(