Skip to content

Commit

Permalink
Merge pull request #22960 from DrFaust92/sagemaker-endpoint-config
Browse files Browse the repository at this point in the history
r/sagemaker_endpoint_configuration - emptiness check for arguments
  • Loading branch information
ewbankkit authored Feb 11, 2022
2 parents 419f146 + 1205098 commit c75c99a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 19 deletions.
3 changes: 3 additions & 0 deletions .changelog/22960.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
resource/aws_sagemaker_endpoint_configuration: Emptiness check for arguments, Allow not passing `async_inference_config.kms_key_id`.
```
36 changes: 18 additions & 18 deletions internal/service/sagemaker/endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ func expandSagemakerProductionVariants(configured []interface{}) []*sagemaker.Pr
l.InitialVariantWeight = aws.Float64(v.(float64))
}

if v, ok := data["accelerator_type"]; ok && v.(string) != "" {
l.AcceleratorType = aws.String(data["accelerator_type"].(string))
if v, ok := data["accelerator_type"].(string); ok && v != "" {
l.AcceleratorType = aws.String(v)
}

containers = append(containers, l)
Expand Down Expand Up @@ -472,12 +472,12 @@ func expandSagemakerDataCaptureConfig(configured []interface{}) *sagemaker.DataC
c.EnableCapture = aws.Bool(v.(bool))
}

if v, ok := m["kms_key_id"]; ok && v.(string) != "" {
c.KmsKeyId = aws.String(v.(string))
if v, ok := m["kms_key_id"].(string); ok && v != "" {
c.KmsKeyId = aws.String(v)
}

if v, ok := m["capture_content_type_header"]; ok && (len(v.([]interface{})) > 0) {
c.CaptureContentTypeHeader = expandSagemakerCaptureContentTypeHeader(v.([]interface{})[0].(map[string]interface{}))
if v, ok := m["capture_content_type_header"].([]interface{}); ok && (len(v) > 0) {
c.CaptureContentTypeHeader = expandSagemakerCaptureContentTypeHeader(v[0].(map[string]interface{}))
}

return c
Expand Down Expand Up @@ -577,12 +577,12 @@ func expandSagemakerEndpointConfigAsyncInferenceConfig(configured []interface{})

c := &sagemaker.AsyncInferenceConfig{}

if v, ok := m["client_config"]; ok && (len(v.([]interface{})) > 0) {
c.ClientConfig = expandSagemakerEndpointConfigClientConfig(v.([]interface{}))
if v, ok := m["client_config"].([]interface{}); ok && len(v) > 0 {
c.ClientConfig = expandSagemakerEndpointConfigClientConfig(v)
}

if v, ok := m["output_config"]; ok && (len(v.([]interface{})) > 0) {
c.OutputConfig = expandSagemakerEndpointConfigOutputConfig(v.([]interface{}))
if v, ok := m["output_config"].([]interface{}); ok && len(v) > 0 {
c.OutputConfig = expandSagemakerEndpointConfigOutputConfig(v)
}

return c
Expand Down Expand Up @@ -615,12 +615,12 @@ func expandSagemakerEndpointConfigOutputConfig(configured []interface{}) *sagema
S3OutputPath: aws.String(m["s3_output_path"].(string)),
}

if v, ok := m["kms_key_id"]; ok {
c.KmsKeyId = aws.String(v.(string))
if v, ok := m["kms_key_id"].(string); ok && v != "" {
c.KmsKeyId = aws.String(v)
}

if v, ok := m["notification_config"]; ok && (len(v.([]interface{})) > 0) {
c.NotificationConfig = expandSagemakerEndpointConfigNotificationConfig(v.([]interface{}))
if v, ok := m["notification_config"].([]interface{}); ok && len(v) > 0 {
c.NotificationConfig = expandSagemakerEndpointConfigNotificationConfig(v)
}

return c
Expand All @@ -635,12 +635,12 @@ func expandSagemakerEndpointConfigNotificationConfig(configured []interface{}) *

c := &sagemaker.AsyncInferenceNotificationConfig{}

if v, ok := m["error_topic"]; ok {
c.ErrorTopic = aws.String(v.(string))
if v, ok := m["error_topic"].(string); ok && v != "" {
c.ErrorTopic = aws.String(v)
}

if v, ok := m["success_topic"]; ok {
c.SuccessTopic = aws.String(v.(string))
if v, ok := m["success_topic"].(string); ok && v != "" {
c.SuccessTopic = aws.String(v)
}

return c
Expand Down
62 changes: 61 additions & 1 deletion internal/service/sagemaker/endpoint_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,38 @@ func TestAccSageMakerEndpointConfiguration_async(t *testing.T) {
Steps: []resource.TestStep{
{
Config: testAccSagemakerEndpointConfigurationConfigAsyncConfig(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckSagemakerEndpointConfigurationExists(resourceName),
resource.TestCheckResourceAttr(resourceName, "name", rName),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.client_config.#", "0"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.#", "1"),
resource.TestCheckResourceAttrSet(resourceName, "async_inference_config.0.output_config.0.s3_output_path"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.0.notification_config.#", "0"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.0.kms_key_id", ""),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerEndpointConfiguration_async_kms(t *testing.T) {
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_endpoint_configuration.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
Providers: acctest.Providers,
CheckDestroy: testAccCheckSagemakerEndpointConfigurationDestroy,
Steps: []resource.TestStep{
{
Config: testAccSagemakerEndpointConfigurationConfigAsyncKMSConfig(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckSagemakerEndpointConfigurationExists(resourceName),
resource.TestCheckResourceAttr(resourceName, "name", rName),
Expand Down Expand Up @@ -567,7 +599,7 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
`, rName)
}

func testAccSagemakerEndpointConfigurationConfigAsyncConfig(rName string) string {
func testAccSagemakerEndpointConfigurationConfigAsyncKMSConfig(rName string) string {
return testAccSagemakerEndpointConfigurationConfig_Base(rName) + fmt.Sprintf(`
resource "aws_s3_bucket" "test" {
bucket = %[1]q
Expand Down Expand Up @@ -605,6 +637,34 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
`, rName)
}

func testAccSagemakerEndpointConfigurationConfigAsyncConfig(rName string) string {
return testAccSagemakerEndpointConfigurationConfig_Base(rName) + fmt.Sprintf(`
resource "aws_s3_bucket" "test" {
bucket = %[1]q
acl = "private"
force_destroy = true
}
resource "aws_sagemaker_endpoint_configuration" "test" {
name = %[1]q
production_variants {
variant_name = "variant-1"
model_name = aws_sagemaker_model.test.name
initial_instance_count = 2
instance_type = "ml.t2.medium"
initial_variant_weight = 1
}
async_inference_config {
output_config {
s3_output_path = "s3://${aws_s3_bucket.test.bucket}/"
}
}
}
`, rName)
}

func testAccSagemakerEndpointConfigurationConfigAsyncNotifConfig(rName string) string {
return testAccSagemakerEndpointConfigurationConfig_Base(rName) + fmt.Sprintf(`
resource "aws_s3_bucket" "test" {
Expand Down

0 comments on commit c75c99a

Please sign in to comment.