Skip to content

Commit

Permalink
Migrate S3 to SDK v2 (#779)
Browse files Browse the repository at this point in the history
* Migrate/s3

* migrate: aws v2 s3

---------

Co-authored-by: James Kwon <96548424+hongil0316@users.noreply.github.com>
  • Loading branch information
james03160927 and james03160927 authored Nov 25, 2024
1 parent b6a4b37 commit c7caecb
Show file tree
Hide file tree
Showing 16 changed files with 503 additions and 432 deletions.
204 changes: 106 additions & 98 deletions aws/resources/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ package resources

import (
"context"
goerr "errors"
"fmt"
"math"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go"

"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/report"
Expand All @@ -26,17 +30,17 @@ func (sb S3Buckets) getS3BucketRegion(bucketName string) (string, error) {
Bucket: aws.String(bucketName),
}

result, err := sb.Client.GetBucketLocationWithContext(sb.Context, input)
result, err := sb.Client.GetBucketLocation(sb.Context, input)
if err != nil {
return "", err
}

if result.LocationConstraint == nil {
if result.LocationConstraint == "" {
// GetBucketLocation returns nil for us-east-1
// https://github.com/aws/aws-sdk-go/issues/1687
return "us-east-1", nil
}
return *result.LocationConstraint, nil
return string(result.LocationConstraint), nil
}

// getS3BucketTags returns S3 Bucket tags.
Expand All @@ -47,18 +51,18 @@ func (bucket *S3Buckets) getS3BucketTags(bucketName string) (map[string]string,

// Please note that svc argument should be created from a session object which is
// in the same region as the bucket or GetBucketTagging will fail.
result, err := bucket.Client.GetBucketTaggingWithContext(bucket.Context, input)
result, err := bucket.Client.GetBucketTagging(bucket.Context, input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case "NoSuchTagSet":
var apiErr *smithy.OperationError
if goerr.As(err, &apiErr) {
if strings.Contains(apiErr.Error(), "NoSuchTagSet: The TagSet does not exist") {
return nil, nil
}
return nil, err
}
return nil, err
}

return util.ConvertS3TagsToMap(result.TagSet), nil
return util.ConvertS3TypesTagsToMap(result.TagSet), nil
}

// S3Bucket - represents S3 bucket
Expand All @@ -73,7 +77,7 @@ type S3Bucket struct {

// getAllS3Buckets returns a map of per region AWS S3 buckets which were created before excludeAfter
func (sb S3Buckets) getAll(c context.Context, configObj config.Config) ([]*string, error) {
output, err := sb.Client.ListBucketsWithContext(sb.Context, &s3.ListBucketsInput{})
output, err := sb.Client.ListBuckets(sb.Context, &s3.ListBucketsInput{})
if err != nil {
return nil, errors.WithStackTrace(err)
}
Expand Down Expand Up @@ -106,14 +110,14 @@ func (sb S3Buckets) getAll(c context.Context, configObj config.Config) ([]*strin
}

// getBucketNamesPerRegions gets valid bucket names concurrently from list of target buckets
func (sb S3Buckets) getBucketNames(targetBuckets []*s3.Bucket, configObj config.Config) ([]*string, error) {
func (sb S3Buckets) getBucketNames(targetBuckets []types.Bucket, configObj config.Config) ([]*string, error) {
var bucketNames []*string
bucketCh := make(chan *S3Bucket, len(targetBuckets))
var wg sync.WaitGroup

for _, bucket := range targetBuckets {
wg.Add(1)
go func(bucket *s3.Bucket) {
go func(bucket types.Bucket) {
defer wg.Done()
sb.getBucketInfo(bucket, bucketCh, configObj)
}(bucket)
Expand Down Expand Up @@ -143,10 +147,10 @@ func (sb S3Buckets) getBucketNames(targetBuckets []*s3.Bucket, configObj config.
}

// getBucketInfo populates the local S3Bucket struct for the passed AWS bucket
func (sb S3Buckets) getBucketInfo(bucket *s3.Bucket, bucketCh chan<- *S3Bucket, configObj config.Config) {
func (sb S3Buckets) getBucketInfo(bucket types.Bucket, bucketCh chan<- *S3Bucket, configObj config.Config) {
var bucketData S3Bucket
bucketData.Name = aws.StringValue(bucket.Name)
bucketData.CreationDate = aws.TimeValue(bucket.CreationDate)
bucketData.Name = aws.ToString(bucket.Name)
bucketData.CreationDate = aws.ToTime(bucket.CreationDate)

bucketRegion, err := sb.getS3BucketRegion(bucketData.Name)
if err != nil {
Expand Down Expand Up @@ -194,93 +198,73 @@ func (sb S3Buckets) getBucketInfo(bucket *s3.Bucket, bucketCh chan<- *S3Bucket,
func (sb S3Buckets) emptyBucket(bucketName *string) error {
// Since the error may happen in the inner function handler for the pager, we need a function scoped variable that
// the inner function can set when there is an error.
var errOut error
pageId := 1

// As bucket versioning is managed separately and you can turn off versioning after the bucket is created,
// we need to check if there are any versions in the bucket regardless of the versioning status.
err := sb.Client.ListObjectVersionsPagesWithContext(
sb.Context,
&s3.ListObjectVersionsInput{
Bucket: bucketName,
MaxKeys: aws.Int64(int64(sb.MaxBatchSize())),
},
func(page *s3.ListObjectVersionsOutput, lastPage bool) (shouldContinue bool) {
logging.Debugf("Deleting page %d of object versions (%d objects) from bucket %s", pageId, len(page.Versions), aws.StringValue(bucketName))
if err := sb.deleteObjectVersions(bucketName, page.Versions); err != nil {
logging.Errorf("Error deleting objects versions for page %d from bucket %s: %s", pageId, aws.StringValue(bucketName), err)
errOut = err
return false
}
logging.Debugf("[OK] - deleted page %d of object versions (%d objects) from bucket %s", pageId, len(page.Versions), aws.StringValue(bucketName))

logging.Debugf("Deleting page %d of deletion markers (%d deletion markers) from bucket %s", pageId, len(page.DeleteMarkers), aws.StringValue(bucketName))
if err := sb.deleteDeletionMarkers(bucketName, page.DeleteMarkers); err != nil {
logging.Debugf("Error deleting deletion markers for page %d from bucket %s: %s", pageId, aws.StringValue(bucketName), err)
errOut = err
return false
}
logging.Debugf("[OK] - deleted page %d of deletion markers (%d deletion markers) from bucket %s", pageId, len(page.DeleteMarkers), aws.StringValue(bucketName))

pageId++
return true
},
)
outputs, err := sb.Client.ListObjectVersions(sb.Context, &s3.ListObjectVersionsInput{
Bucket: bucketName,
MaxKeys: aws.Int32(int32(sb.MaxBatchSize())),
})
if err != nil {
return err
}
if errOut != nil {
return errOut
return errors.WithStackTrace(err)
}
return nil

// Handle non versioned buckets.
err = sb.Client.ListObjectsV2PagesWithContext(
sb.Context,
&s3.ListObjectsV2Input{
Bucket: bucketName,
MaxKeys: aws.Int64(int64(sb.MaxBatchSize())),
},
func(page *s3.ListObjectsV2Output, lastPage bool) (shouldContinue bool) {
logging.Debugf("Deleting object page %d (%d objects) from bucket %s", pageId, len(page.Contents), aws.StringValue(bucketName))
if err := sb.deleteObjects(bucketName, page.Contents); err != nil {
logging.Errorf("Error deleting objects for page %d from bucket %s: %s", pageId, aws.StringValue(bucketName), err)
errOut = err
return false
}
logging.Debugf("[OK] - deleted object page %d (%d objects) from bucket %s", pageId, len(page.Contents), aws.StringValue(bucketName))
logging.Debugf("Deleting page %d of object versions (%d objects) from bucket %s", pageId, len(outputs.Versions), aws.ToString(bucketName))
if err := sb.deleteObjectVersions(bucketName, outputs.Versions); err != nil {
logging.Errorf("Error deleting objects versions for page %d from bucket %s: %s", pageId, aws.ToString(bucketName), err)
return errors.WithStackTrace(err)
}
logging.Debugf("[OK] - deleted page %d of object versions (%d objects) from bucket %s", pageId, len(outputs.Versions), aws.ToString(bucketName))

pageId++
return true
},
)
if err != nil {
return err
logging.Debugf("Deleting page %d of object delete markers (%d objects) from bucket %s", pageId, len(outputs.Versions), aws.ToString(bucketName))
if err := sb.deleteDeletionMarkers(bucketName, outputs.DeleteMarkers); err != nil {
logging.Errorf("Error deleting deletion markers for page %d from bucket %s: %s", pageId, aws.ToString(bucketName), err)
return errors.WithStackTrace(err)
}
if errOut != nil {
return errOut
logging.Debugf("[OK] - deleted page %d of deletion markers (%d deletion markers) from bucket %s", pageId, len(outputs.DeleteMarkers), aws.ToString(bucketName))

paginator := s3.NewListObjectsV2Paginator(sb.Client, &s3.ListObjectsV2Input{
Bucket: bucketName,
MaxKeys: aws.Int32(int32(sb.MaxBatchSize())),
})

for paginator.HasMorePages() {

page, err := paginator.NextPage(sb.Context)
if err != nil {
return errors.WithStackTrace(err)
}

logging.Debugf("Deleting object page %d (%d objects) from bucket %s", pageId, len(page.Contents), aws.ToString(bucketName))
if err := sb.deleteObjects(bucketName, page.Contents); err != nil {
logging.Errorf("Error deleting objects for page %d from bucket %s: %s", pageId, aws.ToString(bucketName), err)
return err
}
pageId++
}
return nil
}

// deleteObjects will delete the provided objects (unversioned) from the specified bucket.
func (sb S3Buckets) deleteObjects(bucketName *string, objects []*s3.Object) error {
func (sb S3Buckets) deleteObjects(bucketName *string, objects []types.Object) error {
if len(objects) == 0 {
logging.Debugf("No objects returned in page")
return nil
}

objectIdentifiers := []*s3.ObjectIdentifier{}
objectIdentifiers := []types.ObjectIdentifier{}
for _, obj := range objects {
objectIdentifiers = append(objectIdentifiers, &s3.ObjectIdentifier{
objectIdentifiers = append(objectIdentifiers, types.ObjectIdentifier{
Key: obj.Key,
})
}
_, err := sb.Client.DeleteObjectsWithContext(
_, err := sb.Client.DeleteObjects(
sb.Context,
&s3.DeleteObjectsInput{
Bucket: bucketName,
Delete: &s3.Delete{
Delete: &types.Delete{
Objects: objectIdentifiers,
Quiet: aws.Bool(false),
},
Expand All @@ -290,24 +274,24 @@ func (sb S3Buckets) deleteObjects(bucketName *string, objects []*s3.Object) erro
}

// deleteObjectVersions will delete the provided object versions from the specified bucket.
func (sb S3Buckets) deleteObjectVersions(bucketName *string, objectVersions []*s3.ObjectVersion) error {
func (sb S3Buckets) deleteObjectVersions(bucketName *string, objectVersions []types.ObjectVersion) error {
if len(objectVersions) == 0 {
logging.Debugf("No object versions returned in page")
return nil
}

objectIdentifiers := []*s3.ObjectIdentifier{}
objectIdentifiers := []types.ObjectIdentifier{}
for _, obj := range objectVersions {
objectIdentifiers = append(objectIdentifiers, &s3.ObjectIdentifier{
objectIdentifiers = append(objectIdentifiers, types.ObjectIdentifier{
Key: obj.Key,
VersionId: obj.VersionId,
})
}
_, err := sb.Client.DeleteObjectsWithContext(
_, err := sb.Client.DeleteObjects(
sb.Context,
&s3.DeleteObjectsInput{
Bucket: bucketName,
Delete: &s3.Delete{
Delete: &types.Delete{
Objects: objectIdentifiers,
Quiet: aws.Bool(false),
},
Expand All @@ -317,24 +301,24 @@ func (sb S3Buckets) deleteObjectVersions(bucketName *string, objectVersions []*s
}

// deleteDeletionMarkers will delete the provided deletion markers from the specified bucket.
func (sb S3Buckets) deleteDeletionMarkers(bucketName *string, objectDelMarkers []*s3.DeleteMarkerEntry) error {
func (sb S3Buckets) deleteDeletionMarkers(bucketName *string, objectDelMarkers []types.DeleteMarkerEntry) error {
if len(objectDelMarkers) == 0 {
logging.Debugf("No deletion markers returned in page")
return nil
}

objectIdentifiers := []*s3.ObjectIdentifier{}
objectIdentifiers := []types.ObjectIdentifier{}
for _, obj := range objectDelMarkers {
objectIdentifiers = append(objectIdentifiers, &s3.ObjectIdentifier{
objectIdentifiers = append(objectIdentifiers, types.ObjectIdentifier{
Key: obj.Key,
VersionId: obj.VersionId,
})
}
_, err := sb.Client.DeleteObjectsWithContext(
_, err := sb.Client.DeleteObjects(
sb.Context,
&s3.DeleteObjectsInput{
Bucket: bucketName,
Delete: &s3.Delete{
Delete: &types.Delete{
Objects: objectIdentifiers,
Quiet: aws.Bool(false),
},
Expand All @@ -349,18 +333,18 @@ func (sb S3Buckets) nukeAllS3BucketObjects(bucketName *string) error {
return fmt.Errorf("Invalid batchsize - %d - should be between %d and %d", sb.MaxBatchSize(), 1, 1000)
}

logging.Debugf("Emptying bucket %s", aws.StringValue(bucketName))
logging.Debugf("Emptying bucket %s", aws.ToString(bucketName))
if err := sb.emptyBucket(bucketName); err != nil {
return err
}
logging.Debugf("[OK] - successfully emptied bucket %s", aws.StringValue(bucketName))
logging.Debugf("[OK] - successfully emptied bucket %s", aws.ToString(bucketName))
return nil
}

// nukeEmptyS3Bucket deletes an empty S3 bucket
func (sb S3Buckets) nukeEmptyS3Bucket(bucketName *string, verifyBucketDeletion bool) error {

_, err := sb.Client.DeleteBucketWithContext(sb.Context, &s3.DeleteBucketInput{
_, err := sb.Client.DeleteBucket(sb.Context, &s3.DeleteBucketInput{
Bucket: bucketName,
})
if err != nil {
Expand All @@ -375,23 +359,47 @@ func (sb S3Buckets) nukeEmptyS3Bucket(bucketName *string, verifyBucketDeletion b
// such, we retry this routine up to 3 times for a total of 300 seconds.
const maxRetries = 3
for i := 0; i < maxRetries; i++ {
logging.Debugf("Waiting until bucket (%s) deletion is propagated (attempt %d / %d)", aws.StringValue(bucketName), i+1, maxRetries)
err = sb.Client.WaitUntilBucketNotExistsWithContext(sb.Context, &s3.HeadBucketInput{
Bucket: bucketName,
})
logging.Debugf("Waiting until bucket (%s) deletion is propagated (attempt %d / %d)", aws.ToString(bucketName), i+1, maxRetries)
err = waitForBucketDeletion(sb.Context, sb.Client, aws.ToString(bucketName))
// Exit early if no error
if err == nil {
logging.Debug("Successfully detected bucket deletion.")
return nil
}
logging.Debugf("Error waiting for bucket (%s) deletion propagation (attempt %d / %d)", aws.StringValue(bucketName), i+1, maxRetries)
logging.Debugf("Error waiting for bucket (%s) deletion propagation (attempt %d / %d)", aws.ToString(bucketName), i+1, maxRetries)
logging.Debugf("Underlying error was: %s", err)
}
return err
}

func waitForBucketDeletion(ctx context.Context, client S3API, bucketName string) error {
waiter := s3.NewBucketNotExistsWaiter(client)

for i := 0; i < maxRetries; i++ {
logging.Debugf("Waiting until bucket (%s) deletion is propagated (attempt %d / %d)", bucketName, i+1, maxRetries)

err := waiter.Wait(ctx, &s3.HeadBucketInput{
Bucket: aws.String(bucketName),
}, waitDuration)
if err == nil {
logging.Debugf("Successfully detected bucket deletion.")
return nil
}
logging.Debugf("Waiting until bucket erorr (%v)", err)

if i == maxRetries-1 {
return fmt.Errorf("failed to confirm bucket deletion after %d attempts: %w", maxRetries, err)
}

logging.Debugf("Error waiting for bucket (%s) deletion propagation (attempt %d / %d)", bucketName, i+1, maxRetries)
logging.Debugf("Underlying error was: %s", err)
}

return fmt.Errorf("unexpected error: reached end of retry loop")
}

func (sb S3Buckets) nukeS3BucketPolicy(bucketName *string) error {
_, err := sb.Client.DeleteBucketPolicyWithContext(
_, err := sb.Client.DeleteBucketPolicy(
sb.Context,
&s3.DeleteBucketPolicyInput{
Bucket: aws.String(*bucketName),
Expand Down Expand Up @@ -443,7 +451,7 @@ func (sb S3Buckets) nukeAll(bucketNames []*string) (delCount int, err error) {

// Record status of this resource
e := report.Entry{
Identifier: aws.StringValue(bucketName),
Identifier: aws.ToString(bucketName),
ResourceType: "S3 Bucket",
Error: err,
}
Expand Down
Loading

0 comments on commit c7caecb

Please sign in to comment.