Skip to content

Commit

Permalink
Merge pull request #39796 from DrFaust92/tracking-server
Browse files Browse the repository at this point in the history
r/sagemaker_mlflow_tracking_server - new resource
  • Loading branch information
ewbankkit authored Oct 18, 2024
2 parents 9dfc700 + 821915f commit 008c0cc
Show file tree
Hide file tree
Showing 9 changed files with 728 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .changelog/39796.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:new-resource
aws_sagemaker_mlflow_tracking_server
```
2 changes: 2 additions & 0 deletions internal/service/sagemaker/exports_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ var (
ResourceHumanTaskUI = resourceHumanTaskUI
ResourceImage = resourceImage
ResourceImageVersion = resourceImageVersion
ResourceMlflowTrackingServer = resourceMlflowTrackingServer
ResourceModel = resourceModel
ResourceModelPackageGroup = resourceModelPackageGroup
ResourceModelPackageGroupPolicy = resourceModelPackageGroupPolicy
Expand Down Expand Up @@ -47,6 +48,7 @@ var (
FindHumanTaskUIByName = findHumanTaskUIByName
FindImageByName = findImageByName
FindImageVersionByName = findImageVersionByName
FindMlflowTrackingServerByName = findMlflowTrackingServerByName
FindModelByName = findModelByName
FindModelPackageGroupByName = findModelPackageGroupByName
FindModelPackageGroupPolicyByName = findModelPackageGroupPolicyByName
Expand Down
250 changes: 250 additions & 0 deletions internal/service/sagemaker/mlflow_tracking_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package sagemaker

import (
"context"
"log"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sagemaker"
awstypes "github.com/aws/aws-sdk-go-v2/service/sagemaker/types"
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/enum"
"github.com/hashicorp/terraform-provider-aws/internal/errs"
"github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag"
tftags "github.com/hashicorp/terraform-provider-aws/internal/tags"
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
"github.com/hashicorp/terraform-provider-aws/internal/verify"
"github.com/hashicorp/terraform-provider-aws/names"
)

// @SDKResource("aws_sagemaker_mlflow_tracking_server", name="Mlflow Tracking Server")
// @Tags(identifierAttribute="arn")
func resourceMlflowTrackingServer() *schema.Resource {
return &schema.Resource{
CreateWithoutTimeout: resourceMlflowTrackingServerCreate,
ReadWithoutTimeout: resourceMlflowTrackingServerRead,
UpdateWithoutTimeout: resourceMlflowTrackingServerUpdate,
DeleteWithoutTimeout: resourceMlflowTrackingServerDelete,
Importer: &schema.ResourceImporter{
StateContext: schema.ImportStatePassthroughContext,
},

Schema: map[string]*schema.Schema{
names.AttrARN: {
Type: schema.TypeString,
Computed: true,
},
"artifact_store_uri": {
Type: schema.TypeString,
Required: true,
ValidateFunc: validModelDataURL,
},
names.AttrRoleARN: {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
"tracking_server_name": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
},
"mlflow_version": {
Type: schema.TypeString,
Optional: true,
Computed: true,
ForceNew: true,
},
"tracking_server_url": {
Type: schema.TypeString,
Computed: true,
},
"automatic_model_registration": {
Type: schema.TypeBool,
Optional: true,
Default: false,
},
"tracking_server_size": {
Type: schema.TypeString,
Optional: true,
Default: awstypes.TrackingServerSizeS,
ValidateDiagFunc: enum.Validate[awstypes.TrackingServerSize](),
},
"weekly_maintenance_window_start": {
Type: schema.TypeString,
Optional: true,
Computed: true,
},
names.AttrTags: tftags.TagsSchema(),
names.AttrTagsAll: tftags.TagsSchemaComputed(),
},

CustomizeDiff: verify.SetTagsDiff,
}
}

func resourceMlflowTrackingServerCreate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).SageMakerClient(ctx)

name := d.Get("tracking_server_name").(string)
input := &sagemaker.CreateMlflowTrackingServerInput{
TrackingServerName: aws.String(name),
ArtifactStoreUri: aws.String(d.Get("artifact_store_uri").(string)),
RoleArn: aws.String(d.Get(names.AttrRoleARN).(string)),
AutomaticModelRegistration: aws.Bool(d.Get("automatic_model_registration").(bool)),
TrackingServerSize: awstypes.TrackingServerSize(d.Get("tracking_server_size").(string)),
Tags: getTagsIn(ctx),
}

if v, ok := d.GetOk("mlflow_version"); ok {
input.MlflowVersion = aws.String(v.(string))
}

if v, ok := d.GetOk("weekly_maintenance_window_start"); ok {
input.WeeklyMaintenanceWindowStart = aws.String(v.(string))
}

_, err := conn.CreateMlflowTrackingServer(ctx, input)
if err != nil {
return sdkdiag.AppendErrorf(diags, "creating SageMaker Mlflow Tracking Server %s: %s", name, err)
}

d.SetId(name)

if _, err := waitMlflowTrackingServerCreated(ctx, conn, d.Id()); err != nil {
return sdkdiag.AppendErrorf(diags, "waiting for SageMaker Mlflow Tracking Server (%s) to delete: %s", d.Id(), err)
}

return append(diags, resourceMlflowTrackingServerRead(ctx, d, meta)...)
}

func resourceMlflowTrackingServerRead(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).SageMakerClient(ctx)

output, err := findMlflowTrackingServerByName(ctx, conn, d.Id())

if !d.IsNewResource() && tfresource.NotFound(err) {
d.SetId("")
log.Printf("[WARN] Unable to find SageMaker Mlflow Tracking Server (%s); removing from state", d.Id())
return diags
}

if err != nil {
return sdkdiag.AppendErrorf(diags, "reading SageMaker Mlflow Tracking Server (%s): %s", d.Id(), err)
}

d.Set("tracking_server_name", output.TrackingServerName)
d.Set(names.AttrARN, output.TrackingServerArn)
d.Set("artifact_store_uri", output.ArtifactStoreUri)
d.Set(names.AttrRoleARN, output.RoleArn)
d.Set("mlflow_version", output.MlflowVersion)
d.Set("tracking_server_size", output.TrackingServerSize)
d.Set("weekly_maintenance_window_start", output.WeeklyMaintenanceWindowStart)
d.Set("tracking_server_url", output.TrackingServerUrl)
d.Set("automatic_model_registration", output.AutomaticModelRegistration)

return diags
}

func resourceMlflowTrackingServerUpdate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).SageMakerClient(ctx)

if d.HasChangesExcept(names.AttrTags, names.AttrTagsAll) {
input := &sagemaker.UpdateMlflowTrackingServerInput{
TrackingServerName: aws.String(d.Id()),
}

if d.HasChange("artifact_store_uri") {
if v, ok := d.GetOk("artifact_store_uri"); ok {
input.ArtifactStoreUri = aws.String(v.(string))
}
}

if d.HasChange("automatic_model_registration") {
if v, ok := d.GetOk("automatic_model_registration"); ok {
input.AutomaticModelRegistration = aws.Bool(v.(bool))
}
}

if d.HasChange("tracking_server_size") {
if v, ok := d.GetOk("tracking_server_size"); ok {
input.TrackingServerSize = awstypes.TrackingServerSize(v.(string))
}
}

if d.HasChange("weekly_maintenance_window_start") {
if v, ok := d.GetOk("weekly_maintenance_window_start"); ok {
input.WeeklyMaintenanceWindowStart = aws.String(v.(string))
}
}

log.Printf("[DEBUG] SageMaker Mlflow Tracking Server update config: %#v", *input)
_, err := conn.UpdateMlflowTrackingServer(ctx, input)
if err != nil {
return sdkdiag.AppendErrorf(diags, "updating SageMaker Mlflow Tracking Server: %s", err)
}

if _, err := waitMlflowTrackingServerUpdated(ctx, conn, d.Id()); err != nil {
return sdkdiag.AppendErrorf(diags, "waiting for SageMaker Mlflow Tracking Server (%s) to update: %s", d.Id(), err)
}
}

return append(diags, resourceMlflowTrackingServerRead(ctx, d, meta)...)
}

func resourceMlflowTrackingServerDelete(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).SageMakerClient(ctx)

input := &sagemaker.DeleteMlflowTrackingServerInput{
TrackingServerName: aws.String(d.Id()),
}

if _, err := conn.DeleteMlflowTrackingServer(ctx, input); err != nil {
if errs.IsA[*awstypes.ResourceNotFound](err) {
return diags
}
return sdkdiag.AppendErrorf(diags, "deleting SageMaker Mlflow Tracking Server (%s): %s", d.Id(), err)
}

if _, err := waitMlflowTrackingServerDeleted(ctx, conn, d.Id()); err != nil {
return sdkdiag.AppendErrorf(diags, "waiting for SageMaker Mlflow Tracking Server (%s) to delete: %s", d.Id(), err)
}

return diags
}

func findMlflowTrackingServerByName(ctx context.Context, conn *sagemaker.Client, name string) (*sagemaker.DescribeMlflowTrackingServerOutput, error) {
input := &sagemaker.DescribeMlflowTrackingServerInput{
TrackingServerName: aws.String(name),
}

output, err := conn.DescribeMlflowTrackingServer(ctx, input)

if errs.IsA[*awstypes.ResourceNotFound](err) {
return nil, &retry.NotFoundError{
LastError: err,
LastRequest: input,
}
}

if err != nil {
return nil, err
}

if output == nil {
return nil, tfresource.NewEmptyResultError(input)
}

return output, nil
}
Loading

0 comments on commit 008c0cc

Please sign in to comment.