From 0b6a98265190afaf09e67e84da713674180df7ea Mon Sep 17 00:00:00 2001 From: Tyler Kalbach Date: Wed, 16 Oct 2024 08:51:27 -0400 Subject: [PATCH] POC Hack to get s3 tags working for aws backend (breaks gcp, and minio I expect) --- backend/src/v2/component/launcher_v2.go | 22 +++++- backend/src/v2/config/s3.go | 3 +- backend/src/v2/driver/driver_utils.go | 69 +++++++++++++++++ backend/src/v2/objectstore/config.go | 3 +- backend/src/v2/objectstore/object_store.go | 65 ++++++++++++---- .../src/v2/objectstore/object_store_bucket.go | 75 +++++++++++++++++++ 6 files changed, 219 insertions(+), 18 deletions(-) create mode 100644 backend/src/v2/driver/driver_utils.go create mode 100644 backend/src/v2/objectstore/object_store_bucket.go diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index 2b5297c7b09..429536cdb14 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -319,7 +319,7 @@ func executeV2( return nil, nil, err } // TODO(Bobgy): should we log metadata per each artifact, or batched after uploading all artifacts. - outputArtifacts, err := uploadOutputArtifacts(ctx, executorInput, executorOutput, uploadOutputArtifactsOptions{ + outputArtifacts, err := uploadOutputArtifacts(ctx, executorInput, executorOutput, bucketConfig, uploadOutputArtifactsOptions{ bucketConfig: bucketConfig, bucket: bucket, metadataClient: metadataClient, @@ -415,7 +415,7 @@ type uploadOutputArtifactsOptions struct { metadataClient metadata.ClientInterface } -func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, executorOutput *pipelinespec.ExecutorOutput, opts uploadOutputArtifactsOptions) ([]*metadata.OutputArtifact, error) { +func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, executorOutput *pipelinespec.ExecutorOutput, bucketConfig *objectstore.Config, opts uploadOutputArtifactsOptions) ([]*metadata.OutputArtifact, error) { // Register artifacts with MLMD. outputArtifacts := make([]*metadata.OutputArtifact, 0, len(executorInput.GetOutputs().GetArtifacts())) for name, artifactList := range executorInput.GetOutputs().GetArtifacts() { @@ -430,6 +430,22 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec mergeRuntimeArtifacts(list.Artifacts[0], outputArtifact) } + // TODO: Getting aws tags from output metadata. Is there a better way to convert map[string]interface{} to map[string]string? + outputMetadata := outputArtifact.GetMetadata() + glog.Infof("Artifact metadata: %#v", outputMetadata) + awsTags, ok := outputMetadata.AsMap()["aws-tags"] + if !ok { + glog.Warningf("Output Artifact %q does not have aws-tags metadata", name) + } + tags := make(map[string]string) + for k, v := range awsTags.(map[string]interface{}) { + if strValue, ok := v.(string); ok { + tags[k] = strValue + } else { + glog.Warningf("Output Artifact %q aws-tags metadata is not a map[string]string: %#v", name, awsTags) + } + } + // Upload artifacts from local path to remote storages. localDir, err := localPathForURI(outputArtifact.Uri) if err != nil { @@ -439,7 +455,7 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec if err != nil { return nil, fmt.Errorf("failed to upload output artifact %q: %w", name, err) } - if err := objectstore.UploadBlob(ctx, opts.bucket, localDir, blobKey); err != nil { + if err := objectstore.UploadBlob(ctx, opts.bucket, localDir, blobKey, bucketConfig, tags); err != nil { // We allow components to not produce output files if errors.Is(err, os.ErrNotExist) { glog.Warningf("Local filepath %q does not exist", localDir) diff --git a/backend/src/v2/config/s3.go b/backend/src/v2/config/s3.go index 8cfc86d8514..057b8bc5df3 100644 --- a/backend/src/v2/config/s3.go +++ b/backend/src/v2/config/s3.go @@ -16,9 +16,10 @@ package config import ( "fmt" - "github.com/kubeflow/pipelines/backend/src/v2/objectstore" "strconv" "strings" + + "github.com/kubeflow/pipelines/backend/src/v2/objectstore" ) type S3ProviderConfig struct { diff --git a/backend/src/v2/driver/driver_utils.go b/backend/src/v2/driver/driver_utils.go new file mode 100644 index 00000000000..adf540ae5cb --- /dev/null +++ b/backend/src/v2/driver/driver_utils.go @@ -0,0 +1,69 @@ +package driver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/golang/glog" + "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "github.com/kubeflow/pipelines/backend/src/v2/metadata" +) + +type ArtifactReader interface { + GetOutputArtifactsByExecutionId(ctx context.Context, executionId int64) (map[string]*metadata.OutputArtifact, error) +} + +func resolveUpstreamArtifacts(ctx context.Context, tasks map[string]*metadata.Execution, taskName string, outputArtifactKey string, mlmd ArtifactReader) (runtimeArtifact *pipelinespec.RuntimeArtifact, err error) { + glog.V(4).Info("taskName: ", taskName) + glog.V(4).Info("outputArtifactKey: ", outputArtifactKey) + upstreamTask := tasks[taskName] + if *upstreamTask.GetExecution().Type == "system.DAGExecution" { + // recurse + outputArtifactsCustomProperty, ok := upstreamTask.GetExecution().GetCustomProperties()["output_artifacts"] + if !ok { + return nil, fmt.Errorf("cannot find output_artifacts") + } + var outputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec + glog.V(4).Infof("outputArtifactsCustomProperty: %#v", outputArtifactsCustomProperty) + glog.V(4).Info("outputArtifactsCustomProperty String: ", outputArtifactsCustomProperty.GetStringValue()) + err = json.Unmarshal([]byte(outputArtifactsCustomProperty.GetStringValue()), &outputArtifacts) + if err != nil { + return nil, err + } + glog.V(4).Info("Deserialized outputArtifactsMap: ", outputArtifacts) + var subTaskName string + outputArtifactSelectors := outputArtifacts[outputArtifactKey].GetArtifactSelectors() + for _, outputArtifactSelector := range outputArtifactSelectors { + subTaskName = outputArtifactSelector.ProducerSubtask + outputArtifactKey = outputArtifactSelector.OutputArtifactKey + glog.V(4).Infof("ProducerSubtask: %v", outputArtifactSelector.ProducerSubtask) + glog.V(4).Infof("OutputArtifactKey: %v", outputArtifactSelector.OutputArtifactKey) + } + downstreamParameterMapping, err := resolveUpstreamArtifacts(ctx, tasks, subTaskName, outputArtifactKey, mlmd) + glog.V(4).Infof("downstreamParameterMapping: %#v", downstreamParameterMapping) + if err != nil { + return nil, err + } + return downstreamParameterMapping, nil + } else { + // base case + outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, upstreamTask.GetID()) + if err != nil { + return nil, err + } + artifact, ok := outputs[outputArtifactKey] + if !ok { + return nil, fmt.Errorf( + "cannot find output artifact key %q in producer task %q", + outputArtifactKey, + taskName, + ) + } + runtimeArtifact, err := artifact.ToRuntimeArtifact() + if err != nil { + return nil, err + } + return runtimeArtifact, nil + } +} diff --git a/backend/src/v2/objectstore/config.go b/backend/src/v2/objectstore/config.go index 28e82cd65de..8ae2c64a3c8 100644 --- a/backend/src/v2/objectstore/config.go +++ b/backend/src/v2/objectstore/config.go @@ -18,12 +18,13 @@ package objectstore import ( "encoding/json" "fmt" - "github.com/golang/glog" "os" "path" "regexp" "strconv" "strings" + + "github.com/golang/glog" ) // The endpoint uses Kubernetes service DNS name with namespace: diff --git a/backend/src/v2/objectstore/object_store.go b/backend/src/v2/objectstore/object_store.go index a7b85565565..8abb9a8e8b2 100644 --- a/backend/src/v2/objectstore/object_store.go +++ b/backend/src/v2/objectstore/object_store.go @@ -17,9 +17,17 @@ package objectstore import ( "context" "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "regexp" + "strings" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" "github.com/golang/glog" "gocloud.dev/blob" "gocloud.dev/blob/gcsblob" @@ -27,14 +35,8 @@ import ( "gocloud.dev/blob/s3blob" "gocloud.dev/gcp" "golang.org/x/oauth2/google" - "io" - "io/ioutil" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" - "os" - "path/filepath" - "regexp" - "strings" ) func OpenBucket(ctx context.Context, k8sClient kubernetes.Interface, namespace string, config *Config) (bucket *blob.Bucket, err error) { @@ -81,18 +83,22 @@ func OpenBucket(ctx context.Context, k8sClient kubernetes.Interface, namespace s bucketURL = strings.Replace(bucketURL, "minio://", "s3://", 1) } + bucket, err = blob.OpenBucket(ctx, bucketURL) + if err != nil { + return nil, nil, err + } // When no provider config is provided, or "FromEnv" is specified, use default credentials from the environment - return blob.OpenBucket(ctx, bucketURL) + return bucket, nil, nil } -func UploadBlob(ctx context.Context, bucket *blob.Bucket, localPath, blobPath string) error { +func UploadBlob(ctx context.Context, bucket *blob.Bucket, localPath, blobPath string, bucketConfig *Config, awsTags map[string]string) error { fileInfo, err := os.Stat(localPath) if err != nil { return fmt.Errorf("unable to stat local filepath %q: %w", localPath, err) } if !fileInfo.IsDir() { - return uploadFile(ctx, bucket, localPath, blobPath) + return uploadFile(ctx, bucket, localPath, blobPath, bucketConfig, awsTags) } // localPath is a directory. @@ -103,14 +109,14 @@ func UploadBlob(ctx context.Context, bucket *blob.Bucket, localPath, blobPath st for _, f := range files { if f.IsDir() { - err = UploadBlob(ctx, bucket, filepath.Join(localPath, f.Name()), blobPath+"/"+f.Name()) + err = UploadBlob(ctx, bucket, filepath.Join(localPath, f.Name()), blobPath+"/"+f.Name(), bucketConfig, awsTags) if err != nil { return err } } else { blobFilePath := filepath.Join(blobPath, filepath.Base(f.Name())) localFilePath := filepath.Join(localPath, f.Name()) - if err := uploadFile(ctx, bucket, localFilePath, blobFilePath); err != nil { + if err := uploadFile(ctx, bucket, localFilePath, blobFilePath, bucketConfig, awsTags); err != nil { return err } } @@ -149,12 +155,18 @@ func DownloadBlob(ctx context.Context, bucket *blob.Bucket, localDir, blobDir st return nil } -func uploadFile(ctx context.Context, bucket *blob.Bucket, localFilePath, blobFilePath string) error { +func uploadFile(ctx context.Context, bucket *blob.Bucket, localFilePath, blobFilePath string, bucketConfig *Config, awsTags map[string]string) error { errorF := func(err error) error { return fmt.Errorf("uploadFile(): unable to complete copying %q to remote storage %q: %w", localFilePath, blobFilePath, err) } - w, err := bucket.NewWriter(ctx, blobFilePath, nil) + glog.Infof("Adding Tag: %#v", map[string]string{"TKRetention": "works"}) + writerOptions := &blob.WriterOptions{ + Metadata: map[string]string{"TKRetention": "works"}, + // Tagging: string{"tkalbach=works&testing=true"}, + } + + w, err := bucket.NewWriter(ctx, blobFilePath, writerOptions) if err != nil { return errorF(fmt.Errorf("unable to open writer for bucket: %w", err)) } @@ -173,6 +185,31 @@ func uploadFile(ctx context.Context, bucket *blob.Bucket, localFilePath, blobFil return errorF(fmt.Errorf("failed to close Writer for bucket: %w", err)) } + // Adding Hack to get tags. + glog.Info("Adding tags to S3 object") + glog.Infof("Bucket URL: %#v", bucketConfig.bucketURL()) + // TODO: Only run if the bucket is s3. Determine condition for this. + // TODO: Check if Must can solve the issue if I pass the session through. + mySession := session.Must(session.NewSession()) + // TODO: Is region required. I should look how they do this with google blob Bucket. + s3Client := s3.New(mySession) + glog.Info("Creating Client", s3Client) + + S3B := S3Bucket{client: s3Client, bucket: bucketConfig.BucketName, key: bucketConfig.Prefix + blobFilePath} + glog.Info("Calling AddTags") + tags := []*s3.Tag{} + for k, v := range awsTags { + glog.Infof("Adding Tag: %#v", map[string]string{k: v}) + tags = append(tags, &s3.Tag{ + Key: aws.String(k), + Value: aws.String(v), + }) + } + err = S3B.AddTags(ctx, tags) + if err != nil { + return errorF(fmt.Errorf("unable to add tags to S3 object: %w", err)) + } + glog.Infof("uploadFile(localFilePath=%q, blobFilePath=%q)", localFilePath, blobFilePath) return nil } @@ -244,6 +281,7 @@ func createS3BucketSession(ctx context.Context, namespace string, sessionInfo *S return nil, nil } config := &aws.Config{} + glog.Infof("sessionInfo.Params: %#v", sessionInfo.Params) params, err := StructuredS3Params(sessionInfo.Params) if err != nil { return nil, err @@ -272,6 +310,7 @@ func createS3BucketSession(ctx context.Context, namespace string, sessionInfo *S config.Endpoint = aws.String(params.Endpoint) } + glog.Infof("config: %#v", config) sess, err := session.NewSession(config) if err != nil { return nil, fmt.Errorf("Failed to create object store session, %v", err) diff --git a/backend/src/v2/objectstore/object_store_bucket.go b/backend/src/v2/objectstore/object_store_bucket.go new file mode 100644 index 00000000000..fe51f3387b9 --- /dev/null +++ b/backend/src/v2/objectstore/object_store_bucket.go @@ -0,0 +1,75 @@ +package objectstore + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/golang/glog" +) + +func createS3Client(region, accessKey, secretKey string) (*s3.S3, error) { + sess, err := session.NewSession(&aws.Config{ + Region: aws.String(region), + Credentials: credentials.NewStaticCredentials( + accessKey, secretKey, "", + ), + }) + if err != nil { + return nil, fmt.Errorf("failed to create AWS session: %w", err) + } + + s3Client := s3.New(sess) + return s3Client, nil +} + +type S3Bucket struct { + client *s3.S3 + bucket string + key string +} + +// AddTags is an additional method for S3Bucket. +func (b *S3Bucket) AddTags(ctx context.Context, tags []*s3.Tag) (err error) { + defer func() { + if err != nil { + // wrap error before returning + err = fmt.Errorf("failed to add tags to S3 bucket: %w", err) + } + }() + glog.Info("Bucket String: ", b.bucket) + bucket := aws.String(b.bucket) + glog.Info("Key String: ", b.key) + key := aws.String(b.key) + glog.Info("Tags: ", tags) + s3Tags := s3.Tagging{TagSet: tags} + + if b.client == nil { + return fmt.Errorf("S3 client is not initialized") + } + if b.bucket == "" { + return fmt.Errorf("bucket name is not set") + } + if b.key == "" { + return fmt.Errorf("object key is not set") + } + + glog.Info("Calling PutObjectTagging") + objectTaggingInput := &s3.PutObjectTaggingInput{ + Bucket: bucket, + Key: key, + Tagging: &s3Tags, + } + + glog.Info("Calling PutObjectTaggingWithContext") + // sess, err1 := createS3BucketSession(ctx, namespace, config.SessionInfo, k8sClient) + _, err = b.client.PutObjectTaggingWithContext(ctx, objectTaggingInput) + if err != nil { + glog.Errorf("failed to add tags to S3 bucket: %v", err) + return fmt.Errorf("failed to add tags to S3 bucket: %w", err) + } + return err +}