diff --git a/docs/options.md b/docs/options.md index 88d1deed4c..707dc03e84 100644 --- a/docs/options.md +++ b/docs/options.md @@ -12,4 +12,4 @@ There are a couple of driver options that can be passed as arguments when starti | logging-format | json | text | Sets the log format. Permitted formats: text, json| | user-agent-extra | csi-ebs | helm | Extra string appended to user agent| | enable-otel-tracing | true | false | If set to true, the driver will enable opentelemetry tracing. Might need [additional env variables](https://opentelemetry.io/docs/specs/otel/configuration/sdk-environment-variables/#general-sdk-configuration) to export the traces to the right collector| -| batching | true | true | If set to true, the driver will enable batching of API calls. This is especially helpful for improving performance in workloads that are sensitive to EC2 rate limits| +| batching | true | true | If set to true, the driver will enable batching of API calls. This is especially helpful for improving performance in workloads that are sensitive to EC2 rate limits at the cost of a small increase to worst-case latency| diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 56905f587e..5c7c361e14 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -33,6 +33,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/batcher" dm "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/devicemanager" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util" "k8s.io/apimachinery/pkg/util/wait" @@ -137,6 +138,12 @@ const ( AwsEbsDriverTagKey = "ebs.csi.aws.com/cluster" ) +// Batcher +const ( + volumeIDBatcher batcherType = iota + volumeTagBatcher +) + var ( // ErrMultiDisks is an error that is returned when multiple // disks are found with the same volume name. @@ -235,21 +242,41 @@ type ec2ListSnapshotsResponse struct { NextToken *string } +// batcherType is an enum representing the types of batchers available. +type batcherType int + +// batcherManager maintains a collection of batchers for different types of tasks. +type batcherManager struct { + batchers map[batcherType]*batcher.Batcher[string, *ec2.Volume] +} + type cloud struct { region string ec2 ec2iface.EC2API dm dm.DeviceManager + bm *batcherManager } var _ Cloud = &cloud{} // NewCloud returns a new instance of AWS cloud // It panics if session is invalid -func NewCloud(region string, awsSdkDebugLog bool, userAgentExtra string) (Cloud, error) { - return newEC2Cloud(region, awsSdkDebugLog, userAgentExtra) +func NewCloud(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (Cloud, error) { + c := newEC2Cloud(region, awsSdkDebugLog, userAgentExtra) + + if batching { + klog.V(4).InfoS("NewCloud: batching enabled") + cloudInstance, ok := c.(*cloud) + if !ok { + return nil, fmt.Errorf("expected *cloud type but got %T", c) + } + cloudInstance.bm = newBatcherManager(cloudInstance.ec2) + } + + return c, nil } -func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) (Cloud, error) { +func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) Cloud { awsConfig := &aws.Config{ Region: aws.String(region), CredentialsChainVerboseErrors: aws.Bool(true), @@ -296,7 +323,135 @@ func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) (Clo region: region, dm: dm.NewDeviceManager(), ec2: svc, - }, nil + } +} + +// newBatcherManager initializes a new instance of batcherManager. +func newBatcherManager(svc ec2iface.EC2API) *batcherManager { + return &batcherManager{ + batchers: map[batcherType]*batcher.Batcher[string, *ec2.Volume]{ + volumeIDBatcher: batcher.New(500, 1*time.Second, func(ids []string) (map[string]*ec2.Volume, error) { + return execBatchDescribeVolumes(svc, ids, volumeIDBatcher) + }), + volumeTagBatcher: batcher.New(500, 1*time.Second, func(names []string) (map[string]*ec2.Volume, error) { + return execBatchDescribeVolumes(svc, names, volumeTagBatcher) + }), + }, + } +} + +// getBatcher fetches a specific type of batcher from the batcherManager. +func (bm *batcherManager) getBatcher(b batcherType) *batcher.Batcher[string, *ec2.Volume] { + return bm.batchers[b] +} + +// executes a batched DescribeVolumes API call depending on the type of batcher. +func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batcherType) (map[string]*ec2.Volume, error) { + var request *ec2.DescribeVolumesInput + + switch batcher { + case volumeIDBatcher: + klog.V(7).InfoS("execBatchDescribeVolumes", "volumeIds", input) + request = &ec2.DescribeVolumesInput{ + VolumeIds: aws.StringSlice(input), + } + + case volumeTagBatcher: + klog.V(7).InfoS("execBatchDescribeVolumes", "names", input) + filters := []*ec2.Filter{ + { + Name: aws.String("tag:" + VolumeNameTagKey), + Values: aws.StringSlice(input), + }, + } + request = &ec2.DescribeVolumesInput{ + Filters: filters, + } + + default: + return nil, fmt.Errorf("execBatchDescribeVolumes: unsupported request type") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + resp, err := describeVolumes(ctx, svc, request) + if err != nil { + return nil, err + } + + result := make(map[string]*ec2.Volume) + + for _, volume := range resp { + key, err := extractVolumeKey(volume, batcher) + if err != nil { + klog.Warningf("execBatchDescribeVolumes: skipping volume: %v, reason: %v", volume, err) + continue + } + result[key] = volume + } + + klog.V(7).InfoS("execBatchDescribeVolumes: success", "result", result) + return result, nil +} + +// batchDescribeVolumes processes a DescribeVolumes request. Depending on the request, +// it determines the appropriate batcher to use, queues the task, and waits for the result. +func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Volume, error) { + var bType batcherType + var task string + + switch { + case len(request.VolumeIds) == 1 && request.VolumeIds[0] != nil: + bType = volumeIDBatcher + task = *request.VolumeIds[0] + + case len(request.Filters) == 1 && *request.Filters[0].Name == "tag:"+VolumeNameTagKey && len(request.Filters[0].Values) == 1: + bType = volumeTagBatcher + task = *request.Filters[0].Values[0] + + default: + return nil, fmt.Errorf("batchDescribeVolumes: invalid request, request: %v", request) + } + + ch := make(chan batcher.BatchResult[*ec2.Volume]) + + b := c.bm.getBatcher(bType) + b.AddTask(task, ch) + + r := <-ch + + if r.Err != nil { + return nil, r.Err + } + if r.Result == nil { + return nil, fmt.Errorf("batchDescribeVolumes: no volume found %s", task) + } + return r.Result, nil +} + +// extractVolumeKey retrieves the key associated with a given volume based on the batcher type. +// For the volumeIDBatcher type, it returns the volume's ID. +// For other types, it searches for the VolumeNameTagKey within the volume's tags. +func extractVolumeKey(v *ec2.Volume, batcher batcherType) (string, error) { + if batcher == volumeIDBatcher { + if v.VolumeId == nil { + return "", errors.New("extractVolumeKey: missing volume ID") + } + return *v.VolumeId, nil + } + for _, tag := range v.Tags { + klog.V(7).InfoS("extractVolumeKey: processing tag", "volume", v, "*tag.Key", *tag.Key, "VolumeNameTagKey", VolumeNameTagKey) + if tag.Key == nil || tag.Value == nil { + klog.V(7).InfoS("extractVolumeKey: skipping volume due to missing tag", "volume", v, "tag", tag) + continue + } + if *tag.Key == VolumeNameTagKey { + klog.V(7).InfoS("extractVolumeKey: found volume name tag", "volume", v, "tag", tag) + return *tag.Value, nil + } + } + return "", errors.New("extractVolumeKey: missing VolumeNameTagKey in volume tags") } func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions *DiskOptions) (*Disk, error) { @@ -704,7 +859,7 @@ func (c *cloud) WaitForAttachmentState(ctx context.Context, volumeID, expectedSt return true, nil } // continue waiting - klog.V(4).InfoS("Waiting for volume state", "volumeID", volumeID, "actual", attachmentState, "desired", expectedState) + klog.InfoS("Waiting for volume state", "volumeID", volumeID, "actual", attachmentState, "desired", expectedState) return false, nil } @@ -929,11 +1084,11 @@ func (c *cloud) EnableFastSnapshotRestores(ctx context.Context, availabilityZone return response, nil } -func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput) (*ec2.Volume, error) { +func describeVolumes(ctx context.Context, svc ec2iface.EC2API, request *ec2.DescribeVolumesInput) ([]*ec2.Volume, error) { var volumes []*ec2.Volume var nextToken *string for { - response, err := c.ec2.DescribeVolumesWithContext(ctx, request) + response, err := svc.DescribeVolumesWithContext(ctx, request) if err != nil { return nil, err } @@ -944,14 +1099,25 @@ func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput } request.NextToken = nextToken } + return volumes, nil +} - if l := len(volumes); l > 1 { - return nil, ErrMultiDisks - } else if l < 1 { - return nil, ErrNotFound - } +func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput) (*ec2.Volume, error) { + if c.bm == nil { + volumes, err := describeVolumes(ctx, c.ec2, request) + if err != nil { + return nil, err + } - return volumes[0], nil + if l := len(volumes); l > 1 { + return nil, ErrMultiDisks + } else if l < 1 { + return nil, ErrNotFound + } + return volumes[0], nil + } else { + return c.batchDescribeVolumes(request) + } } func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, error) { diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 3a2c7ad6e0..806ee9845e 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -23,6 +23,7 @@ import ( "reflect" "sort" "strings" + "sync" "testing" "github.com/aws/aws-sdk-go/aws" @@ -44,6 +45,208 @@ const ( defaultPath = "/dev/xvdaa" ) +func generateVolumes(volIdCount, volTagCount int) []*ec2.Volume { + volumes := make([]*ec2.Volume, 0, volIdCount+volTagCount) + + for i := 0; i < volIdCount; i++ { + volumeID := fmt.Sprintf("vol-%d", i) + volumes = append(volumes, &ec2.Volume{VolumeId: aws.String(volumeID)}) + } + + for i := 0; i < volTagCount; i++ { + volumeName := fmt.Sprintf("vol-name-%d", i) + volumes = append(volumes, &ec2.Volume{Tags: []*ec2.Tag{{Key: aws.String(VolumeNameTagKey), Value: aws.String(volumeName)}}}) + } + + return volumes +} + +func extractVolumeIdentifiers(volumes []*ec2.Volume) (volumeIDs []string, volumeNames []string) { + for _, volume := range volumes { + if volume.VolumeId != nil { + volumeIDs = append(volumeIDs, *volume.VolumeId) + } + for _, tag := range volume.Tags { + if tag.Key != nil && *tag.Key == VolumeNameTagKey && tag.Value != nil { + volumeNames = append(volumeNames, *tag.Value) + } + } + } + return volumeIDs, volumeNames +} + +func TestBatchDescribeVolumes(t *testing.T) { + testCases := []struct { + name string + volumes []*ec2.Volume + expErr error + mockFunc func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) + }{ + { + name: "TestBatchDescribeVolumes: volume by ID", + volumes: generateVolumes(10, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(1) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: volume by tag", + volumes: generateVolumes(0, 10), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(1) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: volume by ID and tag", + volumes: generateVolumes(10, 10), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(2) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: max capacity", + volumes: generateVolumes(500, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(1) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: capacity exceeded", + volumes: generateVolumes(550, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(2) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: EC2 API generic error", + volumes: generateVolumes(4, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(nil, expErr).Times(1) + }, + expErr: fmt.Errorf("Generic EC2 API error"), + }, + { + name: "TestBatchDescribeVolumes: volume not found", + volumes: generateVolumes(1, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(nil, expErr).Times(1) + }, + expErr: fmt.Errorf("volume not found"), + }, + { + name: "TestBatchDescribeVolumes: invalid tag", + volumes: []*ec2.Volume{ + { + Tags: []*ec2.Tag{ + {Key: aws.String("InvalidKey"), Value: aws.String("InvalidValue")}, + }, + }, + }, + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(0) + }, + expErr: fmt.Errorf("invalid tag"), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockEC2 := NewMockEC2API(mockCtrl) + c := newCloud(mockEC2) + cloudInstance := c.(*cloud) + cloudInstance.bm = newBatcherManager(cloudInstance.ec2) + + tc.mockFunc(mockEC2, tc.expErr, tc.volumes) + volumeIDs, volumeNames := extractVolumeIdentifiers(tc.volumes) + executeDescribeVolumesTest(t, cloudInstance, volumeIDs, volumeNames, tc.expErr) + }) + } +} +func executeDescribeVolumesTest(t *testing.T, c *cloud, volumeIDs, volumeNames []string, expErr error) { + var wg sync.WaitGroup + + getRequestForID := func(id string) *ec2.DescribeVolumesInput { + return &ec2.DescribeVolumesInput{VolumeIds: []*string{&id}} + } + + getRequestForTag := func(volName string) *ec2.DescribeVolumesInput { + return &ec2.DescribeVolumesInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("tag:" + VolumeNameTagKey), + Values: []*string{&volName}, + }, + }, + } + } + + requests := make([]*ec2.DescribeVolumesInput, 0, len(volumeIDs)+len(volumeNames)) + for _, volumeID := range volumeIDs { + requests = append(requests, getRequestForID(volumeID)) + } + for _, volumeName := range volumeNames { + requests = append(requests, getRequestForTag(volumeName)) + } + + r := make([]chan *ec2.Volume, len(requests)) + e := make([]chan error, len(requests)) + + for i, request := range requests { + wg.Add(1) + r[i] = make(chan *ec2.Volume, 1) + e[i] = make(chan error, 1) + + go func(req *ec2.DescribeVolumesInput, resultCh chan *ec2.Volume, errCh chan error) { + defer wg.Done() + volume, err := c.batchDescribeVolumes(req) + if err != nil { + errCh <- err + return + } + resultCh <- volume + // passing `request` as a parameter to create a copy + // TODO remove after https://github.com/golang/go/discussions/56010 is implemented + }(request, r[i], e[i]) + } + + wg.Wait() + + for i := range requests { + select { + case result := <-r[i]: + if result == nil { + t.Errorf("Received nil result for a request") + } + case err := <-e[i]: + if expErr == nil { + t.Errorf("Error while processing request: %v", err) + } + if !errors.Is(err, expErr) { + t.Errorf("Expected error %v, but got %v", expErr, err) + } + default: + t.Errorf("Did not receive a result or an error for a request") + } + } +} + func TestCreateDisk(t *testing.T) { testCases := []struct { name string @@ -769,9 +972,9 @@ func TestAttachDisk(t *testing.T) { attachRequest := createAttachRequest(volumeID, nodeID, path) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().AttachVolumeWithContext(ctx, attachRequest).Return(createAttachVolumeOutput(volumeID, nodeID, path), nil), - mockEC2.EXPECT().DescribeVolumesWithContext(ctx, volumeRequest).Return(createDescribeVolumesOutput(volumeID, nodeID, path, "attached"), nil), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().AttachVolumeWithContext(gomock.Any(), attachRequest).Return(createAttachVolumeOutput(volumeID, nodeID, path), nil), + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), volumeRequest).Return(createDescribeVolumesOutput([]*string{&volumeID}, nodeID, path, "attached"), nil), ) }, }, @@ -790,8 +993,8 @@ func TestAttachDisk(t *testing.T) { assert.NoError(t, err) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID, volumeID), nil), - mockEC2.EXPECT().DescribeVolumesWithContext(ctx, volumeRequest).Return(createDescribeVolumesOutput(volumeID, nodeID, path, "attached"), nil)) + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID, volumeID), nil), + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), volumeRequest).Return(createDescribeVolumesOutput([]*string{&volumeID}, nodeID, path, "attached"), nil)) }, }, { @@ -805,8 +1008,8 @@ func TestAttachDisk(t *testing.T) { attachRequest := createAttachRequest(volumeID, nodeID, path) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().AttachVolumeWithContext(ctx, attachRequest).Return(nil, errors.New("AttachVolume error")), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().AttachVolumeWithContext(gomock.Any(), attachRequest).Return(nil, errors.New("AttachVolume error")), ) }, }, @@ -864,7 +1067,7 @@ func TestAttachDisk(t *testing.T) { gomock.InOrder( mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), mockEC2.EXPECT().AttachVolumeWithContext(ctx, attachRequest).Return(createAttachVolumeOutput(volumeID, nodeID, path), nil), - mockEC2.EXPECT().DescribeVolumesWithContext(ctx, volumeRequest).Return(createDescribeVolumesOutput(volumeID, nodeID, path, "attached"), nil), + mockEC2.EXPECT().DescribeVolumesWithContext(ctx, volumeRequest).Return(createDescribeVolumesOutput([]*string{&volumeID}, nodeID, path, "attached"), nil), mockEC2.EXPECT().DescribeInstancesWithContext(ctx, createInstanceRequest2).Return(newDescribeInstancesOutput(nodeID2), nil), mockEC2.EXPECT().AttachVolumeWithContext(ctx, attachRequest2).Return(createAttachVolumeOutput(volumeID, nodeID2, path), nil), @@ -925,9 +1128,9 @@ func TestDetachDisk(t *testing.T) { detachRequest := createDetachRequest(volumeID, nodeID) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().DetachVolumeWithContext(ctx, detachRequest).Return(nil, nil), - mockEC2.EXPECT().DescribeVolumesWithContext(ctx, volumeRequest).Return(createDescribeVolumesOutput(volumeID, nodeID, "", "detached"), nil), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().DetachVolumeWithContext(gomock.Any(), detachRequest).Return(nil, nil), + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), volumeRequest).Return(createDescribeVolumesOutput([]*string{&volumeID}, nodeID, "", "detached"), nil), ) }, }, @@ -941,8 +1144,8 @@ func TestDetachDisk(t *testing.T) { detachRequest := createDetachRequest(volumeID, nodeID) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().DetachVolumeWithContext(ctx, detachRequest).Return(nil, errors.New("DetachVolume error")), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().DetachVolumeWithContext(gomock.Any(), detachRequest).Return(nil, errors.New("DetachVolume error")), ) }, }, @@ -956,8 +1159,8 @@ func TestDetachDisk(t *testing.T) { detachRequest := createDetachRequest(volumeID, nodeID) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().DetachVolumeWithContext(ctx, detachRequest).Return(nil, ErrNotFound), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().DetachVolumeWithContext(gomock.Any(), detachRequest).Return(nil, ErrNotFound), ) }, }, @@ -1029,6 +1232,12 @@ func TestGetDiskByName(t *testing.T) { Size: aws.Int64(util.BytesToGiB(tc.volumeCapacity)), AvailabilityZone: aws.String(tc.availabilityZone), OutpostArn: aws.String(tc.outpostArn), + Tags: []*ec2.Tag{ + { + Key: aws.String(VolumeNameTagKey), + Value: aws.String(tc.volumeName), + }, + }, } ctx := context.Background() @@ -2195,11 +2404,12 @@ func TestWaitForAttachmentState(t *testing.T) { } func newCloud(mockEC2 ec2iface.EC2API) Cloud { - return &cloud{ + c := &cloud{ region: "test-region", dm: dm.NewDeviceManager(), ec2: mockEC2, } + return c } func newDescribeInstancesOutput(nodeID string, volumeID ...string) *ec2.DescribeInstancesOutput { @@ -2272,20 +2482,24 @@ func createDetachRequest(volumeID, nodeID string) *ec2.DetachVolumeInput { } } -func createDescribeVolumesOutput(volumeID, nodeID, path, state string) *ec2.DescribeVolumesOutput { - return &ec2.DescribeVolumesOutput{ - Volumes: []*ec2.Volume{ - { - VolumeId: aws.String(volumeID), - Attachments: []*ec2.VolumeAttachment{ - { - Device: aws.String(path), - InstanceId: aws.String(nodeID), - State: aws.String(state), - }, +func createDescribeVolumesOutput(volumeIDs []*string, nodeID, path, state string) *ec2.DescribeVolumesOutput { + volumes := make([]*ec2.Volume, 0, len(volumeIDs)) + + for _, volumeID := range volumeIDs { + volumes = append(volumes, &ec2.Volume{ + VolumeId: volumeID, + Attachments: []*ec2.VolumeAttachment{ + { + Device: aws.String(path), + InstanceId: aws.String(nodeID), + State: aws.String(state), }, }, - }, + }) + } + + return &ec2.DescribeVolumesOutput{ + Volumes: volumes, } } diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index b15e1ae3bf..09d508208f 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -89,7 +89,8 @@ func newControllerService(driverOptions *DriverOptions) controllerService { region = metadata.GetRegion() } - cloudSrv, err := NewCloudFunc(region, driverOptions.awsSdkDebugLog, driverOptions.userAgentExtra) + klog.InfoS("batching", "status", driverOptions.batching) + cloudSrv, err := NewCloudFunc(region, driverOptions.awsSdkDebugLog, driverOptions.userAgentExtra, driverOptions.batching) if err != nil { panic(err) } @@ -420,7 +421,7 @@ func (d *controllerService) ControllerPublishVolume(ctx context.Context, req *cs if err != nil { return nil, status.Errorf(codes.Internal, "Could not attach volume %q to node %q: %v", volumeID, nodeID, err) } - klog.V(2).InfoS("ControllerPublishVolume: attached", "volumeID", volumeID, "nodeID", nodeID, "devicePath", devicePath) + klog.InfoS("ControllerPublishVolume: attached", "volumeID", volumeID, "nodeID", nodeID, "devicePath", devicePath) pvInfo := map[string]string{DevicePathKey: devicePath} return &csi.ControllerPublishVolumeResponse{PublishContext: pvInfo}, nil @@ -448,6 +449,7 @@ func validateControllerPublishVolumeRequest(req *csi.ControllerPublishVolumeRequ func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { klog.V(4).InfoS("ControllerUnpublishVolume: called", "args", *req) + if err := validateControllerUnpublishVolumeRequest(req); err != nil { return nil, err } @@ -463,12 +465,12 @@ func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req * klog.V(2).InfoS("ControllerUnpublishVolume: detaching", "volumeID", volumeID, "nodeID", nodeID) if err := d.cloud.DetachDisk(ctx, volumeID, nodeID); err != nil { if errors.Is(err, cloud.ErrNotFound) { - klog.V(2).InfoS("ControllerUnpublishVolume: attachment not found", "volumeID", volumeID, "nodeID", nodeID) + klog.InfoS("ControllerUnpublishVolume: attachment not found", "volumeID", volumeID, "nodeID", nodeID) return &csi.ControllerUnpublishVolumeResponse{}, nil } return nil, status.Errorf(codes.Internal, "Could not detach volume %q from node %q: %v", volumeID, nodeID, err) } - klog.V(2).InfoS("ControllerUnpublishVolume: detached", "volumeID", volumeID, "nodeID", nodeID) + klog.InfoS("ControllerUnpublishVolume: detached", "volumeID", volumeID, "nodeID", nodeID) return &csi.ControllerUnpublishVolumeResponse{}, nil } diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 51487e6359..6e91e0672f 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -53,8 +53,8 @@ func TestNewControllerService(t *testing.T) { testErr = errors.New("test error") testRegion = "test-region" - getNewCloudFunc = func(expectedRegion string, _ bool) func(region string, awsSdkDebugLog bool, userAgentExtra string) (cloud.Cloud, error) { - return func(region string, awsSdkDebugLog bool, userAgentExtra string) (cloud.Cloud, error) { + getNewCloudFunc = func(expectedRegion string, _ bool) func(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (cloud.Cloud, error) { + return func(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (cloud.Cloud, error) { if region != expectedRegion { t.Fatalf("expected region %q but got %q", expectedRegion, region) } @@ -66,7 +66,7 @@ func TestNewControllerService(t *testing.T) { testCases := []struct { name string region string - newCloudFunc func(string, bool, string) (cloud.Cloud, error) + newCloudFunc func(string, bool, string, bool) (cloud.Cloud, error) newMetadataFuncErrors bool expectPanic bool }{ @@ -78,7 +78,7 @@ func TestNewControllerService(t *testing.T) { { name: "AWS_REGION variable set, newCloud errors", region: "foo", - newCloudFunc: func(region string, awsSdkDebugLog bool, userAgentExtra string) (cloud.Cloud, error) { + newCloudFunc: func(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (cloud.Cloud, error) { return nil, testErr }, expectPanic: true, diff --git a/tests/e2e/dynamic_provisioning.go b/tests/e2e/dynamic_provisioning.go index abe13d6d1a..48cf268d3c 100644 --- a/tests/e2e/dynamic_provisioning.go +++ b/tests/e2e/dynamic_provisioning.go @@ -555,7 +555,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Dynamic Provisioning", func() { availabilityZones := strings.Split(os.Getenv(awsAvailabilityZonesEnv), ",") availabilityZone := availabilityZones[rand.Intn(len(availabilityZones))] region := availabilityZone[0 : len(availabilityZone)-1] - cloud, err := awscloud.NewCloud(region, false, "") + cloud, err := awscloud.NewCloud(region, false, "", true) if err != nil { Fail(fmt.Sprintf("could not get NewCloud: %v", err)) } diff --git a/tests/e2e/pre_provsioning.go b/tests/e2e/pre_provsioning.go index 0de2b65cef..c6da7e5b0d 100644 --- a/tests/e2e/pre_provsioning.go +++ b/tests/e2e/pre_provsioning.go @@ -88,7 +88,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Pre-Provisioned", func() { Tags: map[string]string{awscloud.VolumeNameTagKey: dummyVolumeName, awscloud.AwsEbsDriverTagKey: "true"}, } var err error - cloud, err = awscloud.NewCloud(region, false, "") + cloud, err = awscloud.NewCloud(region, false, "", true) if err != nil { Fail(fmt.Sprintf("could not get NewCloud: %v", err)) } @@ -261,7 +261,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Pre-Provisioned with Multi-Attach", Tags: map[string]string{awscloud.VolumeNameTagKey: dummyVolumeName, awscloud.AwsEbsDriverTagKey: "true"}, } var err error - cloud, err = awscloud.NewCloud(region, false, "") + cloud, err = awscloud.NewCloud(region, false, "", true) if err != nil { Fail(fmt.Sprintf("could not get NewCloud: %v", err)) }