diff --git a/.changelog/20066.txt b/.changelog/20066.txt new file mode 100644 index 000000000000..675a3ca28791 --- /dev/null +++ b/.changelog/20066.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +resource/aws_sagemaker_model: Add `inference_execution_config`. +``` diff --git a/aws/resource_aws_sagemaker_model.go b/aws/resource_aws_sagemaker_model.go index 1bf6f5a891ea..6c109c00b836 100644 --- a/aws/resource_aws_sagemaker_model.go +++ b/aws/resource_aws_sagemaker_model.go @@ -28,18 +28,8 @@ func resourceAwsSagemakerModel() *schema.Resource { Type: schema.TypeString, Computed: true, }, - - "name": { - Type: schema.TypeString, - Optional: true, - Computed: true, - ForceNew: true, - ValidateFunc: validateSagemakerName, - }, - - "primary_container": { + "container": { Type: schema.TypeList, - MaxItems: 1, Optional: true, Elem: &schema.Resource{ Schema: map[string]*schema.Schema{ @@ -49,29 +39,6 @@ func resourceAwsSagemakerModel() *schema.Resource { ForceNew: true, ValidateFunc: validateSagemakerName, }, - - "image": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - ValidateFunc: validateSagemakerImage, - }, - - "mode": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - Default: sagemaker.ContainerModeSingleModel, - ValidateFunc: validation.StringInSlice(sagemaker.ContainerMode_Values(), false), - }, - - "model_data_url": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - ValidateFunc: validateSagemakerModelDataUrl, - }, - "environment": { Type: schema.TypeMap, Optional: true, @@ -79,6 +46,12 @@ func resourceAwsSagemakerModel() *schema.Resource { ValidateFunc: validateSagemakerEnvironment, Elem: &schema.Schema{Type: schema.TypeString}, }, + "image": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + ValidateFunc: validateSagemakerImage, + }, "image_config": { Type: schema.TypeList, Optional: true, @@ -94,46 +67,59 @@ func resourceAwsSagemakerModel() *schema.Resource { }, }, }, + "mode": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + Default: sagemaker.ContainerModeSingleModel, + ValidateFunc: validation.StringInSlice(sagemaker.ContainerMode_Values(), false), + }, + "model_data_url": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: validateSagemakerModelDataUrl, + }, }, }, }, - - "vpc_config": { - Type: schema.TypeList, + "enable_network_isolation": { + Type: schema.TypeBool, Optional: true, - MaxItems: 1, ForceNew: true, - Elem: &schema.Resource{ - Schema: map[string]*schema.Schema{ - "subnets": { - Type: schema.TypeSet, - Required: true, - Elem: &schema.Schema{Type: schema.TypeString}, - }, - "security_group_ids": { - Type: schema.TypeSet, - Required: true, - Elem: &schema.Schema{Type: schema.TypeString}, - }, - }, - }, }, - "execution_role_arn": { Type: schema.TypeString, Required: true, ForceNew: true, ValidateFunc: validateArn, }, - - "enable_network_isolation": { - Type: schema.TypeBool, + "inference_execution_config": { + Type: schema.TypeList, + MaxItems: 1, Optional: true, + Computed: true, ForceNew: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "mode": { + Type: schema.TypeString, + Required: true, + ValidateFunc: validation.StringInSlice(sagemaker.InferenceExecutionMode_Values(), false), + }, + }, + }, }, - - "container": { + "name": { + Type: schema.TypeString, + Optional: true, + Computed: true, + ForceNew: true, + ValidateFunc: validateSagemakerName, + }, + "primary_container": { Type: schema.TypeList, + MaxItems: 1, Optional: true, Elem: &schema.Resource{ Schema: map[string]*schema.Schema{ @@ -143,29 +129,6 @@ func resourceAwsSagemakerModel() *schema.Resource { ForceNew: true, ValidateFunc: validateSagemakerName, }, - - "image": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - ValidateFunc: validateSagemakerImage, - }, - - "mode": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - Default: sagemaker.ContainerModeSingleModel, - ValidateFunc: validation.StringInSlice(sagemaker.ContainerMode_Values(), false), - }, - - "model_data_url": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - ValidateFunc: validateSagemakerModelDataUrl, - }, - "environment": { Type: schema.TypeMap, Optional: true, @@ -173,6 +136,12 @@ func resourceAwsSagemakerModel() *schema.Resource { ValidateFunc: validateSagemakerEnvironment, Elem: &schema.Schema{Type: schema.TypeString}, }, + "image": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + ValidateFunc: validateSagemakerImage, + }, "image_config": { Type: schema.TypeList, Optional: true, @@ -188,12 +157,46 @@ func resourceAwsSagemakerModel() *schema.Resource { }, }, }, + "mode": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + Default: sagemaker.ContainerModeSingleModel, + ValidateFunc: validation.StringInSlice(sagemaker.ContainerMode_Values(), false), + }, + "model_data_url": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: validateSagemakerModelDataUrl, + }, }, }, }, - "tags": tagsSchema(), "tags_all": tagsSchemaComputed(), + "vpc_config": { + Type: schema.TypeList, + Optional: true, + MaxItems: 1, + ForceNew: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "subnets": { + Type: schema.TypeSet, + Required: true, + MaxItems: 16, + Elem: &schema.Schema{Type: schema.TypeString}, + }, + "security_group_ids": { + Type: schema.TypeSet, + Required: true, + MaxItems: 5, + Elem: &schema.Schema{Type: schema.TypeString}, + }, + }, + }, + }, }, CustomizeDiff: SetTagsDiff, @@ -240,6 +243,10 @@ func resourceAwsSagemakerModelCreate(d *schema.ResourceData, meta interface{}) e createOpts.EnableNetworkIsolation = aws.Bool(v.(bool)) } + if v, ok := d.GetOk("inference_execution_config"); ok { + createOpts.InferenceExecutionConfig = expandSagemakerModelInferenceExecutionConfig(v.([]interface{})) + } + log.Printf("[DEBUG] Sagemaker model create config: %#v", *createOpts) _, err := retryOnAwsCode("ValidationException", func() (interface{}, error) { return conn.CreateModel(createOpts) @@ -285,29 +292,29 @@ func resourceAwsSagemakerModelRead(d *schema.ResourceData, meta interface{}) err return fmt.Errorf("error reading Sagemaker model %s: %w", d.Id(), err) } - if err := d.Set("arn", model.ModelArn); err != nil { - return fmt.Errorf("unable to set arn for sagemaker model %q: %+v", d.Id(), err) - } - if err := d.Set("name", model.ModelName); err != nil { - return err - } - if err := d.Set("execution_role_arn", model.ExecutionRoleArn); err != nil { - return err - } - if err := d.Set("enable_network_isolation", model.EnableNetworkIsolation); err != nil { - return err - } + arn := aws.StringValue(model.ModelArn) + d.Set("arn", arn) + d.Set("name", model.ModelName) + d.Set("execution_role_arn", model.ExecutionRoleArn) + d.Set("enable_network_isolation", model.EnableNetworkIsolation) + if err := d.Set("primary_container", flattenContainer(model.PrimaryContainer)); err != nil { - return err + return fmt.Errorf("error setting primary_container: %w", err) } + if err := d.Set("container", flattenContainers(model.Containers)); err != nil { - return err + return fmt.Errorf("error setting container: %w", err) } + if err := d.Set("vpc_config", flattenSageMakerVpcConfigResponse(model.VpcConfig)); err != nil { return fmt.Errorf("error setting vpc_config: %w", err) } - tags, err := keyvaluetags.SagemakerListTags(conn, aws.StringValue(model.ModelArn)) + if err := d.Set("inference_execution_config", flattenSagemakerModelInferenceExecutionConfig(model.InferenceExecutionConfig)); err != nil { + return fmt.Errorf("error setting inference_execution_config: %w", err) + } + + tags, err := keyvaluetags.SagemakerListTags(conn, arn) if err != nil { return fmt.Errorf("error listing tags for Sagemaker Model (%s): %w", d.Id(), err) } @@ -480,3 +487,29 @@ func flattenContainers(containers []*sagemaker.ContainerDefinition) []interface{ } return fContainers } + +func expandSagemakerModelInferenceExecutionConfig(l []interface{}) *sagemaker.InferenceExecutionConfig { + if len(l) == 0 { + return nil + } + + m := l[0].(map[string]interface{}) + + config := &sagemaker.InferenceExecutionConfig{ + Mode: aws.String(m["mode"].(string)), + } + + return config +} + +func flattenSagemakerModelInferenceExecutionConfig(config *sagemaker.InferenceExecutionConfig) []interface{} { + if config == nil { + return []interface{}{} + } + + cfg := make(map[string]interface{}) + + cfg["mode"] = aws.StringValue(config.Mode) + + return []interface{}{cfg} +} diff --git a/aws/resource_aws_sagemaker_model_test.go b/aws/resource_aws_sagemaker_model_test.go index 3b8b29e29c26..1125eeabec99 100644 --- a/aws/resource_aws_sagemaker_model_test.go +++ b/aws/resource_aws_sagemaker_model_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/sagemaker" + multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/acctest" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" @@ -26,29 +27,34 @@ func testSweepSagemakerModels(region string) error { return fmt.Errorf("error getting client: %w", err) } conn := client.(*AWSClient).sagemakerconn + var sweeperErrs *multierror.Error + + err = conn.ListModelsPages(&sagemaker.ListModelsInput{}, func(page *sagemaker.ListModelsOutput, lastPage bool) bool { + for _, model := range page.Models { + + r := resourceAwsSagemakerModel() + d := r.Data(nil) + d.SetId(aws.StringValue(model.ModelName)) + err = r.Delete(d, client) + if err != nil { + log.Printf("[ERROR] %s", err) + sweeperErrs = multierror.Append(sweeperErrs, err) + continue + } + } - req := &sagemaker.ListModelsInput{} - resp, err := conn.ListModels(req) - if err != nil { - return fmt.Errorf("error listing models: %w", err) - } - - if len(resp.Models) == 0 { - log.Print("[DEBUG] No sagemaker models to sweep") - return nil + return !lastPage + }) + if testSweepSkipSweepError(err) { + log.Printf("[WARN] Skipping SageMaker Model sweep for %s: %s", region, err) + return sweeperErrs.ErrorOrNil() } - for _, model := range resp.Models { - _, err := conn.DeleteModel(&sagemaker.DeleteModelInput{ - ModelName: model.ModelName, - }) - if err != nil { - return fmt.Errorf( - "error deleting sagemaker model (%s): %w", aws.StringValue(model.ModelName), err) - } + if err != nil { + sweeperErrs = multierror.Append(sweeperErrs, fmt.Errorf("error retrieving Sagemaker Models: %w", err)) } - return nil + return sweeperErrs.ErrorOrNil() } func TestAccAWSSagemakerModel_basic(t *testing.T) { @@ -74,6 +80,34 @@ func TestAccAWSSagemakerModel_basic(t *testing.T) { testAccCheckResourceAttrRegionalARN(resourceName, "arn", "sagemaker", fmt.Sprintf("model/%s", rName)), resource.TestCheckResourceAttr(resourceName, "enable_network_isolation", "false"), resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "inference_execution_config.#", "0"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccAWSSagemakerModel_inferenceExecutionConfig(t *testing.T) { + rName := acctest.RandomWithPrefix("tf-acc-test") + resourceName := "aws_sagemaker_model.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + ErrorCheck: testAccErrorCheck(t, sagemaker.EndpointsID), + Providers: testAccProviders, + CheckDestroy: testAccCheckSagemakerModelDestroy, + Steps: []resource.TestStep{ + { + Config: testAccSagemakerModelInferenceExecutionConfig(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckSagemakerModelExists(resourceName), + resource.TestCheckResourceAttr(resourceName, "inference_execution_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "inference_execution_config.0.mode", "Serial"), ), }, { @@ -458,6 +492,27 @@ resource "aws_sagemaker_model" "test" { `, rName) } +func testAccSagemakerModelInferenceExecutionConfig(rName string) string { + return testAccSagemakerModelConfigBase(rName) + fmt.Sprintf(` +resource "aws_sagemaker_model" "test" { + name = %[1]q + execution_role_arn = aws_iam_role.test.arn + + inference_execution_config { + mode = "Serial" + } + + container { + image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } + + container { + image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path + } +} +`, rName) +} + func testAccSagemakerModelConfigTags1(rName, tagKey1, tagValue1 string) string { return testAccSagemakerModelConfigBase(rName) + fmt.Sprintf(` resource "aws_sagemaker_model" "test" { diff --git a/website/docs/r/sagemaker_model.html.markdown b/website/docs/r/sagemaker_model.html.markdown index 03e99639e1ce..acd8a4a6e586 100644 --- a/website/docs/r/sagemaker_model.html.markdown +++ b/website/docs/r/sagemaker_model.html.markdown @@ -20,7 +20,7 @@ resource "aws_sagemaker_model" "example" { execution_role_arn = aws_iam_role.example.arn primary_container { - image = "174872318107.dkr.ecr.us-west-2.amazonaws.com/kmeans:1" + image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path } } @@ -38,6 +38,10 @@ data "aws_iam_policy_document" "assume_role" { } } } + +data "aws_sagemaker_prebuilt_ecr_image" "test" { + repository_name = "kmeans" +} ``` ## Argument Reference @@ -47,6 +51,7 @@ The following arguments are supported: * `name` - (Optional) The name of the model (must be unique). If omitted, Terraform will assign a random, unique name. * `primary_container` - (Optional) The primary docker image containing inference code that is used when the model is deployed for predictions. If not specified, the `container` argument is required. Fields are documented below. * `execution_role_arn` - (Required) A role that SageMaker can assume to access model artifacts and docker images for deployment. +* `inference_execution_config` - (Optional) Specifies details of how containers in a multi-container endpoint are called. see [Inference Execution Config](#inference-execution-config). * `container` (Optional) - Specifies containers in the inference pipeline. If not specified, the `primary_container` argument is required. Fields are documented below. * `enable_network_isolation` (Optional) - Isolates the model container. No inbound or outbound network calls can be made to or from the model container. * `vpc_config` (Optional) - Specifies the VPC that you want your model to connect to. VpcConfig is used in hosting services and in batch transform. @@ -66,6 +71,10 @@ The `primary_container` and `container` block both support: * `repository_access_mode` - (Required) Specifies whether the model container is in Amazon ECR or a private Docker registry accessible from your Amazon Virtual Private Cloud (VPC). Allowed values are: `Platform` and `Vpc`. +## Inference Execution Config + +* `mode` - (Required) How containers in a multi-container are run. The following values are valid `Serial` and `Direct`. + ## Attributes Reference In addition to all arguments above, the following attributes are exported: