diff --git a/client/grpcutil/grpcutil.go b/client/grpcutil/grpcutil.go index 169eb2a4634..59b7224f29e 100644 --- a/client/grpcutil/grpcutil.go +++ b/client/grpcutil/grpcutil.go @@ -90,12 +90,13 @@ func GetOrCreateGRPCConn(ctx context.Context, clientConns *sync.Map, addr string if err != nil { return nil, err } - existing, loaded := clientConns.LoadOrStore(addr, cc) + conn, loaded := clientConns.LoadOrStore(addr, cc) if !loaded { // Successfully stored the connection. return cc, nil } cc.Close() + cc = conn.(*grpc.ClientConn) log.Debug("use existing connection", zap.String("target", cc.Target()), zap.String("state", cc.GetState().String())) - return existing.(*grpc.ClientConn), nil + return cc, nil } diff --git a/server/grpc_service.go b/server/grpc_service.go index 57715789033..5f323b667b8 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -202,17 +202,9 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { return errors.WithStack(err) } - streamCtx := stream.Context() - forwardedHost := grpcutil.GetForwardedHost(streamCtx) - if s.IsAPIServiceMode() || !s.isLocalRequest(forwardedHost) { - if s.IsAPIServiceMode() { - var ok bool - forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) - if !ok || len(forwardedHost) == 0 { - return ErrNotFoundTSOAddr - } - } - + if forwardedHost, err := s.getForwardedHost(ctx, stream.Context()); err != nil { + return err + } else if len(forwardedHost) > 0 { clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) if err != nil { return errors.WithStack(err) @@ -261,6 +253,19 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { } } +func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context) (forwardedHost string, err error) { + if s.IsAPIServiceMode() { + var ok bool + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) + if !ok || len(forwardedHost) == 0 { + return "", ErrNotFoundTSOAddr + } + } else if fh := grpcutil.GetForwardedHost(streamCtx); !s.isLocalRequest(fh) { + forwardedHost = fh + } + return forwardedHost, nil +} + // Bootstrap implements gRPC PDServer. func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) {