diff --git a/balancer/pickfirstleaf/pickfirstleaf.go b/balancer/pickfirstleaf/pickfirstleaf.go index 5b842f409922..cfd10d1cddc9 100644 --- a/balancer/pickfirstleaf/pickfirstleaf.go +++ b/balancer/pickfirstleaf/pickfirstleaf.go @@ -30,6 +30,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" @@ -287,7 +288,7 @@ func (b *pickfirstBalancer) Close() { func (b *pickfirstBalancer) ExitIdle() { b.mu.Lock() defer b.mu.Unlock() - if b.state == connectivity.Idle { + if b.state == connectivity.Idle && b.addressList.currentAddress() == b.addressList.first() { b.firstPass = true b.requestConnectionLocked() } @@ -546,11 +547,14 @@ func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { // idlePicker is used when the SubConn is IDLE and kicks the SubConn into // CONNECTING when Pick is called. type idlePicker struct { - exitIdle func() + connectionRequested atomic.Bool + exitIdle func() } func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - i.exitIdle() + if i.connectionRequested.CompareAndSwap(false, true) { + i.exitIdle() + } return balancer.PickResult{}, balancer.ErrNoSubConnAvailable } @@ -590,6 +594,15 @@ func (al *addressList) currentAddress() resolver.Address { return al.addresses[al.idx] } +// first returns the first address in the list. If the list is empty, it returns +// an empty address instead. +func (al *addressList) first() resolver.Address { + if len(al.addresses) == 0 { + return resolver.Address{} + } + return al.addresses[0] +} + func (al *addressList) reset() { al.idx = 0 } diff --git a/balancer/pickfirstleaf/pickfirstleaf_test.go b/balancer/pickfirstleaf/pickfirstleaf_test.go index 7dfe6f9d0227..dc62c45f4719 100644 --- a/balancer/pickfirstleaf/pickfirstleaf_test.go +++ b/balancer/pickfirstleaf/pickfirstleaf_test.go @@ -73,8 +73,21 @@ func (s) TestAddressList_Iteration(t *testing.T) { } addressList := addressList{} + emptyAddress := resolver.Address{} + if got, want := addressList.first(), emptyAddress; got != want { + t.Fatalf("addressList.first() = %v, want %v", got, want) + } + addressList.updateAddrs(addrs) + if got, want := addressList.first(), addressList.currentAddress(); got != want { + t.Fatalf("addressList.first() = %v, want %v", got, want) + } + + if got, want := addressList.first(), addrs[0]; got != want { + t.Fatalf("addressList.first() = %v, want %v", got, want) + } + for i := 0; i < len(addrs); i++ { if got, want := addressList.isValid(), true; got != want { t.Fatalf("addressList.isValid() = %t, want %t", got, want)