Skip to content

Commit

Permalink
Add model_access_config, multi_model_config and inference_specificati…
Browse files Browse the repository at this point in the history
…on_name to container block
  • Loading branch information
deepakbshetty committed Feb 20, 2024
1 parent cb64b2a commit 2185243
Show file tree
Hide file tree
Showing 4 changed files with 535 additions and 4 deletions.
154 changes: 153 additions & 1 deletion internal/service/sagemaker/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,48 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ModelCompressionType_Values(), false),
},
"model_access_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"accept_eula": {
Type: schema.TypeBool,
Required: true,
ForceNew: true,
},
},
},
},
},
},
},
},
},
},
"inference_specification_name": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validName,
},
"multi_model_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"model_cache_setting": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ModelCacheSetting_Values(), false),
},
},
},
},
},
},
},
Expand Down Expand Up @@ -292,12 +328,49 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ModelCompressionType_Values(), false),
},
"model_access_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"accept_eula": {
Type: schema.TypeBool,
Required: true,
ForceNew: true,
},
},
},
},
},
},
},
},
},
},
"inference_specification_name": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validName,
},
"multi_model_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"model_cache_setting": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ModelCacheSetting_Values(), false),
},
},
},
},
},
},
},
Expand Down Expand Up @@ -521,6 +594,14 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition {
container.ImageConfig = expandModelImageConfig(v.([]interface{}))
}

if v, ok := m["inference_specification_name"]; ok && v.(string) != "" {
container.InferenceSpecificationName = aws.String(v.(string))
}

if v, ok := m["multi_model_config"].([]interface{}); ok && len(v) > 0 {
container.MultiModelConfig = expandMultiModelConfig(v)
}

return &container
}

Expand Down Expand Up @@ -559,6 +640,10 @@ func expandS3ModelDataSource(l []interface{}) *sagemaker.S3ModelDataSource {
s3ModelDataSource.CompressionType = aws.String(v.(string))
}

if v, ok := m["model_access_config"].([]interface{}); ok && len(v) > 0 {
s3ModelDataSource.ModelAccessConfig = expandModelAccessConfig(v)
}

return &s3ModelDataSource
}

Expand Down Expand Up @@ -604,6 +689,38 @@ func expandContainers(a []interface{}) []*sagemaker.ContainerDefinition {
return containers
}

func expandModelAccessConfig(l []interface{}) *sagemaker.ModelAccessConfig {
if len(l) == 0 {
return nil
}

m := l[0].(map[string]interface{})

modelAccessConfig := &sagemaker.ModelAccessConfig{}

if v, ok := m["accept_eula"].(bool); ok {
modelAccessConfig.AcceptEula = aws.Bool(v)
}

return modelAccessConfig
}

func expandMultiModelConfig(l []interface{}) *sagemaker.MultiModelConfig {
if len(l) == 0 {
return nil
}

m := l[0].(map[string]interface{})

multiModelConfig := &sagemaker.MultiModelConfig{}

if v, ok := m["model_cache_setting"].(string); ok && v != "" {
multiModelConfig.ModelCacheSetting = aws.String(v)
}

return multiModelConfig
}

func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} {
if container == nil {
return []interface{}{}
Expand Down Expand Up @@ -634,11 +751,18 @@ func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} {
if container.Environment != nil {
cfg["environment"] = aws.StringValueMap(container.Environment)
}

if container.ImageConfig != nil {
cfg["image_config"] = flattenImageConfig(container.ImageConfig)
}

if container.InferenceSpecificationName != nil {
cfg["inference_specification_name"] = aws.StringValue(container.InferenceSpecificationName)
}

if container.MultiModelConfig != nil {
cfg["multi_model_config"] = flattenMultiModelConfig(container.MultiModelConfig)
}

return []interface{}{cfg}
}

Expand Down Expand Up @@ -673,6 +797,10 @@ func flattenS3ModelDataSource(s3ModelDataSource *sagemaker.S3ModelDataSource) []
cfg["compression_type"] = aws.StringValue(s3ModelDataSource.CompressionType)
}

if s3ModelDataSource.ModelAccessConfig != nil {
cfg["model_access_config"] = flattenModelAccessConfig(s3ModelDataSource.ModelAccessConfig)
}

return []interface{}{cfg}
}

Expand Down Expand Up @@ -714,6 +842,30 @@ func flattenContainers(containers []*sagemaker.ContainerDefinition) []interface{
return fContainers
}

func flattenModelAccessConfig(config *sagemaker.ModelAccessConfig) []interface{} {
if config == nil {
return []interface{}{}
}

cfg := make(map[string]interface{})

cfg["accept_eula"] = aws.BoolValue(config.AcceptEula)

return []interface{}{cfg}
}

func flattenMultiModelConfig(config *sagemaker.MultiModelConfig) []interface{} {
if config == nil {
return []interface{}{}
}

cfg := make(map[string]interface{})

cfg["model_cache_setting"] = aws.StringValue(config.ModelCacheSetting)

return []interface{}{cfg}
}

func expandModelInferenceExecutionConfig(l []interface{}) *sagemaker.InferenceExecutionConfig {
if len(l) == 0 {
return nil
Expand Down
Loading

0 comments on commit 2185243

Please sign in to comment.