Skip to content

Commit

Permalink
Ensure exitIdle doesn't increment the address list multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal committed Oct 7, 2024
1 parent c4b4aa4 commit f0e479e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
19 changes: 16 additions & 3 deletions balancer/pickfirstleaf/pickfirstleaf.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"

"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
Expand Down Expand Up @@ -287,7 +288,7 @@ func (b *pickfirstBalancer) Close() {
func (b *pickfirstBalancer) ExitIdle() {
b.mu.Lock()
defer b.mu.Unlock()
if b.state == connectivity.Idle {
if b.state == connectivity.Idle && b.addressList.currentAddress() == b.addressList.first() {
b.firstPass = true
b.requestConnectionLocked()
}
Expand Down Expand Up @@ -546,11 +547,14 @@ func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
// idlePicker is used when the SubConn is IDLE and kicks the SubConn into
// CONNECTING when Pick is called.
type idlePicker struct {
exitIdle func()
connectionRequested atomic.Bool
exitIdle func()
}

func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
i.exitIdle()
if i.connectionRequested.CompareAndSwap(false, true) {
i.exitIdle()
}
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}

Expand Down Expand Up @@ -590,6 +594,15 @@ func (al *addressList) currentAddress() resolver.Address {
return al.addresses[al.idx]
}

// first returns the first address in the list. If the list is empty, it returns
// an empty address instead.
func (al *addressList) first() resolver.Address {
if len(al.addresses) == 0 {
return resolver.Address{}
}
return al.addresses[0]
}

func (al *addressList) reset() {
al.idx = 0
}
Expand Down
13 changes: 13 additions & 0 deletions balancer/pickfirstleaf/pickfirstleaf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,21 @@ func (s) TestAddressList_Iteration(t *testing.T) {
}

addressList := addressList{}
emptyAddress := resolver.Address{}
if got, want := addressList.first(), emptyAddress; got != want {
t.Fatalf("addressList.first() = %v, want %v", got, want)
}

addressList.updateAddrs(addrs)

if got, want := addressList.first(), addressList.currentAddress(); got != want {
t.Fatalf("addressList.first() = %v, want %v", got, want)
}

if got, want := addressList.first(), addrs[0]; got != want {
t.Fatalf("addressList.first() = %v, want %v", got, want)
}

for i := 0; i < len(addrs); i++ {
if got, want := addressList.isValid(), true; got != want {
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
Expand Down

0 comments on commit f0e479e

Please sign in to comment.