diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index fda0dcdff0..f2d22b591e 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -91,6 +91,13 @@ var ( // ErrAlreadyExists is returned when a resource is already existent. ErrAlreadyExists = errors.New("Resource already exists") + + // ErrMultiSnapshots is returned when multiple snapshots are found + // with the same ID + ErrMultiSnapshots = errors.New("Multiple snapshots with the same name found") + + // ErrInvalidMaxResults is returned when a MaxResults pagination parameter is between 1 and 4 + ErrInvalidMaxResults = errors.New("MaxResults parameter must be 0 or greater than or equal to 5") ) // Disk represents a EBS volume @@ -124,11 +131,23 @@ type Snapshot struct { ReadyToUse bool } +// ListSnapshotsResponse is the container for our snapshots along with a pagination token to pass back to the caller +type ListSnapshotsResponse struct { + Snapshots []*Snapshot + NextToken string +} + // SnapshotOptions represents parameters to create an EBS volume type SnapshotOptions struct { Tags map[string]string } +// ec2ListSnapshotsResponse is a helper struct returned from the AWS API calling function to the main ListSnapshots function +type ec2ListSnapshotsResponse struct { + Snapshots []*ec2.Snapshot + NextToken *string +} + // EC2 abstracts aws.EC2 to facilitate its mocking. // See https://docs.aws.amazon.com/sdk-for-go/api/service/ec2/ for details type EC2 interface { @@ -156,6 +175,7 @@ type Cloud interface { CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) DeleteSnapshot(ctx context.Context, snapshotID string) (success bool, err error) GetSnapshotByName(ctx context.Context, name string) (snapshot *Snapshot, err error) + ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *ListSnapshotsResponse, err error) } type cloud struct { @@ -542,6 +562,49 @@ func (c *cloud) GetSnapshotByName(ctx context.Context, name string) (snapshot *S return c.ec2SnapshotResponseToStruct(ec2snapshot), nil } +// ListSnapshots retrieves AWS EBS snapshots for an optionally specified volume ID. If maxResults is set, it will return up to maxResults snapshots. If there are more snapshots than maxResults, +// a next token value will be returned to the client as well. They can use this token with subsequent calls to retrieve the next page of results. If maxResults is not set (0), +// there will be no restriction up to 1000 results (https://docs.aws.amazon.com/sdk-for-go/api/service/ec2/#DescribeSnapshotsInput). +func (c *cloud) ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *ListSnapshotsResponse, err error) { + if maxResults > 0 && maxResults < 5 { + return nil, ErrInvalidMaxResults + } + + describeSnapshotsInput := &ec2.DescribeSnapshotsInput{ + MaxResults: aws.Int64(maxResults), + } + + if len(nextToken) != 0 { + describeSnapshotsInput.NextToken = aws.String(nextToken) + } + if len(volumeID) != 0 { + describeSnapshotsInput.Filters = []*ec2.Filter{ + { + Name: aws.String("volume-id"), + Values: []*string{aws.String(volumeID)}, + }, + } + } + + ec2SnapshotsResponse, err := c.listSnapshots(ctx, describeSnapshotsInput) + if err != nil { + return nil, err + } + var snapshots []*Snapshot + for _, ec2Snapshot := range ec2SnapshotsResponse.Snapshots { + snapshots = append(snapshots, c.ec2SnapshotResponseToStruct(ec2Snapshot)) + } + + if len(snapshots) == 0 { + return nil, ErrNotFound + } + + return &ListSnapshotsResponse{ + Snapshots: snapshots, + NextToken: aws.StringValue(ec2SnapshotsResponse.NextToken), + }, nil +} + // Helper method converting EC2 snapshot type to the internal struct func (c *cloud) ec2SnapshotResponseToStruct(ec2Snapshot *ec2.Snapshot) *Snapshot { if ec2Snapshot == nil { @@ -625,7 +688,6 @@ func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2.Snapshot, error) { var snapshots []*ec2.Snapshot var nextToken *string - for { response, err := c.ec2.DescribeSnapshotsWithContext(ctx, request) if err != nil { @@ -640,7 +702,7 @@ func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsI } if l := len(snapshots); l > 1 { - return nil, errors.New("Multiple snapshots with the same name found") + return nil, ErrMultiSnapshots } else if l < 1 { return nil, ErrNotFound } @@ -648,6 +710,28 @@ func (c *cloud) getSnapshot(ctx context.Context, request *ec2.DescribeSnapshotsI return snapshots[0], nil } +// listSnapshots returns all snapshots based from a request +func (c *cloud) listSnapshots(ctx context.Context, request *ec2.DescribeSnapshotsInput) (*ec2ListSnapshotsResponse, error) { + var snapshots []*ec2.Snapshot + var nextToken *string + + response, err := c.ec2.DescribeSnapshotsWithContext(ctx, request) + if err != nil { + return nil, err + } + + snapshots = append(snapshots, response.Snapshots...) + + if response.NextToken != nil { + nextToken = response.NextToken + } + + return &ec2ListSnapshotsResponse{ + Snapshots: snapshots, + NextToken: nextToken, + }, nil +} + // waitForVolume waits for volume to be in the "available" state. // On a random AWS account (shared among several developers) it took 4s on average. func (c *cloud) waitForVolume(ctx context.Context, volumeID string) error { diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 916209c3d7..f877443001 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -18,6 +18,7 @@ package cloud import ( "context" + "errors" "fmt" "strings" "testing" @@ -647,6 +648,217 @@ func TestGetSnapshotByName(t *testing.T) { } } +func TestListSnapshots(t *testing.T) { + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "success: normal", + testFunc: func(t *testing.T) { + expSnapshots := []*Snapshot{ + { + SourceVolumeID: "snap-test-volume1", + SnapshotID: "snap-test-name1", + }, + { + SourceVolumeID: "snap-test-volume2", + SnapshotID: "snap-test-name2", + }, + } + ec2Snapshots := []*ec2.Snapshot{ + { + SnapshotId: aws.String(expSnapshots[0].SnapshotID), + VolumeId: aws.String("snap-test-volume1"), + State: aws.String("completed"), + }, + { + SnapshotId: aws.String(expSnapshots[1].SnapshotID), + VolumeId: aws.String("snap-test-volume2"), + State: aws.String("completed"), + }, + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: ec2Snapshots}, nil) + + _, err := c.ListSnapshots(ctx, "", 0, "") + if err != nil { + t.Fatalf("ListSnapshots() failed: expected no error, got: %v", err) + } + }, + }, + { + name: "success: with volume ID", + testFunc: func(t *testing.T) { + sourceVolumeID := "snap-test-volume" + expSnapshots := []*Snapshot{ + { + SourceVolumeID: sourceVolumeID, + SnapshotID: "snap-test-name1", + }, + { + SourceVolumeID: sourceVolumeID, + SnapshotID: "snap-test-name2", + }, + } + ec2Snapshots := []*ec2.Snapshot{ + { + SnapshotId: aws.String(expSnapshots[0].SnapshotID), + VolumeId: aws.String(sourceVolumeID), + State: aws.String("completed"), + }, + { + SnapshotId: aws.String(expSnapshots[1].SnapshotID), + VolumeId: aws.String(sourceVolumeID), + State: aws.String("completed"), + }, + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: ec2Snapshots}, nil) + + resp, err := c.ListSnapshots(ctx, sourceVolumeID, 0, "") + if err != nil { + t.Fatalf("ListSnapshots() failed: expected no error, got: %v", err) + } + + if len(resp.Snapshots) != len(expSnapshots) { + t.Fatalf("Expected %d snapshots, got %d", len(expSnapshots), len(resp.Snapshots)) + } + + for _, snap := range resp.Snapshots { + if snap.SourceVolumeID != sourceVolumeID { + t.Fatalf("Unexpected source volume. Expected %s, got %s", sourceVolumeID, snap.SourceVolumeID) + } + } + }, + }, + { + name: "success: max results, next token", + testFunc: func(t *testing.T) { + maxResults := 5 + nextTokenValue := "nextTokenValue" + var expSnapshots []*Snapshot + for i := 0; i < maxResults*2; i++ { + expSnapshots = append(expSnapshots, &Snapshot{ + SourceVolumeID: "snap-test-volume1", + SnapshotID: fmt.Sprintf("snap-test-name%d", i), + }) + } + + var ec2Snapshots []*ec2.Snapshot + for i := 0; i < maxResults*2; i++ { + ec2Snapshots = append(ec2Snapshots, &ec2.Snapshot{ + SnapshotId: aws.String(expSnapshots[i].SnapshotID), + VolumeId: aws.String(fmt.Sprintf("snap-test-volume%d", i)), + State: aws.String("completed"), + }) + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + firstCall := mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{ + Snapshots: ec2Snapshots[:maxResults], + NextToken: aws.String(nextTokenValue), + }, nil) + secondCall := mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{ + Snapshots: ec2Snapshots[maxResults:], + }, nil) + gomock.InOrder( + firstCall, + secondCall, + ) + + firstSnapshotsResponse, err := c.ListSnapshots(ctx, "", 5, "") + if err != nil { + t.Fatalf("ListSnapshots() failed: expected no error, got: %v", err) + } + + if len(firstSnapshotsResponse.Snapshots) != maxResults { + t.Fatalf("Expected %d snapshots, got %d", maxResults, len(firstSnapshotsResponse.Snapshots)) + } + + if firstSnapshotsResponse.NextToken != nextTokenValue { + t.Fatalf("Expected next token value '%s' got '%s'", nextTokenValue, firstSnapshotsResponse.NextToken) + } + + secondSnapshotsResponse, err := c.ListSnapshots(ctx, "", 0, firstSnapshotsResponse.NextToken) + if err != nil { + t.Fatalf("CreateSnapshot() failed: expected no error, got: %v", err) + } + + if len(secondSnapshotsResponse.Snapshots) != maxResults { + t.Fatalf("Expected %d snapshots, got %d", maxResults, len(secondSnapshotsResponse.Snapshots)) + } + + if secondSnapshotsResponse.NextToken != "" { + t.Fatalf("Expected next token value to be empty got %s", secondSnapshotsResponse.NextToken) + } + }, + }, + { + name: "fail: AWS DescribeSnapshotsWithContext error", + testFunc: func(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(nil, errors.New("test error")) + + if _, err := c.ListSnapshots(ctx, "", 0, ""); err == nil { + t.Fatalf("ListSnapshots() failed: expected an error, got none") + } + }, + }, + { + name: "fail: no snapshots ErrNotFound", + testFunc: func(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockEC2 := mocks.NewMockEC2(mockCtl) + c := newCloud(mockEC2) + + ctx := context.Background() + + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Eq(ctx), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{}, nil) + + if _, err := c.ListSnapshots(ctx, "", 0, ""); err != nil { + if err != ErrNotFound { + t.Fatalf("Expected error %v, got %v", ErrNotFound, err) + } + } else { + t.Fatalf("Expected error, got none") + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} + func newCloud(mockEC2 EC2) Cloud { return &cloud{ metadata: &Metadata{ diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 3e9c5efbba..03c490ce19 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -44,6 +44,7 @@ var ( csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, + csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS, } ) @@ -369,7 +370,50 @@ func (d *controllerService) DeleteSnapshot(ctx context.Context, req *csi.DeleteS } func (d *controllerService) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { - return nil, status.Error(codes.Unimplemented, "") + klog.V(4).Infof("ListSnapshots: called with args %+v", req) + var snapshots []*cloud.Snapshot + + snapshotID := req.GetSnapshotId() + if len(snapshotID) != 0 { + snapshot, err := d.cloud.GetSnapshotByName(ctx, snapshotID) + if err != nil { + if err == cloud.ErrNotFound { + klog.V(4).Info("ListSnapshots: snapshot not found, returning with success") + return &csi.ListSnapshotsResponse{}, nil + } + return nil, status.Errorf(codes.Internal, "Could not get snapshot ID %q: %v", snapshotID, err) + } + snapshots = append(snapshots, snapshot) + if response, err := newListSnapshotsResponse(&cloud.ListSnapshotsResponse{ + Snapshots: snapshots, + }); err != nil { + return nil, status.Errorf(codes.Internal, "Could not build ListSnapshotsResponse: %v", err) + } else { + return response, nil + } + } + + volumeID := req.GetSourceVolumeId() + nextToken := req.GetStartingToken() + maxEntries := int64(req.GetMaxEntries()) + + cloudSnapshots, err := d.cloud.ListSnapshots(ctx, volumeID, maxEntries, nextToken) + if err != nil { + if err == cloud.ErrNotFound { + klog.V(4).Info("ListSnapshots: snapshot not found, returning with success") + return &csi.ListSnapshotsResponse{}, nil + } + if err == cloud.ErrInvalidMaxResults { + return nil, status.Errorf(codes.InvalidArgument, "Error mapping MaxEntries to AWS MaxResults: %v", err) + } + return nil, status.Errorf(codes.Internal, "Could not list snapshots: %v", err) + } + + response, err := newListSnapshotsResponse(cloudSnapshots) + if err != nil { + return nil, status.Errorf(codes.Internal, "Could not build ListSnapshotsResponse: %v", err) + } + return response, nil } func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { @@ -430,6 +474,38 @@ func newCreateSnapshotResponse(snapshot *cloud.Snapshot) (*csi.CreateSnapshotRes }, nil } +func newListSnapshotsResponse(cloudResponse *cloud.ListSnapshotsResponse) (*csi.ListSnapshotsResponse, error) { + + var entries []*csi.ListSnapshotsResponse_Entry + for _, snapshot := range cloudResponse.Snapshots { + snapshotResponseEntry, err := newListSnapshotsResponseEntry(snapshot) + if err != nil { + return nil, err + } + entries = append(entries, snapshotResponseEntry) + } + return &csi.ListSnapshotsResponse{ + Entries: entries, + NextToken: cloudResponse.NextToken, + }, nil +} + +func newListSnapshotsResponseEntry(snapshot *cloud.Snapshot) (*csi.ListSnapshotsResponse_Entry, error) { + ts, err := ptypes.TimestampProto(snapshot.CreationTime) + if err != nil { + return nil, err + } + return &csi.ListSnapshotsResponse_Entry{ + Snapshot: &csi.Snapshot{ + SnapshotId: snapshot.SnapshotID, + SourceVolumeId: snapshot.SourceVolumeID, + SizeBytes: snapshot.Size, + CreationTime: ts, + ReadyToUse: snapshot.ReadyToUse, + }, + }, nil +} + func getVolSizeBytes(req *csi.CreateVolumeRequest) (int64, error) { var volSizeBytes int64 capRange := req.GetCapacityRange() diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 6e0ab037fc..96f2656c19 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -1136,6 +1136,191 @@ func TestDeleteSnapshot(t *testing.T) { } } +func TestListSnapshots(t *testing.T) { + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "success normal", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{} + mockCloudSnapshotsResponse := &cloud.ListSnapshotsResponse{ + Snapshots: []*cloud.Snapshot{ + { + SnapshotID: "snapshot-1", + SourceVolumeID: "test-vol", + Size: 1, + CreationTime: time.Now(), + }, + { + SnapshotID: "snapshot-2", + SourceVolumeID: "test-vol", + Size: 1, + CreationTime: time.Now(), + }, + }, + NextToken: "", + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(0)), gomock.Eq("")).Return(mockCloudSnapshotsResponse, nil) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ListSnapshots(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(resp.GetEntries()) != len(mockCloudSnapshotsResponse.Snapshots) { + t.Fatalf("Expected %d entries, got %d", len(mockCloudSnapshotsResponse.Snapshots), len(resp.GetEntries())) + } + }, + }, + { + name: "success no snapshots", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{} + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(0)), gomock.Eq("")).Return(nil, cloud.ErrNotFound) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ListSnapshots(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !reflect.DeepEqual(resp, &csi.ListSnapshotsResponse{}) { + t.Fatalf("Expected empty response, got %+v", resp) + } + }, + }, + { + name: "success snapshot ID", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{ + SnapshotId: "snapshot-1", + } + mockCloudSnapshotsResponse := &cloud.Snapshot{ + SnapshotID: "snapshot-1", + SourceVolumeID: "test-vol", + Size: 1, + CreationTime: time.Now(), + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(mockCloudSnapshotsResponse, nil) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ListSnapshots(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(resp.GetEntries()) != 1 { + t.Fatalf("Expected %d entry, got %d", 1, len(resp.GetEntries())) + } + }, + }, + { + name: "success snapshot ID not found", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{ + SnapshotId: "snapshot-1", + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(nil, cloud.ErrNotFound) + + awsDriver := controllerService{cloud: mockCloud} + resp, err := awsDriver.ListSnapshots(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !reflect.DeepEqual(resp, &csi.ListSnapshotsResponse{}) { + t.Fatalf("Expected empty response, got %+v", resp) + } + }, + }, + { + name: "fail snapshot ID multiple found", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{ + SnapshotId: "snapshot-1", + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq("snapshot-1")).Return(nil, cloud.ErrMultiSnapshots) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ListSnapshots(context.Background(), req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != codes.Internal { + t.Fatalf("Expected error code %d, got %d message %s", codes.Internal, srvErr.Code(), srvErr.Message()) + } + } else { + t.Fatalf("Expected error code %d, got no error", codes.Internal) + } + }, + }, + { + name: "fail 0 < MaxEntries < 5", + testFunc: func(t *testing.T) { + req := &csi.ListSnapshotsRequest{ + MaxEntries: 4, + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().ListSnapshots(gomock.Eq(ctx), gomock.Eq(""), gomock.Eq(int64(4)), gomock.Eq("")).Return(nil, cloud.ErrInvalidMaxResults) + + awsDriver := controllerService{cloud: mockCloud} + if _, err := awsDriver.ListSnapshots(context.Background(), req); err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != codes.InvalidArgument { + t.Fatalf("Expected error code %d, got %d message %s", codes.InvalidArgument, srvErr.Code(), srvErr.Message()) + } + } else { + t.Fatalf("Expected error code %d, got no error", codes.InvalidArgument) + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} + func TestControllerPublishVolume(t *testing.T) { stdVolCap := &csi.VolumeCapability{ AccessType: &csi.VolumeCapability_Mount{ diff --git a/pkg/driver/mocks/mock_cloud.go b/pkg/driver/mocks/mock_cloud.go index 53b4aa199f..cb5f47016b 100644 --- a/pkg/driver/mocks/mock_cloud.go +++ b/pkg/driver/mocks/mock_cloud.go @@ -196,6 +196,21 @@ func (mr *MockCloudMockRecorder) IsExistInstance(arg0, arg1 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsExistInstance", reflect.TypeOf((*MockCloud)(nil).IsExistInstance), arg0, arg1) } +// ListSnapshots mocks base method +func (m *MockCloud) ListSnapshots(arg0 context.Context, arg1 string, arg2 int64, arg3 string) (*cloud.ListSnapshotsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListSnapshots", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*cloud.ListSnapshotsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListSnapshots indicates an expected call of ListSnapshots +func (mr *MockCloudMockRecorder) ListSnapshots(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSnapshots", reflect.TypeOf((*MockCloud)(nil).ListSnapshots), arg0, arg1, arg2, arg3) +} + // WaitForAttachmentState mocks base method func (m *MockCloud) WaitForAttachmentState(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() diff --git a/tests/sanity/fake_cloud_provider.go b/tests/sanity/fake_cloud_provider.go index d1e7fac2a5..4270b97a8d 100644 --- a/tests/sanity/fake_cloud_provider.go +++ b/tests/sanity/fake_cloud_provider.go @@ -31,6 +31,7 @@ type fakeCloudProvider struct { snapshots map[string]*fakeSnapshot m *cloud.Metadata pub map[string]string + tokens map[string]int64 } type fakeDisk struct { @@ -53,6 +54,7 @@ func newFakeCloudProvider() *fakeCloudProvider { Region: "region", AvailabilityZone: "az", }, + tokens: make(map[string]int64), } } @@ -133,11 +135,19 @@ func (c *fakeCloudProvider) IsExistInstance(ctx context.Context, nodeID string) } func (c *fakeCloudProvider) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *cloud.SnapshotOptions) (snapshot *cloud.Snapshot, err error) { - r1 := rand.New(rand.NewSource(time.Now().UnixNano())) - snapshotID := fmt.Sprintf("snapshot-%d", r1.Uint64()) + var snapshotID string if len(snapshotOptions.Tags[cloud.SnapshotNameTagKey]) == 0 { // for simplicity: let's have the Name and ID identical + r1 := rand.New(rand.NewSource(time.Now().UnixNano())) + snapshotID = fmt.Sprintf("snapshot-%d", r1.Uint64()) snapshotOptions.Tags[cloud.SnapshotNameTagKey] = snapshotID + } else { + snapshotID = snapshotOptions.Tags[cloud.SnapshotNameTagKey] + } + for _, existingSnapshot := range c.snapshots { + if existingSnapshot.Snapshot.SnapshotID == snapshotID && existingSnapshot.Snapshot.SourceVolumeID == volumeID { + return nil, cloud.ErrAlreadyExists + } } s := &fakeSnapshot{ Snapshot: &cloud.Snapshot{ @@ -162,14 +172,37 @@ func (c *fakeCloudProvider) DeleteSnapshot(ctx context.Context, snapshotID strin func (c *fakeCloudProvider) GetSnapshotByName(ctx context.Context, name string) (snapshot *cloud.Snapshot, err error) { var snapshots []*fakeSnapshot for _, s := range c.snapshots { - for key, value := range s.tags { - if key == cloud.SnapshotNameTagKey && value == name { - snapshots = append(snapshots, s) - } + if s.SnapshotID == name { + snapshots = append(snapshots, s) } } if len(snapshots) == 0 { - return nil, nil + return nil, cloud.ErrNotFound } return snapshots[0].Snapshot, nil } + +func (c *fakeCloudProvider) ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *cloud.ListSnapshotsResponse, err error) { + var snapshots []*cloud.Snapshot + var retToken string + for _, fakeSnapshot := range c.snapshots { + if fakeSnapshot.Snapshot.SourceVolumeID == volumeID || len(volumeID) == 0 { + snapshots = append(snapshots, fakeSnapshot.Snapshot) + } + } + if maxResults > 0 { + r1 := rand.New(rand.NewSource(time.Now().UnixNano())) + retToken = fmt.Sprintf("token-%d", r1.Uint64()) + c.tokens[retToken] = maxResults + snapshots = snapshots[0:maxResults] + fmt.Printf("%v\n", snapshots) + } + if len(nextToken) != 0 { + snapshots = snapshots[c.tokens[nextToken]:] + } + return &cloud.ListSnapshotsResponse{ + Snapshots: snapshots, + NextToken: retToken, + }, nil + +}