diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 1f300db..63ba6b7 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -18,6 +18,9 @@ jobs: uses: actions/setup-go@v5 with: go-version: 1.21 + - name: Setup Mocks + run: | + make mocks - name: Run golangci-lint uses: golangci/golangci-lint-action@v3 with: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c8e1aa8..9011e72 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,6 +25,7 @@ jobs: run: | echo "$FEATURES_CONF" > docker/multi-node/config/features.conf + echo "$FEATURES_CONF" > docker/vanilla/config/features.conf echo "$FEATURES_CONF" > docker/tls/config/features.conf echo "$FEATURES_CONF" > docker/mtls/config/features.conf echo "$FEATURES_CONF" > docker/auth/config/features.conf diff --git a/.golangci.yml b/.golangci.yml index e343d8e..bc7bc30 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -75,6 +75,12 @@ issues: - path: '(.+)test\.go' linters: - govet # Test code field alignment for sake of space is not a concern + - path: 'token_manager_test.go' + linters: + - goconst # Test code is allowed to have constants + - path: 'connection_provider_test.go' + linters: + - goconst # - path: dir/sample\.go # linters: # - lll # Test code is allowed to have long lines diff --git a/README.md b/README.md index e69e754..66aabd4 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![codecov](https://codecov.io/gh/aerospike/avs-client-go/graph/badge.svg?token=811TWWPW6S)](https://codecov.io/gh/aerospike/avs-client-go) + # Aerospike Vector Search Go Client > :warning: The go client is currently in development. APIs will break in the future! diff --git a/client.go b/client.go index 8d68ce1..6587204 100644 --- a/client.go +++ b/client.go @@ -22,17 +22,27 @@ const ( indexWaitDuration = time.Millisecond * 100 ) const ( - failedToInsertRecord = "failed to insert record" - failedToGetRecord = "failed to get record" - failedToDeleteRecord = "failed to delete record" - failedToCheckRecordExists = "failed to check if record exists" - failedToCheckIsIndexed = "failed to check if record exists" + failedToInsertRecord = "failed to insert record" + failedToGetRecord = "failed to get record" + failedToDeleteRecord = "failed to delete record" + failedToCheckRecordExists = "failed to check if record exists" + failedToCheckIsIndexed = "failed to check if record is indexed" + failedToWaitForIndexCompletion = "failed to wait for index completion" ) +type connProvider interface { + GetNodeIDs() []uint64 + GetRandomConn() (*connection, error) + GetSeedConn() (*connection, error) + GetNodeConn(id uint64) (*connection, error) + Close() error +} + // Client is a client for managing Aerospike Vector Indexes. type Client struct { - logger *slog.Logger - channelProvider *channelProvider + logger *slog.Logger + connectionProvider connProvider + token tokenManager } // NewClient creates a new Client instance. @@ -64,23 +74,40 @@ func NewClient( logger = logger.WithGroup("avs") logger.Info("creating new client") - channelProvider, err := newChannelProvider( + var grpcToken tokenManager + + if credentials != nil { + grpcToken = newGrpcJWTToken(credentials.username, credentials.password, logger) + } + + connectionProvider, err := newConnectionProvider( ctx, seeds, listenerName, isLoadBalancer, - credentials, + grpcToken, tlsConfig, logger, ) if err != nil { - logger.Error("failed to create channel provider", slog.Any("error", err)) + grpcToken.Close() + logger.Error("failed to create connection provider", slog.Any("error", err)) + return nil, NewAVSErrorFromGrpc("failed to connect to server", err) } + return newClient(connectionProvider, grpcToken, logger) +} + +func newClient( + connectionProvider connProvider, + token tokenManager, + logger *slog.Logger, +) (*Client, error) { return &Client{ - logger: logger, - channelProvider: channelProvider, + logger: logger, + token: token, + connectionProvider: connectionProvider, }, nil } @@ -91,7 +118,12 @@ func NewClient( // error: An error if the closure fails, otherwise nil. func (c *Client) Close() error { c.logger.Info("Closing client") - return c.channelProvider.Close() + + if c.token != nil { + c.token.Close() + } + + return c.connectionProvider.Close() } func (c *Client) put( @@ -108,7 +140,7 @@ func (c *Client) put( slog.Any("key", key), ) - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToInsertRecord, slog.Any("error", err)) return NewAVSError(failedToInsertRecord, err) @@ -259,7 +291,7 @@ func (c *Client) Get(ctx context.Context, ) logger.DebugContext(ctx, "getting record") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToGetRecord, slog.Any("error", err)) return nil, NewAVSError(failedToGetRecord, err) @@ -306,7 +338,7 @@ func (c *Client) Delete(ctx context.Context, namespace string, set *string, key ) logger.DebugContext(ctx, "deleting record") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToDeleteRecord, slog.Any("error", err)) return NewAVSError(failedToDeleteRecord, err) @@ -314,8 +346,8 @@ func (c *Client) Delete(ctx context.Context, namespace string, set *string, key protoKey, err := protos.ConvertToKey(namespace, set, key) if err != nil { - logger.Error(failedToInsertRecord, slog.Any("error", err)) - return NewAVSError(failedToInsertRecord, err) + logger.Error(failedToDeleteRecord, slog.Any("error", err)) + return NewAVSError(failedToDeleteRecord, err) } getReq := &protos.DeleteRequest{ @@ -325,7 +357,7 @@ func (c *Client) Delete(ctx context.Context, namespace string, set *string, key _, err = conn.transactClient.Delete(ctx, getReq) if err != nil { logger.Error(failedToDeleteRecord, slog.Any("error", err)) - return NewAVSErrorFromGrpc(failedToGetRecord, err) + return NewAVSErrorFromGrpc(failedToDeleteRecord, err) } return nil @@ -356,7 +388,7 @@ func (c *Client) Exists( ) logger.DebugContext(ctx, "checking if record exists") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToCheckRecordExists, slog.Any("error", err)) return false, NewAVSError(failedToCheckRecordExists, err) @@ -409,7 +441,7 @@ func (c *Client) IsIndexed( ) logger.DebugContext(ctx, "checking if record is indexed") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToCheckIsIndexed, slog.Any("error", err)) return false, NewAVSError(failedToCheckIsIndexed, err) @@ -422,6 +454,10 @@ func (c *Client) IsIndexed( } isIndexedReq := &protos.IsIndexedRequest{ + IndexId: &protos.IndexId{ + Namespace: namespace, + Name: indexName, + }, Key: protoKey, } @@ -449,7 +485,7 @@ func (c *Client) vectorSearch(ctx context.Context, logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "searching for vector") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to search for vector" logger.Error(msg, slog.Any("error", err)) @@ -619,10 +655,12 @@ func (c *Client) WaitForIndexCompletion( logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "waiting for index completion") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { - logger.Error("failed to wait for index completion", slog.Any("error", err)) - return err + msg := failedToWaitForIndexCompletion + logger.Error(msg, slog.Any("error", err)) + + return NewAVSError(msg, err) } indexStatusReq := createIndexStatusRequest(namespace, indexName) @@ -636,8 +674,10 @@ func (c *Client) WaitForIndexCompletion( for { indexStatus, err := conn.indexClient.GetStatus(ctx, indexStatusReq) if err != nil { - logger.ErrorContext(ctx, "failed to wait for index completion", slog.Any("error", err)) - return err + msg := failedToWaitForIndexCompletion + logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return NewAVSError(msg, err) } // We consider the index completed when unmerged record count == 0 for @@ -655,7 +695,7 @@ func (c *Client) WaitForIndexCompletion( } else { logger.DebugContext(ctx, "index not yet completed", slog.Int64("unmerged", unmerged)) - unmergedNotZeroCount-- + unmergedNotZeroCount++ } timer.Reset(waitInterval) @@ -663,8 +703,11 @@ func (c *Client) WaitForIndexCompletion( select { case <-timer.C: case <-ctx.Done(): - logger.ErrorContext(ctx, "waiting for index completion canceled") - return ctx.Err() + msg := "waiting for index completion canceled" + + logger.ErrorContext(ctx, msg) + + return NewAVSError(msg, ctx.Err()) } } } @@ -771,7 +814,7 @@ func (c *Client) IndexCreateFromIndexDef( logger := c.logger.With(slog.Any("definition", indexDef)) logger.DebugContext(ctx, "creating index from definition") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to create index from definition" logger.Error(msg, slog.Any("error", err)) @@ -785,7 +828,7 @@ func (c *Client) IndexCreateFromIndexDef( _, err = conn.indexClient.Create(ctx, indexCreateReq) if err != nil { - msg := "failed to create index" + msg := "failed to create index from definition" logger.Error(msg, slog.Any("error", err)) return NewAVSErrorFromGrpc(msg, err) @@ -804,7 +847,7 @@ func (c *Client) IndexCreateFromIndexDef( // ctx (context.Context): The context for the operation. // namespace (string): The namespace of the index. // name (string): The name of the index. -// metadata (map[string]string): Metadata to update on the index. +// labels (map[string]string): Labels to update on the index. // hnswParams (*protos.HnswIndexUpdate): The HNSW parameters to update. // // Returns: @@ -814,14 +857,14 @@ func (c *Client) IndexUpdate( ctx context.Context, namespace string, indexName string, - metadata map[string]string, + labels map[string]string, hnswParams *protos.HnswIndexUpdate, ) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "updating index") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to update index" logger.Error(msg, slog.Any("error", err)) @@ -834,7 +877,7 @@ func (c *Client) IndexUpdate( Namespace: namespace, Name: indexName, }, - Labels: metadata, + Labels: labels, Update: &protos.IndexUpdateRequest_HnswIndexUpdate{ HnswIndexUpdate: hnswParams, }, @@ -866,7 +909,7 @@ func (c *Client) IndexDrop(ctx context.Context, namespace, indexName string) err logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "dropping index") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to drop index" logger.Error(msg, slog.Any("error", err)) @@ -909,7 +952,7 @@ func (c *Client) IndexDrop(ctx context.Context, namespace, indexName string) err func (c *Client) IndexList(ctx context.Context, applyDefaults bool) (*protos.IndexDefinitionList, error) { c.logger.DebugContext(ctx, "listing indexes") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to get indexes" @@ -956,7 +999,7 @@ func (c *Client) IndexGet( logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "getting index") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to get index" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -999,7 +1042,7 @@ func (c *Client) IndexGetStatus(ctx context.Context, namespace, indexName string logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "getting index status") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to get index status" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1041,7 +1084,7 @@ func (c *Client) GcInvalidVertices(ctx context.Context, namespace, indexName str logger.DebugContext(ctx, "garbage collection invalid vertices") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to garbage collect invalid vertices" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1084,7 +1127,7 @@ func (c *Client) CreateUser(ctx context.Context, username, password string, role logger := c.logger.With(slog.String("username", username), slog.Any("roles", roles)) logger.DebugContext(ctx, "creating user") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to create user" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1123,7 +1166,7 @@ func (c *Client) UpdateCredentials(ctx context.Context, username, password strin logger := c.logger.With(slog.String("username", username)) logger.DebugContext(ctx, "updating user credentials") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to update user credentials" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1160,7 +1203,7 @@ func (c *Client) DropUser(ctx context.Context, username string) error { logger := c.logger.With(slog.String("username", username)) logger.DebugContext(ctx, "dropping user") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to drop user" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1198,7 +1241,7 @@ func (c *Client) GetUser(ctx context.Context, username string) (*protos.User, er logger := c.logger.With(slog.String("username", username)) logger.DebugContext(ctx, "getting user") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to get user" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1234,7 +1277,7 @@ func (c *Client) GetUser(ctx context.Context, username string) (*protos.User, er func (c *Client) ListUsers(ctx context.Context) (*protos.ListUsersResponse, error) { c.logger.DebugContext(ctx, "listing users") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to list users" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1244,7 +1287,7 @@ func (c *Client) ListUsers(ctx context.Context) (*protos.ListUsersResponse, erro usersResp, err := conn.userAdminClient.ListUsers(ctx, &emptypb.Empty{}) if err != nil { - msg := "failed to lists users" + msg := "failed to list users" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) return nil, NewAVSErrorFromGrpc(msg, err) @@ -1268,7 +1311,7 @@ func (c *Client) GrantRoles(ctx context.Context, username string, roles []string logger := c.logger.With(slog.String("username", username), slog.Any("roles", roles)) logger.DebugContext(ctx, "granting user roles") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to grant user roles" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1307,7 +1350,7 @@ func (c *Client) RevokeRoles(ctx context.Context, username string, roles []strin logger := c.logger.With(slog.String("username", username), slog.Any("roles", roles)) logger.DebugContext(ctx, "revoking user roles") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to revoke user roles" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1344,7 +1387,7 @@ func (c *Client) RevokeRoles(ctx context.Context, username string, roles []strin func (c *Client) ListRoles(ctx context.Context) (*protos.ListRolesResponse, error) { c.logger.DebugContext(ctx, "listing roles") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to list roles" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1354,7 +1397,7 @@ func (c *Client) ListRoles(ctx context.Context) (*protos.ListRolesResponse, erro rolesResp, err := conn.userAdminClient.ListRoles(ctx, &emptypb.Empty{}) if err != nil { - msg := "failed to lists roles" + msg := "failed to list roles" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) return nil, NewAVSErrorFromGrpc(msg, err) @@ -1364,6 +1407,7 @@ func (c *Client) ListRoles(ctx context.Context) (*protos.ListRolesResponse, erro } // NodeIDs returns a list of all the node IDs that the client is connected to. +// If load-balancer is set true no NodeIDs will be returned. // If a node is accessible but not a part of the cluster it will not be // returned. // @@ -1377,7 +1421,7 @@ func (c *Client) ListRoles(ctx context.Context) (*protos.ListRolesResponse, erro func (c *Client) NodeIDs(ctx context.Context) []*protos.NodeId { c.logger.DebugContext(ctx, "getting cluster info") - ids := c.channelProvider.GetNodeIDs() + ids := c.connectionProvider.GetNodeIDs() nodeIDs := make([]*protos.NodeId, len(ids)) for i, id := range ids { @@ -1451,7 +1495,7 @@ func (c *Client) ClusteringState(ctx context.Context, nodeID *protos.NodeId) (*p conn, err := c.getConnection(nodeID) if err != nil { - msg := "failed to list roles" + msg := "failed to get clustering state" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) return nil, NewAVSError(msg, err) @@ -1529,7 +1573,7 @@ func (c *Client) About(ctx context.Context, nodeID *protos.NodeId) (*protos.Abou msg := "failed to make about request" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) - return nil, NewAVSErrorFromGrpc(msg, err) + return nil, NewAVSError(msg, err) } resp, err := conn.aboutClient.Get(ctx, &protos.AboutRequest{}) @@ -1545,10 +1589,10 @@ func (c *Client) About(ctx context.Context, nodeID *protos.NodeId) (*protos.Abou func (c *Client) getConnection(nodeID *protos.NodeId) (*connection, error) { if nodeID == nil { - return c.channelProvider.GetSeedConn() + return c.connectionProvider.GetSeedConn() } - return c.channelProvider.GetNodeConn(nodeID.GetId()) + return c.connectionProvider.GetNodeConn(nodeID.GetId()) } // waitForIndexCreation waits for an index to be created and blocks until it is. @@ -1560,7 +1604,7 @@ func (c *Client) waitForIndexCreation(ctx context.Context, ) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to wait for index creation" logger.Error(msg, slog.Any("error", err)) @@ -1611,7 +1655,7 @@ func (c *Client) waitForIndexCreation(ctx context.Context, func (c *Client) waitForIndexDrop(ctx context.Context, namespace, indexName string, waitInterval time.Duration) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to wait for index deletion" logger.Error(msg, slog.Any("error", err)) diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..f03c1a5 --- /dev/null +++ b/client_test.go @@ -0,0 +1,3621 @@ +//go:build unit + +package avs + +import ( + "context" + "fmt" + "io" + "log/slog" + "testing" + "time" + + "github.com/aerospike/avs-client-go/protos" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestInsert_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedPutRequest := &protos.PutRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + WriteType: protos.WriteType_INSERT_ONLY, + Fields: []*protos.Field{ + { + Name: "field1", + Value: &protos.Value{Value: &protos.Value_StringValue{StringValue: "value1"}}, + }, + }, + IgnoreMemQueueFull: false, + } + + // Set up expectations for transactClient.Put() + mockTransactClient. + EXPECT(). + Put(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.PutRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedPutRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + assert.NoError(t, err) +} + +func TestInsert_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToInsertRecord, fmt.Errorf("foo"))) +} + +func TestInsert_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToInsertRecord, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestInsert_FailsConvertingFields(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + recordData := map[string]interface{}{"field1": struct{}{}} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToInsertRecord, fmt.Errorf("error converting field value for key 'field1': unsupported value type: struct {}"))) +} + +func TestInsert_FailsPutRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockTransactClient. + EXPECT(). + Put(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToInsertRecord, fmt.Errorf("foo"))) +} + +func TestUpdate_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock connProvider + mockConnProvider := NewMockconnProvider(ctrl) + // Create a mock transactClient + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + // Create a mock connection + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedPutRequest := &protos.PutRequest{ + // You need to fill this with expected values + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + WriteType: protos.WriteType_UPDATE_ONLY, + Fields: []*protos.Field{ + { + Name: "field1", + Value: &protos.Value{Value: &protos.Value_StringValue{StringValue: "value1"}}, + }, + }, + IgnoreMemQueueFull: false, + } + + // Set up expectations for transactClient.Put() + mockTransactClient. + EXPECT(). + Put(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.PutRequest, opts ...grpc.CallOption) { + // Optionally, you can assert that req matches expectedPutRequest + assert.Equal(t, expectedPutRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + // Call the method under test + err = client.Update(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + // Assert no error occurred + assert.NoError(t, err) +} + +func TestReplace_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock connProvider + mockConnProvider := NewMockconnProvider(ctrl) + // Create a mock transactClient + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + // Create a mock connection + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedPutRequest := &protos.PutRequest{ + // You need to fill this with expected values + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + WriteType: protos.WriteType_UPSERT, + Fields: []*protos.Field{ + { + Name: "field1", + Value: &protos.Value{Value: &protos.Value_StringValue{StringValue: "value1"}}, + }, + }, + IgnoreMemQueueFull: false, + } + + // Set up expectations for transactClient.Put() + mockTransactClient. + EXPECT(). + Put(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.PutRequest, opts ...grpc.CallOption) { + // Optionally, you can assert that req matches expectedPutRequest + assert.Equal(t, expectedPutRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + // Call the method under test + err = client.Upsert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + // Assert no error occurred + assert.NoError(t, err) +} + +func TestGet_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedGetRequest := &protos.GetRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + Projection: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_ALL, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_NONE, + }, + }, + } + + protosRecord := &protos.Record{ + Fields: []*protos.Field{ + {Name: "field1", Value: &protos.Value{Value: &protos.Value_StringValue{StringValue: "value1"}}}, + }, + Metadata: &protos.Record_AerospikeMetadata{ + AerospikeMetadata: &protos.AerospikeRecordMetadata{ + Generation: 10, + Expiration: 1, + }, + }, + } + + expTime := AerospikeEpoch.Add(time.Second * 1) + + expectedRecord := &Record{ + Data: map[string]any{ + "field1": "value1", + }, + Generation: 10, + Expiration: &expTime, + } + + mockTransactClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(protosRecord, nil). + Do(func(ctx context.Context, in *protos.GetRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedGetRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + record, err := client.Get(ctx, namespace, set, key, nil, nil) + + assert.NoError(t, err) + assert.Equal(t, expectedRecord, record) +} + +func TestGet_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + _, err = client.Get(ctx, namespace, set, key, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToGetRecord, fmt.Errorf("foo"))) +} + +func TestGet_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + + _, err = client.Get(ctx, namespace, set, key, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToGetRecord, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestGet_FailsGetRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + mockTransactClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + + _, err = client.Get(ctx, namespace, set, key, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToGetRecord, fmt.Errorf("foo"))) +} + +func TestDelete_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedDeleteRequest := &protos.DeleteRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + } + + mockTransactClient. + EXPECT(). + Delete(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.DeleteRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedDeleteRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + err = client.Delete(ctx, namespace, set, key) + + assert.NoError(t, err) +} + +func TestDelete_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + err = client.Delete(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToDeleteRecord, fmt.Errorf("foo"))) +} + +func TestDelete_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + + err = client.Delete(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToDeleteRecord, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestDelete_FailsDeleteRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + mockTransactClient. + EXPECT(). + Delete(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + + err = client.Delete(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToDeleteRecord, fmt.Errorf("foo"))) +} + +func TestExists_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedExistsRequest := &protos.ExistsRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + } + + mockTransactClient. + EXPECT(). + Exists(gomock.Any(), gomock.Any()). + Return(&protos.Boolean{ + Value: true, + }, nil). + Do(func(ctx context.Context, in *protos.ExistsRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedExistsRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + exists, err := client.Exists(ctx, namespace, set, key) + + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestExists_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + _, err = client.Exists(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckRecordExists, fmt.Errorf("foo"))) +} + +func TestExists_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + + _, err = client.Exists(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckRecordExists, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestExists_FailsDeleteRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + mockTransactClient. + EXPECT(). + Exists(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + + _, err = client.Exists(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckRecordExists, fmt.Errorf("foo"))) +} + +func TestIsIndexed_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedIsIndexedRequest := &protos.IsIndexedRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + mockTransactClient. + EXPECT(). + IsIndexed(gomock.Any(), gomock.Any()). + Return(&protos.Boolean{ + Value: true, + }, nil). + Do(func(ctx context.Context, in *protos.IsIndexedRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIsIndexedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + indexName := "testIndex" + + exists, err := client.IsIndexed(ctx, namespace, set, indexName, key) + + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestIsIndexed_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + indexName := "testIndex" + + _, err = client.IsIndexed(ctx, namespace, set, indexName, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckIsIndexed, fmt.Errorf("foo"))) +} + +func TestIsIndexed_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + indexName := "testIndex" + + _, err = client.IsIndexed(ctx, namespace, set, indexName, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckIsIndexed, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestIsIndexed_FailsDeleteRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + mockTransactClient. + EXPECT(). + IsIndexed(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + indexName := "testIndex" + + _, err = client.IsIndexed(ctx, namespace, set, indexName, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckIsIndexed, fmt.Errorf("foo"))) +} + +func TestVectorSearchFloat32_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + mockVectorSearchClient := protos.NewMockTransactService_VectorSearchClient(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedVectorSearchFloat32Request := &protos.VectorSearchRequest{ + Index: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + QueryVector: &protos.Vector{ + Data: &protos.Vector_FloatData{ + FloatData: &protos.FloatData{ + Value: []float32{1.0, 2.0, 3.0}, + }, + }, + }, + Limit: 7, + SearchParams: &protos.VectorSearchRequest_HnswSearchParams{ + HnswSearchParams: &protos.HnswSearchParams{ + Ef: Ptr(uint32(8)), + }, + }, + Projection: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_ALL, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_NONE, + }, + }, + } + + mockTransactClient. + EXPECT(). + VectorSearch(gomock.Any(), gomock.Any()). + Return(mockVectorSearchClient, nil). + Do(func(ctx context.Context, in *protos.VectorSearchRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedVectorSearchFloat32Request, in) + }) + + vectorCounter := 0 + mockVectorSearchClient. + EXPECT(). + Recv(). + AnyTimes(). + DoAndReturn( + func() (*protos.Neighbor, error) { + vectorCounter++ + + if vectorCounter == 4 { + return nil, io.EOF + } + + return &protos.Neighbor{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: Ptr("testSet"), + Value: &protos.Key_StringValue{ + StringValue: fmt.Sprintf("key-%d", vectorCounter), + }, + }, + Record: &protos.Record{ + Fields: []*protos.Field{ + { + Name: "field1", + Value: &protos.Value{ + Value: &protos.Value_StringValue{ + StringValue: "value1", + }, + }, + }, + }, + Metadata: &protos.Record_AerospikeMetadata{ + AerospikeMetadata: &protos.AerospikeRecordMetadata{ + Generation: uint32(vectorCounter), + Expiration: uint32(vectorCounter), + }, + }, + }, + Distance: float32(vectorCounter), + }, nil + }, + ) + + expectedNeighbors := []*Neighbor{ + { + Record: &Record{ + Data: map[string]any{ + "field1": "value1", + }, + Generation: uint32(1), + Expiration: Ptr(AerospikeEpoch.Add(time.Second * 1)), + }, + Set: Ptr("testSet"), + Key: "key-1", + Namespace: "testNamespace", + Distance: float32(1), + }, + { + Record: &Record{ + Data: map[string]any{ + "field1": "value1", + }, + Generation: uint32(2), + Expiration: Ptr(AerospikeEpoch.Add(time.Second * 2)), + }, + Set: Ptr("testSet"), + Key: "key-2", + Namespace: "testNamespace", + Distance: float32(2), + }, + { + Record: &Record{ + Data: map[string]any{ + "field1": "value1", + }, + Generation: uint32(3), + Expiration: Ptr(AerospikeEpoch.Add(time.Second * 3)), + }, + Set: Ptr("testSet"), + Key: "key-3", + Namespace: "testNamespace", + Distance: float32(3), + }, + } + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: Ptr(uint32(8)), + } + + neighbors, err := client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + assert.NoError(t, err) + assert.Equal(t, len(expectedNeighbors), len(neighbors)) + for i, _ := range neighbors { + assert.EqualExportedValues(t, expectedNeighbors[i], neighbors[i]) + } +} + +func TestVectorSearchFloat32_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: Ptr(uint32(8)), + } + + _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to search for vector", fmt.Errorf("foo"))) +} + +func TestVectorSearchFloat32_FailsVectorSearch(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockTransactClient. + EXPECT(). + VectorSearch(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: Ptr(uint32(8)), + } + + _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to search for vector", fmt.Errorf("foo"))) +} + +func TestVectorSearchFloat32_FailedToRecvAllNeighbors(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + mockVectorSearchClient := protos.NewMockTransactService_VectorSearchClient(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockTransactClient. + EXPECT(). + VectorSearch(gomock.Any(), gomock.Any()). + Return(mockVectorSearchClient, nil) + + mockVectorSearchClient. + EXPECT(). + Recv(). + AnyTimes(). + DoAndReturn( + func() (*protos.Neighbor, error) { + return nil, fmt.Errorf("foo") + }, + ) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: Ptr(uint32(8)), + } + + _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to receive all neighbors", fmt.Errorf("foo"))) +} + +type unknowKeyValue struct{} + +func (u *unknowKeyValue) isKey_Value() {} + +func TestVectorSearchFloat32_FailedToConvertNeighbor(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + mockVectorSearchClient := protos.NewMockTransactService_VectorSearchClient(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockTransactClient. + EXPECT(). + VectorSearch(gomock.Any(), gomock.Any()). + Return(mockVectorSearchClient, nil) + + mockVectorSearchClient. + EXPECT(). + Recv(). + AnyTimes(). + DoAndReturn( + func() (*protos.Neighbor, error) { + return &protos.Neighbor{ + Key: &protos.Key{Value: protos.NewMockisKey_Value(ctrl)}, + }, nil + }, + ) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: Ptr(uint32(8)), + } + + _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to convert neighbor", fmt.Errorf("error converting neighbor: unsupported key value type: *protos.MockisKey_Value"))) +} + +func TestWaitForIndexCompletion_SuccessAfterZeroCountReturnedTwice(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedGetStatusRequest := &protos.IndexStatusRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + indexStatus := &protos.IndexStatusResponse{ + UnmergedRecordCount: 0, + } + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Times(3). + Return(indexStatus, nil). + Do(func(ctx context.Context, in *protos.IndexStatusRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedGetStatusRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Millisecond*1) + + assert.NoError(t, err) +} + +func TestWaitForIndexCompletion_SuccessAfterNonZeroUnmergedCount(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedGetStatusRequest := &protos.IndexStatusRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + indexStatusCount := 0 + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Times(2). + DoAndReturn(func(ctx context.Context, in *protos.IndexStatusRequest, opts ...grpc.CallOption) (*protos.IndexStatusResponse, error) { + assert.Equal(t, expectedGetStatusRequest, in) + indexStatusCount++ + + if indexStatusCount == 1 { + return &protos.IndexStatusResponse{ + UnmergedRecordCount: 1, + }, nil + } else { + return &protos.IndexStatusResponse{ + UnmergedRecordCount: 0, + }, nil + } + + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Millisecond*1) + + assert.NoError(t, err) +} + +func TestWaitForIndexCompletion_FailToGetRandomConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Millisecond*1) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to wait for index completion", fmt.Errorf("foo"))) +} + +func TestWaitForIndexCompletion_FailGetStatusCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Millisecond*1) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to wait for index completion", fmt.Errorf("foo"))) +} + +func TestWaitForIndexCompletion_FailTimeout(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Second*1) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("waiting for index completion canceled", fmt.Errorf("context deadline exceeded"))) +} + +func TestIndexCreateFromIndexDef_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Times(2). + Return(mockConn, nil) + + indexDef := &protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + expectedIndexCreateRequest := &protos.IndexCreateRequest{ + Definition: indexDef, + } + + mockIndexClient. + EXPECT(). + Create(gomock.Any(), gomock.Any()). + Return(nil, nil). + Do(func(ctx context.Context, in *protos.IndexCreateRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexCreateRequest, in) + }) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + err = client.IndexCreateFromIndexDef(ctx, indexDef) + + assert.NoError(t, err) +} + +func TestIndexCreateFromIndexDef_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + indexDef := &protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + err = client.IndexCreateFromIndexDef(ctx, indexDef) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to create index from definition", fmt.Errorf("foo"))) +} + +func TestIndexCreateFromIndexDef_FailCreateCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + Create(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + indexDef := &protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + err = client.IndexCreateFromIndexDef(ctx, indexDef) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to create index from definition", fmt.Errorf("foo"))) +} + +func TestIndexUpdate_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedIndexUpdateRequest := &protos.IndexUpdateRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + Labels: map[string]string{ + "foo": "bar", + }, + Update: &protos.IndexUpdateRequest_HnswIndexUpdate{ + HnswIndexUpdate: &protos.HnswIndexUpdate{ + MaxMemQueueSize: Ptr(uint32(10)), + }, + }, + } + + mockIndexClient. + EXPECT(). + Update(gomock.Any(), gomock.Any()). + Return(nil, nil). + Do(func(ctx context.Context, in *protos.IndexUpdateRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexUpdateRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + testMetadata := map[string]string{ + "foo": "bar", + } + hnswParams := &protos.HnswIndexUpdate{ + MaxMemQueueSize: Ptr(uint32(10)), + } + + err = client.IndexUpdate(ctx, testNamespace, testIndex, testMetadata, hnswParams) + + assert.NoError(t, err) +} + +func TestIndexUpdate_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + testMetadata := map[string]string{ + "foo": "bar", + } + hnswParams := &protos.HnswIndexUpdate{ + MaxMemQueueSize: Ptr(uint32(10)), + } + + err = client.IndexUpdate(ctx, testNamespace, testIndex, testMetadata, hnswParams) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to update index", fmt.Errorf("foo"))) +} + +func TestIndexUpdate_FailUpdateCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + Update(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + testMetadata := map[string]string{ + "foo": "bar", + } + hnswParams := &protos.HnswIndexUpdate{ + MaxMemQueueSize: Ptr(uint32(10)), + } + + err = client.IndexUpdate(ctx, testNamespace, testIndex, testMetadata, hnswParams) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to update index", fmt.Errorf("bar"))) +} + +func TestIndexDrop_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Times(2). + Return(mockConn, nil) + + expectedIndexDropRequest := &protos.IndexDropRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + mockIndexClient. + EXPECT(). + Drop(gomock.Any(), gomock.Any()). + Return(nil, nil). + Do(func(ctx context.Context, in *protos.IndexDropRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexDropRequest, in) + }) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, status.Errorf(codes.NotFound, "foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + err = client.IndexDrop(ctx, testNamespace, testIndex) + + assert.NoError(t, err) +} + +func TestIndexDrop_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + err = client.IndexDrop(ctx, testNamespace, testIndex) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to drop index", fmt.Errorf("foo"))) +} + +func TestIndexDrop_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + Drop(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + err = client.IndexDrop(ctx, testNamespace, testIndex) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to drop index", fmt.Errorf("bar"))) +} + +func TestIndexList_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedIndexListRequest := &protos.IndexListRequest{ + ApplyDefaults: Ptr(true), + } + + expectedIndexDefs := &protos.IndexDefinitionList{ + Indices: []*protos.IndexDefinition{ + { + Id: &protos.IndexId{ + Namespace: "testNamespace0", + Name: "testIndex0", + }, + }, + { + Id: &protos.IndexId{ + Namespace: "testNamespace1", + Name: "testIndex1", + }, + }, + }, + } + + mockIndexClient. + EXPECT(). + List(gomock.Any(), gomock.Any()). + Return(expectedIndexDefs, nil). + Do(func(ctx context.Context, in *protos.IndexListRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexListRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + indexDefs, err := client.IndexList(ctx, true) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedIndexDefs, indexDefs) +} + +func TestIndexList_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + _, err = client.IndexList(ctx, true) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get indexes", fmt.Errorf("foo"))) +} + +func TestIndexList_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + List(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.IndexList(ctx, true) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get indexes", fmt.Errorf("bar"))) +} + +func TestIndexGet_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedIndexListRequest := &protos.IndexGetRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + ApplyDefaults: Ptr(true), + } + + expectedIndexDefs := &protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + mockIndexClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(expectedIndexDefs, nil). + Do(func(ctx context.Context, in *protos.IndexGetRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexListRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + indexDefs, err := client.IndexGet(ctx, testNamespace, testIndex, true) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedIndexDefs, indexDefs) +} + +func TestIndexGet_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + _, err = client.IndexGet(ctx, testNamespace, testIndex, true) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get index", fmt.Errorf("foo"))) +} + +func TestIndexGet_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + _, err = client.IndexGet(ctx, testNamespace, testIndex, true) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get index", fmt.Errorf("bar"))) +} + +func TestIndexGetStatus_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedIndexStatusRequest := &protos.IndexStatusRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + expectedIndexStatusResp := &protos.IndexStatusResponse{ + UnmergedRecordCount: 9, + } + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(expectedIndexStatusResp, nil). + Do(func(ctx context.Context, in *protos.IndexStatusRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexStatusRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + indexDefs, err := client.IndexGetStatus(ctx, testNamespace, testIndex) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedIndexStatusResp, indexDefs) +} + +func TestIndexGetStatus_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + _, err = client.IndexGetStatus(ctx, testNamespace, testIndex) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get index status", fmt.Errorf("foo"))) +} + +func TestIndexGetStatus_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + _, err = client.IndexGetStatus(ctx, testNamespace, testIndex) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get index status", fmt.Errorf("bar"))) +} + +func TestIndexGcInvalidVertices_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + cutoffTime := time.Now() + expectedIndexStatusRequest := &protos.GcInvalidVerticesRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + CutoffTimestamp: cutoffTime.Unix(), + } + + mockIndexClient. + EXPECT(). + GcInvalidVertices(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.GcInvalidVerticesRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexStatusRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + err = client.GcInvalidVertices(ctx, testNamespace, testIndex, cutoffTime) + + assert.NoError(t, err) +} + +func TestIndexGcInvalidVertices_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + cutoffTime := time.Now() + + err = client.GcInvalidVertices(ctx, testNamespace, testIndex, cutoffTime) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to garbage collect invalid vertices", fmt.Errorf("foo"))) +} + +func TestIndexGcInvalidVertices_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + GcInvalidVertices(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + cutoffTime := time.Now() + + err = client.GcInvalidVertices(ctx, testNamespace, testIndex, cutoffTime) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to garbage collect invalid vertices", fmt.Errorf("bar"))) +} + +func TestCreateUser_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.AddUserRequest{ + Credentials: &protos.Credentials{ + Username: "testUser", + Credentials: &protos.Credentials_PasswordCredentials{ + PasswordCredentials: &protos.PasswordCredentials{ + Password: "testPass", + }, + }, + }, + Roles: []string{ + "testRole", + }, + } + + mockUserAdminClient. + EXPECT(). + AddUser(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.AddUserRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + testRoles := []string{"testRole"} + + err = client.CreateUser(ctx, testUser, testPass, testRoles) + + assert.NoError(t, err) +} + +func TestCreateUser_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + testRoles := []string{"testRole"} + + err = client.CreateUser(ctx, testUser, testPass, testRoles) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to create user", fmt.Errorf("foo"))) +} + +func TestCreateUser_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + AddUser(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + testRoles := []string{"testRole"} + + err = client.CreateUser(ctx, testUser, testPass, testRoles) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to create user", fmt.Errorf("bar"))) +} + +func TestUpdateCredentials_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.UpdateCredentialsRequest{ + Credentials: &protos.Credentials{ + Username: "testUser", + Credentials: &protos.Credentials_PasswordCredentials{ + PasswordCredentials: &protos.PasswordCredentials{ + Password: "testPass", + }, + }, + }, + } + + mockUserAdminClient. + EXPECT(). + UpdateCredentials(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.UpdateCredentialsRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + + err = client.UpdateCredentials(ctx, testUser, testPass) + + assert.NoError(t, err) +} + +func TestUpdateCredentials_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + + err = client.UpdateCredentials(ctx, testUser, testPass) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to update user credentials", fmt.Errorf("foo"))) +} + +func TestUpdateCredentials_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + UpdateCredentials(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + + err = client.UpdateCredentials(ctx, testUser, testPass) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to update user credentials", fmt.Errorf("bar"))) +} + +func TestDropUser_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.DropUserRequest{ + Username: "testUser", + } + + mockUserAdminClient. + EXPECT(). + DropUser(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.DropUserRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + err = client.DropUser(ctx, testUser) + + assert.NoError(t, err) +} + +func TestDropUser_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + err = client.DropUser(ctx, testUser) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to drop user", fmt.Errorf("foo"))) +} + +func TestDropUser_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + DropUser(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + err = client.DropUser(ctx, testUser) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to drop user", fmt.Errorf("bar"))) +} + +func TestGetUser_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.GetUserRequest{ + Username: "testUser", + } + + expectedUser := &protos.User{ + Username: "testUser", + Roles: []string{ + "testRole", + }, + } + + mockUserAdminClient. + EXPECT(). + GetUser(gomock.Any(), gomock.Any()). + Return(expectedUser, nil). + Do(func(ctx context.Context, in *protos.GetUserRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + user, err := client.GetUser(ctx, testUser) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedUser, user) +} + +func TestGetUser_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + _, err = client.GetUser(ctx, testUser) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get user", fmt.Errorf("foo"))) +} + +func TestGetUser_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + GetUser(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + _, err = client.GetUser(ctx, testUser) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get user", fmt.Errorf("bar"))) +} + +func TestListUsers_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &emptypb.Empty{} + + expectedUsers := &protos.ListUsersResponse{ + Users: []*protos.User{ + { + + Username: "testUser", + Roles: []string{ + "testRole", + }, + }, + }, + } + + mockUserAdminClient. + EXPECT(). + ListUsers(gomock.Any(), gomock.Any()). + Return(expectedUsers, nil). + Do(func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.ListUsers(ctx) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedUsers, user) +} + +func TestListUsers_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ListUsers(ctx) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to list users", fmt.Errorf("foo"))) +} + +func TestListUsers_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + ListUsers(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ListUsers(ctx) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to list users", fmt.Errorf("bar"))) +} + +func TestRevokeRoles_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.RevokeRolesRequest{ + Username: "testUser", + Roles: []string{"testRole"}, + } + + mockUserAdminClient. + EXPECT(). + RevokeRoles(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.RevokeRolesRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testRoles := []string{"testRole"} + + err = client.RevokeRoles(ctx, testUser, testRoles) + + assert.NoError(t, err) +} + +func TestRevokeRoles_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testRoles := []string{"testRole"} + + err = client.RevokeRoles(ctx, testUser, testRoles) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to revoke user roles", fmt.Errorf("foo"))) +} + +func TestRevokeRoles_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + RevokeRoles(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testRoles := []string{"testRole"} + + err = client.RevokeRoles(ctx, testUser, testRoles) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to revoke user roles", fmt.Errorf("bar"))) +} + +func TestListRoles_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &emptypb.Empty{} + + expectedRoles := &protos.ListRolesResponse{ + Roles: []*protos.Role{ + { + Id: "testRole", + }, + }, + } + + mockUserAdminClient. + EXPECT(). + ListRoles(gomock.Any(), gomock.Any()). + Return(expectedRoles, nil). + Do(func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.ListRoles(ctx) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedRoles, user) +} + +func TestListRoles_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ListRoles(ctx) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to list roles", fmt.Errorf("foo"))) +} + +func TestListRoles_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + ListRoles(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ListRoles(ctx) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to list roles", fmt.Errorf("bar"))) +} + +func TestConnectedNodeEndpoint_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockGrpcConn := NewMockgrpcClientConn(ctrl) + mockConn := &connection{ + grpcConn: mockGrpcConn, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + expectedEndpoint := &protos.ServerEndpoint{ + Address: "1.1.1.1", + Port: 3000, + } + + mockGrpcConn. + EXPECT(). + Target(). + Return("1.1.1.1:3000") + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + endpoint, err := client.ConnectedNodeEndpoint(ctx, nil) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedEndpoint, endpoint) +} + +func TestConnectedNodeEndpoint_FailedGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockGrpcConn := NewMockgrpcClientConn(ctrl) + mockConn := &connection{ + grpcConn: mockGrpcConn, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ConnectedNodeEndpoint(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get connected endpoint", fmt.Errorf("foo"))) +} + +func TestConnectedNodeEndpoint_FailParsePort(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockGrpcConn := NewMockgrpcClientConn(ctrl) + mockConn := &connection{ + grpcConn: mockGrpcConn, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + mockGrpcConn. + EXPECT(). + Target(). + Return("1.1.1.1:aaaa") + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ConnectedNodeEndpoint(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Contains(t, avsError.Error(), "failed to parse port") +} + +func TestClusteringState_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + expectedRequest := &emptypb.Empty{} + + expectedRoles := &protos.ClusteringState{ + IsInCluster: true, + } + + mockClusterInfoClient. + EXPECT(). + GetClusteringState(gomock.Any(), gomock.Any()). + Return(expectedRoles, nil). + Do(func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.ClusteringState(ctx, nil) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedRoles, user) +} + +func TestClusteringState_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ClusteringState(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get clustering state", fmt.Errorf("foo"))) +} + +func TestClusteringState_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + mockClusterInfoClient. + EXPECT(). + GetClusteringState(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ClusteringState(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get clustering state", fmt.Errorf("bar"))) +} + +func TestClusterEndpoints_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + listenerName := "test-listener" + expectedRequest := &protos.ClusterNodeEndpointsRequest{ + ListenerName: &listenerName, + } + + expectedResp := &protos.ClusterNodeEndpoints{} + + mockClusterInfoClient. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(expectedResp, nil). + Do(func(ctx context.Context, in *protos.ClusterNodeEndpointsRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.ClusterEndpoints(ctx, nil, &listenerName) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedResp, user) +} + +func TestClusterEndpoints_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + listenerName := "test-listener" + + _, err = client.ClusterEndpoints(ctx, nil, &listenerName) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get cluster endpoints", fmt.Errorf("foo"))) +} + +func TestClusterEndpoints_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + mockClusterInfoClient. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + listenerName := "test-listener" + _, err = client.ClusterEndpoints(ctx, nil, &listenerName) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get cluster endpoints", fmt.Errorf("bar"))) +} + +func TestAbout_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockAboutClient := protos.NewMockAboutServiceClient(ctrl) + mockConn := &connection{ + aboutClient: mockAboutClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + expectedRequest := &protos.AboutRequest{} + expectedResp := &protos.AboutResponse{} + + mockAboutClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(expectedResp, nil). + Do(func(ctx context.Context, in *protos.AboutRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.About(ctx, nil) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedResp, user) +} + +func TestAbout_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockAboutClient := protos.NewMockAboutServiceClient(ctrl) + mockConn := &connection{ + aboutClient: mockAboutClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.About(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to make about request", fmt.Errorf("foo"))) +} + +func TestAbout_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockAboutClient := protos.NewMockAboutServiceClient(ctrl) + mockConn := &connection{ + aboutClient: mockAboutClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + mockAboutClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, nil, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + _, err = client.About(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to make about request", fmt.Errorf("bar"))) +} diff --git a/channel_provider.go b/connection_provider.go similarity index 59% rename from channel_provider.go rename to connection_provider.go index d5fb63c..cdca0ad 100644 --- a/channel_provider.go +++ b/connection_provider.go @@ -1,4 +1,4 @@ -// Package avs provides a channel provider for connecting to Aerospike servers. +// Package avs provides a connection provider for connecting to Aerospike servers. package avs import ( @@ -7,7 +7,6 @@ import ( "fmt" "log/slog" "math/rand" - "sort" "strings" "sync" "sync/atomic" @@ -21,14 +20,29 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -var errChannelProviderClosed = errors.New("channel provider is closed") +var errConnectionProviderClosed = errors.New("connectionProvider is closed, cannot get connection") + +type grpcClientConn interface { + grpc.ClientConnInterface + Target() string + Close() error +} + +type tokenManager interface { + RequireTransportSecurity() bool + ScheduleRefresh(func() (*connection, error)) + RefreshToken(context.Context, *connection) error + UnaryInterceptor() grpc.UnaryClientInterceptor + StreamInterceptor() grpc.StreamClientInterceptor + Close() +} // connection represents a gRPC client connection and all the clients (stubs) // for the various AVS services. It's main purpose to remove the need to create // multiple clients for the same connection. This follows the documented grpc // best practice of reusing connections. type connection struct { - grpcConn *grpc.ClientConn + grpcConn grpcClientConn transactClient protos.TransactServiceClient authClient protos.AuthServiceClient userAdminClient protos.UserAdminServiceClient @@ -38,7 +52,7 @@ type connection struct { } // newConnection creates a new connection instance. -func newConnection(conn *grpc.ClientConn) *connection { +func newConnection(conn grpcClientConn) *connection { return &connection{ grpcConn: conn, transactClient: protos.NewTransactServiceClient(conn), @@ -50,27 +64,35 @@ func newConnection(conn *grpc.ClientConn) *connection { } } -// channelAndEndpoints represents a combination of a gRPC client connection and server endpoints. -type channelAndEndpoints struct { +func (conn *connection) close() error { + if conn != nil && conn.grpcConn != nil { + return conn.grpcConn.Close() + } + + return nil +} + +// connectionAndEndpoints represents a combination of a gRPC client connection and server endpoints. +type connectionAndEndpoints struct { conn *connection endpoints *protos.ServerEndpointList } -// newConnAndEndpoints creates a new channelAndEndpoints instance. -func newConnAndEndpoints(channel *connection, endpoints *protos.ServerEndpointList) *channelAndEndpoints { - return &channelAndEndpoints{ - conn: channel, +// newConnAndEndpoints creates a new connectionAndEndpoints instance. +func newConnAndEndpoints(conn *connection, endpoints *protos.ServerEndpointList) *connectionAndEndpoints { + return &connectionAndEndpoints{ + conn: conn, endpoints: endpoints, } } -// channelProvider is responsible for managing gRPC client connections to +// connectionProvider is responsible for managing gRPC client connections to // Aerospike servers. // //nolint:govet // We will favor readability over field alignment -type channelProvider struct { +type connectionProvider struct { logger *slog.Logger - nodeConns map[uint64]*channelAndEndpoints + nodeConns map[uint64]*connectionAndEndpoints seedConns []*connection tlsConfig *tls.Config seeds HostPortSlice @@ -79,22 +101,27 @@ type channelProvider struct { clusterID uint64 listenerName *string isLoadBalancer bool - token *tokenManager + token tokenManager stopTendChan chan struct{} closed atomic.Bool + connFactory func(hostPort *HostPort) (*connection, error) } -// newChannelProvider creates a new channelProvider instance. -func newChannelProvider( +// newConnectionProvider creates a new connectionProvider instance. +func newConnectionProvider( ctx context.Context, seeds HostPortSlice, listenerName *string, isLoadBalancer bool, - credentials *UserPassCredentials, + token tokenManager, tlsConfig *tls.Config, logger *slog.Logger, -) (*channelProvider, error) { +) (*connectionProvider, error) { // Initialize the logger. + if logger == nil { + logger = slog.Default() + } + logger = logger.WithGroup("cp") // Validate the seeds. @@ -105,12 +132,7 @@ func newChannelProvider( return nil, errors.New(msg) } - // Create a token manager if username and password are provided. - var token *tokenManager - - if credentials != nil { - token = newJWTToken(credentials.username, credentials.password, logger) - + if token != nil { if token.RequireTransportSecurity() && tlsConfig == nil { msg := "tlsConfig is required when username/password authentication" logger.Error(msg) @@ -119,9 +141,9 @@ func newChannelProvider( } } - // Create the channelProvider instance. - cp := &channelProvider{ - nodeConns: make(map[uint64]*channelAndEndpoints), + // Create the connectionProvider instance. + cp := &connectionProvider{ + nodeConns: make(map[uint64]*connectionAndEndpoints), seeds: seeds, listenerName: listenerName, isLoadBalancer: isLoadBalancer, @@ -134,6 +156,15 @@ func newChannelProvider( closed: atomic.Bool{}, } + cp.connFactory = func(hostPort *HostPort) (*connection, error) { + grpcConn, err := createGrcpConn(cp, hostPort) + if err != nil { + return nil, err + } + + return newConnection(grpcConn), nil + } + // Connect to the seed nodes. err := cp.connectToSeeds(ctx) if err != nil { @@ -148,7 +179,7 @@ func newChannelProvider( // Start the tend routine if load balancing is disabled. if !isLoadBalancer { - cp.updateClusterChannels(ctx) // We want at least one tend to occur before we return + cp.updateClusterConns(ctx) // We want at least one tend to occur before we return cp.logger.Debug("starting tend routine") go cp.tend(context.Background()) // Might add a tend specific timeout in the future? @@ -159,8 +190,12 @@ func newChannelProvider( return cp, nil } -// Close closes the channelProvider and releases all resources. -func (cp *channelProvider) Close() error { +// Close closes the connectionProvider and releases all resources. +func (cp *connectionProvider) Close() error { + if cp == nil { + return nil + } + if !cp.isLoadBalancer { cp.stopTendChan <- struct{}{} <-cp.stopTendChan @@ -168,20 +203,16 @@ func (cp *channelProvider) Close() error { var firstErr error - if cp.token != nil { - cp.token.Close() - } - - for _, channel := range cp.seedConns { - err := channel.grpcConn.Close() + for _, conn := range cp.seedConns { + err := conn.close() if err != nil { if firstErr == nil { firstErr = err } - cp.logger.Error("failed to close seed channel", + cp.logger.Error("failed to close seed connection", slog.Any("error", err), - slog.String("seed", channel.grpcConn.Target()), + slog.String("seed", conn.grpcConn.Target()), ) } } @@ -193,7 +224,7 @@ func (cp *channelProvider) Close() error { firstErr = err } - cp.logger.Error("failed to close node channel", + cp.logger.Error("failed to close node connection", slog.Any("error", err), slog.String("node", conn.conn.grpcConn.Target()), ) @@ -207,14 +238,14 @@ func (cp *channelProvider) Close() error { } // GetSeedConn returns a gRPC client connection to a seed node. -func (cp *channelProvider) GetSeedConn() (*connection, error) { +func (cp *connectionProvider) GetSeedConn() (*connection, error) { if cp.closed.Load() { - cp.logger.Warn("ChannelProvider is closed, cannot get channel") - return nil, errChannelProviderClosed + cp.logger.Warn("ConnectionProvider is closed, cannot get connection") + return nil, errConnectionProviderClosed } if len(cp.seedConns) == 0 { - msg := "no seed channels found" + msg := "no seed connections found" cp.logger.Warn(msg) return nil, errors.New(msg) @@ -227,69 +258,69 @@ func (cp *channelProvider) GetSeedConn() (*connection, error) { // GetRandomConn returns a gRPC client connection to an Aerospike server. If // isLoadBalancer is enabled, it will return the seed connection. -func (cp *channelProvider) GetRandomConn() (*connection, error) { +func (cp *connectionProvider) GetRandomConn() (*connection, error) { if cp.closed.Load() { - cp.logger.Warn("ChannelProvider is closed, cannot get channel") - return nil, errors.New("ChannelProvider is closed") + cp.logger.Warn("ConnectionProvider is closed, cannot get connection") + return nil, errors.New("ConnectionProvider is closed") } if cp.isLoadBalancer { - cp.logger.Debug("load balancer is enabled, using seed channel") + cp.logger.Debug("load balancer is enabled, using seed connection") return cp.GetSeedConn() } cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - discoverdChannels := make([]*channelAndEndpoints, len(cp.nodeConns)) + discoverdConns := make([]*connectionAndEndpoints, len(cp.nodeConns)) i := 0 - for _, channel := range cp.nodeConns { - discoverdChannels[i] = channel + for _, conn := range cp.nodeConns { + discoverdConns[i] = conn i++ } - if len(discoverdChannels) == 0 { - cp.logger.Warn("no node channels found, using seed channel") + if len(discoverdConns) == 0 { + cp.logger.Warn("no node connections found, using seed connection") return cp.GetSeedConn() } - idx := rand.Intn(len(discoverdChannels)) //nolint:gosec // Security is not an issue here + idx := rand.Intn(len(discoverdConns)) //nolint:gosec // Security is not an issue here - return discoverdChannels[idx].conn, nil + return discoverdConns[idx].conn, nil } // GetNodeConn returns a gRPC client connection to a specific node. If the node // ID cannot be found an error is returned. -func (cp *channelProvider) GetNodeConn(nodeID uint64) (*connection, error) { +func (cp *connectionProvider) GetNodeConn(nodeID uint64) (*connection, error) { if cp.closed.Load() { - cp.logger.Warn("ChannelProvider is closed, cannot get channel") - return nil, errors.New("ChannelProvider is closed") + cp.logger.Warn("ConnectionProvider is closed, cannot get connection") + return nil, errors.New("ConnectionProvider is closed") } if cp.isLoadBalancer { - cp.logger.Error("load balancer is enabled, using seed channel") - return nil, errors.New("load balancer is enabled, cannot get specific node channel") + cp.logger.Error("load balancer is enabled, using seed connection") + return nil, errors.New("load balancer is enabled, cannot get specific node connection") } cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - channel, ok := cp.nodeConns[nodeID] + conn, ok := cp.nodeConns[nodeID] if !ok { - msg := "channel not found for specified node id" + msg := "connection not found for specified node id" cp.logger.Error(msg, slog.Uint64("node", nodeID)) return nil, errors.New(msg) } - return channel.conn, nil + return conn.conn, nil } // GetNodeIDs returns the node IDs of all nodes discovered during cluster // tending. If tending is disabled (LB true) then no node IDs are returned. -func (cp *channelProvider) GetNodeIDs() []uint64 { +func (cp *connectionProvider) GetNodeIDs() []uint64 { cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() @@ -303,9 +334,9 @@ func (cp *channelProvider) GetNodeIDs() []uint64 { } // connectToSeeds connects to the seed nodes and creates gRPC client connections. -func (cp *channelProvider) connectToSeeds(ctx context.Context) error { +func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { if len(cp.seedConns) != 0 { - msg := "seed channels already exist, close them first" + msg := "seed connections already exist, close them first" cp.logger.Error(msg) return errors.New(msg) @@ -314,7 +345,7 @@ func (cp *channelProvider) connectToSeeds(ctx context.Context) error { var authErr error wg := sync.WaitGroup{} - seedGrpcConns := make(chan *grpc.ClientConn) + seedConns := make(chan *connection) cp.seedConns = []*connection{} tokenLock := sync.Mutex{} // Ensures only one thread attempts to update token at a time tokenUpdated := false // Ensures token update only occurs once @@ -326,26 +357,30 @@ func (cp *channelProvider) connectToSeeds(ctx context.Context) error { defer wg.Done() logger := cp.logger.With(slog.String("host", seed.String())) + extraCheck := true - grpcConn, err := cp.createGrcpConn(seed) + conn, err := cp.connFactory(seed) if err != nil { - logger.ErrorContext(ctx, "failed to create channel", slog.Any("error", err)) + logger.ErrorContext(ctx, "failed to create connection", slog.Any("error", err)) return } - extraCheck := true - if cp.token != nil { // Only one thread needs to refresh the token. Only first will // succeed others will block tokenLock.Lock() if !tokenUpdated { - err := cp.token.RefreshToken(ctx, grpcConn) + err := cp.token.RefreshToken(ctx, conn) if err != nil { logger.WarnContext(ctx, "failed to refresh token", slog.Any("error", err)) authErr = err tokenLock.Unlock() + err = conn.close() + if err != nil { + logger.WarnContext(ctx, "failed to close connection", slog.Any("error", err)) + } + return } @@ -357,30 +392,34 @@ func (cp *channelProvider) connectToSeeds(ctx context.Context) error { } if extraCheck { - client := protos.NewAboutServiceClient(grpcConn) - - about, err := client.Get(ctx, &protos.AboutRequest{}) + about, err := conn.aboutClient.Get(ctx, &protos.AboutRequest{}) if err != nil { logger.WarnContext(ctx, "failed to connect to seed", slog.Any("error", err)) + + err = conn.close() + if err != nil { + logger.WarnContext(ctx, "failed to close connection", slog.Any("error", err)) + } + return } - if newVersion(about.Version).lt(minimumSupportedAVSVersion) { + if newVersion(about.Version).lt(minimumFullySupportedAVSVersion) { logger.WarnContext(ctx, "incompatible server version", slog.String("version", about.Version)) } } - seedGrpcConns <- grpcConn + seedConns <- conn }(seed) } go func() { wg.Wait() - close(seedGrpcConns) + close(seedConns) }() - for conn := range seedGrpcConns { - cp.seedConns = append(cp.seedConns, newConnection(conn)) + for conn := range seedConns { + cp.seedConns = append(cp.seedConns, conn) } if len(cp.seedConns) == 0 { @@ -401,7 +440,7 @@ func (cp *channelProvider) connectToSeeds(ctx context.Context) error { } // updateNodeConns updates the gRPC client connection for a specific node. -func (cp *channelProvider) updateNodeConns( +func (cp *connectionProvider) updateNodeConns( ctx context.Context, node uint64, endpoints *protos.ServerEndpointList, @@ -423,62 +462,56 @@ func (cp *channelProvider) updateNodeConns( return nil } -// checkAndSetClusterID checks if the cluster ID has changed and updates it if necessary. -func (cp *channelProvider) checkAndSetClusterID(clusterID uint64) bool { - if clusterID != cp.clusterID { - cp.clusterID = clusterID - return true - } - - return false -} - // getTendConns returns all the gRPC client connections for tend operations. -func (cp *channelProvider) getTendConns() []*grpc.ClientConn { +func (cp *connectionProvider) getTendConns() []*connection { cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - channels := make([]*grpc.ClientConn, len(cp.seedConns)+len(cp.nodeConns)) + conns := make([]*connection, len(cp.seedConns)+len(cp.nodeConns)) i := 0 - for _, channel := range cp.seedConns { - channels[i] = channel.grpcConn + for _, conn := range cp.seedConns { + conns[i] = conn i++ } - for _, channel := range cp.nodeConns { - channels[i] = channel.conn.grpcConn + for _, conn := range cp.nodeConns { + conns[i] = conn.conn i++ } - return channels + return conns } // getUpdatedEndpoints retrieves the updated server endpoints from the Aerospike cluster. -func (cp *channelProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]*protos.ServerEndpointList { +func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]*protos.ServerEndpointList { + type idAndEndpoints struct { + endpoints map[uint64]*protos.ServerEndpointList + id uint64 + } + conns := cp.getTendConns() - endpointsChan := make(chan map[uint64]*protos.ServerEndpointList) + newClusterChan := make(chan *idAndEndpoints) endpointsReq := &protos.ClusterNodeEndpointsRequest{ListenerName: cp.listenerName} wg := sync.WaitGroup{} for _, conn := range conns { wg.Add(1) - go func(conn *grpc.ClientConn) { + go func(conn *connection) { defer wg.Done() - logger := cp.logger.With(slog.String("host", conn.Target())) - client := protos.NewClusterInfoServiceClient(conn) + logger := cp.logger.With(slog.String("host", conn.grpcConn.Target())) - clusterID, err := client.GetClusterId(ctx, &emptypb.Empty{}) + clusterID, err := conn.clusterInfoClient.GetClusterId(ctx, &emptypb.Empty{}) if err != nil { logger.WarnContext(ctx, "failed to get cluster ID", slog.Any("error", err)) } - if !cp.checkAndSetClusterID(clusterID.GetId()) { + if clusterID.GetId() == cp.clusterID { logger.DebugContext( ctx, - "old cluster ID found, skipping channel discovery", + "old cluster ID found, skipping connection discovery", slog.Uint64("clusterID", clusterID.GetId()), ) @@ -487,40 +520,58 @@ func (cp *channelProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]* logger.DebugContext(ctx, "new cluster ID found", slog.Uint64("clusterID", clusterID.GetId())) - endpointsResp, err := client.GetClusterEndpoints(ctx, endpointsReq) + endpointsResp, err := conn.clusterInfoClient.GetClusterEndpoints(ctx, endpointsReq) if err != nil { logger.ErrorContext(ctx, "failed to get cluster endpoints", slog.Any("error", err)) return } - endpointsChan <- endpointsResp.Endpoints + newClusterChan <- &idAndEndpoints{ + id: clusterID.GetId(), + endpoints: endpointsResp.Endpoints, + } }(conn) } go func() { wg.Wait() - close(endpointsChan) + close(newClusterChan) }() // Stores the endpoints from the node with the largest view of the cluster - var maxTempEndpoints map[uint64]*protos.ServerEndpointList - for endpoints := range endpointsChan { - if maxTempEndpoints == nil || len(endpoints) > len(maxTempEndpoints) { - maxTempEndpoints = endpoints - cp.logger.DebugContext(ctx, "found new cluster ID", slog.Any("endpoints", maxTempEndpoints)) + // Think about a scenario where the nodes are split into two cluster + // momentarily and the client can see both. We are making the decision here + // to connect to the larger of the two formed cluster. + var largestNewCluster *idAndEndpoints + for cluster := range newClusterChan { + if largestNewCluster == nil || len(cluster.endpoints) > len(largestNewCluster.endpoints) { + largestNewCluster = cluster } } - return maxTempEndpoints + if largestNewCluster != nil { + cp.logger.DebugContext( + ctx, + "largest cluster with new id", + slog.Any("endpoints", largestNewCluster.endpoints), + slog.Uint64("id", largestNewCluster.id), + ) + + cp.clusterID = largestNewCluster.id + + return largestNewCluster.endpoints + } + + return nil } // Checks if the node connections need to be updated and updates them if necessary. -func (cp *channelProvider) checkAndSetNodeConns( +func (cp *connectionProvider) checkAndSetNodeConns( ctx context.Context, newNodeEndpoints map[uint64]*protos.ServerEndpointList, ) { wg := sync.WaitGroup{} - // Find which nodes have a different endpoint list and update their channel + // Find which nodes have a different endpoint list and update their connection for node, newEndpoints := range newNodeEndpoints { wg.Add(1) @@ -535,27 +586,27 @@ func (cp *channelProvider) checkAndSetNodeConns( if ok { if !endpointListEqual(currEndpoints.endpoints, newEndpoints) { - logger.Debug("endpoints for node changed, recreating channel") + logger.Debug("endpoints for node changed, recreating connection") err := currEndpoints.conn.grpcConn.Close() if err != nil { - logger.Warn("failed to close channel", slog.Any("error", err)) + logger.Warn("failed to close connection", slog.Any("error", err)) } // Either this is a new node or its endpoints have changed err = cp.updateNodeConns(ctx, node, newEndpoints) if err != nil { - logger.Error("failed to create new channel", slog.Any("error", err)) + logger.Error("failed to create new connection", slog.Any("error", err)) } } else { cp.logger.Debug("endpoints for node unchanged") } } else { - logger.Debug("new node found, creating new channel") + logger.Debug("new node found, creating new connection") err := cp.updateNodeConns(ctx, node, newEndpoints) if err != nil { - logger.Error("failed to create new channel", slog.Any("error", err)) + logger.Error("failed to create new connection", slog.Any("error", err)) } } }(node, newEndpoints) @@ -565,17 +616,17 @@ func (cp *channelProvider) checkAndSetNodeConns( } // removeDownNodes removes the gRPC client connections for nodes in nodeConns -// that aren't apart of newNodeEndpoints -func (cp *channelProvider) removeDownNodes(newNodeEndpoints map[uint64]*protos.ServerEndpointList) { +// that aren't a part of newNodeEndpoints +func (cp *connectionProvider) removeDownNodes(newNodeEndpoints map[uint64]*protos.ServerEndpointList) { cp.nodeConnsLock.Lock() defer cp.nodeConnsLock.Unlock() - // The cluster state changed. Remove old channels. - for node, channelEndpoints := range cp.nodeConns { + // The cluster state changed. Remove old connections. + for node, connEndpoints := range cp.nodeConns { if _, ok := newNodeEndpoints[node]; !ok { - err := channelEndpoints.conn.grpcConn.Close() + err := connEndpoints.conn.grpcConn.Close() if err != nil { - cp.logger.Warn("failed to close channel", slog.Uint64("node", node), slog.Any("error", err)) + cp.logger.Warn("failed to close connection", slog.Uint64("node", node), slog.Any("error", err)) } delete(cp.nodeConns, node) @@ -583,23 +634,23 @@ func (cp *channelProvider) removeDownNodes(newNodeEndpoints map[uint64]*protos.S } } -// updateClusterChannels updates the gRPC client connections for the Aerospike +// updateClusterConns updates the gRPC client connections for the Aerospike // cluster if the cluster state has changed. -func (cp *channelProvider) updateClusterChannels(ctx context.Context) { +func (cp *connectionProvider) updateClusterConns(ctx context.Context) { updatedEndpoints := cp.getUpdatedEndpoints(ctx) if updatedEndpoints == nil { - cp.logger.Debug("no new cluster ID found, cluster state is unchanged, skipping channel discovery") + cp.logger.Debug("no new cluster ID found, cluster state is unchanged, skipping connection discovery") return } - cp.logger.Debug("new cluster id found, updating channels") + cp.logger.Debug("new cluster id found, updating connections") cp.checkAndSetNodeConns(ctx, updatedEndpoints) cp.removeDownNodes(updatedEndpoints) } -// tend starts a thread to periodically update the cluster channels. -func (cp *channelProvider) tend(ctx context.Context) { +// tend starts a thread to periodically update the cluster connections. +func (cp *connectionProvider) tend(ctx context.Context) { timer := time.NewTimer(cp.tendInterval) defer timer.Stop() @@ -612,7 +663,7 @@ func (cp *channelProvider) tend(ctx context.Context) { ctx, cancel := context.WithTimeout(ctx, cp.tendInterval) // TODO: make configurable? - cp.updateClusterChannels(ctx) + cp.updateClusterConns(ctx) if err := ctx.Err(); err != nil { cp.logger.Warn("tend context cancelled", slog.Any("error", err)) @@ -628,72 +679,9 @@ func (cp *channelProvider) tend(ctx context.Context) { } } -func endpointEqual(a, b *protos.ServerEndpoint) bool { - return a.Address == b.Address && a.Port == b.Port && a.IsTls == b.IsTls -} - -func endpointListEqual(a, b *protos.ServerEndpointList) bool { - if len(a.Endpoints) != len(b.Endpoints) { - return false - } - - aEndpoints := make([]*protos.ServerEndpoint, len(a.Endpoints)) - copy(aEndpoints, a.Endpoints) - - bEndpoints := make([]*protos.ServerEndpoint, len(b.Endpoints)) - copy(bEndpoints, b.Endpoints) - - sortFunc := func(endpoints []*protos.ServerEndpoint) func(int, int) bool { - return func(i, j int) bool { - if endpoints[i].Address < endpoints[j].Address { - return true - } else if endpoints[i].Address > endpoints[j].Address { - return false - } - - return endpoints[i].Port < endpoints[j].Port - } - } - - sort.Slice(aEndpoints, sortFunc(aEndpoints)) - sort.Slice(bEndpoints, sortFunc(bEndpoints)) - - for i, endpoint := range aEndpoints { - if !endpointEqual(endpoint, bEndpoints[i]) { - return false - } - } - - return true -} - -func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort { - return NewHostPort(endpoint.Address, int(endpoint.Port)) -} - -// createGrpcConnFromEndpoints creates a gRPC client connection from the first -// successful endpoint in endpoints. -func (cp *channelProvider) createGrpcConnFromEndpoints( - endpoints *protos.ServerEndpointList, -) (*grpc.ClientConn, error) { - for _, endpoint := range endpoints.Endpoints { - if strings.ContainsRune(endpoint.Address, ':') { - continue // TODO: Add logging and support for IPv6 - } - - conn, err := cp.createGrcpConn(endpointToHostPort(endpoint)) - - if err == nil { - return conn, nil - } - } - - return nil, errors.New("no valid endpoint found") -} - // createGrcpConn creates a gRPC client connection to a host. This handles adding // credential and configuring tls. -func (cp *channelProvider) createGrcpConn(hostPort *HostPort) (*grpc.ClientConn, error) { +func createGrcpConn(cp *connectionProvider, hostPort *HostPort) (grpcClientConn, error) { opts := []grpc.DialOption{} if cp.tlsConfig == nil { @@ -724,11 +712,18 @@ func (cp *channelProvider) createGrcpConn(hostPort *HostPort) (*grpc.ClientConn, return conn, nil } -func (cp *channelProvider) createConnFromEndpoints(endpoints *protos.ServerEndpointList) (*connection, error) { - conn, err := cp.createGrpcConnFromEndpoints(endpoints) - if err != nil { - return nil, err +func (cp *connectionProvider) createConnFromEndpoints(endpoints *protos.ServerEndpointList) (*connection, error) { + for _, endpoint := range endpoints.Endpoints { + if strings.ContainsRune(endpoint.Address, ':') { + continue // TODO: Add logging and support for IPv6 + } + + conn, err := cp.connFactory(endpointToHostPort(endpoint)) + + if err == nil { + return conn, nil + } } - return newConnection(conn), nil + return nil, errors.New("no valid endpoint found") } diff --git a/connection_provider_test.go b/connection_provider_test.go new file mode 100644 index 0000000..86e58d4 --- /dev/null +++ b/connection_provider_test.go @@ -0,0 +1,742 @@ +package avs + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/aerospike/avs-client-go/protos" + "github.com/stretchr/testify/assert" + gomock "go.uber.org/mock/gomock" +) + +func TestNewConnectionProvider_FailSeedsNil(t *testing.T) { + seeds := HostPortSlice{} + listenerName := "listener" + isLoadBalancer := false + tlsConfig := &tls.Config{} //nolint:gosec // tests + + var ( + logger *slog.Logger + token tokenManager + ) + + cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) + if err == nil { + defer cp.Close() + } + + assert.Nil(t, cp) + assert.Equal(t, err, errors.New("seeds cannot be nil or empty")) +} + +func TestNewConnectionProvider_FailNoTLS(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + seeds := HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + listenerName := "listener" + isLoadBalancer := false + + var ( + tlsConfig *tls.Config + logger *slog.Logger + ) + + token := NewMocktokenManager(ctrl) + + token. + EXPECT(). + RequireTransportSecurity(). + Return(true) + + cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) + if err == nil { + defer cp.Close() + } + + assert.Nil(t, cp) + assert.Equal(t, err, errors.New("tlsConfig is required when username/password authentication")) +} + +func TestNewConnectionProvider_FailConnectToSeedConns(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + + seeds := HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + listenerName := "listener" + isLoadBalancer := false + + var ( + tlsConfig *tls.Config + logger *slog.Logger + token tokenManager + ) + + cp, err := newConnectionProvider(ctx, seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) + if err == nil { + defer cp.Close() + } + + assert.Nil(t, cp) + assert.Equal(t, "failed to connect to seeds: context deadline exceeded", err.Error()) +} + +func TestClose_FailsToCloseConns(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSeedConn := NewMockgrpcClientConn(ctrl) + mockNodeConn := NewMockgrpcClientConn(ctrl) + + mockSeedConn. + EXPECT(). + Close(). + Return(fmt.Errorf("foo")) + + mockSeedConn. + EXPECT(). + Target(). + Return("") + + mockNodeConn. + EXPECT(). + Close(). + Return(fmt.Errorf("bar")) + + mockNodeConn. + EXPECT(). + Target(). + Return("") + + cp := &connectionProvider{ + isLoadBalancer: true, + seedConns: []*connection{ + { + grpcConn: mockSeedConn, + }, + }, + nodeConns: map[uint64]*connectionAndEndpoints{ + uint64(1): { + conn: &connection{grpcConn: mockNodeConn}, + endpoints: &protos.ServerEndpointList{}, + }, + }, + logger: slog.Default(), + } + + err := cp.Close() + + assert.Equal(t, fmt.Errorf("foo"), err) +} + +func TestGetSeedConn_FailBecauseClosed(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + } + + cp.closed.Store(true) + + _, err := cp.GetSeedConn() + + assert.Equal(t, errors.New("connectionProvider is closed, cannot get connection"), err) +} + +func TestGetSeedConn_FailSeedConnEmpty(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + } + + cp.closed.Store(false) + + _, err := cp.GetSeedConn() + + assert.Equal(t, errors.New("no seed connections found"), err) +} + +func TestConnectToSeeds_FailedAlreadyConnected(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + } + + cp.seedConns = []*connection{ + { + grpcConn: NewMockgrpcClientConn(ctrl), + }, + } + + err := cp.connectToSeeds(context.Background()) + + assert.Equal(t, errors.New("seed connections already exist, close them first"), err) +} + +func TestConnectToSeeds_FailedFailedToCreateConnection(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + connFactory: func(_ *HostPort) (*connection, error) { + return nil, fmt.Errorf("foo") + }, + } + + cp.seeds = HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + + err := cp.connectToSeeds(context.Background()) + + assert.Equal(t, NewAVSError("failed to connect to seeds", nil), err) +} + +func TestConnectToSeeds_FailedToRefreshToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockToken := NewMocktokenManager(ctrl) + mockToken. + EXPECT(). + RefreshToken(gomock.Any(), gomock.Any()). + Return(fmt.Errorf("foo")) + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + connFactory: func(_ *HostPort) (*connection, error) { + return nil, nil + }, + token: mockToken, + } + + cp.seeds = HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + + err := cp.connectToSeeds(context.Background()) + + assert.Equal(t, NewAVSError("failed to connect to seeds", fmt.Errorf("foo")), err) +} + +func TestUpdateClusterConns_NoNewClusterID(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + + cp := &connectionProvider{ + logger: slog.Default(), + nodeConns: make(map[uint64]*connectionAndEndpoints), + seedConns: []*connection{}, + tlsConfig: nil, + seeds: HostPortSlice{}, + nodeConnsLock: &sync.RWMutex{}, + tendInterval: time.Second * 1, + clusterID: 123, + listenerName: nil, + isLoadBalancer: false, + token: nil, + stopTendChan: make(chan struct{}), + closed: atomic.Bool{}, + } + + cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NoNewClusterID")) + + cp.logger.Debug("Setting up existing node connections") + + grpcConn1 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl) + grpcConn2 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl) + + grpcConn1. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient1. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 123, + }, nil) + + grpcConn2. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient2. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 123, + }, nil) + + // Existing node connections + cp.nodeConns[1] = &connectionAndEndpoints{ + conn: &connection{ + grpcConn: grpcConn1, + clusterInfoClient: mockClusterInfoClient1, + }, + endpoints: &protos.ServerEndpointList{}, + } + + cp.nodeConns[2] = &connectionAndEndpoints{ + conn: &connection{ + grpcConn: grpcConn2, + clusterInfoClient: mockClusterInfoClient2, + }, + endpoints: &protos.ServerEndpointList{}, + } + + cp.logger.Debug("Running updateClusterConns") + + cp.updateClusterConns(ctx) + + assert.Equal(t, uint64(123), cp.clusterID) + assert.Len(t, cp.nodeConns, 2) +} + +func TestUpdateClusterConns_NewClusterIDWithDIFFERENTNodeIDs(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + mockClusterInfoClient1111 := protos.NewMockClusterInfoServiceClient(ctrl) + mockClusterInfoClient2222 := protos.NewMockClusterInfoServiceClient(ctrl) + + mockAboutClient1111 := protos.NewMockAboutServiceClient(ctrl) + mockAboutClient2222 := protos.NewMockAboutServiceClient(ctrl) + + mockAboutClient1111. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + mockAboutClient2222. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + cp := &connectionProvider{ + logger: slog.Default(), + nodeConns: make(map[uint64]*connectionAndEndpoints), + seedConns: []*connection{}, + tlsConfig: &tls.Config{}, //nolint:gosec // tests + seeds: HostPortSlice{}, + nodeConnsLock: &sync.RWMutex{}, + tendInterval: time.Second * 1, + clusterID: 123, + listenerName: nil, + isLoadBalancer: false, + token: nil, + stopTendChan: make(chan struct{}), + closed: atomic.Bool{}, + connFactory: func(hostPort *HostPort) (*connection, error) { + if hostPort.String() == "1.1.1.1:3000" { + return &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + }, nil + } else if hostPort.String() == "2.2.2.2:3000" { + return &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + }, nil + } + + return nil, fmt.Errorf("foo") + }, + } + + cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID")) + + cp.logger.Debug("Setting up existing node connections") + + grpcConn1 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl) + grpcConn2 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl) + + grpcConn1. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient1. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 789, // Different cluster id from client 2 + }, nil) + + mockClusterInfoClient1. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // Smaller num of endpoints from client 2 + 0: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + }, nil) + + grpcConn1. + EXPECT(). + Close(). + Return(nil) + + grpcConn2. + EXPECT(). + Target(). + Return("") + + expectedClusterID := uint64(456) + // After a new cluster is discovered we expect these to be the new nodeConns + expectedNewNodeConns := map[uint64]*connectionAndEndpoints{ + 3: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + 4: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + } + + mockClusterInfoClient2. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: expectedClusterID, + }, nil) + + mockClusterInfoClient2. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // larger, so the cluster id 456 will win + 3: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + 4: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + }, nil) + + grpcConn2. + EXPECT(). + Close(). + Return(nil) + + // Existing node connections. These will be replaced after a new cluster is found. + cp.nodeConns = map[uint64]*connectionAndEndpoints{ + 1: { + conn: &connection{ + grpcConn: grpcConn1, + clusterInfoClient: mockClusterInfoClient1, + }, + endpoints: &protos.ServerEndpointList{}, + }, + 2: { + conn: &connection{ + grpcConn: grpcConn2, + clusterInfoClient: mockClusterInfoClient2, + }, + endpoints: &protos.ServerEndpointList{}, + }, + } + + cp.logger.Debug("Running updateClusterConns") + + cp.updateClusterConns(ctx) + + assert.Equal(t, expectedClusterID, cp.clusterID) + assert.Len(t, cp.nodeConns, 2) + + for k, v := range cp.nodeConns { + assert.EqualExportedValues(t, expectedNewNodeConns[k].endpoints, v.endpoints) + } +} + +func TestUpdateClusterConns_NewClusterIDWithSAMENodeIDs(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + grpcConn1 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl) + grpcConn2 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl) + + expectedClusterID := uint64(456) + + mockClusterInfoClient2. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: expectedClusterID, + }, nil) + + mockClusterInfoClient2. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // larger, so the cluster id 456 will win + 1: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + 2: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + }, nil) + + mockClusterInfoClient1111 := protos.NewMockClusterInfoServiceClient(ctrl) + mockClusterInfoClient2222 := protos.NewMockClusterInfoServiceClient(ctrl) + + mockAboutClient1111 := protos.NewMockAboutServiceClient(ctrl) + mockAboutClient2222 := protos.NewMockAboutServiceClient(ctrl) + + mockAboutClient1111. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + mockAboutClient2222. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + cp := &connectionProvider{ + logger: slog.Default(), + seedConns: []*connection{}, + tlsConfig: &tls.Config{}, //nolint:gosec // tests + seeds: HostPortSlice{}, + nodeConnsLock: &sync.RWMutex{}, + tendInterval: time.Second * 1, + clusterID: 123, + listenerName: nil, + isLoadBalancer: false, + token: nil, + stopTendChan: make(chan struct{}), + closed: atomic.Bool{}, + connFactory: func(hostPort *HostPort) (*connection, error) { + if hostPort.String() == "1.1.1.1:3000" { + return &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + }, nil + } else if hostPort.String() == "2.2.2.2:3000" { + return &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + }, nil + } + + return nil, fmt.Errorf("foo") + }, + + // Existing node connections. These will be replaced after a new cluster is found. + nodeConns: map[uint64]*connectionAndEndpoints{ + 1: { + conn: &connection{ + grpcConn: grpcConn1, + clusterInfoClient: mockClusterInfoClient1, + }, + endpoints: &protos.ServerEndpointList{}, + }, + 2: { + conn: &connection{ + grpcConn: grpcConn2, + clusterInfoClient: mockClusterInfoClient2, + }, + endpoints: &protos.ServerEndpointList{}, + }, + }, + } + + cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID")) + + cp.logger.Debug("Setting up existing node connections") + + grpcConn1. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient1. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 789, // Different cluster id from client 2 + }, nil) + + mockClusterInfoClient1. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // Smaller num of endpoints from client 2 + 0: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + }, nil) + + grpcConn1. + EXPECT(). + Close(). + Return(nil) + + grpcConn2. + EXPECT(). + Target(). + Return("") + + // After a new cluster is discovered we expect these to be the new nodeConns + expectedNewNodeConns := map[uint64]*connectionAndEndpoints{ + 1: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + 2: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + } + + grpcConn2. + EXPECT(). + Close(). + Return(nil) + + cp.logger.Debug("Running updateClusterConns") + + cp.updateClusterConns(ctx) + + assert.Equal(t, expectedClusterID, cp.clusterID) + assert.Len(t, cp.nodeConns, 2) + + for k, v := range cp.nodeConns { + assert.EqualExportedValues(t, expectedNewNodeConns[k].endpoints, v.endpoints) + } +} diff --git a/docker/.env b/docker/.env new file mode 100644 index 0000000..447d4f2 --- /dev/null +++ b/docker/.env @@ -0,0 +1 @@ +AVS_IMAGE=aerospike/aerospike-vector-search:0.10.0 \ No newline at end of file diff --git a/docker/auth/docker-compose.yml b/docker/auth/docker-compose.yml index dbf6d6b..a3ae1a8 100644 --- a/docker/auth/docker-compose.yml +++ b/docker/auth/docker-compose.yml @@ -16,7 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} depends_on: aerospike: condition: service_healthy diff --git a/docker/mtls/docker-compose.yml b/docker/mtls/docker-compose.yml index dbf6d6b..a3ae1a8 100644 --- a/docker/mtls/docker-compose.yml +++ b/docker/mtls/docker-compose.yml @@ -16,7 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} depends_on: aerospike: condition: service_healthy diff --git a/docker/multi-node-LB/docker-compose.yml b/docker/multi-node-LB/docker-compose.yml index 02d8786..89b0981 100644 --- a/docker/multi-node-LB/docker-compose.yml +++ b/docker/multi-node-LB/docker-compose.yml @@ -20,7 +20,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} volumes: - ./config/aerospike-vector-search-1.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml - ./config/features.conf:/etc/aerospike-vector-search/features.conf @@ -35,7 +35,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} volumes: - ./config/aerospike-vector-search-2.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml - ./config/features.conf:/etc/aerospike-vector-search/features.conf @@ -50,7 +50,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} volumes: - ./config/aerospike-vector-search-3.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml - ./config/features.conf:/etc/aerospike-vector-search/features.conf diff --git a/docker/multi-node-client-visibility-err/docker-compose.yml b/docker/multi-node-client-visibility-err/docker-compose.yml index 195f9a6..c0dbcc3 100644 --- a/docker/multi-node-client-visibility-err/docker-compose.yml +++ b/docker/multi-node-client-visibility-err/docker-compose.yml @@ -20,7 +20,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10000:10000" volumes: @@ -37,7 +37,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10001:10001" volumes: @@ -54,7 +54,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} # ports: # - "10002:10002" # This causes the visibility err volumes: diff --git a/docker/multi-node/docker-compose.yml b/docker/multi-node/docker-compose.yml index 2cfd93a..94157a0 100644 --- a/docker/multi-node/docker-compose.yml +++ b/docker/multi-node/docker-compose.yml @@ -20,7 +20,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10000:10000" volumes: @@ -37,7 +37,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10001:10001" volumes: @@ -54,7 +54,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10002:10002" volumes: diff --git a/docker/tls/docker-compose.yml b/docker/tls/docker-compose.yml index 2308444..a3ae1a8 100644 --- a/docker/tls/docker-compose.yml +++ b/docker/tls/docker-compose.yml @@ -16,8 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector- -search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} depends_on: aerospike: condition: service_healthy diff --git a/docker/vanilla/docker-compose.yml b/docker/vanilla/docker-compose.yml index 2308444..a3ae1a8 100644 --- a/docker/vanilla/docker-compose.yml +++ b/docker/vanilla/docker-compose.yml @@ -16,8 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector- -search-private:0.9.1-SNAPSHOT + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} depends_on: aerospike: condition: service_healthy diff --git a/errors.go b/errors.go index 9a38945..f975989 100644 --- a/errors.go +++ b/errors.go @@ -10,7 +10,7 @@ type Error struct { msg string } -func NewAVSError(msg string, err error) error { +func NewAVSError(msg string, err error) *Error { if err != nil { msg = fmt.Sprintf("%s: %s", msg, err.Error()) } @@ -18,14 +18,14 @@ func NewAVSError(msg string, err error) error { return &Error{msg: msg} } -func NewAVSErrorFromGrpc(msg string, gErr error) error { +func NewAVSErrorFromGrpc(msg string, gErr error) *Error { if gErr == nil { return nil } gStatus, ok := status.FromError(gErr) if !ok { - return NewAVSError(gErr.Error(), nil) + return NewAVSError(msg, gErr) } errStr := fmt.Sprintf("%s: server error: %s", msg, gStatus.Code().String()) diff --git a/go.mod b/go.mod index 0bfeb5e..ca47cad 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,11 @@ go 1.21 require ( github.com/aerospike/tools-common-go v0.0.0-20240701164814-36eec593d9c6 + github.com/golang/mock v1.6.0 github.com/stretchr/testify v1.9.0 + go.uber.org/goleak v1.3.0 + go.uber.org/mock v0.4.0 + golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f golang.org/x/net v0.27.0 google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 diff --git a/go.sum b/go.sum index 91a1f11..e58f316 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEe github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 h1:y3N7Bm7Y9/CtpiVkw/ZWj6lSlDF3F74SfKwfTCer72Q= @@ -49,6 +51,7 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= @@ -59,16 +62,44 @@ go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgS go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d h1:JU0iKnSg02Gmb5ZdV8nYsKEKsP6o/FGVWTrw4i1DA9A= google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= diff --git a/integration_single_node_test.go b/integration_single_node_test.go index 41a92e7..28fba80 100644 --- a/integration_single_node_test.go +++ b/integration_single_node_test.go @@ -9,6 +9,9 @@ import ( "fmt" "log/slog" "os" + "strconv" + "strings" + "sync" "testing" "time" @@ -59,7 +62,7 @@ func TestSingleNodeSuite(t *testing.T) { suites := []*SingleNodeTestSuite{ { ServerTestBaseSuite: ServerTestBaseSuite{ - ComposeFile: "docker/multi-node/docker-compose.yml", // vanilla + ComposeFile: "docker/vanilla/docker-compose.yml", // vanilla AvsLB: false, AvsHostPort: avsHostPort, }, @@ -89,7 +92,30 @@ func TestSingleNodeSuite(t *testing.T) { }, } - for _, s := range suites { + testSuiteEnv := os.Getenv("ASVEC_TEST_SUITES") + picked_suites := map[int]struct{}{} + + if testSuiteEnv != "" { + testSuites := strings.Split(testSuiteEnv, ",") + + for _, s := range testSuites { + i, err := strconv.Atoi(s) + if err != nil { + t.Fatalf("unable to convert %s to int: %v", s, err) + } + + picked_suites[i] = struct{}{} + } + } + + logger.Info("Running test suites", slog.Any("suites", picked_suites)) + + for i, s := range suites { + if len(picked_suites) != 0 { + if _, ok := picked_suites[i]; !ok { + continue + } + } suite.Run(t, s) } } @@ -110,7 +136,7 @@ func (suite *SingleNodeTestSuite) TestBasicUpsertGetDelete() { { "test", getUniqueSetName(), - "key1", + getUniqueKey(), map[string]any{ "str": "str", "int": int64(64), @@ -157,6 +183,504 @@ func (suite *SingleNodeTestSuite) TestBasicUpsertGetDelete() { } } +func (suite *SingleNodeTestSuite) TestBasicUpsertExistsDelete() { + records := []struct { + namespace string + set *string + key any + recordData map[string]any + }{ + { + "test", + getUniqueSetName(), + getUniqueKey(), + map[string]any{ + "str": "str", + "int": int64(64), + "float": 3.14, + "bool": false, + "arr": []any{int64(0), int64(1), int64(2), int64(3)}, + "map": map[any]any{ + "foo": "bar", + }, + }, + }, + } + + // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + // defer cancel() + ctx := context.Background() + + for _, rec := range records { + err := suite.AvsClient.Upsert(ctx, rec.namespace, rec.set, rec.key, rec.recordData, false) + suite.NoError(err) + + if err != nil { + return + } + + exists, err := suite.AvsClient.Exists(ctx, rec.namespace, rec.set, rec.key) + suite.NoError(err) + + if err != nil { + return + } + + suite.True(exists) + + err = suite.AvsClient.Delete(ctx, rec.namespace, rec.set, rec.key) + suite.NoError(err) + + if err != nil { + return + } + + exists, err = suite.AvsClient.Exists(ctx, rec.namespace, rec.set, rec.key) + suite.NoError(err) + + suite.False(exists) + } +} + +func (suite *SingleNodeTestSuite) TestIndexCreate() { + testcases := []struct { + name string + namespace string + indexName string + vectorField string + dimension uint32 + vectorDistanceMetric protos.VectorDistanceMetric + opts *IndexCreateOpts + expectedIndex protos.IndexDefinition + }{ + { + "basic", + "test", + "index", + "vector", + 10, + protos.VectorDistanceMetric_SQUARED_EUCLIDEAN, + nil, + protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "test", + Name: "index", + }, + Dimensions: uint32(10), + VectorDistanceMetric: protos.VectorDistanceMetric_SQUARED_EUCLIDEAN, + Type: protos.IndexType_HNSW, + SetFilter: nil, + Field: "vector", + }, + }, + { + "with opts", + "test", + "index", + "vector", + 10, + protos.VectorDistanceMetric_COSINE, + &IndexCreateOpts{ + Sets: []string{"testset"}, + Storage: &protos.IndexStorage{ + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + }, + protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "test", + Name: "index", + }, + Dimensions: uint32(10), + VectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + Type: protos.IndexType_HNSW, + SetFilter: Ptr("testset"), + Field: "vector", + Storage: &protos.IndexStorage{ + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + Params: &protos.IndexDefinition_HnswParams{ + HnswParams: &protos.HnswParams{}, + }, + }, + }, + } + + for _, tc := range testcases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + err := suite.AvsClient.IndexCreate(ctx, tc.namespace, tc.indexName, tc.vectorField, tc.dimension, tc.vectorDistanceMetric, tc.opts) + suite.NoError(err) + + if err != nil { + return + } + + defer suite.AvsClient.IndexDrop(ctx, tc.namespace, tc.indexName) + + index, err := suite.AvsClient.IndexGet(ctx, tc.namespace, tc.indexName, false) + suite.NoError(err) + + if err != nil { + return + } + + suite.EqualExportedValues(tc.expectedIndex, *index) + }) + } +} + +func (suite *SingleNodeTestSuite) TestIndexUpdate() { + testcases := []struct { + name string + namespace string + indexName string + vectorField string + dimension uint32 + vectorDistanceMetric protos.VectorDistanceMetric + opts *IndexCreateOpts + updateLabels map[string]string + updateHnsw *protos.HnswIndexUpdate + expectedIndex protos.IndexDefinition + }{ + { + name: "no update", + namespace: "test", + indexName: "index", + vectorField: "vector", + dimension: 10, + vectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + opts: &IndexCreateOpts{ + Sets: []string{"testset"}, + Storage: &protos.IndexStorage{ + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + }, + updateHnsw: nil, + updateLabels: nil, + expectedIndex: protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "test", + Name: "index", + }, + Dimensions: uint32(10), + VectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + Type: protos.IndexType_HNSW, + SetFilter: Ptr("testset"), + Field: "vector", + Storage: &protos.IndexStorage{ + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + Params: &protos.IndexDefinition_HnswParams{ + HnswParams: &protos.HnswParams{}, + }, + }, + }, + { + name: "update all params", + namespace: "test", + indexName: "index", + vectorField: "vector", + dimension: 10, + vectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + opts: &IndexCreateOpts{ + Sets: []string{"testset"}, + Storage: &protos.IndexStorage{ + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + }, + updateHnsw: &protos.HnswIndexUpdate{ + MaxMemQueueSize: Ptr(uint32(100)), + BatchingParams: &protos.HnswBatchingParams{ + MaxRecords: Ptr(uint32(10_001)), + Interval: Ptr(uint32(10_002)), + }, + CachingParams: &protos.HnswCachingParams{ + MaxEntries: Ptr(uint64(10_003)), + Expiry: Ptr(uint64(10_004)), + }, + HealerParams: &protos.HnswHealerParams{ + MaxScanRatePerNode: Ptr(uint32(10_005)), + MaxScanPageSize: Ptr(uint32(10_006)), + ReindexPercent: Ptr(float32(51)), + Schedule: Ptr("0 0 0 25 12 ?"), + Parallelism: Ptr(uint32(1)), + }, + MergeParams: &protos.HnswIndexMergeParams{ + IndexParallelism: Ptr(uint32(2)), + ReIndexParallelism: Ptr(uint32(3)), + }, + }, + updateLabels: map[string]string{ + "c": "d", + }, + expectedIndex: protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "test", + Name: "index", + }, + Dimensions: uint32(10), + VectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + Type: protos.IndexType_HNSW, + SetFilter: Ptr("testset"), + Field: "vector", + Storage: &protos.IndexStorage{ + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + "c": "d", + }, + Params: &protos.IndexDefinition_HnswParams{ + HnswParams: &protos.HnswParams{ + MaxMemQueueSize: Ptr(uint32(100)), + BatchingParams: &protos.HnswBatchingParams{ + MaxRecords: Ptr(uint32(10_001)), + Interval: Ptr(uint32(10_002)), + }, + CachingParams: &protos.HnswCachingParams{ + MaxEntries: Ptr(uint64(10_003)), + Expiry: Ptr(uint64(10_004)), + }, + HealerParams: &protos.HnswHealerParams{ + MaxScanRatePerNode: Ptr(uint32(10_005)), + MaxScanPageSize: Ptr(uint32(10_006)), + ReindexPercent: Ptr(float32(51)), + Schedule: Ptr("0 0 0 25 12 ?"), + Parallelism: Ptr(uint32(1)), + }, + MergeParams: &protos.HnswIndexMergeParams{ + IndexParallelism: Ptr(uint32(2)), + ReIndexParallelism: Ptr(uint32(3)), + }, + }, + }, + }, + }, + } + + for _, tc := range testcases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + err := suite.AvsClient.IndexCreate(ctx, tc.namespace, tc.indexName, tc.vectorField, tc.dimension, tc.vectorDistanceMetric, tc.opts) + suite.NoError(err) + + if err != nil { + return + } + + defer suite.AvsClient.IndexDrop(ctx, tc.namespace, tc.indexName) + + index, err := suite.AvsClient.IndexGet(ctx, tc.namespace, tc.indexName, false) + suite.NoError(err) + + if err != nil { + return + } + + err = suite.AvsClient.IndexUpdate(ctx, tc.namespace, tc.indexName, tc.updateLabels, tc.updateHnsw) + suite.NoError(err) + + if err != nil { + return + } + + time.Sleep(time.Second * 3) + + index, err = suite.AvsClient.IndexGet(ctx, tc.namespace, tc.indexName, false) + suite.NoError(err) + + suite.EqualExportedValues(tc.expectedIndex, *index) + }) + } +} + +// func (suite *SingleNodeTestSuite) TestIsIndexed() { +// records := []struct { +// namespace string +// set *string +// key any +// recordData map[string]any +// }{ +// { +// namespace: "test", +// set: getUniqueSetName(), +// key: getUniqueKey(), +// recordData: map[string]any{ +// // "str": "str", +// // "int": int64(64), +// // "float": 3.14, +// // "bool": false, +// "arr": getVectorFloat32(10, 1.0), +// // "map": map[any]any{ +// // "foo": "bar", +// // }, +// }, +// }, +// } + +// // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +// // defer cancel() +// ctx := context.Background() + +// for _, rec := range records { +// err := suite.AvsClient.Upsert(ctx, rec.namespace, rec.set, rec.key, rec.recordData, false) +// suite.NoError(err) + +// for i := 0; i < 10; i++ { +// recordData := map[string]any{ +// // "str": "str", +// // "int": int64(64), +// // "float": 3.14, +// // "bool": false, +// "arr": getVectorFloat32(10, float32(i*2)), +// // "map": map[any]any{ +// // "foo": "bar", +// // }, +// } + +// suite.AvsClient.Upsert(ctx, rec.namespace, rec.set, getUniqueKey(), recordData, false) +// } + +// if err != nil { +// return +// } + +// isIndexed, err := suite.AvsClient.IsIndexed(ctx, rec.namespace, rec.set, "index", rec.key) +// suite.Error(err) + +// var indexOpts *IndexCreateOpts + +// if rec.set != nil { +// indexOpts = &IndexCreateOpts{ +// Sets: []string{*rec.set}, +// } +// } + +// err = suite.AvsClient.IndexCreate( +// ctx, +// rec.namespace, +// "index", +// "arr", +// 10, +// protos.VectorDistanceMetric_SQUARED_EUCLIDEAN, +// indexOpts, +// ) +// suite.NoError(err) + +// if err != nil { +// return +// } + +// isIndexed, err = suite.AvsClient.IsIndexed(ctx, rec.namespace, rec.set, "index", rec.key) +// suite.NoError(err) +// suite.False(isIndexed) + +// defer suite.AvsClient.IndexDrop(ctx, rec.namespace, "index") + +// suite.AvsClient.WaitForIndexCompletion(ctx, rec.namespace, "index", time.Second*12) + +// isIndexed, err = suite.AvsClient.IsIndexed(ctx, rec.namespace, rec.set, "index", rec.key) +// suite.NoError(err) + +// // time.Sleep(time.Second * 33330) + +// suite.True(isIndexed) + +// err = suite.AvsClient.Delete(ctx, rec.namespace, rec.set, rec.key) +// suite.NoError(err) +// } +// } + +func (suite *SingleNodeTestSuite) TestIndexGetStatus() { + testcases := []struct { + name string + namespace string + indexName string + vectorField string + dimension uint32 + vectorDistanceMetric protos.VectorDistanceMetric + opts *IndexCreateOpts + expectedStatus protos.IndexStatusResponse + }{ + { + name: "basic", + namespace: "test", + indexName: "index", + vectorField: "vector", + dimension: 10, + vectorDistanceMetric: protos.VectorDistanceMetric_SQUARED_EUCLIDEAN, + opts: nil, + expectedStatus: protos.IndexStatusResponse{UnmergedRecordCount: 0}, + }, + } + + for _, tc := range testcases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + err := suite.AvsClient.IndexCreate(ctx, tc.namespace, tc.indexName, tc.vectorField, tc.dimension, tc.vectorDistanceMetric, tc.opts) + suite.NoError(err) + + if err != nil { + return + } + + defer suite.AvsClient.IndexDrop(ctx, tc.namespace, tc.indexName) + + status, err := suite.AvsClient.IndexGetStatus(ctx, tc.namespace, tc.indexName) + suite.NoError(err) + + if err != nil { + return + } + + suite.EqualExportedValues(tc.expectedStatus, *status) + }) + } +} + +func (suite *SingleNodeTestSuite) TestIndexGCInvalidVertices() { + ctx := context.Background() + namespace := "test" + indexName := "index" + vectorField := "vector" + dimension := uint32(10) + vectorDistanceMetric := protos.VectorDistanceMetric_SQUARED_EUCLIDEAN + opts := &IndexCreateOpts{} + + err := suite.AvsClient.IndexCreate(ctx, namespace, indexName, vectorField, dimension, vectorDistanceMetric, opts) + suite.NoError(err) + + defer suite.AvsClient.IndexDrop(ctx, namespace, indexName) + + err = suite.AvsClient.GcInvalidVertices(ctx, namespace, indexName, time.Now()) + suite.NoError(err) +} + func (suite *SingleNodeTestSuite) TestFailsToInsertAlreadyExists() { ctx := context.Background() key := getUniqueKey() @@ -234,6 +758,13 @@ func getUniqueSetName() *string { return &val } +var userNameCount = -1 + +func getUniqueUserName() string { + userNameCount++ + return "user" + fmt.Sprintf("%d", userNameCount) +} + func createNeighborFloat32(namespace string, set *string, key string, distance float32, vector []float32) *Neighbor { return &Neighbor{ Namespace: namespace, @@ -334,6 +865,10 @@ func (suite *SingleNodeTestSuite) TestVectorSearchFloat32() { return } + isIndexed, err := suite.AvsClient.IsIndexed(ctx, tc.namespace, setName, indexName, getKey(0)) + suite.NoError(err) + suite.True(isIndexed) + time.Sleep(time.Second * 10) neighbors, err := suite.AvsClient.VectorSearchFloat32(ctx, tc.namespace, indexName, tc.query, 3, nil, nil, nil) @@ -427,3 +962,385 @@ func (suite *SingleNodeTestSuite) TestVectorSearchBool() { }) } } + +func (suite *SingleNodeTestSuite) TestConcurrentWrites() { + wg := sync.WaitGroup{} + numWrites := 10_000 + keys := []string{} + + for i := 0; i < numWrites; i++ { + keys = append(keys, fmt.Sprintf("concurrent-key-%d", i)) + } + + for i := 0; i < numWrites; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + suite.AvsClient.Upsert(context.Background(), testNamespace, nil, keys[i], map[string]any{"foo": "bar"}, false) + }(i) + } + + wg.Wait() + + wg = sync.WaitGroup{} + + for i := 0; i < numWrites; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _, err := suite.AvsClient.Get(context.Background(), testNamespace, nil, keys[i], nil, nil) + suite.Assert().NoError(err) + }(i) + } + + wg.Wait() +} + +func (suite *SingleNodeTestSuite) TestUserCreate() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) + suite.NoError(err) + + actualUser, err := suite.AvsClient.GetUser(ctx, username) + suite.NoError(err) + + expectedUser := protos.User{ + Username: username, + Roles: []string{ + "read-write", + }, + } + + suite.EqualExportedValues(expectedUser, *actualUser) + return + +} + +func (suite *SingleNodeTestSuite) TestUserDelete() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) + suite.NoError(err) + + err = suite.AvsClient.DropUser(ctx, username) + suite.NoError(err) + + _, err = suite.AvsClient.GetUser(ctx, username) + suite.Error(err) +} + +// func (suite *SingleNodeTestSuite) TestUserUpdateCredentials() { +// suite.SkipIfUserPassAuthDisabled() + +// ctx := context.Background() +// username := getUniqueUserName() + +// err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) +// suite.NoError(err) + +// err = suite.AvsClient.UpdateCredentials(ctx, username, "new-password") +// suite.NoError(err) + +// c, err := NewClient( +// ctx, +// HostPortSlice{suite.AvsHostPort}, +// nil, +// suite.AvsLB, +// NewCredentialsFromUserPass(username, "new-password"), +// suite.AvsTLSConfig, +// suite.Logger, +// ) +// suite.NoError(err) + +// c.Close() +// } + +func (suite *SingleNodeTestSuite) TestUserGrantRoles() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) + suite.NoError(err) + + err = suite.AvsClient.GrantRoles(ctx, username, []string{"admin"}) + suite.NoError(err) + + actualUser, err := suite.AvsClient.GetUser(ctx, username) + suite.NoError(err) + + expectedUser := protos.User{ + Username: username, + Roles: []string{ + "admin", + "read-write", + }, + } + + suite.EqualExportedValues(expectedUser, *actualUser) +} + +func (suite *SingleNodeTestSuite) TestUserRevokeRoles() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write", "admin"}) + suite.NoError(err) + + err = suite.AvsClient.RevokeRoles(ctx, username, []string{"admin"}) + suite.NoError(err) + + actualUser, err := suite.AvsClient.GetUser(ctx, username) + suite.NoError(err) + + expectedUser := protos.User{ + Username: username, + Roles: []string{ + "read-write", + }, + } + + suite.EqualExportedValues(expectedUser, *actualUser) +} + +func (suite *SingleNodeTestSuite) TestListUsers() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) + suite.NoError(err) + + username1 := getUniqueUserName() + + err = suite.AvsClient.CreateUser(ctx, username1, "test-password", []string{"read-write"}) + suite.NoError(err) + + users, err := suite.AvsClient.ListUsers(ctx) + suite.NoError(err) + + suite.Equal(3, len(users.Users)) + + for _, user := range users.Users { + if user.Username == username { + suite.Equal([]string{"read-write"}, user.Roles) + } else if user.Username == username1 { + suite.Equal([]string{"read-write"}, user.Roles) + } else { + suite.Equal("admin", user.Username) + suite.Equal([]string{"admin", "read-write"}, user.Roles) + } + } +} + +func (suite *SingleNodeTestSuite) TestListRoles() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + roles, err := suite.AvsClient.ListRoles(ctx) + suite.NoError(err) + + suite.Equal(2, len(roles.Roles)) +} + +func (suite *SingleNodeTestSuite) TestNodeIDs() { + ctx := context.Background() + nodeIDs := suite.AvsClient.NodeIDs(ctx) + + if suite.AvsLB { + suite.Equal(0, len(nodeIDs)) + } else { + suite.Equal(1, len(nodeIDs)) + } +} + +func (suite *SingleNodeTestSuite) TestConnectedNodeEndpoint() { + testCases := []struct { + name string + nodeId *protos.NodeId + listenerName *string + expectedEndpoints *protos.ServerEndpoint + expectedErrMsg *string + }{ + { + name: "nil-node", + nodeId: nil, + expectedEndpoints: &protos.ServerEndpoint{ + Address: "localhost", + Port: 10000, + IsTls: false, + }, + }, + { + name: "node id DNE", + nodeId: &protos.NodeId{ + Id: 1, + }, + expectedErrMsg: Ptr("failed to get connected endpoint"), + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + endpoint, err := suite.AvsClient.ConnectedNodeEndpoint(ctx, tc.nodeId) + + if tc.expectedErrMsg != nil { + suite.Error(err) + suite.Contains(err.Error(), *tc.expectedErrMsg) + return + } else { + suite.NoError(err) + } + + suite.EqualExportedValues(tc.expectedEndpoints, endpoint) + }) + } +} + +func (suite *SingleNodeTestSuite) TestClusteringState() { + testCases := []struct { + name string + nodeId *protos.NodeId + listenerName *string + expectedErrMsg *string + }{ + { + name: "nil-node", + nodeId: nil, + }, + { + name: "node id DNE", + nodeId: &protos.NodeId{ + Id: 1, + }, + expectedErrMsg: Ptr("failed to get clustering state"), + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + state, err := suite.AvsClient.ClusteringState(ctx, tc.nodeId) + + if tc.expectedErrMsg != nil { + suite.Error(err) + suite.Contains(err.Error(), *tc.expectedErrMsg) + return + } else { + suite.NoError(err) + } + + // Simple test. Ideally we will have a better check in the future. + // Does not test cluster-id + suite.True(state.IsInCluster) + suite.Len(state.Members, 1) + }) + } +} + +func (suite *SingleNodeTestSuite) TestClusterEndpoints() { + testCases := []struct { + name string + nodeId *protos.NodeId + listenerName *string + expectedEndpoints []*protos.ServerEndpoint + expectedErrMsg *string + }{ + { + name: "nil-node", + nodeId: nil, + expectedEndpoints: []*protos.ServerEndpoint{ + { + Address: "127.0.0.1", + Port: 10000, + IsTls: suite.AvsTLSConfig != nil, + }, + }, + }, + { + name: "node id DNE", + nodeId: &protos.NodeId{ + Id: 1, + }, + expectedErrMsg: Ptr("failed to get cluster endpoints"), + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + endpoints, err := suite.AvsClient.ClusterEndpoints(ctx, tc.nodeId, tc.listenerName) + + if tc.expectedErrMsg != nil { + suite.Error(err) + suite.Contains(err.Error(), *tc.expectedErrMsg) + return + } else { + suite.NoError(err) + } + + for id, endpoint := range endpoints.Endpoints { + // When LB is true we aren't able to validate the node-id since + // we did not tend the cluster. When LB is false we are able to + // get the node-id of the single node. + if !suite.AvsLB { + nodeId := suite.AvsClient.NodeIDs(ctx)[0] + suite.Assert().Equal(nodeId.Id, id) + } + + suite.EqualExportedValues(tc.expectedEndpoints[0], endpoint.Endpoints[0]) + } + }) + } +} + +func (suite *SingleNodeTestSuite) TestAbout() { + testCases := []struct { + name string + nodeId *protos.NodeId + expectedVersion string + expectedErrMsg *string + }{ + { + name: "nil-node", + nodeId: nil, + expectedVersion: "0.10.0", + }, + { + name: "node id DNE", + nodeId: &protos.NodeId{ + Id: 1, + }, + expectedErrMsg: Ptr("failed to make about request"), + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + actualVersion, err := suite.AvsClient.About(ctx, tc.nodeId) + + if tc.expectedErrMsg != nil { + suite.Error(err) + suite.Contains(err.Error(), *tc.expectedErrMsg) + return + } else { + suite.NoError(err) + } + + suite.Equal(actualVersion.GetVersion(), tc.expectedVersion) + }) + } +} diff --git a/makefile b/makefile index 61b9c7b..e6aa9f5 100644 --- a/makefile +++ b/makefile @@ -1,16 +1,48 @@ +ifeq (,$(shell go env GOBIN)) +GOBIN=$(shell go env GOPATH)/bin +else +GOBIN=$(shell go env GOBIN) +endif + +GOLANGCI_LINT ?= $(GOBIN)/golangci-lint +GOLANGCI_LINT_VERSION ?= v1.58.0 + +MOCKGEN ?= $(GOBIN)/mockgen +MOCKGEN_VERSION ?= v0.3.0 + ROOT_DIR = $(shell pwd) PROTO_DIR = $(ROOT_DIR)/protos COVERAGE_DIR = $(ROOT_DIR)/coverage COV_UNIT_DIR = $(COVERAGE_DIR)/unit COV_INTEGRATION_DIR = $(COVERAGE_DIR)/integration +$(GOLANGCI_LINT): $(GOBIN) + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GOBIN) $(GOLANGCI_LINT_VERSION) + .PHONY: protos protos: protoc --proto_path=$(PROTO_DIR) --go_out=$(PROTO_DIR) --go_opt=paths=source_relative \ --go-grpc_out=$(PROTO_DIR) --go-grpc_opt=paths=source_relative $(PROTO_DIR)/*.proto +.PHONY: get-mockgen +get-mockgen: $(MOCKGEN) ## Download mockgen locally if necessary. +$(MOCKGEN): $(GOBIN) + go install go.uber.org/mock/mockgen@$(MOCKGEN_VERSION) + +.PHONY: mocks +mocks: get-mockgen + $(MOCKGEN) --source client.go --destination client_mock.go --package avs + $(MOCKGEN) --source connection_provider.go --destination connection_provider_mock.go --package avs + $(MOCKGEN) --source protos/auth_grpc.pb.go --destination protos/auth_grpc_mock.pb.go --package protos + $(MOCKGEN) --source protos/index_grpc.pb.go --destination protos/index_grpc_mock.pb.go --package protos + $(MOCKGEN) --source protos/transact_grpc.pb.go --destination protos/transact_grpc_mock.pb.go --package protos + $(MOCKGEN) --source protos/types.pb.go --destination protos/types_mock.pb.go --package protos + $(MOCKGEN) --source protos/user-admin_grpc.pb.go --destination protos/user-admin_grpc_mock.pb.go --package protos + $(MOCKGEN) --source protos/vector-db_grpc.pb.go --destination protos/vector-db_grpc_mock.pb.go --package protos + + .PHONY: test -test: integration unit +test: unit integration .PHONY: integration integration: @@ -18,7 +50,7 @@ integration: go test -tags=integration -timeout 30m -cover ./... -args -test.gocoverdir=$(COV_INTEGRATION_DIR) .PHONY: unit -unit: +unit: mocks mkdir -p $(COV_UNIT_DIR) || true go test -tags=unit -cover ./... -args -test.gocoverdir=$(COV_UNIT_DIR) @@ -26,9 +58,13 @@ unit: coverage: test go tool covdata textfmt -i="$(COV_INTEGRATION_DIR),$(COV_UNIT_DIR)" -o=$(COVERAGE_DIR)/tmp.cov go tool cover -func=$(COVERAGE_DIR)/tmp.cov - grep -v 'testutils.go' $(COVERAGE_DIR)/tmp.cov > $(COVERAGE_DIR)/total.cov + grep -Ev '(testutils\.go|.*\.pb\.go)' $(COVERAGE_DIR)/tmp.cov > $(COVERAGE_DIR)/total.cov PHONY: view-coverage view-coverage: $(COVERAGE_DIR)/total.cov - go tool cover -html=$(COVERAGE_DIR)/total.cov \ No newline at end of file + go tool cover -html=$(COVERAGE_DIR)/total.cov + +PHONY: lint +lint: $(GOLANGCI_LINT) mocks + $(GOLANGCI_LINT) run \ No newline at end of file diff --git a/protos/utils_test.go b/protos/utils_test.go index 2aef07f..7e97e65 100644 --- a/protos/utils_test.go +++ b/protos/utils_test.go @@ -2,12 +2,13 @@ package protos import ( "fmt" + "sort" "testing" "github.com/stretchr/testify/assert" ) -func GetStrPtr(str string) *string { +func Ptr(str string) *string { ptr := str return &ptr } @@ -21,7 +22,7 @@ func TestConvertToKey(t *testing.T) { input: "testString", expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_StringValue{ StringValue: "testString", }, @@ -31,7 +32,7 @@ func TestConvertToKey(t *testing.T) { input: []byte{0x01, 0x02, 0x03}, expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_BytesValue{ BytesValue: []byte{0x01, 0x02, 0x03}, }, @@ -41,7 +42,7 @@ func TestConvertToKey(t *testing.T) { input: int32(123), expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_IntValue{ IntValue: 123, }, @@ -51,7 +52,7 @@ func TestConvertToKey(t *testing.T) { input: int64(123456789), expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_LongValue{ LongValue: 123456789, }, @@ -61,7 +62,7 @@ func TestConvertToKey(t *testing.T) { input: int(123456789), expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_LongValue{ LongValue: 123456789, }, @@ -75,7 +76,7 @@ func TestConvertToKey(t *testing.T) { } for _, tc := range testCases { - result, err := ConvertToKey("testNamespace", GetStrPtr("testSet"), tc.input) + result, err := ConvertToKey("testNamespace", Ptr("testSet"), tc.input) assert.Equal(t, tc.expected, result) assert.Equal(t, tc.expectedErr, err) @@ -97,59 +98,59 @@ func TestConvertFromKey(t *testing.T) { { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_StringValue{ StringValue: "testString", }, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: "testString", }, { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_BytesValue{ BytesValue: []byte{0x01, 0x02, 0x03}, }, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: []byte{0x01, 0x02, 0x03}, }, { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_IntValue{ IntValue: 123, }, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: int32(123), }, { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_LongValue{ LongValue: 123456789, }, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: int64(123456789), }, { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &keyUnknown{}, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: nil, // Unsupported or nil input expectedErr: fmt.Errorf("unsupported key value type: *protos.keyUnknown"), }, @@ -171,7 +172,7 @@ func TestConvertToValue(t *testing.T) { testCases := []struct { input any expected *Value - expectedErr error + expectedErr bool }{ { input: "testString", @@ -369,10 +370,48 @@ func TestConvertToValue(t *testing.T) { }, }, }, + { + input: []float32{1, 2}, + expected: &Value{ + Value: &Value_VectorValue{ + VectorValue: &Vector{ + Data: &Vector_FloatData{ + FloatData: &FloatData{ + Value: []float32{1, 2}, + }, + }, + }, + }, + }, + }, + { + input: []bool{true, false}, + expected: &Value{ + Value: &Value_VectorValue{ + VectorValue: &Vector{ + Data: &Vector_BoolData{ + BoolData: &BoolData{ + Value: []bool{true, false}, + }, + }, + }, + }, + }, + }, + { + input: map[int]any{10: struct{}{}}, + expected: nil, + expectedErr: true, + }, + { + input: []any{struct{}{}}, + expected: nil, + expectedErr: true, + }, { input: struct{}{}, // Unsupported type expected: nil, - expectedErr: fmt.Errorf("unsupported value type: struct {}"), + expectedErr: true, }, } @@ -381,11 +420,489 @@ func TestConvertToValue(t *testing.T) { result, err := ConvertToValue(tc.input) assert.Equal(t, tc.expected, result) - assert.Equal(t, tc.expectedErr, err) + + if tc.expectedErr { + assert.Error(t, err) + } }) } } +func TestConvertToMapKey(t *testing.T) { + testCases := []struct { + input any + expected *MapKey + expectedErr error + }{ + { + input: "testString", + expected: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "testString", + }, + }, + }, + { + input: int32(123), + expected: &MapKey{ + Value: &MapKey_IntValue{ + IntValue: 123, + }, + }, + }, + { + input: int64(123456789), + expected: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: 123456789, + }, + }, + }, + { + input: int(123456789), + expected: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: 123456789, + }, + }, + }, + { + input: []byte{0x01, 0x02, 0x03}, + expected: &MapKey{ + Value: &MapKey_BytesValue{ + BytesValue: []byte{0x01, 0x02, 0x03}, + }, + }, + }, + { + input: struct{}{}, + expected: nil, // Unsupported type + expectedErr: fmt.Errorf("unsupported key type: struct {}"), + }, + } + + for _, tc := range testCases { + result, err := ConvertToMapKey(tc.input) + + assert.Equal(t, tc.expected, result) + assert.Equal(t, tc.expectedErr, err) + } +} + +type mapKeyValueUnknown struct{} + +func (*mapKeyValueUnknown) isMapKey_Value() {} //nolint:revive,stylecheck // Grpc generated + +func TestConvertFromMapKey(t *testing.T) { + testCases := []struct { + input *MapKey + expected any + expectedErr error + }{ + { + input: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "testString", + }, + }, + expected: "testString", + }, + { + input: &MapKey{ + Value: &MapKey_BytesValue{ + BytesValue: []byte{0x01, 0x02, 0x03}, + }, + }, + expected: []byte{0x01, 0x02, 0x03}, + }, + { + input: &MapKey{ + Value: &MapKey_IntValue{ + IntValue: 123, + }, + }, + expected: int32(123), + }, + { + input: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: 123456789, + }, + }, + expected: int64(123456789), + }, + { + input: &MapKey{ + Value: &mapKeyValueUnknown{}, + }, + expected: nil, + expectedErr: fmt.Errorf("unsupported map key value type: *protos.mapKeyValueUnknown"), + }, + } + + for _, tc := range testCases { + result, err := ConvertFromMapKey(tc.input) + + assert.Equal(t, tc.expected, result) + assert.Equal(t, tc.expectedErr, err) + } +} + +func TestConvertToMapValue(t *testing.T) { + testCases := []struct { + input any + expected *Map + expectedErr *string + }{ + { + input: map[any]any{"key": "value"}, + expected: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "key", + }, + }, + Value: &Value{ + Value: &Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + }, + }, + { + input: map[string]string{"key": "value"}, + expected: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "key", + }, + }, + Value: &Value{ + Value: &Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + }, + }, + { + input: map[int]float64{10: 3.124}, + expected: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: int64(10), + }, + }, + Value: &Value{ + Value: &Value_DoubleValue{ + DoubleValue: 3.124, + }, + }, + }, + }, + }, + }, + { + input: map[int]any{10: struct{}{}}, + expected: nil, + expectedErr: Ptr("unsupported map value: unsupported value type: struct {}"), + }, + { + input: map[any]any{struct{}{}: 10}, + expected: nil, + expectedErr: Ptr("unsupported map key: unsupported key type: struct {}"), + }, + } + + for _, tc := range testCases { + result, err := ConvertToMapValue(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr != nil { + assert.ErrorContains(t, err, *tc.expectedErr) + } + } +} + +func TestConvertToList(t *testing.T) { + testCases := []struct { + input []any + expected *List + expectedErr *string + }{ + { + input: []any{"item1", "item2"}, + expected: &List{ + Entries: []*Value{ + { + Value: &Value_StringValue{ + StringValue: "item1", + }, + }, + { + Value: &Value_StringValue{ + StringValue: "item2", + }, + }, + }, + }, + }, + { + input: []any{1, 2}, + expected: &List{ + Entries: []*Value{ + { + Value: &Value_LongValue{ + LongValue: int64(1), + }, + }, + { + Value: &Value_LongValue{ + LongValue: int64(2), + }, + }, + }, + }, + }, + { + input: []any{true, false}, + expected: &List{ + Entries: []*Value{ + { + Value: &Value_BooleanValue{ + BooleanValue: true, + }, + }, + { + Value: &Value_BooleanValue{ + BooleanValue: false, + }, + }, + }, + }, + }, + { + input: []any{struct{}{}}, + expected: nil, + expectedErr: Ptr("unsupported list value: unsupported value type: struct {}"), + }, + } + + for _, tc := range testCases { + result, err := ConvertToList(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr != nil { + assert.ErrorContains(t, err, *tc.expectedErr) + } + } +} + +func TestConvertFromListValue(t *testing.T) { + testCases := []struct { + input *List + expected []any + expectedErr *string + }{ + { + input: &List{ + Entries: []*Value{ + { + Value: &Value_StringValue{ + StringValue: "item1", + }, + }, + { + Value: &Value_StringValue{ + StringValue: "item2", + }, + }, + }, + }, + expected: []any{"item1", "item2"}, + }, + { + input: &List{ + Entries: []*Value{ + { + Value: &Value_LongValue{ + LongValue: int64(1), + }, + }, + { + Value: &Value_LongValue{ + LongValue: int64(2), + }, + }, + }, + }, + expected: []any{int64(1), int64(2)}, + }, + { + input: &List{ + Entries: []*Value{ + { + Value: &Value_BooleanValue{ + BooleanValue: true, + }, + }, + { + Value: &Value_BooleanValue{ + BooleanValue: false, + }, + }, + }, + }, + expected: []any{true, false}, + }, + { + input: &List{ + Entries: []*Value{ + { + Value: &valueUnknown{}, + }, + }, + }, + expected: nil, + expectedErr: Ptr("unsupported list value: unsupported value type: *protos.valueUnknown"), + }, + } + + for _, tc := range testCases { + result, err := ConvertFromListValue(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr != nil { + assert.ErrorContains(t, err, *tc.expectedErr) + } + } +} + +func TestConvertFromMapValue(t *testing.T) { + var nilMap map[any]any + testCases := []struct { + input *Map + expected any + expectedErr *string + }{ + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "key", + }, + }, + Value: &Value{ + Value: &Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + }, + expected: map[any]any{"key": "value"}, + }, + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "key", + }, + }, + Value: &Value{ + Value: &Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + }, + expected: map[any]any{"key": "value"}, + }, + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: int64(10), + }, + }, + Value: &Value{ + Value: &Value_DoubleValue{ + DoubleValue: 3.124, + }, + }, + }, + }, + }, + expected: map[any]any{int64(10): 3.124}, + }, + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &mapKeyValueUnknown{}, + }, + Value: &Value{ + Value: &Value_DoubleValue{ + DoubleValue: 3.124, + }, + }, + }, + }, + }, + expected: nilMap, + expectedErr: Ptr("unsupported map key value type: *protos.mapKeyValueUnknown"), + }, + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: int64(10), + }, + }, + Value: &Value{ + Value: &valueUnknown{}, + }, + }, + }, + }, + expected: nilMap, + expectedErr: Ptr("unsupported map value: unsupported value type: *protos.valueUnknown"), + }, + } + + for _, tc := range testCases { + result, err := ConvertFromMapValue(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr != nil { + assert.ErrorContains(t, err, *tc.expectedErr) + } + } +} + type valueUnknown struct{} func (*valueUnknown) isValue_Value() {} //nolint:revive,stylecheck // Grpc generated @@ -496,6 +1013,20 @@ func TestConvertFromValue(t *testing.T) { }, expected: []any{"item1", "item2"}, }, + { + input: &Value{ + Value: &Value_VectorValue{ + VectorValue: &Vector{ + Data: &Vector_FloatData{ + FloatData: &FloatData{ + Value: []float32{1, 2}, + }, + }, + }, + }, + }, + expected: []float32{float32(1), float32(2)}, + }, { input: &Value{ Value: &valueUnknown{}, @@ -514,3 +1045,152 @@ func TestConvertFromValue(t *testing.T) { }) } } + +func TestConvertToFields(t *testing.T) { + testCases := []struct { + input map[string]any + expected []*Field + expectedErr bool + }{ + { + input: map[string]any{ + "key1": "value1", + "key2": 123, + }, + expected: []*Field{ + { + Name: "key1", + Value: &Value{Value: &Value_StringValue{StringValue: "value1"}}, + }, + { + Name: "key2", + Value: &Value{Value: &Value_LongValue{LongValue: 123}}, + }, + }, + }, + { + input: map[string]any{ + "key1": "value1", + "key2": struct{}{}, + }, + expected: nil, + expectedErr: true, + }, + } + + sortFunc := func(fields []*Field) func(int, int) bool { + return func(i, j int) bool { + return fields[i].Name < fields[j].Name + } + } + + for _, tc := range testCases { + result, err := ConvertToFields(tc.input) + assert.Equal(t, len(tc.expected), len(result)) + + sort.Slice(result, sortFunc(result)) + sort.Slice(tc.expected, sortFunc(result)) + + for i := range tc.expected { + assert.EqualExportedValues(t, tc.expected[i], result[i]) + } + + if tc.expectedErr { + assert.Error(t, err) + } + } +} + +func TestConvertFromFields(t *testing.T) { + testCases := []struct { + input []*Field + expected map[string]any + expectedErr bool + }{ + { + input: []*Field{ + { + Name: "key1", + Value: &Value{Value: &Value_StringValue{StringValue: "value1"}}, + }, + { + Name: "key2", + Value: &Value{Value: &Value_LongValue{LongValue: 123}}, + }, + }, + expected: map[string]any{ + "key1": "value1", + "key2": int64(123), + }, + }, + { + input: []*Field{ + { + Name: "key1", + Value: &Value{Value: &Value_StringValue{StringValue: "value1"}}, + }, + { + Name: "key2", + Value: &Value{Value: &valueUnknown{}}, + }, + }, + expected: nil, + expectedErr: true, + }, + } + + for _, tc := range testCases { + result, err := ConvertFromFields(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr { + assert.Error(t, err) + } + } +} + +type unknownVectorType struct{} + +func (*unknownVectorType) isVector_Data() {} //nolint:revive,stylecheck // Grpc generated + +func TestConvertFromVector(t *testing.T) { + testCases := []struct { + input *Vector + expected any + expectedErr error + }{ + { + input: &Vector{ + Data: &Vector_FloatData{ + FloatData: &FloatData{ + Value: []float32{1, 2}, + }, + }, + }, + expected: []float32{float32(1), float32(2)}, + }, + { + input: &Vector{ + Data: &Vector_BoolData{ + BoolData: &BoolData{ + Value: []bool{true, false}, + }, + }, + }, + expected: []bool{true, false}, + }, + { + input: &Vector{Data: &unknownVectorType{}}, + expected: nil, + expectedErr: fmt.Errorf("unsupported vector data type: *protos.unknownVectorType"), + }, + } + + for _, tc := range testCases { + result, err := ConvertFromVector(tc.input) + + assert.Equal(t, tc.expected, result) + assert.Equal(t, tc.expectedErr, err) + } +} diff --git a/testutils.go b/testutils.go index b0de1b5..8a4f6b9 100644 --- a/testutils.go +++ b/testutils.go @@ -16,6 +16,7 @@ import ( "github.com/aerospike/avs-client-go/protos" "github.com/aerospike/tools-common-go/client" "github.com/stretchr/testify/suite" + "go.uber.org/goleak" "golang.org/x/net/context" ) @@ -67,21 +68,12 @@ func (suite *ServerTestBaseSuite) TearDownSuite() { if err != nil { fmt.Println("unable to stop docker compose down") } -} - -func GetStrPtr(str string) *string { - ptr := str - return &ptr -} -func GetUint32Ptr(i int) *uint32 { - ptr := uint32(i) - return &ptr + goleak.VerifyNone(suite.T()) } -func GetBoolPtr(b bool) *bool { - ptr := b - return &ptr +func Ptr[T any](value T) *T { + return &value } func CreateFlagStr(name, value string) string { @@ -237,12 +229,12 @@ func (idb *IndexDefinitionBuilder) Build() *protos.IndexDefinition { }, Params: &protos.IndexDefinition_HnswParams{ HnswParams: &protos.HnswParams{ - M: GetUint32Ptr(16), - EfConstruction: GetUint32Ptr(100), - Ef: GetUint32Ptr(100), + M: Ptr(uint32(16)), + EfConstruction: Ptr(uint32(100)), + Ef: Ptr(uint32(100)), BatchingParams: &protos.HnswBatchingParams{ - MaxRecords: GetUint32Ptr(100000), - Interval: GetUint32Ptr(30000), + MaxRecords: Ptr(uint32(100000)), + Interval: Ptr(uint32(30000)), }, CachingParams: &protos.HnswCachingParams{}, HealerParams: &protos.HnswHealerParams{}, @@ -333,7 +325,7 @@ func DockerComposeUp(composeFile string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - cmd := exec.CommandContext(ctx, "docker", "-lDEBUG", "compose", fmt.Sprintf("-f%s", composeFile), "up", "-d") + cmd := exec.CommandContext(ctx, "docker", "-lDEBUG", "compose", fmt.Sprintf("-f%s", composeFile), "--env-file", "docker/.env", "up", "-d") err := cmd.Run() cmd.Wait() diff --git a/token_manager.go b/token_manager.go index 1947bcc..848f96b 100644 --- a/token_manager.go +++ b/token_manager.go @@ -15,11 +15,11 @@ import ( "google.golang.org/grpc/metadata" ) -// tokenManager is responsible for managing authentication tokens and refreshing +// grpcTokenManager is responsible for managing authentication tokens and refreshing // them when necessary. // //nolint:govet // We will favor readability over field alignment -type tokenManager struct { +type grpcTokenManager struct { username string password string token atomic.Value @@ -29,13 +29,13 @@ type tokenManager struct { refreshScheduled bool } -// newJWTToken creates a new tokenManager instance with the provided username, password, and logger. -func newJWTToken(username, password string, logger *slog.Logger) *tokenManager { +// newGrpcJWTToken creates a new tokenManager instance with the provided username, password, and logger. +func newGrpcJWTToken(username, password string, logger *slog.Logger) *grpcTokenManager { logger.WithGroup("jwt") logger.Debug("creating new token manager") - return &tokenManager{ + return &grpcTokenManager{ username: username, password: password, logger: logger, @@ -44,7 +44,7 @@ func newJWTToken(username, password string, logger *slog.Logger) *tokenManager { } // Close stops the scheduled token refresh and closes the token manager. -func (tm *tokenManager) Close() { +func (tm *grpcTokenManager) Close() { if tm.refreshScheduled { tm.logger.Debug("stopping scheduled token refresh") tm.stopRefreshChan <- struct{}{} @@ -55,17 +55,16 @@ func (tm *tokenManager) Close() { } // setRefreshTimeFromTTL sets the refresh time based on the provided time-to-live (TTL) duration. -func (tm *tokenManager) setRefreshTimeFromTTL(ttl time.Duration) { +func (tm *grpcTokenManager) setRefreshTimeFromTTL(ttl time.Duration) { tm.refreshTime.Store(time.Now().Add(ttl)) } // RefreshToken refreshes the authentication token using the provided gRPC client connection. // It returns a boolean indicating if the token was successfully refreshed and // an error if any. It is not thread safe. -func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnInterface) error { +func (tm *grpcTokenManager) RefreshToken(ctx context.Context, conn *connection) error { // We only want one goroutine to refresh the token at a time - client := protos.NewAuthServiceClient(conn) - resp, err := client.Authenticate(ctx, &protos.AuthRequest{ + resp, err := conn.authClient.Authenticate(ctx, &protos.AuthRequest{ Credentials: createUserPassCredential(tm.username, tm.password), }) @@ -74,6 +73,11 @@ func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnIn } claims := strings.Split(resp.GetToken(), ".") + + if len(claims) < 3 { + return fmt.Errorf("failed to authenticate: missing either header, payload, or signature") + } + decClaims, err := base64.RawURLEncoding.DecodeString(claims[1]) if err != nil { @@ -89,17 +93,17 @@ func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnIn expiryToken, ok := tokenMap["exp"].(float64) if !ok { - return fmt.Errorf("%s: %w", "failed to authenticate", err) + return fmt.Errorf("failed to authenticate: unable to find exp in token") } iat, ok := tokenMap["iat"].(float64) if !ok { - return fmt.Errorf("%s: %w", "failed to authenticate", err) + return fmt.Errorf("failed to authenticate: unable to find iat in token") } ttl := time.Duration(expiryToken-iat) * time.Second if ttl <= 0 { - return fmt.Errorf("%s: %w", "failed to authenticate", err) + return fmt.Errorf("failed to authenticate: jwt ttl is less than 0") } tm.logger.DebugContext( @@ -120,7 +124,7 @@ func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnIn // ScheduleRefresh schedules the token refresh using the provided function to // get the gRPC client connection. This is not threadsafe. It should only be // called once. -func (tm *tokenManager) ScheduleRefresh(getConn func() (*connection, error)) { +func (tm *grpcTokenManager) ScheduleRefresh(getConn func() (*connection, error)) { if tm.refreshScheduled { tm.logger.Warn("refresh already scheduled") } @@ -139,7 +143,7 @@ func (tm *tokenManager) ScheduleRefresh(getConn func() (*connection, error)) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - err = tm.RefreshToken(ctx, connClients.grpcConn) + err = tm.RefreshToken(ctx, connClients) if err != nil { tm.logger.Warn("failed to refresh token", slog.Any("error", err)) } @@ -167,12 +171,12 @@ func (tm *tokenManager) ScheduleRefresh(getConn func() (*connection, error)) { } // RequireTransportSecurity returns true to indicate that transport security is required. -func (tm *tokenManager) RequireTransportSecurity() bool { +func (tm *grpcTokenManager) RequireTransportSecurity() bool { return true } // UnaryInterceptor returns the grpc unary client interceptor that attaches the token to outgoing requests. -func (tm *tokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor { +func (tm *grpcTokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor { return func( ctx context.Context, method string, @@ -186,7 +190,7 @@ func (tm *tokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor { } // StreamInterceptor returns the grpc stream client interceptor that attaches the token to outgoing requests. -func (tm *tokenManager) StreamInterceptor() grpc.StreamClientInterceptor { +func (tm *grpcTokenManager) StreamInterceptor() grpc.StreamClientInterceptor { return func( ctx context.Context, desc *grpc.StreamDesc, @@ -200,7 +204,7 @@ func (tm *tokenManager) StreamInterceptor() grpc.StreamClientInterceptor { } // attachToken attaches the authentication token to the outgoing context. -func (tm *tokenManager) attachToken(ctx context.Context) context.Context { +func (tm *grpcTokenManager) attachToken(ctx context.Context) context.Context { rawToken := tm.token.Load() if rawToken == nil { return ctx diff --git a/token_manager_test.go b/token_manager_test.go new file mode 100644 index 0000000..392a0d8 --- /dev/null +++ b/token_manager_test.go @@ -0,0 +1,301 @@ +package avs + +import ( + "context" + "fmt" + "log/slog" + "testing" + + "github.com/aerospike/avs-client-go/protos" + "github.com/stretchr/testify/assert" + gomock "go.uber.org/mock/gomock" +) + +func TestRefreshToken_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDAsImlhdCI6MTcyNzExMDc1NH0.GD01CEWxW6-7lHcyeetM95WKdUlwY85m5lFqzcTCtzs" //nolint:gosec,lll // tests + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.NoError(t, err) + + // Assert the token was set correctly + assert.Equal(t, "Bearer "+b64token, tm.token.Load().(string)) +} + +func TestRefreshToken_FailedToRefreshToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Create the token manager with the + tm := newGrpcJWTToken(username, password, logger) + + // Refresh + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert error occurred + assert.Equal(t, "failed to authenticate: foo", err.Error()) +} + +func TestRefreshToken_FailedMissingClaims(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: "badToken", + }, nil) + + // Create the token manager with the + tm := newGrpcJWTToken(username, password, logger) + + // Refresh + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: missing either header, payload, or signature", err.Error()) +} + +func TestRefreshToken_FailedToDecodeToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: "badToken.blahz.foo", + }, nil) + + // Create the token manager with the + tm := newGrpcJWTToken(username, password, logger) + + // Refresh + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: illegal base64 data at input byte 4", err.Error()) +} + +func TestRefreshToken_FailedInvalidJson(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDAsImlhdCI6MTcyNzExMD.GD01CEWxW6-7lHcyeetM95WKdUlwY85m5lFqzcTCtzs" //nolint:gosec,lll // tests + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: unexpected end of JSON input", err.Error()) +} + +func TestRefreshToken_FailedFindExp(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE3MjcxMTA3NTR9.50IZcLoS7mQPzQsKvJZyXNUukvT5FdiqN2tynNIjHuk" //nolint:gosec,lll // tests + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: unable to find exp in token", err.Error()) +} + +func TestRefreshToken_FailedFindIat(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDB9.f5DtjF1sYLH6fz0ThcFKxwngIXkVMLnhJtIrjLi_1p0" //nolint:gosec,lll // tests + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: unable to find iat in token", err.Error()) +} + +func TestRefreshToken_FailedTtlLessThan0(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE3MjcxMTA3NTMsImlhdCI6MTcyNzExMDc1NH0.YdH5twU6-LGLtgvD2sktiw1j40MRUe_r4oPN565z4Ok" //nolint:gosec,lll // tests + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: jwt ttl is less than 0", err.Error()) +} diff --git a/types_test.go b/types_test.go new file mode 100644 index 0000000..aab1472 --- /dev/null +++ b/types_test.go @@ -0,0 +1,30 @@ +package avs + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHostPort_String(t *testing.T) { + host := "localhost" + port := 8080 + hp := NewHostPort(host, port) + + expected := "localhost:8080" + result := hp.String() + + assert.Equal(t, expected, result) +} + +func TestHostPortSlice_String(t *testing.T) { + hps := HostPortSlice{ + NewHostPort("localhost", 8080), + NewHostPort("example.com", 1234), + } + + expected := "[localhost:8080, example.com:1234]" + result := hps.String() + + assert.Equal(t, expected, result) +} diff --git a/utils.go b/utils.go index d1e943b..c895790 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,7 @@ package avs import ( + "sort" "strconv" "strings" @@ -55,7 +56,9 @@ func createProjectionSpec(includeFields, excludeFields []string) *protos.Project Type: protos.ProjectionType_SPECIFIED, Fields: includeFields, } - } else if excludeFields != nil { + } + + if excludeFields != nil { spec.Exclude = &protos.ProjectionFilter{ Type: protos.ProjectionType_SPECIFIED, Fields: excludeFields, @@ -74,7 +77,50 @@ func createIndexStatusRequest(namespace, name string) *protos.IndexStatusRequest } } -var minimumSupportedAVSVersion = newVersion("0.9.0") +func endpointEqual(a, b *protos.ServerEndpoint) bool { + return a.Address == b.Address && a.Port == b.Port && a.IsTls == b.IsTls +} + +func endpointListEqual(a, b *protos.ServerEndpointList) bool { + if len(a.Endpoints) != len(b.Endpoints) { + return false + } + + aEndpoints := make([]*protos.ServerEndpoint, len(a.Endpoints)) + copy(aEndpoints, a.Endpoints) + + bEndpoints := make([]*protos.ServerEndpoint, len(b.Endpoints)) + copy(bEndpoints, b.Endpoints) + + sortFunc := func(endpoints []*protos.ServerEndpoint) func(int, int) bool { + return func(i, j int) bool { + if endpoints[i].Address < endpoints[j].Address { + return true + } else if endpoints[i].Address > endpoints[j].Address { + return false + } + + return endpoints[i].Port < endpoints[j].Port + } + } + + sort.Slice(aEndpoints, sortFunc(aEndpoints)) + sort.Slice(bEndpoints, sortFunc(bEndpoints)) + + for i, endpoint := range aEndpoints { + if !endpointEqual(endpoint, bEndpoints[i]) { + return false + } + } + + return true +} + +func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort { + return NewHostPort(endpoint.Address, int(endpoint.Port)) +} + +var minimumFullySupportedAVSVersion = newVersion("0.10.0") type version []any diff --git a/utils_test.go b/utils_test.go index 2fc6528..532ad1f 100644 --- a/utils_test.go +++ b/utils_test.go @@ -1,6 +1,114 @@ package avs -import "testing" +import ( + reflect "reflect" + "testing" + + "github.com/aerospike/avs-client-go/protos" +) + +func TestCreateProjectionSpec(t *testing.T) { + testCases := []struct { + name string + includeFields []string + excludeFields []string + expectedProjectionSpec *protos.ProjectionSpec + }{ + { + name: "include fields", + includeFields: []string{"field1", "field2"}, + excludeFields: nil, + expectedProjectionSpec: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_SPECIFIED, + Fields: []string{"field1", "field2"}, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_NONE, + }, + }, + }, + { + name: "exclude fields", + includeFields: nil, + excludeFields: []string{"field3", "field4"}, + expectedProjectionSpec: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_ALL, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_SPECIFIED, + Fields: []string{"field3", "field4"}, + }, + }, + }, + { + name: "default fields", + includeFields: nil, + excludeFields: nil, + expectedProjectionSpec: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_ALL, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_NONE, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + spec := createProjectionSpec(tc.includeFields, tc.excludeFields) + + if !reflect.DeepEqual(spec, tc.expectedProjectionSpec) { + t.Errorf("expected projection spec %v, got %v", tc.expectedProjectionSpec, spec) + } + }) + } +} +func TestVersionString(t *testing.T) { + testCases := []struct { + name string + v version + expected string + }{ + { + name: "valid version", + v: newVersion("1.2.3"), + expected: "1.2.3", + }, + { + name: "valid version with suffix", + v: newVersion("1.2.3-dev"), + expected: "1.2.3-dev", + }, + { + name: "valid version with multiple suffixes", + v: newVersion("1.2.3-dev.1"), + expected: "1.2.3-dev.1", + }, + { + name: "valid version with pre-release", + v: newVersion("1.2.3-alpha"), + expected: "1.2.3-alpha", + }, + { + name: "valid version with pre-release and build metadata", + v: newVersion("1.2.3-alpha+build123"), + expected: "1.2.3-alpha+build123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := tc.v.String() + if got != tc.expected { + t.Errorf("expected %s, got %s", tc.expected, got) + } + }) + } +} func TestVersionLTGT(t *testing.T) { testCases := []struct { @@ -75,3 +183,113 @@ func TestVersionLTGT(t *testing.T) { }) } } + +func TestEndpointListEqual(t *testing.T) { + testCases := []struct { + name string + endpoints1 *protos.ServerEndpointList + endpoints2 *protos.ServerEndpointList + want bool + }{ + { + name: "equal endpoints", + endpoints1: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + { + Address: "127.0.0.1", + Port: 9090, + IsTls: true, + }, + }, + }, + endpoints2: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "127.0.0.1", + Port: 9090, + IsTls: true, + }, + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + }, + }, + want: true, + }, + { + name: "different endpoints", + endpoints1: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + { + Address: "127.0.0.1", + Port: 9090, + IsTls: true, + }, + }, + }, + endpoints2: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + { + Address: "127.0.0.1", + Port: 9091, + IsTls: true, + }, + }, + }, + want: false, + }, + { + name: "different number of endpoints", + endpoints1: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + { + Address: "127.0.0.1", + Port: 9090, + IsTls: true, + }, + }, + }, + endpoints2: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + }, + }, + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := endpointListEqual(tc.endpoints1, tc.endpoints2) + if got != tc.want { + t.Errorf("expected %v, got %v", tc.want, got) + } + }) + } +}