From 8d64cfe23c12823102855e2d99145f635788631e Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 12 Sep 2024 12:54:57 -0700 Subject: [PATCH 01/42] increase test-coverage --- channel_provider.go | 2 +- client.go | 3 +- docker/auth/docker-compose.yml | 2 +- docker/mtls/docker-compose.yml | 2 +- docker/multi-node-LB/docker-compose.yml | 6 +- .../docker-compose.yml | 6 +- docker/multi-node/docker-compose.yml | 6 +- docker/tls/docker-compose.yml | 3 +- docker/vanilla/docker-compose.yml | 3 +- errors.go | 6 +- go.mod | 1 + go.sum | 2 + integration_single_node_test.go | 308 +++++++++++++++++- makefile | 20 +- testutils.go | 3 + utils.go | 2 +- 16 files changed, 350 insertions(+), 25 deletions(-) diff --git a/channel_provider.go b/channel_provider.go index d5fb63c..6030644 100644 --- a/channel_provider.go +++ b/channel_provider.go @@ -365,7 +365,7 @@ func (cp *channelProvider) connectToSeeds(ctx context.Context) error { return } - if newVersion(about.Version).lt(minimumSupportedAVSVersion) { + if newVersion(about.Version).lt(minimumFullySupportedAVSVersion) { logger.WarnContext(ctx, "incompatible server version", slog.String("version", about.Version)) } } diff --git a/client.go b/client.go index 8d68ce1..b0f5c16 100644 --- a/client.go +++ b/client.go @@ -1364,6 +1364,7 @@ func (c *Client) ListRoles(ctx context.Context) (*protos.ListRolesResponse, erro } // NodeIDs returns a list of all the node IDs that the client is connected to. +// If load-balancer is set true no NodeIDs will be returned. // If a node is accessible but not a part of the cluster it will not be // returned. // @@ -1529,7 +1530,7 @@ func (c *Client) About(ctx context.Context, nodeID *protos.NodeId) (*protos.Abou msg := "failed to make about request" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) - return nil, NewAVSErrorFromGrpc(msg, err) + return nil, NewAVSError(msg, err) } resp, err := conn.aboutClient.Get(ctx, &protos.AboutRequest{}) diff --git a/docker/auth/docker-compose.yml b/docker/auth/docker-compose.yml index dbf6d6b..3a1c1a8 100644 --- a/docker/auth/docker-compose.yml +++ b/docker/auth/docker-compose.yml @@ -16,7 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 depends_on: aerospike: condition: service_healthy diff --git a/docker/mtls/docker-compose.yml b/docker/mtls/docker-compose.yml index dbf6d6b..3a1c1a8 100644 --- a/docker/mtls/docker-compose.yml +++ b/docker/mtls/docker-compose.yml @@ -16,7 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 depends_on: aerospike: condition: service_healthy diff --git a/docker/multi-node-LB/docker-compose.yml b/docker/multi-node-LB/docker-compose.yml index 02d8786..3e59264 100644 --- a/docker/multi-node-LB/docker-compose.yml +++ b/docker/multi-node-LB/docker-compose.yml @@ -20,7 +20,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 volumes: - ./config/aerospike-vector-search-1.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml - ./config/features.conf:/etc/aerospike-vector-search/features.conf @@ -35,7 +35,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 volumes: - ./config/aerospike-vector-search-2.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml - ./config/features.conf:/etc/aerospike-vector-search/features.conf @@ -50,7 +50,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 volumes: - ./config/aerospike-vector-search-3.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml - ./config/features.conf:/etc/aerospike-vector-search/features.conf diff --git a/docker/multi-node-client-visibility-err/docker-compose.yml b/docker/multi-node-client-visibility-err/docker-compose.yml index 195f9a6..01d12e4 100644 --- a/docker/multi-node-client-visibility-err/docker-compose.yml +++ b/docker/multi-node-client-visibility-err/docker-compose.yml @@ -20,7 +20,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 ports: - "10000:10000" volumes: @@ -37,7 +37,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 ports: - "10001:10001" volumes: @@ -54,7 +54,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 # ports: # - "10002:10002" # This causes the visibility err volumes: diff --git a/docker/multi-node/docker-compose.yml b/docker/multi-node/docker-compose.yml index 2cfd93a..2478015 100644 --- a/docker/multi-node/docker-compose.yml +++ b/docker/multi-node/docker-compose.yml @@ -20,7 +20,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 ports: - "10000:10000" volumes: @@ -37,7 +37,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 ports: - "10001:10001" volumes: @@ -54,7 +54,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector-search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 ports: - "10002:10002" volumes: diff --git a/docker/tls/docker-compose.yml b/docker/tls/docker-compose.yml index 2308444..3a1c1a8 100644 --- a/docker/tls/docker-compose.yml +++ b/docker/tls/docker-compose.yml @@ -16,8 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector- -search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 depends_on: aerospike: condition: service_healthy diff --git a/docker/vanilla/docker-compose.yml b/docker/vanilla/docker-compose.yml index 2308444..3a1c1a8 100644 --- a/docker/vanilla/docker-compose.yml +++ b/docker/vanilla/docker-compose.yml @@ -16,8 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike.jfrog.io/docker/aerospike/aerospike-vector- -search-private:0.9.1-SNAPSHOT + image: aerospike/aerospike-vector-search:0.10.0 depends_on: aerospike: condition: service_healthy diff --git a/errors.go b/errors.go index 9a38945..f975989 100644 --- a/errors.go +++ b/errors.go @@ -10,7 +10,7 @@ type Error struct { msg string } -func NewAVSError(msg string, err error) error { +func NewAVSError(msg string, err error) *Error { if err != nil { msg = fmt.Sprintf("%s: %s", msg, err.Error()) } @@ -18,14 +18,14 @@ func NewAVSError(msg string, err error) error { return &Error{msg: msg} } -func NewAVSErrorFromGrpc(msg string, gErr error) error { +func NewAVSErrorFromGrpc(msg string, gErr error) *Error { if gErr == nil { return nil } gStatus, ok := status.FromError(gErr) if !ok { - return NewAVSError(gErr.Error(), nil) + return NewAVSError(msg, gErr) } errStr := fmt.Sprintf("%s: server error: %s", msg, gStatus.Code().String()) diff --git a/go.mod b/go.mod index 0bfeb5e..28d5ac7 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect + go.uber.org/goleak v1.3.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect diff --git a/go.sum b/go.sum index 91a1f11..a922ec1 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,8 @@ go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgS go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= diff --git a/integration_single_node_test.go b/integration_single_node_test.go index 41a92e7..b33664c 100644 --- a/integration_single_node_test.go +++ b/integration_single_node_test.go @@ -9,6 +9,9 @@ import ( "fmt" "log/slog" "os" + "strconv" + "strings" + "sync" "testing" "time" @@ -59,7 +62,7 @@ func TestSingleNodeSuite(t *testing.T) { suites := []*SingleNodeTestSuite{ { ServerTestBaseSuite: ServerTestBaseSuite{ - ComposeFile: "docker/multi-node/docker-compose.yml", // vanilla + ComposeFile: "docker/vanilla/docker-compose.yml", // vanilla AvsLB: false, AvsHostPort: avsHostPort, }, @@ -89,7 +92,30 @@ func TestSingleNodeSuite(t *testing.T) { }, } - for _, s := range suites { + testSuiteEnv := os.Getenv("ASVEC_TEST_SUITES") + picked_suites := map[int]struct{}{} + + if testSuiteEnv != "" { + testSuites := strings.Split(testSuiteEnv, ",") + + for _, s := range testSuites { + i, err := strconv.Atoi(s) + if err != nil { + t.Fatalf("unable to convert %s to int: %v", s, err) + } + + picked_suites[i] = struct{}{} + } + } + + logger.Info("Running test suites", slog.Any("suites", picked_suites)) + + for i, s := range suites { + if len(picked_suites) != 0 { + if _, ok := picked_suites[i]; !ok { + continue + } + } suite.Run(t, s) } } @@ -234,6 +260,13 @@ func getUniqueSetName() *string { return &val } +var userNameCount = -1 + +func getUniqueUserName() string { + userNameCount++ + return "user" + fmt.Sprintf("%d", userNameCount) +} + func createNeighborFloat32(namespace string, set *string, key string, distance float32, vector []float32) *Neighbor { return &Neighbor{ Namespace: namespace, @@ -427,3 +460,274 @@ func (suite *SingleNodeTestSuite) TestVectorSearchBool() { }) } } + +func (suite *SingleNodeTestSuite) TestConcurrentWrites() { + wg := sync.WaitGroup{} + numWrites := 10_000 + keys := []string{} + + for i := 0; i < numWrites; i++ { + keys = append(keys, fmt.Sprintf("concurrent-key-%d", i)) + } + + for i := 0; i < numWrites; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + suite.AvsClient.Upsert(context.Background(), testNamespace, nil, keys[i], map[string]any{"foo": "bar"}, false) + }(i) + } + + wg.Wait() + + wg = sync.WaitGroup{} + + for i := 0; i < numWrites; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _, err := suite.AvsClient.Get(context.Background(), testNamespace, nil, keys[i], nil, nil) + suite.Assert().NoError(err) + }(i) + } + + wg.Wait() +} + +func (suite *SingleNodeTestSuite) TestUserCreate() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) + suite.NoError(err) + + actualUser, err := suite.AvsClient.GetUser(ctx, username) + suite.NoError(err) + + expectedUser := protos.User{ + Username: username, + Roles: []string{ + "read-write", + }, + } + + suite.EqualExportedValues(expectedUser, *actualUser) + return + +} + +func (suite *SingleNodeTestSuite) TestUserDelete() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) + suite.NoError(err) + + err = suite.AvsClient.DropUser(ctx, username) + suite.NoError(err) + + _, err = suite.AvsClient.GetUser(ctx, username) + suite.Error(err) +} + +func (suite *SingleNodeTestSuite) TestUserGrantRoles() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) + suite.NoError(err) + + err = suite.AvsClient.GrantRoles(ctx, username, []string{"admin"}) + suite.NoError(err) + + actualUser, err := suite.AvsClient.GetUser(ctx, username) + suite.NoError(err) + + expectedUser := protos.User{ + Username: username, + Roles: []string{ + "admin", + "read-write", + }, + } + + suite.EqualExportedValues(expectedUser, *actualUser) +} + +func (suite *SingleNodeTestSuite) TestUserRevokeRoles() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write", "admin"}) + suite.NoError(err) + + err = suite.AvsClient.RevokeRoles(ctx, username, []string{"admin"}) + suite.NoError(err) + + actualUser, err := suite.AvsClient.GetUser(ctx, username) + suite.NoError(err) + + expectedUser := protos.User{ + Username: username, + Roles: []string{ + "read-write", + }, + } + + suite.EqualExportedValues(expectedUser, *actualUser) +} + +func (suite *SingleNodeTestSuite) TestListUsers() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) + suite.NoError(err) + + username1 := getUniqueUserName() + + err = suite.AvsClient.CreateUser(ctx, username1, "test-password", []string{"read-write"}) + suite.NoError(err) + + users, err := suite.AvsClient.ListUsers(ctx) + suite.NoError(err) + + suite.Equal(3, len(users.Users)) + + for _, user := range users.Users { + if user.Username == username { + suite.Equal([]string{"read-write"}, user.Roles) + } else if user.Username == username1 { + suite.Equal([]string{"read-write"}, user.Roles) + } else { + suite.Equal("admin", user.Username) + suite.Equal([]string{"admin", "read-write"}, user.Roles) + } + } +} + +func (suite *SingleNodeTestSuite) TestListRoles() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + roles, err := suite.AvsClient.ListRoles(ctx) + suite.NoError(err) + + suite.Equal(2, len(roles.Roles)) +} + +func (suite *SingleNodeTestSuite) TestNodeIDs() { + ctx := context.Background() + nodeIDs := suite.AvsClient.NodeIDs(ctx) + + if suite.AvsLB { + suite.Equal(0, len(nodeIDs)) + } else { + suite.Equal(1, len(nodeIDs)) + } +} + +func (suite *SingleNodeTestSuite) TestClusterEndpoints() { + testCases := []struct { + name string + nodeId *protos.NodeId + listenerName *string + expectedEndpoints []*protos.ServerEndpoint + expectedErrMsg *string + }{ + { + name: "nil-node", + nodeId: nil, + expectedEndpoints: []*protos.ServerEndpoint{ + { + Address: "127.0.0.1", + Port: 10000, + IsTls: suite.AvsTLSConfig != nil, + }, + }, + }, + { + name: "node id DNE", + nodeId: &protos.NodeId{ + Id: 1, + }, + expectedErrMsg: GetStrPtr("failed to get cluster endpoints"), + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + endpoints, err := suite.AvsClient.ClusterEndpoints(ctx, tc.nodeId, tc.listenerName) + + if tc.expectedErrMsg != nil { + suite.Error(err) + suite.Contains(err.Error(), *tc.expectedErrMsg) + return + } else { + suite.NoError(err) + } + + for id, endpoint := range endpoints.Endpoints { + // When LB is true we aren't able to validate the node-id since + // we did not tend the cluster. When LB is false we are able to + // get the node-id of the single node. + if !suite.AvsLB { + nodeId := suite.AvsClient.NodeIDs(ctx)[0] + suite.Assert().Equal(nodeId.Id, id) + } + + suite.EqualExportedValues(tc.expectedEndpoints[0], endpoint.Endpoints[0]) + } + }) + } +} + +func (suite *SingleNodeTestSuite) TestAbout() { + testCases := []struct { + name string + nodeId *protos.NodeId + expectedVersion string + expectedErrMsg *string + }{ + { + name: "nil-node", + nodeId: nil, + expectedVersion: "0.10.0", + }, + { + name: "node id DNE", + nodeId: &protos.NodeId{ + Id: 1, + }, + expectedErrMsg: GetStrPtr("failed to make about request"), + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + actualVersion, err := suite.AvsClient.About(ctx, tc.nodeId) + + if tc.expectedErrMsg != nil { + suite.Error(err) + suite.Contains(err.Error(), *tc.expectedErrMsg) + return + } else { + suite.NoError(err) + } + + suite.Equal(actualVersion.GetVersion(), tc.expectedVersion) + }) + } +} diff --git a/makefile b/makefile index 61b9c7b..7d5d43e 100644 --- a/makefile +++ b/makefile @@ -1,9 +1,21 @@ +ifeq (,$(shell go env GOBIN)) +GOBIN=$(shell go env GOPATH)/bin +else +GOBIN=$(shell go env GOBIN) +endif + +GOLANGCI_LINT ?= $(GOBIN)/golangci-lint +GOLANGCI_LINT_VERSION ?= v1.54.0 + ROOT_DIR = $(shell pwd) PROTO_DIR = $(ROOT_DIR)/protos COVERAGE_DIR = $(ROOT_DIR)/coverage COV_UNIT_DIR = $(COVERAGE_DIR)/unit COV_INTEGRATION_DIR = $(COVERAGE_DIR)/integration +$(GOLANGCI_LINT): $(GOBIN) + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GOBIN) $(GOLANGCI_LINT_VERSION) + .PHONY: protos protos: protoc --proto_path=$(PROTO_DIR) --go_out=$(PROTO_DIR) --go_opt=paths=source_relative \ @@ -13,7 +25,7 @@ protos: test: integration unit .PHONY: integration -integration: +integration: $(GOLEAK) mkdir -p $(COV_INTEGRATION_DIR) || true go test -tags=integration -timeout 30m -cover ./... -args -test.gocoverdir=$(COV_INTEGRATION_DIR) @@ -31,4 +43,8 @@ coverage: test PHONY: view-coverage view-coverage: $(COVERAGE_DIR)/total.cov - go tool cover -html=$(COVERAGE_DIR)/total.cov \ No newline at end of file + go tool cover -html=$(COVERAGE_DIR)/total.cov + +PHONY: lint +lint: $(GOLANGCI_LINT) + $(GOLANGCI_LINT) run \ No newline at end of file diff --git a/testutils.go b/testutils.go index b0de1b5..8aebcf5 100644 --- a/testutils.go +++ b/testutils.go @@ -16,6 +16,7 @@ import ( "github.com/aerospike/avs-client-go/protos" "github.com/aerospike/tools-common-go/client" "github.com/stretchr/testify/suite" + "go.uber.org/goleak" "golang.org/x/net/context" ) @@ -67,6 +68,8 @@ func (suite *ServerTestBaseSuite) TearDownSuite() { if err != nil { fmt.Println("unable to stop docker compose down") } + + goleak.VerifyNone(suite.T()) } func GetStrPtr(str string) *string { diff --git a/utils.go b/utils.go index d1e943b..b099ff4 100644 --- a/utils.go +++ b/utils.go @@ -74,7 +74,7 @@ func createIndexStatusRequest(namespace, name string) *protos.IndexStatusRequest } } -var minimumSupportedAVSVersion = newVersion("0.9.0") +var minimumFullySupportedAVSVersion = newVersion("0.10.0") type version []any From 5d9b3bcf7fb999beaa5795953f6f18367213da80 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 12 Sep 2024 13:22:59 -0700 Subject: [PATCH 02/42] add feature key file --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c8e1aa8..9011e72 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,6 +25,7 @@ jobs: run: | echo "$FEATURES_CONF" > docker/multi-node/config/features.conf + echo "$FEATURES_CONF" > docker/vanilla/config/features.conf echo "$FEATURES_CONF" > docker/tls/config/features.conf echo "$FEATURES_CONF" > docker/mtls/config/features.conf echo "$FEATURES_CONF" > docker/auth/config/features.conf From 548218c8944625657853caa020b241efd4996f0a Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 12 Sep 2024 13:53:59 -0700 Subject: [PATCH 03/42] remove generated code from coverage --- makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/makefile b/makefile index 7d5d43e..eb79cec 100644 --- a/makefile +++ b/makefile @@ -38,7 +38,7 @@ unit: coverage: test go tool covdata textfmt -i="$(COV_INTEGRATION_DIR),$(COV_UNIT_DIR)" -o=$(COVERAGE_DIR)/tmp.cov go tool cover -func=$(COVERAGE_DIR)/tmp.cov - grep -v 'testutils.go' $(COVERAGE_DIR)/tmp.cov > $(COVERAGE_DIR)/total.cov + grep -Ev '(testutils\.go|.*\.pb\.go)' $(COVERAGE_DIR)/tmp.cov > $(COVERAGE_DIR)/total.cov PHONY: view-coverage From c984d8ce470d82c63655ae65f204f1cda40ebddf Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 13 Sep 2024 16:17:57 -0700 Subject: [PATCH 04/42] more tests --- client.go | 6 +- integration_single_node_test.go | 261 +++++++++++++++++++++++++++++++- 2 files changed, 264 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index b0f5c16..7aee55b 100644 --- a/client.go +++ b/client.go @@ -26,7 +26,7 @@ const ( failedToGetRecord = "failed to get record" failedToDeleteRecord = "failed to delete record" failedToCheckRecordExists = "failed to check if record exists" - failedToCheckIsIndexed = "failed to check if record exists" + failedToCheckIsIndexed = "failed to check if record is indexed" ) // Client is a client for managing Aerospike Vector Indexes. @@ -422,6 +422,10 @@ func (c *Client) IsIndexed( } isIndexedReq := &protos.IsIndexedRequest{ + IndexId: &protos.IndexId{ + Namespace: namespace, + Name: indexName, + }, Key: protoKey, } diff --git a/integration_single_node_test.go b/integration_single_node_test.go index b33664c..ec9d28b 100644 --- a/integration_single_node_test.go +++ b/integration_single_node_test.go @@ -63,7 +63,7 @@ func TestSingleNodeSuite(t *testing.T) { { ServerTestBaseSuite: ServerTestBaseSuite{ ComposeFile: "docker/vanilla/docker-compose.yml", // vanilla - AvsLB: false, + AvsLB: true, AvsHostPort: avsHostPort, }, }, @@ -136,7 +136,7 @@ func (suite *SingleNodeTestSuite) TestBasicUpsertGetDelete() { { "test", getUniqueSetName(), - "key1", + getUniqueKey(), map[string]any{ "str": "str", "int": int64(64), @@ -183,6 +183,259 @@ func (suite *SingleNodeTestSuite) TestBasicUpsertGetDelete() { } } +func (suite *SingleNodeTestSuite) TestBasicUpsertExistsDelete() { + records := []struct { + namespace string + set *string + key any + recordData map[string]any + }{ + { + "test", + getUniqueSetName(), + getUniqueKey(), + map[string]any{ + "str": "str", + "int": int64(64), + "float": 3.14, + "bool": false, + "arr": []any{int64(0), int64(1), int64(2), int64(3)}, + "map": map[any]any{ + "foo": "bar", + }, + }, + }, + } + + // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + // defer cancel() + ctx := context.Background() + + for _, rec := range records { + err := suite.AvsClient.Upsert(ctx, rec.namespace, rec.set, rec.key, rec.recordData, false) + suite.NoError(err) + + if err != nil { + return + } + + exists, err := suite.AvsClient.Exists(ctx, rec.namespace, rec.set, rec.key) + suite.NoError(err) + + if err != nil { + return + } + + suite.True(exists) + + err = suite.AvsClient.Delete(ctx, rec.namespace, rec.set, rec.key) + suite.NoError(err) + + if err != nil { + return + } + + exists, err = suite.AvsClient.Exists(ctx, rec.namespace, rec.set, rec.key) + suite.NoError(err) + + suite.False(exists) + } +} + +func (suite *SingleNodeTestSuite) TestIndexCreate() { + testcases := []struct { + name string + namespace string + indexName string + vectorField string + dimension uint32 + vectorDistanceMetric protos.VectorDistanceMetric + opts *IndexCreateOpts + expectedIndex protos.IndexDefinition + }{ + { + "basic", + "test", + "index", + "vector", + 10, + protos.VectorDistanceMetric_SQUARED_EUCLIDEAN, + nil, + protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "test", + Name: "index", + }, + Dimensions: uint32(10), + VectorDistanceMetric: protos.VectorDistanceMetric_SQUARED_EUCLIDEAN, + Type: protos.IndexType_HNSW, + SetFilter: nil, + Field: "vector", + }, + }, + { + "with opts", + "test", + "index", + "vector", + 10, + protos.VectorDistanceMetric_COSINE, + &IndexCreateOpts{ + Sets: []string{"testset"}, + Storage: &protos.IndexStorage{ + Namespace: GetStrPtr("storage-ns"), + Set: GetStrPtr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + }, + protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "test", + Name: "index", + }, + Dimensions: uint32(10), + VectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + Type: protos.IndexType_HNSW, + SetFilter: GetStrPtr("testset"), + Field: "vector", + Storage: &protos.IndexStorage{ + Namespace: GetStrPtr("storage-ns"), + Set: GetStrPtr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + Params: &protos.IndexDefinition_HnswParams{ + HnswParams: &protos.HnswParams{}, + }, + }, + }, + } + + for _, tc := range testcases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + err := suite.AvsClient.IndexCreate(ctx, tc.namespace, tc.indexName, tc.vectorField, tc.dimension, tc.vectorDistanceMetric, tc.opts) + suite.NoError(err) + + if err != nil { + return + } + + defer suite.AvsClient.IndexDrop(ctx, tc.namespace, tc.indexName) + + index, err := suite.AvsClient.IndexGet(ctx, tc.namespace, tc.indexName, false) + suite.NoError(err) + + if err != nil { + return + } + + suite.EqualExportedValues(tc.expectedIndex, *index) + }) + } +} + +// func (suite *SingleNodeTestSuite) TestIsIndexed() { +// records := []struct { +// namespace string +// set *string +// key any +// recordData map[string]any +// }{ +// { +// namespace: "test", +// set: getUniqueSetName(), +// key: getUniqueKey(), +// recordData: map[string]any{ +// // "str": "str", +// // "int": int64(64), +// // "float": 3.14, +// // "bool": false, +// "arr": getVectorFloat32(10, 1.0), +// // "map": map[any]any{ +// // "foo": "bar", +// // }, +// }, +// }, +// } + +// // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +// // defer cancel() +// ctx := context.Background() + +// for _, rec := range records { +// err := suite.AvsClient.Upsert(ctx, rec.namespace, rec.set, rec.key, rec.recordData, false) +// suite.NoError(err) + +// for i := 0; i < 10; i++ { +// recordData := map[string]any{ +// // "str": "str", +// // "int": int64(64), +// // "float": 3.14, +// // "bool": false, +// "arr": getVectorFloat32(10, float32(i*2)), +// // "map": map[any]any{ +// // "foo": "bar", +// // }, +// } + +// suite.AvsClient.Upsert(ctx, rec.namespace, rec.set, getUniqueKey(), recordData, false) +// } + +// if err != nil { +// return +// } + +// isIndexed, err := suite.AvsClient.IsIndexed(ctx, rec.namespace, rec.set, "index", rec.key) +// suite.Error(err) + +// var indexOpts *IndexCreateOpts + +// if rec.set != nil { +// indexOpts = &IndexCreateOpts{ +// Sets: []string{*rec.set}, +// } +// } + +// err = suite.AvsClient.IndexCreate( +// ctx, +// rec.namespace, +// "index", +// "arr", +// 10, +// protos.VectorDistanceMetric_SQUARED_EUCLIDEAN, +// indexOpts, +// ) +// suite.NoError(err) + +// if err != nil { +// return +// } + +// isIndexed, err = suite.AvsClient.IsIndexed(ctx, rec.namespace, rec.set, "index", rec.key) +// suite.NoError(err) +// suite.False(isIndexed) + +// defer suite.AvsClient.IndexDrop(ctx, rec.namespace, "index") + +// suite.AvsClient.WaitForIndexCompletion(ctx, rec.namespace, "index", time.Second*12) + +// isIndexed, err = suite.AvsClient.IsIndexed(ctx, rec.namespace, rec.set, "index", rec.key) +// suite.NoError(err) + +// // time.Sleep(time.Second * 33330) + +// suite.True(isIndexed) + +// err = suite.AvsClient.Delete(ctx, rec.namespace, rec.set, rec.key) +// suite.NoError(err) +// } +// } + func (suite *SingleNodeTestSuite) TestFailsToInsertAlreadyExists() { ctx := context.Background() key := getUniqueKey() @@ -367,6 +620,10 @@ func (suite *SingleNodeTestSuite) TestVectorSearchFloat32() { return } + isIndexed, err := suite.AvsClient.IsIndexed(ctx, tc.namespace, setName, indexName, getKey(0)) + suite.NoError(err) + suite.True(isIndexed) + time.Sleep(time.Second * 10) neighbors, err := suite.AvsClient.VectorSearchFloat32(ctx, tc.namespace, indexName, tc.query, 3, nil, nil, nil) From 5bc2a3618c0f9e8f272f1872b9a1dc416fa9ef83 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 16 Sep 2024 11:45:24 -0700 Subject: [PATCH 05/42] more tests --- client.go | 2 +- integration_single_node_test.go | 357 +++++++++++++++++++++++++++++++- testutils.go | 10 + 3 files changed, 367 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 7aee55b..27a34f6 100644 --- a/client.go +++ b/client.go @@ -1456,7 +1456,7 @@ func (c *Client) ClusteringState(ctx context.Context, nodeID *protos.NodeId) (*p conn, err := c.getConnection(nodeID) if err != nil { - msg := "failed to list roles" + msg := "failed to get clustering state" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) return nil, NewAVSError(msg, err) diff --git a/integration_single_node_test.go b/integration_single_node_test.go index ec9d28b..13e68e2 100644 --- a/integration_single_node_test.go +++ b/integration_single_node_test.go @@ -63,7 +63,7 @@ func TestSingleNodeSuite(t *testing.T) { { ServerTestBaseSuite: ServerTestBaseSuite{ ComposeFile: "docker/vanilla/docker-compose.yml", // vanilla - AvsLB: true, + AvsLB: false, AvsHostPort: avsHostPort, }, }, @@ -339,6 +339,185 @@ func (suite *SingleNodeTestSuite) TestIndexCreate() { } } +func (suite *SingleNodeTestSuite) TestIndexUpdate() { + testcases := []struct { + name string + namespace string + indexName string + vectorField string + dimension uint32 + vectorDistanceMetric protos.VectorDistanceMetric + opts *IndexCreateOpts + updateLabels map[string]string + updateHnsw *protos.HnswIndexUpdate + expectedIndex protos.IndexDefinition + }{ + { + name: "no update", + namespace: "test", + indexName: "index", + vectorField: "vector", + dimension: 10, + vectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + opts: &IndexCreateOpts{ + Sets: []string{"testset"}, + Storage: &protos.IndexStorage{ + Namespace: GetStrPtr("storage-ns"), + Set: GetStrPtr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + }, + updateHnsw: nil, + updateLabels: nil, + expectedIndex: protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "test", + Name: "index", + }, + Dimensions: uint32(10), + VectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + Type: protos.IndexType_HNSW, + SetFilter: GetStrPtr("testset"), + Field: "vector", + Storage: &protos.IndexStorage{ + Namespace: GetStrPtr("storage-ns"), + Set: GetStrPtr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + Params: &protos.IndexDefinition_HnswParams{ + HnswParams: &protos.HnswParams{}, + }, + }, + }, + { + name: "update all params", + namespace: "test", + indexName: "index", + vectorField: "vector", + dimension: 10, + vectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + opts: &IndexCreateOpts{ + Sets: []string{"testset"}, + Storage: &protos.IndexStorage{ + Namespace: GetStrPtr("storage-ns"), + Set: GetStrPtr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + }, + }, + updateHnsw: &protos.HnswIndexUpdate{ + MaxMemQueueSize: GetUint32Ptr(100), + BatchingParams: &protos.HnswBatchingParams{ + MaxRecords: GetUint32Ptr(10_001), + Interval: GetUint32Ptr(10_002), + }, + CachingParams: &protos.HnswCachingParams{ + MaxEntries: GetUint64Ptr(10_003), + Expiry: GetUint64Ptr(10_004), + }, + HealerParams: &protos.HnswHealerParams{ + MaxScanRatePerNode: GetUint32Ptr(10_005), + MaxScanPageSize: GetUint32Ptr(10_006), + ReindexPercent: GetFloat32Ptr(51), + Schedule: GetStrPtr("0 0 0 25 12 ?"), + Parallelism: GetUint32Ptr(1), + }, + MergeParams: &protos.HnswIndexMergeParams{ + IndexParallelism: GetUint32Ptr(2), + ReIndexParallelism: GetUint32Ptr(3), + }, + }, + updateLabels: map[string]string{ + "c": "d", + }, + expectedIndex: protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "test", + Name: "index", + }, + Dimensions: uint32(10), + VectorDistanceMetric: protos.VectorDistanceMetric_COSINE, + Type: protos.IndexType_HNSW, + SetFilter: GetStrPtr("testset"), + Field: "vector", + Storage: &protos.IndexStorage{ + Namespace: GetStrPtr("storage-ns"), + Set: GetStrPtr("storage-set"), + }, + Labels: map[string]string{ + "a": "b", + "c": "d", + }, + Params: &protos.IndexDefinition_HnswParams{ + HnswParams: &protos.HnswParams{ + MaxMemQueueSize: GetUint32Ptr(100), + BatchingParams: &protos.HnswBatchingParams{ + MaxRecords: GetUint32Ptr(10_001), + Interval: GetUint32Ptr(10_002), + }, + CachingParams: &protos.HnswCachingParams{ + MaxEntries: GetUint64Ptr(10_003), + Expiry: GetUint64Ptr(10_004), + }, + HealerParams: &protos.HnswHealerParams{ + MaxScanRatePerNode: GetUint32Ptr(10_005), + MaxScanPageSize: GetUint32Ptr(10_006), + ReindexPercent: GetFloat32Ptr(51), + Schedule: GetStrPtr("0 0 0 25 12 ?"), + Parallelism: GetUint32Ptr(1), + }, + MergeParams: &protos.HnswIndexMergeParams{ + IndexParallelism: GetUint32Ptr(2), + ReIndexParallelism: GetUint32Ptr(3), + }, + }, + }, + }, + }, + } + + for _, tc := range testcases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + err := suite.AvsClient.IndexCreate(ctx, tc.namespace, tc.indexName, tc.vectorField, tc.dimension, tc.vectorDistanceMetric, tc.opts) + suite.NoError(err) + + if err != nil { + return + } + + defer suite.AvsClient.IndexDrop(ctx, tc.namespace, tc.indexName) + + index, err := suite.AvsClient.IndexGet(ctx, tc.namespace, tc.indexName, false) + suite.NoError(err) + + if err != nil { + return + } + + err = suite.AvsClient.IndexUpdate(ctx, tc.namespace, tc.indexName, tc.updateLabels, tc.updateHnsw) + suite.NoError(err) + + if err != nil { + return + } + + time.Sleep(time.Second * 3) + + index, err = suite.AvsClient.IndexGet(ctx, tc.namespace, tc.indexName, false) + suite.NoError(err) + + suite.EqualExportedValues(tc.expectedIndex, *index) + }) + } +} + // func (suite *SingleNodeTestSuite) TestIsIndexed() { // records := []struct { // namespace string @@ -436,6 +615,72 @@ func (suite *SingleNodeTestSuite) TestIndexCreate() { // } // } +func (suite *SingleNodeTestSuite) TestIndexGetStatus() { + testcases := []struct { + name string + namespace string + indexName string + vectorField string + dimension uint32 + vectorDistanceMetric protos.VectorDistanceMetric + opts *IndexCreateOpts + expectedStatus protos.IndexStatusResponse + }{ + { + name: "basic", + namespace: "test", + indexName: "index", + vectorField: "vector", + dimension: 10, + vectorDistanceMetric: protos.VectorDistanceMetric_SQUARED_EUCLIDEAN, + opts: nil, + expectedStatus: protos.IndexStatusResponse{UnmergedRecordCount: 0}, + }, + } + + for _, tc := range testcases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + err := suite.AvsClient.IndexCreate(ctx, tc.namespace, tc.indexName, tc.vectorField, tc.dimension, tc.vectorDistanceMetric, tc.opts) + suite.NoError(err) + + if err != nil { + return + } + + defer suite.AvsClient.IndexDrop(ctx, tc.namespace, tc.indexName) + + status, err := suite.AvsClient.IndexGetStatus(ctx, tc.namespace, tc.indexName) + suite.NoError(err) + + if err != nil { + return + } + + suite.EqualExportedValues(tc.expectedStatus, *status) + }) + } +} + +func (suite *SingleNodeTestSuite) TestIndexGCInvalidVertices() { + ctx := context.Background() + namespace := "test" + indexName := "index" + vectorField := "vector" + dimension := uint32(10) + vectorDistanceMetric := protos.VectorDistanceMetric_SQUARED_EUCLIDEAN + opts := &IndexCreateOpts{} + + err := suite.AvsClient.IndexCreate(ctx, namespace, indexName, vectorField, dimension, vectorDistanceMetric, opts) + suite.NoError(err) + + defer suite.AvsClient.IndexDrop(ctx, namespace, indexName) + + err = suite.AvsClient.GcInvalidVertices(ctx, namespace, indexName, time.Now()) + suite.NoError(err) +} + func (suite *SingleNodeTestSuite) TestFailsToInsertAlreadyExists() { ctx := context.Background() key := getUniqueKey() @@ -791,6 +1036,31 @@ func (suite *SingleNodeTestSuite) TestUserDelete() { suite.Error(err) } +func (suite *SingleNodeTestSuite) TestUserUpdateCredentials() { + suite.SkipIfUserPassAuthDisabled() + + ctx := context.Background() + username := getUniqueUserName() + + err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) + suite.NoError(err) + + err = suite.AvsClient.UpdateCredentials(ctx, username, "new-password") + suite.NoError(err) + + _, err = NewClient( + ctx, + HostPortSlice{suite.AvsHostPort}, + nil, + suite.AvsLB, + NewCredentialsFromUserPass(username, "new-password"), + suite.AvsTLSConfig, + suite.Logger, + ) + suite.NoError(err) + +} + func (suite *SingleNodeTestSuite) TestUserGrantRoles() { suite.SkipIfUserPassAuthDisabled() @@ -894,6 +1164,91 @@ func (suite *SingleNodeTestSuite) TestNodeIDs() { } } +func (suite *SingleNodeTestSuite) TestConnectedNodeEndpoint() { + testCases := []struct { + name string + nodeId *protos.NodeId + listenerName *string + expectedEndpoints *protos.ServerEndpoint + expectedErrMsg *string + }{ + { + name: "nil-node", + nodeId: nil, + expectedEndpoints: &protos.ServerEndpoint{ + Address: "localhost", + Port: 10000, + IsTls: false, + }, + }, + { + name: "node id DNE", + nodeId: &protos.NodeId{ + Id: 1, + }, + expectedErrMsg: GetStrPtr("failed to get connected endpoint"), + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + endpoint, err := suite.AvsClient.ConnectedNodeEndpoint(ctx, tc.nodeId) + + if tc.expectedErrMsg != nil { + suite.Error(err) + suite.Contains(err.Error(), *tc.expectedErrMsg) + return + } else { + suite.NoError(err) + } + + suite.EqualExportedValues(tc.expectedEndpoints, endpoint) + }) + } +} + +func (suite *SingleNodeTestSuite) TestClusteringState() { + testCases := []struct { + name string + nodeId *protos.NodeId + listenerName *string + expectedErrMsg *string + }{ + { + name: "nil-node", + nodeId: nil, + }, + { + name: "node id DNE", + nodeId: &protos.NodeId{ + Id: 1, + }, + expectedErrMsg: GetStrPtr("failed to get clustering state"), + }, + } + + for _, tc := range testCases { + suite.T().Run(tc.name, func(t *testing.T) { + ctx := context.Background() + state, err := suite.AvsClient.ClusteringState(ctx, tc.nodeId) + + if tc.expectedErrMsg != nil { + suite.Error(err) + suite.Contains(err.Error(), *tc.expectedErrMsg) + return + } else { + suite.NoError(err) + } + + // Simple test. Ideally we will have a better check in the future. + // Does not test cluster-id + suite.True(state.IsInCluster) + suite.Len(state.Members, 1) + }) + } +} + func (suite *SingleNodeTestSuite) TestClusterEndpoints() { testCases := []struct { name string diff --git a/testutils.go b/testutils.go index 8aebcf5..c70c4f0 100644 --- a/testutils.go +++ b/testutils.go @@ -82,6 +82,16 @@ func GetUint32Ptr(i int) *uint32 { return &ptr } +func GetUint64Ptr(i int) *uint64 { + ptr := uint64(i) + return &ptr +} + +func GetFloat32Ptr(i float32) *float32 { + ptr := float32(i) + return &ptr +} + func GetBoolPtr(b bool) *bool { ptr := b return &ptr From 6156272c1f4266d5a55fdec12e05cf00a96f0f2f Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 16 Sep 2024 12:17:02 -0700 Subject: [PATCH 06/42] cleanup client --- integration_single_node_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integration_single_node_test.go b/integration_single_node_test.go index 13e68e2..dbd3842 100644 --- a/integration_single_node_test.go +++ b/integration_single_node_test.go @@ -1048,7 +1048,7 @@ func (suite *SingleNodeTestSuite) TestUserUpdateCredentials() { err = suite.AvsClient.UpdateCredentials(ctx, username, "new-password") suite.NoError(err) - _, err = NewClient( + c, err := NewClient( ctx, HostPortSlice{suite.AvsHostPort}, nil, @@ -1059,6 +1059,7 @@ func (suite *SingleNodeTestSuite) TestUserUpdateCredentials() { ) suite.NoError(err) + c.Close() } func (suite *SingleNodeTestSuite) TestUserGrantRoles() { From d071716396e2d9e18c6712f5562aeea5d644fe85 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 16 Sep 2024 12:38:20 -0700 Subject: [PATCH 07/42] rename channel to connection --- client.go | 68 +++---- channel_provider.go => connection_provider.go | 170 +++++++++--------- 2 files changed, 119 insertions(+), 119 deletions(-) rename channel_provider.go => connection_provider.go (75%) diff --git a/client.go b/client.go index 27a34f6..40933ee 100644 --- a/client.go +++ b/client.go @@ -31,8 +31,8 @@ const ( // Client is a client for managing Aerospike Vector Indexes. type Client struct { - logger *slog.Logger - channelProvider *channelProvider + logger *slog.Logger + connectionProvider *connectionProvider } // NewClient creates a new Client instance. @@ -64,7 +64,7 @@ func NewClient( logger = logger.WithGroup("avs") logger.Info("creating new client") - channelProvider, err := newChannelProvider( + connectionProvider, err := newConnectionProvider( ctx, seeds, listenerName, @@ -74,13 +74,13 @@ func NewClient( logger, ) if err != nil { - logger.Error("failed to create channel provider", slog.Any("error", err)) + logger.Error("failed to create connection provider", slog.Any("error", err)) return nil, NewAVSErrorFromGrpc("failed to connect to server", err) } return &Client{ - logger: logger, - channelProvider: channelProvider, + logger: logger, + connectionProvider: connectionProvider, }, nil } @@ -91,7 +91,7 @@ func NewClient( // error: An error if the closure fails, otherwise nil. func (c *Client) Close() error { c.logger.Info("Closing client") - return c.channelProvider.Close() + return c.connectionProvider.Close() } func (c *Client) put( @@ -108,7 +108,7 @@ func (c *Client) put( slog.Any("key", key), ) - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToInsertRecord, slog.Any("error", err)) return NewAVSError(failedToInsertRecord, err) @@ -259,7 +259,7 @@ func (c *Client) Get(ctx context.Context, ) logger.DebugContext(ctx, "getting record") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToGetRecord, slog.Any("error", err)) return nil, NewAVSError(failedToGetRecord, err) @@ -306,7 +306,7 @@ func (c *Client) Delete(ctx context.Context, namespace string, set *string, key ) logger.DebugContext(ctx, "deleting record") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToDeleteRecord, slog.Any("error", err)) return NewAVSError(failedToDeleteRecord, err) @@ -356,7 +356,7 @@ func (c *Client) Exists( ) logger.DebugContext(ctx, "checking if record exists") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToCheckRecordExists, slog.Any("error", err)) return false, NewAVSError(failedToCheckRecordExists, err) @@ -409,7 +409,7 @@ func (c *Client) IsIndexed( ) logger.DebugContext(ctx, "checking if record is indexed") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error(failedToCheckIsIndexed, slog.Any("error", err)) return false, NewAVSError(failedToCheckIsIndexed, err) @@ -453,7 +453,7 @@ func (c *Client) vectorSearch(ctx context.Context, logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "searching for vector") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to search for vector" logger.Error(msg, slog.Any("error", err)) @@ -623,7 +623,7 @@ func (c *Client) WaitForIndexCompletion( logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "waiting for index completion") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { logger.Error("failed to wait for index completion", slog.Any("error", err)) return err @@ -775,7 +775,7 @@ func (c *Client) IndexCreateFromIndexDef( logger := c.logger.With(slog.Any("definition", indexDef)) logger.DebugContext(ctx, "creating index from definition") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to create index from definition" logger.Error(msg, slog.Any("error", err)) @@ -825,7 +825,7 @@ func (c *Client) IndexUpdate( logger.DebugContext(ctx, "updating index") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to update index" logger.Error(msg, slog.Any("error", err)) @@ -870,7 +870,7 @@ func (c *Client) IndexDrop(ctx context.Context, namespace, indexName string) err logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "dropping index") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to drop index" logger.Error(msg, slog.Any("error", err)) @@ -913,7 +913,7 @@ func (c *Client) IndexDrop(ctx context.Context, namespace, indexName string) err func (c *Client) IndexList(ctx context.Context, applyDefaults bool) (*protos.IndexDefinitionList, error) { c.logger.DebugContext(ctx, "listing indexes") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to get indexes" @@ -960,7 +960,7 @@ func (c *Client) IndexGet( logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "getting index") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to get index" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1003,7 +1003,7 @@ func (c *Client) IndexGetStatus(ctx context.Context, namespace, indexName string logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) logger.DebugContext(ctx, "getting index status") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to get index status" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1045,7 +1045,7 @@ func (c *Client) GcInvalidVertices(ctx context.Context, namespace, indexName str logger.DebugContext(ctx, "garbage collection invalid vertices") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to garbage collect invalid vertices" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1088,7 +1088,7 @@ func (c *Client) CreateUser(ctx context.Context, username, password string, role logger := c.logger.With(slog.String("username", username), slog.Any("roles", roles)) logger.DebugContext(ctx, "creating user") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to create user" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1127,7 +1127,7 @@ func (c *Client) UpdateCredentials(ctx context.Context, username, password strin logger := c.logger.With(slog.String("username", username)) logger.DebugContext(ctx, "updating user credentials") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to update user credentials" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1164,7 +1164,7 @@ func (c *Client) DropUser(ctx context.Context, username string) error { logger := c.logger.With(slog.String("username", username)) logger.DebugContext(ctx, "dropping user") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to drop user" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1202,7 +1202,7 @@ func (c *Client) GetUser(ctx context.Context, username string) (*protos.User, er logger := c.logger.With(slog.String("username", username)) logger.DebugContext(ctx, "getting user") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to get user" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1238,7 +1238,7 @@ func (c *Client) GetUser(ctx context.Context, username string) (*protos.User, er func (c *Client) ListUsers(ctx context.Context) (*protos.ListUsersResponse, error) { c.logger.DebugContext(ctx, "listing users") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to list users" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1272,7 +1272,7 @@ func (c *Client) GrantRoles(ctx context.Context, username string, roles []string logger := c.logger.With(slog.String("username", username), slog.Any("roles", roles)) logger.DebugContext(ctx, "granting user roles") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to grant user roles" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1311,7 +1311,7 @@ func (c *Client) RevokeRoles(ctx context.Context, username string, roles []strin logger := c.logger.With(slog.String("username", username), slog.Any("roles", roles)) logger.DebugContext(ctx, "revoking user roles") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to revoke user roles" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1348,7 +1348,7 @@ func (c *Client) RevokeRoles(ctx context.Context, username string, roles []strin func (c *Client) ListRoles(ctx context.Context) (*protos.ListRolesResponse, error) { c.logger.DebugContext(ctx, "listing roles") - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to list roles" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -1382,7 +1382,7 @@ func (c *Client) ListRoles(ctx context.Context) (*protos.ListRolesResponse, erro func (c *Client) NodeIDs(ctx context.Context) []*protos.NodeId { c.logger.DebugContext(ctx, "getting cluster info") - ids := c.channelProvider.GetNodeIDs() + ids := c.connectionProvider.GetNodeIDs() nodeIDs := make([]*protos.NodeId, len(ids)) for i, id := range ids { @@ -1550,10 +1550,10 @@ func (c *Client) About(ctx context.Context, nodeID *protos.NodeId) (*protos.Abou func (c *Client) getConnection(nodeID *protos.NodeId) (*connection, error) { if nodeID == nil { - return c.channelProvider.GetSeedConn() + return c.connectionProvider.GetSeedConn() } - return c.channelProvider.GetNodeConn(nodeID.GetId()) + return c.connectionProvider.GetNodeConn(nodeID.GetId()) } // waitForIndexCreation waits for an index to be created and blocks until it is. @@ -1565,7 +1565,7 @@ func (c *Client) waitForIndexCreation(ctx context.Context, ) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to wait for index creation" logger.Error(msg, slog.Any("error", err)) @@ -1616,7 +1616,7 @@ func (c *Client) waitForIndexCreation(ctx context.Context, func (c *Client) waitForIndexDrop(ctx context.Context, namespace, indexName string, waitInterval time.Duration) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) - conn, err := c.channelProvider.GetRandomConn() + conn, err := c.connectionProvider.GetRandomConn() if err != nil { msg := "failed to wait for index deletion" logger.Error(msg, slog.Any("error", err)) diff --git a/channel_provider.go b/connection_provider.go similarity index 75% rename from channel_provider.go rename to connection_provider.go index 6030644..4632393 100644 --- a/channel_provider.go +++ b/connection_provider.go @@ -1,4 +1,4 @@ -// Package avs provides a channel provider for connecting to Aerospike servers. +// Package avs provides a connection provider for connecting to Aerospike servers. package avs import ( @@ -21,7 +21,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -var errChannelProviderClosed = errors.New("channel provider is closed") +var errConnectionProviderClosed = errors.New("connection provider is closed") // connection represents a gRPC client connection and all the clients (stubs) // for the various AVS services. It's main purpose to remove the need to create @@ -50,27 +50,27 @@ func newConnection(conn *grpc.ClientConn) *connection { } } -// channelAndEndpoints represents a combination of a gRPC client connection and server endpoints. -type channelAndEndpoints struct { +// connectionAndEndpoints represents a combination of a gRPC client connection and server endpoints. +type connectionAndEndpoints struct { conn *connection endpoints *protos.ServerEndpointList } -// newConnAndEndpoints creates a new channelAndEndpoints instance. -func newConnAndEndpoints(channel *connection, endpoints *protos.ServerEndpointList) *channelAndEndpoints { - return &channelAndEndpoints{ - conn: channel, +// newConnAndEndpoints creates a new connectionAndEndpoints instance. +func newConnAndEndpoints(conn *connection, endpoints *protos.ServerEndpointList) *connectionAndEndpoints { + return &connectionAndEndpoints{ + conn: conn, endpoints: endpoints, } } -// channelProvider is responsible for managing gRPC client connections to +// connectionProvider is responsible for managing gRPC client connections to // Aerospike servers. // //nolint:govet // We will favor readability over field alignment -type channelProvider struct { +type connectionProvider struct { logger *slog.Logger - nodeConns map[uint64]*channelAndEndpoints + nodeConns map[uint64]*connectionAndEndpoints seedConns []*connection tlsConfig *tls.Config seeds HostPortSlice @@ -84,8 +84,8 @@ type channelProvider struct { closed atomic.Bool } -// newChannelProvider creates a new channelProvider instance. -func newChannelProvider( +// newConnectionProvider creates a new connectionProvider instance. +func newConnectionProvider( ctx context.Context, seeds HostPortSlice, listenerName *string, @@ -93,7 +93,7 @@ func newChannelProvider( credentials *UserPassCredentials, tlsConfig *tls.Config, logger *slog.Logger, -) (*channelProvider, error) { +) (*connectionProvider, error) { // Initialize the logger. logger = logger.WithGroup("cp") @@ -119,9 +119,9 @@ func newChannelProvider( } } - // Create the channelProvider instance. - cp := &channelProvider{ - nodeConns: make(map[uint64]*channelAndEndpoints), + // Create the connectionProvider instance. + cp := &connectionProvider{ + nodeConns: make(map[uint64]*connectionAndEndpoints), seeds: seeds, listenerName: listenerName, isLoadBalancer: isLoadBalancer, @@ -148,7 +148,7 @@ func newChannelProvider( // Start the tend routine if load balancing is disabled. if !isLoadBalancer { - cp.updateClusterChannels(ctx) // We want at least one tend to occur before we return + cp.updateClusterConns(ctx) // We want at least one tend to occur before we return cp.logger.Debug("starting tend routine") go cp.tend(context.Background()) // Might add a tend specific timeout in the future? @@ -159,8 +159,8 @@ func newChannelProvider( return cp, nil } -// Close closes the channelProvider and releases all resources. -func (cp *channelProvider) Close() error { +// Close closes the connectionProvider and releases all resources. +func (cp *connectionProvider) Close() error { if !cp.isLoadBalancer { cp.stopTendChan <- struct{}{} <-cp.stopTendChan @@ -172,16 +172,16 @@ func (cp *channelProvider) Close() error { cp.token.Close() } - for _, channel := range cp.seedConns { - err := channel.grpcConn.Close() + for _, conn := range cp.seedConns { + err := conn.grpcConn.Close() if err != nil { if firstErr == nil { firstErr = err } - cp.logger.Error("failed to close seed channel", + cp.logger.Error("failed to close seed connection", slog.Any("error", err), - slog.String("seed", channel.grpcConn.Target()), + slog.String("seed", conn.grpcConn.Target()), ) } } @@ -193,7 +193,7 @@ func (cp *channelProvider) Close() error { firstErr = err } - cp.logger.Error("failed to close node channel", + cp.logger.Error("failed to close node connection", slog.Any("error", err), slog.String("node", conn.conn.grpcConn.Target()), ) @@ -207,14 +207,14 @@ func (cp *channelProvider) Close() error { } // GetSeedConn returns a gRPC client connection to a seed node. -func (cp *channelProvider) GetSeedConn() (*connection, error) { +func (cp *connectionProvider) GetSeedConn() (*connection, error) { if cp.closed.Load() { - cp.logger.Warn("ChannelProvider is closed, cannot get channel") - return nil, errChannelProviderClosed + cp.logger.Warn("ConnectionProvider is closed, cannot get connection") + return nil, errConnectionProviderClosed } if len(cp.seedConns) == 0 { - msg := "no seed channels found" + msg := "no seed connections found" cp.logger.Warn(msg) return nil, errors.New(msg) @@ -227,69 +227,69 @@ func (cp *channelProvider) GetSeedConn() (*connection, error) { // GetRandomConn returns a gRPC client connection to an Aerospike server. If // isLoadBalancer is enabled, it will return the seed connection. -func (cp *channelProvider) GetRandomConn() (*connection, error) { +func (cp *connectionProvider) GetRandomConn() (*connection, error) { if cp.closed.Load() { - cp.logger.Warn("ChannelProvider is closed, cannot get channel") - return nil, errors.New("ChannelProvider is closed") + cp.logger.Warn("ConnectionProvider is closed, cannot get connection") + return nil, errors.New("ConnectionProvider is closed") } if cp.isLoadBalancer { - cp.logger.Debug("load balancer is enabled, using seed channel") + cp.logger.Debug("load balancer is enabled, using seed connection") return cp.GetSeedConn() } cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - discoverdChannels := make([]*channelAndEndpoints, len(cp.nodeConns)) + discoverdConns := make([]*connectionAndEndpoints, len(cp.nodeConns)) i := 0 - for _, channel := range cp.nodeConns { - discoverdChannels[i] = channel + for _, conn := range cp.nodeConns { + discoverdConns[i] = conn i++ } - if len(discoverdChannels) == 0 { - cp.logger.Warn("no node channels found, using seed channel") + if len(discoverdConns) == 0 { + cp.logger.Warn("no node connections found, using seed connection") return cp.GetSeedConn() } - idx := rand.Intn(len(discoverdChannels)) //nolint:gosec // Security is not an issue here + idx := rand.Intn(len(discoverdConns)) //nolint:gosec // Security is not an issue here - return discoverdChannels[idx].conn, nil + return discoverdConns[idx].conn, nil } // GetNodeConn returns a gRPC client connection to a specific node. If the node // ID cannot be found an error is returned. -func (cp *channelProvider) GetNodeConn(nodeID uint64) (*connection, error) { +func (cp *connectionProvider) GetNodeConn(nodeID uint64) (*connection, error) { if cp.closed.Load() { - cp.logger.Warn("ChannelProvider is closed, cannot get channel") - return nil, errors.New("ChannelProvider is closed") + cp.logger.Warn("ConnectionProvider is closed, cannot get connection") + return nil, errors.New("ConnectionProvider is closed") } if cp.isLoadBalancer { - cp.logger.Error("load balancer is enabled, using seed channel") - return nil, errors.New("load balancer is enabled, cannot get specific node channel") + cp.logger.Error("load balancer is enabled, using seed connection") + return nil, errors.New("load balancer is enabled, cannot get specific node connection") } cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - channel, ok := cp.nodeConns[nodeID] + conn, ok := cp.nodeConns[nodeID] if !ok { - msg := "channel not found for specified node id" + msg := "connection not found for specified node id" cp.logger.Error(msg, slog.Uint64("node", nodeID)) return nil, errors.New(msg) } - return channel.conn, nil + return conn.conn, nil } // GetNodeIDs returns the node IDs of all nodes discovered during cluster // tending. If tending is disabled (LB true) then no node IDs are returned. -func (cp *channelProvider) GetNodeIDs() []uint64 { +func (cp *connectionProvider) GetNodeIDs() []uint64 { cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() @@ -303,9 +303,9 @@ func (cp *channelProvider) GetNodeIDs() []uint64 { } // connectToSeeds connects to the seed nodes and creates gRPC client connections. -func (cp *channelProvider) connectToSeeds(ctx context.Context) error { +func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { if len(cp.seedConns) != 0 { - msg := "seed channels already exist, close them first" + msg := "seed connections already exist, close them first" cp.logger.Error(msg) return errors.New(msg) @@ -329,7 +329,7 @@ func (cp *channelProvider) connectToSeeds(ctx context.Context) error { grpcConn, err := cp.createGrcpConn(seed) if err != nil { - logger.ErrorContext(ctx, "failed to create channel", slog.Any("error", err)) + logger.ErrorContext(ctx, "failed to create connection", slog.Any("error", err)) return } @@ -401,7 +401,7 @@ func (cp *channelProvider) connectToSeeds(ctx context.Context) error { } // updateNodeConns updates the gRPC client connection for a specific node. -func (cp *channelProvider) updateNodeConns( +func (cp *connectionProvider) updateNodeConns( ctx context.Context, node uint64, endpoints *protos.ServerEndpointList, @@ -424,7 +424,7 @@ func (cp *channelProvider) updateNodeConns( } // checkAndSetClusterID checks if the cluster ID has changed and updates it if necessary. -func (cp *channelProvider) checkAndSetClusterID(clusterID uint64) bool { +func (cp *connectionProvider) checkAndSetClusterID(clusterID uint64) bool { if clusterID != cp.clusterID { cp.clusterID = clusterID return true @@ -434,28 +434,28 @@ func (cp *channelProvider) checkAndSetClusterID(clusterID uint64) bool { } // getTendConns returns all the gRPC client connections for tend operations. -func (cp *channelProvider) getTendConns() []*grpc.ClientConn { +func (cp *connectionProvider) getTendConns() []*grpc.ClientConn { cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - channels := make([]*grpc.ClientConn, len(cp.seedConns)+len(cp.nodeConns)) + conns := make([]*grpc.ClientConn, len(cp.seedConns)+len(cp.nodeConns)) i := 0 - for _, channel := range cp.seedConns { - channels[i] = channel.grpcConn + for _, conn := range cp.seedConns { + conns[i] = conn.grpcConn i++ } - for _, channel := range cp.nodeConns { - channels[i] = channel.conn.grpcConn + for _, conn := range cp.nodeConns { + conns[i] = conn.conn.grpcConn i++ } - return channels + return conns } // getUpdatedEndpoints retrieves the updated server endpoints from the Aerospike cluster. -func (cp *channelProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]*protos.ServerEndpointList { +func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]*protos.ServerEndpointList { conns := cp.getTendConns() endpointsChan := make(chan map[uint64]*protos.ServerEndpointList) endpointsReq := &protos.ClusterNodeEndpointsRequest{ListenerName: cp.listenerName} @@ -478,7 +478,7 @@ func (cp *channelProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]* if !cp.checkAndSetClusterID(clusterID.GetId()) { logger.DebugContext( ctx, - "old cluster ID found, skipping channel discovery", + "old cluster ID found, skipping connection discovery", slog.Uint64("clusterID", clusterID.GetId()), ) @@ -515,12 +515,12 @@ func (cp *channelProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]* } // Checks if the node connections need to be updated and updates them if necessary. -func (cp *channelProvider) checkAndSetNodeConns( +func (cp *connectionProvider) checkAndSetNodeConns( ctx context.Context, newNodeEndpoints map[uint64]*protos.ServerEndpointList, ) { wg := sync.WaitGroup{} - // Find which nodes have a different endpoint list and update their channel + // Find which nodes have a different endpoint list and update their connection for node, newEndpoints := range newNodeEndpoints { wg.Add(1) @@ -535,27 +535,27 @@ func (cp *channelProvider) checkAndSetNodeConns( if ok { if !endpointListEqual(currEndpoints.endpoints, newEndpoints) { - logger.Debug("endpoints for node changed, recreating channel") + logger.Debug("endpoints for node changed, recreating connection") err := currEndpoints.conn.grpcConn.Close() if err != nil { - logger.Warn("failed to close channel", slog.Any("error", err)) + logger.Warn("failed to close connection", slog.Any("error", err)) } // Either this is a new node or its endpoints have changed err = cp.updateNodeConns(ctx, node, newEndpoints) if err != nil { - logger.Error("failed to create new channel", slog.Any("error", err)) + logger.Error("failed to create new connection", slog.Any("error", err)) } } else { cp.logger.Debug("endpoints for node unchanged") } } else { - logger.Debug("new node found, creating new channel") + logger.Debug("new node found, creating new connection") err := cp.updateNodeConns(ctx, node, newEndpoints) if err != nil { - logger.Error("failed to create new channel", slog.Any("error", err)) + logger.Error("failed to create new connection", slog.Any("error", err)) } } }(node, newEndpoints) @@ -566,16 +566,16 @@ func (cp *channelProvider) checkAndSetNodeConns( // removeDownNodes removes the gRPC client connections for nodes in nodeConns // that aren't apart of newNodeEndpoints -func (cp *channelProvider) removeDownNodes(newNodeEndpoints map[uint64]*protos.ServerEndpointList) { +func (cp *connectionProvider) removeDownNodes(newNodeEndpoints map[uint64]*protos.ServerEndpointList) { cp.nodeConnsLock.Lock() defer cp.nodeConnsLock.Unlock() - // The cluster state changed. Remove old channels. - for node, channelEndpoints := range cp.nodeConns { + // The cluster state changed. Remove old connections. + for node, connEndpoints := range cp.nodeConns { if _, ok := newNodeEndpoints[node]; !ok { - err := channelEndpoints.conn.grpcConn.Close() + err := connEndpoints.conn.grpcConn.Close() if err != nil { - cp.logger.Warn("failed to close channel", slog.Uint64("node", node), slog.Any("error", err)) + cp.logger.Warn("failed to close connection", slog.Uint64("node", node), slog.Any("error", err)) } delete(cp.nodeConns, node) @@ -583,23 +583,23 @@ func (cp *channelProvider) removeDownNodes(newNodeEndpoints map[uint64]*protos.S } } -// updateClusterChannels updates the gRPC client connections for the Aerospike +// updateClusterConns updates the gRPC client connections for the Aerospike // cluster if the cluster state has changed. -func (cp *channelProvider) updateClusterChannels(ctx context.Context) { +func (cp *connectionProvider) updateClusterConns(ctx context.Context) { updatedEndpoints := cp.getUpdatedEndpoints(ctx) if updatedEndpoints == nil { - cp.logger.Debug("no new cluster ID found, cluster state is unchanged, skipping channel discovery") + cp.logger.Debug("no new cluster ID found, cluster state is unchanged, skipping connection discovery") return } - cp.logger.Debug("new cluster id found, updating channels") + cp.logger.Debug("new cluster id found, updating connections") cp.checkAndSetNodeConns(ctx, updatedEndpoints) cp.removeDownNodes(updatedEndpoints) } -// tend starts a thread to periodically update the cluster channels. -func (cp *channelProvider) tend(ctx context.Context) { +// tend starts a thread to periodically update the cluster connections. +func (cp *connectionProvider) tend(ctx context.Context) { timer := time.NewTimer(cp.tendInterval) defer timer.Stop() @@ -612,7 +612,7 @@ func (cp *channelProvider) tend(ctx context.Context) { ctx, cancel := context.WithTimeout(ctx, cp.tendInterval) // TODO: make configurable? - cp.updateClusterChannels(ctx) + cp.updateClusterConns(ctx) if err := ctx.Err(); err != nil { cp.logger.Warn("tend context cancelled", slog.Any("error", err)) @@ -673,7 +673,7 @@ func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort { // createGrpcConnFromEndpoints creates a gRPC client connection from the first // successful endpoint in endpoints. -func (cp *channelProvider) createGrpcConnFromEndpoints( +func (cp *connectionProvider) createGrpcConnFromEndpoints( endpoints *protos.ServerEndpointList, ) (*grpc.ClientConn, error) { for _, endpoint := range endpoints.Endpoints { @@ -693,7 +693,7 @@ func (cp *channelProvider) createGrpcConnFromEndpoints( // createGrcpConn creates a gRPC client connection to a host. This handles adding // credential and configuring tls. -func (cp *channelProvider) createGrcpConn(hostPort *HostPort) (*grpc.ClientConn, error) { +func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (*grpc.ClientConn, error) { opts := []grpc.DialOption{} if cp.tlsConfig == nil { @@ -724,7 +724,7 @@ func (cp *channelProvider) createGrcpConn(hostPort *HostPort) (*grpc.ClientConn, return conn, nil } -func (cp *channelProvider) createConnFromEndpoints(endpoints *protos.ServerEndpointList) (*connection, error) { +func (cp *connectionProvider) createConnFromEndpoints(endpoints *protos.ServerEndpointList) (*connection, error) { conn, err := cp.createGrpcConnFromEndpoints(endpoints) if err != nil { return nil, err From 08464a1affb78b2e28a58f43e8110e29010764bd Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 16 Sep 2024 17:01:35 -0700 Subject: [PATCH 08/42] add client unit tests with mocks --- client.go | 23 +- client_test.go | 838 +++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 5 +- go.sum | 29 ++ makefile | 20 +- 5 files changed, 909 insertions(+), 6 deletions(-) create mode 100644 client_test.go diff --git a/client.go b/client.go index 40933ee..17aaa56 100644 --- a/client.go +++ b/client.go @@ -29,10 +29,18 @@ const ( failedToCheckIsIndexed = "failed to check if record is indexed" ) +type connProvider interface { + GetNodeIDs() []uint64 + GetRandomConn() (*connection, error) + GetSeedConn() (*connection, error) + GetNodeConn(id uint64) (*connection, error) + Close() error +} + // Client is a client for managing Aerospike Vector Indexes. type Client struct { logger *slog.Logger - connectionProvider *connectionProvider + connectionProvider connProvider } // NewClient creates a new Client instance. @@ -78,6 +86,13 @@ func NewClient( return nil, NewAVSErrorFromGrpc("failed to connect to server", err) } + return newClient(connectionProvider, logger) +} + +func newClient( + connectionProvider connProvider, + logger *slog.Logger, +) (*Client, error) { return &Client{ logger: logger, connectionProvider: connectionProvider, @@ -314,8 +329,8 @@ func (c *Client) Delete(ctx context.Context, namespace string, set *string, key protoKey, err := protos.ConvertToKey(namespace, set, key) if err != nil { - logger.Error(failedToInsertRecord, slog.Any("error", err)) - return NewAVSError(failedToInsertRecord, err) + logger.Error(failedToDeleteRecord, slog.Any("error", err)) + return NewAVSError(failedToDeleteRecord, err) } getReq := &protos.DeleteRequest{ @@ -325,7 +340,7 @@ func (c *Client) Delete(ctx context.Context, namespace string, set *string, key _, err = conn.transactClient.Delete(ctx, getReq) if err != nil { logger.Error(failedToDeleteRecord, slog.Any("error", err)) - return NewAVSErrorFromGrpc(failedToGetRecord, err) + return NewAVSErrorFromGrpc(failedToDeleteRecord, err) } return nil diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..8f5688d --- /dev/null +++ b/client_test.go @@ -0,0 +1,838 @@ +package avs + +import ( + "context" + "fmt" + "log/slog" + "testing" + "time" + + "github.com/aerospike/avs-client-go/protos" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestInsert_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedPutRequest := &protos.PutRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + WriteType: protos.WriteType_INSERT_ONLY, + Fields: []*protos.Field{ + { + Name: "field1", + Value: &protos.Value{Value: &protos.Value_StringValue{StringValue: "value1"}}, + }, + }, + IgnoreMemQueueFull: false, + } + + // Set up expectations for transactClient.Put() + mockTransactClient. + EXPECT(). + Put(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.PutRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedPutRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + assert.NoError(t, err) +} + +func TestInsert_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToInsertRecord, fmt.Errorf("foo"))) +} + +func TestInsert_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToInsertRecord, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestInsert_FailsConvertingFields(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + recordData := map[string]interface{}{"field1": struct{}{}} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToInsertRecord, fmt.Errorf("error converting field value for key 'field1': unsupported value type: struct {}"))) +} + +func TestInsert_FailsPutRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockTransactClient. + EXPECT(). + Put(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + err = client.Insert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToInsertRecord, fmt.Errorf("foo"))) +} + +func TestUpdate_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock connProvider + mockConnProvider := NewMockconnProvider(ctrl) + // Create a mock transactClient + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + // Create a mock connection + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedPutRequest := &protos.PutRequest{ + // You need to fill this with expected values + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + WriteType: protos.WriteType_UPDATE_ONLY, + Fields: []*protos.Field{ + { + Name: "field1", + Value: &protos.Value{Value: &protos.Value_StringValue{StringValue: "value1"}}, + }, + }, + IgnoreMemQueueFull: false, + } + + // Set up expectations for transactClient.Put() + mockTransactClient. + EXPECT(). + Put(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.PutRequest, opts ...grpc.CallOption) { + // Optionally, you can assert that req matches expectedPutRequest + assert.Equal(t, expectedPutRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + // Call the method under test + err = client.Update(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + // Assert no error occurred + assert.NoError(t, err) +} + +func TestReplace_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock connProvider + mockConnProvider := NewMockconnProvider(ctrl) + // Create a mock transactClient + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + // Create a mock connection + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedPutRequest := &protos.PutRequest{ + // You need to fill this with expected values + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + WriteType: protos.WriteType_UPSERT, + Fields: []*protos.Field{ + { + Name: "field1", + Value: &protos.Value{Value: &protos.Value_StringValue{StringValue: "value1"}}, + }, + }, + IgnoreMemQueueFull: false, + } + + // Set up expectations for transactClient.Put() + mockTransactClient. + EXPECT(). + Put(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.PutRequest, opts ...grpc.CallOption) { + // Optionally, you can assert that req matches expectedPutRequest + assert.Equal(t, expectedPutRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + recordData := map[string]interface{}{"field1": "value1"} + ignoreMemQueueFull := false + + // Call the method under test + err = client.Upsert(ctx, namespace, set, key, recordData, ignoreMemQueueFull) + + // Assert no error occurred + assert.NoError(t, err) +} + +func TestGet_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedGetRequest := &protos.GetRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + Projection: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_ALL, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_NONE, + }, + }, + } + + protosRecord := &protos.Record{ + Fields: []*protos.Field{ + {Name: "field1", Value: &protos.Value{Value: &protos.Value_StringValue{StringValue: "value1"}}}, + }, + Metadata: &protos.Record_AerospikeMetadata{ + AerospikeMetadata: &protos.AerospikeRecordMetadata{ + Generation: 10, + Expiration: 1, + }, + }, + } + + expTime := AerospikeEpoch.Add(time.Second * 1) + + expectedRecord := &Record{ + Data: map[string]any{ + "field1": "value1", + }, + Generation: 10, + Expiration: &expTime, + } + + mockTransactClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(protosRecord, nil). + Do(func(ctx context.Context, in *protos.GetRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedGetRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + record, err := client.Get(ctx, namespace, set, key, nil, nil) + + assert.NoError(t, err) + assert.Equal(t, expectedRecord, record) +} + +func TestGet_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + _, err = client.Get(ctx, namespace, set, key, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToGetRecord, fmt.Errorf("foo"))) +} + +func TestGet_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + + _, err = client.Get(ctx, namespace, set, key, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToGetRecord, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestGet_FailsGetRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + mockTransactClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + + _, err = client.Get(ctx, namespace, set, key, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToGetRecord, fmt.Errorf("foo"))) +} + +func TestDelete_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedDeleteRequest := &protos.DeleteRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + } + + mockTransactClient. + EXPECT(). + Delete(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.DeleteRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedDeleteRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + err = client.Delete(ctx, namespace, set, key) + + assert.NoError(t, err) +} + +func TestDelete_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + err = client.Delete(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToDeleteRecord, fmt.Errorf("foo"))) +} + +func TestDelete_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + + err = client.Delete(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToDeleteRecord, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestDelete_FailsDeleteRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + mockTransactClient. + EXPECT(). + Delete(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + + err = client.Delete(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToDeleteRecord, fmt.Errorf("foo"))) +} + +func TestExists_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedExistsRequest := &protos.ExistsRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + } + + mockTransactClient. + EXPECT(). + Exists(gomock.Any(), gomock.Any()). + Return(&protos.Boolean{ + Value: true, + }, nil). + Do(func(ctx context.Context, in *protos.ExistsRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedExistsRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + exists, err := client.Exists(ctx, namespace, set, key) + + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestExists_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + + _, err = client.Exists(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckRecordExists, fmt.Errorf("foo"))) +} + +func TestExists_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + + _, err = client.Exists(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckRecordExists, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestExists_FailsDeleteRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + mockTransactClient. + EXPECT(). + Exists(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + + _, err = client.Exists(ctx, namespace, set, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckRecordExists, fmt.Errorf("foo"))) +} diff --git a/go.mod b/go.mod index 28d5ac7..ca47cad 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,11 @@ go 1.21 require ( github.com/aerospike/tools-common-go v0.0.0-20240701164814-36eec593d9c6 + github.com/golang/mock v1.6.0 github.com/stretchr/testify v1.9.0 + go.uber.org/goleak v1.3.0 + go.uber.org/mock v0.4.0 + golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f golang.org/x/net v0.27.0 google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 @@ -16,7 +20,6 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect - go.uber.org/goleak v1.3.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect diff --git a/go.sum b/go.sum index a922ec1..e58f316 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEe github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 h1:y3N7Bm7Y9/CtpiVkw/ZWj6lSlDF3F74SfKwfTCer72Q= @@ -49,6 +51,7 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= @@ -61,16 +64,42 @@ go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2L go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d h1:JU0iKnSg02Gmb5ZdV8nYsKEKsP6o/FGVWTrw4i1DA9A= google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= diff --git a/makefile b/makefile index eb79cec..57ba5b9 100644 --- a/makefile +++ b/makefile @@ -7,6 +7,9 @@ endif GOLANGCI_LINT ?= $(GOBIN)/golangci-lint GOLANGCI_LINT_VERSION ?= v1.54.0 +MOCKGEN ?= $(GOBIN)/mockgen +MOCKGEN_VERSION ?= v0.3.0 + ROOT_DIR = $(shell pwd) PROTO_DIR = $(ROOT_DIR)/protos COVERAGE_DIR = $(ROOT_DIR)/coverage @@ -21,6 +24,21 @@ protos: protoc --proto_path=$(PROTO_DIR) --go_out=$(PROTO_DIR) --go_opt=paths=source_relative \ --go-grpc_out=$(PROTO_DIR) --go-grpc_opt=paths=source_relative $(PROTO_DIR)/*.proto +.PHONY: get-mockgen +get-mockgen: $(MOCKGEN) ## Download mockgen locally if necessary. +$(MOCKGEN): $(GOBIN) + go install go.uber.org/mock/mockgen@$(MOCKGEN_VERSION) + +.PHONY: mocks +mocks: get-mockgen + $(MOCKGEN) --source client.go --destination client_mock.go --package avs + $(MOCKGEN) --source protos/auth_grpc.pb.go --destination protos/auth_grpc_mock.pb.go --package protos + $(MOCKGEN) --source protos/index_grpc.pb.go --destination protos/index_grpc_mock.pb.go --package protos + $(MOCKGEN) --source protos/transact_grpc.pb.go --destination protos/transact_grpc_mock.pb.go --package protos + $(MOCKGEN) --source protos/user-admin_grpc.pb.go --destination protos/user-admin_grpc_mock.pb.go --package protos + $(MOCKGEN) --source protos/vector-db_grpc.pb.go --destination protos/vector-db_grpc_mock.pb.go --package protos + + .PHONY: test test: integration unit @@ -30,7 +48,7 @@ integration: $(GOLEAK) go test -tags=integration -timeout 30m -cover ./... -args -test.gocoverdir=$(COV_INTEGRATION_DIR) .PHONY: unit -unit: +unit: mocks mkdir -p $(COV_UNIT_DIR) || true go test -tags=unit -cover ./... -args -test.gocoverdir=$(COV_UNIT_DIR) From 69f5094db65adc827ddaec31d3d35050dfaff29a Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 16 Sep 2024 17:05:19 -0700 Subject: [PATCH 09/42] make unit run first --- makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/makefile b/makefile index 57ba5b9..9865f21 100644 --- a/makefile +++ b/makefile @@ -40,7 +40,7 @@ mocks: get-mockgen .PHONY: test -test: integration unit +test: unit integration .PHONY: integration integration: $(GOLEAK) From 0f8bc194a78c101504713bc4ce8c5e0cc702000d Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Tue, 17 Sep 2024 11:29:37 -0700 Subject: [PATCH 10/42] more tests --- client_test.go | 512 +++++++++++++++++++++++++++++++++++++++++++++++++ makefile | 1 + testutils.go | 7 +- 3 files changed, 518 insertions(+), 2 deletions(-) diff --git a/client_test.go b/client_test.go index 8f5688d..38997d5 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,7 @@ package avs import ( "context" "fmt" + "io" "log/slog" "testing" "time" @@ -836,3 +837,514 @@ func TestExists_FailsDeleteRequest(t *testing.T) { assert.ErrorAs(t, err, &avsError) assert.Equal(t, avsError, NewAVSError(failedToCheckRecordExists, fmt.Errorf("foo"))) } + +func TestIsIndexed_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedIsIndexedRequest := &protos.IsIndexedRequest{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: nil, + Value: &protos.Key_StringValue{ + StringValue: "testKey", + }, + }, + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + mockTransactClient. + EXPECT(). + IsIndexed(gomock.Any(), gomock.Any()). + Return(&protos.Boolean{ + Value: true, + }, nil). + Do(func(ctx context.Context, in *protos.IsIndexedRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIsIndexedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + indexName := "testIndex" + + exists, err := client.IsIndexed(ctx, namespace, set, indexName, key) + + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestIsIndexed_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "testKey" + indexName := "testIndex" + + _, err = client.IsIndexed(ctx, namespace, set, indexName, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckIsIndexed, fmt.Errorf("foo"))) +} + +func TestIsIndexed_FailsConvertingKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := struct{}{} + indexName := "testIndex" + + _, err = client.IsIndexed(ctx, namespace, set, indexName, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckIsIndexed, fmt.Errorf("unsupported key type: struct {}"))) +} + +func TestIsIndexed_FailsDeleteRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + mockTransactClient. + EXPECT(). + IsIndexed(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + var set *string = nil + key := "key" + indexName := "testIndex" + + _, err = client.IsIndexed(ctx, namespace, set, indexName, key) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError(failedToCheckIsIndexed, fmt.Errorf("foo"))) +} + +func TestVectorSearchFloat32_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + mockVectorSearchClient := protos.NewMockTransactService_VectorSearchClient(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + // Prepare the expected PutRequest + expectedVectorSearchFloat32Request := &protos.VectorSearchRequest{ + Index: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + QueryVector: &protos.Vector{ + Data: &protos.Vector_FloatData{ + FloatData: &protos.FloatData{ + Value: []float32{1.0, 2.0, 3.0}, + }, + }, + }, + Limit: 7, + SearchParams: &protos.VectorSearchRequest_HnswSearchParams{ + HnswSearchParams: &protos.HnswSearchParams{ + Ef: GetUint32Ptr(8), + }, + }, + Projection: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_ALL, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_NONE, + }, + }, + } + + mockTransactClient. + EXPECT(). + VectorSearch(gomock.Any(), gomock.Any()). + Return(mockVectorSearchClient, nil). + Do(func(ctx context.Context, in *protos.VectorSearchRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedVectorSearchFloat32Request, in) + }) + + vectorCounter := 0 + mockVectorSearchClient. + EXPECT(). + Recv(). + AnyTimes(). + DoAndReturn( + func() (*protos.Neighbor, error) { + vectorCounter++ + + if vectorCounter == 4 { + return nil, io.EOF + } + + return &protos.Neighbor{ + Key: &protos.Key{ + Namespace: "testNamespace", + Set: GetStrPtr("testSet"), + Value: &protos.Key_StringValue{ + StringValue: fmt.Sprintf("key-%d", vectorCounter), + }, + }, + Record: &protos.Record{ + Fields: []*protos.Field{ + { + Name: "field1", + Value: &protos.Value{ + Value: &protos.Value_StringValue{ + StringValue: "value1", + }, + }, + }, + }, + Metadata: &protos.Record_AerospikeMetadata{ + AerospikeMetadata: &protos.AerospikeRecordMetadata{ + Generation: uint32(vectorCounter), + Expiration: uint32(vectorCounter), + }, + }, + }, + Distance: float32(vectorCounter), + }, nil + }, + ) + + expectedNeighbors := []*Neighbor{ + { + Record: &Record{ + Data: map[string]any{ + "field1": "value1", + }, + Generation: uint32(1), + Expiration: GetTimePtr(AerospikeEpoch.Add(time.Second * 1)), + }, + Set: GetStrPtr("testSet"), + Key: "key-1", + Namespace: "testNamespace", + Distance: float32(1), + }, + { + Record: &Record{ + Data: map[string]any{ + "field1": "value1", + }, + Generation: uint32(2), + Expiration: GetTimePtr(AerospikeEpoch.Add(time.Second * 2)), + }, + Set: GetStrPtr("testSet"), + Key: "key-2", + Namespace: "testNamespace", + Distance: float32(2), + }, + { + Record: &Record{ + Data: map[string]any{ + "field1": "value1", + }, + Generation: uint32(3), + Expiration: GetTimePtr(AerospikeEpoch.Add(time.Second * 3)), + }, + Set: GetStrPtr("testSet"), + Key: "key-3", + Namespace: "testNamespace", + Distance: float32(3), + }, + } + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: GetUint32Ptr(8), + } + + neighbors, err := client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + assert.NoError(t, err) + assert.Equal(t, len(expectedNeighbors), len(neighbors)) + for i, _ := range neighbors { + assert.EqualExportedValues(t, expectedNeighbors[i], neighbors[i]) + } +} + +func TestVectorSearchFloat32_FailsGettingConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: GetUint32Ptr(8), + } + + _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to search for vector", fmt.Errorf("foo"))) +} + +func TestVectorSearchFloat32_FailsVectorSearch(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockTransactClient. + EXPECT(). + VectorSearch(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: GetUint32Ptr(8), + } + + _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to search for vector", fmt.Errorf("foo"))) +} + +func TestVectorSearchFloat32_FailedToRecvAllNeighbors(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + mockVectorSearchClient := protos.NewMockTransactService_VectorSearchClient(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockTransactClient. + EXPECT(). + VectorSearch(gomock.Any(), gomock.Any()). + Return(mockVectorSearchClient, nil) + + mockVectorSearchClient. + EXPECT(). + Recv(). + AnyTimes(). + DoAndReturn( + func() (*protos.Neighbor, error) { + return nil, fmt.Errorf("foo") + }, + ) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: GetUint32Ptr(8), + } + + _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to receive all neighbors", fmt.Errorf("foo"))) +} + +type unknowKeyValue struct{} + +func (u *unknowKeyValue) isKey_Value() {} + +func TestVectorSearchFloat32_FailedToConvertNeighbor(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockTransactClient := protos.NewMockTransactServiceClient(ctrl) + mockConn := &connection{ + transactClient: mockTransactClient, + } + mockVectorSearchClient := protos.NewMockTransactService_VectorSearchClient(ctrl) + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockTransactClient. + EXPECT(). + VectorSearch(gomock.Any(), gomock.Any()). + Return(mockVectorSearchClient, nil) + + mockVectorSearchClient. + EXPECT(). + Recv(). + AnyTimes(). + DoAndReturn( + func() (*protos.Neighbor, error) { + return &protos.Neighbor{ + Key: &protos.Key{Value: protos.NewMockisKey_Value(ctrl)}, + }, nil + }, + ) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + vector := []float32{1.0, 2.0, 3.0} + limit := uint32(7) + searchParams := &protos.HnswSearchParams{ + Ef: GetUint32Ptr(8), + } + + _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to convert neighbor", fmt.Errorf("error converting neighbor: unsupported key value type: *protos.MockisKey_Value"))) +} diff --git a/makefile b/makefile index 9865f21..c8a91ac 100644 --- a/makefile +++ b/makefile @@ -35,6 +35,7 @@ mocks: get-mockgen $(MOCKGEN) --source protos/auth_grpc.pb.go --destination protos/auth_grpc_mock.pb.go --package protos $(MOCKGEN) --source protos/index_grpc.pb.go --destination protos/index_grpc_mock.pb.go --package protos $(MOCKGEN) --source protos/transact_grpc.pb.go --destination protos/transact_grpc_mock.pb.go --package protos + $(MOCKGEN) --source protos/types.pb.go --destination protos/types_mock.pb.go --package protos $(MOCKGEN) --source protos/user-admin_grpc.pb.go --destination protos/user-admin_grpc_mock.pb.go --package protos $(MOCKGEN) --source protos/vector-db_grpc.pb.go --destination protos/vector-db_grpc_mock.pb.go --package protos diff --git a/testutils.go b/testutils.go index c70c4f0..5c2cd7d 100644 --- a/testutils.go +++ b/testutils.go @@ -93,8 +93,11 @@ func GetFloat32Ptr(i float32) *float32 { } func GetBoolPtr(b bool) *bool { - ptr := b - return &ptr + return &b +} + +func GetTimePtr(t time.Time) *time.Time { + return &t } func CreateFlagStr(name, value string) string { From c4d72e93e135e1648e895c643ef527b0d7f0e21e Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Tue, 17 Sep 2024 13:31:46 -0700 Subject: [PATCH 11/42] more tests --- client.go | 17 ++-- client_test.go | 216 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 17aaa56..c048c54 100644 --- a/client.go +++ b/client.go @@ -640,8 +640,10 @@ func (c *Client) WaitForIndexCompletion( conn, err := c.connectionProvider.GetRandomConn() if err != nil { - logger.Error("failed to wait for index completion", slog.Any("error", err)) - return err + msg := "failed to wait for index completion" + logger.Error(msg, slog.Any("error", err)) + + return NewAVSError(msg, err) } indexStatusReq := createIndexStatusRequest(namespace, indexName) @@ -655,8 +657,10 @@ func (c *Client) WaitForIndexCompletion( for { indexStatus, err := conn.indexClient.GetStatus(ctx, indexStatusReq) if err != nil { - logger.ErrorContext(ctx, "failed to wait for index completion", slog.Any("error", err)) - return err + msg := "failed to wait for index completion" + logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return NewAVSError(msg, err) } // We consider the index completed when unmerged record count == 0 for @@ -674,7 +678,7 @@ func (c *Client) WaitForIndexCompletion( } else { logger.DebugContext(ctx, "index not yet completed", slog.Int64("unmerged", unmerged)) - unmergedNotZeroCount-- + unmergedNotZeroCount++ } timer.Reset(waitInterval) @@ -682,8 +686,9 @@ func (c *Client) WaitForIndexCompletion( select { case <-timer.C: case <-ctx.Done(): + msg := "failed to wait for index completion" logger.ErrorContext(ctx, "waiting for index completion canceled") - return ctx.Err() + return NewAVSError(msg, ctx.Err()) } } } diff --git a/client_test.go b/client_test.go index 38997d5..a39f653 100644 --- a/client_test.go +++ b/client_test.go @@ -1348,3 +1348,219 @@ func TestVectorSearchFloat32_FailedToConvertNeighbor(t *testing.T) { assert.ErrorAs(t, err, &avsError) assert.Equal(t, avsError, NewAVSError("failed to convert neighbor", fmt.Errorf("error converting neighbor: unsupported key value type: *protos.MockisKey_Value"))) } + +func TestWaitForIndexCompletion_SuccessAfterZeroCountReturnedTwice(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedGetStatusRequest := &protos.IndexStatusRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + indexStatus := &protos.IndexStatusResponse{ + UnmergedRecordCount: 0, + } + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Times(3). + Return(indexStatus, nil). + Do(func(ctx context.Context, in *protos.IndexStatusRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedGetStatusRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Millisecond*1) + + assert.NoError(t, err) +} + +func TestWaitForIndexCompletion_SuccessAfterNonZeroUnmergedCount(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedGetStatusRequest := &protos.IndexStatusRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + indexStatusCount := 0 + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Times(2). + DoAndReturn(func(ctx context.Context, in *protos.IndexStatusRequest, opts ...grpc.CallOption) (*protos.IndexStatusResponse, error) { + assert.Equal(t, expectedGetStatusRequest, in) + indexStatusCount++ + + if indexStatusCount == 1 { + return &protos.IndexStatusResponse{ + UnmergedRecordCount: 1, + }, nil + } else { + return &protos.IndexStatusResponse{ + UnmergedRecordCount: 0, + }, nil + } + + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Millisecond*1) + + assert.NoError(t, err) +} + +func TestWaitForIndexCompletion_FailToGetRandomConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Millisecond*1) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to wait for index completion", fmt.Errorf("foo"))) +} + +func TestWaitForIndexCompletion_FailGetStatusCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Millisecond*1) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to wait for index completion", fmt.Errorf("foo"))) +} + +func TestWaitForIndexCompletion_FailTimeout(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + // Set up expectations for connProvider.GetRandomConn() + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + + namespace := "testNamespace" + indexName := "testIndex" + + err = client.WaitForIndexCompletion(ctx, namespace, indexName, time.Second*1) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to wait for index completion", fmt.Errorf("context deadline exceeded"))) +} From 6ca315f92ece588917ac29189c550a80f932eea3 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Tue, 17 Sep 2024 14:16:13 -0700 Subject: [PATCH 12/42] more tests --- README.md | 2 + client.go | 8 +- client_test.go | 265 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 271 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e69e754..66aabd4 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![codecov](https://codecov.io/gh/aerospike/avs-client-go/graph/badge.svg?token=811TWWPW6S)](https://codecov.io/gh/aerospike/avs-client-go) + # Aerospike Vector Search Go Client > :warning: The go client is currently in development. APIs will break in the future! diff --git a/client.go b/client.go index c048c54..798668d 100644 --- a/client.go +++ b/client.go @@ -809,7 +809,7 @@ func (c *Client) IndexCreateFromIndexDef( _, err = conn.indexClient.Create(ctx, indexCreateReq) if err != nil { - msg := "failed to create index" + msg := "failed to create index from definition" logger.Error(msg, slog.Any("error", err)) return NewAVSErrorFromGrpc(msg, err) @@ -828,7 +828,7 @@ func (c *Client) IndexCreateFromIndexDef( // ctx (context.Context): The context for the operation. // namespace (string): The namespace of the index. // name (string): The name of the index. -// metadata (map[string]string): Metadata to update on the index. +// labels (map[string]string): Labels to update on the index. // hnswParams (*protos.HnswIndexUpdate): The HNSW parameters to update. // // Returns: @@ -838,7 +838,7 @@ func (c *Client) IndexUpdate( ctx context.Context, namespace string, indexName string, - metadata map[string]string, + labels map[string]string, hnswParams *protos.HnswIndexUpdate, ) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) @@ -858,7 +858,7 @@ func (c *Client) IndexUpdate( Namespace: namespace, Name: indexName, }, - Labels: metadata, + Labels: labels, Update: &protos.IndexUpdateRequest_HnswIndexUpdate{ HnswIndexUpdate: hnswParams, }, diff --git a/client_test.go b/client_test.go index a39f653..3671949 100644 --- a/client_test.go +++ b/client_test.go @@ -1564,3 +1564,268 @@ func TestWaitForIndexCompletion_FailTimeout(t *testing.T) { assert.ErrorAs(t, err, &avsError) assert.Equal(t, avsError, NewAVSError("failed to wait for index completion", fmt.Errorf("context deadline exceeded"))) } + +func TestIndexCreateFromIndexDef_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Times(2). + Return(mockConn, nil) + + indexDef := &protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + expectedIndexCreateRequest := &protos.IndexCreateRequest{ + Definition: indexDef, + } + + mockIndexClient. + EXPECT(). + Create(gomock.Any(), gomock.Any()). + Return(nil, nil). + Do(func(ctx context.Context, in *protos.IndexCreateRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexCreateRequest, in) + }) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, nil) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + err = client.IndexCreateFromIndexDef(ctx, indexDef) + + assert.NoError(t, err) +} + +func TestIndexCreateFromIndexDef_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + indexDef := &protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + err = client.IndexCreateFromIndexDef(ctx, indexDef) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to create index from definition", fmt.Errorf("foo"))) +} + +func TestIndexCreateFromIndexDef_FailCreateCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + Create(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + indexDef := &protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + err = client.IndexCreateFromIndexDef(ctx, indexDef) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to create index from definition", fmt.Errorf("foo"))) +} + +func TestIndexUpdate_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedIndexUpdateRequest := &protos.IndexUpdateRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + Labels: map[string]string{ + "foo": "bar", + }, + Update: &protos.IndexUpdateRequest_HnswIndexUpdate{ + HnswIndexUpdate: &protos.HnswIndexUpdate{ + MaxMemQueueSize: GetUint32Ptr(10), + }, + }, + } + + mockIndexClient. + EXPECT(). + Update(gomock.Any(), gomock.Any()). + Return(nil, nil). + Do(func(ctx context.Context, in *protos.IndexUpdateRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexUpdateRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + testMetadata := map[string]string{ + "foo": "bar", + } + hnswParams := &protos.HnswIndexUpdate{ + MaxMemQueueSize: GetUint32Ptr(10), + } + + err = client.IndexUpdate(ctx, testNamespace, testIndex, testMetadata, hnswParams) + + assert.NoError(t, err) +} + +func TestIndexUpdate_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + testMetadata := map[string]string{ + "foo": "bar", + } + hnswParams := &protos.HnswIndexUpdate{ + MaxMemQueueSize: GetUint32Ptr(10), + } + + err = client.IndexUpdate(ctx, testNamespace, testIndex, testMetadata, hnswParams) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to update index", fmt.Errorf("foo"))) +} + +func TestIndexUpdate_FailUpdateCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + Update(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + testMetadata := map[string]string{ + "foo": "bar", + } + hnswParams := &protos.HnswIndexUpdate{ + MaxMemQueueSize: GetUint32Ptr(10), + } + + err = client.IndexUpdate(ctx, testNamespace, testIndex, testMetadata, hnswParams) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to update index", fmt.Errorf("bar"))) +} From 1a137cc8e5f0645dec7292a07ff09f1bbcb273f5 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Tue, 17 Sep 2024 15:47:32 -0700 Subject: [PATCH 13/42] more tests --- client_test.go | 238 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 238 insertions(+) diff --git a/client_test.go b/client_test.go index 3671949..9ff2c72 100644 --- a/client_test.go +++ b/client_test.go @@ -10,6 +10,8 @@ import ( "github.com/aerospike/avs-client-go/protos" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" "github.com/stretchr/testify/assert" @@ -1829,3 +1831,239 @@ func TestIndexUpdate_FailUpdateCall(t *testing.T) { assert.ErrorAs(t, err, &avsError) assert.Equal(t, avsError, NewAVSError("failed to update index", fmt.Errorf("bar"))) } + +func TestIndexDrop_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Times(2). + Return(mockConn, nil) + + expectedIndexDropRequest := &protos.IndexDropRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + mockIndexClient. + EXPECT(). + Drop(gomock.Any(), gomock.Any()). + Return(nil, nil). + Do(func(ctx context.Context, in *protos.IndexDropRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexDropRequest, in) + }) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, status.Errorf(codes.NotFound, "foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + err = client.IndexDrop(ctx, testNamespace, testIndex) + + assert.NoError(t, err) +} + +func TestIndexDrop_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + err = client.IndexDrop(ctx, testNamespace, testIndex) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to drop index", fmt.Errorf("foo"))) +} + +func TestIndexDrop_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + Drop(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + err = client.IndexDrop(ctx, testNamespace, testIndex) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to drop index", fmt.Errorf("bar"))) +} + +func TestIndexList_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedIndexListRequest := &protos.IndexListRequest{ + ApplyDefaults: GetBoolPtr(true), + } + + expectedIndexDefs := &protos.IndexDefinitionList{ + Indices: []*protos.IndexDefinition{ + { + Id: &protos.IndexId{ + Namespace: "testNamespace0", + Name: "testIndex0", + }, + }, + { + Id: &protos.IndexId{ + Namespace: "testNamespace1", + Name: "testIndex1", + }, + }, + }, + } + + mockIndexClient. + EXPECT(). + List(gomock.Any(), gomock.Any()). + Return(expectedIndexDefs, nil). + Do(func(ctx context.Context, in *protos.IndexListRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexListRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + indexDefs, err := client.IndexList(ctx, true) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedIndexDefs, indexDefs) +} + +func TestIndexList_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + _, err = client.IndexList(ctx, true) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get indexes", fmt.Errorf("foo"))) +} + +func TestIndexList_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + List(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.IndexList(ctx, true) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get indexes", fmt.Errorf("bar"))) +} From 3f5711f03c04dad600b54d51f44801cf43330b57 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Wed, 18 Sep 2024 08:44:23 -0700 Subject: [PATCH 14/42] more tests --- client_test.go | 351 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 351 insertions(+) diff --git a/client_test.go b/client_test.go index 9ff2c72..55f1829 100644 --- a/client_test.go +++ b/client_test.go @@ -2067,3 +2067,354 @@ func TestIndexList_FailDropCall(t *testing.T) { assert.ErrorAs(t, err, &avsError) assert.Equal(t, avsError, NewAVSError("failed to get indexes", fmt.Errorf("bar"))) } + +func TestIndexGet_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedIndexListRequest := &protos.IndexGetRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + ApplyDefaults: GetBoolPtr(true), + } + + expectedIndexDefs := &protos.IndexDefinition{ + Id: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + mockIndexClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(expectedIndexDefs, nil). + Do(func(ctx context.Context, in *protos.IndexGetRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexListRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + indexDefs, err := client.IndexGet(ctx, testNamespace, testIndex, true) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedIndexDefs, indexDefs) +} + +func TestIndexGet_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + _, err = client.IndexGet(ctx, testNamespace, testIndex, true) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get index", fmt.Errorf("foo"))) +} + +func TestIndexGet_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + _, err = client.IndexGet(ctx, testNamespace, testIndex, true) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get index", fmt.Errorf("bar"))) +} + +func TestIndexGetStatus_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedIndexStatusRequest := &protos.IndexStatusRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + } + + expectedIndexStatusResp := &protos.IndexStatusResponse{ + UnmergedRecordCount: 9, + } + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(expectedIndexStatusResp, nil). + Do(func(ctx context.Context, in *protos.IndexStatusRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexStatusRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + indexDefs, err := client.IndexGetStatus(ctx, testNamespace, testIndex) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedIndexStatusResp, indexDefs) +} + +func TestIndexGetStatus_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + _, err = client.IndexGetStatus(ctx, testNamespace, testIndex) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get index status", fmt.Errorf("foo"))) +} + +func TestIndexGetStatus_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + GetStatus(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + _, err = client.IndexGetStatus(ctx, testNamespace, testIndex) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get index status", fmt.Errorf("bar"))) +} + +func TestIndexGcInvalidVertices_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + cutoffTime := time.Now() + expectedIndexStatusRequest := &protos.GcInvalidVerticesRequest{ + IndexId: &protos.IndexId{ + Namespace: "testNamespace", + Name: "testIndex", + }, + CutoffTimestamp: cutoffTime.Unix(), + } + + mockIndexClient. + EXPECT(). + GcInvalidVertices(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.GcInvalidVerticesRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedIndexStatusRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + + err = client.GcInvalidVertices(ctx, testNamespace, testIndex, cutoffTime) + + assert.NoError(t, err) +} + +func TestIndexGcInvalidVertices_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + cutoffTime := time.Now() + + err = client.GcInvalidVertices(ctx, testNamespace, testIndex, cutoffTime) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to garbage collect invalid vertices", fmt.Errorf("foo"))) +} + +func TestIndexGcInvalidVertices_FailDropCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockIndexClient := protos.NewMockIndexServiceClient(ctrl) + mockConn := &connection{ + indexClient: mockIndexClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockIndexClient. + EXPECT(). + GcInvalidVertices(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testNamespace := "testNamespace" + testIndex := "testIndex" + cutoffTime := time.Now() + + err = client.GcInvalidVertices(ctx, testNamespace, testIndex, cutoffTime) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to garbage collect invalid vertices", fmt.Errorf("bar"))) +} From f659151914a236a5719bf67fae4f9b4686cc6968 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Wed, 18 Sep 2024 09:58:10 -0700 Subject: [PATCH 15/42] more tests --- client.go | 2 +- client_test.go | 676 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 677 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 798668d..752d6ef 100644 --- a/client.go +++ b/client.go @@ -1268,7 +1268,7 @@ func (c *Client) ListUsers(ctx context.Context) (*protos.ListUsersResponse, erro usersResp, err := conn.userAdminClient.ListUsers(ctx, &emptypb.Empty{}) if err != nil { - msg := "failed to lists users" + msg := "failed to list users" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) return nil, NewAVSErrorFromGrpc(msg, err) diff --git a/client_test.go b/client_test.go index 55f1829..966e44f 100644 --- a/client_test.go +++ b/client_test.go @@ -2418,3 +2418,679 @@ func TestIndexGcInvalidVertices_FailDropCall(t *testing.T) { assert.ErrorAs(t, err, &avsError) assert.Equal(t, avsError, NewAVSError("failed to garbage collect invalid vertices", fmt.Errorf("bar"))) } + +func TestCreateUser_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.AddUserRequest{ + Credentials: &protos.Credentials{ + Username: "testUser", + Credentials: &protos.Credentials_PasswordCredentials{ + PasswordCredentials: &protos.PasswordCredentials{ + Password: "testPass", + }, + }, + }, + Roles: []string{ + "testRole", + }, + } + + mockUserAdminClient. + EXPECT(). + AddUser(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.AddUserRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + testRoles := []string{"testRole"} + + err = client.CreateUser(ctx, testUser, testPass, testRoles) + + assert.NoError(t, err) +} + +func TestCreateUser_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + testRoles := []string{"testRole"} + + err = client.CreateUser(ctx, testUser, testPass, testRoles) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to create user", fmt.Errorf("foo"))) +} + +func TestCreateUser_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + AddUser(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + testRoles := []string{"testRole"} + + err = client.CreateUser(ctx, testUser, testPass, testRoles) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to create user", fmt.Errorf("bar"))) +} + +func TestUpdateCredentials_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.UpdateCredentialsRequest{ + Credentials: &protos.Credentials{ + Username: "testUser", + Credentials: &protos.Credentials_PasswordCredentials{ + PasswordCredentials: &protos.PasswordCredentials{ + Password: "testPass", + }, + }, + }, + } + + mockUserAdminClient. + EXPECT(). + UpdateCredentials(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.UpdateCredentialsRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + + err = client.UpdateCredentials(ctx, testUser, testPass) + + assert.NoError(t, err) +} + +func TestUpdateCredentials_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + + err = client.UpdateCredentials(ctx, testUser, testPass) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to update user credentials", fmt.Errorf("foo"))) +} + +func TestUpdateCredentials_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + UpdateCredentials(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testPass := "testPass" + + err = client.UpdateCredentials(ctx, testUser, testPass) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to update user credentials", fmt.Errorf("bar"))) +} + +func TestDropUser_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.DropUserRequest{ + Username: "testUser", + } + + mockUserAdminClient. + EXPECT(). + DropUser(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.DropUserRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + err = client.DropUser(ctx, testUser) + + assert.NoError(t, err) +} + +func TestDropUser_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + err = client.DropUser(ctx, testUser) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to drop user", fmt.Errorf("foo"))) +} + +func TestDropUser_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + DropUser(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + err = client.DropUser(ctx, testUser) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to drop user", fmt.Errorf("bar"))) +} + +func TestGetUser_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.GetUserRequest{ + Username: "testUser", + } + + expectedUser := &protos.User{ + Username: "testUser", + Roles: []string{ + "testRole", + }, + } + + mockUserAdminClient. + EXPECT(). + GetUser(gomock.Any(), gomock.Any()). + Return(expectedUser, nil). + Do(func(ctx context.Context, in *protos.GetUserRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + user, err := client.GetUser(ctx, testUser) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedUser, user) +} + +func TestGetUser_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + _, err = client.GetUser(ctx, testUser) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get user", fmt.Errorf("foo"))) +} + +func TestGetUser_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + GetUser(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + + _, err = client.GetUser(ctx, testUser) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get user", fmt.Errorf("bar"))) +} + +func TestListUsers_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &emptypb.Empty{} + + expectedUsers := &protos.ListUsersResponse{ + Users: []*protos.User{ + { + + Username: "testUser", + Roles: []string{ + "testRole", + }, + }, + }, + } + + mockUserAdminClient. + EXPECT(). + ListUsers(gomock.Any(), gomock.Any()). + Return(expectedUsers, nil). + Do(func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.ListUsers(ctx) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedUsers, user) +} + +func TestListUsers_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ListUsers(ctx) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to list users", fmt.Errorf("foo"))) +} + +func TestListUsers_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + ListUsers(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ListUsers(ctx) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to list users", fmt.Errorf("bar"))) +} + +func TestRevokeRoles_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &protos.RevokeRolesRequest{ + Username: "testUser", + Roles: []string{"testRole"}, + } + + mockUserAdminClient. + EXPECT(). + RevokeRoles(gomock.Any(), gomock.Any()). + Return(&emptypb.Empty{}, nil). + Do(func(ctx context.Context, in *protos.RevokeRolesRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testRoles := []string{"testRole"} + + err = client.RevokeRoles(ctx, testUser, testRoles) + + assert.NoError(t, err) +} + +func TestRevokeRoles_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testRoles := []string{"testRole"} + + err = client.RevokeRoles(ctx, testUser, testRoles) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to revoke user roles", fmt.Errorf("foo"))) +} + +func TestRevokeRoles_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + RevokeRoles(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + testUser := "testUser" + testRoles := []string{"testRole"} + + err = client.RevokeRoles(ctx, testUser, testRoles) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to revoke user roles", fmt.Errorf("bar"))) +} From 4184ad6f5217a4597a66eff64be8a2cb09926c12 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Wed, 18 Sep 2024 12:37:53 -0700 Subject: [PATCH 16/42] add more tests --- client.go | 2 +- client_test.go | 523 +++++++++++++++++++++++++++++++++++++++++ connection_provider.go | 22 +- makefile | 1 + 4 files changed, 539 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index 752d6ef..e43b49d 100644 --- a/client.go +++ b/client.go @@ -1378,7 +1378,7 @@ func (c *Client) ListRoles(ctx context.Context) (*protos.ListRolesResponse, erro rolesResp, err := conn.userAdminClient.ListRoles(ctx, &emptypb.Empty{}) if err != nil { - msg := "failed to lists roles" + msg := "failed to list roles" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) return nil, NewAVSErrorFromGrpc(msg, err) diff --git a/client_test.go b/client_test.go index 966e44f..e9efbd2 100644 --- a/client_test.go +++ b/client_test.go @@ -3094,3 +3094,526 @@ func TestRevokeRoles_FailCall(t *testing.T) { assert.ErrorAs(t, err, &avsError) assert.Equal(t, avsError, NewAVSError("failed to revoke user roles", fmt.Errorf("bar"))) } + +func TestListRoles_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + expectedRequest := &emptypb.Empty{} + + expectedRoles := &protos.ListRolesResponse{ + Roles: []*protos.Role{ + { + Id: "testRole", + }, + }, + } + + mockUserAdminClient. + EXPECT(). + ListRoles(gomock.Any(), gomock.Any()). + Return(expectedRoles, nil). + Do(func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.ListRoles(ctx) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedRoles, user) +} + +func TestListRoles_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserAdminClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserAdminClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ListRoles(ctx) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to list roles", fmt.Errorf("foo"))) +} + +func TestListRoles_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockUserClient := protos.NewMockUserAdminServiceClient(ctrl) + mockConn := &connection{ + userAdminClient: mockUserClient, + } + + mockConnProvider. + EXPECT(). + GetRandomConn(). + Return(mockConn, nil) + + mockUserClient. + EXPECT(). + ListRoles(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ListRoles(ctx) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to list roles", fmt.Errorf("bar"))) +} + +func TestConnectedNodeEndpoint_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockGrpcConn := NewMockGrpcClientConn(ctrl) + mockConn := &connection{ + grpcConn: mockGrpcConn, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + expectedEndpoint := &protos.ServerEndpoint{ + Address: "1.1.1.1", + Port: 3000, + } + + mockGrpcConn. + EXPECT(). + Target(). + Return("1.1.1.1:3000") + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + endpoint, err := client.ConnectedNodeEndpoint(ctx, nil) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedEndpoint, endpoint) +} + +func TestConnectedNodeEndpoint_FailedGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockGrpcConn := NewMockGrpcClientConn(ctrl) + mockConn := &connection{ + grpcConn: mockGrpcConn, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ConnectedNodeEndpoint(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get connected endpoint", fmt.Errorf("foo"))) +} + +func TestConnectedNodeEndpoint_FailParsePort(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockGrpcConn := NewMockGrpcClientConn(ctrl) + mockConn := &connection{ + grpcConn: mockGrpcConn, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + mockGrpcConn. + EXPECT(). + Target(). + Return("1.1.1.1:aaaa") + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ConnectedNodeEndpoint(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Contains(t, avsError.Error(), "failed to parse port") +} + +func TestClusteringState_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + expectedRequest := &emptypb.Empty{} + + expectedRoles := &protos.ClusteringState{ + IsInCluster: true, + } + + mockClusterInfoClient. + EXPECT(). + GetClusteringState(gomock.Any(), gomock.Any()). + Return(expectedRoles, nil). + Do(func(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.ClusteringState(ctx, nil) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedRoles, user) +} + +func TestClusteringState_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ClusteringState(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get clustering state", fmt.Errorf("foo"))) +} + +func TestClusteringState_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + mockClusterInfoClient. + EXPECT(). + GetClusteringState(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.ClusteringState(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get clustering state", fmt.Errorf("bar"))) +} + +func TestClusterEndpoints_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + listenerName := "test-listener" + expectedRequest := &protos.ClusterNodeEndpointsRequest{ + ListenerName: &listenerName, + } + + expectedResp := &protos.ClusterNodeEndpoints{} + + mockClusterInfoClient. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(expectedResp, nil). + Do(func(ctx context.Context, in *protos.ClusterNodeEndpointsRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.ClusterEndpoints(ctx, nil, &listenerName) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedResp, user) +} + +func TestClusterEndpoints_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + listenerName := "test-listener" + + _, err = client.ClusterEndpoints(ctx, nil, &listenerName) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get cluster endpoints", fmt.Errorf("foo"))) +} + +func TestClusterEndpoints_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockClusterInfoClient := protos.NewMockClusterInfoServiceClient(ctrl) + mockConn := &connection{ + clusterInfoClient: mockClusterInfoClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + mockClusterInfoClient. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + listenerName := "test-listener" + _, err = client.ClusterEndpoints(ctx, nil, &listenerName) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to get cluster endpoints", fmt.Errorf("bar"))) +} + +func TestAbout_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockAboutClient := protos.NewMockAboutServiceClient(ctrl) + mockConn := &connection{ + aboutClient: mockAboutClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + expectedRequest := &protos.AboutRequest{} + expectedResp := &protos.AboutResponse{} + + mockAboutClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(expectedResp, nil). + Do(func(ctx context.Context, in *protos.AboutRequest, opts ...grpc.CallOption) { + assert.Equal(t, expectedRequest, in) + }) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + user, err := client.About(ctx, nil) + + assert.NoError(t, err) + assert.EqualExportedValues(t, expectedResp, user) +} + +func TestAbout_FailGetConn(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockAboutClient := protos.NewMockAboutServiceClient(ctrl) + mockConn := &connection{ + aboutClient: mockAboutClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, fmt.Errorf("foo")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + + _, err = client.About(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to make about request", fmt.Errorf("foo"))) +} + +func TestAbout_FailCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConnProvider := NewMockconnProvider(ctrl) + mockAboutClient := protos.NewMockAboutServiceClient(ctrl) + mockConn := &connection{ + aboutClient: mockAboutClient, + } + + mockConnProvider. + EXPECT(). + GetSeedConn(). + Return(mockConn, nil) + + mockAboutClient. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("bar")) + + // Create the client with the mock connProvider + client, err := newClient(mockConnProvider, slog.Default()) + assert.NoError(t, err) + + // Prepare input parameters + ctx := context.Background() + _, err = client.About(ctx, nil) + + var avsError *Error + assert.ErrorAs(t, err, &avsError) + assert.Equal(t, avsError, NewAVSError("failed to make about request", fmt.Errorf("bar"))) +} diff --git a/connection_provider.go b/connection_provider.go index 4632393..1382646 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -23,12 +23,18 @@ import ( var errConnectionProviderClosed = errors.New("connection provider is closed") +type GrpcClientConn interface { + grpc.ClientConnInterface + Target() string + Close() error +} + // connection represents a gRPC client connection and all the clients (stubs) // for the various AVS services. It's main purpose to remove the need to create // multiple clients for the same connection. This follows the documented grpc // best practice of reusing connections. type connection struct { - grpcConn *grpc.ClientConn + grpcConn GrpcClientConn transactClient protos.TransactServiceClient authClient protos.AuthServiceClient userAdminClient protos.UserAdminServiceClient @@ -38,7 +44,7 @@ type connection struct { } // newConnection creates a new connection instance. -func newConnection(conn *grpc.ClientConn) *connection { +func newConnection(conn GrpcClientConn) *connection { return &connection{ grpcConn: conn, transactClient: protos.NewTransactServiceClient(conn), @@ -314,7 +320,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { var authErr error wg := sync.WaitGroup{} - seedGrpcConns := make(chan *grpc.ClientConn) + seedGrpcConns := make(chan GrpcClientConn) cp.seedConns = []*connection{} tokenLock := sync.Mutex{} // Ensures only one thread attempts to update token at a time tokenUpdated := false // Ensures token update only occurs once @@ -434,11 +440,11 @@ func (cp *connectionProvider) checkAndSetClusterID(clusterID uint64) bool { } // getTendConns returns all the gRPC client connections for tend operations. -func (cp *connectionProvider) getTendConns() []*grpc.ClientConn { +func (cp *connectionProvider) getTendConns() []GrpcClientConn { cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - conns := make([]*grpc.ClientConn, len(cp.seedConns)+len(cp.nodeConns)) + conns := make([]GrpcClientConn, len(cp.seedConns)+len(cp.nodeConns)) i := 0 for _, conn := range cp.seedConns { @@ -464,7 +470,7 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 for _, conn := range conns { wg.Add(1) - go func(conn *grpc.ClientConn) { + go func(conn GrpcClientConn) { defer wg.Done() logger := cp.logger.With(slog.String("host", conn.Target())) @@ -675,7 +681,7 @@ func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort { // successful endpoint in endpoints. func (cp *connectionProvider) createGrpcConnFromEndpoints( endpoints *protos.ServerEndpointList, -) (*grpc.ClientConn, error) { +) (GrpcClientConn, error) { for _, endpoint := range endpoints.Endpoints { if strings.ContainsRune(endpoint.Address, ':') { continue // TODO: Add logging and support for IPv6 @@ -693,7 +699,7 @@ func (cp *connectionProvider) createGrpcConnFromEndpoints( // createGrcpConn creates a gRPC client connection to a host. This handles adding // credential and configuring tls. -func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (*grpc.ClientConn, error) { +func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (GrpcClientConn, error) { opts := []grpc.DialOption{} if cp.tlsConfig == nil { diff --git a/makefile b/makefile index c8a91ac..7ed72c1 100644 --- a/makefile +++ b/makefile @@ -32,6 +32,7 @@ $(MOCKGEN): $(GOBIN) .PHONY: mocks mocks: get-mockgen $(MOCKGEN) --source client.go --destination client_mock.go --package avs + $(MOCKGEN) --source connection_provider.go --destination connection_provider_mock.go --package avs $(MOCKGEN) --source protos/auth_grpc.pb.go --destination protos/auth_grpc_mock.pb.go --package protos $(MOCKGEN) --source protos/index_grpc.pb.go --destination protos/index_grpc_mock.pb.go --package protos $(MOCKGEN) --source protos/transact_grpc.pb.go --destination protos/transact_grpc_mock.pb.go --package protos From d0ba6c4ec824388bcf2d6991a586d99317295a71 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Wed, 18 Sep 2024 14:45:00 -0700 Subject: [PATCH 17/42] add more tests --- connection_provider.go | 4 + connection_provider_test.go | 28 +++ protos/utils_test.go | 453 +++++++++++++++++++++++++++++++++++- 3 files changed, 482 insertions(+), 3 deletions(-) create mode 100644 connection_provider_test.go diff --git a/connection_provider.go b/connection_provider.go index 1382646..eb66a1a 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -101,6 +101,10 @@ func newConnectionProvider( logger *slog.Logger, ) (*connectionProvider, error) { // Initialize the logger. + if logger == nil { + logger = slog.Default() + } + logger = logger.WithGroup("cp") // Validate the seeds. diff --git a/connection_provider_test.go b/connection_provider_test.go new file mode 100644 index 0000000..fb5474d --- /dev/null +++ b/connection_provider_test.go @@ -0,0 +1,28 @@ +package avs + +import ( + "context" + "crypto/tls" + "errors" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewConnectionProvider(t *testing.T) { + seeds := HostPortSlice{} + listenerName := "listener" + isLoadBalancer := false + credentials := &UserPassCredentials{ + username: "admin", + password: "password", + } + tlsConfig := &tls.Config{} + var logger *slog.Logger + + cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, credentials, tlsConfig, logger) + + assert.Nil(t, cp) + assert.Error(t, err, errors.New("seeds cannot be nil or empty")) +} diff --git a/protos/utils_test.go b/protos/utils_test.go index 2aef07f..8df2f55 100644 --- a/protos/utils_test.go +++ b/protos/utils_test.go @@ -171,7 +171,7 @@ func TestConvertToValue(t *testing.T) { testCases := []struct { input any expected *Value - expectedErr error + expectedErr bool }{ { input: "testString", @@ -369,10 +369,48 @@ func TestConvertToValue(t *testing.T) { }, }, }, + { + input: []float32{1, 2}, + expected: &Value{ + Value: &Value_VectorValue{ + VectorValue: &Vector{ + Data: &Vector_FloatData{ + FloatData: &FloatData{ + Value: []float32{1, 2}, + }, + }, + }, + }, + }, + }, + { + input: []bool{true, false}, + expected: &Value{ + Value: &Value_VectorValue{ + VectorValue: &Vector{ + Data: &Vector_BoolData{ + BoolData: &BoolData{ + Value: []bool{true, false}, + }, + }, + }, + }, + }, + }, + { + input: map[int]any{10: struct{}{}}, + expected: nil, + expectedErr: true, + }, + { + input: []any{struct{}{}}, + expected: nil, + expectedErr: true, + }, { input: struct{}{}, // Unsupported type expected: nil, - expectedErr: fmt.Errorf("unsupported value type: struct {}"), + expectedErr: true, }, } @@ -381,11 +419,328 @@ func TestConvertToValue(t *testing.T) { result, err := ConvertToValue(tc.input) assert.Equal(t, tc.expected, result) - assert.Equal(t, tc.expectedErr, err) + + if tc.expectedErr { + assert.Error(t, err) + } }) } } +func TestConvertToMapKey(t *testing.T) { + testCases := []struct { + input any + expected *MapKey + expectedErr error + }{ + { + input: "testString", + expected: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "testString", + }, + }, + }, + { + input: int32(123), + expected: &MapKey{ + Value: &MapKey_IntValue{ + IntValue: 123, + }, + }, + }, + { + input: int64(123456789), + expected: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: 123456789, + }, + }, + }, + { + input: int(123456789), + expected: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: 123456789, + }, + }, + }, + { + input: []byte{0x01, 0x02, 0x03}, + expected: &MapKey{ + Value: &MapKey_BytesValue{ + BytesValue: []byte{0x01, 0x02, 0x03}, + }, + }, + }, + { + input: struct{}{}, + expected: nil, // Unsupported type + expectedErr: fmt.Errorf("unsupported key type: struct {}"), + }, + } + + for _, tc := range testCases { + result, err := ConvertToMapKey(tc.input) + + assert.Equal(t, tc.expected, result) + assert.Equal(t, tc.expectedErr, err) + } +} + +type mapKeyValueUnknown struct{} + +func (*mapKeyValueUnknown) isMapKey_Value() {} //nolint:revive,stylecheck // Grpc generated + +func TestConvertFromMapKey(t *testing.T) { + testCases := []struct { + input *MapKey + expected any + expectedErr error + }{ + { + input: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "testString", + }, + }, + expected: "testString", + }, + { + input: &MapKey{ + Value: &MapKey_BytesValue{ + BytesValue: []byte{0x01, 0x02, 0x03}, + }, + }, + expected: []byte{0x01, 0x02, 0x03}, + }, + { + input: &MapKey{ + Value: &MapKey_IntValue{ + IntValue: 123, + }, + }, + expected: int32(123), + }, + { + input: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: 123456789, + }, + }, + expected: int64(123456789), + }, + { + input: &MapKey{ + Value: &mapKeyValueUnknown{}, + }, + expected: nil, + expectedErr: fmt.Errorf("unsupported map key value type: *protos.mapKeyValueUnknown"), + }, + } + + for _, tc := range testCases { + result, err := ConvertFromMapKey(tc.input) + + assert.Equal(t, tc.expected, result) + assert.Equal(t, tc.expectedErr, err) + } +} + +func TestConvertToMapValue(t *testing.T) { + testCases := []struct { + input any + expected *Map + expectedErr *string + }{ + { + input: map[any]any{"key": "value"}, + expected: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "key", + }, + }, + Value: &Value{ + Value: &Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + }, + }, + { + input: map[string]string{"key": "value"}, + expected: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "key", + }, + }, + Value: &Value{ + Value: &Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + }, + }, + { + input: map[int]float64{10: 3.124}, + expected: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: int64(10), + }, + }, + Value: &Value{ + Value: &Value_DoubleValue{ + DoubleValue: 3.124, + }, + }, + }, + }, + }, + }, + { + input: map[int]any{10: struct{}{}}, + expected: nil, + expectedErr: GetStrPtr("unsupported map value: unsupported value type: struct {}"), + }, + } + + for _, tc := range testCases { + result, err := ConvertToMapValue(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr != nil { + assert.ErrorContains(t, err, *tc.expectedErr) + } + } +} + +func TestConvertFromMapValue(t *testing.T) { + var nilMap map[any]any + testCases := []struct { + input *Map + expected any + expectedErr *string + }{ + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "key", + }, + }, + Value: &Value{ + Value: &Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + }, + expected: map[any]any{"key": "value"}, + }, + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_StringValue{ + StringValue: "key", + }, + }, + Value: &Value{ + Value: &Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + }, + expected: map[any]any{"key": "value"}, + }, + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: int64(10), + }, + }, + Value: &Value{ + Value: &Value_DoubleValue{ + DoubleValue: 3.124, + }, + }, + }, + }, + }, + expected: map[any]any{int64(10): 3.124}, + }, + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &mapKeyValueUnknown{}, + }, + Value: &Value{ + Value: &Value_DoubleValue{ + DoubleValue: 3.124, + }, + }, + }, + }, + }, + expected: nilMap, + expectedErr: GetStrPtr("unsupported map key value type: *protos.mapKeyValueUnknown"), + }, + { + input: &Map{ + Entries: []*MapEntry{ + { + Key: &MapKey{ + Value: &MapKey_LongValue{ + LongValue: int64(10), + }, + }, + Value: &Value{ + Value: &valueUnknown{}, + }, + }, + }, + }, + expected: nilMap, + expectedErr: GetStrPtr("unsupported map value: unsupported value type: *protos.valueUnknown"), + }, + } + + for _, tc := range testCases { + result, err := ConvertFromMapValue(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr != nil { + assert.ErrorContains(t, err, *tc.expectedErr) + } + } +} + type valueUnknown struct{} func (*valueUnknown) isValue_Value() {} //nolint:revive,stylecheck // Grpc generated @@ -514,3 +869,95 @@ func TestConvertFromValue(t *testing.T) { }) } } + +func TestConvertToFields(t *testing.T) { + testCases := []struct { + input map[string]any + expected []*Field + expectedErr bool + }{ + { + input: map[string]any{ + "key1": "value1", + "key2": 123, + }, + expected: []*Field{ + { + Name: "key1", + Value: &Value{Value: &Value_StringValue{StringValue: "value1"}}, + }, + { + Name: "key2", + Value: &Value{Value: &Value_LongValue{LongValue: 123}}, + }, + }, + }, + { + input: map[string]any{ + "key1": "value1", + "key2": struct{}{}, + }, + expected: nil, + expectedErr: true, + }, + } + + for _, tc := range testCases { + result, err := ConvertToFields(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr { + assert.Error(t, err) + } + } +} + +func TestConvertFromFields(t *testing.T) { + testCases := []struct { + input []*Field + expected map[string]any + expectedErr bool + }{ + { + input: []*Field{ + { + Name: "key1", + Value: &Value{Value: &Value_StringValue{StringValue: "value1"}}, + }, + { + Name: "key2", + Value: &Value{Value: &Value_LongValue{LongValue: 123}}, + }, + }, + expected: map[string]any{ + "key1": "value1", + "key2": int64(123), + }, + }, + { + input: []*Field{ + { + Name: "key1", + Value: &Value{Value: &Value_StringValue{StringValue: "value1"}}, + }, + { + Name: "key2", + Value: &Value{Value: &valueUnknown{}}, + }, + }, + expected: nil, + expectedErr: true, + }, + } + + for _, tc := range testCases { + result, err := ConvertFromFields(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr { + assert.Error(t, err) + } + } +} From befc3a5065bdd8d7aa1ddab363d90d9798ffdd90 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 19 Sep 2024 12:55:56 -0700 Subject: [PATCH 18/42] add more tests --- protos/utils_test.go | 220 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 220 insertions(+) diff --git a/protos/utils_test.go b/protos/utils_test.go index 8df2f55..bdabc65 100644 --- a/protos/utils_test.go +++ b/protos/utils_test.go @@ -615,6 +615,11 @@ func TestConvertToMapValue(t *testing.T) { expected: nil, expectedErr: GetStrPtr("unsupported map value: unsupported value type: struct {}"), }, + { + input: map[any]any{struct{}{}: 10}, + expected: nil, + expectedErr: GetStrPtr("unsupported map key: unsupported key type: struct {}"), + }, } for _, tc := range testCases { @@ -628,6 +633,162 @@ func TestConvertToMapValue(t *testing.T) { } } +func TestConvertToList(t *testing.T) { + testCases := []struct { + input []any + expected *List + expectedErr *string + }{ + { + input: []any{"item1", "item2"}, + expected: &List{ + Entries: []*Value{ + { + Value: &Value_StringValue{ + StringValue: "item1", + }, + }, + { + Value: &Value_StringValue{ + StringValue: "item2", + }, + }, + }, + }, + }, + { + input: []any{1, 2}, + expected: &List{ + Entries: []*Value{ + { + Value: &Value_LongValue{ + LongValue: int64(1), + }, + }, + { + Value: &Value_LongValue{ + LongValue: int64(2), + }, + }, + }, + }, + }, + { + input: []any{true, false}, + expected: &List{ + Entries: []*Value{ + { + Value: &Value_BooleanValue{ + BooleanValue: true, + }, + }, + { + Value: &Value_BooleanValue{ + BooleanValue: false, + }, + }, + }, + }, + }, + { + input: []any{struct{}{}}, + expected: nil, + expectedErr: GetStrPtr("unsupported list value: unsupported value type: struct {}"), + }, + } + + for _, tc := range testCases { + result, err := ConvertToList(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr != nil { + assert.ErrorContains(t, err, *tc.expectedErr) + } + } +} + +func TestConvertFromListValue(t *testing.T) { + testCases := []struct { + input *List + expected []any + expectedErr *string + }{ + { + input: &List{ + Entries: []*Value{ + { + Value: &Value_StringValue{ + StringValue: "item1", + }, + }, + { + Value: &Value_StringValue{ + StringValue: "item2", + }, + }, + }, + }, + expected: []any{"item1", "item2"}, + }, + { + input: &List{ + Entries: []*Value{ + { + Value: &Value_LongValue{ + LongValue: int64(1), + }, + }, + { + Value: &Value_LongValue{ + LongValue: int64(2), + }, + }, + }, + }, + expected: []any{int64(1), int64(2)}, + }, + { + input: &List{ + Entries: []*Value{ + { + Value: &Value_BooleanValue{ + BooleanValue: true, + }, + }, + { + Value: &Value_BooleanValue{ + BooleanValue: false, + }, + }, + }, + }, + expected: []any{true, false}, + }, + { + input: &List{ + Entries: []*Value{ + { + Value: &valueUnknown{}, + }, + }, + }, + expected: nil, + expectedErr: GetStrPtr("unsupported list value: unsupported value type: *protos.valueUnknown"), + }, + } + + for _, tc := range testCases { + result, err := ConvertFromListValue(tc.input) + + assert.Equal(t, tc.expected, result) + + if tc.expectedErr != nil { + assert.ErrorContains(t, err, *tc.expectedErr) + } + } +} + func TestConvertFromMapValue(t *testing.T) { var nilMap map[any]any testCases := []struct { @@ -851,6 +1012,20 @@ func TestConvertFromValue(t *testing.T) { }, expected: []any{"item1", "item2"}, }, + { + input: &Value{ + Value: &Value_VectorValue{ + VectorValue: &Vector{ + Data: &Vector_FloatData{ + FloatData: &FloatData{ + Value: []float32{1, 2}, + }, + }, + }, + }, + }, + expected: []any{float32(1), float32(2)}, + }, { input: &Value{ Value: &valueUnknown{}, @@ -961,3 +1136,48 @@ func TestConvertFromFields(t *testing.T) { } } } + +type unknownVectorType struct{} + +func (*unknownVectorType) isVector_Data() {} + +func TestConvertFromVector(t *testing.T) { + testCases := []struct { + input *Vector + expected []any + expectedErr error + }{ + { + input: &Vector{ + Data: &Vector_FloatData{ + FloatData: &FloatData{ + Value: []float32{1, 2}, + }, + }, + }, + expected: []any{float32(1), float32(2)}, + }, + { + input: &Vector{ + Data: &Vector_BoolData{ + BoolData: &BoolData{ + Value: []bool{true, false}, + }, + }, + }, + expected: []any{true, false}, + }, + { + input: &Vector{Data: &unknownVectorType{}}, + expected: nil, + expectedErr: fmt.Errorf("unsupported value type: *protos.unknownVectorType"), + }, + } + + for _, tc := range testCases { + result, err := ConvertFromVector(tc.input) + + assert.Equal(t, tc.expected, result) + assert.Equal(t, tc.expectedErr, err) + } +} From 7cb8408fdfd8909385baa36e1e81aba5805d99f2 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 19 Sep 2024 13:23:41 -0700 Subject: [PATCH 19/42] fix unit tests, make tokenManager interface --- client.go | 8 +++++++- client_test.go | 6 +++--- connection_provider.go | 38 ++++++++++++++++++++----------------- connection_provider_test.go | 10 +++++----- protos/utils_test.go | 15 +++++++++------ token_manager.go | 26 ++++++++++++------------- 6 files changed, 58 insertions(+), 45 deletions(-) diff --git a/client.go b/client.go index e43b49d..ce15e36 100644 --- a/client.go +++ b/client.go @@ -72,12 +72,18 @@ func NewClient( logger = logger.WithGroup("avs") logger.Info("creating new client") + var grpcToken tokenManager + + if credentials != nil { + grpcToken = newGrpcJWTToken(credentials.username, credentials.password, logger) + } + connectionProvider, err := newConnectionProvider( ctx, seeds, listenerName, isLoadBalancer, - credentials, + grpcToken, tlsConfig, logger, ) diff --git a/client_test.go b/client_test.go index e9efbd2..99610b7 100644 --- a/client_test.go +++ b/client_test.go @@ -3209,7 +3209,7 @@ func TestConnectedNodeEndpoint_Success(t *testing.T) { defer ctrl.Finish() mockConnProvider := NewMockconnProvider(ctrl) - mockGrpcConn := NewMockGrpcClientConn(ctrl) + mockGrpcConn := NewMockgrpcClientConn(ctrl) mockConn := &connection{ grpcConn: mockGrpcConn, } @@ -3247,7 +3247,7 @@ func TestConnectedNodeEndpoint_FailedGetConn(t *testing.T) { defer ctrl.Finish() mockConnProvider := NewMockconnProvider(ctrl) - mockGrpcConn := NewMockGrpcClientConn(ctrl) + mockGrpcConn := NewMockgrpcClientConn(ctrl) mockConn := &connection{ grpcConn: mockGrpcConn, } @@ -3276,7 +3276,7 @@ func TestConnectedNodeEndpoint_FailParsePort(t *testing.T) { defer ctrl.Finish() mockConnProvider := NewMockconnProvider(ctrl) - mockGrpcConn := NewMockGrpcClientConn(ctrl) + mockGrpcConn := NewMockgrpcClientConn(ctrl) mockConn := &connection{ grpcConn: mockGrpcConn, } diff --git a/connection_provider.go b/connection_provider.go index eb66a1a..90c8b12 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -23,18 +23,27 @@ import ( var errConnectionProviderClosed = errors.New("connection provider is closed") -type GrpcClientConn interface { +type grpcClientConn interface { grpc.ClientConnInterface Target() string Close() error } +type tokenManager interface { + RequireTransportSecurity() bool + ScheduleRefresh(func() (*connection, error)) + RefreshToken(context.Context, grpcClientConn) error + UnaryInterceptor() grpc.UnaryClientInterceptor + StreamInterceptor() grpc.StreamClientInterceptor + Close() +} + // connection represents a gRPC client connection and all the clients (stubs) // for the various AVS services. It's main purpose to remove the need to create // multiple clients for the same connection. This follows the documented grpc // best practice of reusing connections. type connection struct { - grpcConn GrpcClientConn + grpcConn grpcClientConn transactClient protos.TransactServiceClient authClient protos.AuthServiceClient userAdminClient protos.UserAdminServiceClient @@ -44,7 +53,7 @@ type connection struct { } // newConnection creates a new connection instance. -func newConnection(conn GrpcClientConn) *connection { +func newConnection(conn grpcClientConn) *connection { return &connection{ grpcConn: conn, transactClient: protos.NewTransactServiceClient(conn), @@ -85,7 +94,7 @@ type connectionProvider struct { clusterID uint64 listenerName *string isLoadBalancer bool - token *tokenManager + token tokenManager stopTendChan chan struct{} closed atomic.Bool } @@ -96,7 +105,7 @@ func newConnectionProvider( seeds HostPortSlice, listenerName *string, isLoadBalancer bool, - credentials *UserPassCredentials, + token tokenManager, tlsConfig *tls.Config, logger *slog.Logger, ) (*connectionProvider, error) { @@ -115,12 +124,7 @@ func newConnectionProvider( return nil, errors.New(msg) } - // Create a token manager if username and password are provided. - var token *tokenManager - - if credentials != nil { - token = newJWTToken(credentials.username, credentials.password, logger) - + if token != nil { if token.RequireTransportSecurity() && tlsConfig == nil { msg := "tlsConfig is required when username/password authentication" logger.Error(msg) @@ -324,7 +328,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { var authErr error wg := sync.WaitGroup{} - seedGrpcConns := make(chan GrpcClientConn) + seedGrpcConns := make(chan grpcClientConn) cp.seedConns = []*connection{} tokenLock := sync.Mutex{} // Ensures only one thread attempts to update token at a time tokenUpdated := false // Ensures token update only occurs once @@ -444,11 +448,11 @@ func (cp *connectionProvider) checkAndSetClusterID(clusterID uint64) bool { } // getTendConns returns all the gRPC client connections for tend operations. -func (cp *connectionProvider) getTendConns() []GrpcClientConn { +func (cp *connectionProvider) getTendConns() []grpcClientConn { cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - conns := make([]GrpcClientConn, len(cp.seedConns)+len(cp.nodeConns)) + conns := make([]grpcClientConn, len(cp.seedConns)+len(cp.nodeConns)) i := 0 for _, conn := range cp.seedConns { @@ -474,7 +478,7 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 for _, conn := range conns { wg.Add(1) - go func(conn GrpcClientConn) { + go func(conn grpcClientConn) { defer wg.Done() logger := cp.logger.With(slog.String("host", conn.Target())) @@ -685,7 +689,7 @@ func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort { // successful endpoint in endpoints. func (cp *connectionProvider) createGrpcConnFromEndpoints( endpoints *protos.ServerEndpointList, -) (GrpcClientConn, error) { +) (grpcClientConn, error) { for _, endpoint := range endpoints.Endpoints { if strings.ContainsRune(endpoint.Address, ':') { continue // TODO: Add logging and support for IPv6 @@ -703,7 +707,7 @@ func (cp *connectionProvider) createGrpcConnFromEndpoints( // createGrcpConn creates a gRPC client connection to a host. This handles adding // credential and configuring tls. -func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (GrpcClientConn, error) { +func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (grpcClientConn, error) { opts := []grpc.DialOption{} if cp.tlsConfig == nil { diff --git a/connection_provider_test.go b/connection_provider_test.go index fb5474d..a0b7d09 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -14,14 +14,14 @@ func TestNewConnectionProvider(t *testing.T) { seeds := HostPortSlice{} listenerName := "listener" isLoadBalancer := false - credentials := &UserPassCredentials{ - username: "admin", - password: "password", - } + // credentials := &UserPassCredentials{ + // username: "admin", + // password: "password", + // } tlsConfig := &tls.Config{} var logger *slog.Logger - cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, credentials, tlsConfig, logger) + cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, nil, tlsConfig, logger) assert.Nil(t, cp) assert.Error(t, err, errors.New("seeds cannot be nil or empty")) diff --git a/protos/utils_test.go b/protos/utils_test.go index bdabc65..6fec48b 100644 --- a/protos/utils_test.go +++ b/protos/utils_test.go @@ -1024,7 +1024,7 @@ func TestConvertFromValue(t *testing.T) { }, }, }, - expected: []any{float32(1), float32(2)}, + expected: []float32{float32(1), float32(2)}, }, { input: &Value{ @@ -1079,8 +1079,11 @@ func TestConvertToFields(t *testing.T) { for _, tc := range testCases { result, err := ConvertToFields(tc.input) + assert.Equal(t, len(tc.expected), len(result)) - assert.Equal(t, tc.expected, result) + for i, _ := range tc.expected { + assert.EqualExportedValues(t, tc.expected[i], result[i]) + } if tc.expectedErr { assert.Error(t, err) @@ -1144,7 +1147,7 @@ func (*unknownVectorType) isVector_Data() {} func TestConvertFromVector(t *testing.T) { testCases := []struct { input *Vector - expected []any + expected any expectedErr error }{ { @@ -1155,7 +1158,7 @@ func TestConvertFromVector(t *testing.T) { }, }, }, - expected: []any{float32(1), float32(2)}, + expected: []float32{float32(1), float32(2)}, }, { input: &Vector{ @@ -1165,12 +1168,12 @@ func TestConvertFromVector(t *testing.T) { }, }, }, - expected: []any{true, false}, + expected: []bool{true, false}, }, { input: &Vector{Data: &unknownVectorType{}}, expected: nil, - expectedErr: fmt.Errorf("unsupported value type: *protos.unknownVectorType"), + expectedErr: fmt.Errorf("unsupported vector data type: *protos.unknownVectorType"), }, } diff --git a/token_manager.go b/token_manager.go index 1947bcc..db1c653 100644 --- a/token_manager.go +++ b/token_manager.go @@ -15,11 +15,11 @@ import ( "google.golang.org/grpc/metadata" ) -// tokenManager is responsible for managing authentication tokens and refreshing +// grpcTokenManager is responsible for managing authentication tokens and refreshing // them when necessary. // //nolint:govet // We will favor readability over field alignment -type tokenManager struct { +type grpcTokenManager struct { username string password string token atomic.Value @@ -29,13 +29,13 @@ type tokenManager struct { refreshScheduled bool } -// newJWTToken creates a new tokenManager instance with the provided username, password, and logger. -func newJWTToken(username, password string, logger *slog.Logger) *tokenManager { +// newGrpcJWTToken creates a new tokenManager instance with the provided username, password, and logger. +func newGrpcJWTToken(username, password string, logger *slog.Logger) *grpcTokenManager { logger.WithGroup("jwt") logger.Debug("creating new token manager") - return &tokenManager{ + return &grpcTokenManager{ username: username, password: password, logger: logger, @@ -44,7 +44,7 @@ func newJWTToken(username, password string, logger *slog.Logger) *tokenManager { } // Close stops the scheduled token refresh and closes the token manager. -func (tm *tokenManager) Close() { +func (tm *grpcTokenManager) Close() { if tm.refreshScheduled { tm.logger.Debug("stopping scheduled token refresh") tm.stopRefreshChan <- struct{}{} @@ -55,14 +55,14 @@ func (tm *tokenManager) Close() { } // setRefreshTimeFromTTL sets the refresh time based on the provided time-to-live (TTL) duration. -func (tm *tokenManager) setRefreshTimeFromTTL(ttl time.Duration) { +func (tm *grpcTokenManager) setRefreshTimeFromTTL(ttl time.Duration) { tm.refreshTime.Store(time.Now().Add(ttl)) } // RefreshToken refreshes the authentication token using the provided gRPC client connection. // It returns a boolean indicating if the token was successfully refreshed and // an error if any. It is not thread safe. -func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnInterface) error { +func (tm *grpcTokenManager) RefreshToken(ctx context.Context, conn grpcClientConn) error { // We only want one goroutine to refresh the token at a time client := protos.NewAuthServiceClient(conn) resp, err := client.Authenticate(ctx, &protos.AuthRequest{ @@ -120,7 +120,7 @@ func (tm *tokenManager) RefreshToken(ctx context.Context, conn grpc.ClientConnIn // ScheduleRefresh schedules the token refresh using the provided function to // get the gRPC client connection. This is not threadsafe. It should only be // called once. -func (tm *tokenManager) ScheduleRefresh(getConn func() (*connection, error)) { +func (tm *grpcTokenManager) ScheduleRefresh(getConn func() (*connection, error)) { if tm.refreshScheduled { tm.logger.Warn("refresh already scheduled") } @@ -167,12 +167,12 @@ func (tm *tokenManager) ScheduleRefresh(getConn func() (*connection, error)) { } // RequireTransportSecurity returns true to indicate that transport security is required. -func (tm *tokenManager) RequireTransportSecurity() bool { +func (tm *grpcTokenManager) RequireTransportSecurity() bool { return true } // UnaryInterceptor returns the grpc unary client interceptor that attaches the token to outgoing requests. -func (tm *tokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor { +func (tm *grpcTokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor { return func( ctx context.Context, method string, @@ -186,7 +186,7 @@ func (tm *tokenManager) UnaryInterceptor() grpc.UnaryClientInterceptor { } // StreamInterceptor returns the grpc stream client interceptor that attaches the token to outgoing requests. -func (tm *tokenManager) StreamInterceptor() grpc.StreamClientInterceptor { +func (tm *grpcTokenManager) StreamInterceptor() grpc.StreamClientInterceptor { return func( ctx context.Context, desc *grpc.StreamDesc, @@ -200,7 +200,7 @@ func (tm *tokenManager) StreamInterceptor() grpc.StreamClientInterceptor { } // attachToken attaches the authentication token to the outgoing context. -func (tm *tokenManager) attachToken(ctx context.Context) context.Context { +func (tm *grpcTokenManager) attachToken(ctx context.Context) context.Context { rawToken := tm.token.Load() if rawToken == nil { return ctx From 05973cb299b60c4009f44a72cd4b4a3a5f4ba0af Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 19 Sep 2024 14:19:38 -0700 Subject: [PATCH 20/42] add more tests --- connection_provider.go | 2 +- connection_provider_test.go | 153 ++++++++++++++++++++++++++++++++++-- 2 files changed, 147 insertions(+), 8 deletions(-) diff --git a/connection_provider.go b/connection_provider.go index 90c8b12..acd1f78 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -21,7 +21,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -var errConnectionProviderClosed = errors.New("connection provider is closed") +var errConnectionProviderClosed = errors.New("connectionProvider is closed, cannot get connection") type grpcClientConn interface { grpc.ClientConnInterface diff --git a/connection_provider_test.go b/connection_provider_test.go index a0b7d09..53ae1be 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -4,25 +4,164 @@ import ( "context" "crypto/tls" "errors" + "fmt" "log/slog" + "sync/atomic" "testing" + "time" + "github.com/aerospike/avs-client-go/protos" "github.com/stretchr/testify/assert" + gomock "go.uber.org/mock/gomock" ) -func TestNewConnectionProvider(t *testing.T) { +func TestNewConnectionProvider_FailSeedsNil(t *testing.T) { seeds := HostPortSlice{} listenerName := "listener" isLoadBalancer := false - // credentials := &UserPassCredentials{ - // username: "admin", - // password: "password", - // } tlsConfig := &tls.Config{} var logger *slog.Logger + var token tokenManager - cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, nil, tlsConfig, logger) + cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) assert.Nil(t, cp) - assert.Error(t, err, errors.New("seeds cannot be nil or empty")) + assert.Equal(t, err, errors.New("seeds cannot be nil or empty")) +} + +func TestNewConnectionProvider_FailNoTLS(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + seeds := HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + listenerName := "listener" + isLoadBalancer := false + + var tlsConfig *tls.Config + var logger *slog.Logger + + token := NewMocktokenManager(ctrl) + + token. + EXPECT(). + RequireTransportSecurity(). + Return(true) + + cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) + + assert.Nil(t, cp) + assert.Equal(t, err, errors.New("tlsConfig is required when username/password authentication")) +} + +func TestNewConnectionProvider_FailConnectToSeedConns(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + + seeds := HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + listenerName := "listener" + isLoadBalancer := false + + var tlsConfig *tls.Config + var logger *slog.Logger + var token tokenManager + + cp, err := newConnectionProvider(ctx, seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) + + assert.Nil(t, cp) + assert.Equal(t, "failed to connect to seeds: context deadline exceeded", err.Error()) +} + +func TestClose_FailsToCloseConns(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSeedConn := NewMockgrpcClientConn(ctrl) + mockNodeConn := NewMockgrpcClientConn(ctrl) + + mockSeedConn. + EXPECT(). + Close(). + Return(fmt.Errorf("foo")) + + mockSeedConn. + EXPECT(). + Target(). + Return("") + + mockNodeConn. + EXPECT(). + Close(). + Return(fmt.Errorf("bar")) + + mockNodeConn. + EXPECT(). + Target(). + Return("") + + cp := &connectionProvider{ + isLoadBalancer: true, + seedConns: []*connection{ + { + grpcConn: mockSeedConn, + }, + }, + nodeConns: map[uint64]*connectionAndEndpoints{ + uint64(1): { + conn: &connection{grpcConn: mockNodeConn}, + endpoints: &protos.ServerEndpointList{}, + }, + }, + logger: slog.Default(), + } + + err := cp.Close() + + assert.Equal(t, fmt.Errorf("foo"), err) +} + +func TestGetSeedConn_FailBecauseClosed(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + } + + cp.closed.Store(true) + + _, err := cp.GetSeedConn() + + assert.Equal(t, errors.New("connectionProvider is closed, cannot get connection"), err) +} + +func TestGetSeedConn_FailSeedConnEmpty(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + } + + cp.closed.Store(false) + + _, err := cp.GetSeedConn() + + assert.Equal(t, errors.New("no seed connections found"), err) } From dbeed1c44c08bf0ac112679b244ad1d384c9aab0 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 20 Sep 2024 12:11:01 -0700 Subject: [PATCH 21/42] fix tests --- client.go | 1 + connection_provider.go | 4 ++++ connection_provider_test.go | 3 +++ 3 files changed, 8 insertions(+) diff --git a/client.go b/client.go index ce15e36..53e6537 100644 --- a/client.go +++ b/client.go @@ -112,6 +112,7 @@ func newClient( // error: An error if the closure fails, otherwise nil. func (c *Client) Close() error { c.logger.Info("Closing client") + return c.connectionProvider.Close() } diff --git a/connection_provider.go b/connection_provider.go index acd1f78..c370d18 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -175,6 +175,10 @@ func newConnectionProvider( // Close closes the connectionProvider and releases all resources. func (cp *connectionProvider) Close() error { + if cp == nil { + return nil + } + if !cp.isLoadBalancer { cp.stopTendChan <- struct{}{} <-cp.stopTendChan diff --git a/connection_provider_test.go b/connection_provider_test.go index 53ae1be..5d4adec 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -24,6 +24,7 @@ func TestNewConnectionProvider_FailSeedsNil(t *testing.T) { var token tokenManager cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) + defer cp.Close() assert.Nil(t, cp) assert.Equal(t, err, errors.New("seeds cannot be nil or empty")) @@ -53,6 +54,7 @@ func TestNewConnectionProvider_FailNoTLS(t *testing.T) { Return(true) cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) + defer cp.Close() assert.Nil(t, cp) assert.Equal(t, err, errors.New("tlsConfig is required when username/password authentication")) @@ -79,6 +81,7 @@ func TestNewConnectionProvider_FailConnectToSeedConns(t *testing.T) { var token tokenManager cp, err := newConnectionProvider(ctx, seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) + defer cp.Close() assert.Nil(t, cp) assert.Equal(t, "failed to connect to seeds: context deadline exceeded", err.Error()) From 05c3c269b6fcad73c285a5ee11c015ff028d396f Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 20 Sep 2024 15:02:32 -0700 Subject: [PATCH 22/42] fix tests --- client.go | 10 +- client_test.go | 174 ++++++++++++++++---------------- connection_provider.go | 10 +- integration_single_node_test.go | 48 ++++----- 4 files changed, 125 insertions(+), 117 deletions(-) diff --git a/client.go b/client.go index 53e6537..38a7d27 100644 --- a/client.go +++ b/client.go @@ -41,6 +41,7 @@ type connProvider interface { type Client struct { logger *slog.Logger connectionProvider connProvider + token tokenManager } // NewClient creates a new Client instance. @@ -88,19 +89,22 @@ func NewClient( logger, ) if err != nil { + grpcToken.Close() logger.Error("failed to create connection provider", slog.Any("error", err)) return nil, NewAVSErrorFromGrpc("failed to connect to server", err) } - return newClient(connectionProvider, logger) + return newClient(connectionProvider, grpcToken, logger) } func newClient( connectionProvider connProvider, + token tokenManager, logger *slog.Logger, ) (*Client, error) { return &Client{ logger: logger, + token: token, connectionProvider: connectionProvider, }, nil } @@ -113,6 +117,10 @@ func newClient( func (c *Client) Close() error { c.logger.Info("Closing client") + if c.token != nil { + c.token.Close() + } + return c.connectionProvider.Close() } diff --git a/client_test.go b/client_test.go index 99610b7..56284b0 100644 --- a/client_test.go +++ b/client_test.go @@ -63,7 +63,7 @@ func TestInsert_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -92,7 +92,7 @@ func TestInsert_FailsGettingConn(t *testing.T) { Return(nil, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -127,7 +127,7 @@ func TestInsert_FailsConvertingKey(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -162,7 +162,7 @@ func TestInsert_FailsConvertingFields(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -202,7 +202,7 @@ func TestInsert_FailsPutRequest(t *testing.T) { Return(&emptypb.Empty{}, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -270,7 +270,7 @@ func TestUpdate_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -338,7 +338,7 @@ func TestReplace_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -422,7 +422,7 @@ func TestGet_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -450,7 +450,7 @@ func TestGet_FailsGettingConn(t *testing.T) { Return(nil, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -483,7 +483,7 @@ func TestGet_FailsConvertingKey(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -521,7 +521,7 @@ func TestGet_FailsGetRequest(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -573,7 +573,7 @@ func TestDelete_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -600,7 +600,7 @@ func TestDelete_FailsGettingConn(t *testing.T) { Return(nil, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -633,7 +633,7 @@ func TestDelete_FailsConvertingKey(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -671,7 +671,7 @@ func TestDelete_FailsDeleteRequest(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -725,7 +725,7 @@ func TestExists_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -753,7 +753,7 @@ func TestExists_FailsGettingConn(t *testing.T) { Return(nil, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -786,7 +786,7 @@ func TestExists_FailsConvertingKey(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -824,7 +824,7 @@ func TestExists_FailsDeleteRequest(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -882,7 +882,7 @@ func TestIsIndexed_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -911,7 +911,7 @@ func TestIsIndexed_FailsGettingConn(t *testing.T) { Return(nil, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -945,7 +945,7 @@ func TestIsIndexed_FailsConvertingKey(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -984,7 +984,7 @@ func TestIsIndexed_FailsDeleteRequest(t *testing.T) { Return(mockConn, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1142,7 +1142,7 @@ func TestVectorSearchFloat32_Success(t *testing.T) { } // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1177,7 +1177,7 @@ func TestVectorSearchFloat32_FailsGettingConn(t *testing.T) { Return(nil, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1219,7 +1219,7 @@ func TestVectorSearchFloat32_FailsVectorSearch(t *testing.T) { Return(nil, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1272,7 +1272,7 @@ func TestVectorSearchFloat32_FailedToRecvAllNeighbors(t *testing.T) { ) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1331,7 +1331,7 @@ func TestVectorSearchFloat32_FailedToConvertNeighbor(t *testing.T) { ) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1388,7 +1388,7 @@ func TestWaitForIndexCompletion_SuccessAfterZeroCountReturnedTwice(t *testing.T) }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1446,7 +1446,7 @@ func TestWaitForIndexCompletion_SuccessAfterNonZeroUnmergedCount(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1476,7 +1476,7 @@ func TestWaitForIndexCompletion_FailToGetRandomConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1513,7 +1513,7 @@ func TestWaitForIndexCompletion_FailGetStatusCall(t *testing.T) { Return(nil, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1550,7 +1550,7 @@ func TestWaitForIndexCompletion_FailTimeout(t *testing.T) { Return(nil, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1607,7 +1607,7 @@ func TestIndexCreateFromIndexDef_Success(t *testing.T) { Return(nil, nil) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1641,7 +1641,7 @@ func TestIndexCreateFromIndexDef_FailGetConn(t *testing.T) { } // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1682,7 +1682,7 @@ func TestIndexCreateFromIndexDef_FailCreateCall(t *testing.T) { } // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1734,7 +1734,7 @@ func TestIndexUpdate_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1769,7 +1769,7 @@ func TestIndexUpdate_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1811,7 +1811,7 @@ func TestIndexUpdate_FailUpdateCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1869,7 +1869,7 @@ func TestIndexDrop_Success(t *testing.T) { Return(nil, status.Errorf(codes.NotFound, "foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1898,7 +1898,7 @@ func TestIndexDrop_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1934,7 +1934,7 @@ func TestIndexDrop_FailDropCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -1994,7 +1994,7 @@ func TestIndexList_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2022,7 +2022,7 @@ func TestIndexList_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2055,7 +2055,7 @@ func TestIndexList_FailDropCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2107,7 +2107,7 @@ func TestIndexGet_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2137,7 +2137,7 @@ func TestIndexGet_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2173,7 +2173,7 @@ func TestIndexGet_FailDropCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2223,7 +2223,7 @@ func TestIndexGetStatus_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2253,7 +2253,7 @@ func TestIndexGetStatus_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2289,7 +2289,7 @@ func TestIndexGetStatus_FailDropCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2337,7 +2337,7 @@ func TestIndexGcInvalidVertices_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2366,7 +2366,7 @@ func TestIndexGcInvalidVertices_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2403,7 +2403,7 @@ func TestIndexGcInvalidVertices_FailDropCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2457,7 +2457,7 @@ func TestCreateUser_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2487,7 +2487,7 @@ func TestCreateUser_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2524,7 +2524,7 @@ func TestCreateUser_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2575,7 +2575,7 @@ func TestUpdateCredentials_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2604,7 +2604,7 @@ func TestUpdateCredentials_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2640,7 +2640,7 @@ func TestUpdateCredentials_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2683,7 +2683,7 @@ func TestDropUser_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2711,7 +2711,7 @@ func TestDropUser_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2746,7 +2746,7 @@ func TestDropUser_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2795,7 +2795,7 @@ func TestGetUser_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2824,7 +2824,7 @@ func TestGetUser_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2859,7 +2859,7 @@ func TestGetUser_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2911,7 +2911,7 @@ func TestListUsers_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2939,7 +2939,7 @@ func TestListUsers_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -2973,7 +2973,7 @@ func TestListUsers_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3015,7 +3015,7 @@ func TestRevokeRoles_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3044,7 +3044,7 @@ func TestRevokeRoles_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3080,7 +3080,7 @@ func TestRevokeRoles_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3129,7 +3129,7 @@ func TestListRoles_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3157,7 +3157,7 @@ func TestListRoles_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3191,7 +3191,7 @@ func TestListRoles_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3230,7 +3230,7 @@ func TestConnectedNodeEndpoint_Success(t *testing.T) { Return("1.1.1.1:3000") // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3258,7 +3258,7 @@ func TestConnectedNodeEndpoint_FailedGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3292,7 +3292,7 @@ func TestConnectedNodeEndpoint_FailParsePort(t *testing.T) { Return("1.1.1.1:aaaa") // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3335,7 +3335,7 @@ func TestClusteringState_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3363,7 +3363,7 @@ func TestClusteringState_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3397,7 +3397,7 @@ func TestClusteringState_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3441,7 +3441,7 @@ func TestClusterEndpoints_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3469,7 +3469,7 @@ func TestClusterEndpoints_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3504,7 +3504,7 @@ func TestClusterEndpoints_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3544,7 +3544,7 @@ func TestAbout_Success(t *testing.T) { }) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3572,7 +3572,7 @@ func TestAbout_FailGetConn(t *testing.T) { Return(mockConn, fmt.Errorf("foo")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters @@ -3606,7 +3606,7 @@ func TestAbout_FailCall(t *testing.T) { Return(nil, fmt.Errorf("bar")) // Create the client with the mock connProvider - client, err := newClient(mockConnProvider, slog.Default()) + client, err := newClient(mockConnProvider, nil, slog.Default()) assert.NoError(t, err) // Prepare input parameters diff --git a/connection_provider.go b/connection_provider.go index c370d18..9002213 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -96,6 +96,7 @@ type connectionProvider struct { isLoadBalancer bool token tokenManager stopTendChan chan struct{} + initialized bool closed atomic.Bool } @@ -151,6 +152,7 @@ func newConnectionProvider( // Connect to the seed nodes. err := cp.connectToSeeds(ctx) if err != nil { + cp.Close() logger.Error("failed to connect to seeds", slog.Any("error", err)) return nil, err } @@ -170,6 +172,8 @@ func newConnectionProvider( cp.logger.Debug("load balancer is enabled, not starting tend routine") } + cp.initialized = true + return cp, nil } @@ -179,17 +183,13 @@ func (cp *connectionProvider) Close() error { return nil } - if !cp.isLoadBalancer { + if !cp.isLoadBalancer && cp.initialized { cp.stopTendChan <- struct{}{} <-cp.stopTendChan } var firstErr error - if cp.token != nil { - cp.token.Close() - } - for _, conn := range cp.seedConns { err := conn.grpcConn.Close() if err != nil { diff --git a/integration_single_node_test.go b/integration_single_node_test.go index dbd3842..ad3e2ab 100644 --- a/integration_single_node_test.go +++ b/integration_single_node_test.go @@ -1036,31 +1036,31 @@ func (suite *SingleNodeTestSuite) TestUserDelete() { suite.Error(err) } -func (suite *SingleNodeTestSuite) TestUserUpdateCredentials() { - suite.SkipIfUserPassAuthDisabled() - - ctx := context.Background() - username := getUniqueUserName() - - err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) - suite.NoError(err) +// func (suite *SingleNodeTestSuite) TestUserUpdateCredentials() { +// suite.SkipIfUserPassAuthDisabled() - err = suite.AvsClient.UpdateCredentials(ctx, username, "new-password") - suite.NoError(err) - - c, err := NewClient( - ctx, - HostPortSlice{suite.AvsHostPort}, - nil, - suite.AvsLB, - NewCredentialsFromUserPass(username, "new-password"), - suite.AvsTLSConfig, - suite.Logger, - ) - suite.NoError(err) - - c.Close() -} +// ctx := context.Background() +// username := getUniqueUserName() + +// err := suite.AvsClient.CreateUser(ctx, username, "test-password", []string{"read-write"}) +// suite.NoError(err) + +// err = suite.AvsClient.UpdateCredentials(ctx, username, "new-password") +// suite.NoError(err) + +// c, err := NewClient( +// ctx, +// HostPortSlice{suite.AvsHostPort}, +// nil, +// suite.AvsLB, +// NewCredentialsFromUserPass(username, "new-password"), +// suite.AvsTLSConfig, +// suite.Logger, +// ) +// suite.NoError(err) + +// c.Close() +// } func (suite *SingleNodeTestSuite) TestUserGrantRoles() { suite.SkipIfUserPassAuthDisabled() From 460ce11c434f26c45da5a49a4e1f224c887c7551 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 20 Sep 2024 15:20:26 -0700 Subject: [PATCH 23/42] fix tests --- connection_provider.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/connection_provider.go b/connection_provider.go index 9002213..eac5d30 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -96,7 +96,6 @@ type connectionProvider struct { isLoadBalancer bool token tokenManager stopTendChan chan struct{} - initialized bool closed atomic.Bool } @@ -172,8 +171,6 @@ func newConnectionProvider( cp.logger.Debug("load balancer is enabled, not starting tend routine") } - cp.initialized = true - return cp, nil } @@ -183,7 +180,7 @@ func (cp *connectionProvider) Close() error { return nil } - if !cp.isLoadBalancer && cp.initialized { + if !cp.isLoadBalancer { cp.stopTendChan <- struct{}{} <-cp.stopTendChan } @@ -348,6 +345,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { grpcConn, err := cp.createGrcpConn(seed) if err != nil { logger.ErrorContext(ctx, "failed to create connection", slog.Any("error", err)) + grpcConn.Close() return } @@ -363,7 +361,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { logger.WarnContext(ctx, "failed to refresh token", slog.Any("error", err)) authErr = err tokenLock.Unlock() - + grpcConn.Close() return } @@ -380,6 +378,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { about, err := client.Get(ctx, &protos.AboutRequest{}) if err != nil { logger.WarnContext(ctx, "failed to connect to seed", slog.Any("error", err)) + grpcConn.Close() return } From 12d951f26c97cccca06ea29e594b0526e43a3198 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 20 Sep 2024 15:21:10 -0700 Subject: [PATCH 24/42] again --- connection_provider.go | 1 - 1 file changed, 1 deletion(-) diff --git a/connection_provider.go b/connection_provider.go index eac5d30..f85c00b 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -151,7 +151,6 @@ func newConnectionProvider( // Connect to the seed nodes. err := cp.connectToSeeds(ctx) if err != nil { - cp.Close() logger.Error("failed to connect to seeds", slog.Any("error", err)) return nil, err } From b8051cf2ca3a89b31d0899142713f5d256ff6c15 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 20 Sep 2024 16:35:59 -0700 Subject: [PATCH 25/42] more tests --- connection_provider.go | 61 ++-------- connection_provider_test.go | 221 ++++++++++++++++++++++++++++++++++++ utils.go | 44 +++++++ utils_test.go | 116 ++++++++++++++++++- 4 files changed, 388 insertions(+), 54 deletions(-) diff --git a/connection_provider.go b/connection_provider.go index f85c00b..f9fa169 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -7,7 +7,6 @@ import ( "fmt" "log/slog" "math/rand" - "sort" "strings" "sync" "sync/atomic" @@ -450,20 +449,20 @@ func (cp *connectionProvider) checkAndSetClusterID(clusterID uint64) bool { } // getTendConns returns all the gRPC client connections for tend operations. -func (cp *connectionProvider) getTendConns() []grpcClientConn { +func (cp *connectionProvider) getTendConns() []*connection { cp.nodeConnsLock.RLock() defer cp.nodeConnsLock.RUnlock() - conns := make([]grpcClientConn, len(cp.seedConns)+len(cp.nodeConns)) + conns := make([]*connection, len(cp.seedConns)+len(cp.nodeConns)) i := 0 for _, conn := range cp.seedConns { - conns[i] = conn.grpcConn + conns[i] = conn i++ } for _, conn := range cp.nodeConns { - conns[i] = conn.conn.grpcConn + conns[i] = conn.conn i++ } @@ -480,13 +479,12 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 for _, conn := range conns { wg.Add(1) - go func(conn grpcClientConn) { + go func(conn *connection) { defer wg.Done() - logger := cp.logger.With(slog.String("host", conn.Target())) - client := protos.NewClusterInfoServiceClient(conn) + logger := cp.logger.With(slog.String("host", conn.grpcConn.Target())) - clusterID, err := client.GetClusterId(ctx, &emptypb.Empty{}) + clusterID, err := conn.clusterInfoClient.GetClusterId(ctx, &emptypb.Empty{}) if err != nil { logger.WarnContext(ctx, "failed to get cluster ID", slog.Any("error", err)) } @@ -503,7 +501,7 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 logger.DebugContext(ctx, "new cluster ID found", slog.Uint64("clusterID", clusterID.GetId())) - endpointsResp, err := client.GetClusterEndpoints(ctx, endpointsReq) + endpointsResp, err := conn.clusterInfoClient.GetClusterEndpoints(ctx, endpointsReq) if err != nil { logger.ErrorContext(ctx, "failed to get cluster endpoints", slog.Any("error", err)) return @@ -644,49 +642,6 @@ func (cp *connectionProvider) tend(ctx context.Context) { } } -func endpointEqual(a, b *protos.ServerEndpoint) bool { - return a.Address == b.Address && a.Port == b.Port && a.IsTls == b.IsTls -} - -func endpointListEqual(a, b *protos.ServerEndpointList) bool { - if len(a.Endpoints) != len(b.Endpoints) { - return false - } - - aEndpoints := make([]*protos.ServerEndpoint, len(a.Endpoints)) - copy(aEndpoints, a.Endpoints) - - bEndpoints := make([]*protos.ServerEndpoint, len(b.Endpoints)) - copy(bEndpoints, b.Endpoints) - - sortFunc := func(endpoints []*protos.ServerEndpoint) func(int, int) bool { - return func(i, j int) bool { - if endpoints[i].Address < endpoints[j].Address { - return true - } else if endpoints[i].Address > endpoints[j].Address { - return false - } - - return endpoints[i].Port < endpoints[j].Port - } - } - - sort.Slice(aEndpoints, sortFunc(aEndpoints)) - sort.Slice(bEndpoints, sortFunc(bEndpoints)) - - for i, endpoint := range aEndpoints { - if !endpointEqual(endpoint, bEndpoints[i]) { - return false - } - } - - return true -} - -func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort { - return NewHostPort(endpoint.Address, int(endpoint.Port)) -} - // createGrpcConnFromEndpoints creates a gRPC client connection from the first // successful endpoint in endpoints. func (cp *connectionProvider) createGrpcConnFromEndpoints( diff --git a/connection_provider_test.go b/connection_provider_test.go index 5d4adec..8f47967 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "sync" "sync/atomic" "testing" "time" @@ -168,3 +169,223 @@ func TestGetSeedConn_FailSeedConnEmpty(t *testing.T) { assert.Equal(t, errors.New("no seed connections found"), err) } +func TestUpdateClusterConns_NoNewClusterID(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + + cp := &connectionProvider{ + logger: slog.Default(), + nodeConns: make(map[uint64]*connectionAndEndpoints), + seedConns: []*connection{}, + tlsConfig: nil, + seeds: HostPortSlice{}, + nodeConnsLock: &sync.RWMutex{}, + tendInterval: time.Second * 1, + clusterID: 123, + listenerName: nil, + isLoadBalancer: false, + token: nil, + stopTendChan: make(chan struct{}), + closed: atomic.Bool{}, + } + + cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NoNewClusterID")) + + cp.logger.Debug("Setting up existing node connections") + + grpcConn1 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl) + grpcConn2 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl) + + grpcConn1. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient1. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 123, + }, nil) + + grpcConn2. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient2. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 123, + }, nil) + + // Existing node connections + cp.nodeConns[1] = &connectionAndEndpoints{ + conn: &connection{ + grpcConn: grpcConn1, + clusterInfoClient: mockClusterInfoClient1, + }, + endpoints: &protos.ServerEndpointList{}, + } + + cp.nodeConns[2] = &connectionAndEndpoints{ + conn: &connection{ + grpcConn: grpcConn2, + clusterInfoClient: mockClusterInfoClient2, + }, + endpoints: &protos.ServerEndpointList{}, + } + + cp.logger.Debug("Running updateClusterConns") + + cp.updateClusterConns(ctx) + + assert.Equal(t, uint64(123), cp.clusterID) + assert.Len(t, cp.nodeConns, 2) +} + +// func TestUpdateClusterConns_NewClusterID(t *testing.T) { +// ctrl := gomock.NewController(t) +// defer ctrl.Finish() + +// ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) +// defer cancel() + +// cp := &connectionProvider{ +// logger: slog.Default(), +// nodeConns: make(map[uint64]*connectionAndEndpoints), +// seedConns: []*connection{}, +// tlsConfig: &tls.Config{}, +// seeds: HostPortSlice{}, +// nodeConnsLock: &sync.RWMutex{}, +// tendInterval: time.Second * 1, +// clusterID: 123, +// listenerName: nil, +// isLoadBalancer: false, +// token: nil, +// stopTendChan: make(chan struct{}), +// closed: atomic.Bool{}, +// } + +// cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID")) + +// cp.logger.Debug("Setting up existing node connections") + +// grpcConn1 := NewMockgrpcClientConn(ctrl) +// mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl) +// grpcConn2 := NewMockgrpcClientConn(ctrl) +// mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl) + +// grpcConn1. +// EXPECT(). +// Target(). +// Return("") + +// mockClusterInfoClient1. +// EXPECT(). +// GetClusterId(gomock.Any(), gomock.Any()). +// Return(&protos.ClusterId{ +// Id: 123, +// }, nil) + +// grpcConn1. +// EXPECT(). +// Close(). +// Return(nil) + +// grpcConn2. +// EXPECT(). +// Target(). +// Return("") + +// mockClusterInfoClient2. +// EXPECT(). +// GetClusterId(gomock.Any(), gomock.Any()). +// Return(&protos.ClusterId{ +// Id: 456, +// }, nil) + +// mockClusterInfoClient2. +// EXPECT(). +// GetClusterEndpoints(gomock.Any(), gomock.Any()). +// Return(&protos.ClusterNodeEndpoints{ +// Endpoints: map[uint64]*protos.ServerEndpointList{ +// 3: { +// Endpoints: []*protos.ServerEndpoint{ +// { +// Address: "1.1.1.1", +// Port: 3000, +// }, +// }, +// }, +// 4: { +// Endpoints: []*protos.ServerEndpoint{ +// { +// Address: "2.2.2.2", +// Port: 3000, +// }, +// }, +// }, +// }, +// }, nil) + +// grpcConn2. +// EXPECT(). +// Close(). +// Return(nil) + +// // Existing node connections +// cp.nodeConns[1] = &connectionAndEndpoints{ +// conn: &connection{ +// grpcConn: grpcConn1, +// clusterInfoClient: mockClusterInfoClient1, +// }, +// endpoints: &protos.ServerEndpointList{}, +// } + +// cp.nodeConns[2] = &connectionAndEndpoints{ +// conn: &connection{ +// grpcConn: grpcConn2, +// clusterInfoClient: mockClusterInfoClient2, +// }, +// endpoints: &protos.ServerEndpointList{}, +// } + +// cp.logger.Debug("Running updateClusterConns") + +// // New cluster ID +// // newEndpoints := &protos.ServerEndpointList{ +// // Endpoints: []*protos.ServerEndpoint{ +// // { +// // Address: "host1", +// // Port: 3000, +// // }, +// // { +// // Address: "host2", +// // Port: 3000, +// // }, +// // }, +// // } + +// cp.updateClusterConns(ctx) + +// // cp.checkAndSetNodeConns(ctx, map[uint64]*protos.ServerEndpointList{ +// // 1: newEndpoints, +// // 2: newEndpoints, +// // }) + +// // cp.removeDownNodes(map[uint64]*protos.ServerEndpointList{ +// // 1: newEndpoints, +// // 2: newEndpoints, +// // }) + +// // cp.updateClusterConns(ctx) + +// assert.Equal(t, uint64(456), cp.clusterID) +// assert.Len(t, cp.nodeConns, 2) +// } diff --git a/utils.go b/utils.go index b099ff4..557abae 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,7 @@ package avs import ( + "sort" "strconv" "strings" @@ -74,6 +75,49 @@ func createIndexStatusRequest(namespace, name string) *protos.IndexStatusRequest } } +func endpointEqual(a, b *protos.ServerEndpoint) bool { + return a.Address == b.Address && a.Port == b.Port && a.IsTls == b.IsTls +} + +func endpointListEqual(a, b *protos.ServerEndpointList) bool { + if len(a.Endpoints) != len(b.Endpoints) { + return false + } + + aEndpoints := make([]*protos.ServerEndpoint, len(a.Endpoints)) + copy(aEndpoints, a.Endpoints) + + bEndpoints := make([]*protos.ServerEndpoint, len(b.Endpoints)) + copy(bEndpoints, b.Endpoints) + + sortFunc := func(endpoints []*protos.ServerEndpoint) func(int, int) bool { + return func(i, j int) bool { + if endpoints[i].Address < endpoints[j].Address { + return true + } else if endpoints[i].Address > endpoints[j].Address { + return false + } + + return endpoints[i].Port < endpoints[j].Port + } + } + + sort.Slice(aEndpoints, sortFunc(aEndpoints)) + sort.Slice(bEndpoints, sortFunc(bEndpoints)) + + for i, endpoint := range aEndpoints { + if !endpointEqual(endpoint, bEndpoints[i]) { + return false + } + } + + return true +} + +func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort { + return NewHostPort(endpoint.Address, int(endpoint.Port)) +} + var minimumFullySupportedAVSVersion = newVersion("0.10.0") type version []any diff --git a/utils_test.go b/utils_test.go index 2fc6528..cd0d71e 100644 --- a/utils_test.go +++ b/utils_test.go @@ -1,6 +1,10 @@ package avs -import "testing" +import ( + "testing" + + "github.com/aerospike/avs-client-go/protos" +) func TestVersionLTGT(t *testing.T) { testCases := []struct { @@ -75,3 +79,113 @@ func TestVersionLTGT(t *testing.T) { }) } } + +func TestEndpointListEqual(t *testing.T) { + testCases := []struct { + name string + endpoints1 *protos.ServerEndpointList + endpoints2 *protos.ServerEndpointList + want bool + }{ + { + name: "equal endpoints", + endpoints1: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + { + Address: "127.0.0.1", + Port: 9090, + IsTls: true, + }, + }, + }, + endpoints2: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + { + Address: "127.0.0.1", + Port: 9090, + IsTls: true, + }, + }, + }, + want: true, + }, + { + name: "different endpoints", + endpoints1: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + { + Address: "127.0.0.1", + Port: 9090, + IsTls: true, + }, + }, + }, + endpoints2: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + { + Address: "127.0.0.1", + Port: 9091, + IsTls: true, + }, + }, + }, + want: false, + }, + { + name: "different number of endpoints", + endpoints1: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + { + Address: "127.0.0.1", + Port: 9090, + IsTls: true, + }, + }, + }, + endpoints2: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, + }, + }, + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := endpointListEqual(tc.endpoints1, tc.endpoints2) + if got != tc.want { + t.Errorf("expected %v, got %v", tc.want, got) + } + }) + } +} From f8e2b8990dc88ea914286db24ac6c36a47aa7d99 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 20 Sep 2024 16:44:32 -0700 Subject: [PATCH 26/42] more tests --- utils.go | 4 +- utils_test.go | 104 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/utils.go b/utils.go index 557abae..c895790 100644 --- a/utils.go +++ b/utils.go @@ -56,7 +56,9 @@ func createProjectionSpec(includeFields, excludeFields []string) *protos.Project Type: protos.ProjectionType_SPECIFIED, Fields: includeFields, } - } else if excludeFields != nil { + } + + if excludeFields != nil { spec.Exclude = &protos.ProjectionFilter{ Type: protos.ProjectionType_SPECIFIED, Fields: excludeFields, diff --git a/utils_test.go b/utils_test.go index cd0d71e..4886715 100644 --- a/utils_test.go +++ b/utils_test.go @@ -1,11 +1,115 @@ package avs import ( + reflect "reflect" "testing" "github.com/aerospike/avs-client-go/protos" ) +func TestCreateProjectionSpec(t *testing.T) { + testCases := []struct { + name string + includeFields []string + excludeFields []string + expectedProjectionSpec *protos.ProjectionSpec + }{ + { + name: "include fields", + includeFields: []string{"field1", "field2"}, + excludeFields: nil, + expectedProjectionSpec: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_SPECIFIED, + Fields: []string{"field1", "field2"}, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_NONE, + }, + }, + }, + { + name: "exclude fields", + includeFields: nil, + excludeFields: []string{"field3", "field4"}, + expectedProjectionSpec: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_ALL, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_SPECIFIED, + Fields: []string{"field3", "field4"}, + }, + }, + }, + { + name: "default fields", + includeFields: nil, + excludeFields: nil, + expectedProjectionSpec: &protos.ProjectionSpec{ + Include: &protos.ProjectionFilter{ + Type: protos.ProjectionType_ALL, + }, + Exclude: &protos.ProjectionFilter{ + Type: protos.ProjectionType_NONE, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + spec := createProjectionSpec(tc.includeFields, tc.excludeFields) + + if !reflect.DeepEqual(spec, tc.expectedProjectionSpec) { + t.Errorf("expected projection spec %v, got %v", tc.expectedProjectionSpec, spec) + } + }) + } +} +func TestVersionString(t *testing.T) { + testCases := []struct { + name string + v version + expected string + }{ + { + name: "valid version", + v: newVersion("1.2.3"), + expected: "1.2.3", + }, + { + name: "valid version with suffix", + v: newVersion("1.2.3-dev"), + expected: "1.2.3-dev", + }, + { + name: "valid version with multiple suffixes", + v: newVersion("1.2.3-dev.1"), + expected: "1.2.3-dev.1", + }, + { + name: "valid version with pre-release", + v: newVersion("1.2.3-alpha"), + expected: "1.2.3-alpha", + }, + { + name: "valid version with pre-release and build metadata", + v: newVersion("1.2.3-alpha+build123"), + expected: "1.2.3-alpha+build123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := tc.v.String() + if got != tc.expected { + t.Errorf("expected %s, got %s", tc.expected, got) + } + }) + } +} + func TestVersionLTGT(t *testing.T) { testCases := []struct { name string From 8959927239cc7f33d50e2703b46134fe41918c0c Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 20 Sep 2024 16:50:53 -0700 Subject: [PATCH 27/42] add sort --- protos/utils_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/protos/utils_test.go b/protos/utils_test.go index 6fec48b..acd6195 100644 --- a/protos/utils_test.go +++ b/protos/utils_test.go @@ -2,6 +2,7 @@ package protos import ( "fmt" + "sort" "testing" "github.com/stretchr/testify/assert" @@ -1077,10 +1078,19 @@ func TestConvertToFields(t *testing.T) { }, } + sortFunc := func(fields []*Field) func(int, int) bool { + return func(i, j int) bool { + return fields[i].Name < fields[j].Name + } + } + for _, tc := range testCases { result, err := ConvertToFields(tc.input) assert.Equal(t, len(tc.expected), len(result)) + sort.Slice(result, sortFunc(result)) + sort.Slice(tc.expected, sortFunc(result)) + for i, _ := range tc.expected { assert.EqualExportedValues(t, tc.expected[i], result[i]) } From ea25445ff8bbe969c01872154c712aeffdf87536 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 23 Sep 2024 10:35:48 -0700 Subject: [PATCH 28/42] add tests --- connection_provider.go | 27 +++++++++++++++------------ token_manager.go | 18 +++++++++++------- utils_test.go | 10 +++++----- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/connection_provider.go b/connection_provider.go index f9fa169..e69c4c2 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -31,7 +31,7 @@ type grpcClientConn interface { type tokenManager interface { RequireTransportSecurity() bool ScheduleRefresh(func() (*connection, error)) - RefreshToken(context.Context, grpcClientConn) error + RefreshToken(context.Context, *connection) error UnaryInterceptor() grpc.UnaryClientInterceptor StreamInterceptor() grpc.StreamClientInterceptor Close() @@ -64,6 +64,10 @@ func newConnection(conn grpcClientConn) *connection { } } +func (conn *connection) close() error { + return conn.grpcConn.Close() +} + // connectionAndEndpoints represents a combination of a gRPC client connection and server endpoints. type connectionAndEndpoints struct { conn *connection @@ -186,7 +190,7 @@ func (cp *connectionProvider) Close() error { var firstErr error for _, conn := range cp.seedConns { - err := conn.grpcConn.Close() + err := conn.close() if err != nil { if firstErr == nil { firstErr = err @@ -327,7 +331,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { var authErr error wg := sync.WaitGroup{} - seedGrpcConns := make(chan grpcClientConn) + seedConns := make(chan *connection) cp.seedConns = []*connection{} tokenLock := sync.Mutex{} // Ensures only one thread attempts to update token at a time tokenUpdated := false // Ensures token update only occurs once @@ -348,18 +352,19 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { } extraCheck := true + conn := newConnection(grpcConn) if cp.token != nil { // Only one thread needs to refresh the token. Only first will // succeed others will block tokenLock.Lock() if !tokenUpdated { - err := cp.token.RefreshToken(ctx, grpcConn) + err := cp.token.RefreshToken(ctx, conn) if err != nil { logger.WarnContext(ctx, "failed to refresh token", slog.Any("error", err)) authErr = err tokenLock.Unlock() - grpcConn.Close() + conn.close() return } @@ -371,9 +376,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { } if extraCheck { - client := protos.NewAboutServiceClient(grpcConn) - - about, err := client.Get(ctx, &protos.AboutRequest{}) + about, err := conn.aboutClient.Get(ctx, &protos.AboutRequest{}) if err != nil { logger.WarnContext(ctx, "failed to connect to seed", slog.Any("error", err)) grpcConn.Close() @@ -385,17 +388,17 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { } } - seedGrpcConns <- grpcConn + seedConns <- conn }(seed) } go func() { wg.Wait() - close(seedGrpcConns) + close(seedConns) }() - for conn := range seedGrpcConns { - cp.seedConns = append(cp.seedConns, newConnection(conn)) + for conn := range seedConns { + cp.seedConns = append(cp.seedConns, conn) } if len(cp.seedConns) == 0 { diff --git a/token_manager.go b/token_manager.go index db1c653..848f96b 100644 --- a/token_manager.go +++ b/token_manager.go @@ -62,10 +62,9 @@ func (tm *grpcTokenManager) setRefreshTimeFromTTL(ttl time.Duration) { // RefreshToken refreshes the authentication token using the provided gRPC client connection. // It returns a boolean indicating if the token was successfully refreshed and // an error if any. It is not thread safe. -func (tm *grpcTokenManager) RefreshToken(ctx context.Context, conn grpcClientConn) error { +func (tm *grpcTokenManager) RefreshToken(ctx context.Context, conn *connection) error { // We only want one goroutine to refresh the token at a time - client := protos.NewAuthServiceClient(conn) - resp, err := client.Authenticate(ctx, &protos.AuthRequest{ + resp, err := conn.authClient.Authenticate(ctx, &protos.AuthRequest{ Credentials: createUserPassCredential(tm.username, tm.password), }) @@ -74,6 +73,11 @@ func (tm *grpcTokenManager) RefreshToken(ctx context.Context, conn grpcClientCon } claims := strings.Split(resp.GetToken(), ".") + + if len(claims) < 3 { + return fmt.Errorf("failed to authenticate: missing either header, payload, or signature") + } + decClaims, err := base64.RawURLEncoding.DecodeString(claims[1]) if err != nil { @@ -89,17 +93,17 @@ func (tm *grpcTokenManager) RefreshToken(ctx context.Context, conn grpcClientCon expiryToken, ok := tokenMap["exp"].(float64) if !ok { - return fmt.Errorf("%s: %w", "failed to authenticate", err) + return fmt.Errorf("failed to authenticate: unable to find exp in token") } iat, ok := tokenMap["iat"].(float64) if !ok { - return fmt.Errorf("%s: %w", "failed to authenticate", err) + return fmt.Errorf("failed to authenticate: unable to find iat in token") } ttl := time.Duration(expiryToken-iat) * time.Second if ttl <= 0 { - return fmt.Errorf("%s: %w", "failed to authenticate", err) + return fmt.Errorf("failed to authenticate: jwt ttl is less than 0") } tm.logger.DebugContext( @@ -139,7 +143,7 @@ func (tm *grpcTokenManager) ScheduleRefresh(getConn func() (*connection, error)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - err = tm.RefreshToken(ctx, connClients.grpcConn) + err = tm.RefreshToken(ctx, connClients) if err != nil { tm.logger.Warn("failed to refresh token", slog.Any("error", err)) } diff --git a/utils_test.go b/utils_test.go index 4886715..532ad1f 100644 --- a/utils_test.go +++ b/utils_test.go @@ -209,16 +209,16 @@ func TestEndpointListEqual(t *testing.T) { }, endpoints2: &protos.ServerEndpointList{ Endpoints: []*protos.ServerEndpoint{ - { - Address: "localhost", - Port: 8080, - IsTls: false, - }, { Address: "127.0.0.1", Port: 9090, IsTls: true, }, + { + Address: "localhost", + Port: 8080, + IsTls: false, + }, }, }, want: true, From bf79f62d395390252f733b66a3045f5ca0eb79fd Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 23 Sep 2024 10:36:00 -0700 Subject: [PATCH 29/42] add tests --- token_manager_test.go | 301 ++++++++++++++++++++++++++++++++++++++++++ types_test.go | 30 +++++ 2 files changed, 331 insertions(+) create mode 100644 token_manager_test.go create mode 100644 types_test.go diff --git a/token_manager_test.go b/token_manager_test.go new file mode 100644 index 0000000..8a3c971 --- /dev/null +++ b/token_manager_test.go @@ -0,0 +1,301 @@ +package avs + +import ( + "context" + "fmt" + "log/slog" + "testing" + + "github.com/aerospike/avs-client-go/protos" + "github.com/stretchr/testify/assert" + gomock "go.uber.org/mock/gomock" +) + +func TestRefreshToken_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDAsImlhdCI6MTcyNzExMDc1NH0.GD01CEWxW6-7lHcyeetM95WKdUlwY85m5lFqzcTCtzs" + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.NoError(t, err) + + // Assert the token was set correctly + assert.Equal(t, "Bearer "+b64token, tm.token.Load().(string)) +} + +func TestRefreshToken_FailedToRefreshToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("foo")) + + // Create the token manager with the + tm := newGrpcJWTToken(username, password, logger) + + // Refresh + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert error occurred + assert.Equal(t, "failed to authenticate: foo", err.Error()) +} + +func TestRefreshToken_FailedMissingClaims(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: "badToken", + }, nil) + + // Create the token manager with the + tm := newGrpcJWTToken(username, password, logger) + + // Refresh + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: missing either header, payload, or signature", err.Error()) +} + +func TestRefreshToken_FailedToDecodeToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: "badToken.blah.foo", + }, nil) + + // Create the token manager with the + tm := newGrpcJWTToken(username, password, logger) + + // Refresh + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: invalid character 'V' in literal null (expecting 'u')", err.Error()) +} + +func TestRefreshToken_FailedInvalidJson(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDAsImlhdCI6MTcyNzExMD.GD01CEWxW6-7lHcyeetM95WKdUlwY85m5lFqzcTCtzs" + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: unexpected end of JSON input", err.Error()) +} + +func TestRefreshToken_FailedFindExp(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE3MjcxMTA3NTR9.50IZcLoS7mQPzQsKvJZyXNUukvT5FdiqN2tynNIjHuk" + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: unable to find exp in token", err.Error()) +} + +func TestRefreshToken_FailedFindIat(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDB9.f5DtjF1sYLH6fz0ThcFKxwngIXkVMLnhJtIrjLi_1p0" + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: unable to find iat in token", err.Error()) +} + +func TestRefreshToken_FailedTtlLessThan0(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + username := "testUser" + password := "testPassword" + logger := slog.Default() + + // Create a mock gRPC client connection + mockConn := NewMockgrpcClientConn(ctrl) + mockAuthServiceClient := protos.NewMockAuthServiceClient(ctrl) + mockConnClients := &connection{ + grpcConn: mockConn, + authClient: mockAuthServiceClient, + } + + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE3MjcxMTA3NTMsImlhdCI6MTcyNzExMDc1NH0.YdH5twU6-LGLtgvD2sktiw1j40MRUe_r4oPN565z4Ok" + + // Set up expectations for AuthServiceClient.Authenticate() + mockAuthServiceClient. + EXPECT(). + Authenticate(gomock.Any(), gomock.Any()). + Return(&protos.AuthResponse{ + Token: b64token, + }, nil) + + // Create the token manager with the mock connection + tm := newGrpcJWTToken(username, password, logger) + + // Refresh the token + err := tm.RefreshToken(context.Background(), mockConnClients) + + // Assert no error occurred + assert.Error(t, err) + assert.Equal(t, "failed to authenticate: jwt ttl is less than 0", err.Error()) +} diff --git a/types_test.go b/types_test.go new file mode 100644 index 0000000..aab1472 --- /dev/null +++ b/types_test.go @@ -0,0 +1,30 @@ +package avs + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHostPort_String(t *testing.T) { + host := "localhost" + port := 8080 + hp := NewHostPort(host, port) + + expected := "localhost:8080" + result := hp.String() + + assert.Equal(t, expected, result) +} + +func TestHostPortSlice_String(t *testing.T) { + hps := HostPortSlice{ + NewHostPort("localhost", 8080), + NewHostPort("example.com", 1234), + } + + expected := "[localhost:8080, example.com:1234]" + result := hps.String() + + assert.Equal(t, expected, result) +} From bbaaa12c5985870f0fbbf0a34cbe45c998dcca6f Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 23 Sep 2024 12:57:17 -0700 Subject: [PATCH 30/42] add tests --- connection_provider.go | 85 +++++---- connection_provider_test.go | 350 +++++++++++++++++++++--------------- token_manager_test.go | 4 +- 3 files changed, 259 insertions(+), 180 deletions(-) diff --git a/connection_provider.go b/connection_provider.go index e69c4c2..e02ebec 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -87,19 +87,21 @@ func newConnAndEndpoints(conn *connection, endpoints *protos.ServerEndpointList) // //nolint:govet // We will favor readability over field alignment type connectionProvider struct { - logger *slog.Logger - nodeConns map[uint64]*connectionAndEndpoints - seedConns []*connection - tlsConfig *tls.Config - seeds HostPortSlice - nodeConnsLock *sync.RWMutex - tendInterval time.Duration - clusterID uint64 - listenerName *string - isLoadBalancer bool - token tokenManager - stopTendChan chan struct{} - closed atomic.Bool + logger *slog.Logger + nodeConns map[uint64]*connectionAndEndpoints + seedConns []*connection + tlsConfig *tls.Config + seeds HostPortSlice + nodeConnsLock *sync.RWMutex + tendInterval time.Duration + clusterID uint64 + listenerName *string + isLoadBalancer bool + token tokenManager + stopTendChan chan struct{} + closed atomic.Bool + grpcConnFactory func(hostPort *HostPort) (grpcClientConn, error) // For testing + connFactory func(conn grpcClientConn) *connection // For testing } // newConnectionProvider creates a new connectionProvider instance. @@ -151,6 +153,11 @@ func newConnectionProvider( closed: atomic.Bool{}, } + cp.connFactory = newConnection + cp.grpcConnFactory = func(hostPort *HostPort) (grpcClientConn, error) { + return createGrcpConn(cp, hostPort) + } + // Connect to the seed nodes. err := cp.connectToSeeds(ctx) if err != nil { @@ -344,7 +351,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { logger := cp.logger.With(slog.String("host", seed.String())) - grpcConn, err := cp.createGrcpConn(seed) + grpcConn, err := cp.grpcConnFactory(seed) if err != nil { logger.ErrorContext(ctx, "failed to create connection", slog.Any("error", err)) grpcConn.Close() @@ -352,7 +359,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { } extraCheck := true - conn := newConnection(grpcConn) + conn := cp.connFactory(grpcConn) if cp.token != nil { // Only one thread needs to refresh the token. Only first will @@ -441,16 +448,6 @@ func (cp *connectionProvider) updateNodeConns( return nil } -// checkAndSetClusterID checks if the cluster ID has changed and updates it if necessary. -func (cp *connectionProvider) checkAndSetClusterID(clusterID uint64) bool { - if clusterID != cp.clusterID { - cp.clusterID = clusterID - return true - } - - return false -} - // getTendConns returns all the gRPC client connections for tend operations. func (cp *connectionProvider) getTendConns() []*connection { cp.nodeConnsLock.RLock() @@ -474,8 +471,13 @@ func (cp *connectionProvider) getTendConns() []*connection { // getUpdatedEndpoints retrieves the updated server endpoints from the Aerospike cluster. func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]*protos.ServerEndpointList { + type idAndEndpoints struct { + id uint64 + endpoints map[uint64]*protos.ServerEndpointList + } + conns := cp.getTendConns() - endpointsChan := make(chan map[uint64]*protos.ServerEndpointList) + newClusterChan := make(chan *idAndEndpoints) endpointsReq := &protos.ClusterNodeEndpointsRequest{ListenerName: cp.listenerName} wg := sync.WaitGroup{} @@ -492,7 +494,7 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 logger.WarnContext(ctx, "failed to get cluster ID", slog.Any("error", err)) } - if !cp.checkAndSetClusterID(clusterID.GetId()) { + if clusterID.GetId() == cp.clusterID { logger.DebugContext( ctx, "old cluster ID found, skipping connection discovery", @@ -510,25 +512,32 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 return } - endpointsChan <- endpointsResp.Endpoints + newClusterChan <- &idAndEndpoints{ + id: clusterID.GetId(), + endpoints: endpointsResp.Endpoints, + } }(conn) } go func() { wg.Wait() - close(endpointsChan) + close(newClusterChan) }() // Stores the endpoints from the node with the largest view of the cluster - var maxTempEndpoints map[uint64]*protos.ServerEndpointList - for endpoints := range endpointsChan { - if maxTempEndpoints == nil || len(endpoints) > len(maxTempEndpoints) { - maxTempEndpoints = endpoints - cp.logger.DebugContext(ctx, "found new cluster ID", slog.Any("endpoints", maxTempEndpoints)) + var largestNewCluster *idAndEndpoints + for cluster := range newClusterChan { + if largestNewCluster == nil || len(cluster.endpoints) > len(largestNewCluster.endpoints) { + largestNewCluster = cluster + cp.logger.DebugContext(ctx, "found new cluster ID", slog.Any("endpoints", largestNewCluster)) } } - return maxTempEndpoints + if largestNewCluster != nil { + cp.clusterID = largestNewCluster.id + } + + return largestNewCluster.endpoints } // Checks if the node connections need to be updated and updates them if necessary. @@ -655,7 +664,7 @@ func (cp *connectionProvider) createGrpcConnFromEndpoints( continue // TODO: Add logging and support for IPv6 } - conn, err := cp.createGrcpConn(endpointToHostPort(endpoint)) + conn, err := cp.grpcConnFactory(endpointToHostPort(endpoint)) if err == nil { return conn, nil @@ -667,7 +676,7 @@ func (cp *connectionProvider) createGrpcConnFromEndpoints( // createGrcpConn creates a gRPC client connection to a host. This handles adding // credential and configuring tls. -func (cp *connectionProvider) createGrcpConn(hostPort *HostPort) (grpcClientConn, error) { +func createGrcpConn(cp *connectionProvider, hostPort *HostPort) (grpcClientConn, error) { opts := []grpc.DialOption{} if cp.tlsConfig == nil { @@ -704,5 +713,5 @@ func (cp *connectionProvider) createConnFromEndpoints(endpoints *protos.ServerEn return nil, err } - return newConnection(conn), nil + return cp.connFactory(conn), nil } diff --git a/connection_provider_test.go b/connection_provider_test.go index 8f47967..be0ef96 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -249,143 +249,213 @@ func TestUpdateClusterConns_NoNewClusterID(t *testing.T) { assert.Len(t, cp.nodeConns, 2) } -// func TestUpdateClusterConns_NewClusterID(t *testing.T) { -// ctrl := gomock.NewController(t) -// defer ctrl.Finish() - -// ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) -// defer cancel() - -// cp := &connectionProvider{ -// logger: slog.Default(), -// nodeConns: make(map[uint64]*connectionAndEndpoints), -// seedConns: []*connection{}, -// tlsConfig: &tls.Config{}, -// seeds: HostPortSlice{}, -// nodeConnsLock: &sync.RWMutex{}, -// tendInterval: time.Second * 1, -// clusterID: 123, -// listenerName: nil, -// isLoadBalancer: false, -// token: nil, -// stopTendChan: make(chan struct{}), -// closed: atomic.Bool{}, -// } - -// cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID")) - -// cp.logger.Debug("Setting up existing node connections") - -// grpcConn1 := NewMockgrpcClientConn(ctrl) -// mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl) -// grpcConn2 := NewMockgrpcClientConn(ctrl) -// mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl) - -// grpcConn1. -// EXPECT(). -// Target(). -// Return("") - -// mockClusterInfoClient1. -// EXPECT(). -// GetClusterId(gomock.Any(), gomock.Any()). -// Return(&protos.ClusterId{ -// Id: 123, -// }, nil) - -// grpcConn1. -// EXPECT(). -// Close(). -// Return(nil) - -// grpcConn2. -// EXPECT(). -// Target(). -// Return("") - -// mockClusterInfoClient2. -// EXPECT(). -// GetClusterId(gomock.Any(), gomock.Any()). -// Return(&protos.ClusterId{ -// Id: 456, -// }, nil) - -// mockClusterInfoClient2. -// EXPECT(). -// GetClusterEndpoints(gomock.Any(), gomock.Any()). -// Return(&protos.ClusterNodeEndpoints{ -// Endpoints: map[uint64]*protos.ServerEndpointList{ -// 3: { -// Endpoints: []*protos.ServerEndpoint{ -// { -// Address: "1.1.1.1", -// Port: 3000, -// }, -// }, -// }, -// 4: { -// Endpoints: []*protos.ServerEndpoint{ -// { -// Address: "2.2.2.2", -// Port: 3000, -// }, -// }, -// }, -// }, -// }, nil) - -// grpcConn2. -// EXPECT(). -// Close(). -// Return(nil) - -// // Existing node connections -// cp.nodeConns[1] = &connectionAndEndpoints{ -// conn: &connection{ -// grpcConn: grpcConn1, -// clusterInfoClient: mockClusterInfoClient1, -// }, -// endpoints: &protos.ServerEndpointList{}, -// } - -// cp.nodeConns[2] = &connectionAndEndpoints{ -// conn: &connection{ -// grpcConn: grpcConn2, -// clusterInfoClient: mockClusterInfoClient2, -// }, -// endpoints: &protos.ServerEndpointList{}, -// } - -// cp.logger.Debug("Running updateClusterConns") - -// // New cluster ID -// // newEndpoints := &protos.ServerEndpointList{ -// // Endpoints: []*protos.ServerEndpoint{ -// // { -// // Address: "host1", -// // Port: 3000, -// // }, -// // { -// // Address: "host2", -// // Port: 3000, -// // }, -// // }, -// // } - -// cp.updateClusterConns(ctx) - -// // cp.checkAndSetNodeConns(ctx, map[uint64]*protos.ServerEndpointList{ -// // 1: newEndpoints, -// // 2: newEndpoints, -// // }) - -// // cp.removeDownNodes(map[uint64]*protos.ServerEndpointList{ -// // 1: newEndpoints, -// // 2: newEndpoints, -// // }) - -// // cp.updateClusterConns(ctx) - -// assert.Equal(t, uint64(456), cp.clusterID) -// assert.Len(t, cp.nodeConns, 2) -// } +func TestUpdateClusterConns_NewClusterID(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + mockNewGrpcConn1111 := NewMockgrpcClientConn(ctrl) + mockNewGrpcConn2222 := NewMockgrpcClientConn(ctrl) + + mockClusterInfoClient1111 := protos.NewMockClusterInfoServiceClient(ctrl) + mockClusterInfoClient2222 := protos.NewMockClusterInfoServiceClient(ctrl) + + mockAboutClient1111 := protos.NewMockAboutServiceClient(ctrl) + mockAboutClient2222 := protos.NewMockAboutServiceClient(ctrl) + + mockAboutClient1111. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + mockAboutClient2222. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + cp := &connectionProvider{ + logger: slog.Default(), + nodeConns: make(map[uint64]*connectionAndEndpoints), + seedConns: []*connection{}, + tlsConfig: &tls.Config{}, + seeds: HostPortSlice{}, + nodeConnsLock: &sync.RWMutex{}, + tendInterval: time.Second * 1, + clusterID: 123, + listenerName: nil, + isLoadBalancer: false, + token: nil, + stopTendChan: make(chan struct{}), + closed: atomic.Bool{}, + grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + if hostPort.String() == "1.1.1.1:3000" { + return mockNewGrpcConn1111, nil + } else if hostPort.String() == "2.2.2.2:3000" { + return mockNewGrpcConn2222, nil + } + + return nil, fmt.Errorf("foo") + }, + connFactory: func(grpcConn grpcClientConn) *connection { + if grpcConn == mockNewGrpcConn1111 { + return &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + } + } else if grpcConn == mockNewGrpcConn2222 { + return &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + } + } + + return nil + }, + } + + cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID")) + + cp.logger.Debug("Setting up existing node connections") + + grpcConn1 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl) + grpcConn2 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl) + + grpcConn1. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient1. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 789, // Different cluster id from client 2 + }, nil) + + mockClusterInfoClient1. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // Smaller num of endpoints from client 2 + 0: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + }, nil) + + grpcConn1. + EXPECT(). + Close(). + Return(nil) + + grpcConn2. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient2. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 456, + }, nil) + + mockClusterInfoClient2. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // larger, so the cluster id 456 will win + 3: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + 4: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + }, nil) + + grpcConn2. + EXPECT(). + Close(). + Return(nil) + + // Existing node connections. These will be replaced after a new cluster is found. + cp.nodeConns = map[uint64]*connectionAndEndpoints{ + 1: { + conn: &connection{ + grpcConn: grpcConn1, + clusterInfoClient: mockClusterInfoClient1, + }, + endpoints: &protos.ServerEndpointList{}, + }, + 2: { + conn: &connection{ + grpcConn: grpcConn2, + clusterInfoClient: mockClusterInfoClient2, + }, + endpoints: &protos.ServerEndpointList{}, + }, + } + + // After a new cluster is discovered we expect these to be the new nodeConns + expectedNewNodeConns := map[uint64]*connectionAndEndpoints{ + 3: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + 4: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + } + + cp.logger.Debug("Running updateClusterConns") + + cp.updateClusterConns(ctx) + + assert.Equal(t, uint64(456), cp.clusterID) + assert.Len(t, cp.nodeConns, 2) + + for k, v := range cp.nodeConns { + assert.EqualExportedValues(t, expectedNewNodeConns[k].endpoints, v.endpoints) + } +} diff --git a/token_manager_test.go b/token_manager_test.go index 8a3c971..26b20e0 100644 --- a/token_manager_test.go +++ b/token_manager_test.go @@ -138,7 +138,7 @@ func TestRefreshToken_FailedToDecodeToken(t *testing.T) { EXPECT(). Authenticate(gomock.Any(), gomock.Any()). Return(&protos.AuthResponse{ - Token: "badToken.blah.foo", + Token: "badToken.blahz.foo", }, nil) // Create the token manager with the @@ -149,7 +149,7 @@ func TestRefreshToken_FailedToDecodeToken(t *testing.T) { // Assert error occurred assert.Error(t, err) - assert.Equal(t, "failed to authenticate: invalid character 'V' in literal null (expecting 'u')", err.Error()) + assert.Equal(t, "failed to authenticate: illegal base64 data at input byte 4", err.Error()) } func TestRefreshToken_FailedInvalidJson(t *testing.T) { From 913801329b9b4827564b3c1fe1b45391fe7331c3 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 23 Sep 2024 13:06:18 -0700 Subject: [PATCH 31/42] fix tests --- connection_provider.go | 3 ++- connection_provider_test.go | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/connection_provider.go b/connection_provider.go index e02ebec..8d9de6a 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -535,9 +535,10 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 if largestNewCluster != nil { cp.clusterID = largestNewCluster.id + return largestNewCluster.endpoints } - return largestNewCluster.endpoints + return nil } // Checks if the node connections need to be updated and updates them if necessary. diff --git a/connection_provider_test.go b/connection_provider_test.go index be0ef96..647e5da 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -169,6 +169,10 @@ func TestGetSeedConn_FailSeedConnEmpty(t *testing.T) { assert.Equal(t, errors.New("no seed connections found"), err) } + +func TestconnectToSeeds(t *testing.T) { +} + func TestUpdateClusterConns_NoNewClusterID(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() From e52ab475ba3e74939a99e9ef58721fe9c9a8cd8a Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Mon, 23 Sep 2024 13:32:03 -0700 Subject: [PATCH 32/42] more tests --- connection_provider.go | 7 +- connection_provider_test.go | 302 +++++++++++++++++++++++++++++++++++- 2 files changed, 301 insertions(+), 8 deletions(-) diff --git a/connection_provider.go b/connection_provider.go index 8d9de6a..ea6d253 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -65,7 +65,11 @@ func newConnection(conn grpcClientConn) *connection { } func (conn *connection) close() error { - return conn.grpcConn.Close() + if conn.grpcConn != nil { + return conn.grpcConn.Close() + } + + return nil } // connectionAndEndpoints represents a combination of a gRPC client connection and server endpoints. @@ -354,7 +358,6 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { grpcConn, err := cp.grpcConnFactory(seed) if err != nil { logger.ErrorContext(ctx, "failed to create connection", slog.Any("error", err)) - grpcConn.Close() return } diff --git a/connection_provider_test.go b/connection_provider_test.go index 647e5da..a4265d5 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -170,7 +170,85 @@ func TestGetSeedConn_FailSeedConnEmpty(t *testing.T) { assert.Equal(t, errors.New("no seed connections found"), err) } -func TestconnectToSeeds(t *testing.T) { +func TestConnectToSeeds_FailedAlreadyConnected(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + } + + cp.seedConns = []*connection{ + { + grpcConn: NewMockgrpcClientConn(ctrl), + }, + } + + err := cp.connectToSeeds(context.Background()) + + assert.Equal(t, errors.New("seed connections already exist, close them first"), err) +} + +func TestConnectToSeeds_FailedFailedToCreateConnection(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + return nil, fmt.Errorf("foo") + }, + } + + cp.seeds = HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + + err := cp.connectToSeeds(context.Background()) + + assert.Equal(t, NewAVSError("failed to connect to seeds", nil), err) +} + +func TestConnectToSeeds_FailedToRefreshToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockToken := NewMocktokenManager(ctrl) + mockToken. + EXPECT(). + RefreshToken(gomock.Any(), gomock.Any()). + Return(fmt.Errorf("foo")) + + cp := &connectionProvider{ + isLoadBalancer: true, + closed: atomic.Bool{}, + logger: slog.Default(), + grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + return nil, nil + }, + connFactory: func(conn grpcClientConn) *connection { + return &connection{} + }, + token: mockToken, + } + + cp.seeds = HostPortSlice{ + &HostPort{ + Host: "host", + Port: 3000, + }, + } + + err := cp.connectToSeeds(context.Background()) + + assert.Equal(t, NewAVSError("failed to connect to seeds", fmt.Errorf("foo")), err) } func TestUpdateClusterConns_NoNewClusterID(t *testing.T) { @@ -253,7 +331,7 @@ func TestUpdateClusterConns_NoNewClusterID(t *testing.T) { assert.Len(t, cp.nodeConns, 2) } -func TestUpdateClusterConns_NewClusterID(t *testing.T) { +func TestUpdateClusterConns_NewClusterIDWithDIFFERENTNodeIDs(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -366,11 +444,44 @@ func TestUpdateClusterConns_NewClusterID(t *testing.T) { Target(). Return("") + expectedClusterID := uint64(456) + // After a new cluster is discovered we expect these to be the new nodeConns + expectedNewNodeConns := map[uint64]*connectionAndEndpoints{ + 3: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + 4: { + conn: &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + }, + endpoints: &protos.ServerEndpointList{ + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + } + mockClusterInfoClient2. EXPECT(). GetClusterId(gomock.Any(), gomock.Any()). Return(&protos.ClusterId{ - Id: 456, + Id: expectedClusterID, }, nil) mockClusterInfoClient2. @@ -420,9 +531,183 @@ func TestUpdateClusterConns_NewClusterID(t *testing.T) { }, } + cp.logger.Debug("Running updateClusterConns") + + cp.updateClusterConns(ctx) + + assert.Equal(t, expectedClusterID, cp.clusterID) + assert.Len(t, cp.nodeConns, 2) + + for k, v := range cp.nodeConns { + assert.EqualExportedValues(t, expectedNewNodeConns[k].endpoints, v.endpoints) + } +} + +func TestUpdateClusterConns_NewClusterIDWithSAMENodeIDs(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + grpcConn1 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient1 := protos.NewMockClusterInfoServiceClient(ctrl) + grpcConn2 := NewMockgrpcClientConn(ctrl) + mockClusterInfoClient2 := protos.NewMockClusterInfoServiceClient(ctrl) + + expectedClusterID := uint64(456) + + mockClusterInfoClient2. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: expectedClusterID, + }, nil) + + mockClusterInfoClient2. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // larger, so the cluster id 456 will win + 1: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + 2: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "2.2.2.2", + Port: 3000, + }, + }, + }, + }, + }, nil) + + mockNewGrpcConn1111 := NewMockgrpcClientConn(ctrl) + mockNewGrpcConn2222 := NewMockgrpcClientConn(ctrl) + + mockClusterInfoClient1111 := protos.NewMockClusterInfoServiceClient(ctrl) + mockClusterInfoClient2222 := protos.NewMockClusterInfoServiceClient(ctrl) + + mockAboutClient1111 := protos.NewMockAboutServiceClient(ctrl) + mockAboutClient2222 := protos.NewMockAboutServiceClient(ctrl) + + mockAboutClient1111. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + mockAboutClient2222. + EXPECT(). + Get(gomock.Any(), gomock.Any()). + Return(nil, nil) + + cp := &connectionProvider{ + logger: slog.Default(), + seedConns: []*connection{}, + tlsConfig: &tls.Config{}, + seeds: HostPortSlice{}, + nodeConnsLock: &sync.RWMutex{}, + tendInterval: time.Second * 1, + clusterID: 123, + listenerName: nil, + isLoadBalancer: false, + token: nil, + stopTendChan: make(chan struct{}), + closed: atomic.Bool{}, + grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + if hostPort.String() == "1.1.1.1:3000" { + return mockNewGrpcConn1111, nil + } else if hostPort.String() == "2.2.2.2:3000" { + return mockNewGrpcConn2222, nil + } + + return nil, fmt.Errorf("foo") + }, + connFactory: func(grpcConn grpcClientConn) *connection { + if grpcConn == mockNewGrpcConn1111 { + return &connection{ + clusterInfoClient: mockClusterInfoClient1111, + aboutClient: mockAboutClient1111, + } + } else if grpcConn == mockNewGrpcConn2222 { + return &connection{ + clusterInfoClient: mockClusterInfoClient2222, + aboutClient: mockAboutClient2222, + } + } + + return nil + }, + // Existing node connections. These will be replaced after a new cluster is found. + nodeConns: map[uint64]*connectionAndEndpoints{ + 1: { + conn: &connection{ + grpcConn: grpcConn1, + clusterInfoClient: mockClusterInfoClient1, + }, + endpoints: &protos.ServerEndpointList{}, + }, + 2: { + conn: &connection{ + grpcConn: grpcConn2, + clusterInfoClient: mockClusterInfoClient2, + }, + endpoints: &protos.ServerEndpointList{}, + }, + }, + } + + cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID")) + + cp.logger.Debug("Setting up existing node connections") + + grpcConn1. + EXPECT(). + Target(). + Return("") + + mockClusterInfoClient1. + EXPECT(). + GetClusterId(gomock.Any(), gomock.Any()). + Return(&protos.ClusterId{ + Id: 789, // Different cluster id from client 2 + }, nil) + + mockClusterInfoClient1. + EXPECT(). + GetClusterEndpoints(gomock.Any(), gomock.Any()). + Return(&protos.ClusterNodeEndpoints{ + Endpoints: map[uint64]*protos.ServerEndpointList{ // Smaller num of endpoints from client 2 + 0: { + Endpoints: []*protos.ServerEndpoint{ + { + Address: "1.1.1.1", + Port: 3000, + }, + }, + }, + }, + }, nil) + + grpcConn1. + EXPECT(). + Close(). + Return(nil) + + grpcConn2. + EXPECT(). + Target(). + Return("") + // After a new cluster is discovered we expect these to be the new nodeConns expectedNewNodeConns := map[uint64]*connectionAndEndpoints{ - 3: { + 1: { conn: &connection{ clusterInfoClient: mockClusterInfoClient1111, aboutClient: mockAboutClient1111, @@ -436,7 +721,7 @@ func TestUpdateClusterConns_NewClusterID(t *testing.T) { }, }, }, - 4: { + 2: { conn: &connection{ clusterInfoClient: mockClusterInfoClient2222, aboutClient: mockAboutClient2222, @@ -452,11 +737,16 @@ func TestUpdateClusterConns_NewClusterID(t *testing.T) { }, } + grpcConn2. + EXPECT(). + Close(). + Return(nil) + cp.logger.Debug("Running updateClusterConns") cp.updateClusterConns(ctx) - assert.Equal(t, uint64(456), cp.clusterID) + assert.Equal(t, expectedClusterID, cp.clusterID) assert.Len(t, cp.nodeConns, 2) for k, v := range cp.nodeConns { From e6f975f96dae2decac04ce999b108322cb28faed Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Tue, 24 Sep 2024 13:06:14 -0700 Subject: [PATCH 33/42] fix linter --- .golangci.yml | 6 ++++++ client.go | 20 +++++++++++-------- client_test.go | 2 ++ connection_provider.go | 4 +++- connection_provider_test.go | 39 ++++++++++++++++++++++++------------- protos/utils_test.go | 4 ++-- token_manager_test.go | 10 +++++----- 7 files changed, 56 insertions(+), 29 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index e343d8e..3538c5c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -75,6 +75,12 @@ issues: - path: '(.+)test\.go' linters: - govet # Test code field alignment for sake of space is not a concern + - path: 'token_manager_test.go' + linters: + - goconst # Test code is allowed to have constants + - path: 'connection_provider_test.go' + linters: + - goconst # Test cod # - path: dir/sample\.go # linters: # - lll # Test code is allowed to have long lines diff --git a/client.go b/client.go index 38a7d27..4de465e 100644 --- a/client.go +++ b/client.go @@ -22,11 +22,12 @@ const ( indexWaitDuration = time.Millisecond * 100 ) const ( - failedToInsertRecord = "failed to insert record" - failedToGetRecord = "failed to get record" - failedToDeleteRecord = "failed to delete record" - failedToCheckRecordExists = "failed to check if record exists" - failedToCheckIsIndexed = "failed to check if record is indexed" + failedToInsertRecord = "failed to insert record" + failedToGetRecord = "failed to get record" + failedToDeleteRecord = "failed to delete record" + failedToCheckRecordExists = "failed to check if record exists" + failedToCheckIsIndexed = "failed to check if record is indexed" + failedToWaitForIndexCompletion = "failed to wait for index completion" ) type connProvider interface { @@ -91,6 +92,7 @@ func NewClient( if err != nil { grpcToken.Close() logger.Error("failed to create connection provider", slog.Any("error", err)) + return nil, NewAVSErrorFromGrpc("failed to connect to server", err) } @@ -651,11 +653,11 @@ func (c *Client) WaitForIndexCompletion( waitInterval time.Duration, ) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) - logger.DebugContext(ctx, "waiting for index completion") + logger.DebugContext(ctx, failedToWaitForIndexCompletion) conn, err := c.connectionProvider.GetRandomConn() if err != nil { - msg := "failed to wait for index completion" + msg := failedToWaitForIndexCompletion logger.Error(msg, slog.Any("error", err)) return NewAVSError(msg, err) @@ -672,7 +674,7 @@ func (c *Client) WaitForIndexCompletion( for { indexStatus, err := conn.indexClient.GetStatus(ctx, indexStatusReq) if err != nil { - msg := "failed to wait for index completion" + msg := failedToWaitForIndexCompletion logger.ErrorContext(ctx, msg, slog.Any("error", err)) return NewAVSError(msg, err) @@ -702,7 +704,9 @@ func (c *Client) WaitForIndexCompletion( case <-timer.C: case <-ctx.Done(): msg := "failed to wait for index completion" + logger.ErrorContext(ctx, "waiting for index completion canceled") + return NewAVSError(msg, ctx.Err()) } } diff --git a/client_test.go b/client_test.go index 56284b0..8dac3ed 100644 --- a/client_test.go +++ b/client_test.go @@ -1,3 +1,5 @@ +//go:build unit + package avs import ( diff --git a/connection_provider.go b/connection_provider.go index ea6d253..de8e8ad 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -375,6 +375,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { authErr = err tokenLock.Unlock() conn.close() + return } @@ -390,6 +391,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { if err != nil { logger.WarnContext(ctx, "failed to connect to seed", slog.Any("error", err)) grpcConn.Close() + return } @@ -475,8 +477,8 @@ func (cp *connectionProvider) getTendConns() []*connection { // getUpdatedEndpoints retrieves the updated server endpoints from the Aerospike cluster. func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]*protos.ServerEndpointList { type idAndEndpoints struct { - id uint64 endpoints map[uint64]*protos.ServerEndpointList + id uint64 } conns := cp.getTendConns() diff --git a/connection_provider_test.go b/connection_provider_test.go index a4265d5..18c5ba7 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -20,12 +20,17 @@ func TestNewConnectionProvider_FailSeedsNil(t *testing.T) { seeds := HostPortSlice{} listenerName := "listener" isLoadBalancer := false - tlsConfig := &tls.Config{} - var logger *slog.Logger - var token tokenManager + tlsConfig := &tls.Config{} //nolint:gosec // tests + + var ( + logger *slog.Logger + token tokenManager + ) cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) - defer cp.Close() + if err == nil { + defer cp.Close() + } assert.Nil(t, cp) assert.Equal(t, err, errors.New("seeds cannot be nil or empty")) @@ -44,8 +49,10 @@ func TestNewConnectionProvider_FailNoTLS(t *testing.T) { listenerName := "listener" isLoadBalancer := false - var tlsConfig *tls.Config - var logger *slog.Logger + var ( + tlsConfig *tls.Config + logger *slog.Logger + ) token := NewMocktokenManager(ctrl) @@ -55,7 +62,9 @@ func TestNewConnectionProvider_FailNoTLS(t *testing.T) { Return(true) cp, err := newConnectionProvider(context.Background(), seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) - defer cp.Close() + if err == nil { + defer cp.Close() + } assert.Nil(t, cp) assert.Equal(t, err, errors.New("tlsConfig is required when username/password authentication")) @@ -77,12 +86,16 @@ func TestNewConnectionProvider_FailConnectToSeedConns(t *testing.T) { listenerName := "listener" isLoadBalancer := false - var tlsConfig *tls.Config - var logger *slog.Logger - var token tokenManager + var ( + tlsConfig *tls.Config + logger *slog.Logger + token tokenManager + ) cp, err := newConnectionProvider(ctx, seeds, &listenerName, isLoadBalancer, token, tlsConfig, logger) - defer cp.Close() + if err == nil { + defer cp.Close() + } assert.Nil(t, cp) assert.Equal(t, "failed to connect to seeds: context deadline exceeded", err.Error()) @@ -361,7 +374,7 @@ func TestUpdateClusterConns_NewClusterIDWithDIFFERENTNodeIDs(t *testing.T) { logger: slog.Default(), nodeConns: make(map[uint64]*connectionAndEndpoints), seedConns: []*connection{}, - tlsConfig: &tls.Config{}, + tlsConfig: &tls.Config{}, //nolint:gosec // tests seeds: HostPortSlice{}, nodeConnsLock: &sync.RWMutex{}, tendInterval: time.Second * 1, @@ -610,7 +623,7 @@ func TestUpdateClusterConns_NewClusterIDWithSAMENodeIDs(t *testing.T) { cp := &connectionProvider{ logger: slog.Default(), seedConns: []*connection{}, - tlsConfig: &tls.Config{}, + tlsConfig: &tls.Config{}, //nolint:gosec // tests seeds: HostPortSlice{}, nodeConnsLock: &sync.RWMutex{}, tendInterval: time.Second * 1, diff --git a/protos/utils_test.go b/protos/utils_test.go index acd6195..55e19b2 100644 --- a/protos/utils_test.go +++ b/protos/utils_test.go @@ -1091,7 +1091,7 @@ func TestConvertToFields(t *testing.T) { sort.Slice(result, sortFunc(result)) sort.Slice(tc.expected, sortFunc(result)) - for i, _ := range tc.expected { + for i := range tc.expected { assert.EqualExportedValues(t, tc.expected[i], result[i]) } @@ -1152,7 +1152,7 @@ func TestConvertFromFields(t *testing.T) { type unknownVectorType struct{} -func (*unknownVectorType) isVector_Data() {} +func (*unknownVectorType) isVector_Data() {} //nolint:revive,stylecheck // Grpc generated func TestConvertFromVector(t *testing.T) { testCases := []struct { diff --git a/token_manager_test.go b/token_manager_test.go index 26b20e0..392a0d8 100644 --- a/token_manager_test.go +++ b/token_manager_test.go @@ -27,7 +27,7 @@ func TestRefreshToken_Success(t *testing.T) { authClient: mockAuthServiceClient, } - b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDAsImlhdCI6MTcyNzExMDc1NH0.GD01CEWxW6-7lHcyeetM95WKdUlwY85m5lFqzcTCtzs" + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDAsImlhdCI6MTcyNzExMDc1NH0.GD01CEWxW6-7lHcyeetM95WKdUlwY85m5lFqzcTCtzs" //nolint:gosec,lll // tests // Set up expectations for AuthServiceClient.Authenticate() mockAuthServiceClient. @@ -168,7 +168,7 @@ func TestRefreshToken_FailedInvalidJson(t *testing.T) { authClient: mockAuthServiceClient, } - b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDAsImlhdCI6MTcyNzExMD.GD01CEWxW6-7lHcyeetM95WKdUlwY85m5lFqzcTCtzs" + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDAsImlhdCI6MTcyNzExMD.GD01CEWxW6-7lHcyeetM95WKdUlwY85m5lFqzcTCtzs" //nolint:gosec,lll // tests // Set up expectations for AuthServiceClient.Authenticate() mockAuthServiceClient. @@ -205,7 +205,7 @@ func TestRefreshToken_FailedFindExp(t *testing.T) { authClient: mockAuthServiceClient, } - b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE3MjcxMTA3NTR9.50IZcLoS7mQPzQsKvJZyXNUukvT5FdiqN2tynNIjHuk" + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE3MjcxMTA3NTR9.50IZcLoS7mQPzQsKvJZyXNUukvT5FdiqN2tynNIjHuk" //nolint:gosec,lll // tests // Set up expectations for AuthServiceClient.Authenticate() mockAuthServiceClient. @@ -242,7 +242,7 @@ func TestRefreshToken_FailedFindIat(t *testing.T) { authClient: mockAuthServiceClient, } - b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDB9.f5DtjF1sYLH6fz0ThcFKxwngIXkVMLnhJtIrjLi_1p0" + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjMwMDAwMDAwMDB9.f5DtjF1sYLH6fz0ThcFKxwngIXkVMLnhJtIrjLi_1p0" //nolint:gosec,lll // tests // Set up expectations for AuthServiceClient.Authenticate() mockAuthServiceClient. @@ -279,7 +279,7 @@ func TestRefreshToken_FailedTtlLessThan0(t *testing.T) { authClient: mockAuthServiceClient, } - b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE3MjcxMTA3NTMsImlhdCI6MTcyNzExMDc1NH0.YdH5twU6-LGLtgvD2sktiw1j40MRUe_r4oPN565z4Ok" + b64token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE3MjcxMTA3NTMsImlhdCI6MTcyNzExMDc1NH0.YdH5twU6-LGLtgvD2sktiw1j40MRUe_r4oPN565z4Ok" //nolint:gosec,lll // tests // Set up expectations for AuthServiceClient.Authenticate() mockAuthServiceClient. From 4f355541939b5d524e212cb0c1bb03e6fcdf116d Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 26 Sep 2024 08:57:07 -0700 Subject: [PATCH 34/42] review changes --- client.go | 6 +++--- connection_provider.go | 10 +++++++++- makefile | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 4de465e..6587204 100644 --- a/client.go +++ b/client.go @@ -653,7 +653,7 @@ func (c *Client) WaitForIndexCompletion( waitInterval time.Duration, ) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("indexName", indexName)) - logger.DebugContext(ctx, failedToWaitForIndexCompletion) + logger.DebugContext(ctx, "waiting for index completion") conn, err := c.connectionProvider.GetRandomConn() if err != nil { @@ -703,9 +703,9 @@ func (c *Client) WaitForIndexCompletion( select { case <-timer.C: case <-ctx.Done(): - msg := "failed to wait for index completion" + msg := "waiting for index completion canceled" - logger.ErrorContext(ctx, "waiting for index completion canceled") + logger.ErrorContext(ctx, msg) return NewAVSError(msg, ctx.Err()) } diff --git a/connection_provider.go b/connection_provider.go index de8e8ad..982145f 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -530,11 +530,19 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 }() // Stores the endpoints from the node with the largest view of the cluster + // Think about a scenario where the nodes are split into two cluster + // momentarily and the client can see both. We are making the decision here + // to connect to the larger of the two formed cluster. var largestNewCluster *idAndEndpoints for cluster := range newClusterChan { if largestNewCluster == nil || len(cluster.endpoints) > len(largestNewCluster.endpoints) { largestNewCluster = cluster - cp.logger.DebugContext(ctx, "found new cluster ID", slog.Any("endpoints", largestNewCluster)) + cp.logger.DebugContext( + ctx, + "largest cluster with new id", + slog.Any("endpoints", largestNewCluster.endpoints), + slog.Uint64("id", largestNewCluster.id), + ) } } diff --git a/makefile b/makefile index 7ed72c1..75aae10 100644 --- a/makefile +++ b/makefile @@ -66,5 +66,5 @@ view-coverage: $(COVERAGE_DIR)/total.cov go tool cover -html=$(COVERAGE_DIR)/total.cov PHONY: lint -lint: $(GOLANGCI_LINT) +lint: $(GOLANGCI_LINT) mocks $(GOLANGCI_LINT) run \ No newline at end of file From 1a5d80fdddc53ea1128ad0bfddaa1cbbd20c1cae Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 26 Sep 2024 09:00:19 -0700 Subject: [PATCH 35/42] fix tests --- client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client_test.go b/client_test.go index 8dac3ed..2a34c39 100644 --- a/client_test.go +++ b/client_test.go @@ -1566,7 +1566,7 @@ func TestWaitForIndexCompletion_FailTimeout(t *testing.T) { var avsError *Error assert.ErrorAs(t, err, &avsError) - assert.Equal(t, avsError, NewAVSError("failed to wait for index completion", fmt.Errorf("context deadline exceeded"))) + assert.Equal(t, avsError, NewAVSError("waiting for index completion canceled", fmt.Errorf("context deadline exceeded"))) } func TestIndexCreateFromIndexDef_Success(t *testing.T) { From 43e8bb907e361e0da1ec6e9ec18e88291e3316b5 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 26 Sep 2024 09:06:22 -0700 Subject: [PATCH 36/42] fix linter --- .github/workflows/golangci-lint.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 1f300db..63ba6b7 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -18,6 +18,9 @@ jobs: uses: actions/setup-go@v5 with: go-version: 1.21 + - name: Setup Mocks + run: | + make mocks - name: Run golangci-lint uses: golangci/golangci-lint-action@v3 with: From d4d536a4c00ce5344851df265a05b925e551fda7 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 26 Sep 2024 09:19:33 -0700 Subject: [PATCH 37/42] fix lint --- connection_provider_test.go | 6 +++--- makefile | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/connection_provider_test.go b/connection_provider_test.go index 18c5ba7..fbd7f06 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -212,7 +212,7 @@ func TestConnectToSeeds_FailedFailedToCreateConnection(t *testing.T) { isLoadBalancer: true, closed: atomic.Bool{}, logger: slog.Default(), - grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + grpcConnFactory: func(_ *HostPort) (grpcClientConn, error) { return nil, fmt.Errorf("foo") }, } @@ -243,10 +243,10 @@ func TestConnectToSeeds_FailedToRefreshToken(t *testing.T) { isLoadBalancer: true, closed: atomic.Bool{}, logger: slog.Default(), - grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + grpcConnFactory: func(_ *HostPort) (grpcClientConn, error) { return nil, nil }, - connFactory: func(conn grpcClientConn) *connection { + connFactory: func(_ grpcClientConn) *connection { return &connection{} }, token: mockToken, diff --git a/makefile b/makefile index 75aae10..028d918 100644 --- a/makefile +++ b/makefile @@ -5,7 +5,7 @@ GOBIN=$(shell go env GOBIN) endif GOLANGCI_LINT ?= $(GOBIN)/golangci-lint -GOLANGCI_LINT_VERSION ?= v1.54.0 +GOLANGCI_LINT_VERSION ?= v1.58.0 MOCKGEN ?= $(GOBIN)/mockgen MOCKGEN_VERSION ?= v0.3.0 From cfa92d1f73287efb776b2fa7a00f4b9e8c8818ed Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 26 Sep 2024 15:38:52 -0700 Subject: [PATCH 38/42] review changes --- client_test.go | 38 ++++---- connection_provider.go | 34 +++++--- connection_provider_test.go | 6 +- docker/.env | 1 + docker/auth/docker-compose.yml | 2 +- docker/mtls/docker-compose.yml | 2 +- docker/multi-node-LB/docker-compose.yml | 6 +- .../docker-compose.yml | 6 +- docker/multi-node/docker-compose.yml | 6 +- docker/tls/docker-compose.yml | 2 +- docker/vanilla/docker-compose.yml | 2 +- integration_single_node_test.go | 86 +++++++++---------- makefile | 2 +- protos/utils_test.go | 46 +++++----- testutils.go | 40 ++------- 15 files changed, 131 insertions(+), 148 deletions(-) create mode 100644 docker/.env diff --git a/client_test.go b/client_test.go index 2a34c39..f03c1a5 100644 --- a/client_test.go +++ b/client_test.go @@ -1036,7 +1036,7 @@ func TestVectorSearchFloat32_Success(t *testing.T) { Limit: 7, SearchParams: &protos.VectorSearchRequest_HnswSearchParams{ HnswSearchParams: &protos.HnswSearchParams{ - Ef: GetUint32Ptr(8), + Ef: Ptr(uint32(8)), }, }, Projection: &protos.ProjectionSpec{ @@ -1073,7 +1073,7 @@ func TestVectorSearchFloat32_Success(t *testing.T) { return &protos.Neighbor{ Key: &protos.Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &protos.Key_StringValue{ StringValue: fmt.Sprintf("key-%d", vectorCounter), }, @@ -1108,9 +1108,9 @@ func TestVectorSearchFloat32_Success(t *testing.T) { "field1": "value1", }, Generation: uint32(1), - Expiration: GetTimePtr(AerospikeEpoch.Add(time.Second * 1)), + Expiration: Ptr(AerospikeEpoch.Add(time.Second * 1)), }, - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Key: "key-1", Namespace: "testNamespace", Distance: float32(1), @@ -1121,9 +1121,9 @@ func TestVectorSearchFloat32_Success(t *testing.T) { "field1": "value1", }, Generation: uint32(2), - Expiration: GetTimePtr(AerospikeEpoch.Add(time.Second * 2)), + Expiration: Ptr(AerospikeEpoch.Add(time.Second * 2)), }, - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Key: "key-2", Namespace: "testNamespace", Distance: float32(2), @@ -1134,9 +1134,9 @@ func TestVectorSearchFloat32_Success(t *testing.T) { "field1": "value1", }, Generation: uint32(3), - Expiration: GetTimePtr(AerospikeEpoch.Add(time.Second * 3)), + Expiration: Ptr(AerospikeEpoch.Add(time.Second * 3)), }, - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Key: "key-3", Namespace: "testNamespace", Distance: float32(3), @@ -1154,7 +1154,7 @@ func TestVectorSearchFloat32_Success(t *testing.T) { vector := []float32{1.0, 2.0, 3.0} limit := uint32(7) searchParams := &protos.HnswSearchParams{ - Ef: GetUint32Ptr(8), + Ef: Ptr(uint32(8)), } neighbors, err := client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) @@ -1189,7 +1189,7 @@ func TestVectorSearchFloat32_FailsGettingConn(t *testing.T) { vector := []float32{1.0, 2.0, 3.0} limit := uint32(7) searchParams := &protos.HnswSearchParams{ - Ef: GetUint32Ptr(8), + Ef: Ptr(uint32(8)), } _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) @@ -1231,7 +1231,7 @@ func TestVectorSearchFloat32_FailsVectorSearch(t *testing.T) { vector := []float32{1.0, 2.0, 3.0} limit := uint32(7) searchParams := &protos.HnswSearchParams{ - Ef: GetUint32Ptr(8), + Ef: Ptr(uint32(8)), } _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) @@ -1284,7 +1284,7 @@ func TestVectorSearchFloat32_FailedToRecvAllNeighbors(t *testing.T) { vector := []float32{1.0, 2.0, 3.0} limit := uint32(7) searchParams := &protos.HnswSearchParams{ - Ef: GetUint32Ptr(8), + Ef: Ptr(uint32(8)), } _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) @@ -1343,7 +1343,7 @@ func TestVectorSearchFloat32_FailedToConvertNeighbor(t *testing.T) { vector := []float32{1.0, 2.0, 3.0} limit := uint32(7) searchParams := &protos.HnswSearchParams{ - Ef: GetUint32Ptr(8), + Ef: Ptr(uint32(8)), } _, err = client.VectorSearchFloat32(ctx, namespace, indexName, vector, uint32(limit), searchParams, nil, nil) @@ -1722,7 +1722,7 @@ func TestIndexUpdate_Success(t *testing.T) { }, Update: &protos.IndexUpdateRequest_HnswIndexUpdate{ HnswIndexUpdate: &protos.HnswIndexUpdate{ - MaxMemQueueSize: GetUint32Ptr(10), + MaxMemQueueSize: Ptr(uint32(10)), }, }, } @@ -1747,7 +1747,7 @@ func TestIndexUpdate_Success(t *testing.T) { "foo": "bar", } hnswParams := &protos.HnswIndexUpdate{ - MaxMemQueueSize: GetUint32Ptr(10), + MaxMemQueueSize: Ptr(uint32(10)), } err = client.IndexUpdate(ctx, testNamespace, testIndex, testMetadata, hnswParams) @@ -1782,7 +1782,7 @@ func TestIndexUpdate_FailGetConn(t *testing.T) { "foo": "bar", } hnswParams := &protos.HnswIndexUpdate{ - MaxMemQueueSize: GetUint32Ptr(10), + MaxMemQueueSize: Ptr(uint32(10)), } err = client.IndexUpdate(ctx, testNamespace, testIndex, testMetadata, hnswParams) @@ -1824,7 +1824,7 @@ func TestIndexUpdate_FailUpdateCall(t *testing.T) { "foo": "bar", } hnswParams := &protos.HnswIndexUpdate{ - MaxMemQueueSize: GetUint32Ptr(10), + MaxMemQueueSize: Ptr(uint32(10)), } err = client.IndexUpdate(ctx, testNamespace, testIndex, testMetadata, hnswParams) @@ -1967,7 +1967,7 @@ func TestIndexList_Success(t *testing.T) { Return(mockConn, nil) expectedIndexListRequest := &protos.IndexListRequest{ - ApplyDefaults: GetBoolPtr(true), + ApplyDefaults: Ptr(true), } expectedIndexDefs := &protos.IndexDefinitionList{ @@ -2090,7 +2090,7 @@ func TestIndexGet_Success(t *testing.T) { Namespace: "testNamespace", Name: "testIndex", }, - ApplyDefaults: GetBoolPtr(true), + ApplyDefaults: Ptr(true), } expectedIndexDefs := &protos.IndexDefinition{ diff --git a/connection_provider.go b/connection_provider.go index 982145f..b54e6fc 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -65,7 +65,7 @@ func newConnection(conn grpcClientConn) *connection { } func (conn *connection) close() error { - if conn.grpcConn != nil { + if conn != nil && conn.grpcConn != nil { return conn.grpcConn.Close() } @@ -104,8 +104,8 @@ type connectionProvider struct { token tokenManager stopTendChan chan struct{} closed atomic.Bool - grpcConnFactory func(hostPort *HostPort) (grpcClientConn, error) // For testing - connFactory func(conn grpcClientConn) *connection // For testing + grpcConnFactory func(hostPort *HostPort) (grpcClientConn, error) + connFactory func(conn grpcClientConn) *connection } // newConnectionProvider creates a new connectionProvider instance. @@ -362,7 +362,7 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { } extraCheck := true - conn := cp.connFactory(grpcConn) + conn := newConnection(grpcConn) if cp.token != nil { // Only one thread needs to refresh the token. Only first will @@ -374,7 +374,11 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { logger.WarnContext(ctx, "failed to refresh token", slog.Any("error", err)) authErr = err tokenLock.Unlock() - conn.close() + + err = conn.close() + if err != nil { + logger.WarnContext(ctx, "failed to close connection", slog.Any("error", err)) + } return } @@ -390,7 +394,11 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { about, err := conn.aboutClient.Get(ctx, &protos.AboutRequest{}) if err != nil { logger.WarnContext(ctx, "failed to connect to seed", slog.Any("error", err)) - grpcConn.Close() + + err = conn.close() + if err != nil { + logger.WarnContext(ctx, "failed to close connection", slog.Any("error", err)) + } return } @@ -537,16 +545,16 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 for cluster := range newClusterChan { if largestNewCluster == nil || len(cluster.endpoints) > len(largestNewCluster.endpoints) { largestNewCluster = cluster - cp.logger.DebugContext( - ctx, - "largest cluster with new id", - slog.Any("endpoints", largestNewCluster.endpoints), - slog.Uint64("id", largestNewCluster.id), - ) } } if largestNewCluster != nil { + cp.logger.DebugContext( + ctx, + "largest cluster with new id", + slog.Any("endpoints", largestNewCluster.endpoints), + slog.Uint64("id", largestNewCluster.id), + ) cp.clusterID = largestNewCluster.id return largestNewCluster.endpoints } @@ -605,7 +613,7 @@ func (cp *connectionProvider) checkAndSetNodeConns( } // removeDownNodes removes the gRPC client connections for nodes in nodeConns -// that aren't apart of newNodeEndpoints +// that aren't a part of newNodeEndpoints func (cp *connectionProvider) removeDownNodes(newNodeEndpoints map[uint64]*protos.ServerEndpointList) { cp.nodeConnsLock.Lock() defer cp.nodeConnsLock.Unlock() diff --git a/connection_provider_test.go b/connection_provider_test.go index fbd7f06..637ef62 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -246,10 +246,8 @@ func TestConnectToSeeds_FailedToRefreshToken(t *testing.T) { grpcConnFactory: func(_ *HostPort) (grpcClientConn, error) { return nil, nil }, - connFactory: func(_ grpcClientConn) *connection { - return &connection{} - }, - token: mockToken, + connFactory: newConnection, + token: mockToken, } cp.seeds = HostPortSlice{ diff --git a/docker/.env b/docker/.env new file mode 100644 index 0000000..447d4f2 --- /dev/null +++ b/docker/.env @@ -0,0 +1 @@ +AVS_IMAGE=aerospike/aerospike-vector-search:0.10.0 \ No newline at end of file diff --git a/docker/auth/docker-compose.yml b/docker/auth/docker-compose.yml index 3a1c1a8..a3ae1a8 100644 --- a/docker/auth/docker-compose.yml +++ b/docker/auth/docker-compose.yml @@ -16,7 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} depends_on: aerospike: condition: service_healthy diff --git a/docker/mtls/docker-compose.yml b/docker/mtls/docker-compose.yml index 3a1c1a8..a3ae1a8 100644 --- a/docker/mtls/docker-compose.yml +++ b/docker/mtls/docker-compose.yml @@ -16,7 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} depends_on: aerospike: condition: service_healthy diff --git a/docker/multi-node-LB/docker-compose.yml b/docker/multi-node-LB/docker-compose.yml index 3e59264..89b0981 100644 --- a/docker/multi-node-LB/docker-compose.yml +++ b/docker/multi-node-LB/docker-compose.yml @@ -20,7 +20,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} volumes: - ./config/aerospike-vector-search-1.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml - ./config/features.conf:/etc/aerospike-vector-search/features.conf @@ -35,7 +35,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} volumes: - ./config/aerospike-vector-search-2.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml - ./config/features.conf:/etc/aerospike-vector-search/features.conf @@ -50,7 +50,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} volumes: - ./config/aerospike-vector-search-3.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml - ./config/features.conf:/etc/aerospike-vector-search/features.conf diff --git a/docker/multi-node-client-visibility-err/docker-compose.yml b/docker/multi-node-client-visibility-err/docker-compose.yml index 01d12e4..c0dbcc3 100644 --- a/docker/multi-node-client-visibility-err/docker-compose.yml +++ b/docker/multi-node-client-visibility-err/docker-compose.yml @@ -20,7 +20,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10000:10000" volumes: @@ -37,7 +37,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10001:10001" volumes: @@ -54,7 +54,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} # ports: # - "10002:10002" # This causes the visibility err volumes: diff --git a/docker/multi-node/docker-compose.yml b/docker/multi-node/docker-compose.yml index 2478015..94157a0 100644 --- a/docker/multi-node/docker-compose.yml +++ b/docker/multi-node/docker-compose.yml @@ -20,7 +20,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10000:10000" volumes: @@ -37,7 +37,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10001:10001" volumes: @@ -54,7 +54,7 @@ services: depends_on: aerospike: condition: service_healthy - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} ports: - "10002:10002" volumes: diff --git a/docker/tls/docker-compose.yml b/docker/tls/docker-compose.yml index 3a1c1a8..a3ae1a8 100644 --- a/docker/tls/docker-compose.yml +++ b/docker/tls/docker-compose.yml @@ -16,7 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} depends_on: aerospike: condition: service_healthy diff --git a/docker/vanilla/docker-compose.yml b/docker/vanilla/docker-compose.yml index 3a1c1a8..a3ae1a8 100644 --- a/docker/vanilla/docker-compose.yml +++ b/docker/vanilla/docker-compose.yml @@ -16,7 +16,7 @@ services: timeout: 20s retries: 20 avs: - image: aerospike/aerospike-vector-search:0.10.0 + image: ${AVS_IMAGE:-"AVS_IMAGE env not set"} depends_on: aerospike: condition: service_healthy diff --git a/integration_single_node_test.go b/integration_single_node_test.go index ad3e2ab..28fba80 100644 --- a/integration_single_node_test.go +++ b/integration_single_node_test.go @@ -283,8 +283,8 @@ func (suite *SingleNodeTestSuite) TestIndexCreate() { &IndexCreateOpts{ Sets: []string{"testset"}, Storage: &protos.IndexStorage{ - Namespace: GetStrPtr("storage-ns"), - Set: GetStrPtr("storage-set"), + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), }, Labels: map[string]string{ "a": "b", @@ -298,11 +298,11 @@ func (suite *SingleNodeTestSuite) TestIndexCreate() { Dimensions: uint32(10), VectorDistanceMetric: protos.VectorDistanceMetric_COSINE, Type: protos.IndexType_HNSW, - SetFilter: GetStrPtr("testset"), + SetFilter: Ptr("testset"), Field: "vector", Storage: &protos.IndexStorage{ - Namespace: GetStrPtr("storage-ns"), - Set: GetStrPtr("storage-set"), + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), }, Labels: map[string]string{ "a": "b", @@ -362,8 +362,8 @@ func (suite *SingleNodeTestSuite) TestIndexUpdate() { opts: &IndexCreateOpts{ Sets: []string{"testset"}, Storage: &protos.IndexStorage{ - Namespace: GetStrPtr("storage-ns"), - Set: GetStrPtr("storage-set"), + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), }, Labels: map[string]string{ "a": "b", @@ -379,11 +379,11 @@ func (suite *SingleNodeTestSuite) TestIndexUpdate() { Dimensions: uint32(10), VectorDistanceMetric: protos.VectorDistanceMetric_COSINE, Type: protos.IndexType_HNSW, - SetFilter: GetStrPtr("testset"), + SetFilter: Ptr("testset"), Field: "vector", Storage: &protos.IndexStorage{ - Namespace: GetStrPtr("storage-ns"), - Set: GetStrPtr("storage-set"), + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), }, Labels: map[string]string{ "a": "b", @@ -403,33 +403,33 @@ func (suite *SingleNodeTestSuite) TestIndexUpdate() { opts: &IndexCreateOpts{ Sets: []string{"testset"}, Storage: &protos.IndexStorage{ - Namespace: GetStrPtr("storage-ns"), - Set: GetStrPtr("storage-set"), + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), }, Labels: map[string]string{ "a": "b", }, }, updateHnsw: &protos.HnswIndexUpdate{ - MaxMemQueueSize: GetUint32Ptr(100), + MaxMemQueueSize: Ptr(uint32(100)), BatchingParams: &protos.HnswBatchingParams{ - MaxRecords: GetUint32Ptr(10_001), - Interval: GetUint32Ptr(10_002), + MaxRecords: Ptr(uint32(10_001)), + Interval: Ptr(uint32(10_002)), }, CachingParams: &protos.HnswCachingParams{ - MaxEntries: GetUint64Ptr(10_003), - Expiry: GetUint64Ptr(10_004), + MaxEntries: Ptr(uint64(10_003)), + Expiry: Ptr(uint64(10_004)), }, HealerParams: &protos.HnswHealerParams{ - MaxScanRatePerNode: GetUint32Ptr(10_005), - MaxScanPageSize: GetUint32Ptr(10_006), - ReindexPercent: GetFloat32Ptr(51), - Schedule: GetStrPtr("0 0 0 25 12 ?"), - Parallelism: GetUint32Ptr(1), + MaxScanRatePerNode: Ptr(uint32(10_005)), + MaxScanPageSize: Ptr(uint32(10_006)), + ReindexPercent: Ptr(float32(51)), + Schedule: Ptr("0 0 0 25 12 ?"), + Parallelism: Ptr(uint32(1)), }, MergeParams: &protos.HnswIndexMergeParams{ - IndexParallelism: GetUint32Ptr(2), - ReIndexParallelism: GetUint32Ptr(3), + IndexParallelism: Ptr(uint32(2)), + ReIndexParallelism: Ptr(uint32(3)), }, }, updateLabels: map[string]string{ @@ -443,11 +443,11 @@ func (suite *SingleNodeTestSuite) TestIndexUpdate() { Dimensions: uint32(10), VectorDistanceMetric: protos.VectorDistanceMetric_COSINE, Type: protos.IndexType_HNSW, - SetFilter: GetStrPtr("testset"), + SetFilter: Ptr("testset"), Field: "vector", Storage: &protos.IndexStorage{ - Namespace: GetStrPtr("storage-ns"), - Set: GetStrPtr("storage-set"), + Namespace: Ptr("storage-ns"), + Set: Ptr("storage-set"), }, Labels: map[string]string{ "a": "b", @@ -455,25 +455,25 @@ func (suite *SingleNodeTestSuite) TestIndexUpdate() { }, Params: &protos.IndexDefinition_HnswParams{ HnswParams: &protos.HnswParams{ - MaxMemQueueSize: GetUint32Ptr(100), + MaxMemQueueSize: Ptr(uint32(100)), BatchingParams: &protos.HnswBatchingParams{ - MaxRecords: GetUint32Ptr(10_001), - Interval: GetUint32Ptr(10_002), + MaxRecords: Ptr(uint32(10_001)), + Interval: Ptr(uint32(10_002)), }, CachingParams: &protos.HnswCachingParams{ - MaxEntries: GetUint64Ptr(10_003), - Expiry: GetUint64Ptr(10_004), + MaxEntries: Ptr(uint64(10_003)), + Expiry: Ptr(uint64(10_004)), }, HealerParams: &protos.HnswHealerParams{ - MaxScanRatePerNode: GetUint32Ptr(10_005), - MaxScanPageSize: GetUint32Ptr(10_006), - ReindexPercent: GetFloat32Ptr(51), - Schedule: GetStrPtr("0 0 0 25 12 ?"), - Parallelism: GetUint32Ptr(1), + MaxScanRatePerNode: Ptr(uint32(10_005)), + MaxScanPageSize: Ptr(uint32(10_006)), + ReindexPercent: Ptr(float32(51)), + Schedule: Ptr("0 0 0 25 12 ?"), + Parallelism: Ptr(uint32(1)), }, MergeParams: &protos.HnswIndexMergeParams{ - IndexParallelism: GetUint32Ptr(2), - ReIndexParallelism: GetUint32Ptr(3), + IndexParallelism: Ptr(uint32(2)), + ReIndexParallelism: Ptr(uint32(3)), }, }, }, @@ -1187,7 +1187,7 @@ func (suite *SingleNodeTestSuite) TestConnectedNodeEndpoint() { nodeId: &protos.NodeId{ Id: 1, }, - expectedErrMsg: GetStrPtr("failed to get connected endpoint"), + expectedErrMsg: Ptr("failed to get connected endpoint"), }, } @@ -1225,7 +1225,7 @@ func (suite *SingleNodeTestSuite) TestClusteringState() { nodeId: &protos.NodeId{ Id: 1, }, - expectedErrMsg: GetStrPtr("failed to get clustering state"), + expectedErrMsg: Ptr("failed to get clustering state"), }, } @@ -1274,7 +1274,7 @@ func (suite *SingleNodeTestSuite) TestClusterEndpoints() { nodeId: &protos.NodeId{ Id: 1, }, - expectedErrMsg: GetStrPtr("failed to get cluster endpoints"), + expectedErrMsg: Ptr("failed to get cluster endpoints"), }, } @@ -1323,7 +1323,7 @@ func (suite *SingleNodeTestSuite) TestAbout() { nodeId: &protos.NodeId{ Id: 1, }, - expectedErrMsg: GetStrPtr("failed to make about request"), + expectedErrMsg: Ptr("failed to make about request"), }, } diff --git a/makefile b/makefile index 028d918..e6aa9f5 100644 --- a/makefile +++ b/makefile @@ -45,7 +45,7 @@ mocks: get-mockgen test: unit integration .PHONY: integration -integration: $(GOLEAK) +integration: mkdir -p $(COV_INTEGRATION_DIR) || true go test -tags=integration -timeout 30m -cover ./... -args -test.gocoverdir=$(COV_INTEGRATION_DIR) diff --git a/protos/utils_test.go b/protos/utils_test.go index 55e19b2..7e97e65 100644 --- a/protos/utils_test.go +++ b/protos/utils_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" ) -func GetStrPtr(str string) *string { +func Ptr(str string) *string { ptr := str return &ptr } @@ -22,7 +22,7 @@ func TestConvertToKey(t *testing.T) { input: "testString", expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_StringValue{ StringValue: "testString", }, @@ -32,7 +32,7 @@ func TestConvertToKey(t *testing.T) { input: []byte{0x01, 0x02, 0x03}, expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_BytesValue{ BytesValue: []byte{0x01, 0x02, 0x03}, }, @@ -42,7 +42,7 @@ func TestConvertToKey(t *testing.T) { input: int32(123), expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_IntValue{ IntValue: 123, }, @@ -52,7 +52,7 @@ func TestConvertToKey(t *testing.T) { input: int64(123456789), expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_LongValue{ LongValue: 123456789, }, @@ -62,7 +62,7 @@ func TestConvertToKey(t *testing.T) { input: int(123456789), expected: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_LongValue{ LongValue: 123456789, }, @@ -76,7 +76,7 @@ func TestConvertToKey(t *testing.T) { } for _, tc := range testCases { - result, err := ConvertToKey("testNamespace", GetStrPtr("testSet"), tc.input) + result, err := ConvertToKey("testNamespace", Ptr("testSet"), tc.input) assert.Equal(t, tc.expected, result) assert.Equal(t, tc.expectedErr, err) @@ -98,59 +98,59 @@ func TestConvertFromKey(t *testing.T) { { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_StringValue{ StringValue: "testString", }, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: "testString", }, { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_BytesValue{ BytesValue: []byte{0x01, 0x02, 0x03}, }, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: []byte{0x01, 0x02, 0x03}, }, { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_IntValue{ IntValue: 123, }, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: int32(123), }, { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &Key_LongValue{ LongValue: 123456789, }, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: int64(123456789), }, { input: &Key{ Namespace: "testNamespace", - Set: GetStrPtr("testSet"), + Set: Ptr("testSet"), Value: &keyUnknown{}, }, expectedNamespace: "testNamespace", - expectedSet: GetStrPtr("testSet"), + expectedSet: Ptr("testSet"), expectedKey: nil, // Unsupported or nil input expectedErr: fmt.Errorf("unsupported key value type: *protos.keyUnknown"), }, @@ -614,12 +614,12 @@ func TestConvertToMapValue(t *testing.T) { { input: map[int]any{10: struct{}{}}, expected: nil, - expectedErr: GetStrPtr("unsupported map value: unsupported value type: struct {}"), + expectedErr: Ptr("unsupported map value: unsupported value type: struct {}"), }, { input: map[any]any{struct{}{}: 10}, expected: nil, - expectedErr: GetStrPtr("unsupported map key: unsupported key type: struct {}"), + expectedErr: Ptr("unsupported map key: unsupported key type: struct {}"), }, } @@ -694,7 +694,7 @@ func TestConvertToList(t *testing.T) { { input: []any{struct{}{}}, expected: nil, - expectedErr: GetStrPtr("unsupported list value: unsupported value type: struct {}"), + expectedErr: Ptr("unsupported list value: unsupported value type: struct {}"), }, } @@ -775,7 +775,7 @@ func TestConvertFromListValue(t *testing.T) { }, }, expected: nil, - expectedErr: GetStrPtr("unsupported list value: unsupported value type: *protos.valueUnknown"), + expectedErr: Ptr("unsupported list value: unsupported value type: *protos.valueUnknown"), }, } @@ -870,7 +870,7 @@ func TestConvertFromMapValue(t *testing.T) { }, }, expected: nilMap, - expectedErr: GetStrPtr("unsupported map key value type: *protos.mapKeyValueUnknown"), + expectedErr: Ptr("unsupported map key value type: *protos.mapKeyValueUnknown"), }, { input: &Map{ @@ -888,7 +888,7 @@ func TestConvertFromMapValue(t *testing.T) { }, }, expected: nilMap, - expectedErr: GetStrPtr("unsupported map value: unsupported value type: *protos.valueUnknown"), + expectedErr: Ptr("unsupported map value: unsupported value type: *protos.valueUnknown"), }, } diff --git a/testutils.go b/testutils.go index 5c2cd7d..8a4f6b9 100644 --- a/testutils.go +++ b/testutils.go @@ -72,32 +72,8 @@ func (suite *ServerTestBaseSuite) TearDownSuite() { goleak.VerifyNone(suite.T()) } -func GetStrPtr(str string) *string { - ptr := str - return &ptr -} - -func GetUint32Ptr(i int) *uint32 { - ptr := uint32(i) - return &ptr -} - -func GetUint64Ptr(i int) *uint64 { - ptr := uint64(i) - return &ptr -} - -func GetFloat32Ptr(i float32) *float32 { - ptr := float32(i) - return &ptr -} - -func GetBoolPtr(b bool) *bool { - return &b -} - -func GetTimePtr(t time.Time) *time.Time { - return &t +func Ptr[T any](value T) *T { + return &value } func CreateFlagStr(name, value string) string { @@ -253,12 +229,12 @@ func (idb *IndexDefinitionBuilder) Build() *protos.IndexDefinition { }, Params: &protos.IndexDefinition_HnswParams{ HnswParams: &protos.HnswParams{ - M: GetUint32Ptr(16), - EfConstruction: GetUint32Ptr(100), - Ef: GetUint32Ptr(100), + M: Ptr(uint32(16)), + EfConstruction: Ptr(uint32(100)), + Ef: Ptr(uint32(100)), BatchingParams: &protos.HnswBatchingParams{ - MaxRecords: GetUint32Ptr(100000), - Interval: GetUint32Ptr(30000), + MaxRecords: Ptr(uint32(100000)), + Interval: Ptr(uint32(30000)), }, CachingParams: &protos.HnswCachingParams{}, HealerParams: &protos.HnswHealerParams{}, @@ -349,7 +325,7 @@ func DockerComposeUp(composeFile string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - cmd := exec.CommandContext(ctx, "docker", "-lDEBUG", "compose", fmt.Sprintf("-f%s", composeFile), "up", "-d") + cmd := exec.CommandContext(ctx, "docker", "-lDEBUG", "compose", fmt.Sprintf("-f%s", composeFile), "--env-file", "docker/.env", "up", "-d") err := cmd.Run() cmd.Wait() From afec4abffb24d46c1f460098402e007fd6b3e293 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 26 Sep 2024 19:47:47 -0700 Subject: [PATCH 39/42] fix lint --- connection_provider.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/connection_provider.go b/connection_provider.go index b54e6fc..2989e71 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -555,7 +555,9 @@ func (cp *connectionProvider) getUpdatedEndpoints(ctx context.Context) map[uint6 slog.Any("endpoints", largestNewCluster.endpoints), slog.Uint64("id", largestNewCluster.id), ) + cp.clusterID = largestNewCluster.id + return largestNewCluster.endpoints } From 6dc0ad6dc6078038f630d5933b60af8bdd2b2682 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 27 Sep 2024 09:19:09 -0700 Subject: [PATCH 40/42] fix tests, combine factory methods --- connection_provider.go | 81 +++++++++++++--------------- connection_provider_test.go | 102 +++++++++++++++++++++++------------- 2 files changed, 101 insertions(+), 82 deletions(-) diff --git a/connection_provider.go b/connection_provider.go index 2989e71..64aab1b 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -91,21 +91,21 @@ func newConnAndEndpoints(conn *connection, endpoints *protos.ServerEndpointList) // //nolint:govet // We will favor readability over field alignment type connectionProvider struct { - logger *slog.Logger - nodeConns map[uint64]*connectionAndEndpoints - seedConns []*connection - tlsConfig *tls.Config - seeds HostPortSlice - nodeConnsLock *sync.RWMutex - tendInterval time.Duration - clusterID uint64 - listenerName *string - isLoadBalancer bool - token tokenManager - stopTendChan chan struct{} - closed atomic.Bool - grpcConnFactory func(hostPort *HostPort) (grpcClientConn, error) - connFactory func(conn grpcClientConn) *connection + logger *slog.Logger + nodeConns map[uint64]*connectionAndEndpoints + seedConns []*connection + tlsConfig *tls.Config + seeds HostPortSlice + nodeConnsLock *sync.RWMutex + tendInterval time.Duration + clusterID uint64 + listenerName *string + isLoadBalancer bool + token tokenManager + stopTendChan chan struct{} + closed atomic.Bool + // grpcConnFactory func(hostPort *HostPort) (grpcClientConn, error) + connFactory func(hostPort *HostPort) (*connection, error) } // newConnectionProvider creates a new connectionProvider instance. @@ -157,9 +157,13 @@ func newConnectionProvider( closed: atomic.Bool{}, } - cp.connFactory = newConnection - cp.grpcConnFactory = func(hostPort *HostPort) (grpcClientConn, error) { - return createGrcpConn(cp, hostPort) + cp.connFactory = func(hostPort *HostPort) (*connection, error) { + grpcConn, err := createGrcpConn(cp, hostPort) + if err != nil { + return nil, err + } + + return newConnection(grpcConn), nil } // Connect to the seed nodes. @@ -354,16 +358,14 @@ func (cp *connectionProvider) connectToSeeds(ctx context.Context) error { defer wg.Done() logger := cp.logger.With(slog.String("host", seed.String())) + extraCheck := true - grpcConn, err := cp.grpcConnFactory(seed) + conn, err := cp.connFactory(seed) if err != nil { logger.ErrorContext(ctx, "failed to create connection", slog.Any("error", err)) return } - extraCheck := true - conn := newConnection(grpcConn) - if cp.token != nil { // Only one thread needs to refresh the token. Only first will // succeed others will block @@ -678,26 +680,6 @@ func (cp *connectionProvider) tend(ctx context.Context) { } } -// createGrpcConnFromEndpoints creates a gRPC client connection from the first -// successful endpoint in endpoints. -func (cp *connectionProvider) createGrpcConnFromEndpoints( - endpoints *protos.ServerEndpointList, -) (grpcClientConn, error) { - for _, endpoint := range endpoints.Endpoints { - if strings.ContainsRune(endpoint.Address, ':') { - continue // TODO: Add logging and support for IPv6 - } - - conn, err := cp.grpcConnFactory(endpointToHostPort(endpoint)) - - if err == nil { - return conn, nil - } - } - - return nil, errors.New("no valid endpoint found") -} - // createGrcpConn creates a gRPC client connection to a host. This handles adding // credential and configuring tls. func createGrcpConn(cp *connectionProvider, hostPort *HostPort) (grpcClientConn, error) { @@ -732,10 +714,17 @@ func createGrcpConn(cp *connectionProvider, hostPort *HostPort) (grpcClientConn, } func (cp *connectionProvider) createConnFromEndpoints(endpoints *protos.ServerEndpointList) (*connection, error) { - conn, err := cp.createGrpcConnFromEndpoints(endpoints) - if err != nil { - return nil, err + for _, endpoint := range endpoints.Endpoints { + if strings.ContainsRune(endpoint.Address, ':') { + continue // TODO: Add logging and support for IPv6 + } + + conn, err := cp.connFactory(endpointToHostPort(endpoint)) + + if err == nil { + return conn, nil + } } - return cp.connFactory(conn), nil + return nil, errors.New("no valid endpoint found") } diff --git a/connection_provider_test.go b/connection_provider_test.go index 637ef62..3540234 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -212,7 +212,7 @@ func TestConnectToSeeds_FailedFailedToCreateConnection(t *testing.T) { isLoadBalancer: true, closed: atomic.Bool{}, logger: slog.Default(), - grpcConnFactory: func(_ *HostPort) (grpcClientConn, error) { + connFactory: func(_ *HostPort) (*connection, error) { return nil, fmt.Errorf("foo") }, } @@ -243,11 +243,11 @@ func TestConnectToSeeds_FailedToRefreshToken(t *testing.T) { isLoadBalancer: true, closed: atomic.Bool{}, logger: slog.Default(), - grpcConnFactory: func(_ *HostPort) (grpcClientConn, error) { + connFactory: func(_ *HostPort) (*connection, error) { return nil, nil }, - connFactory: newConnection, - token: mockToken, + // connFactory: newConnection, + token: mockToken, } cp.seeds = HostPortSlice{ @@ -349,8 +349,8 @@ func TestUpdateClusterConns_NewClusterIDWithDIFFERENTNodeIDs(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - mockNewGrpcConn1111 := NewMockgrpcClientConn(ctrl) - mockNewGrpcConn2222 := NewMockgrpcClientConn(ctrl) + // mockNewGrpcConn1111 := NewMockgrpcClientConn(ctrl) + // mockNewGrpcConn2222 := NewMockgrpcClientConn(ctrl) mockClusterInfoClient1111 := protos.NewMockClusterInfoServiceClient(ctrl) mockClusterInfoClient2222 := protos.NewMockClusterInfoServiceClient(ctrl) @@ -382,30 +382,45 @@ func TestUpdateClusterConns_NewClusterIDWithDIFFERENTNodeIDs(t *testing.T) { token: nil, stopTendChan: make(chan struct{}), closed: atomic.Bool{}, - grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + connFactory: func(hostPort *HostPort) (*connection, error) { if hostPort.String() == "1.1.1.1:3000" { - return mockNewGrpcConn1111, nil - } else if hostPort.String() == "2.2.2.2:3000" { - return mockNewGrpcConn2222, nil - } - - return nil, fmt.Errorf("foo") - }, - connFactory: func(grpcConn grpcClientConn) *connection { - if grpcConn == mockNewGrpcConn1111 { return &connection{ clusterInfoClient: mockClusterInfoClient1111, aboutClient: mockAboutClient1111, - } - } else if grpcConn == mockNewGrpcConn2222 { + }, nil + } else if hostPort.String() == "2.2.2.2:3000" { return &connection{ clusterInfoClient: mockClusterInfoClient2222, aboutClient: mockAboutClient2222, - } + }, nil } - return nil + return nil, fmt.Errorf("foo") }, + // grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + // if hostPort.String() == "1.1.1.1:3000" { + // return mockNewGrpcConn1111, nil + // } else if hostPort.String() == "2.2.2.2:3000" { + // return mockNewGrpcConn2222, nil + // } + + // return nil, fmt.Errorf("foo") + // }, + // connFactory: func(grpcConn grpcClientConn) *connection { + // if grpcConn == mockNewGrpcConn1111 { + // return &connection{ + // clusterInfoClient: mockClusterInfoClient1111, + // aboutClient: mockAboutClient1111, + // } + // } else if grpcConn == mockNewGrpcConn2222 { + // return &connection{ + // clusterInfoClient: mockClusterInfoClient2222, + // aboutClient: mockAboutClient2222, + // } + // } + + // return nil + // }, } cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID")) @@ -599,8 +614,8 @@ func TestUpdateClusterConns_NewClusterIDWithSAMENodeIDs(t *testing.T) { }, }, nil) - mockNewGrpcConn1111 := NewMockgrpcClientConn(ctrl) - mockNewGrpcConn2222 := NewMockgrpcClientConn(ctrl) + // mockNewGrpcConn1111 := NewMockgrpcClientConn(ctrl) + // mockNewGrpcConn2222 := NewMockgrpcClientConn(ctrl) mockClusterInfoClient1111 := protos.NewMockClusterInfoServiceClient(ctrl) mockClusterInfoClient2222 := protos.NewMockClusterInfoServiceClient(ctrl) @@ -631,30 +646,45 @@ func TestUpdateClusterConns_NewClusterIDWithSAMENodeIDs(t *testing.T) { token: nil, stopTendChan: make(chan struct{}), closed: atomic.Bool{}, - grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + connFactory: func(hostPort *HostPort) (*connection, error) { if hostPort.String() == "1.1.1.1:3000" { - return mockNewGrpcConn1111, nil - } else if hostPort.String() == "2.2.2.2:3000" { - return mockNewGrpcConn2222, nil - } - - return nil, fmt.Errorf("foo") - }, - connFactory: func(grpcConn grpcClientConn) *connection { - if grpcConn == mockNewGrpcConn1111 { return &connection{ clusterInfoClient: mockClusterInfoClient1111, aboutClient: mockAboutClient1111, - } - } else if grpcConn == mockNewGrpcConn2222 { + }, nil + } else if hostPort.String() == "2.2.2.2:3000" { return &connection{ clusterInfoClient: mockClusterInfoClient2222, aboutClient: mockAboutClient2222, - } + }, nil } - return nil + return nil, fmt.Errorf("foo") }, + // grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { + // if hostPort.String() == "1.1.1.1:3000" { + // return mockNewGrpcConn1111, nil + // } else if hostPort.String() == "2.2.2.2:3000" { + // return mockNewGrpcConn2222, nil + // } + + // return nil, fmt.Errorf("foo") + // }, + // connFactory: func(grpcConn grpcClientConn) *connection { + // if grpcConn == mockNewGrpcConn1111 { + // return &connection{ + // clusterInfoClient: mockClusterInfoClient1111, + // aboutClient: mockAboutClient1111, + // } + // } else if grpcConn == mockNewGrpcConn2222 { + // return &connection{ + // clusterInfoClient: mockClusterInfoClient2222, + // aboutClient: mockAboutClient2222, + // } + // } + + // return nil + // }, // Existing node connections. These will be replaced after a new cluster is found. nodeConns: map[uint64]*connectionAndEndpoints{ 1: { From 4a7c3a9fde6796d75f407b36bc44c8ed56952cdb Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 27 Sep 2024 09:23:02 -0700 Subject: [PATCH 41/42] fix lint --- connection_provider_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/connection_provider_test.go b/connection_provider_test.go index 3540234..168fba4 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -349,9 +349,6 @@ func TestUpdateClusterConns_NewClusterIDWithDIFFERENTNodeIDs(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - // mockNewGrpcConn1111 := NewMockgrpcClientConn(ctrl) - // mockNewGrpcConn2222 := NewMockgrpcClientConn(ctrl) - mockClusterInfoClient1111 := protos.NewMockClusterInfoServiceClient(ctrl) mockClusterInfoClient2222 := protos.NewMockClusterInfoServiceClient(ctrl) @@ -614,9 +611,6 @@ func TestUpdateClusterConns_NewClusterIDWithSAMENodeIDs(t *testing.T) { }, }, nil) - // mockNewGrpcConn1111 := NewMockgrpcClientConn(ctrl) - // mockNewGrpcConn2222 := NewMockgrpcClientConn(ctrl) - mockClusterInfoClient1111 := protos.NewMockClusterInfoServiceClient(ctrl) mockClusterInfoClient2222 := protos.NewMockClusterInfoServiceClient(ctrl) From 28bcc993d4e5f96a79181b0706d5f5027fef4ae3 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Fri, 27 Sep 2024 09:38:13 -0700 Subject: [PATCH 42/42] rm comments --- .golangci.yml | 2 +- connection_provider.go | 3 +-- connection_provider_test.go | 50 +------------------------------------ 3 files changed, 3 insertions(+), 52 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 3538c5c..bc7bc30 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -80,7 +80,7 @@ issues: - goconst # Test code is allowed to have constants - path: 'connection_provider_test.go' linters: - - goconst # Test cod + - goconst # - path: dir/sample\.go # linters: # - lll # Test code is allowed to have long lines diff --git a/connection_provider.go b/connection_provider.go index 64aab1b..cdca0ad 100644 --- a/connection_provider.go +++ b/connection_provider.go @@ -104,8 +104,7 @@ type connectionProvider struct { token tokenManager stopTendChan chan struct{} closed atomic.Bool - // grpcConnFactory func(hostPort *HostPort) (grpcClientConn, error) - connFactory func(hostPort *HostPort) (*connection, error) + connFactory func(hostPort *HostPort) (*connection, error) } // newConnectionProvider creates a new connectionProvider instance. diff --git a/connection_provider_test.go b/connection_provider_test.go index 168fba4..86e58d4 100644 --- a/connection_provider_test.go +++ b/connection_provider_test.go @@ -246,7 +246,6 @@ func TestConnectToSeeds_FailedToRefreshToken(t *testing.T) { connFactory: func(_ *HostPort) (*connection, error) { return nil, nil }, - // connFactory: newConnection, token: mockToken, } @@ -394,30 +393,6 @@ func TestUpdateClusterConns_NewClusterIDWithDIFFERENTNodeIDs(t *testing.T) { return nil, fmt.Errorf("foo") }, - // grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { - // if hostPort.String() == "1.1.1.1:3000" { - // return mockNewGrpcConn1111, nil - // } else if hostPort.String() == "2.2.2.2:3000" { - // return mockNewGrpcConn2222, nil - // } - - // return nil, fmt.Errorf("foo") - // }, - // connFactory: func(grpcConn grpcClientConn) *connection { - // if grpcConn == mockNewGrpcConn1111 { - // return &connection{ - // clusterInfoClient: mockClusterInfoClient1111, - // aboutClient: mockAboutClient1111, - // } - // } else if grpcConn == mockNewGrpcConn2222 { - // return &connection{ - // clusterInfoClient: mockClusterInfoClient2222, - // aboutClient: mockAboutClient2222, - // } - // } - - // return nil - // }, } cp.logger = cp.logger.With(slog.String("test", "TestUpdateClusterConns_NewClusterID")) @@ -655,30 +630,7 @@ func TestUpdateClusterConns_NewClusterIDWithSAMENodeIDs(t *testing.T) { return nil, fmt.Errorf("foo") }, - // grpcConnFactory: func(hostPort *HostPort) (grpcClientConn, error) { - // if hostPort.String() == "1.1.1.1:3000" { - // return mockNewGrpcConn1111, nil - // } else if hostPort.String() == "2.2.2.2:3000" { - // return mockNewGrpcConn2222, nil - // } - - // return nil, fmt.Errorf("foo") - // }, - // connFactory: func(grpcConn grpcClientConn) *connection { - // if grpcConn == mockNewGrpcConn1111 { - // return &connection{ - // clusterInfoClient: mockClusterInfoClient1111, - // aboutClient: mockAboutClient1111, - // } - // } else if grpcConn == mockNewGrpcConn2222 { - // return &connection{ - // clusterInfoClient: mockClusterInfoClient2222, - // aboutClient: mockAboutClient2222, - // } - // } - - // return nil - // }, + // Existing node connections. These will be replaced after a new cluster is found. nodeConns: map[uint64]*connectionAndEndpoints{ 1: {