Skip to content

Commit

Permalink
Add Support for AWS Launch Template Configuration (#2668)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcmcand committed Sep 19, 2024
2 parents 379736e + 5a6bda3 commit 5f5e53c
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 25 deletions.
36 changes: 36 additions & 0 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ class AzureInputVars(schema.Base):
workload_identity_enabled: bool = False


class AWSAmiTypes(enum.Enum):
AL2_x86_64 = "AL2_x86_64"
AL2_x86_64_GPU = "AL2_x86_64_GPU"
CUSTOM = "CUSTOM"


class AWSNodeLaunchTemplate(schema.Base):
pre_bootstrap_command: Optional[str] = None
ami_id: Optional[str] = None


class AWSNodeGroupInputVars(schema.Base):
name: str
instance_type: str
Expand All @@ -137,6 +148,28 @@ class AWSNodeGroupInputVars(schema.Base):
max_size: int
single_subnet: bool
permissions_boundary: Optional[str] = None
ami_type: Optional[AWSAmiTypes] = None
launch_template: Optional[AWSNodeLaunchTemplate] = None

@field_validator("ami_type", mode="before")
@classmethod
def _infer_and_validate_ami_type(cls, value, values) -> str:
gpu_enabled = values.get("gpu", False)

# Auto-set ami_type if not provided
if not value:
if values.get("launch_template") and values["launch_template"].ami_id:
return "CUSTOM"
if gpu_enabled:
return "AL2_x86_64_GPU"
return "AL2_x86_64"

# Explicit validation
if value == "AL2_x86_64" and gpu_enabled:
raise ValueError(
"ami_type 'AL2_x86_64' cannot be used with GPU enabled (gpu=True)."
)
return value


class AWSInputVars(schema.Base):
Expand Down Expand Up @@ -449,6 +482,7 @@ class AWSNodeGroup(schema.Base):
gpu: bool = False
single_subnet: bool = False
permissions_boundary: Optional[str] = None
launch_template: Optional[AWSNodeLaunchTemplate] = None


DEFAULT_AWS_NODE_GROUPS = {
Expand Down Expand Up @@ -525,6 +559,7 @@ def _check_input(cls, data: Any) -> Any:
raise ValueError(
f"Amazon Web Services instance {node_group.instance} not one of available instance types={available_instances}"
)

return data


Expand Down Expand Up @@ -828,6 +863,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
max_size=node_group.max_nodes,
single_subnet=node_group.single_subnet,
permissions_boundary=node_group.permissions_boundary,
launch_template=node_group.launch_template,
)
for name, node_group in self.config.amazon_web_services.node_groups.items()
],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
MIME-Version: 1.0
Content-Type: multipart/mixed; boundary="//"

%{ if node_pre_bootstrap_command != null }
--//
Content-Type: text/x-shellscript; charset="us-ascii"

${node_pre_bootstrap_command}
%{ endif }

%{ if include_bootstrap_cmd }
--//
Content-Type: text/x-shellscript; charset="us-ascii"
#!/bin/bash
set -ex

/etc/eks/bootstrap.sh ${cluster_name} --b64-cluster-ca ${cluster_cert_authority} --apiserver-endpoint ${cluster_endpoint}
%{ endif }

--//
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,52 @@ resource "aws_eks_cluster" "main" {
tags = merge({ Name = var.name }, var.tags)
}

## aws_launch_template user_data invocation
## If using a Custom AMI, then the /etc/eks/bootstrap cmds and args must be included/modified,
## otherwise, on default AWS EKS Node AMI, the bootstrap cmd is appended automatically
resource "aws_launch_template" "main" {
for_each = {
for node_group in var.node_groups :
node_group.name => node_group
if node_group.launch_template != null
}

name_prefix = "eks-${var.name}-${each.value.name}-"
image_id = each.value.launch_template.ami_id

vpc_security_group_ids = var.cluster_security_groups


metadata_options {
http_tokens = "required"
http_endpoint = "enabled"
instance_metadata_tags = "enabled"
}

block_device_mappings {
device_name = "/dev/xvda"
ebs {
volume_size = 50
volume_type = "gp2"
}
}

# https://docs.aws.amazon.com/eks/latest/userguide/launch-templates.html#launch-template-basics
user_data = base64encode(
templatefile(
"${path.module}/files/user_data.tftpl",
{
node_pre_bootstrap_command = each.value.launch_template.pre_bootstrap_command
# This will ensure the bootstrap user data is used to join the node
include_bootstrap_cmd = each.value.launch_template.ami_id != null ? true : false
cluster_name = aws_eks_cluster.main.name
cluster_cert_authority = aws_eks_cluster.main.certificate_authority[0].data
cluster_endpoint = aws_eks_cluster.main.endpoint
}
)
)
}


resource "aws_eks_node_group" "main" {
count = length(var.node_groups)
Expand All @@ -31,15 +77,24 @@ resource "aws_eks_node_group" "main" {
subnet_ids = var.node_groups[count.index].single_subnet ? [element(var.cluster_subnets, 0)] : var.cluster_subnets

instance_types = [var.node_groups[count.index].instance_type]
ami_type = var.node_groups[count.index].gpu == true ? "AL2_x86_64_GPU" : "AL2_x86_64"
disk_size = 50
ami_type = var.node_groups[count.index].ami_type
disk_size = var.node_groups[count.index].launch_template == null ? 50 : null

scaling_config {
min_size = var.node_groups[count.index].min_size
desired_size = var.node_groups[count.index].desired_size
max_size = var.node_groups[count.index].max_size
}

# Only set launch_template if its node_group counterpart parameter is not null
dynamic "launch_template" {
for_each = var.node_groups[count.index].launch_template != null ? [0] : []
content {
id = aws_launch_template.main[var.node_groups[count.index].name].id
version = aws_launch_template.main[var.node_groups[count.index].name].latest_version
}
}

labels = {
"dedicated" = var.node_groups[count.index].name
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ variable "node_group_additional_policies" {
variable "node_groups" {
description = "Node groups to add to EKS Cluster"
type = list(object({
name = string
instance_type = string
gpu = bool
min_size = number
desired_size = number
max_size = number
single_subnet = bool
name = string
instance_type = string
gpu = bool
min_size = number
desired_size = number
max_size = number
single_subnet = bool
launch_template = map(any)
ami_type = string
}))
}

Expand Down
16 changes: 9 additions & 7 deletions src/_nebari/stages/infrastructure/template/aws/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ variable "kubernetes_version" {
variable "node_groups" {
description = "AWS node groups"
type = list(object({
name = string
instance_type = string
gpu = bool
min_size = number
desired_size = number
max_size = number
single_subnet = bool
name = string
instance_type = string
gpu = bool
min_size = number
desired_size = number
max_size = number
single_subnet = bool
launch_template = map(any)
ami_type = string
}))
}

Expand Down
1 change: 0 additions & 1 deletion src/_nebari/stages/terraform_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def check_immutable_fields(self):
nebari_config_diff = utils.JsonDiff(
nebari_config_state.model_dump(), self.config.model_dump()
)

# check if any changed fields are immutable
for keys, old, new in nebari_config_diff.modified():
bottom_level_schema = self.config
Expand Down
22 changes: 14 additions & 8 deletions tests/tests_unit/test_cli_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,26 @@ def test_cli_validate_from_env():
["validate", "--config", tmp_file.resolve()],
env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.20"},
)

assert 0 == valid_result.exit_code
assert not valid_result.exception
assert "Successfully validated configuration" in valid_result.stdout
try:
assert 0 == valid_result.exit_code
assert not valid_result.exception
assert "Successfully validated configuration" in valid_result.stdout
except AssertionError:
print(valid_result.stdout)
raise

invalid_result = runner.invoke(
app,
["validate", "--config", tmp_file.resolve()],
env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.0"},
)

assert 1 == invalid_result.exit_code
assert invalid_result.exception
assert "Invalid `kubernetes-version`" in invalid_result.stdout
try:
assert 1 == invalid_result.exit_code
assert invalid_result.exception
assert "Invalid `kubernetes-version`" in invalid_result.stdout
except AssertionError:
print(invalid_result.stdout)
raise


@pytest.mark.parametrize(
Expand Down

0 comments on commit 5f5e53c

Please sign in to comment.