From 689d5a9c119d479d33660f82e2282166f77cf7d1 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Fri, 19 May 2023 13:18:04 -0700 Subject: [PATCH] client: support a 1:1 mapping with acbws and addrConns --- balancer_conn_wrappers.go | 70 +++------------------- clientconn.go | 122 +++++++++++++++++++------------------- picker_wrapper.go | 12 ++-- 3 files changed, 73 insertions(+), 131 deletions(-) diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index 1865a3f09c2b..fdd801305951 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -107,19 +107,6 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat // updateSubConnState is invoked by grpc to push a subConn state update to the // underlying balancer. func (ccb *ccBalancerWrapper) updateSubConnState(sc balancer.SubConn, s connectivity.State, err error) { - // When updating addresses for a SubConn, if the address in use is not in - // the new addresses, the old ac will be tearDown() and a new ac will be - // created. tearDown() generates a state change with Shutdown state, we - // don't want the balancer to receive this state change. So before - // tearDown() on the old ac, ac.acbw (acWrapper) will be set to nil, and - // this function will be called with (nil, Shutdown). We don't need to call - // balancer method in this case. - // - // TODO: Suppress the above mentioned state change to Shutdown, so we don't - // have to handle it here. - if sc == nil { - return - } ccb.serializer.Schedule(func(_ context.Context) { ccb.balancer.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: s, ConnectionError: err}) }) @@ -193,9 +180,7 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer return nil, err } acbw := &acBalancerWrapper{ac: ac, producers: make(map[balancer.ProducerBuilder]*refCountedProducer)} - acbw.ac.mu.Lock() ac.acbw = acbw - acbw.ac.mu.Unlock() return acbw, nil } @@ -204,7 +189,7 @@ func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { if !ok { return } - ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) + ccb.cc.removeAddrConn(acbw.ac, errConnDrain) } func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { @@ -236,63 +221,24 @@ func (ccb *ccBalancerWrapper) Target() string { // acBalancerWrapper is a wrapper on top of ac for balancers. // It implements balancer.SubConn interface. type acBalancerWrapper struct { + ac *addrConn // read-only + mu sync.Mutex - ac *addrConn producers map[balancer.ProducerBuilder]*refCountedProducer } -func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { - acbw.mu.Lock() - defer acbw.mu.Unlock() - if len(addrs) <= 0 { - acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain) - return - } - if !acbw.ac.tryUpdateAddrs(addrs) { - cc := acbw.ac.cc - opts := acbw.ac.scopts - acbw.ac.mu.Lock() - // Set old ac.acbw to nil so the Shutdown state update will be ignored - // by balancer. - // - // TODO(bar) the state transition could be wrong when tearDown() old ac - // and creating new ac, fix the transition. - acbw.ac.acbw = nil - acbw.ac.mu.Unlock() - acState := acbw.ac.getState() - acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain) - - if acState == connectivity.Shutdown { - return - } +func (acbw *acBalancerWrapper) String() string { + return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelzID.Int()) +} - newAC, err := cc.newAddrConn(addrs, opts) - if err != nil { - channelz.Warningf(logger, acbw.ac.channelzID, "acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err) - return - } - acbw.ac = newAC - newAC.mu.Lock() - newAC.acbw = acbw - newAC.mu.Unlock() - if acState != connectivity.Idle { - go newAC.connect() - } - } +func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { + acbw.ac.updateAddrs(addrs) } func (acbw *acBalancerWrapper) Connect() { - acbw.mu.Lock() - defer acbw.mu.Unlock() go acbw.ac.connect() } -func (acbw *acBalancerWrapper) getAddrConn() *addrConn { - acbw.mu.Lock() - defer acbw.mu.Unlock() - return acbw.ac -} - // NewStream begins a streaming RPC on the addrConn. If the addrConn is not // ready, blocks until it is or ctx expires. Returns an error when the context // expires or the addrConn is shut down. diff --git a/clientconn.go b/clientconn.go index 50d08a49a205..1ffcf969d17b 100644 --- a/clientconn.go +++ b/clientconn.go @@ -836,9 +836,6 @@ func (ac *addrConn) connect() error { ac.mu.Unlock() return nil } - // Update connectivity state within the lock to prevent subsequent or - // concurrent calls from resetting the transport more than once. - ac.updateConnectivityState(connectivity.Connecting, nil) ac.mu.Unlock() ac.resetTransport() @@ -857,58 +854,53 @@ func equalAddresses(a, b []resolver.Address) bool { return true } -// tryUpdateAddrs tries to update ac.addrs with the new addresses list. -// -// If ac is TransientFailure, it updates ac.addrs and returns true. The updated -// addresses will be picked up by retry in the next iteration after backoff. -// -// If ac is Shutdown or Idle, it updates ac.addrs and returns true. -// -// If the addresses is the same as the old list, it does nothing and returns -// true. -// -// If ac is Connecting, it returns false. The caller should tear down the ac and -// create a new one. Note that the backoff will be reset when this happens. -// -// If ac is Ready, it checks whether current connected address of ac is in the -// new addrs list. -// - If true, it updates ac.addrs and returns true. The ac will keep using -// the existing connection. -// - If false, it does nothing and returns false. -func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool { +// updateAddrs updates ac.addrs with the new addresses list and handles active +// connections or connection attempts. +func (ac *addrConn) updateAddrs(addrs []resolver.Address) { ac.mu.Lock() - defer ac.mu.Unlock() - channelz.Infof(logger, ac.channelzID, "addrConn: tryUpdateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs) - if ac.state == connectivity.Shutdown || - ac.state == connectivity.TransientFailure || - ac.state == connectivity.Idle { - ac.addrs = addrs - return true - } + channelz.Infof(logger, ac.channelzID, "addrConn: updateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs) if equalAddresses(ac.addrs, addrs) { - return true + ac.mu.Unlock() + return } - if ac.state == connectivity.Connecting { - return false + ac.addrs = addrs + + if ac.state == connectivity.Shutdown || + ac.state == connectivity.TransientFailure || + ac.state == connectivity.Idle { + // We were not connecting, so do nothing but update the addresses. + ac.mu.Unlock() + return } - // ac.state is Ready, try to find the connected address. - var curAddrFound bool - for _, a := range addrs { - a.ServerName = ac.cc.getServerName(a) - if reflect.DeepEqual(ac.curAddr, a) { - curAddrFound = true - break + if ac.state == connectivity.Ready { + // try to find the connected address. + for _, a := range addrs { + a.ServerName = ac.cc.getServerName(a) + if reflect.DeepEqual(ac.curAddr, a) { + // We are connected to a valid address, so do nothing bu update + // the addresses. + ac.mu.Unlock() + return + } } } - channelz.Infof(logger, ac.channelzID, "addrConn: tryUpdateAddrs curAddrFound: %v", curAddrFound) - if curAddrFound { - ac.addrs = addrs - } - return curAddrFound + // We are either connected to the wrong address or currently connecting. + // Stop the current iteration and restart. + + ac.cancel() + ac.ctx, ac.cancel = context.WithCancel(ac.cc.ctx) + + curTr := ac.transport + ac.transport = nil + ac.mu.Unlock() + curTr.GracefulClose() + // Since we were connecting/connected, we should start a new connection + // attempt. + go ac.resetTransport() } // getServerName determines the serverName to be used in the connection @@ -1166,7 +1158,7 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) { func (ac *addrConn) resetTransport() { ac.mu.Lock() - if ac.state == connectivity.Shutdown { + if ac.ctx.Err() != nil { ac.mu.Unlock() return } @@ -1192,17 +1184,17 @@ func (ac *addrConn) resetTransport() { connectDeadline := time.Now().Add(dialDuration) ac.updateConnectivityState(connectivity.Connecting, nil) + acCtx := ac.ctx ac.mu.Unlock() - if err := ac.tryAllAddrs(addrs, connectDeadline); err != nil { + if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil { ac.cc.resolveNow(resolver.ResolveNowOptions{}) // After exhausting all addresses, the addrConn enters // TRANSIENT_FAILURE. - ac.mu.Lock() - if ac.state == connectivity.Shutdown { - ac.mu.Unlock() + if acCtx.Err() != nil { return } + ac.mu.Lock() ac.updateConnectivityState(connectivity.TransientFailure, err) // Backoff. @@ -1217,13 +1209,13 @@ func (ac *addrConn) resetTransport() { ac.mu.Unlock() case <-b: timer.Stop() - case <-ac.ctx.Done(): + case <-acCtx.Done(): timer.Stop() return } ac.mu.Lock() - if ac.state != connectivity.Shutdown { + if acCtx.Err() == nil { ac.updateConnectivityState(connectivity.Idle, err) } ac.mu.Unlock() @@ -1238,11 +1230,11 @@ func (ac *addrConn) resetTransport() { // tryAllAddrs tries to creates a connection to the addresses, and stop when at // the first successful one. It returns an error if no address was successfully // connected, or updates ac appropriately with the new transport. -func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.Time) error { +func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, connectDeadline time.Time) error { var firstConnErr error for _, addr := range addrs { ac.mu.Lock() - if ac.state == connectivity.Shutdown { + if ac.ctx.Err() != nil { ac.mu.Unlock() return errConnClosing } @@ -1259,7 +1251,7 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T channelz.Infof(logger, ac.channelzID, "Subchannel picks a new address %q to connect", addr.Addr) - err := ac.createTransport(addr, copts, connectDeadline) + err := ac.createTransport(ctx, addr, copts, connectDeadline) if err == nil { return nil } @@ -1276,19 +1268,20 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T // createTransport creates a connection to addr. It returns an error if the // address was not successfully connected, or updates ac appropriately with the // new transport. -func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error { +func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error { addr.ServerName = ac.cc.getServerName(addr) - hctx, hcancel := context.WithCancel(ac.ctx) + hctx, hcancel := context.WithCancel(ctx) onClose := func(r transport.GoAwayReason) { ac.mu.Lock() defer ac.mu.Unlock() // adjust params based on GoAwayReason ac.adjustParams(r) - if ac.state == connectivity.Shutdown { - // Already shut down. tearDown() already cleared the transport and - // canceled hctx via ac.ctx, and we expected this connection to be - // closed, so do nothing here. + if ctx.Err() != nil { + // Already shut down or connection attempt canceled. tearDown() or + // updateAddrs() already cleared the transport and canceled hctx + // via ac.ctx, and we expected this connection to be closed, so do + // nothing here. return } hcancel() @@ -1307,7 +1300,7 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne ac.updateConnectivityState(connectivity.Idle, nil) } - connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline) + connectCtx, cancel := context.WithDeadline(ctx, connectDeadline) defer cancel() copts.ChannelzParentID = ac.channelzID @@ -1346,6 +1339,11 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne ac.updateConnectivityState(connectivity.Idle, nil) return nil } + if ctx.Err() != nil { + // updateAddrs stopped this connection attempt just after it completed. + // Pretend it didn't happen. + return nil + } ac.curAddr = addr ac.transport = newTr ac.startHealthCheck(hctx) // Will set state to READY if appropriate. diff --git a/picker_wrapper.go b/picker_wrapper.go index c525dc070fc6..b05dc4b043cf 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -63,10 +63,8 @@ func (pw *pickerWrapper) updatePicker(p balancer.Picker) { // - wraps the done function in the passed in result to increment the calls // failed or calls succeeded channelz counter before invoking the actual // done function. -func doneChannelzWrapper(acw *acBalancerWrapper, result *balancer.PickResult) { - acw.mu.Lock() - ac := acw.ac - acw.mu.Unlock() +func doneChannelzWrapper(acbw *acBalancerWrapper, result *balancer.PickResult) { + ac := acbw.ac ac.incrCallsStarted() done := result.Done result.Done = func(b balancer.DoneInfo) { @@ -152,14 +150,14 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error()) } - acw, ok := pickResult.SubConn.(*acBalancerWrapper) + acbw, ok := pickResult.SubConn.(*acBalancerWrapper) if !ok { logger.Errorf("subconn returned from pick is type %T, not *acBalancerWrapper", pickResult.SubConn) continue } - if t := acw.getAddrConn().getReadyTransport(); t != nil { + if t := acbw.ac.getReadyTransport(); t != nil { if channelz.IsOn() { - doneChannelzWrapper(acw, &pickResult) + doneChannelzWrapper(acbw, &pickResult) return t, pickResult, nil } return t, pickResult, nil