diff --git a/.changelog/39796.txt b/.changelog/39796.txt new file mode 100644 index 000000000000..e65010e0276e --- /dev/null +++ b/.changelog/39796.txt @@ -0,0 +1,3 @@ +```release-note:new-resource +aws_sagemaker_mlflow_tracking_server +``` diff --git a/internal/service/sagemaker/exports_test.go b/internal/service/sagemaker/exports_test.go index 2aa9346fdf9b..4628dca7a016 100644 --- a/internal/service/sagemaker/exports_test.go +++ b/internal/service/sagemaker/exports_test.go @@ -19,6 +19,7 @@ var ( ResourceHumanTaskUI = resourceHumanTaskUI ResourceImage = resourceImage ResourceImageVersion = resourceImageVersion + ResourceMlflowTrackingServer = resourceMlflowTrackingServer ResourceModel = resourceModel ResourceModelPackageGroup = resourceModelPackageGroup ResourceModelPackageGroupPolicy = resourceModelPackageGroupPolicy @@ -47,6 +48,7 @@ var ( FindHumanTaskUIByName = findHumanTaskUIByName FindImageByName = findImageByName FindImageVersionByName = findImageVersionByName + FindMlflowTrackingServerByName = findMlflowTrackingServerByName FindModelByName = findModelByName FindModelPackageGroupByName = findModelPackageGroupByName FindModelPackageGroupPolicyByName = findModelPackageGroupPolicyByName diff --git a/internal/service/sagemaker/mlflow_tracking_server.go b/internal/service/sagemaker/mlflow_tracking_server.go new file mode 100644 index 000000000000..9122d3f606f0 --- /dev/null +++ b/internal/service/sagemaker/mlflow_tracking_server.go @@ -0,0 +1,250 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package sagemaker + +import ( + "context" + "log" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sagemaker" + awstypes "github.com/aws/aws-sdk-go-v2/service/sagemaker/types" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/terraform-provider-aws/internal/conns" + "github.com/hashicorp/terraform-provider-aws/internal/enum" + "github.com/hashicorp/terraform-provider-aws/internal/errs" + "github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag" + tftags "github.com/hashicorp/terraform-provider-aws/internal/tags" + "github.com/hashicorp/terraform-provider-aws/internal/tfresource" + "github.com/hashicorp/terraform-provider-aws/internal/verify" + "github.com/hashicorp/terraform-provider-aws/names" +) + +// @SDKResource("aws_sagemaker_mlflow_tracking_server", name="Mlflow Tracking Server") +// @Tags(identifierAttribute="arn") +func resourceMlflowTrackingServer() *schema.Resource { + return &schema.Resource{ + CreateWithoutTimeout: resourceMlflowTrackingServerCreate, + ReadWithoutTimeout: resourceMlflowTrackingServerRead, + UpdateWithoutTimeout: resourceMlflowTrackingServerUpdate, + DeleteWithoutTimeout: resourceMlflowTrackingServerDelete, + Importer: &schema.ResourceImporter{ + StateContext: schema.ImportStatePassthroughContext, + }, + + Schema: map[string]*schema.Schema{ + names.AttrARN: { + Type: schema.TypeString, + Computed: true, + }, + "artifact_store_uri": { + Type: schema.TypeString, + Required: true, + ValidateFunc: validModelDataURL, + }, + names.AttrRoleARN: { + Type: schema.TypeString, + Required: true, + ForceNew: true, + ValidateFunc: verify.ValidARN, + }, + "tracking_server_name": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + }, + "mlflow_version": { + Type: schema.TypeString, + Optional: true, + Computed: true, + ForceNew: true, + }, + "tracking_server_url": { + Type: schema.TypeString, + Computed: true, + }, + "automatic_model_registration": { + Type: schema.TypeBool, + Optional: true, + Default: false, + }, + "tracking_server_size": { + Type: schema.TypeString, + Optional: true, + Default: awstypes.TrackingServerSizeS, + ValidateDiagFunc: enum.Validate[awstypes.TrackingServerSize](), + }, + "weekly_maintenance_window_start": { + Type: schema.TypeString, + Optional: true, + Computed: true, + }, + names.AttrTags: tftags.TagsSchema(), + names.AttrTagsAll: tftags.TagsSchemaComputed(), + }, + + CustomizeDiff: verify.SetTagsDiff, + } +} + +func resourceMlflowTrackingServerCreate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + var diags diag.Diagnostics + conn := meta.(*conns.AWSClient).SageMakerClient(ctx) + + name := d.Get("tracking_server_name").(string) + input := &sagemaker.CreateMlflowTrackingServerInput{ + TrackingServerName: aws.String(name), + ArtifactStoreUri: aws.String(d.Get("artifact_store_uri").(string)), + RoleArn: aws.String(d.Get(names.AttrRoleARN).(string)), + AutomaticModelRegistration: aws.Bool(d.Get("automatic_model_registration").(bool)), + TrackingServerSize: awstypes.TrackingServerSize(d.Get("tracking_server_size").(string)), + Tags: getTagsIn(ctx), + } + + if v, ok := d.GetOk("mlflow_version"); ok { + input.MlflowVersion = aws.String(v.(string)) + } + + if v, ok := d.GetOk("weekly_maintenance_window_start"); ok { + input.WeeklyMaintenanceWindowStart = aws.String(v.(string)) + } + + _, err := conn.CreateMlflowTrackingServer(ctx, input) + if err != nil { + return sdkdiag.AppendErrorf(diags, "creating SageMaker Mlflow Tracking Server %s: %s", name, err) + } + + d.SetId(name) + + if _, err := waitMlflowTrackingServerCreated(ctx, conn, d.Id()); err != nil { + return sdkdiag.AppendErrorf(diags, "waiting for SageMaker Mlflow Tracking Server (%s) to delete: %s", d.Id(), err) + } + + return append(diags, resourceMlflowTrackingServerRead(ctx, d, meta)...) +} + +func resourceMlflowTrackingServerRead(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + var diags diag.Diagnostics + conn := meta.(*conns.AWSClient).SageMakerClient(ctx) + + output, err := findMlflowTrackingServerByName(ctx, conn, d.Id()) + + if !d.IsNewResource() && tfresource.NotFound(err) { + d.SetId("") + log.Printf("[WARN] Unable to find SageMaker Mlflow Tracking Server (%s); removing from state", d.Id()) + return diags + } + + if err != nil { + return sdkdiag.AppendErrorf(diags, "reading SageMaker Mlflow Tracking Server (%s): %s", d.Id(), err) + } + + d.Set("tracking_server_name", output.TrackingServerName) + d.Set(names.AttrARN, output.TrackingServerArn) + d.Set("artifact_store_uri", output.ArtifactStoreUri) + d.Set(names.AttrRoleARN, output.RoleArn) + d.Set("mlflow_version", output.MlflowVersion) + d.Set("tracking_server_size", output.TrackingServerSize) + d.Set("weekly_maintenance_window_start", output.WeeklyMaintenanceWindowStart) + d.Set("tracking_server_url", output.TrackingServerUrl) + d.Set("automatic_model_registration", output.AutomaticModelRegistration) + + return diags +} + +func resourceMlflowTrackingServerUpdate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + var diags diag.Diagnostics + conn := meta.(*conns.AWSClient).SageMakerClient(ctx) + + if d.HasChangesExcept(names.AttrTags, names.AttrTagsAll) { + input := &sagemaker.UpdateMlflowTrackingServerInput{ + TrackingServerName: aws.String(d.Id()), + } + + if d.HasChange("artifact_store_uri") { + if v, ok := d.GetOk("artifact_store_uri"); ok { + input.ArtifactStoreUri = aws.String(v.(string)) + } + } + + if d.HasChange("automatic_model_registration") { + if v, ok := d.GetOk("automatic_model_registration"); ok { + input.AutomaticModelRegistration = aws.Bool(v.(bool)) + } + } + + if d.HasChange("tracking_server_size") { + if v, ok := d.GetOk("tracking_server_size"); ok { + input.TrackingServerSize = awstypes.TrackingServerSize(v.(string)) + } + } + + if d.HasChange("weekly_maintenance_window_start") { + if v, ok := d.GetOk("weekly_maintenance_window_start"); ok { + input.WeeklyMaintenanceWindowStart = aws.String(v.(string)) + } + } + + log.Printf("[DEBUG] SageMaker Mlflow Tracking Server update config: %#v", *input) + _, err := conn.UpdateMlflowTrackingServer(ctx, input) + if err != nil { + return sdkdiag.AppendErrorf(diags, "updating SageMaker Mlflow Tracking Server: %s", err) + } + + if _, err := waitMlflowTrackingServerUpdated(ctx, conn, d.Id()); err != nil { + return sdkdiag.AppendErrorf(diags, "waiting for SageMaker Mlflow Tracking Server (%s) to update: %s", d.Id(), err) + } + } + + return append(diags, resourceMlflowTrackingServerRead(ctx, d, meta)...) +} + +func resourceMlflowTrackingServerDelete(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + var diags diag.Diagnostics + conn := meta.(*conns.AWSClient).SageMakerClient(ctx) + + input := &sagemaker.DeleteMlflowTrackingServerInput{ + TrackingServerName: aws.String(d.Id()), + } + + if _, err := conn.DeleteMlflowTrackingServer(ctx, input); err != nil { + if errs.IsA[*awstypes.ResourceNotFound](err) { + return diags + } + return sdkdiag.AppendErrorf(diags, "deleting SageMaker Mlflow Tracking Server (%s): %s", d.Id(), err) + } + + if _, err := waitMlflowTrackingServerDeleted(ctx, conn, d.Id()); err != nil { + return sdkdiag.AppendErrorf(diags, "waiting for SageMaker Mlflow Tracking Server (%s) to delete: %s", d.Id(), err) + } + + return diags +} + +func findMlflowTrackingServerByName(ctx context.Context, conn *sagemaker.Client, name string) (*sagemaker.DescribeMlflowTrackingServerOutput, error) { + input := &sagemaker.DescribeMlflowTrackingServerInput{ + TrackingServerName: aws.String(name), + } + + output, err := conn.DescribeMlflowTrackingServer(ctx, input) + + if errs.IsA[*awstypes.ResourceNotFound](err) { + return nil, &retry.NotFoundError{ + LastError: err, + LastRequest: input, + } + } + + if err != nil { + return nil, err + } + + if output == nil { + return nil, tfresource.NewEmptyResultError(input) + } + + return output, nil +} diff --git a/internal/service/sagemaker/mlflow_tracking_server_test.go b/internal/service/sagemaker/mlflow_tracking_server_test.go new file mode 100644 index 000000000000..67d1379f7b3f --- /dev/null +++ b/internal/service/sagemaker/mlflow_tracking_server_test.go @@ -0,0 +1,289 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package sagemaker_test + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/sagemaker" + sdkacctest "github.com/hashicorp/terraform-plugin-testing/helper/acctest" + "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/terraform" + "github.com/hashicorp/terraform-provider-aws/internal/acctest" + "github.com/hashicorp/terraform-provider-aws/internal/conns" + tfsagemaker "github.com/hashicorp/terraform-provider-aws/internal/service/sagemaker" + "github.com/hashicorp/terraform-provider-aws/internal/tfresource" + "github.com/hashicorp/terraform-provider-aws/names" +) + +func TestAccSageMakerMlflowTrackingServer_basic(t *testing.T) { + ctx := acctest.Context(t) + var mpg sagemaker.DescribeMlflowTrackingServerOutput + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_sagemaker_mlflow_tracking_server.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t) }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckMlflowTrackingServerDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccMlflowTrackingServerConfig_basic(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckMlflowTrackingServerExists(ctx, resourceName, &mpg), + resource.TestCheckResourceAttr(resourceName, "tracking_server_name", rName), + resource.TestCheckResourceAttr(resourceName, "automatic_model_registration", acctest.CtFalse), + resource.TestCheckResourceAttr(resourceName, "tracking_server_size", "Small"), + resource.TestCheckResourceAttrSet(resourceName, "tracking_server_url"), + resource.TestCheckResourceAttrPair(resourceName, names.AttrRoleARN, "aws_iam_role.test", names.AttrARN), + acctest.CheckResourceAttrRegionalARN(resourceName, names.AttrARN, "sagemaker", fmt.Sprintf("mlflow-tracking-server/%s", rName)), + resource.TestCheckResourceAttr(resourceName, acctest.CtTagsPercent, acctest.Ct0), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccMlflowTrackingServerConfig_update(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckMlflowTrackingServerExists(ctx, resourceName, &mpg), + resource.TestCheckResourceAttr(resourceName, "tracking_server_name", rName), + resource.TestCheckResourceAttr(resourceName, "automatic_model_registration", acctest.CtTrue), + resource.TestCheckResourceAttr(resourceName, "tracking_server_size", "Medium"), + resource.TestCheckResourceAttr(resourceName, "weekly_maintenance_window_start", "Sun:01:00"), + resource.TestCheckResourceAttrSet(resourceName, "tracking_server_url"), + resource.TestCheckResourceAttrPair(resourceName, names.AttrRoleARN, "aws_iam_role.test", names.AttrARN), + acctest.CheckResourceAttrRegionalARN(resourceName, names.AttrARN, "sagemaker", fmt.Sprintf("mlflow-tracking-server/%s", rName)), + resource.TestCheckResourceAttr(resourceName, acctest.CtTagsPercent, acctest.Ct0), + ), + }, + }, + }) +} + +func TestAccSageMakerMlflowTrackingServer_tags(t *testing.T) { + ctx := acctest.Context(t) + var mpg sagemaker.DescribeMlflowTrackingServerOutput + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_sagemaker_mlflow_tracking_server.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t) }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckMlflowTrackingServerDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccMlflowTrackingServerConfig_tags1(rName, acctest.CtKey1, acctest.CtValue1), + Check: resource.ComposeTestCheckFunc( + testAccCheckMlflowTrackingServerExists(ctx, resourceName, &mpg), + resource.TestCheckResourceAttr(resourceName, acctest.CtTagsPercent, acctest.Ct1), + resource.TestCheckResourceAttr(resourceName, acctest.CtTagsKey1, acctest.CtValue1), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccMlflowTrackingServerConfig_tags2(rName, acctest.CtKey1, acctest.CtValue1Updated, acctest.CtKey2, acctest.CtValue2), + Check: resource.ComposeTestCheckFunc( + testAccCheckMlflowTrackingServerExists(ctx, resourceName, &mpg), + resource.TestCheckResourceAttr(resourceName, acctest.CtTagsPercent, acctest.Ct2), + resource.TestCheckResourceAttr(resourceName, acctest.CtTagsKey1, acctest.CtValue1Updated), + resource.TestCheckResourceAttr(resourceName, acctest.CtTagsKey2, acctest.CtValue2), + ), + }, + { + Config: testAccMlflowTrackingServerConfig_tags1(rName, acctest.CtKey2, acctest.CtValue2), + Check: resource.ComposeTestCheckFunc( + testAccCheckMlflowTrackingServerExists(ctx, resourceName, &mpg), + resource.TestCheckResourceAttr(resourceName, acctest.CtTagsPercent, acctest.Ct1), + resource.TestCheckResourceAttr(resourceName, acctest.CtTagsKey2, acctest.CtValue2), + ), + }, + }, + }) +} + +func TestAccSageMakerMlflowTrackingServer_disappears(t *testing.T) { + ctx := acctest.Context(t) + var mpg sagemaker.DescribeMlflowTrackingServerOutput + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_sagemaker_mlflow_tracking_server.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t) }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckMlflowTrackingServerDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccMlflowTrackingServerConfig_basic(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckMlflowTrackingServerExists(ctx, resourceName, &mpg), + acctest.CheckResourceDisappears(ctx, acctest.Provider, tfsagemaker.ResourceMlflowTrackingServer(), resourceName), + ), + ExpectNonEmptyPlan: true, + }, + }, + }) +} + +func testAccCheckMlflowTrackingServerDestroy(ctx context.Context) resource.TestCheckFunc { + return func(s *terraform.State) error { + conn := acctest.Provider.Meta().(*conns.AWSClient).SageMakerClient(ctx) + + for _, rs := range s.RootModule().Resources { + if rs.Type != "aws_sagemaker_mlflow_tracking_server" { + continue + } + + _, err := tfsagemaker.FindMlflowTrackingServerByName(ctx, conn, rs.Primary.ID) + + if tfresource.NotFound(err) { + continue + } + + if err != nil { + return fmt.Errorf("reading SageMaker Mlflow Tracking Server (%s): %w", rs.Primary.ID, err) + } + + return fmt.Errorf("sagemaker Mlflow Tracking Server %s still exists", rs.Primary.ID) + } + + return nil + } +} + +func testAccCheckMlflowTrackingServerExists(ctx context.Context, n string, mpg *sagemaker.DescribeMlflowTrackingServerOutput) resource.TestCheckFunc { + return func(s *terraform.State) error { + rs, ok := s.RootModule().Resources[n] + if !ok { + return fmt.Errorf("Not found: %s", n) + } + + if rs.Primary.ID == "" { + return fmt.Errorf("No sagmaker Mlflow Tracking Server ID is set") + } + + conn := acctest.Provider.Meta().(*conns.AWSClient).SageMakerClient(ctx) + resp, err := tfsagemaker.FindMlflowTrackingServerByName(ctx, conn, rs.Primary.ID) + if err != nil { + return err + } + + *mpg = *resp + + return nil + } +} + +func testAccMlflowTrackingServerConfig_base(rName string) string { + return fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_iam_role" "test" { + name = %[1]q + path = "/" + assume_role_policy = data.aws_iam_policy_document.test.json + inline_policy { + name = "TrackingServerPolicy" + policy = data.aws_iam_policy_document.tracking.json + } +} + +data "aws_iam_policy_document" "test" { + statement { + actions = ["sts:AssumeRole"] + + principals { + type = "Service" + identifiers = ["sagemaker.${data.aws_partition.current.dns_suffix}"] + } + } +} + +data "aws_iam_policy_document" "tracking" { + statement { + sid = "Tracking" + effect = "Allow" + actions = [ + "s3:Get*", + "s3:Put*", + "s3:List*", + "sagemaker:AddTags", + "sagemaker:CreateModelPackageGroup", + "sagemaker:CreateModelPackage", + "sagemaker:UpdateModelPackage", + "sagemaker:DescribeModelPackageGroup" + ] + resources = ["*"] + } +} + +resource "aws_s3_bucket" "test" { + bucket = %[1]q + force_destroy = true +} +`, rName) +} + +func testAccMlflowTrackingServerConfig_basic(rName string) string { + return acctest.ConfigCompose(testAccMlflowTrackingServerConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_mlflow_tracking_server" "test" { + tracking_server_name = %[1]q + role_arn = aws_iam_role.test.arn + artifact_store_uri = "s3://${aws_s3_bucket.test.bucket}/path" +} +`, rName)) +} + +func testAccMlflowTrackingServerConfig_update(rName string) string { + return acctest.ConfigCompose(testAccMlflowTrackingServerConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_mlflow_tracking_server" "test" { + tracking_server_name = %[1]q + role_arn = aws_iam_role.test.arn + artifact_store_uri = "s3://${aws_s3_bucket.test.bucket}/path" + automatic_model_registration = true + tracking_server_size = "Medium" + weekly_maintenance_window_start = "Sun:01:00" +} +`, rName)) +} + +func testAccMlflowTrackingServerConfig_tags1(rName, tagKey1, tagValue1 string) string { + return acctest.ConfigCompose(testAccMlflowTrackingServerConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_mlflow_tracking_server" "test" { + tracking_server_name = %[1]q + role_arn = aws_iam_role.test.arn + artifact_store_uri = "s3://${aws_s3_bucket.test.bucket}/path" + + tags = { + %[2]q = %[3]q + } +} +`, rName, tagKey1, tagValue1)) +} + +func testAccMlflowTrackingServerConfig_tags2(rName, tagKey1, tagValue1, tagKey2, tagValue2 string) string { + return acctest.ConfigCompose(testAccMlflowTrackingServerConfig_base(rName), fmt.Sprintf(` +resource "aws_sagemaker_mlflow_tracking_server" "test" { + tracking_server_name = %[1]q + role_arn = aws_iam_role.test.arn + artifact_store_uri = "s3://${aws_s3_bucket.test.bucket}/path" + + tags = { + %[2]q = %[3]q + %[4]q = %[5]q + } +} +`, rName, tagKey1, tagValue1, tagKey2, tagValue2)) +} diff --git a/internal/service/sagemaker/service_package_gen.go b/internal/service/sagemaker/service_package_gen.go index 4f9dca4afdb3..6c32572278d2 100644 --- a/internal/service/sagemaker/service_package_gen.go +++ b/internal/service/sagemaker/service_package_gen.go @@ -140,6 +140,14 @@ func (p *servicePackage) SDKResources(ctx context.Context) []*types.ServicePacka TypeName: "aws_sagemaker_image_version", Name: "Image Version", }, + { + Factory: resourceMlflowTrackingServer, + TypeName: "aws_sagemaker_mlflow_tracking_server", + Name: "Mlflow Tracking Server", + Tags: &types.ServicePackageResourceTags{ + IdentifierAttribute: names.AttrARN, + }, + }, { Factory: resourceModel, TypeName: "aws_sagemaker_model", diff --git a/internal/service/sagemaker/status.go b/internal/service/sagemaker/status.go index 77645f33cebc..6743a8f7419d 100644 --- a/internal/service/sagemaker/status.go +++ b/internal/service/sagemaker/status.go @@ -202,3 +202,19 @@ func statusMonitoringSchedule(ctx context.Context, conn *sagemaker.Client, name return output, string(output.MonitoringScheduleStatus), nil } } + +func statusMlflowTrackingServer(ctx context.Context, conn *sagemaker.Client, name string) retry.StateRefreshFunc { + return func() (interface{}, string, error) { + output, err := findMlflowTrackingServerByName(ctx, conn, name) + + if tfresource.NotFound(err) { + return nil, "", nil + } + + if err != nil { + return nil, "", err + } + + return output, string(output.TrackingServerStatus), nil + } +} diff --git a/internal/service/sagemaker/sweep.go b/internal/service/sagemaker/sweep.go index 36977c46606a..eb2a6c5f5244 100644 --- a/internal/service/sagemaker/sweep.go +++ b/internal/service/sagemaker/sweep.go @@ -87,6 +87,11 @@ func RegisterSweepers() { F: sweepImages, }) + resource.AddTestSweepers("aws_sagemaker_mlflow_tracking_server", &resource.Sweeper{ + Name: "aws_sagemaker_mlflow_tracking_server", + F: sweepMlflowTrackingServers, + }) + resource.AddTestSweepers("aws_sagemaker_model_package_group", &resource.Sweeper{ Name: "aws_sagemaker_model_package_group", F: sweepModelPackageGroups, @@ -1071,3 +1076,44 @@ func sweepPipelines(region string) error { return sweeperErrs.ErrorOrNil() } + +func sweepMlflowTrackingServers(region string) error { + ctx := sweep.Context(region) + client, err := sweep.SharedRegionalSweepClient(ctx, region) + if err != nil { + return fmt.Errorf("getting client: %s", err) + } + conn := client.SageMakerClient(ctx) + + sweepResources := make([]sweep.Sweepable, 0) + var sweeperErrs *multierror.Error + + pages := sagemaker.NewListMlflowTrackingServersPaginator(conn, &sagemaker.ListMlflowTrackingServersInput{}) + for pages.HasMorePages() { + page, err := pages.NextPage(ctx) + + if awsv2.SkipSweepError(err) { + log.Printf("[WARN] Skipping SageMaker Mlflow Tracking Server sweep for %s: %s", region, err) + return sweeperErrs.ErrorOrNil() + } + if err != nil { + sweeperErrs = multierror.Append(sweeperErrs, fmt.Errorf("retrieving SageMaker Mlflow Tracking Servers: %w", err)) + } + + for _, project := range page.TrackingServerSummaries { + name := aws.ToString(project.TrackingServerName) + + r := resourceMlflowTrackingServer() + d := r.Data(nil) + d.SetId(name) + + sweepResources = append(sweepResources, sweep.NewSweepResource(r, d, client)) + } + } + + if err := sweep.SweepOrchestrator(ctx, sweepResources); err != nil { + sweeperErrs = multierror.Append(sweeperErrs, fmt.Errorf("sweeping SageMaker Mlflow Tracking Servers: %w", err)) + } + + return sweeperErrs.ErrorOrNil() +} diff --git a/internal/service/sagemaker/wait.go b/internal/service/sagemaker/wait.go index dd557bc06acd..71ce0f2a8c99 100644 --- a/internal/service/sagemaker/wait.go +++ b/internal/service/sagemaker/wait.go @@ -42,6 +42,7 @@ const ( spaceInServiceTimeout = 10 * time.Minute monitoringScheduleScheduledTimeout = 2 * time.Minute monitoringScheduleStoppedTimeout = 2 * time.Minute + mlflowTrackingServerTimeout = 30 * time.Minute notebookInstanceStatusNotFound = "NotFound" ) @@ -589,3 +590,54 @@ func waitMonitoringScheduleNotFound(ctx context.Context, conn *sagemaker.Client, return nil, err } + +func waitMlflowTrackingServerCreated(ctx context.Context, conn *sagemaker.Client, name string) (*sagemaker.DescribeMlflowTrackingServerOutput, error) { + stateConf := &retry.StateChangeConf{ + Pending: enum.Slice(awstypes.TrackingServerStatusCreating), + Target: enum.Slice(awstypes.TrackingServerStatusCreated), + Refresh: statusMlflowTrackingServer(ctx, conn, name), + Timeout: mlflowTrackingServerTimeout, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + + if output, ok := outputRaw.(*sagemaker.DescribeMlflowTrackingServerOutput); ok { + return output, err + } + + return nil, err +} + +func waitMlflowTrackingServerUpdated(ctx context.Context, conn *sagemaker.Client, name string) (*sagemaker.DescribeMlflowTrackingServerOutput, error) { + stateConf := &retry.StateChangeConf{ + Pending: enum.Slice(awstypes.TrackingServerStatusUpdating), + Target: enum.Slice(awstypes.TrackingServerStatusUpdated), + Refresh: statusMlflowTrackingServer(ctx, conn, name), + Timeout: mlflowTrackingServerTimeout, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + + if output, ok := outputRaw.(*sagemaker.DescribeMlflowTrackingServerOutput); ok { + return output, err + } + + return nil, err +} + +func waitMlflowTrackingServerDeleted(ctx context.Context, conn *sagemaker.Client, name string) (*sagemaker.DescribeMlflowTrackingServerOutput, error) { + stateConf := &retry.StateChangeConf{ + Pending: enum.Slice(awstypes.TrackingServerStatusDeleting), + Target: []string{}, + Refresh: statusMlflowTrackingServer(ctx, conn, name), + Timeout: mlflowTrackingServerTimeout, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + + if output, ok := outputRaw.(*sagemaker.DescribeMlflowTrackingServerOutput); ok { + return output, err + } + + return nil, err +} diff --git a/website/docs/r/sagemaker_mlflow_tracking_server.html.markdown b/website/docs/r/sagemaker_mlflow_tracking_server.html.markdown new file mode 100644 index 000000000000..2b114b6baa71 --- /dev/null +++ b/website/docs/r/sagemaker_mlflow_tracking_server.html.markdown @@ -0,0 +1,62 @@ +--- +subcategory: "SageMaker" +layout: "aws" +page_title: "AWS: aws_sagemaker_mlflow_tracking_server" +description: |- + Provides a SageMaker MLFlow Tracking Server resource. +--- + +# Resource: aws_sagemaker_mlflow_tracking_server + +Provides a SageMaker MLFlow Tracking Server resource. + +## Example Usage + +### Cognito Usage + +```terraform +resource "aws_sagemaker_mlflow_tracking_server" "example" { + tracking_server_name = "example" + role_arn = aws_iam_role.example.arn + artifact_store_uri = "s3://${aws_s3_bucket.example.bucket}/path" +} +``` + +## Argument Reference + +This resource supports the following arguments: + +* `artifact_store_uri` - (Required) The S3 URI for a general purpose bucket to use as the MLflow Tracking Server artifact store. +* `role_arn` - (Required) The Amazon Resource Name (ARN) for an IAM role in your account that the MLflow Tracking Server uses to access the artifact store in Amazon S3. The role should have AmazonS3FullAccess permissions. For more information on IAM permissions for tracking server creation, see [Set up IAM permissions for MLflow](https://docs.aws.amazon.com/sagemaker/latest/dg/mlflow-create-tracking-server-iam.html). +* `tracking_server_name` - (Required) A unique string identifying the tracking server name. This string is part of the tracking server ARN. +* `mlflow_version` - (Optional) The version of MLflow that the tracking server uses. To see which MLflow versions are available to use, see [How it works](https://docs.aws.amazon.com/sagemaker/latest/dg/mlflow.html#mlflow-create-tracking-server-how-it-works). +* `automatic_model_registration` - (Optional) A list of Member Definitions that contains objects that identify the workers that make up the work team. +* `tracking_server_size` - (Optional) The size of the tracking server you want to create. You can choose between "Small", "Medium", and "Large". The default MLflow Tracking Server configuration size is "Small". You can choose a size depending on the projected use of the tracking server such as the volume of data logged, number of users, and frequency of use. +* `weekly_maintenance_window_start` - (Optional) The day and time of the week in Coordinated Universal Time (UTC) 24-hour standard time that weekly maintenance updates are scheduled. For example: TUE:03:30. +* `tags` - (Optional) A map of tags to assign to the resource. If configured with a provider [`default_tags` configuration block](https://registry.terraform.io/providers/hashicorp/aws/latest/docs#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level. + +## Attribute Reference + +This resource exports the following attributes in addition to the arguments above: + +* `arn` - The Amazon Resource Name (ARN) assigned by AWS to this MLFlow Tracking Server. +* `id` - The name of the MLFlow Tracking Server. +* `tracking_server_url` - The URL to connect to the MLflow user interface for the described tracking server. +* `tags_all` - A map of tags assigned to the resource, including those inherited from the provider [`default_tags` configuration block](https://registry.terraform.io/providers/hashicorp/aws/latest/docs#default_tags-configuration-block). + +## Import + +In Terraform v1.5.0 and later, use an [`import` block](https://developer.hashicorp.com/terraform/language/import) to import SageMaker MLFlow Tracking Servers using the `workteam_name`. For example: + +```terraform +import { + to = aws_sagemaker_mlflow_tracking_server.example + id = "example" +} +``` + +Using `terraform import`, import SageMaker MLFlow Tracking Servers using the `workteam_name`. For example: + +```console +% terraform import aws_sagemaker_mlflow_tracking_server.example example +```