Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ModelPackageName in Sagemaker Container Definition #31532

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions internal/service/sagemaker/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: validModelDataURL,
},
"model_package_name": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validation.StringLenBetween(1, 1024),
},
},
},
},
Expand Down Expand Up @@ -211,6 +217,12 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: validModelDataURL,
},
"model_package_name": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validation.StringLenBetween(1, 1024),
},
},
},
},
Expand Down Expand Up @@ -418,6 +430,9 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition {
if v, ok := m["model_data_url"]; ok && v.(string) != "" {
container.ModelDataUrl = aws.String(v.(string))
}
if v, ok := m["model_package_name"]; ok && v.(string) != "" {
container.ModelPackageName = aws.String(v.(string))
}
if v, ok := m["environment"].(map[string]interface{}); ok && len(v) > 0 {
container.Environment = flex.ExpandStringMap(v)
}
Expand Down Expand Up @@ -490,6 +505,9 @@ func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} {
if container.ModelDataUrl != nil {
cfg["model_data_url"] = aws.StringValue(container.ModelDataUrl)
}
if container.ModelPackageName != nil {
cfg["model_package_name"] = aws.StringValue(container.ModelDataUrl)
}
if container.Environment != nil {
cfg["environment"] = aws.StringValueMap(container.Environment)
}
Expand Down
87 changes: 87 additions & 0 deletions internal/service/sagemaker/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,33 @@ func TestAccSageMakerModel_primaryContainerModeSingle(t *testing.T) {
})
}

func TestAccSageMakerModel_primaryContainerModelPackageName(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_model.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckModelDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccModelConfig_primaryContainerPackageName(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckModelExists(ctx, resourceName),
resource.TestCheckResourceAttrSet(resourceName, "primary_container.0.model_package_name"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerModel_containers(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -618,6 +645,66 @@ resource "aws_s3_object" "test" {
`, rName))
}

func testAccModelConfig_primaryContainerPackageName(rName string) string {
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn

primary_container {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
model_package_name = %[1]q
}
}

resource "aws_iam_policy" "test" {
name = %[1]q
description = "Allow SageMaker to create model"
policy = data.aws_iam_policy_document.policy.json
}

data "aws_iam_policy_document" "policy" {
statement {
effect = "Allow"

actions = [
"cloudwatch:PutMetricData",
"logs:CreateLogStream",
"logs:PutLogEvents",
"logs:CreateLogGroup",
"logs:DescribeLogStreams",
"ecr:GetAuthorizationToken",
"ecr:BatchCheckLayerAvailability",
"ecr:GetDownloadUrlForLayer",
"ecr:BatchGetImage",
]

resources = [
"*",
]
}

statement {
effect = "Allow"

actions = [
"sagemaker:*",
]

resources = [
"*",
]
}
}

resource "aws_iam_role_policy_attachment" "test" {
role = aws_iam_role.test.name
policy_arn = aws_iam_policy.test.arn
}

`, rName))
}

func testAccModelConfig_primaryContainerHostname(rName string) string {
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
Expand Down
1 change: 1 addition & 0 deletions website/docs/r/sagemaker_model.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ The `primary_container` and `container` block both support:
* `image` - (Required) The registry path where the inference code image is stored in Amazon ECR.
* `mode` - (Optional) The container hosts value `SingleModel/MultiModel`. The default value is `SingleModel`.
* `model_data_url` - (Optional) The URL for the S3 location where model artifacts are stored.
* `model_package_name` - (Optional) The name or Amazon Resource Name (ARN) of the model package to use to create the model.
* `container_hostname` - (Optional) The DNS host name for the container.
* `environment` - (Optional) Environment variables for the Docker container.
A list of key value pairs.
Expand Down