diff --git a/pkg/utils/grpcutil/grpcutil.go b/pkg/utils/grpcutil/grpcutil.go index 1bfb64868f34..9823adf2cfbe 100644 --- a/pkg/utils/grpcutil/grpcutil.go +++ b/pkg/utils/grpcutil/grpcutil.go @@ -252,6 +252,9 @@ func CheckStream(ctx context.Context, cancel context.CancelFunc, done chan struc // NeedRebuildConnection checks if the error is a connection error. func NeedRebuildConnection(err error) bool { + if err == nil { + return false + } return err == io.EOF || strings.Contains(err.Error(), codes.Unavailable.String()) || // Unavailable indicates the service is currently unavailable. This is a most likely a transient condition. strings.Contains(err.Error(), codes.DeadlineExceeded.String()) || // DeadlineExceeded means operation expired before completion. diff --git a/server/forward.go b/server/forward.go index 65750fcd4be9..2d379cfee2a5 100644 --- a/server/forward.go +++ b/server/forward.go @@ -89,6 +89,7 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { forwardStream tsopb.TSO_TsoClient forwardCtx context.Context cancelForward context.CancelFunc + tsoStreamErr error lastForwardedHost string ) defer func() { @@ -96,6 +97,9 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { if cancelForward != nil { cancelForward() } + if grpcutil.NeedRebuildConnection(tsoStreamErr) { + s.closeDelegateClient(lastForwardedHost) + } }() maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) @@ -131,26 +135,31 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName) if !ok || len(forwardedHost) == 0 { - return errors.WithStack(ErrNotFoundTSOAddr) + tsoStreamErr = errors.WithStack(ErrNotFoundTSOAddr) + return tsoStreamErr } if forwardStream == nil || lastForwardedHost != forwardedHost { if cancelForward != nil { cancelForward() } + s.closeDelegateClient(lastForwardedHost) clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) if err != nil { - return errors.WithStack(err) + tsoStreamErr = errors.WithStack(err) + return tsoStreamErr } forwardStream, forwardCtx, cancelForward, err = s.createTSOForwardStream(stream.Context(), clientConn) if err != nil { - return errors.WithStack(err) + tsoStreamErr = errors.WithStack(err) + return tsoStreamErr } lastForwardedHost = forwardedHost } tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) if err != nil { + tsoStreamErr = errors.WithStack(err) return errors.WithStack(err) } @@ -363,25 +372,12 @@ func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string return conn.(*grpc.ClientConn), nil } -func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context, serviceName ...string) (forwardedHost string, err error) { - if s.IsAPIServiceMode() { - var ok bool - if len(serviceName) == 0 { - return "", ErrNotFoundService - } - forwardedHost, ok = s.GetServicePrimaryAddr(ctx, serviceName[0]) - if !ok || len(forwardedHost) == 0 { - switch serviceName[0] { - case utils.TSOServiceName: - return "", ErrNotFoundTSOAddr - case utils.SchedulingServiceName: - return "", ErrNotFoundSchedulingAddr - } - } - } else if fh := grpcutil.GetForwardedHost(streamCtx); !s.isLocalRequest(fh) { - forwardedHost = fh +func (s *GrpcServer) closeDelegateClient(forwardedHost string) { + client, ok := s.clientConns.LoadAndDelete(forwardedHost) + if !ok { + return } - return forwardedHost, nil + client.(*grpc.ClientConn).Close() } func (s *GrpcServer) isLocalRequest(host string) bool { diff --git a/server/grpc_service.go b/server/grpc_service.go index ef7020f7fee8..8c0349d6d1a9 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -547,9 +547,8 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { return errors.WithStack(err) } - if forwardedHost, err := s.getForwardedHost(ctx, stream.Context(), utils.TSOServiceName); err != nil { - return err - } else if len(forwardedHost) > 0 { + forwardedHost := grpcutil.GetForwardedHost(stream.Context()) + if !s.isLocalRequest(forwardedHost) { clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) if err != nil { return errors.WithStack(err) @@ -1332,6 +1331,8 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error if cancel != nil { cancel() } + + s.closeDelegateClient(lastForwardedSchedulingHost) client, err := s.getDelegateClient(s.ctx, forwardedSchedulingHost) if err != nil { errRegionHeartbeatClient.Inc() @@ -1370,6 +1371,9 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error } if err := forwardSchedulingStream.Send(schedulingpbReq); err != nil { forwardSchedulingStream = nil + if grpcutil.NeedRebuildConnection(err) { + s.closeDelegateClient(lastForwardedSchedulingHost) + } errRegionHeartbeatSend.Inc() log.Error("failed to send request to scheduling service", zap.Error(err)) }