From 1fb7e8ac9e1e4e4dadef52a5317e54a36a8eba14 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Fri, 19 May 2023 13:18:04 -0700 Subject: [PATCH 1/3] client: support a 1:1 mapping with acbws and addrConns --- balancer_conn_wrappers.go | 72 +++------------------- clientconn.go | 125 ++++++++++++++++++-------------------- picker_wrapper.go | 12 ++-- 3 files changed, 73 insertions(+), 136 deletions(-) diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index 4f9944697dde..7dea991d8d26 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -133,24 +133,9 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat // updateSubConnState is invoked by grpc to push a subConn state update to the // underlying balancer. func (ccb *ccBalancerWrapper) updateSubConnState(sc balancer.SubConn, s connectivity.State, err error) { - // When updating addresses for a SubConn, if the address in use is not in - // the new addresses, the old ac will be tearDown() and a new ac will be - // created. tearDown() generates a state change with Shutdown state, we - // don't want the balancer to receive this state change. So before - // tearDown() on the old ac, ac.acbw (acWrapper) will be set to nil, and - // this function will be called with (nil, Shutdown). We don't need to call - // balancer method in this case. - // - // TODO: Suppress the above mentioned state change to Shutdown, so we don't - // have to handle it here. - if sc == nil { - return - } - ccb.mu.Lock() ccb.serializer.Schedule(func(_ context.Context) { ccb.balancer.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: s, ConnectionError: err}) }) - ccb.mu.Unlock() } func (ccb *ccBalancerWrapper) resolverError(err error) { @@ -324,9 +309,7 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer return nil, err } acbw := &acBalancerWrapper{ac: ac, producers: make(map[balancer.ProducerBuilder]*refCountedProducer)} - acbw.ac.mu.Lock() ac.acbw = acbw - acbw.ac.mu.Unlock() return acbw, nil } @@ -347,7 +330,7 @@ func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { if !ok { return } - ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) + ccb.cc.removeAddrConn(acbw.ac, errConnDrain) } func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { @@ -391,63 +374,24 @@ func (ccb *ccBalancerWrapper) Target() string { // acBalancerWrapper is a wrapper on top of ac for balancers. // It implements balancer.SubConn interface. type acBalancerWrapper struct { + ac *addrConn // read-only + mu sync.Mutex - ac *addrConn producers map[balancer.ProducerBuilder]*refCountedProducer } -func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { - acbw.mu.Lock() - defer acbw.mu.Unlock() - if len(addrs) <= 0 { - acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain) - return - } - if !acbw.ac.tryUpdateAddrs(addrs) { - cc := acbw.ac.cc - opts := acbw.ac.scopts - acbw.ac.mu.Lock() - // Set old ac.acbw to nil so the Shutdown state update will be ignored - // by balancer. - // - // TODO(bar) the state transition could be wrong when tearDown() old ac - // and creating new ac, fix the transition. - acbw.ac.acbw = nil - acbw.ac.mu.Unlock() - acState := acbw.ac.getState() - acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain) - - if acState == connectivity.Shutdown { - return - } +func (acbw *acBalancerWrapper) String() string { + return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelzID.Int()) +} - newAC, err := cc.newAddrConn(addrs, opts) - if err != nil { - channelz.Warningf(logger, acbw.ac.channelzID, "acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err) - return - } - acbw.ac = newAC - newAC.mu.Lock() - newAC.acbw = acbw - newAC.mu.Unlock() - if acState != connectivity.Idle { - go newAC.connect() - } - } +func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { + acbw.ac.updateAddrs(addrs) } func (acbw *acBalancerWrapper) Connect() { - acbw.mu.Lock() - defer acbw.mu.Unlock() go acbw.ac.connect() } -func (acbw *acBalancerWrapper) getAddrConn() *addrConn { - acbw.mu.Lock() - defer acbw.mu.Unlock() - return acbw.ac -} - // NewStream begins a streaming RPC on the addrConn. If the addrConn is not // ready, blocks until it is or ctx expires. Returns an error when the context // expires or the addrConn is shut down. diff --git a/clientconn.go b/clientconn.go index 1def61e5a23d..f2f4a18a28cd 100644 --- a/clientconn.go +++ b/clientconn.go @@ -970,9 +970,6 @@ func (ac *addrConn) connect() error { ac.mu.Unlock() return nil } - // Update connectivity state within the lock to prevent subsequent or - // concurrent calls from resetting the transport more than once. - ac.updateConnectivityState(connectivity.Connecting, nil) ac.mu.Unlock() ac.resetTransport() @@ -991,58 +988,53 @@ func equalAddresses(a, b []resolver.Address) bool { return true } -// tryUpdateAddrs tries to update ac.addrs with the new addresses list. -// -// If ac is TransientFailure, it updates ac.addrs and returns true. The updated -// addresses will be picked up by retry in the next iteration after backoff. -// -// If ac is Shutdown or Idle, it updates ac.addrs and returns true. -// -// If the addresses is the same as the old list, it does nothing and returns -// true. -// -// If ac is Connecting, it returns false. The caller should tear down the ac and -// create a new one. Note that the backoff will be reset when this happens. -// -// If ac is Ready, it checks whether current connected address of ac is in the -// new addrs list. -// - If true, it updates ac.addrs and returns true. The ac will keep using -// the existing connection. -// - If false, it does nothing and returns false. -func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool { +// updateAddrs updates ac.addrs with the new addresses list and handles active +// connections or connection attempts. +func (ac *addrConn) updateAddrs(addrs []resolver.Address) { ac.mu.Lock() - defer ac.mu.Unlock() - channelz.Infof(logger, ac.channelzID, "addrConn: tryUpdateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs) - if ac.state == connectivity.Shutdown || - ac.state == connectivity.TransientFailure || - ac.state == connectivity.Idle { - ac.addrs = addrs - return true - } + channelz.Infof(logger, ac.channelzID, "addrConn: updateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs) if equalAddresses(ac.addrs, addrs) { - return true + ac.mu.Unlock() + return } - if ac.state == connectivity.Connecting { - return false + ac.addrs = addrs + + if ac.state == connectivity.Shutdown || + ac.state == connectivity.TransientFailure || + ac.state == connectivity.Idle { + // We were not connecting, so do nothing but update the addresses. + ac.mu.Unlock() + return } - // ac.state is Ready, try to find the connected address. - var curAddrFound bool - for _, a := range addrs { - a.ServerName = ac.cc.getServerName(a) - if reflect.DeepEqual(ac.curAddr, a) { - curAddrFound = true - break + if ac.state == connectivity.Ready { + // try to find the connected address. + for _, a := range addrs { + a.ServerName = ac.cc.getServerName(a) + if reflect.DeepEqual(ac.curAddr, a) { + // We are connected to a valid address, so do nothing bu update + // the addresses. + ac.mu.Unlock() + return + } } } - channelz.Infof(logger, ac.channelzID, "addrConn: tryUpdateAddrs curAddrFound: %v", curAddrFound) - if curAddrFound { - ac.addrs = addrs - } - return curAddrFound + // We are either connected to the wrong address or currently connecting. + // Stop the current iteration and restart. + + ac.cancel() + ac.ctx, ac.cancel = context.WithCancel(ac.cc.ctx) + + curTr := ac.transport + ac.transport = nil + ac.mu.Unlock() + curTr.GracefulClose() + // Since we were connecting/connected, we should start a new connection + // attempt. + go ac.resetTransport() } // getServerName determines the serverName to be used in the connection @@ -1301,7 +1293,8 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) { func (ac *addrConn) resetTransport() { ac.mu.Lock() - if ac.state == connectivity.Shutdown { + acCtx := ac.ctx + if acCtx.Err() != nil { ac.mu.Unlock() return } @@ -1329,15 +1322,14 @@ func (ac *addrConn) resetTransport() { ac.updateConnectivityState(connectivity.Connecting, nil) ac.mu.Unlock() - if err := ac.tryAllAddrs(addrs, connectDeadline); err != nil { + if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil { ac.cc.resolveNow(resolver.ResolveNowOptions{}) // After exhausting all addresses, the addrConn enters // TRANSIENT_FAILURE. - ac.mu.Lock() - if ac.state == connectivity.Shutdown { - ac.mu.Unlock() + if acCtx.Err() != nil { return } + ac.mu.Lock() ac.updateConnectivityState(connectivity.TransientFailure, err) // Backoff. @@ -1352,13 +1344,13 @@ func (ac *addrConn) resetTransport() { ac.mu.Unlock() case <-b: timer.Stop() - case <-ac.ctx.Done(): + case <-acCtx.Done(): timer.Stop() return } ac.mu.Lock() - if ac.state != connectivity.Shutdown { + if acCtx.Err() == nil { ac.updateConnectivityState(connectivity.Idle, err) } ac.mu.Unlock() @@ -1373,14 +1365,13 @@ func (ac *addrConn) resetTransport() { // tryAllAddrs tries to creates a connection to the addresses, and stop when at // the first successful one. It returns an error if no address was successfully // connected, or updates ac appropriately with the new transport. -func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.Time) error { +func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, connectDeadline time.Time) error { var firstConnErr error for _, addr := range addrs { - ac.mu.Lock() - if ac.state == connectivity.Shutdown { - ac.mu.Unlock() + if ctx.Err() != nil { return errConnClosing } + ac.mu.Lock() ac.cc.mu.RLock() ac.dopts.copts.KeepaliveParams = ac.cc.mkp @@ -1394,7 +1385,7 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T channelz.Infof(logger, ac.channelzID, "Subchannel picks a new address %q to connect", addr.Addr) - err := ac.createTransport(addr, copts, connectDeadline) + err := ac.createTransport(ctx, addr, copts, connectDeadline) if err == nil { return nil } @@ -1411,19 +1402,20 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T // createTransport creates a connection to addr. It returns an error if the // address was not successfully connected, or updates ac appropriately with the // new transport. -func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error { +func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error { addr.ServerName = ac.cc.getServerName(addr) - hctx, hcancel := context.WithCancel(ac.ctx) + hctx, hcancel := context.WithCancel(ctx) onClose := func(r transport.GoAwayReason) { ac.mu.Lock() defer ac.mu.Unlock() // adjust params based on GoAwayReason ac.adjustParams(r) - if ac.state == connectivity.Shutdown { - // Already shut down. tearDown() already cleared the transport and - // canceled hctx via ac.ctx, and we expected this connection to be - // closed, so do nothing here. + if ctx.Err() != nil { + // Already shut down or connection attempt canceled. tearDown() or + // updateAddrs() already cleared the transport and canceled hctx + // via ac.ctx, and we expected this connection to be closed, so do + // nothing here. return } hcancel() @@ -1442,7 +1434,7 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne ac.updateConnectivityState(connectivity.Idle, nil) } - connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline) + connectCtx, cancel := context.WithDeadline(ctx, connectDeadline) defer cancel() copts.ChannelzParentID = ac.channelzID @@ -1459,7 +1451,7 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne ac.mu.Lock() defer ac.mu.Unlock() - if ac.state == connectivity.Shutdown { + if ctx.Err() != nil { // This can happen if the subConn was removed while in `Connecting` // state. tearDown() would have set the state to `Shutdown`, but // would not have closed the transport since ac.transport would not @@ -1471,6 +1463,9 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne // The error we pass to Close() is immaterial since there are no open // streams at this point, so no trailers with error details will be sent // out. We just need to pass a non-nil error. + // + // This can also happen when updateAddrs is called during a connection + // attempt. go newTr.Close(transport.ErrConnClosing) return nil } diff --git a/picker_wrapper.go b/picker_wrapper.go index 8e24d864986d..02f975951242 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -68,10 +68,8 @@ func (pw *pickerWrapper) updatePicker(p balancer.Picker) { // - wraps the done function in the passed in result to increment the calls // failed or calls succeeded channelz counter before invoking the actual // done function. -func doneChannelzWrapper(acw *acBalancerWrapper, result *balancer.PickResult) { - acw.mu.Lock() - ac := acw.ac - acw.mu.Unlock() +func doneChannelzWrapper(acbw *acBalancerWrapper, result *balancer.PickResult) { + ac := acbw.ac ac.incrCallsStarted() done := result.Done result.Done = func(b balancer.DoneInfo) { @@ -157,14 +155,14 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error()) } - acw, ok := pickResult.SubConn.(*acBalancerWrapper) + acbw, ok := pickResult.SubConn.(*acBalancerWrapper) if !ok { logger.Errorf("subconn returned from pick is type %T, not *acBalancerWrapper", pickResult.SubConn) continue } - if t := acw.getAddrConn().getReadyTransport(); t != nil { + if t := acbw.ac.getReadyTransport(); t != nil { if channelz.IsOn() { - doneChannelzWrapper(acw, &pickResult) + doneChannelzWrapper(acbw, &pickResult) return t, pickResult, nil } return t, pickResult, nil From 5021fdf80ad79afa2b70258cc7712e369f0623b9 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Fri, 19 May 2023 15:44:46 -0700 Subject: [PATCH 2/3] review comments --- balancer_conn_wrappers.go | 2 +- clientconn.go | 20 +++++++++++++------- stream.go | 17 +++++++++++------ 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index 7dea991d8d26..6814578bbbf8 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -300,7 +300,7 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer return nil, fmt.Errorf("grpc: cannot create SubConn when balancer is closed or idle") } - if len(addrs) <= 0 { + if len(addrs) == 0 { return nil, fmt.Errorf("grpc: cannot create SubConn with empty address list") } ac, err := ccb.cc.newAddrConn(addrs, opts) diff --git a/clientconn.go b/clientconn.go index f2f4a18a28cd..5e45f01f91cf 100644 --- a/clientconn.go +++ b/clientconn.go @@ -24,7 +24,6 @@ import ( "fmt" "math" "net/url" - "reflect" "strings" "sync" "sync/atomic" @@ -1010,12 +1009,12 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) { } if ac.state == connectivity.Ready { - // try to find the connected address. + // Try to find the connected address. for _, a := range addrs { a.ServerName = ac.cc.getServerName(a) - if reflect.DeepEqual(ac.curAddr, a) { - // We are connected to a valid address, so do nothing bu update - // the addresses. + if a.Equal(ac.curAddr) { + // We are connected to a valid address, so do nothing but + // update the addresses. ac.mu.Unlock() return } @@ -1028,10 +1027,17 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) { ac.cancel() ac.ctx, ac.cancel = context.WithCancel(ac.cc.ctx) - curTr := ac.transport + // We have to defer here because GracefulClose => Close => onClose, which + // requires locking ac.mu. + defer ac.transport.GracefulClose() ac.transport = nil + + if len(addrs) == 0 { + ac.updateConnectivityState(connectivity.Idle, nil) + } + ac.mu.Unlock() - curTr.GracefulClose() + // Since we were connecting/connected, we should start a new connection // attempt. go ac.resetTransport() diff --git a/stream.go b/stream.go index 75ab86268ba1..10092685b228 100644 --- a/stream.go +++ b/stream.go @@ -1273,14 +1273,19 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin as.p = &parser{r: s} ac.incrCallsStarted() if desc != unaryStreamDesc { - // Listen on cc and stream contexts to cleanup when the user closes the - // ClientConn or cancels the stream context. In all other cases, an error - // should already be injected into the recv buffer by the transport, which - // the client will eventually receive, and then we will cancel the stream's - // context in clientStream.finish. + // Listen on stream context to cleanup when the stream context is + // canceled. Also listen for the addrConn's context in case the + // addrConn is closed or reconnects to a different address. In all + // other cases, an error should already be injected into the recv + // buffer by the transport, which the client will eventually receive, + // and then we will cancel the stream's context in + // addrConnStream.finish. go func() { + ac.mu.Lock() + acCtx := ac.ctx + ac.mu.Unlock() select { - case <-ac.ctx.Done(): + case <-acCtx.Done(): as.finish(status.Error(codes.Canceled, "grpc: the SubConn is closing")) case <-ctx.Done(): as.finish(toRPCErr(ctx.Err())) From 23ff0397cdc1695b6566a82b6cf11ad12cc10451 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Mon, 22 May 2023 13:06:46 -0700 Subject: [PATCH 3/3] rebase fixes --- balancer_conn_wrappers.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index 6814578bbbf8..4a4dea189d0e 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -133,9 +133,11 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat // updateSubConnState is invoked by grpc to push a subConn state update to the // underlying balancer. func (ccb *ccBalancerWrapper) updateSubConnState(sc balancer.SubConn, s connectivity.State, err error) { + ccb.mu.Lock() ccb.serializer.Schedule(func(_ context.Context) { ccb.balancer.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: s, ConnectionError: err}) }) + ccb.mu.Unlock() } func (ccb *ccBalancerWrapper) resolverError(err error) {