Skip to content

Commit

Permalink
server: fix the race on (*Server).clusterID (#7773)
Browse files Browse the repository at this point in the history
close #7772

Fix the race on `(*Server).clusterID`.

Signed-off-by: JmPotato <ghzpotato@gmail.com>
  • Loading branch information
JmPotato authored Jan 29, 2024
1 parent bdbf328 commit e8da033
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
2 changes: 1 addition & 1 deletion server/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) {
}
request := &tsopb.TsoRequest{
Header: &tsopb.RequestHeader{
ClusterId: s.clusterID,
ClusterId: s.ClusterID(),
KeyspaceId: utils.DefaultKeyspaceID,
KeyspaceGroupId: utils.DefaultKeyspaceGroupID,
},
Expand Down
15 changes: 8 additions & 7 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
if s.IsClosed() {
return status.Errorf(codes.Unknown, "server not started")
}
if request.GetHeader().GetClusterId() != s.clusterID {
if clusterID := s.ClusterID(); request.GetHeader().GetClusterId() != clusterID {
return status.Errorf(codes.FailedPrecondition,
"mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId())
"mismatch cluster id, need %d but got %d", clusterID, request.GetHeader().GetClusterId())
}
count := request.GetCount()
ctx, task := trace.NewTask(ctx, "tso")
Expand Down Expand Up @@ -2276,17 +2276,18 @@ func (s *GrpcServer) validateRoleInRequest(ctx context.Context, header *pdpb.Req
}
*allowFollower = true
}
if header.GetClusterId() != s.clusterID {
return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, header.GetClusterId())
if clusterID := s.ClusterID(); header.GetClusterId() != clusterID {
return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", clusterID, header.GetClusterId())
}
return nil
}

func (s *GrpcServer) header() *pdpb.ResponseHeader {
if s.clusterID == 0 {
clusterID := s.ClusterID()
if clusterID == 0 {
return s.wrapErrorToHeader(pdpb.ErrorType_NOT_BOOTSTRAPPED, "cluster id is not ready")
}
return &pdpb.ResponseHeader{ClusterId: s.clusterID}
return &pdpb.ResponseHeader{ClusterId: clusterID}
}

func (s *GrpcServer) wrapErrorToHeader(errorType pdpb.ErrorType, message string) *pdpb.ResponseHeader {
Expand All @@ -2298,7 +2299,7 @@ func (s *GrpcServer) wrapErrorToHeader(errorType pdpb.ErrorType, message string)

func (s *GrpcServer) errorHeader(err *pdpb.Error) *pdpb.ResponseHeader {
return &pdpb.ResponseHeader{
ClusterId: s.clusterID,
ClusterId: s.ClusterID(),
Error: err,
}
}
Expand Down
32 changes: 17 additions & 15 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ type Server struct {
electionClient *clientv3.Client
// http client
httpClient *http.Client
clusterID uint64 // pd cluster id.
rootPath string
// PD cluster ID.
clusterID atomic.Uint64
rootPath string

// Server services.
// for id allocator, we can use one allocator for
Expand Down Expand Up @@ -425,17 +426,18 @@ func (s *Server) AddStartCallback(callbacks ...func()) {
}

func (s *Server) startServer(ctx context.Context) error {
var err error
if s.clusterID, err = etcdutil.InitClusterID(s.client, pdClusterIDPath); err != nil {
clusterID, err := etcdutil.InitClusterID(s.client, pdClusterIDPath)
if err != nil {
log.Error("failed to init cluster id", errs.ZapError(err))
return err
}
log.Info("init cluster id", zap.Uint64("cluster-id", s.clusterID))
s.clusterID.Store(clusterID)
log.Info("init cluster id", zap.Uint64("cluster-id", clusterID))
// It may lose accuracy if use float64 to store uint64. So we store the cluster id in label.
metadataGauge.WithLabelValues(fmt.Sprintf("cluster%d", s.clusterID)).Set(0)
metadataGauge.WithLabelValues(fmt.Sprintf("cluster%d", clusterID)).Set(0)
bs.ServerInfoGauge.WithLabelValues(versioninfo.PDReleaseVersion, versioninfo.PDGitHash).Set(float64(time.Now().Unix()))

s.rootPath = endpoint.PDRootPath(s.clusterID)
s.rootPath = endpoint.PDRootPath(clusterID)
s.member.InitMemberInfo(s.cfg.AdvertiseClientUrls, s.cfg.AdvertisePeerUrls, s.Name(), s.rootPath)
s.member.SetMemberDeployPath(s.member.ID())
s.member.SetMemberBinaryVersion(s.member.ID(), versioninfo.PDReleaseVersion)
Expand Down Expand Up @@ -478,7 +480,7 @@ func (s *Server) startServer(ctx context.Context) error {

s.gcSafePointManager = gc.NewSafePointManager(s.storage, s.cfg.PDServerCfg)
s.basicCluster = core.NewBasicCluster()
s.cluster = cluster.NewRaftCluster(ctx, s.clusterID, s.GetBasicCluster(), s.GetStorage(), syncer.NewRegionSyncer(s), s.client, s.httpClient)
s.cluster = cluster.NewRaftCluster(ctx, clusterID, s.GetBasicCluster(), s.GetStorage(), syncer.NewRegionSyncer(s), s.client, s.httpClient)
keyspaceIDAllocator := id.NewAllocator(&id.AllocatorParams{
Client: s.client,
RootPath: s.rootPath,
Expand All @@ -488,11 +490,11 @@ func (s *Server) startServer(ctx context.Context) error {
Step: keyspace.AllocStep,
})
if s.IsAPIServiceMode() {
s.keyspaceGroupManager = keyspace.NewKeyspaceGroupManager(s.ctx, s.storage, s.client, s.clusterID)
s.keyspaceGroupManager = keyspace.NewKeyspaceGroupManager(s.ctx, s.storage, s.client, clusterID)
}
s.keyspaceManager = keyspace.NewKeyspaceManager(s.ctx, s.storage, s.cluster, keyspaceIDAllocator, &s.cfg.Keyspace, s.keyspaceGroupManager)
s.safePointV2Manager = gc.NewSafePointManagerV2(s.ctx, s.storage, s.storage, s.storage)
s.hbStreams = hbstream.NewHeartbeatStreams(ctx, s.clusterID, "", s.cluster)
s.hbStreams = hbstream.NewHeartbeatStreams(ctx, clusterID, "", s.cluster)
// initial hot_region_storage in here.

s.hotRegionStorage, err = storage.NewHotRegionsStorage(
Expand Down Expand Up @@ -685,7 +687,7 @@ func (s *Server) collectEtcdStateMetrics() {
}

func (s *Server) bootstrapCluster(req *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) {
clusterID := s.clusterID
clusterID := s.ClusterID()

log.Info("try to bootstrap raft cluster",
zap.Uint64("cluster-id", clusterID),
Expand Down Expand Up @@ -916,7 +918,7 @@ func (s *Server) Name() string {

// ClusterID returns the cluster ID of this server.
func (s *Server) ClusterID() uint64 {
return s.clusterID
return s.clusterID.Load()
}

// StartTimestamp returns the start timestamp of this server
Expand Down Expand Up @@ -1409,7 +1411,7 @@ func (s *Server) DirectlyGetRaftCluster() *cluster.RaftCluster {
// GetCluster gets cluster.
func (s *Server) GetCluster() *metapb.Cluster {
return &metapb.Cluster{
Id: s.clusterID,
Id: s.ClusterID(),
MaxPeerCount: uint32(s.persistOptions.GetMaxReplicas()),
}
}
Expand Down Expand Up @@ -2010,15 +2012,15 @@ func (s *Server) SetServicePrimaryAddr(serviceName, addr string) {

func (s *Server) initTSOPrimaryWatcher() {
serviceName := mcs.TSOServiceName
tsoRootPath := endpoint.TSOSvcRootPath(s.clusterID)
tsoRootPath := endpoint.TSOSvcRootPath(s.ClusterID())
tsoServicePrimaryKey := endpoint.KeyspaceGroupPrimaryPath(tsoRootPath, mcs.DefaultKeyspaceGroupID)
s.tsoPrimaryWatcher = s.initServicePrimaryWatcher(serviceName, tsoServicePrimaryKey)
s.tsoPrimaryWatcher.StartWatchLoop()
}

func (s *Server) initSchedulingPrimaryWatcher() {
serviceName := mcs.SchedulingServiceName
primaryKey := endpoint.SchedulingPrimaryPath(s.clusterID)
primaryKey := endpoint.SchedulingPrimaryPath(s.ClusterID())
s.schedulingPrimaryWatcher = s.initServicePrimaryWatcher(serviceName, primaryKey)
s.schedulingPrimaryWatcher.StartWatchLoop()
}
Expand Down

0 comments on commit e8da033

Please sign in to comment.