Skip to content

Commit

Permalink
POC Hack to get s3 tags working for aws backend (breaks gcp, and mini…
Browse files Browse the repository at this point in the history
…o I expect)
  • Loading branch information
boarder7395 committed Oct 16, 2024
1 parent a80b65a commit 0b6a982
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 18 deletions.
22 changes: 19 additions & 3 deletions backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func executeV2(
return nil, nil, err
}
// TODO(Bobgy): should we log metadata per each artifact, or batched after uploading all artifacts.
outputArtifacts, err := uploadOutputArtifacts(ctx, executorInput, executorOutput, uploadOutputArtifactsOptions{
outputArtifacts, err := uploadOutputArtifacts(ctx, executorInput, executorOutput, bucketConfig, uploadOutputArtifactsOptions{
bucketConfig: bucketConfig,
bucket: bucket,
metadataClient: metadataClient,
Expand Down Expand Up @@ -415,7 +415,7 @@ type uploadOutputArtifactsOptions struct {
metadataClient metadata.ClientInterface
}

func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, executorOutput *pipelinespec.ExecutorOutput, opts uploadOutputArtifactsOptions) ([]*metadata.OutputArtifact, error) {
func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, executorOutput *pipelinespec.ExecutorOutput, bucketConfig *objectstore.Config, opts uploadOutputArtifactsOptions) ([]*metadata.OutputArtifact, error) {
// Register artifacts with MLMD.
outputArtifacts := make([]*metadata.OutputArtifact, 0, len(executorInput.GetOutputs().GetArtifacts()))
for name, artifactList := range executorInput.GetOutputs().GetArtifacts() {
Expand All @@ -430,6 +430,22 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec
mergeRuntimeArtifacts(list.Artifacts[0], outputArtifact)
}

// TODO: Getting aws tags from output metadata. Is there a better way to convert map[string]interface{} to map[string]string?
outputMetadata := outputArtifact.GetMetadata()
glog.Infof("Artifact metadata: %#v", outputMetadata)
awsTags, ok := outputMetadata.AsMap()["aws-tags"]
if !ok {
glog.Warningf("Output Artifact %q does not have aws-tags metadata", name)
}
tags := make(map[string]string)
for k, v := range awsTags.(map[string]interface{}) {
if strValue, ok := v.(string); ok {
tags[k] = strValue
} else {
glog.Warningf("Output Artifact %q aws-tags metadata is not a map[string]string: %#v", name, awsTags)
}
}

// Upload artifacts from local path to remote storages.
localDir, err := localPathForURI(outputArtifact.Uri)
if err != nil {
Expand All @@ -439,7 +455,7 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec
if err != nil {
return nil, fmt.Errorf("failed to upload output artifact %q: %w", name, err)
}
if err := objectstore.UploadBlob(ctx, opts.bucket, localDir, blobKey); err != nil {
if err := objectstore.UploadBlob(ctx, opts.bucket, localDir, blobKey, bucketConfig, tags); err != nil {
// We allow components to not produce output files
if errors.Is(err, os.ErrNotExist) {
glog.Warningf("Local filepath %q does not exist", localDir)
Expand Down
3 changes: 2 additions & 1 deletion backend/src/v2/config/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ package config

import (
"fmt"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
"strconv"
"strings"

"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
)

type S3ProviderConfig struct {
Expand Down
69 changes: 69 additions & 0 deletions backend/src/v2/driver/driver_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package driver

import (
"context"
"encoding/json"
"fmt"

"github.com/golang/glog"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/v2/metadata"
)

type ArtifactReader interface {
GetOutputArtifactsByExecutionId(ctx context.Context, executionId int64) (map[string]*metadata.OutputArtifact, error)
}

func resolveUpstreamArtifacts(ctx context.Context, tasks map[string]*metadata.Execution, taskName string, outputArtifactKey string, mlmd ArtifactReader) (runtimeArtifact *pipelinespec.RuntimeArtifact, err error) {
glog.V(4).Info("taskName: ", taskName)
glog.V(4).Info("outputArtifactKey: ", outputArtifactKey)
upstreamTask := tasks[taskName]
if *upstreamTask.GetExecution().Type == "system.DAGExecution" {
// recurse
outputArtifactsCustomProperty, ok := upstreamTask.GetExecution().GetCustomProperties()["output_artifacts"]
if !ok {
return nil, fmt.Errorf("cannot find output_artifacts")
}
var outputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec
glog.V(4).Infof("outputArtifactsCustomProperty: %#v", outputArtifactsCustomProperty)
glog.V(4).Info("outputArtifactsCustomProperty String: ", outputArtifactsCustomProperty.GetStringValue())
err = json.Unmarshal([]byte(outputArtifactsCustomProperty.GetStringValue()), &outputArtifacts)
if err != nil {
return nil, err
}
glog.V(4).Info("Deserialized outputArtifactsMap: ", outputArtifacts)
var subTaskName string
outputArtifactSelectors := outputArtifacts[outputArtifactKey].GetArtifactSelectors()
for _, outputArtifactSelector := range outputArtifactSelectors {
subTaskName = outputArtifactSelector.ProducerSubtask
outputArtifactKey = outputArtifactSelector.OutputArtifactKey
glog.V(4).Infof("ProducerSubtask: %v", outputArtifactSelector.ProducerSubtask)
glog.V(4).Infof("OutputArtifactKey: %v", outputArtifactSelector.OutputArtifactKey)
}
downstreamParameterMapping, err := resolveUpstreamArtifacts(ctx, tasks, subTaskName, outputArtifactKey, mlmd)
glog.V(4).Infof("downstreamParameterMapping: %#v", downstreamParameterMapping)
if err != nil {
return nil, err
}
return downstreamParameterMapping, nil
} else {
// base case
outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, upstreamTask.GetID())
if err != nil {
return nil, err
}
artifact, ok := outputs[outputArtifactKey]
if !ok {
return nil, fmt.Errorf(
"cannot find output artifact key %q in producer task %q",
outputArtifactKey,
taskName,
)
}
runtimeArtifact, err := artifact.ToRuntimeArtifact()
if err != nil {
return nil, err
}
return runtimeArtifact, nil
}
}
3 changes: 2 additions & 1 deletion backend/src/v2/objectstore/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ package objectstore
import (
"encoding/json"
"fmt"
"github.com/golang/glog"
"os"
"path"
"regexp"
"strconv"
"strings"

"github.com/golang/glog"
)

// The endpoint uses Kubernetes service DNS name with namespace:
Expand Down
65 changes: 52 additions & 13 deletions backend/src/v2/objectstore/object_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,26 @@ package objectstore
import (
"context"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/golang/glog"
"gocloud.dev/blob"
"gocloud.dev/blob/gcsblob"
_ "gocloud.dev/blob/gcsblob"
"gocloud.dev/blob/s3blob"
"gocloud.dev/gcp"
"golang.org/x/oauth2/google"
"io"
"io/ioutil"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"os"
"path/filepath"
"regexp"
"strings"
)

func OpenBucket(ctx context.Context, k8sClient kubernetes.Interface, namespace string, config *Config) (bucket *blob.Bucket, err error) {
Expand Down Expand Up @@ -81,18 +83,22 @@ func OpenBucket(ctx context.Context, k8sClient kubernetes.Interface, namespace s
bucketURL = strings.Replace(bucketURL, "minio://", "s3://", 1)
}

bucket, err = blob.OpenBucket(ctx, bucketURL)
if err != nil {
return nil, nil, err

Check failure on line 88 in backend/src/v2/objectstore/object_store.go

View workflow job for this annotation

GitHub Actions / run-go-unittests

too many return values
}
// When no provider config is provided, or "FromEnv" is specified, use default credentials from the environment
return blob.OpenBucket(ctx, bucketURL)
return bucket, nil, nil

Check failure on line 91 in backend/src/v2/objectstore/object_store.go

View workflow job for this annotation

GitHub Actions / run-go-unittests

too many return values
}

func UploadBlob(ctx context.Context, bucket *blob.Bucket, localPath, blobPath string) error {
func UploadBlob(ctx context.Context, bucket *blob.Bucket, localPath, blobPath string, bucketConfig *Config, awsTags map[string]string) error {
fileInfo, err := os.Stat(localPath)
if err != nil {
return fmt.Errorf("unable to stat local filepath %q: %w", localPath, err)
}

if !fileInfo.IsDir() {
return uploadFile(ctx, bucket, localPath, blobPath)
return uploadFile(ctx, bucket, localPath, blobPath, bucketConfig, awsTags)
}

// localPath is a directory.
Expand All @@ -103,14 +109,14 @@ func UploadBlob(ctx context.Context, bucket *blob.Bucket, localPath, blobPath st

for _, f := range files {
if f.IsDir() {
err = UploadBlob(ctx, bucket, filepath.Join(localPath, f.Name()), blobPath+"/"+f.Name())
err = UploadBlob(ctx, bucket, filepath.Join(localPath, f.Name()), blobPath+"/"+f.Name(), bucketConfig, awsTags)
if err != nil {
return err
}
} else {
blobFilePath := filepath.Join(blobPath, filepath.Base(f.Name()))
localFilePath := filepath.Join(localPath, f.Name())
if err := uploadFile(ctx, bucket, localFilePath, blobFilePath); err != nil {
if err := uploadFile(ctx, bucket, localFilePath, blobFilePath, bucketConfig, awsTags); err != nil {
return err
}
}
Expand Down Expand Up @@ -149,12 +155,18 @@ func DownloadBlob(ctx context.Context, bucket *blob.Bucket, localDir, blobDir st
return nil
}

func uploadFile(ctx context.Context, bucket *blob.Bucket, localFilePath, blobFilePath string) error {
func uploadFile(ctx context.Context, bucket *blob.Bucket, localFilePath, blobFilePath string, bucketConfig *Config, awsTags map[string]string) error {
errorF := func(err error) error {
return fmt.Errorf("uploadFile(): unable to complete copying %q to remote storage %q: %w", localFilePath, blobFilePath, err)
}

w, err := bucket.NewWriter(ctx, blobFilePath, nil)
glog.Infof("Adding Tag: %#v", map[string]string{"TKRetention": "works"})
writerOptions := &blob.WriterOptions{
Metadata: map[string]string{"TKRetention": "works"},
// Tagging: string{"tkalbach=works&testing=true"},
}

w, err := bucket.NewWriter(ctx, blobFilePath, writerOptions)
if err != nil {
return errorF(fmt.Errorf("unable to open writer for bucket: %w", err))
}
Expand All @@ -173,6 +185,31 @@ func uploadFile(ctx context.Context, bucket *blob.Bucket, localFilePath, blobFil
return errorF(fmt.Errorf("failed to close Writer for bucket: %w", err))
}

// Adding Hack to get tags.
glog.Info("Adding tags to S3 object")
glog.Infof("Bucket URL: %#v", bucketConfig.bucketURL())
// TODO: Only run if the bucket is s3. Determine condition for this.
// TODO: Check if Must can solve the issue if I pass the session through.
mySession := session.Must(session.NewSession())
// TODO: Is region required. I should look how they do this with google blob Bucket.
s3Client := s3.New(mySession)
glog.Info("Creating Client", s3Client)

S3B := S3Bucket{client: s3Client, bucket: bucketConfig.BucketName, key: bucketConfig.Prefix + blobFilePath}
glog.Info("Calling AddTags")
tags := []*s3.Tag{}
for k, v := range awsTags {
glog.Infof("Adding Tag: %#v", map[string]string{k: v})
tags = append(tags, &s3.Tag{
Key: aws.String(k),
Value: aws.String(v),
})
}
err = S3B.AddTags(ctx, tags)
if err != nil {
return errorF(fmt.Errorf("unable to add tags to S3 object: %w", err))
}

glog.Infof("uploadFile(localFilePath=%q, blobFilePath=%q)", localFilePath, blobFilePath)
return nil
}
Expand Down Expand Up @@ -244,6 +281,7 @@ func createS3BucketSession(ctx context.Context, namespace string, sessionInfo *S
return nil, nil
}
config := &aws.Config{}
glog.Infof("sessionInfo.Params: %#v", sessionInfo.Params)
params, err := StructuredS3Params(sessionInfo.Params)
if err != nil {
return nil, err
Expand Down Expand Up @@ -272,6 +310,7 @@ func createS3BucketSession(ctx context.Context, namespace string, sessionInfo *S
config.Endpoint = aws.String(params.Endpoint)
}

glog.Infof("config: %#v", config)
sess, err := session.NewSession(config)
if err != nil {
return nil, fmt.Errorf("Failed to create object store session, %v", err)
Expand Down
75 changes: 75 additions & 0 deletions backend/src/v2/objectstore/object_store_bucket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package objectstore

import (
"context"
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/golang/glog"
)

func createS3Client(region, accessKey, secretKey string) (*s3.S3, error) {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(region),
Credentials: credentials.NewStaticCredentials(
accessKey, secretKey, "",
),
})
if err != nil {
return nil, fmt.Errorf("failed to create AWS session: %w", err)
}

s3Client := s3.New(sess)
return s3Client, nil
}

type S3Bucket struct {
client *s3.S3
bucket string
key string
}

// AddTags is an additional method for S3Bucket.
func (b *S3Bucket) AddTags(ctx context.Context, tags []*s3.Tag) (err error) {
defer func() {
if err != nil {
// wrap error before returning
err = fmt.Errorf("failed to add tags to S3 bucket: %w", err)
}
}()
glog.Info("Bucket String: ", b.bucket)
bucket := aws.String(b.bucket)
glog.Info("Key String: ", b.key)
key := aws.String(b.key)
glog.Info("Tags: ", tags)
s3Tags := s3.Tagging{TagSet: tags}

if b.client == nil {
return fmt.Errorf("S3 client is not initialized")
}
if b.bucket == "" {
return fmt.Errorf("bucket name is not set")
}
if b.key == "" {
return fmt.Errorf("object key is not set")
}

glog.Info("Calling PutObjectTagging")
objectTaggingInput := &s3.PutObjectTaggingInput{
Bucket: bucket,
Key: key,
Tagging: &s3Tags,
}

glog.Info("Calling PutObjectTaggingWithContext")
// sess, err1 := createS3BucketSession(ctx, namespace, config.SessionInfo, k8sClient)
_, err = b.client.PutObjectTaggingWithContext(ctx, objectTaggingInput)
if err != nil {
glog.Errorf("failed to add tags to S3 bucket: %v", err)
return fmt.Errorf("failed to add tags to S3 bucket: %w", err)
}
return err
}

0 comments on commit 0b6a982

Please sign in to comment.