diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 970a48ff2cc4..ef832ed8cbf0 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -19,6 +19,9 @@ jobs: - name: Run coverage run: go test -coverprofile=coverage.out -coverpkg=./... ./... + - name: Run coverage with new pickfirst + run: GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST=true go test -coverprofile=coverage_new_pickfirst.out -coverpkg=./... ./... + - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 with: diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 5a2ad60776eb..a6576a21fa15 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -70,6 +70,11 @@ jobs: - type: tests goversion: '1.21' + - type: tests + goversion: '1.22' + testflags: -race + grpcenv: 'GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST=true' + steps: # Setup the environment. - name: Setup GOARCH diff --git a/balancer/pickfirst/pickfirst.go b/balancer/pickfirst/pickfirst.go index 3e792b2b366f..e069346a7565 100644 --- a/balancer/pickfirst/pickfirst.go +++ b/balancer/pickfirst/pickfirst.go @@ -29,13 +29,19 @@ import ( "google.golang.org/grpc/balancer/pickfirst/internal" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/envconfig" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" + + _ "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" // For automatically registering the new pickfirst if required. ) func init() { + if envconfig.NewPickFirstEnabled { + return + } balancer.Register(pickfirstBuilder{}) } diff --git a/balancer/pickfirst/pickfirst_test.go b/balancer/pickfirst/pickfirst_test.go new file mode 100644 index 000000000000..43d8b20df3e7 --- /dev/null +++ b/balancer/pickfirst/pickfirst_test.go @@ -0,0 +1,132 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package pickfirst + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" +) + +const ( + // Default timeout for tests in this package. + defaultTestTimeout = 10 * time.Second + // Default short timeout, to be used when waiting for events which are not + // expected to happen. + defaultTestShortTimeout = 100 * time.Millisecond +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// TestPickFirstLeaf_InitialResolverError sends a resolver error to the balancer +// before a valid resolver update. It verifies that the clientconn state is +// updated to TRANSIENT_FAILURE. +func (s) TestPickFirstLeaf_InitialResolverError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc := testutils.NewBalancerClientConn(t) + bal := pickfirstBuilder{}.Build(cc, balancer.BuildOptions{}) + defer bal.Close() + bal.ResolverError(errors.New("resolution failed: test error")) + + if err := cc.WaitForConnectivityState(ctx, connectivity.TransientFailure); err != nil { + t.Fatalf("cc.WaitForConnectivityState(%v) returned error: %v", connectivity.TransientFailure, err) + } + + // After sending a valid update, the LB policy should report CONNECTING. + ccState := balancer.ClientConnState{ + ResolverState: resolver.State{ + Endpoints: []resolver.Endpoint{ + {Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}}, + {Addresses: []resolver.Address{{Addr: "2.2.2.2:2"}}}, + }, + }, + } + if err := bal.UpdateClientConnState(ccState); err != nil { + t.Fatalf("UpdateClientConnState(%v) returned error: %v", ccState, err) + } + + if err := cc.WaitForConnectivityState(ctx, connectivity.Connecting); err != nil { + t.Fatalf("cc.WaitForConnectivityState(%v) returned error: %v", connectivity.Connecting, err) + } +} + +// TestPickFirstLeaf_ResolverErrorinTF sends a resolver error to the balancer +// before when it's attempting to connect to a SubConn TRANSIENT_FAILURE. It +// verifies that the picker is updated and the SubConn is not closed. +func (s) TestPickFirstLeaf_ResolverErrorinTF(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc := testutils.NewBalancerClientConn(t) + bal := pickfirstBuilder{}.Build(cc, balancer.BuildOptions{}) + defer bal.Close() + + // After sending a valid update, the LB policy should report CONNECTING. + ccState := balancer.ClientConnState{ + ResolverState: resolver.State{ + Endpoints: []resolver.Endpoint{ + {Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}}, + }, + }, + } + + if err := bal.UpdateClientConnState(ccState); err != nil { + t.Fatalf("UpdateClientConnState(%v) returned error: %v", ccState, err) + } + + sc1 := <-cc.NewSubConnCh + if err := cc.WaitForConnectivityState(ctx, connectivity.Connecting); err != nil { + t.Fatalf("cc.WaitForConnectivityState(%v) returned error: %v", connectivity.Connecting, err) + } + + scErr := fmt.Errorf("test error: connection refused") + sc1.UpdateState(balancer.SubConnState{ + ConnectivityState: connectivity.TransientFailure, + ConnectionError: scErr, + }) + + if err := cc.WaitForPickerWithErr(ctx, scErr); err != nil { + t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", scErr, err) + } + + bal.ResolverError(errors.New("resolution failed: test error")) + if err := cc.WaitForErrPicker(ctx); err != nil { + t.Fatalf("cc.WaitForPickerWithErr() returned error: %v", err) + } + + select { + case <-time.After(defaultTestShortTimeout): + case sc := <-cc.ShutdownSubConnCh: + t.Fatalf("Unexpected SubConn shutdown: %v", sc) + } +} diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go new file mode 100644 index 000000000000..48ce8c50e5c1 --- /dev/null +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go @@ -0,0 +1,624 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package pickfirstleaf contains the pick_first load balancing policy which +// will be the universal leaf policy after dualstack changes are implemented. +// +// # Experimental +// +// Notice: This package is EXPERIMENTAL and may be changed or removed in a +// later release. +package pickfirstleaf + +import ( + "encoding/json" + "errors" + "fmt" + "sync" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/pickfirst/internal" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/envconfig" + internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" +) + +func init() { + if envconfig.NewPickFirstEnabled { + // Register as the default pick_first balancer. + Name = "pick_first" + } + balancer.Register(pickfirstBuilder{}) +} + +var ( + logger = grpclog.Component("pick-first-leaf-lb") + // Name is the name of the pick_first_leaf balancer. + // It is changed to "pick_first" in init() if this balancer is to be + // registered as the default pickfirst. + Name = "pick_first_leaf" +) + +// TODO: change to pick-first when this becomes the default pick_first policy. +const logPrefix = "[pick-first-leaf-lb %p] " + +type pickfirstBuilder struct{} + +func (pickfirstBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { + b := &pickfirstBalancer{ + cc: cc, + addressList: addressList{}, + subConns: resolver.NewAddressMap(), + state: connectivity.Connecting, + mu: sync.Mutex{}, + } + b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b)) + return b +} + +func (b pickfirstBuilder) Name() string { + return Name +} + +func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + var cfg pfConfig + if err := json.Unmarshal(js, &cfg); err != nil { + return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err) + } + return cfg, nil +} + +type pfConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + + // If set to true, instructs the LB policy to shuffle the order of the list + // of endpoints received from the name resolver before attempting to + // connect to them. + ShuffleAddressList bool `json:"shuffleAddressList"` +} + +// scData keeps track of the current state of the subConn. +// It is not safe for concurrent access. +type scData struct { + // The following fields are initialized at build time and read-only after + // that. + subConn balancer.SubConn + addr resolver.Address + + state connectivity.State + lastErr error +} + +func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) { + sd := &scData{ + state: connectivity.Idle, + addr: addr, + } + sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{ + StateListener: func(state balancer.SubConnState) { + b.updateSubConnState(sd, state) + }, + }) + if err != nil { + return nil, err + } + sd.subConn = sc + return sd, nil +} + +type pickfirstBalancer struct { + // The following fields are initialized at build time and read-only after + // that and therefore do not need to be guarded by a mutex. + logger *internalgrpclog.PrefixLogger + cc balancer.ClientConn + + // The mutex is used to ensure synchronization of updates triggered + // from the idle picker and the already serialized resolver, + // SubConn state updates. + mu sync.Mutex + state connectivity.State + // scData for active subonns mapped by address. + subConns *resolver.AddressMap + addressList addressList + firstPass bool + numTF int +} + +// ResolverError is called by the ClientConn when the name resolver produces +// an error or when pickfirst determined the resolver update to be invalid. +func (b *pickfirstBalancer) ResolverError(err error) { + b.mu.Lock() + defer b.mu.Unlock() + b.resolverErrorLocked(err) +} + +func (b *pickfirstBalancer) resolverErrorLocked(err error) { + if b.logger.V(2) { + b.logger.Infof("Received error from the name resolver: %v", err) + } + + // The picker will not change since the balancer does not currently + // report an error. If the balancer hasn't received a single good resolver + // update yet, transition to TRANSIENT_FAILURE. + if b.state != connectivity.TransientFailure && b.addressList.size() > 0 { + if b.logger.V(2) { + b.logger.Infof("Ignoring resolver error because balancer is using a previous good update.") + } + return + } + + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)}, + }) +} + +func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error { + b.mu.Lock() + defer b.mu.Unlock() + if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 { + // Cleanup state pertaining to the previous resolver state. + // Treat an empty address list like an error by calling b.ResolverError. + b.state = connectivity.TransientFailure + b.closeSubConnsLocked() + b.addressList.updateAddrs(nil) + b.resolverErrorLocked(errors.New("produced zero addresses")) + return balancer.ErrBadResolverState + } + cfg, ok := state.BalancerConfig.(pfConfig) + if state.BalancerConfig != nil && !ok { + return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v: %w", state.BalancerConfig, state.BalancerConfig, balancer.ErrBadResolverState) + } + + if b.logger.V(2) { + b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState)) + } + + var newAddrs []resolver.Address + if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 { + // Perform the optional shuffling described in gRFC A62. The shuffling + // will change the order of endpoints but not touch the order of the + // addresses within each endpoint. - A61 + if cfg.ShuffleAddressList { + endpoints = append([]resolver.Endpoint{}, endpoints...) + internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] }) + } + + // "Flatten the list by concatenating the ordered list of addresses for + // each of the endpoints, in order." - A61 + for _, endpoint := range endpoints { + // "In the flattened list, interleave addresses from the two address + // families, as per RFC-8305 section 4." - A61 + // TODO: support the above language. + newAddrs = append(newAddrs, endpoint.Addresses...) + } + } else { + // Endpoints not set, process addresses until we migrate resolver + // emissions fully to Endpoints. The top channel does wrap emitted + // addresses with endpoints, however some balancers such as weighted + // target do not forward the corresponding correct endpoints down/split + // endpoints properly. Once all balancers correctly forward endpoints + // down, can delete this else conditional. + newAddrs = state.ResolverState.Addresses + if cfg.ShuffleAddressList { + newAddrs = append([]resolver.Address{}, newAddrs...) + internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] }) + } + } + + // If an address appears in multiple endpoints or in the same endpoint + // multiple times, we keep it only once. We will create only one SubConn + // for the address because an AddressMap is used to store SubConns. + // Not de-duplicating would result in attempting to connect to the same + // SubConn multiple times in the same pass. We don't want this. + newAddrs = deDupAddresses(newAddrs) + + // Since we have a new set of addresses, we are again at first pass. + b.firstPass = true + + // If the previous ready SubConn exists in new address list, + // keep this connection and don't create new SubConns. + prevAddr := b.addressList.currentAddress() + prevAddrsCount := b.addressList.size() + b.addressList.updateAddrs(newAddrs) + if b.state == connectivity.Ready && b.addressList.seekTo(prevAddr) { + return nil + } + + b.reconcileSubConnsLocked(newAddrs) + // If it's the first resolver update or the balancer was already READY + // (but the new address list does not contain the ready SubConn) or + // CONNECTING, enter CONNECTING. + // We may be in TRANSIENT_FAILURE due to a previous empty address list, + // we should still enter CONNECTING because the sticky TF behaviour + // mentioned in A62 applies only when the TRANSIENT_FAILURE is reported + // due to connectivity failures. + if b.state == connectivity.Ready || b.state == connectivity.Connecting || prevAddrsCount == 0 { + // Start connection attempt at first address. + b.state = connectivity.Connecting + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: &picker{err: balancer.ErrNoSubConnAvailable}, + }) + b.requestConnectionLocked() + } else if b.state == connectivity.TransientFailure { + // If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until + // we're READY. See A62. + b.requestConnectionLocked() + } + return nil +} + +// UpdateSubConnState is unused as a StateListener is always registered when +// creating SubConns. +func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state balancer.SubConnState) { + b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", subConn, state) +} + +func (b *pickfirstBalancer) Close() { + b.mu.Lock() + defer b.mu.Unlock() + b.closeSubConnsLocked() + b.state = connectivity.Shutdown +} + +// ExitIdle moves the balancer out of idle state. It can be called concurrently +// by the idlePicker and clientConn so access to variables should be +// synchronized. +func (b *pickfirstBalancer) ExitIdle() { + b.mu.Lock() + defer b.mu.Unlock() + if b.state == connectivity.Idle && b.addressList.currentAddress() == b.addressList.first() { + b.firstPass = true + b.requestConnectionLocked() + } +} + +func (b *pickfirstBalancer) closeSubConnsLocked() { + for _, sd := range b.subConns.Values() { + sd.(*scData).subConn.Shutdown() + } + b.subConns = resolver.NewAddressMap() +} + +// deDupAddresses ensures that each address appears only once in the slice. +func deDupAddresses(addrs []resolver.Address) []resolver.Address { + seenAddrs := resolver.NewAddressMap() + retAddrs := []resolver.Address{} + + for _, addr := range addrs { + if _, ok := seenAddrs.Get(addr); ok { + continue + } + retAddrs = append(retAddrs, addr) + } + return retAddrs +} + +func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address) { + // Remove old subConns that were not in new address list. + oldAddrsMap := resolver.NewAddressMap() + for _, k := range b.subConns.Keys() { + oldAddrsMap.Set(k, true) + } + + // Flatten the new endpoint addresses. + newAddrsMap := resolver.NewAddressMap() + for _, addr := range newAddrs { + newAddrsMap.Set(addr, true) + } + + // Shut them down and remove them. + for _, oldAddr := range oldAddrsMap.Keys() { + if _, ok := newAddrsMap.Get(oldAddr); ok { + continue + } + val, _ := b.subConns.Get(oldAddr) + val.(*scData).subConn.Shutdown() + b.subConns.Delete(oldAddr) + } +} + +// shutdownRemainingLocked shuts down remaining subConns. Called when a subConn +// becomes ready, which means that all other subConn must be shutdown. +func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) { + for _, v := range b.subConns.Values() { + sd := v.(*scData) + if sd.subConn != selected.subConn { + sd.subConn.Shutdown() + } + } + b.subConns = resolver.NewAddressMap() + b.subConns.Set(selected.addr, selected) +} + +// requestConnectionLocked starts connecting on the subchannel corresponding to +// the current address. If no subchannel exists, one is created. If the current +// subchannel is in TransientFailure, a connection to the next address is +// attempted until a subchannel is found. +func (b *pickfirstBalancer) requestConnectionLocked() { + if !b.addressList.isValid() { + return + } + var lastErr error + for valid := true; valid; valid = b.addressList.increment() { + curAddr := b.addressList.currentAddress() + sd, ok := b.subConns.Get(curAddr) + if !ok { + var err error + // We want to assign the new scData to sd from the outer scope, + // hence we can't use := below. + sd, err = b.newSCData(curAddr) + if err != nil { + // This should never happen, unless the clientConn is being shut + // down. + if b.logger.V(2) { + b.logger.Infof("Failed to create a subConn for address %v: %v", curAddr.String(), err) + } + // Do nothing, the LB policy will be closed soon. + return + } + b.subConns.Set(curAddr, sd) + } + + scd := sd.(*scData) + switch scd.state { + case connectivity.Idle: + scd.subConn.Connect() + case connectivity.TransientFailure: + // Try the next address. + lastErr = scd.lastErr + continue + case connectivity.Ready: + // Should never happen. + b.logger.Errorf("Requesting a connection even though we have a READY SubConn") + case connectivity.Shutdown: + // Should never happen. + b.logger.Errorf("SubConn with state SHUTDOWN present in SubConns map") + case connectivity.Connecting: + // Wait for the SubConn to report success or failure. + } + return + } + // All the remaining addresses in the list are in TRANSIENT_FAILURE, end the + // first pass. + b.endFirstPassLocked(lastErr) +} + +func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) { + b.mu.Lock() + defer b.mu.Unlock() + oldState := sd.state + sd.state = newState.ConnectivityState + // Previously relevant SubConns can still callback with state updates. + // To prevent pickers from returning these obsolete SubConns, this logic + // is included to check if the current list of active SubConns includes this + // SubConn. + if activeSD, found := b.subConns.Get(sd.addr); !found || activeSD != sd { + return + } + if newState.ConnectivityState == connectivity.Shutdown { + return + } + + if newState.ConnectivityState == connectivity.Ready { + b.shutdownRemainingLocked(sd) + if !b.addressList.seekTo(sd.addr) { + // This should not fail as we should have only one SubConn after + // entering READY. The SubConn should be present in the addressList. + b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses) + return + } + b.state = connectivity.Ready + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.Ready, + Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, + }) + return + } + + // If the LB policy is READY, and it receives a subchannel state change, + // it means that the READY subchannel has failed. + // A SubConn can also transition from CONNECTING directly to IDLE when + // a transport is successfully created, but the connection fails + // before the SubConn can send the notification for READY. We treat + // this as a successful connection and transition to IDLE. + if (b.state == connectivity.Ready && newState.ConnectivityState != connectivity.Ready) || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) { + // Once a transport fails, the balancer enters IDLE and starts from + // the first address when the picker is used. + b.shutdownRemainingLocked(sd) + b.state = connectivity.Idle + b.addressList.reset() + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.Idle, + Picker: &idlePicker{exitIdle: sync.OnceFunc(b.ExitIdle)}, + }) + return + } + + if b.firstPass { + switch newState.ConnectivityState { + case connectivity.Connecting: + // The balancer can be in either IDLE, CONNECTING or + // TRANSIENT_FAILURE. If it's in TRANSIENT_FAILURE, stay in + // TRANSIENT_FAILURE until it's READY. See A62. + // If the balancer is already in CONNECTING, no update is needed. + if b.state == connectivity.Idle { + b.state = connectivity.Connecting + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: &picker{err: balancer.ErrNoSubConnAvailable}, + }) + } + case connectivity.TransientFailure: + sd.lastErr = newState.ConnectionError + // Since we're re-using common SubConns while handling resolver + // updates, we could receive an out of turn TRANSIENT_FAILURE from + // a pass over the previous address list. We ignore such updates. + + if curAddr := b.addressList.currentAddress(); !equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) { + return + } + if b.addressList.increment() { + b.requestConnectionLocked() + return + } + // End of the first pass. + b.endFirstPassLocked(newState.ConnectionError) + } + return + } + + // We have finished the first pass, keep re-connecting failing SubConns. + switch newState.ConnectivityState { + case connectivity.TransientFailure: + b.numTF = (b.numTF + 1) % b.subConns.Len() + sd.lastErr = newState.ConnectionError + if b.numTF%b.subConns.Len() == 0 { + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: &picker{err: newState.ConnectionError}, + }) + } + // We don't need to request re-resolution since the SubConn already + // does that before reporting TRANSIENT_FAILURE. + // TODO: #7534 - Move re-resolution requests from SubConn into + // pick_first. + case connectivity.Idle: + sd.subConn.Connect() + } +} + +func (b *pickfirstBalancer) endFirstPassLocked(lastErr error) { + b.firstPass = false + b.numTF = 0 + b.state = connectivity.TransientFailure + + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: &picker{err: lastErr}, + }) + // Start re-connecting all the SubConns that are already in IDLE. + for _, v := range b.subConns.Values() { + sd := v.(*scData) + if sd.state == connectivity.Idle { + sd.subConn.Connect() + } + } +} + +type picker struct { + result balancer.PickResult + err error +} + +func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { + return p.result, p.err +} + +// idlePicker is used when the SubConn is IDLE and kicks the SubConn into +// CONNECTING when Pick is called. +type idlePicker struct { + exitIdle func() +} + +func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) { + i.exitIdle() + return balancer.PickResult{}, balancer.ErrNoSubConnAvailable +} + +// addressList manages sequentially iterating over addresses present in a list +// of endpoints. It provides a 1 dimensional view of the addresses present in +// the endpoints. +// This type is not safe for concurrent access. +type addressList struct { + addresses []resolver.Address + idx int +} + +func (al *addressList) isValid() bool { + return al.idx < len(al.addresses) +} + +func (al *addressList) size() int { + return len(al.addresses) +} + +// increment moves to the next index in the address list. +// This method returns false if it went off the list, true otherwise. +func (al *addressList) increment() bool { + if !al.isValid() { + return false + } + al.idx++ + return al.idx < len(al.addresses) +} + +// currentAddress returns the current address pointed to in the addressList. +// If the list is in an invalid state, it returns an empty address instead. +func (al *addressList) currentAddress() resolver.Address { + if !al.isValid() { + return resolver.Address{} + } + return al.addresses[al.idx] +} + +// first returns the first address in the list. If the list is empty, it returns +// an empty address instead. +func (al *addressList) first() resolver.Address { + if len(al.addresses) == 0 { + return resolver.Address{} + } + return al.addresses[0] +} + +func (al *addressList) reset() { + al.idx = 0 +} + +func (al *addressList) updateAddrs(addrs []resolver.Address) { + al.addresses = addrs + al.reset() +} + +// seekTo returns false if the needle was not found and the current index was +// left unchanged. +func (al *addressList) seekTo(needle resolver.Address) bool { + for ai, addr := range al.addresses { + if !equalAddressIgnoringBalAttributes(&addr, &needle) { + continue + } + al.idx = ai + return true + } + return false +} + +// equalAddressIgnoringBalAttributes returns true is a and b are considered +// equal. This is different from the Equal method on the resolver.Address type +// which considers all fields to determine equality. Here, we only consider +// fields that are meaningful to the SubConn. +func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool { + return a.Addr == b.Addr && a.ServerName == b.ServerName && + a.Attributes.Equal(b.Attributes) && + a.Metadata == b.Metadata +} diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go new file mode 100644 index 000000000000..2ab40ef1615a --- /dev/null +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go @@ -0,0 +1,957 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package pickfirstleaf_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/pickfirst" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/status" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +const ( + // Default timeout for tests in this package. + defaultTestTimeout = 10 * time.Second + // Default short timeout, to be used when waiting for events which are not + // expected to happen. + defaultTestShortTimeout = 100 * time.Millisecond + stateStoringBalancerName = "state_storing" +) + +var stateStoringServiceConfig = fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateStoringBalancerName) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// setupPickFirstLeaf performs steps required for pick_first tests. It starts a +// bunch of backends exporting the TestService, creates a ClientConn to them +// with service config specifying the use of the state_storing LB policy. +func setupPickFirstLeaf(t *testing.T, backendCount int, opts ...grpc.DialOption) (*grpc.ClientConn, *manual.Resolver, *backendManager) { + t.Helper() + r := manual.NewBuilderWithScheme("whatever") + backends := make([]*stubserver.StubServer, backendCount) + addrs := make([]resolver.Address, backendCount) + + for i := 0; i < backendCount; i++ { + backend := stubserver.StartTestService(t, nil) + t.Cleanup(func() { + backend.Stop() + }) + backends[i] = backend + addrs[i] = resolver.Address{Addr: backend.Address} + } + + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithResolvers(r), + grpc.WithDefaultServiceConfig(stateStoringServiceConfig), + } + dopts = append(dopts, opts...) + cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...) + if err != nil { + t.Fatalf("grpc.NewClient() failed: %v", err) + } + t.Cleanup(func() { cc.Close() }) + + // At this point, the resolver has not returned any addresses to the channel. + // This RPC must block until the context expires. + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + client := testgrpc.NewTestServiceClient(cc) + if _, err := client.EmptyCall(sCtx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = %s, want %s", status.Code(err), codes.DeadlineExceeded) + } + return cc, r, &backendManager{backends} +} + +// TestPickFirstLeaf_SimpleResolverUpdate tests the behaviour of the pick first +// policy when given an list of addresses. The following steps are carried +// out in order: +// 1. A list of addresses are given through the resolver. Only one +// of the servers is running. +// 2. RPCs are sent to verify they reach the running server. +// +// The state transitions of the ClientConn and all the subconns created are +// verified. +func (s) TestPickFirstLeaf_SimpleResolverUpdate_FirstServerReady(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + + cc, r, bm := setupPickFirstLeaf(t, 2) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + r.UpdateState(resolver.State{Addresses: addrs}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready}, + } + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +func (s) TestPickFirstLeaf_SimpleResolverUpdate_FirstServerUnReady(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + + cc, r, bm := setupPickFirstLeaf(t, 2) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + bm.stopAllExcept(1) + + r.UpdateState(resolver.State{Addresses: addrs}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +func (s) TestPickFirstLeaf_SimpleResolverUpdate_DuplicateAddrs(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + + cc, r, bm := setupPickFirstLeaf(t, 2) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + bm.stopAllExcept(1) + + // Add a duplicate entry in the addresslist + r.UpdateState(resolver.State{ + Addresses: append([]resolver.Address{addrs[0]}, addrs...), + }) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +// TestPickFirstLeaf_ResolverUpdates_DisjointLists tests the behaviour of the pick first +// policy when the following steps are carried out in order: +// 1. A list of addresses are given through the resolver. Only one +// of the servers is running. +// 2. RPCs are sent to verify they reach the running server. +// 3. A second resolver update is sent. Again, only one of the servers is +// running. This may not be the same server as before. +// 4. RPCs are sent to verify they reach the running server. +// +// The state transitions of the ClientConn and all the subconns created are +// verified. +func (s) TestPickFirstLeaf_ResolverUpdates_DisjointLists(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + cc, r, bm := setupPickFirstLeaf(t, 4) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + bm.backends[0].S.Stop() + bm.backends[0].S = nil + r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + bm.backends[2].S.Stop() + bm.backends[2].S = nil + r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[2], addrs[3]}}) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[3]); err != nil { + t.Fatal(err) + } + wantSCStates = []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[2]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[3]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +func (s) TestPickFirstLeaf_ResolverUpdates_ActiveBackendInUpdatedList(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + cc, r, bm := setupPickFirstLeaf(t, 3) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + bm.backends[0].S.Stop() + bm.backends[0].S = nil + r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + bm.backends[2].S.Stop() + bm.backends[2].S = nil + r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[2], addrs[1]}}) + + // Verify that the ClientConn stays in READY. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + testutils.AwaitNoStateChange(sCtx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + wantSCStates = []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +func (s) TestPickFirstLeaf_ResolverUpdates_InActiveBackendInUpdatedList(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + cc, r, bm := setupPickFirstLeaf(t, 3) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + bm.backends[0].S.Stop() + bm.backends[0].S = nil + r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + bm.backends[2].S.Stop() + bm.backends[2].S = nil + if err := bm.backends[0].StartServer(); err != nil { + t.Fatalf("Failed to re-start test backend: %v", err) + } + r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[2]}}) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + wantSCStates = []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +func (s) TestPickFirstLeaf_ResolverUpdates_IdenticalLists(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + cc, r, bm := setupPickFirstLeaf(t, 2) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + bm.backends[0].S.Stop() + bm.backends[0].S = nil + r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}}) + + // Verify that the ClientConn stays in READY. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + testutils.AwaitNoStateChange(sCtx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + wantSCStates = []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +// TestPickFirstLeaf_StopConnectedServer tests the behaviour of the pick first +// policy when the connected server is shut down. It carries out the following +// steps in order: +// 1. A list of addresses are given through the resolver. Only one +// of the servers is running. +// 2. The running server is stopped, causing the ClientConn to enter IDLE. +// 3. A (possibly different) server is started. +// 4. RPCs are made to kick the ClientConn out of IDLE. The test verifies that +// the RPCs reach the running server. +// +// The test verifies the ClientConn state transitions. +func (s) TestPickFirstLeaf_StopConnectedServer_FirstServerRestart(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + cc, r, bm := setupPickFirstLeaf(t, 2) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + // shutdown all active backends except the target. + bm.stopAllExcept(0) + + r.UpdateState(resolver.State{Addresses: addrs}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + // Shut down the connected server. + bm.backends[0].S.Stop() + bm.backends[0].S = nil + testutils.AwaitState(ctx, t, cc, connectivity.Idle) + + // Start the new target server. + if err := bm.backends[0].StartServer(); err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + connectivity.Idle, + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +func (s) TestPickFirstLeaf_StopConnectedServer_SecondServerRestart(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + cc, r, bm := setupPickFirstLeaf(t, 2) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + // shutdown all active backends except the target. + bm.stopAllExcept(1) + + r.UpdateState(resolver.State{Addresses: addrs}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + // Shut down the connected server. + bm.backends[1].S.Stop() + bm.backends[1].S = nil + testutils.AwaitState(ctx, t, cc, connectivity.Idle) + + // Start the new target server. + if err := bm.backends[1].StartServer(); err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + + wantSCStates = []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + connectivity.Idle, + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +func (s) TestPickFirstLeaf_StopConnectedServer_SecondServerToFirst(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + cc, r, bm := setupPickFirstLeaf(t, 2) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + // shutdown all active backends except the target. + bm.stopAllExcept(1) + + r.UpdateState(resolver.State{Addresses: addrs}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + // Shut down the connected server. + bm.backends[1].S.Stop() + bm.backends[1].S = nil + testutils.AwaitState(ctx, t, cc, connectivity.Idle) + + // Start the new target server. + if err := bm.backends[0].StartServer(); err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + wantSCStates = []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + connectivity.Idle, + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +func (s) TestPickFirstLeaf_StopConnectedServer_FirstServerToSecond(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + balCh := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balCh}) + cc, r, bm := setupPickFirstLeaf(t, 2) + addrs := bm.resolverAddrs() + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + // shutdown all active backends except the target. + bm.stopAllExcept(0) + + r.UpdateState(resolver.State{Addresses: addrs}) + var bal *stateStoringBalancer + select { + case bal = <-balCh: + case <-ctx.Done(): + t.Fatal("Context expired while waiting for balancer to be built") + } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + wantSCStates := []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + // Shut down the connected server. + bm.backends[0].S.Stop() + bm.backends[0].S = nil + testutils.AwaitState(ctx, t, cc, connectivity.Idle) + + // Start the new target server. + if err := bm.backends[1].StartServer(); err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } + + wantSCStates = []scState{ + {Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown}, + {Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready}, + } + + if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" { + t.Errorf("subconn states mismatch (-want +got):\n%s", diff) + } + + wantConnStateTransitions := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + connectivity.Idle, + connectivity.Connecting, + connectivity.Ready, + } + if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +// TestPickFirstLeaf_EmptyAddressList carries out the following steps in order: +// 1. Send a resolver update with one running backend. +// 2. Send an empty address list causing the balancer to enter TRANSIENT_FAILURE. +// 3. Send a resolver update with one running backend. +// The test verifies the ClientConn state transitions. +func (s) TestPickFirstLeaf_EmptyAddressList(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + balChan := make(chan *stateStoringBalancer, 1) + balancer.Register(&stateStoringBalancerBuilder{balancer: balChan}) + cc, r, bm := setupPickFirstLeaf(t, 1) + addrs := bm.resolverAddrs() + + stateSubscriber := &ccStateSubscriber{} + internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber) + + r.UpdateState(resolver.State{Addresses: addrs}) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + r.UpdateState(resolver.State{}) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) + + // The balancer should have entered transient failure. + // It should transition to CONNECTING from TRANSIENT_FAILURE as sticky TF + // only applies when the initial TF is reported due to connection failures + // and not bad resolver states. + r.UpdateState(resolver.State{Addresses: addrs}) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + wantTransitions := []connectivity.State{ + // From first resolver update. + connectivity.Connecting, + connectivity.Ready, + // From second update. + connectivity.TransientFailure, + // From third update. + connectivity.Connecting, + connectivity.Ready, + } + + if diff := cmp.Diff(wantTransitions, stateSubscriber.transitions); diff != "" { + t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff) + } +} + +// stateStoringBalancer stores the state of the subconns being created. +type stateStoringBalancer struct { + balancer.Balancer + mu sync.Mutex + scStates []*scState +} + +func (b *stateStoringBalancer) Close() { + b.Balancer.Close() +} + +func (b *stateStoringBalancer) ExitIdle() { + if ib, ok := b.Balancer.(balancer.ExitIdler); ok { + ib.ExitIdle() + } +} + +type stateStoringBalancerBuilder struct { + balancer chan *stateStoringBalancer +} + +func (b *stateStoringBalancerBuilder) Name() string { + return stateStoringBalancerName +} + +func (b *stateStoringBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + bal := &stateStoringBalancer{} + bal.Balancer = balancer.Get(pickfirstleaf.Name).Build(&stateStoringCCWrapper{cc, bal}, opts) + b.balancer <- bal + return bal +} + +func (b *stateStoringBalancer) subConnStates() []scState { + b.mu.Lock() + defer b.mu.Unlock() + ret := []scState{} + for _, s := range b.scStates { + ret = append(ret, *s) + } + return ret +} + +func (b *stateStoringBalancer) addSCState(state *scState) { + b.mu.Lock() + b.scStates = append(b.scStates, state) + b.mu.Unlock() +} + +type stateStoringCCWrapper struct { + balancer.ClientConn + b *stateStoringBalancer +} + +func (ccw *stateStoringCCWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + oldListener := opts.StateListener + scs := &scState{ + State: connectivity.Idle, + Addrs: addrs, + } + ccw.b.addSCState(scs) + opts.StateListener = func(s balancer.SubConnState) { + ccw.b.mu.Lock() + scs.State = s.ConnectivityState + ccw.b.mu.Unlock() + oldListener(s) + } + return ccw.ClientConn.NewSubConn(addrs, opts) +} + +type scState struct { + State connectivity.State + Addrs []resolver.Address +} + +type backendManager struct { + backends []*stubserver.StubServer +} + +func (b *backendManager) stopAllExcept(index int) { + for idx, b := range b.backends { + if idx != index { + b.S.Stop() + b.S = nil + } + } +} + +// resolverAddrs returns a list of resolver addresses for the stub server +// backends. Useful when pushing addresses to the manual resolver. +func (b *backendManager) resolverAddrs() []resolver.Address { + addrs := make([]resolver.Address, len(b.backends)) + for i, backend := range b.backends { + addrs[i] = resolver.Address{Addr: backend.Address} + } + return addrs +} + +type ccStateSubscriber struct { + transitions []connectivity.State +} + +func (c *ccStateSubscriber) OnMessage(msg any) { + c.transitions = append(c.transitions, msg.(connectivity.State)) +} diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go new file mode 100644 index 000000000000..84b3cb65bed4 --- /dev/null +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go @@ -0,0 +1,259 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package pickfirstleaf + +import ( + "context" + "fmt" + "testing" + "time" + + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" +) + +const ( + // Default timeout for tests in this package. + defaultTestTimeout = 10 * time.Second + // Default short timeout, to be used when waiting for events which are not + // expected to happen. + defaultTestShortTimeout = 100 * time.Millisecond +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// TestAddressList_Iteration verifies the behaviour of the addressList while +// iterating through the entries. +func (s) TestAddressList_Iteration(t *testing.T) { + addrs := []resolver.Address{ + { + Addr: "192.168.1.1", + ServerName: "test-host-1", + Attributes: attributes.New("key-1", "val-1"), + BalancerAttributes: attributes.New("bal-key-1", "bal-val-1"), + }, + { + Addr: "192.168.1.2", + ServerName: "test-host-2", + Attributes: attributes.New("key-2", "val-2"), + BalancerAttributes: attributes.New("bal-key-2", "bal-val-2"), + }, + { + Addr: "192.168.1.3", + ServerName: "test-host-3", + Attributes: attributes.New("key-3", "val-3"), + BalancerAttributes: attributes.New("bal-key-3", "bal-val-3"), + }, + } + + addressList := addressList{} + emptyAddress := resolver.Address{} + if got, want := addressList.first(), emptyAddress; got != want { + t.Fatalf("addressList.first() = %v, want %v", got, want) + } + + addressList.updateAddrs(addrs) + + if got, want := addressList.first(), addressList.currentAddress(); got != want { + t.Fatalf("addressList.first() = %v, want %v", got, want) + } + + if got, want := addressList.first(), addrs[0]; got != want { + t.Fatalf("addressList.first() = %v, want %v", got, want) + } + + for i := 0; i < len(addrs); i++ { + if got, want := addressList.isValid(), true; got != want { + t.Fatalf("addressList.isValid() = %t, want %t", got, want) + } + if got, want := addressList.currentAddress(), addrs[i]; !want.Equal(got) { + t.Errorf("addressList.currentAddress() = %v, want %v", got, want) + } + if got, want := addressList.increment(), i+1 < len(addrs); got != want { + t.Fatalf("addressList.increment() = %t, want %t", got, want) + } + } + + if got, want := addressList.isValid(), false; got != want { + t.Fatalf("addressList.isValid() = %t, want %t", got, want) + } + + // increment an invalid address list. + if got, want := addressList.increment(), false; got != want { + t.Errorf("addressList.increment() = %t, want %t", got, want) + } + + if got, want := addressList.isValid(), false; got != want { + t.Errorf("addressList.isValid() = %t, want %t", got, want) + } + + addressList.reset() + for i := 0; i < len(addrs); i++ { + if got, want := addressList.isValid(), true; got != want { + t.Fatalf("addressList.isValid() = %t, want %t", got, want) + } + if got, want := addressList.currentAddress(), addrs[i]; !want.Equal(got) { + t.Errorf("addressList.currentAddress() = %v, want %v", got, want) + } + if got, want := addressList.increment(), i+1 < len(addrs); got != want { + t.Fatalf("addressList.increment() = %t, want %t", got, want) + } + } +} + +// TestAddressList_SeekTo verifies the behaviour of addressList.seekTo. +func (s) TestAddressList_SeekTo(t *testing.T) { + addrs := []resolver.Address{ + { + Addr: "192.168.1.1", + ServerName: "test-host-1", + Attributes: attributes.New("key-1", "val-1"), + BalancerAttributes: attributes.New("bal-key-1", "bal-val-1"), + }, + { + Addr: "192.168.1.2", + ServerName: "test-host-2", + Attributes: attributes.New("key-2", "val-2"), + BalancerAttributes: attributes.New("bal-key-2", "bal-val-2"), + }, + { + Addr: "192.168.1.3", + ServerName: "test-host-3", + Attributes: attributes.New("key-3", "val-3"), + BalancerAttributes: attributes.New("bal-key-3", "bal-val-3"), + }, + } + + addressList := addressList{} + addressList.updateAddrs(addrs) + + // Try finding an address in the list. + key := resolver.Address{ + Addr: "192.168.1.2", + ServerName: "test-host-2", + Attributes: attributes.New("key-2", "val-2"), + BalancerAttributes: attributes.New("ignored", "bal-val-2"), + } + + if got, want := addressList.seekTo(key), true; got != want { + t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want) + } + + // It should be possible to increment once more now that the pointer has advanced. + if got, want := addressList.increment(), true; got != want { + t.Errorf("addressList.increment() = %t, want %t", got, want) + } + + if got, want := addressList.increment(), false; got != want { + t.Errorf("addressList.increment() = %t, want %t", got, want) + } + + // Seek to the key again, it is behind the pointer now. + if got, want := addressList.seekTo(key), true; got != want { + t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want) + } + + // Seek to a key not in the list. + key = resolver.Address{ + Addr: "192.168.1.5", + ServerName: "test-host-5", + Attributes: attributes.New("key-5", "val-5"), + BalancerAttributes: attributes.New("ignored", "bal-val-5"), + } + + if got, want := addressList.seekTo(key), false; got != want { + t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want) + } + + // It should be possible to increment once more since the pointer has not advanced. + if got, want := addressList.increment(), true; got != want { + t.Errorf("addressList.increment() = %t, want %t", got, want) + } + + if got, want := addressList.increment(), false; got != want { + t.Errorf("addressList.increment() = %t, want %t", got, want) + } +} + +// TestPickFirstLeaf_TFPickerUpdate sends TRANSIENT_FAILURE SubConn state updates +// for each SubConn managed by a pickfirst balancer. It verifies that the picker +// is updated with the expected frequency. +func (s) TestPickFirstLeaf_TFPickerUpdate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc := testutils.NewBalancerClientConn(t) + bal := pickfirstBuilder{}.Build(cc, balancer.BuildOptions{}) + defer bal.Close() + ccState := balancer.ClientConnState{ + ResolverState: resolver.State{ + Endpoints: []resolver.Endpoint{ + {Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}}, + {Addresses: []resolver.Address{{Addr: "2.2.2.2:2"}}}, + }, + }, + } + if err := bal.UpdateClientConnState(ccState); err != nil { + t.Fatalf("UpdateClientConnState(%v) returned error: %v", ccState, err) + } + + // PF should report TRANSIENT_FAILURE only once all the sunbconns have failed + // once. + tfErr := fmt.Errorf("test err: connection refused") + sc1 := <-cc.NewSubConnCh + sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: tfErr}) + + if err := cc.WaitForPickerWithErr(ctx, balancer.ErrNoSubConnAvailable); err != nil { + t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", balancer.ErrNoSubConnAvailable, err) + } + + sc2 := <-cc.NewSubConnCh + sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: tfErr}) + + if err := cc.WaitForPickerWithErr(ctx, tfErr); err != nil { + t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", tfErr, err) + } + + // Subsequent TRANSIENT_FAILUREs should be reported only after seeing "# of SubConns" + // TRANSIENT_FAILUREs. + newTfErr := fmt.Errorf("test err: unreachable") + sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: newTfErr}) + select { + case <-time.After(defaultTestShortTimeout): + case p := <-cc.NewPickerCh: + sc, err := p.Pick(balancer.PickInfo{}) + t.Fatalf("Unexpected picker update: %v, %v", sc, err) + } + + sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: newTfErr}) + if err := cc.WaitForPickerWithErr(ctx, newTfErr); err != nil { + t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", newTfErr, err) + } +} diff --git a/balancer/rls/balancer_test.go b/balancer/rls/balancer_test.go index 16fa77354cde..8c77e3428950 100644 --- a/balancer/rls/balancer_test.go +++ b/balancer/rls/balancer_test.go @@ -1096,6 +1096,9 @@ func (s) TestUpdateStatePauses(t *testing.T) { Init: func(bd *stub.BalancerData) { bd.Data = balancer.Get(pickfirst.Name).Build(bd.ClientConn, bd.BuildOptions) }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, ParseConfig: func(sc json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { cfg := &childPolicyConfig{} if err := json.Unmarshal(sc, cfg); err != nil { diff --git a/clientconn.go b/clientconn.go index b47efb33c0e9..4a408d621692 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1249,6 +1249,8 @@ func (ac *addrConn) resetTransportAndUnlock() { ac.mu.Unlock() if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil { + // TODO: #7534 - Move re-resolution requests into the pick_first LB policy + // to ensure one resolution request per pass instead of per subconn failure. ac.cc.resolveNow(resolver.ResolveNowOptions{}) ac.mu.Lock() if acCtx.Err() != nil { diff --git a/clientconn_test.go b/clientconn_test.go index 0cb09001da04..778fe8269e98 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -37,6 +37,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" internalbackoff "google.golang.org/grpc/internal/backoff" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/transport" @@ -418,17 +419,21 @@ func (s) TestWithTransportCredentialsTLS(t *testing.T) { // When creating a transport configured with n addresses, only calculate the // backoff once per "round" of attempts instead of once per address (n times -// per "round" of attempts). -func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) { +// per "round" of attempts) for old pickfirst and once per address for new pickfirst. +func (s) TestDial_BackoffCountPerRetryGroup(t *testing.T) { var attempts uint32 + wantBackoffs := uint32(1) + if envconfig.NewPickFirstEnabled { + wantBackoffs = 2 + } getMinConnectTimeout := func() time.Duration { - if atomic.AddUint32(&attempts, 1) == 1 { + if atomic.AddUint32(&attempts, 1) <= wantBackoffs { // Once all addresses are exhausted, hang around and wait for the // client.Close to happen rather than re-starting a new round of // attempts. return time.Hour } - t.Error("only one attempt backoff calculation, but got more") + t.Errorf("only %d attempt backoff calculation, but got more", wantBackoffs) return 0 } @@ -499,6 +504,10 @@ func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) { t.Fatal("timed out waiting for test to finish") case <-server2Done: } + + if got, want := atomic.LoadUint32(&attempts), wantBackoffs; got != want { + t.Errorf("attempts = %d, want %d", got, want) + } } func (s) TestDialContextCancel(t *testing.T) { @@ -1062,18 +1071,14 @@ func (s) TestUpdateAddresses_NoopIfCalledWithSameAddresses(t *testing.T) { } // Grab the addrConn and call tryUpdateAddrs. - var ac *addrConn client.mu.Lock() for clientAC := range client.conns { - ac = clientAC - break + // Call UpdateAddresses with the same list of addresses, it should be a noop + // (even when the SubConn is Connecting, and doesn't have a curAddr). + clientAC.acbw.UpdateAddresses(clientAC.addrs) } client.mu.Unlock() - // Call UpdateAddresses with the same list of addresses, it should be a noop - // (even when the SubConn is Connecting, and doesn't have a curAddr). - ac.acbw.UpdateAddresses(addrsList) - // We've called tryUpdateAddrs - now let's make server2 close the // connection and check that it continues to server3. close(closeServer2) diff --git a/internal/balancergroup/balancergroup_test.go b/internal/balancergroup/balancergroup_test.go index 8d22c9ac587e..c154c029d8f2 100644 --- a/internal/balancergroup/balancergroup_test.go +++ b/internal/balancergroup/balancergroup_test.go @@ -575,6 +575,7 @@ func (s) TestBalancerGracefulSwitch(t *testing.T) { bg.UpdateClientConnState(testBalancerIDs[0], balancer.ClientConnState{ResolverState: resolver.State{Addresses: testBackendAddrs[0:2]}}) bg.Start() + defer bg.Close() m1 := make(map[resolver.Address]balancer.SubConn) scs := make(map[balancer.SubConn]bool) @@ -604,6 +605,9 @@ func (s) TestBalancerGracefulSwitch(t *testing.T) { Init: func(bd *stub.BalancerData) { bd.Data = balancer.Get(pickfirst.Name).Build(bd.ClientConn, bd.BuildOptions) }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { ccs.ResolverState.Addresses = ccs.ResolverState.Addresses[1:] bal := bd.Data.(balancer.Balancer) diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go index 452985f8d8f1..6e7dd6b77270 100644 --- a/internal/envconfig/envconfig.go +++ b/internal/envconfig/envconfig.go @@ -50,6 +50,11 @@ var ( // xDS fallback is turned on. If this is unset or is false, only the first // xDS server in the list of server configs will be used. XDSFallbackSupport = boolFromEnv("GRPC_EXPERIMENTAL_XDS_FALLBACK", false) + // NewPickFirstEnabled is set if the new pickfirst leaf policy is to be used + // instead of the exiting pickfirst implementation. This can be enabled by + // setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST" + // to "true". + NewPickFirstEnabled = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", false) ) func boolFromEnv(envVar string, def bool) bool { diff --git a/test/balancer_switching_test.go b/test/balancer_switching_test.go index 8074b59b3a47..e5da19d30d0f 100644 --- a/test/balancer_switching_test.go +++ b/test/balancer_switching_test.go @@ -483,6 +483,9 @@ func (s) TestBalancerSwitch_Graceful(t *testing.T) { pf := balancer.Get(pickfirst.Name) bd.Data = pf.Build(bd.ClientConn, bd.BuildOptions) }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { bal := bd.Data.(balancer.Balancer) close(ccUpdateCh) diff --git a/test/balancer_test.go b/test/balancer_test.go index f27ec4d3fe90..c2405808f2ea 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -850,6 +850,9 @@ func (s) TestMetadataInPickResult(t *testing.T) { cc := &testCCWrapper{ClientConn: bd.ClientConn} bd.Data = balancer.Get(pickfirst.Name).Build(cc, bd.BuildOptions) }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { bal := bd.Data.(balancer.Balancer) return bal.UpdateClientConnState(ccs) diff --git a/test/clientconn_state_transition_test.go b/test/clientconn_state_transition_test.go index 6e9bfb37289d..56ebafaa9308 100644 --- a/test/clientconn_state_transition_test.go +++ b/test/clientconn_state_transition_test.go @@ -34,6 +34,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/balancer/stub" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" @@ -323,6 +324,13 @@ func (s) TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) client, err := grpc.Dial("whatever:///this-gets-overwritten", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), + grpc.WithConnectParams(grpc.ConnectParams{ + // Set a really long back-off delay to ensure the first subConn does + // not enter IDLE before the second subConn connects. + Backoff: backoff.Config{ + BaseDelay: 1 * time.Hour, + }, + }), grpc.WithResolvers(rb)) if err != nil { t.Fatal(err) @@ -334,6 +342,16 @@ func (s) TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) connectivity.Connecting, connectivity.Ready, } + if envconfig.NewPickFirstEnabled { + want = []connectivity.State{ + // The first subconn fails. + connectivity.Connecting, + connectivity.TransientFailure, + // The second subconn connects. + connectivity.Connecting, + connectivity.Ready, + } + } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() for i := 0; i < len(want); i++ { diff --git a/test/resolver_update_test.go b/test/resolver_update_test.go index a7526b9d43c5..619979b9b045 100644 --- a/test/resolver_update_test.go +++ b/test/resolver_update_test.go @@ -162,6 +162,9 @@ func (s) TestResolverUpdate_InvalidServiceConfigAfterGoodUpdate(t *testing.T) { pf := balancer.Get(pickfirst.Name) bd.Data = pf.Build(bd.ClientConn, bd.BuildOptions) }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, ParseConfig: func(lbCfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { cfg := &wrappingBalancerConfig{} if err := json.Unmarshal(lbCfg, cfg); err != nil { diff --git a/xds/internal/balancer/clustermanager/clustermanager_test.go b/xds/internal/balancer/clustermanager/clustermanager_test.go index 079214651871..b606cb9e5e34 100644 --- a/xds/internal/balancer/clustermanager/clustermanager_test.go +++ b/xds/internal/balancer/clustermanager/clustermanager_test.go @@ -607,6 +607,7 @@ func TestClusterGracefulSwitch(t *testing.T) { builder := balancer.Get(balancerName) parser := builder.(balancer.ConfigParser) bal := builder.Build(cc, balancer.BuildOptions{}) + defer bal.Close() configJSON1 := `{ "children": { @@ -644,6 +645,9 @@ func TestClusterGracefulSwitch(t *testing.T) { Init: func(bd *stub.BalancerData) { bd.Data = balancer.Get(pickfirst.Name).Build(bd.ClientConn, bd.BuildOptions) }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { bal := bd.Data.(balancer.Balancer) return bal.UpdateClientConnState(ccs) @@ -730,6 +734,7 @@ func (s) TestUpdateStatePauses(t *testing.T) { builder := balancer.Get(balancerName) parser := builder.(balancer.ConfigParser) bal := builder.Build(cc, balancer.BuildOptions{}) + defer bal.Close() configJSON1 := `{ "children": {