diff --git a/client/base_client.go b/client/base_client.go old mode 100644 new mode 100755 index 26c4505c608..d0ecdb9d8b3 --- a/client/base_client.go +++ b/client/base_client.go @@ -137,9 +137,9 @@ func (c *baseClient) memberLoop() { case <-ctx.Done(): return } - failpoint.Inject("skipUpdateMember", func() { - failpoint.Continue() - }) + if _, _err_ := failpoint.Eval(_curpkg_("skipUpdateMember")); _err_ == nil { + continue + } if err := c.updateMember(); err != nil { log.Error("[pd] failed updateMember", errs.ZapError(err)) } @@ -256,9 +256,9 @@ func (c *baseClient) initClusterID() error { clusterID = members.GetHeader().GetClusterId() continue } - failpoint.Inject("skipClusterIDCheck", func() { - failpoint.Continue() - }) + if _, _err_ := failpoint.Eval(_curpkg_("skipClusterIDCheck")); _err_ == nil { + continue + } // All URLs passed in should have the same cluster ID. if members.GetHeader().GetClusterId() != clusterID { return errors.WithStack(errUnmatchedClusterID) @@ -274,11 +274,11 @@ func (c *baseClient) initClusterID() error { func (c *baseClient) updateMember() error { for i, u := range c.GetURLs() { - failpoint.Inject("skipFirstUpdateMember", func() { + if _, _err_ := failpoint.Eval(_curpkg_("skipFirstUpdateMember")); _err_ == nil { if i == 0 { - failpoint.Continue() + continue } - }) + } members, err := c.getMembers(c.ctx, u, updateMemberTimeout) // Check the cluster ID. if err == nil && members.GetHeader().GetClusterId() != c.clusterID { diff --git a/client/base_client.go__failpoint_stash__ b/client/base_client.go__failpoint_stash__ new file mode 100644 index 00000000000..26c4505c608 --- /dev/null +++ b/client/base_client.go__failpoint_stash__ @@ -0,0 +1,456 @@ +// Copyright 2019 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "fmt" + "reflect" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/client/errs" + "github.com/tikv/pd/client/grpcutil" + "github.com/tikv/pd/client/tlsutil" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +const ( + globalDCLocation = "global" + memberUpdateInterval = time.Minute +) + +// baseClient is a basic client for all other complex client. +type baseClient struct { + urls atomic.Value // Store as []string + clusterID uint64 + // PD leader URL + leader atomic.Value // Store as string + // PD follower URLs + followers atomic.Value // Store as []string + // addr -> TSO gRPC connection + clientConns sync.Map // Store as map[string]*grpc.ClientConn + // dc-location -> TSO allocator leader URL + allocators sync.Map // Store as map[string]string + + checkLeaderCh chan struct{} + checkTSODispatcherCh chan struct{} + updateConnectionCtxsCh chan struct{} + + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + + security SecurityOption + + // Client option. + option *option +} + +// SecurityOption records options about tls +type SecurityOption struct { + CAPath string + CertPath string + KeyPath string + + SSLCABytes []byte + SSLCertBytes []byte + SSLKEYBytes []byte +} + +// newBaseClient returns a new baseClient. +func newBaseClient(ctx context.Context, urls []string, security SecurityOption) *baseClient { + clientCtx, clientCancel := context.WithCancel(ctx) + bc := &baseClient{ + checkLeaderCh: make(chan struct{}, 1), + checkTSODispatcherCh: make(chan struct{}, 1), + updateConnectionCtxsCh: make(chan struct{}, 1), + ctx: clientCtx, + cancel: clientCancel, + security: security, + option: newOption(), + } + bc.urls.Store(urls) + return bc +} + +func (c *baseClient) init() error { + if err := c.initRetry(c.initClusterID); err != nil { + c.cancel() + return err + } + if err := c.initRetry(c.updateMember); err != nil { + c.cancel() + return err + } + log.Info("[pd] init cluster id", zap.Uint64("cluster-id", c.clusterID)) + + c.wg.Add(1) + go c.memberLoop() + return nil +} + +func (c *baseClient) initRetry(f func() error) error { + var err error + for i := 0; i < c.option.maxRetryTimes; i++ { + if err = f(); err == nil { + return nil + } + select { + case <-c.ctx.Done(): + return err + case <-time.After(time.Second): + } + } + return errors.WithStack(err) +} + +func (c *baseClient) memberLoop() { + defer c.wg.Done() + + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + + for { + select { + case <-c.checkLeaderCh: + case <-time.After(memberUpdateInterval): + case <-ctx.Done(): + return + } + failpoint.Inject("skipUpdateMember", func() { + failpoint.Continue() + }) + if err := c.updateMember(); err != nil { + log.Error("[pd] failed updateMember", errs.ZapError(err)) + } + } +} + +// ScheduleCheckLeader is used to check leader. +func (c *baseClient) ScheduleCheckLeader() { + select { + case c.checkLeaderCh <- struct{}{}: + default: + } +} + +func (c *baseClient) scheduleCheckTSODispatcher() { + select { + case c.checkTSODispatcherCh <- struct{}{}: + default: + } +} + +func (c *baseClient) scheduleUpdateConnectionCtxs() { + select { + case c.updateConnectionCtxsCh <- struct{}{}: + default: + } +} + +// GetClusterID returns the ClusterID. +func (c *baseClient) GetClusterID(context.Context) uint64 { + return c.clusterID +} + +// GetLeaderAddr returns the leader address. +func (c *baseClient) GetLeaderAddr() string { + leaderAddr := c.leader.Load() + if leaderAddr == nil { + return "" + } + return leaderAddr.(string) +} + +// GetFollowerAddrs returns the follower address. +func (c *baseClient) GetFollowerAddrs() []string { + followerAddrs := c.followers.Load() + if followerAddrs == nil { + return []string{} + } + return followerAddrs.([]string) +} + +// GetURLs returns the URLs. +// For testing use. It should only be called when the client is closed. +func (c *baseClient) GetURLs() []string { + return c.urls.Load().([]string) +} + +func (c *baseClient) GetAllocatorLeaderURLs() map[string]string { + allocatorLeader := make(map[string]string) + c.allocators.Range(func(dcLocation, url interface{}) bool { + allocatorLeader[dcLocation.(string)] = url.(string) + return true + }) + return allocatorLeader +} + +func (c *baseClient) getAllocatorLeaderAddrByDCLocation(dcLocation string) (string, bool) { + url, exist := c.allocators.Load(dcLocation) + if !exist { + return "", false + } + return url.(string), true +} + +func (c *baseClient) getAllocatorClientConnByDCLocation(dcLocation string) (*grpc.ClientConn, string) { + url, ok := c.allocators.Load(dcLocation) + if !ok { + panic(fmt.Sprintf("the allocator leader in %s should exist", dcLocation)) + } + cc, ok := c.clientConns.Load(url) + if !ok { + panic(fmt.Sprintf("the client connection of %s in %s should exist", url, dcLocation)) + } + return cc.(*grpc.ClientConn), url.(string) +} + +func (c *baseClient) gcAllocatorLeaderAddr(curAllocatorMap map[string]*pdpb.Member) { + // Clean up the old TSO allocators + c.allocators.Range(func(dcLocationKey, _ interface{}) bool { + dcLocation := dcLocationKey.(string) + // Skip the Global TSO Allocator + if dcLocation == globalDCLocation { + return true + } + if _, exist := curAllocatorMap[dcLocation]; !exist { + log.Info("[pd] delete unused tso allocator", zap.String("dc-location", dcLocation)) + c.allocators.Delete(dcLocation) + } + return true + }) +} + +func (c *baseClient) initClusterID() error { + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + var clusterID uint64 + for _, u := range c.GetURLs() { + members, err := c.getMembers(ctx, u, c.option.timeout) + if err != nil || members.GetHeader() == nil { + log.Warn("[pd] failed to get cluster id", zap.String("url", u), errs.ZapError(err)) + continue + } + if clusterID == 0 { + clusterID = members.GetHeader().GetClusterId() + continue + } + failpoint.Inject("skipClusterIDCheck", func() { + failpoint.Continue() + }) + // All URLs passed in should have the same cluster ID. + if members.GetHeader().GetClusterId() != clusterID { + return errors.WithStack(errUnmatchedClusterID) + } + } + // Failed to init the cluster ID. + if clusterID == 0 { + return errors.WithStack(errFailInitClusterID) + } + c.clusterID = clusterID + return nil +} + +func (c *baseClient) updateMember() error { + for i, u := range c.GetURLs() { + failpoint.Inject("skipFirstUpdateMember", func() { + if i == 0 { + failpoint.Continue() + } + }) + members, err := c.getMembers(c.ctx, u, updateMemberTimeout) + // Check the cluster ID. + if err == nil && members.GetHeader().GetClusterId() != c.clusterID { + err = errs.ErrClientUpdateMember.FastGenByArgs("cluster id does not match") + } + // Check the TSO Allocator Leader. + var errTSO error + if err == nil { + if members.GetLeader() == nil || len(members.GetLeader().GetClientUrls()) == 0 { + err = errs.ErrClientGetLeader.FastGenByArgs("leader address don't exist") + } + // Still need to update TsoAllocatorLeaders, even if there is no PD leader + errTSO = c.switchTSOAllocatorLeader(members.GetTsoAllocatorLeaders()) + } + + // Failed to get PD leader + if err != nil { + log.Info("[pd] cannot update member from this address", + zap.String("address", u), + errs.ZapError(err)) + select { + case <-c.ctx.Done(): + return errors.WithStack(err) + default: + continue + } + } + + c.updateURLs(members.GetMembers()) + c.updateFollowers(members.GetMembers(), members.GetLeader()) + if err := c.switchLeader(members.GetLeader().GetClientUrls()); err != nil { + return err + } + c.scheduleCheckTSODispatcher() + + // If `switchLeader` succeeds but `switchTSOAllocatorLeader` has an error, + // the error of `switchTSOAllocatorLeader` will be returned. + return errTSO + } + return errs.ErrClientGetLeader.FastGenByArgs(c.GetURLs()) +} + +func (c *baseClient) getMembers(ctx context.Context, url string, timeout time.Duration) (*pdpb.GetMembersResponse, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + cc, err := c.getOrCreateGRPCConn(url) + if err != nil { + return nil, err + } + members, err := pdpb.NewPDClient(cc).GetMembers(ctx, &pdpb.GetMembersRequest{}) + if err != nil { + attachErr := errors.Errorf("error:%s target:%s status:%s", err, cc.Target(), cc.GetState().String()) + return nil, errs.ErrClientGetMember.Wrap(attachErr).GenWithStackByCause() + } + if members.GetHeader().GetError() != nil { + attachErr := errors.Errorf("error:%s target:%s status:%s", members.GetHeader().GetError().String(), cc.Target(), cc.GetState().String()) + return nil, errs.ErrClientGetMember.Wrap(attachErr).GenWithStackByCause() + } + return members, nil +} + +func (c *baseClient) updateURLs(members []*pdpb.Member) { + urls := make([]string, 0, len(members)) + for _, m := range members { + urls = append(urls, m.GetClientUrls()...) + } + + sort.Strings(urls) + oldURLs := c.GetURLs() + // the url list is same. + if reflect.DeepEqual(oldURLs, urls) { + return + } + c.urls.Store(urls) + // Update the connection contexts when member changes if TSO Follower Proxy is enabled. + if c.option.getEnableTSOFollowerProxy() { + c.scheduleUpdateConnectionCtxs() + } + log.Info("[pd] update member urls", zap.Strings("old-urls", oldURLs), zap.Strings("new-urls", urls)) +} + +func (c *baseClient) switchLeader(addrs []string) error { + // FIXME: How to safely compare leader urls? For now, only allows one client url. + addr := addrs[0] + oldLeader := c.GetLeaderAddr() + if addr == oldLeader { + return nil + } + + if _, err := c.getOrCreateGRPCConn(addr); err != nil { + log.Warn("[pd] failed to connect leader", zap.String("leader", addr), errs.ZapError(err)) + return err + } + // Set PD leader and Global TSO Allocator (which is also the PD leader) + c.leader.Store(addr) + c.allocators.Store(globalDCLocation, addr) + log.Info("[pd] switch leader", zap.String("new-leader", addr), zap.String("old-leader", oldLeader)) + return nil +} + +func (c *baseClient) updateFollowers(members []*pdpb.Member, leader *pdpb.Member) { + var addrs []string + for _, member := range members { + if member.GetMemberId() != leader.GetMemberId() { + if len(member.GetClientUrls()) > 0 { + addrs = append(addrs, member.GetClientUrls()...) + } + } + } + c.followers.Store(addrs) +} + +func (c *baseClient) switchTSOAllocatorLeader(allocatorMap map[string]*pdpb.Member) error { + if len(allocatorMap) == 0 { + return nil + } + // Switch to the new one + for dcLocation, member := range allocatorMap { + if len(member.GetClientUrls()) == 0 { + continue + } + addr := member.GetClientUrls()[0] + oldAddr, exist := c.getAllocatorLeaderAddrByDCLocation(dcLocation) + if exist && addr == oldAddr { + continue + } + if _, err := c.getOrCreateGRPCConn(addr); err != nil { + log.Warn("[pd] failed to connect dc tso allocator leader", + zap.String("dc-location", dcLocation), + zap.String("leader", addr), + errs.ZapError(err)) + return err + } + c.allocators.Store(dcLocation, addr) + log.Info("[pd] switch dc tso allocator leader", + zap.String("dc-location", dcLocation), + zap.String("new-leader", addr), + zap.String("old-leader", oldAddr)) + } + // Garbage collection of the old TSO allocator leaders + c.gcAllocatorLeaderAddr(allocatorMap) + return nil +} + +func (c *baseClient) getOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) { + conn, ok := c.clientConns.Load(addr) + if ok { + return conn.(*grpc.ClientConn), nil + } + tlsCfg, err := tlsutil.TLSConfig{ + CAPath: c.security.CAPath, + CertPath: c.security.CertPath, + KeyPath: c.security.KeyPath, + + SSLCABytes: c.security.SSLCABytes, + SSLCertBytes: c.security.SSLCertBytes, + SSLKEYBytes: c.security.SSLKEYBytes, + }.ToTLSConfig() + if err != nil { + return nil, err + } + dCtx, cancel := context.WithTimeout(c.ctx, dialTimeout) + defer cancel() + cc, err := grpcutil.GetClientConn(dCtx, addr, tlsCfg, c.option.gRPCDialOptions...) + if err != nil { + return nil, err + } + if old, ok := c.clientConns.Load(addr); ok { + cc.Close() + log.Debug("use old connection", zap.String("target", cc.Target()), zap.String("state", cc.GetState().String())) + return old.(*grpc.ClientConn), nil + } + c.clientConns.Store(addr, cc) + return cc, nil +} diff --git a/client/binding__failpoint_binding__.go b/client/binding__failpoint_binding__.go new file mode 100755 index 00000000000..0f0f2e691c3 --- /dev/null +++ b/client/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package pd + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/client/client.go b/client/client.go old mode 100644 new mode 100755 index b7e15fe6eb2..ed706e7aad0 --- a/client/client.go +++ b/client/client.go @@ -499,10 +499,10 @@ func (c *client) checkLeaderHealth(ctx context.Context) { healthCli := healthpb.NewHealthClient(cc.(*grpc.ClientConn)) resp, err := healthCli.Check(ctx, &healthpb.HealthCheckRequest{Service: ""}) rpcErr, ok := status.FromError(err) - failpoint.Inject("unreachableNetwork1", func() { + if _, _err_ := failpoint.Eval(_curpkg_("unreachableNetwork1")); _err_ == nil { resp = nil err = status.New(codes.Unavailable, "unavailable").Err() - }) + } if (ok && isNetworkError(rpcErr.Code())) || resp.GetStatus() != healthpb.HealthCheckResponse_SERVING { atomic.StoreInt32(&(c.leaderNetworkFailure), int32(1)) } else { @@ -648,9 +648,9 @@ func (c *client) checkAllocator( } healthCtx, healthCancel := context.WithTimeout(dispatcherCtx, c.option.timeout) resp, err := healthCli.Check(healthCtx, &healthpb.HealthCheckRequest{Service: ""}) - failpoint.Inject("unreachableNetwork", func() { + if _, _err_ := failpoint.Eval(_curpkg_("unreachableNetwork")); _err_ == nil { resp.Status = healthpb.HealthCheckResponse_UNKNOWN - }) + } healthCancel() if err == nil && resp.GetStatus() == healthpb.HealthCheckResponse_SERVING { // create a stream of the original allocator @@ -957,10 +957,10 @@ func (c *client) tryConnect( cc, url = c.getAllocatorClientConnByDCLocation(dc) cctx, cancel := context.WithCancel(dispatcherCtx) stream, err = c.createTsoStream(cctx, cancel, pdpb.NewPDClient(cc)) - failpoint.Inject("unreachableNetwork", func() { + if _, _err_ := failpoint.Eval(_curpkg_("unreachableNetwork")); _err_ == nil { stream = nil err = status.New(codes.Unavailable, "unavailable").Err() - }) + } if stream != nil && err == nil { updateAndClear(url, &connectionContext{url, stream, cctx, cancel}) return nil diff --git a/client/client.go__failpoint_stash__ b/client/client.go__failpoint_stash__ new file mode 100644 index 00000000000..b7e15fe6eb2 --- /dev/null +++ b/client/client.go__failpoint_stash__ @@ -0,0 +1,1932 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "fmt" + "math/rand" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/prometheus/client_golang/prometheus" + "github.com/tikv/pd/client/errs" + "github.com/tikv/pd/client/grpcutil" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" +) + +// Region contains information of a region's meta and its peers. +type Region struct { + Meta *metapb.Region + Leader *metapb.Peer + DownPeers []*metapb.Peer + PendingPeers []*metapb.Peer + Buckets *metapb.Buckets +} + +// GlobalConfigItem standard format of KV pair in GlobalConfig client +type GlobalConfigItem struct { + Name string + Value string + Error error +} + +// Client is a PD (Placement Driver) client. +// It should not be used after calling Close(). +type Client interface { + // GetClusterID gets the cluster ID from PD. + GetClusterID(ctx context.Context) uint64 + // GetAllMembers gets the members Info from PD + GetAllMembers(ctx context.Context) ([]*pdpb.Member, error) + // GetLeaderAddr returns current leader's address. It returns "" before + // syncing leader from server. + GetLeaderAddr() string + // GetTS gets a timestamp from PD. + GetTS(ctx context.Context) (int64, int64, error) + // GetTSAsync gets a timestamp from PD, without block the caller. + GetTSAsync(ctx context.Context) TSFuture + // GetLocalTS gets a local timestamp from PD. + GetLocalTS(ctx context.Context, dcLocation string) (int64, int64, error) + // GetLocalTSAsync gets a local timestamp from PD, without block the caller. + GetLocalTSAsync(ctx context.Context, dcLocation string) TSFuture + // GetRegion gets a region and its leader Peer from PD by key. + // The region may expire after split. Caller is responsible for caching and + // taking care of region change. + // Also it may return nil if PD finds no Region for the key temporarily, + // client should retry later. + GetRegion(ctx context.Context, key []byte, opts ...GetRegionOption) (*Region, error) + // GetRegionFromMember gets a region from certain members. + GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string) (*Region, error) + // GetPrevRegion gets the previous region and its leader Peer of the region where the key is located. + GetPrevRegion(ctx context.Context, key []byte, opts ...GetRegionOption) (*Region, error) + // GetRegionByID gets a region and its leader Peer from PD by id. + GetRegionByID(ctx context.Context, regionID uint64, opts ...GetRegionOption) (*Region, error) + // ScanRegion gets a list of regions, starts from the region that contains key. + // Limit limits the maximum number of regions returned. + // If a region has no leader, corresponding leader will be placed by a peer + // with empty value (PeerID is 0). + ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*Region, error) + // GetStore gets a store from PD by store id. + // The store may expire later. Caller is responsible for caching and taking care + // of store change. + GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) + // GetAllStores gets all stores from pd. + // The store may expire later. Caller is responsible for caching and taking care + // of store change. + GetAllStores(ctx context.Context, opts ...GetStoreOption) ([]*metapb.Store, error) + // Update GC safe point. TiKV will check it and do GC themselves if necessary. + // If the given safePoint is less than the current one, it will not be updated. + // Returns the new safePoint after updating. + UpdateGCSafePoint(ctx context.Context, safePoint uint64) (uint64, error) + // UpdateServiceGCSafePoint updates the safepoint for specific service and + // returns the minimum safepoint across all services, this value is used to + // determine the safepoint for multiple services, it does not trigger a GC + // job. Use UpdateGCSafePoint to trigger the GC job if needed. + UpdateServiceGCSafePoint(ctx context.Context, serviceID string, ttl int64, safePoint uint64) (uint64, error) + // ScatterRegion scatters the specified region. Should use it for a batch of regions, + // and the distribution of these regions will be dispersed. + // NOTICE: This method is the old version of ScatterRegions, you should use the later one as your first choice. + ScatterRegion(ctx context.Context, regionID uint64) error + // ScatterRegions scatters the specified regions. Should use it for a batch of regions, + // and the distribution of these regions will be dispersed. + ScatterRegions(ctx context.Context, regionsID []uint64, opts ...RegionsOption) (*pdpb.ScatterRegionResponse, error) + // SplitRegions split regions by given split keys + SplitRegions(ctx context.Context, splitKeys [][]byte, opts ...RegionsOption) (*pdpb.SplitRegionsResponse, error) + // SplitAndScatterRegions split regions by given split keys and scatter new regions + SplitAndScatterRegions(ctx context.Context, splitKeys [][]byte, opts ...RegionsOption) (*pdpb.SplitAndScatterRegionsResponse, error) + // GetOperator gets the status of operator of the specified region. + GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) + + // LoadGlobalConfig gets the global config from etcd + LoadGlobalConfig(ctx context.Context, names []string) ([]GlobalConfigItem, error) + // StoreGlobalConfig set the config from etcd + StoreGlobalConfig(ctx context.Context, items []GlobalConfigItem) error + // WatchGlobalConfig returns an stream with all global config and updates + WatchGlobalConfig(ctx context.Context) (chan []GlobalConfigItem, error) + // UpdateOption updates the client option. + UpdateOption(option DynamicOption, value interface{}) error + + // GetExternalTimestamp returns external timestamp + GetExternalTimestamp(ctx context.Context) (uint64, error) + // SetExternalTimestamp sets external timestamp + SetExternalTimestamp(ctx context.Context, timestamp uint64) error + + // KeyspaceClient manages keyspace metadata. + KeyspaceClient + // Close closes the client. + Close() +} + +// GetStoreOp represents available options when getting stores. +type GetStoreOp struct { + excludeTombstone bool +} + +// GetStoreOption configures GetStoreOp. +type GetStoreOption func(*GetStoreOp) + +// WithExcludeTombstone excludes tombstone stores from the result. +func WithExcludeTombstone() GetStoreOption { + return func(op *GetStoreOp) { op.excludeTombstone = true } +} + +// RegionsOp represents available options when operate regions +type RegionsOp struct { + group string + retryLimit uint64 +} + +// RegionsOption configures RegionsOp +type RegionsOption func(op *RegionsOp) + +// WithGroup specify the group during Scatter/Split Regions +func WithGroup(group string) RegionsOption { + return func(op *RegionsOp) { op.group = group } +} + +// WithRetry specify the retry limit during Scatter/Split Regions +func WithRetry(retry uint64) RegionsOption { + return func(op *RegionsOp) { op.retryLimit = retry } +} + +// GetRegionOp represents available options when getting regions. +type GetRegionOp struct { + needBuckets bool +} + +// GetRegionOption configures GetRegionOp. +type GetRegionOption func(op *GetRegionOp) + +// WithBuckets means getting region and its buckets. +func WithBuckets() GetRegionOption { + return func(op *GetRegionOp) { op.needBuckets = true } +} + +type tsoRequest struct { + start time.Time + clientCtx context.Context + requestCtx context.Context + done chan error + physical int64 + logical int64 + dcLocation string +} + +type tsoBatchController struct { + maxBatchSize int + // bestBatchSize is a dynamic size that changed based on the current batch effect. + bestBatchSize int + + tsoRequestCh chan *tsoRequest + collectedRequests []*tsoRequest + collectedRequestCount int + + batchStartTime time.Time +} + +func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tsoBatchController { + return &tsoBatchController{ + maxBatchSize: maxBatchSize, + bestBatchSize: 8, /* Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4) */ + tsoRequestCh: tsoRequestCh, + collectedRequests: make([]*tsoRequest, maxBatchSize+1), + collectedRequestCount: 0, + } +} + +// fetchPendingRequests will start a new round of the batch collecting from the channel. +// It returns true if everything goes well, otherwise false which means we should stop the service. +func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error { + var firstTSORequest *tsoRequest + select { + case <-ctx.Done(): + return ctx.Err() + case firstTSORequest = <-tbc.tsoRequestCh: + } + // Start to batch when the first TSO request arrives. + tbc.batchStartTime = time.Now() + tbc.collectedRequestCount = 0 + tbc.pushRequest(firstTSORequest) + + // This loop is for trying best to collect more requests, so we use `tbc.maxBatchSize` here. +fetchPendingRequestsLoop: + for tbc.collectedRequestCount < tbc.maxBatchSize { + select { + case tsoReq := <-tbc.tsoRequestCh: + tbc.pushRequest(tsoReq) + case <-ctx.Done(): + return ctx.Err() + default: + break fetchPendingRequestsLoop + } + } + + // Check whether we should fetch more pending TSO requests from the channel. + // TODO: maybe consider the actual load that returns through a TSO response from PD server. + if tbc.collectedRequestCount >= tbc.maxBatchSize || maxBatchWaitInterval <= 0 { + return nil + } + + // Fetches more pending TSO requests from the channel. + // Try to collect `tbc.bestBatchSize` requests, or wait `maxBatchWaitInterval` + // when `tbc.collectedRequestCount` is less than the `tbc.bestBatchSize`. + if tbc.collectedRequestCount < tbc.bestBatchSize { + after := time.NewTimer(maxBatchWaitInterval) + defer after.Stop() + for tbc.collectedRequestCount < tbc.bestBatchSize { + select { + case tsoReq := <-tbc.tsoRequestCh: + tbc.pushRequest(tsoReq) + case <-ctx.Done(): + return ctx.Err() + case <-after.C: + return nil + } + } + } + + // Do an additional non-block try. Here we test the length with `tbc.maxBatchSize` instead + // of `tbc.bestBatchSize` because trying best to fetch more requests is necessary so that + // we can adjust the `tbc.bestBatchSize` dynamically later. + for tbc.collectedRequestCount < tbc.maxBatchSize { + select { + case tsoReq := <-tbc.tsoRequestCh: + tbc.pushRequest(tsoReq) + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } + } + return nil +} + +func (tbc *tsoBatchController) pushRequest(tsoReq *tsoRequest) { + tbc.collectedRequests[tbc.collectedRequestCount] = tsoReq + tbc.collectedRequestCount++ +} + +func (tbc *tsoBatchController) getCollectedRequests() []*tsoRequest { + return tbc.collectedRequests[:tbc.collectedRequestCount] +} + +// adjustBestBatchSize stabilizes the latency with the AIAD algorithm. +func (tbc *tsoBatchController) adjustBestBatchSize() { + tsoBestBatchSize.Observe(float64(tbc.bestBatchSize)) + length := tbc.collectedRequestCount + if length < tbc.bestBatchSize && tbc.bestBatchSize > 1 { + // Waits too long to collect requests, reduce the target batch size. + tbc.bestBatchSize-- + } else if length > tbc.bestBatchSize+4 /* Hard-coded number, in order to make `tbc.bestBatchSize` stable */ && + tbc.bestBatchSize < tbc.maxBatchSize { + tbc.bestBatchSize++ + } +} + +func (tbc *tsoBatchController) revokePendingTSORequest(err error) { + for i := 0; i < len(tbc.tsoRequestCh); i++ { + req := <-tbc.tsoRequestCh + req.done <- err + } +} + +type tsoDispatcher struct { + dispatcherCancel context.CancelFunc + tsoBatchController *tsoBatchController +} + +type lastTSO struct { + physical int64 + logical int64 +} + +const ( + dialTimeout = 3 * time.Second + updateMemberTimeout = time.Second // Use a shorter timeout to recover faster from network isolation. + tsLoopDCCheckInterval = time.Minute + defaultMaxTSOBatchSize = 10000 // should be higher if client is sending requests in burst + retryInterval = 500 * time.Millisecond + maxRetryTimes = 6 +) + +// LeaderHealthCheckInterval might be changed in the unit to shorten the testing time. +var LeaderHealthCheckInterval = time.Second + +var ( + // errUnmatchedClusterID is returned when found a PD with a different cluster ID. + errUnmatchedClusterID = errors.New("[pd] unmatched cluster id") + // errFailInitClusterID is returned when failed to load clusterID from all supplied PD addresses. + errFailInitClusterID = errors.New("[pd] failed to get cluster id") + // errClosing is returned when request is canceled when client is closing. + errClosing = errors.New("[pd] closing") + // errTSOLength is returned when the number of response timestamps is inconsistent with request. + errTSOLength = errors.New("[pd] tso length in rpc response is incorrect") + // errGlobalConfigNotFound is returned when etcd does not contain the globalConfig item + errGlobalConfigNotFound = errors.New("[pd] global config not found") +) + +// ClientOption configures client. +type ClientOption func(c *client) + +// WithGRPCDialOptions configures the client with gRPC dial options. +func WithGRPCDialOptions(opts ...grpc.DialOption) ClientOption { + return func(c *client) { + c.option.gRPCDialOptions = append(c.option.gRPCDialOptions, opts...) + } +} + +// WithCustomTimeoutOption configures the client with timeout option. +func WithCustomTimeoutOption(timeout time.Duration) ClientOption { + return func(c *client) { + c.option.timeout = timeout + } +} + +// WithForwardingOption configures the client with forwarding option. +func WithForwardingOption(enableForwarding bool) ClientOption { + return func(c *client) { + c.option.enableForwarding = enableForwarding + } +} + +// WithMaxErrorRetry configures the client max retry times when connect meets error. +func WithMaxErrorRetry(count int) ClientOption { + return func(c *client) { + c.option.maxRetryTimes = count + } +} + +type client struct { + *baseClient + // tsoDispatcher is used to dispatch different TSO requests to + // the corresponding dc-location TSO channel. + tsoDispatcher sync.Map // Same as map[string]chan *tsoRequest + // dc-location -> deadline + tsDeadline sync.Map // Same as map[string]chan deadline + // dc-location -> *lastTSO + lastTSMap sync.Map // Same as map[string]*lastTSO + + // For internal usage. + checkTSDeadlineCh chan struct{} + leaderNetworkFailure int32 +} + +// NewClient creates a PD client. +func NewClient(pdAddrs []string, security SecurityOption, opts ...ClientOption) (Client, error) { + return NewClientWithContext(context.Background(), pdAddrs, security, opts...) +} + +// NewClientWithContext creates a PD client with context. +func NewClientWithContext(ctx context.Context, pdAddrs []string, security SecurityOption, opts ...ClientOption) (Client, error) { + log.Info("[pd] create pd client with endpoints", zap.Strings("pd-address", pdAddrs)) + c := &client{ + baseClient: newBaseClient(ctx, addrsToUrls(pdAddrs), security), + checkTSDeadlineCh: make(chan struct{}), + } + // Inject the client options. + for _, opt := range opts { + opt(c) + } + // Init the client base. + if err := c.init(); err != nil { + return nil, err + } + // Start the daemons. + c.updateTSODispatcher() + c.wg.Add(3) + go c.tsLoop() + go c.tsCancelLoop() + go c.leaderCheckLoop() + + return c, nil +} + +// UpdateOption updates the client option. +func (c *client) UpdateOption(option DynamicOption, value interface{}) error { + switch option { + case MaxTSOBatchWaitInterval: + interval, ok := value.(time.Duration) + if !ok { + return errors.New("[pd] invalid value type for MaxTSOBatchWaitInterval option, it should be time.Duration") + } + if err := c.option.setMaxTSOBatchWaitInterval(interval); err != nil { + return err + } + case EnableTSOFollowerProxy: + enable, ok := value.(bool) + if !ok { + return errors.New("[pd] invalid value type for EnableTSOFollowerProxy option, it should be bool") + } + c.option.setEnableTSOFollowerProxy(enable) + default: + return errors.New("[pd] unsupported client option") + } + return nil +} + +func (c *client) updateTSODispatcher() { + // Set up the new TSO dispatcher and batch controller. + c.allocators.Range(func(dcLocationKey, _ interface{}) bool { + dcLocation := dcLocationKey.(string) + if !c.checkTSODispatcher(dcLocation) { + c.createTSODispatcher(dcLocation) + } + return true + }) + // Clean up the unused TSO dispatcher + c.tsoDispatcher.Range(func(dcLocationKey, _ interface{}) bool { + dcLocation := dcLocationKey.(string) + // Skip the Global TSO Allocator + if dcLocation == globalDCLocation { + return true + } + if dispatcher, exist := c.allocators.Load(dcLocation); !exist { + log.Info("[pd] delete unused tso dispatcher", zap.String("dc-location", dcLocation)) + dispatcher.(*tsoDispatcher).dispatcherCancel() + c.tsoDispatcher.Delete(dcLocation) + } + return true + }) +} + +func (c *client) leaderCheckLoop() { + defer c.wg.Done() + + leaderCheckLoopCtx, leaderCheckLoopCancel := context.WithCancel(c.ctx) + defer leaderCheckLoopCancel() + + ticker := time.NewTicker(LeaderHealthCheckInterval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + c.checkLeaderHealth(leaderCheckLoopCtx) + } + } +} + +func (c *client) checkLeaderHealth(ctx context.Context) { + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + if cc, ok := c.clientConns.Load(c.GetLeaderAddr()); ok { + healthCli := healthpb.NewHealthClient(cc.(*grpc.ClientConn)) + resp, err := healthCli.Check(ctx, &healthpb.HealthCheckRequest{Service: ""}) + rpcErr, ok := status.FromError(err) + failpoint.Inject("unreachableNetwork1", func() { + resp = nil + err = status.New(codes.Unavailable, "unavailable").Err() + }) + if (ok && isNetworkError(rpcErr.Code())) || resp.GetStatus() != healthpb.HealthCheckResponse_SERVING { + atomic.StoreInt32(&(c.leaderNetworkFailure), int32(1)) + } else { + atomic.StoreInt32(&(c.leaderNetworkFailure), int32(0)) + } + } +} + +type deadline struct { + timer <-chan time.Time + done chan struct{} + cancel context.CancelFunc +} + +func (c *client) tsCancelLoop() { + defer c.wg.Done() + + tsCancelLoopCtx, tsCancelLoopCancel := context.WithCancel(c.ctx) + defer tsCancelLoopCancel() + + ticker := time.NewTicker(tsLoopDCCheckInterval) + defer ticker.Stop() + for { + // Watch every dc-location's tsDeadlineCh + c.allocators.Range(func(dcLocation, _ interface{}) bool { + c.watchTSDeadline(tsCancelLoopCtx, dcLocation.(string)) + return true + }) + select { + case <-c.checkTSDeadlineCh: + continue + case <-ticker.C: + continue + case <-tsCancelLoopCtx.Done(): + return + } + } +} + +func (c *client) watchTSDeadline(ctx context.Context, dcLocation string) { + if _, exist := c.tsDeadline.Load(dcLocation); !exist { + tsDeadlineCh := make(chan deadline, 1) + c.tsDeadline.Store(dcLocation, tsDeadlineCh) + go func(dc string, tsDeadlineCh <-chan deadline) { + for { + select { + case d := <-tsDeadlineCh: + select { + case <-d.timer: + log.Error("[pd] tso request is canceled due to timeout", zap.String("dc-location", dc), errs.ZapError(errs.ErrClientGetTSOTimeout)) + d.cancel() + case <-d.done: + continue + case <-ctx.Done(): + return + } + case <-ctx.Done(): + return + } + } + }(dcLocation, tsDeadlineCh) + } +} + +func (c *client) scheduleCheckTSDeadline() { + select { + case c.checkTSDeadlineCh <- struct{}{}: + default: + } +} + +func (c *client) checkStreamTimeout(streamCtx context.Context, cancel context.CancelFunc, done chan struct{}) { + select { + case <-done: + return + case <-time.After(c.option.timeout): + cancel() + case <-streamCtx.Done(): + } + <-done +} + +func (c *client) GetAllMembers(ctx context.Context) ([]*pdpb.Member, error) { + start := time.Now() + defer func() { cmdDurationGetAllMembers.Observe(time.Since(start).Seconds()) }() + + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + req := &pdpb.GetMembersRequest{Header: c.requestHeader()} + ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) + resp, err := c.getClient().GetMembers(ctx, req) + cancel() + if err = c.respForErr(cmdFailDurationGetAllMembers, start, err, resp.GetHeader()); err != nil { + return nil, err + } + return resp.GetMembers(), nil +} + +func (c *client) tsLoop() { + defer c.wg.Done() + + loopCtx, loopCancel := context.WithCancel(c.ctx) + defer loopCancel() + + ticker := time.NewTicker(tsLoopDCCheckInterval) + defer ticker.Stop() + for { + c.updateTSODispatcher() + select { + case <-ticker.C: + case <-c.checkTSODispatcherCh: + case <-loopCtx.Done(): + return + } + } +} + +func (c *client) createTsoStream(ctx context.Context, cancel context.CancelFunc, client pdpb.PDClient) (pdpb.PD_TsoClient, error) { + done := make(chan struct{}) + // TODO: we need to handle a conner case that this goroutine is timeout while the stream is successfully created. + go c.checkStreamTimeout(ctx, cancel, done) + stream, err := client.Tso(ctx) + done <- struct{}{} + return stream, err +} + +func (c *client) checkAllocator( + dispatcherCtx context.Context, + forwardCancel context.CancelFunc, + dc, forwardedHostTrim, addrTrim, url string, + updateAndClear func(newAddr string, connectionCtx *connectionContext)) { + defer func() { + // cancel the forward stream + forwardCancel() + requestForwarded.WithLabelValues(forwardedHostTrim, addrTrim).Set(0) + }() + cc, u := c.getAllocatorClientConnByDCLocation(dc) + healthCli := healthpb.NewHealthClient(cc) + for { + // the pd/allocator leader change, we need to re-establish the stream + if u != url { + log.Info("[pd] the leader of the allocator leader is changed", zap.String("dc", dc), zap.String("origin", url), zap.String("new", u)) + return + } + healthCtx, healthCancel := context.WithTimeout(dispatcherCtx, c.option.timeout) + resp, err := healthCli.Check(healthCtx, &healthpb.HealthCheckRequest{Service: ""}) + failpoint.Inject("unreachableNetwork", func() { + resp.Status = healthpb.HealthCheckResponse_UNKNOWN + }) + healthCancel() + if err == nil && resp.GetStatus() == healthpb.HealthCheckResponse_SERVING { + // create a stream of the original allocator + cctx, cancel := context.WithCancel(dispatcherCtx) + stream, err := c.createTsoStream(cctx, cancel, pdpb.NewPDClient(cc)) + if err == nil && stream != nil { + log.Info("[pd] recover the original tso stream since the network has become normal", zap.String("dc", dc), zap.String("url", url)) + updateAndClear(url, &connectionContext{url, stream, cctx, cancel}) + return + } + } + select { + case <-dispatcherCtx.Done(): + return + case <-time.After(time.Second): + // To ensure we can get the latest allocator leader + // and once the leader is changed, we can exit this function. + _, u = c.getAllocatorClientConnByDCLocation(dc) + } + } +} + +func (c *client) checkTSODispatcher(dcLocation string) bool { + dispatcher, ok := c.tsoDispatcher.Load(dcLocation) + if !ok || dispatcher == nil { + return false + } + return true +} + +func (c *client) createTSODispatcher(dcLocation string) { + dispatcherCtx, dispatcherCancel := context.WithCancel(c.ctx) + dispatcher := &tsoDispatcher{ + dispatcherCancel: dispatcherCancel, + tsoBatchController: newTSOBatchController( + make(chan *tsoRequest, defaultMaxTSOBatchSize*2), + defaultMaxTSOBatchSize), + } + // Each goroutine is responsible for handling the tso stream request for its dc-location. + // The only case that will make the dispatcher goroutine exit + // is that the loopCtx is done, otherwise there is no circumstance + // this goroutine should exit. + go c.handleDispatcher(dispatcherCtx, dcLocation, dispatcher.tsoBatchController) + c.tsoDispatcher.Store(dcLocation, dispatcher) + log.Info("[pd] tso dispatcher created", zap.String("dc-location", dcLocation)) +} + +func (c *client) handleDispatcher( + dispatcherCtx context.Context, + dc string, + tbc *tsoBatchController) { + var ( + err error + streamAddr string + stream pdpb.PD_TsoClient + streamCtx context.Context + cancel context.CancelFunc + // addr -> connectionContext + connectionCtxs sync.Map + opts []opentracing.StartSpanOption + ) + defer func() { + log.Info("[pd] exit tso dispatcher", zap.String("dc-location", dc)) + // Cancel all connections. + connectionCtxs.Range(func(_, cc interface{}) bool { + cc.(*connectionContext).cancel() + return true + }) + }() + // Call updateConnectionCtxs once to init the connectionCtxs first. + c.updateConnectionCtxs(dispatcherCtx, dc, &connectionCtxs) + // Only the Global TSO needs to watch the updateConnectionCtxsCh to sense the + // change of the cluster when TSO Follower Proxy is enabled. + // TODO: support TSO Follower Proxy for the Local TSO. + if dc == globalDCLocation { + go func() { + var updateTicker = &time.Ticker{} + setNewUpdateTicker := func(ticker *time.Ticker) { + if updateTicker.C != nil { + updateTicker.Stop() + } + updateTicker = ticker + } + // Set to nil before returning to ensure that the existing ticker can be GC. + defer setNewUpdateTicker(nil) + + for { + select { + case <-dispatcherCtx.Done(): + return + case <-c.option.enableTSOFollowerProxyCh: + enableTSOFollowerProxy := c.option.getEnableTSOFollowerProxy() + if enableTSOFollowerProxy && updateTicker.C == nil { + // Because the TSO Follower Proxy is enabled, + // the periodic check needs to be performed. + setNewUpdateTicker(time.NewTicker(memberUpdateInterval)) + } else if !enableTSOFollowerProxy && updateTicker.C != nil { + // Because the TSO Follower Proxy is disabled, + // the periodic check needs to be turned off. + setNewUpdateTicker(&time.Ticker{}) + } else { + // The status of TSO Follower Proxy does not change, and updateConnectionCtxs is not triggered + continue + } + case <-updateTicker.C: + case <-c.updateConnectionCtxsCh: + } + c.updateConnectionCtxs(dispatcherCtx, dc, &connectionCtxs) + } + }() + } + + // Loop through each batch of TSO requests and send them for processing. + streamLoopTimer := time.NewTimer(c.option.timeout) +tsoBatchLoop: + for { + select { + case <-dispatcherCtx.Done(): + return + default: + } + // Start to collect the TSO requests. + maxBatchWaitInterval := c.option.getMaxTSOBatchWaitInterval() + if err = tbc.fetchPendingRequests(dispatcherCtx, maxBatchWaitInterval); err != nil { + if err == context.Canceled { + log.Info("[pd] stop fetching the pending tso requests due to context canceled", + zap.String("dc-location", dc)) + } else { + log.Error("[pd] fetch pending tso requests error", + zap.String("dc-location", dc), errs.ZapError(errs.ErrClientGetTSO, err)) + } + return + } + if maxBatchWaitInterval >= 0 { + tbc.adjustBestBatchSize() + } + streamLoopTimer.Reset(c.option.timeout) + // Choose a stream to send the TSO gRPC request. + streamChoosingLoop: + for { + connectionCtx := c.chooseStream(&connectionCtxs) + if connectionCtx != nil { + streamAddr, stream, streamCtx, cancel = connectionCtx.streamAddr, connectionCtx.stream, connectionCtx.ctx, connectionCtx.cancel + } + // Check stream and retry if necessary. + if stream == nil { + log.Info("[pd] tso stream is not ready", zap.String("dc", dc)) + if c.updateConnectionCtxs(dispatcherCtx, dc, &connectionCtxs) { + continue streamChoosingLoop + } + select { + case <-dispatcherCtx.Done(): + return + case <-streamLoopTimer.C: + err = errs.ErrClientCreateTSOStream.FastGenByArgs(errs.RetryTimeoutErr) + log.Error("[pd] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) + c.ScheduleCheckLeader() + c.finishTSORequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err)) + continue tsoBatchLoop + case <-time.After(retryInterval): + continue streamChoosingLoop + } + } + select { + case <-streamCtx.Done(): + log.Info("[pd] tso stream is canceled", zap.String("dc", dc), zap.String("stream-addr", streamAddr)) + // Set `stream` to nil and remove this stream from the `connectionCtxs` due to being canceled. + connectionCtxs.Delete(streamAddr) + cancel() + stream = nil + continue + default: + break streamChoosingLoop + } + } + done := make(chan struct{}) + dl := deadline{ + timer: time.After(c.option.timeout), + done: done, + cancel: cancel, + } + tsDeadlineCh, ok := c.tsDeadline.Load(dc) + for !ok || tsDeadlineCh == nil { + c.scheduleCheckTSDeadline() + time.Sleep(time.Millisecond * 100) + tsDeadlineCh, ok = c.tsDeadline.Load(dc) + } + select { + case <-dispatcherCtx.Done(): + return + case tsDeadlineCh.(chan deadline) <- dl: + } + opts = extractSpanReference(tbc, opts[:0]) + err = c.processTSORequests(stream, dc, tbc, opts) + close(done) + // If error happens during tso stream handling, reset stream and run the next trial. + if err != nil { + select { + case <-dispatcherCtx.Done(): + return + default: + } + c.ScheduleCheckLeader() + log.Error("[pd] getTS error", zap.String("dc-location", dc), zap.String("stream-addr", streamAddr), errs.ZapError(errs.ErrClientGetTSO, err)) + // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. + connectionCtxs.Delete(streamAddr) + cancel() + stream = nil + // Because ScheduleCheckLeader is asynchronous, if the leader changes, we better call `updateMember` ASAP. + if IsLeaderChange(err) { + if err := c.updateMember(); err != nil { + select { + case <-dispatcherCtx.Done(): + return + default: + } + } + // Because the TSO Follower Proxy could be configured online, + // If we change it from on -> off, background updateConnectionCtxs + // will cancel the current stream, then the EOF error caused by cancel() + // should not trigger the updateConnectionCtxs here. + // So we should only call it when the leader changes. + c.updateConnectionCtxs(dispatcherCtx, dc, &connectionCtxs) + } + } + } +} + +// TSO Follower Proxy only supports the Global TSO proxy now. +func (c *client) allowTSOFollowerProxy(dc string) bool { + return dc == globalDCLocation && c.option.getEnableTSOFollowerProxy() +} + +// chooseStream uses the reservoir sampling algorithm to randomly choose a connection. +// connectionCtxs will only have only one stream to choose when the TSO Follower Proxy is off. +func (c *client) chooseStream(connectionCtxs *sync.Map) (connectionCtx *connectionContext) { + idx := 0 + connectionCtxs.Range(func(addr, cc interface{}) bool { + j := rand.Intn(idx + 1) + if j < 1 { + connectionCtx = cc.(*connectionContext) + } + idx++ + return true + }) + return connectionCtx +} + +type connectionContext struct { + streamAddr string + // Current stream to send gRPC requests, maybe a leader or a follower. + stream pdpb.PD_TsoClient + ctx context.Context + cancel context.CancelFunc +} + +func (c *client) updateConnectionCtxs(updaterCtx context.Context, dc string, connectionCtxs *sync.Map) bool { + // Normal connection creating, it will be affected by the `enableForwarding`. + createTSOConnection := c.tryConnect + if c.allowTSOFollowerProxy(dc) { + createTSOConnection = c.tryConnectWithProxy + } + if err := createTSOConnection(updaterCtx, dc, connectionCtxs); err != nil { + log.Error("[pd] update connection contexts failed", zap.String("dc", dc), errs.ZapError(err)) + return false + } + return true +} + +// tryConnect will try to connect to the TSO allocator leader. If the connection becomes unreachable +// and enableForwarding is true, it will create a new connection to a follower to do the forwarding, +// while a new daemon will be created also to switch back to a normal leader connection ASAP the +// connection comes back to normal. +func (c *client) tryConnect( + dispatcherCtx context.Context, + dc string, + connectionCtxs *sync.Map, +) error { + var ( + networkErrNum uint64 + err error + stream pdpb.PD_TsoClient + url string + cc *grpc.ClientConn + ) + updateAndClear := func(newAddr string, connectionCtx *connectionContext) { + if cc, loaded := connectionCtxs.LoadOrStore(newAddr, connectionCtx); loaded { + // If the previous connection still exists, we should close it first. + cc.(*connectionContext).cancel() + connectionCtxs.Store(newAddr, connectionCtx) + } + connectionCtxs.Range(func(addr, cc interface{}) bool { + if addr.(string) != newAddr { + cc.(*connectionContext).cancel() + connectionCtxs.Delete(addr) + } + return true + }) + } + // retry several times before falling back to the follower when the network problem happens + + for i := 0; i < maxRetryTimes; i++ { + c.ScheduleCheckLeader() + cc, url = c.getAllocatorClientConnByDCLocation(dc) + cctx, cancel := context.WithCancel(dispatcherCtx) + stream, err = c.createTsoStream(cctx, cancel, pdpb.NewPDClient(cc)) + failpoint.Inject("unreachableNetwork", func() { + stream = nil + err = status.New(codes.Unavailable, "unavailable").Err() + }) + if stream != nil && err == nil { + updateAndClear(url, &connectionContext{url, stream, cctx, cancel}) + return nil + } + + if err != nil && c.option.enableForwarding { + // The reason we need to judge if the error code is equal to "Canceled" here is that + // when we create a stream we use a goroutine to manually control the timeout of the connection. + // There is no need to wait for the transport layer timeout which can reduce the time of unavailability. + // But it conflicts with the retry mechanism since we use the error code to decide if it is caused by network error. + // And actually the `Canceled` error can be regarded as a kind of network error in some way. + if rpcErr, ok := status.FromError(err); ok && (isNetworkError(rpcErr.Code()) || rpcErr.Code() == codes.Canceled) { + networkErrNum++ + } + } + + cancel() + select { + case <-dispatcherCtx.Done(): + return err + case <-time.After(retryInterval): + } + } + + if networkErrNum == maxRetryTimes { + // encounter the network error + followerClient, addr := c.followerClient() + if followerClient != nil { + log.Info("[pd] fall back to use follower to forward tso stream", zap.String("dc", dc), zap.String("addr", addr)) + forwardedHost, ok := c.getAllocatorLeaderAddrByDCLocation(dc) + if !ok { + return errors.Errorf("cannot find the allocator leader in %s", dc) + } + + // create the follower stream + cctx, cancel := context.WithCancel(dispatcherCtx) + cctx = grpcutil.BuildForwardContext(cctx, forwardedHost) + stream, err = c.createTsoStream(cctx, cancel, followerClient) + if err == nil { + forwardedHostTrim := trimHTTPPrefix(forwardedHost) + addrTrim := trimHTTPPrefix(addr) + // the goroutine is used to check the network and change back to the original stream + go c.checkAllocator(dispatcherCtx, cancel, dc, forwardedHostTrim, addrTrim, url, updateAndClear) + requestForwarded.WithLabelValues(forwardedHostTrim, addrTrim).Set(1) + updateAndClear(addr, &connectionContext{addr, stream, cctx, cancel}) + return nil + } + cancel() + } + } + return err +} + +// tryConnectWithProxy will create multiple streams to all the PD servers to work as a TSO proxy to reduce +// the pressure of PD leader. +func (c *client) tryConnectWithProxy( + dispatcherCtx context.Context, + dc string, + connectionCtxs *sync.Map, +) error { + clients := c.getAllClients() + leaderAddr := c.GetLeaderAddr() + forwardedHost, ok := c.getAllocatorLeaderAddrByDCLocation(dc) + if !ok { + return errors.Errorf("cannot find the allocator leader in %s", dc) + } + // GC the stale one. + connectionCtxs.Range(func(addr, cc interface{}) bool { + if _, ok := clients[addr.(string)]; !ok { + cc.(*connectionContext).cancel() + connectionCtxs.Delete(addr) + } + return true + }) + // Update the missing one. + for addr, client := range clients { + if _, ok = connectionCtxs.Load(addr); ok { + continue + } + cctx, cancel := context.WithCancel(dispatcherCtx) + // Do not proxy the leader client. + if addr != leaderAddr { + log.Info("[pd] use follower to forward tso stream to do the proxy", zap.String("dc", dc), zap.String("addr", addr)) + cctx = grpcutil.BuildForwardContext(cctx, forwardedHost) + } + // Create the TSO stream. + stream, err := c.createTsoStream(cctx, cancel, client) + if err == nil { + if addr != leaderAddr { + forwardedHostTrim := trimHTTPPrefix(forwardedHost) + addrTrim := trimHTTPPrefix(addr) + requestForwarded.WithLabelValues(forwardedHostTrim, addrTrim).Set(1) + } + connectionCtxs.Store(addr, &connectionContext{addr, stream, cctx, cancel}) + continue + } + log.Error("[pd] create the tso stream failed", zap.String("dc", dc), zap.String("addr", addr), errs.ZapError(err)) + cancel() + } + return nil +} + +func extractSpanReference(tbc *tsoBatchController, opts []opentracing.StartSpanOption) []opentracing.StartSpanOption { + for _, req := range tbc.getCollectedRequests() { + if span := opentracing.SpanFromContext(req.requestCtx); span != nil { + opts = append(opts, opentracing.ChildOf(span.Context())) + } + } + return opts +} + +func (c *client) processTSORequests(stream pdpb.PD_TsoClient, dcLocation string, tbc *tsoBatchController, opts []opentracing.StartSpanOption) error { + if len(opts) > 0 { + span := opentracing.StartSpan("pdclient.processTSORequests", opts...) + defer span.Finish() + } + start := time.Now() + requests := tbc.getCollectedRequests() + count := int64(len(requests)) + req := &pdpb.TsoRequest{ + Header: c.requestHeader(), + Count: uint32(count), + DcLocation: dcLocation, + } + + if err := stream.Send(req); err != nil { + err = errors.WithStack(err) + c.finishTSORequest(requests, 0, 0, 0, err) + return err + } + tsoBatchSendLatency.Observe(float64(time.Since(tbc.batchStartTime))) + resp, err := stream.Recv() + if err != nil { + err = errors.WithStack(err) + c.finishTSORequest(requests, 0, 0, 0, err) + return err + } + requestDurationTSO.Observe(time.Since(start).Seconds()) + tsoBatchSize.Observe(float64(count)) + + if resp.GetCount() != uint32(count) { + err = errors.WithStack(errTSOLength) + c.finishTSORequest(requests, 0, 0, 0, err) + return err + } + + physical, logical, suffixBits := resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits() + // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. + firstLogical := addLogical(logical, -count+1, suffixBits) + c.compareAndSwapTS(dcLocation, physical, firstLogical, suffixBits, count) + c.finishTSORequest(requests, physical, firstLogical, suffixBits, nil) + return nil +} + +// Because of the suffix, we need to shift the count before we add it to the logical part. +func addLogical(logical, count int64, suffixBits uint32) int64 { + return logical + count< DefaultSlowRequestTime { log.Warn("kv gets too slow", zap.String("request-key", key), zap.Duration("cost", cost), errs.ZapError(err)) diff --git a/pkg/etcdutil/etcdutil.go__failpoint_stash__ b/pkg/etcdutil/etcdutil.go__failpoint_stash__ new file mode 100644 index 00000000000..2f2445c2459 --- /dev/null +++ b/pkg/etcdutil/etcdutil.go__failpoint_stash__ @@ -0,0 +1,213 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcdutil + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "net/url" + "testing" + "time" + + "github.com/gogo/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/tempurl" + "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/embed" + "go.etcd.io/etcd/etcdserver" + "go.etcd.io/etcd/pkg/types" + "go.uber.org/zap" +) + +const ( + // DefaultDialTimeout is the maximum amount of time a dial will wait for a + // connection to setup. 30s is long enough for most of the network conditions. + DefaultDialTimeout = 30 * time.Second + + // DefaultRequestTimeout 10s is long enough for most of etcd clusters. + DefaultRequestTimeout = 10 * time.Second + + // DefaultSlowRequestTime 1s for the threshold for normal request, for those + // longer then 1s, they are considered as slow requests. + DefaultSlowRequestTime = time.Second +) + +// CheckClusterID checks etcd cluster ID, returns an error if mismatch. +// This function will never block even quorum is not satisfied. +func CheckClusterID(localClusterID types.ID, um types.URLsMap, tlsConfig *tls.Config) error { + if len(um) == 0 { + return nil + } + + var peerURLs []string + for _, urls := range um { + peerURLs = append(peerURLs, urls.StringSlice()...) + } + + for _, u := range peerURLs { + trp := &http.Transport{ + TLSClientConfig: tlsConfig, + } + remoteCluster, gerr := etcdserver.GetClusterFromRemotePeers(nil, []string{u}, trp) + trp.CloseIdleConnections() + if gerr != nil { + // Do not return error, because other members may be not ready. + log.Error("failed to get cluster from remote", errs.ZapError(errs.ErrEtcdGetCluster, gerr)) + continue + } + + remoteClusterID := remoteCluster.ID() + if remoteClusterID != localClusterID { + return errors.Errorf("Etcd cluster ID mismatch, expect %d, got %d", localClusterID, remoteClusterID) + } + } + return nil +} + +// AddEtcdMember adds an etcd member. +func AddEtcdMember(client *clientv3.Client, urls []string) (*clientv3.MemberAddResponse, error) { + ctx, cancel := context.WithTimeout(client.Ctx(), DefaultRequestTimeout) + addResp, err := client.MemberAdd(ctx, urls) + cancel() + return addResp, errors.WithStack(err) +} + +// ListEtcdMembers returns a list of internal etcd members. +func ListEtcdMembers(client *clientv3.Client) (*clientv3.MemberListResponse, error) { + ctx, cancel := context.WithTimeout(client.Ctx(), DefaultRequestTimeout) + failpoint.Inject("SlowMemberList", func(val failpoint.Value) { + d := val.(int) + time.Sleep(time.Duration(d) * time.Second) + }) + listResp, err := client.MemberList(ctx) + cancel() + if err != nil { + return listResp, errs.ErrEtcdMemberList.Wrap(err).GenWithStackByCause() + } + return listResp, nil +} + +// RemoveEtcdMember removes a member by the given id. +func RemoveEtcdMember(client *clientv3.Client, id uint64) (*clientv3.MemberRemoveResponse, error) { + ctx, cancel := context.WithTimeout(client.Ctx(), DefaultRequestTimeout) + rmResp, err := client.MemberRemove(ctx, id) + cancel() + if err != nil { + return rmResp, errs.ErrEtcdMemberRemove.Wrap(err).GenWithStackByCause() + } + return rmResp, nil +} + +// EtcdKVGet returns the etcd GetResponse by given key or key prefix +func EtcdKVGet(c *clientv3.Client, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + ctx, cancel := context.WithTimeout(c.Ctx(), DefaultRequestTimeout) + defer cancel() + + start := time.Now() + failpoint.Inject("SlowEtcdKVGet", func(val failpoint.Value) { + d := val.(int) + time.Sleep(time.Duration(d) * time.Second) + }) + resp, err := clientv3.NewKV(c).Get(ctx, key, opts...) + if cost := time.Since(start); cost > DefaultSlowRequestTime { + log.Warn("kv gets too slow", zap.String("request-key", key), zap.Duration("cost", cost), errs.ZapError(err)) + } + + if err != nil { + e := errs.ErrEtcdKVGet.Wrap(err).GenWithStackByCause() + log.Error("load from etcd meet error", zap.String("key", key), errs.ZapError(e)) + return resp, e + } + return resp, nil +} + +// GetValue gets value with key from etcd. +func GetValue(c *clientv3.Client, key string, opts ...clientv3.OpOption) ([]byte, error) { + resp, err := get(c, key, opts...) + if err != nil { + return nil, err + } + if resp == nil { + return nil, nil + } + return resp.Kvs[0].Value, nil +} + +func get(c *clientv3.Client, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + resp, err := EtcdKVGet(c, key, opts...) + if err != nil { + return nil, err + } + + if n := len(resp.Kvs); n == 0 { + return nil, nil + } else if n > 1 { + return nil, errs.ErrEtcdKVGetResponse.FastGenByArgs(resp.Kvs) + } + return resp, nil +} + +// GetProtoMsgWithModRev returns boolean to indicate whether the key exists or not. +func GetProtoMsgWithModRev(c *clientv3.Client, key string, msg proto.Message, opts ...clientv3.OpOption) (bool, int64, error) { + resp, err := get(c, key, opts...) + if err != nil { + return false, 0, err + } + if resp == nil { + return false, 0, nil + } + value := resp.Kvs[0].Value + if err = proto.Unmarshal(value, msg); err != nil { + return false, 0, errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByCause() + } + return true, resp.Kvs[0].ModRevision, nil +} + +// EtcdKVPutWithTTL put (key, value) into etcd with a ttl of ttlSeconds +func EtcdKVPutWithTTL(ctx context.Context, c *clientv3.Client, key string, value string, ttlSeconds int64) (*clientv3.PutResponse, error) { + kv := clientv3.NewKV(c) + grantResp, err := c.Grant(ctx, ttlSeconds) + if err != nil { + return nil, err + } + return kv.Put(ctx, key, value, clientv3.WithLease(grantResp.ID)) +} + +// NewTestSingleConfig is used to create a etcd config for the unit test purpose. +func NewTestSingleConfig(t *testing.T) *embed.Config { + cfg := embed.NewConfig() + cfg.Name = "test_etcd" + cfg.Dir = t.TempDir() + cfg.WalDir = "" + cfg.Logger = "zap" + cfg.LogOutputs = []string{"stdout"} + + pu, _ := url.Parse(tempurl.Alloc()) + cfg.LPUrls = []url.URL{*pu} + cfg.APUrls = cfg.LPUrls + cu, _ := url.Parse(tempurl.Alloc()) + cfg.LCUrls = []url.URL{*cu} + cfg.ACUrls = cfg.LCUrls + + cfg.StrictReconfigCheck = false + cfg.InitialCluster = fmt.Sprintf("%s=%s", cfg.Name, &cfg.LPUrls[0]) + cfg.ClusterState = embed.ClusterStateFlagNew + return cfg +} diff --git a/server/api/binding__failpoint_binding__.go b/server/api/binding__failpoint_binding__.go new file mode 100755 index 00000000000..d53dfb178d5 --- /dev/null +++ b/server/api/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package api + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/api/middleware.go b/server/api/middleware.go old mode 100644 new mode 100755 index 04d417233da..28d8ba13858 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -64,14 +64,14 @@ func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques requestInfo := requestutil.GetRequestInfo(r) r = r.WithContext(requestutil.WithRequestInfo(r.Context(), requestInfo)) - failpoint.Inject("addRequestInfoMiddleware", func() { + if _, _err_ := failpoint.Eval(_curpkg_("addRequestInfoMiddleware")); _err_ == nil { w.Header().Add("service-label", requestInfo.ServiceLabel) w.Header().Add("body-param", requestInfo.BodyParam) w.Header().Add("url-param", requestInfo.URLParam) w.Header().Add("method", requestInfo.Method) w.Header().Add("component", requestInfo.Component) w.Header().Add("ip", requestInfo.IP) - }) + } next(w, r) } diff --git a/server/api/middleware.go__failpoint_stash__ b/server/api/middleware.go__failpoint_stash__ new file mode 100644 index 00000000000..04d417233da --- /dev/null +++ b/server/api/middleware.go__failpoint_stash__ @@ -0,0 +1,186 @@ +// Copyright 2019 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "net/http" + "time" + + "github.com/pingcap/failpoint" + "github.com/tikv/pd/pkg/audit" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/requestutil" + "github.com/tikv/pd/server" + "github.com/tikv/pd/server/cluster" + "github.com/unrolled/render" + "github.com/urfave/negroni" +) + +// serviceMiddlewareBuilder is used to build service middleware for HTTP api +type serviceMiddlewareBuilder struct { + svr *server.Server + handlers []negroni.Handler +} + +func newServiceMiddlewareBuilder(s *server.Server) *serviceMiddlewareBuilder { + return &serviceMiddlewareBuilder{ + svr: s, + handlers: []negroni.Handler{newRequestInfoMiddleware(s), newAuditMiddleware(s), newRateLimitMiddleware(s)}, + } +} + +func (s *serviceMiddlewareBuilder) createHandler(next func(http.ResponseWriter, *http.Request)) http.Handler { + return negroni.New(append(s.handlers, negroni.WrapFunc(next))...) +} + +// requestInfoMiddleware is used to gather info from requsetInfo +type requestInfoMiddleware struct { + svr *server.Server +} + +func newRequestInfoMiddleware(s *server.Server) negroni.Handler { + return &requestInfoMiddleware{svr: s} +} + +func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if !rm.svr.GetServiceMiddlewarePersistOptions().IsAuditEnabled() && !rm.svr.GetServiceMiddlewarePersistOptions().IsRateLimitEnabled() { + next(w, r) + return + } + + requestInfo := requestutil.GetRequestInfo(r) + r = r.WithContext(requestutil.WithRequestInfo(r.Context(), requestInfo)) + + failpoint.Inject("addRequestInfoMiddleware", func() { + w.Header().Add("service-label", requestInfo.ServiceLabel) + w.Header().Add("body-param", requestInfo.BodyParam) + w.Header().Add("url-param", requestInfo.URLParam) + w.Header().Add("method", requestInfo.Method) + w.Header().Add("component", requestInfo.Component) + w.Header().Add("ip", requestInfo.IP) + }) + + next(w, r) +} + +type clusterMiddleware struct { + s *server.Server + rd *render.Render +} + +func newClusterMiddleware(s *server.Server) clusterMiddleware { + return clusterMiddleware{ + s: s, + rd: render.New(render.Options{IndentJSON: true}), + } +} + +func (m clusterMiddleware) Middleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rc := m.s.GetRaftCluster() + if rc == nil { + m.rd.JSON(w, http.StatusInternalServerError, errs.ErrNotBootstrapped.FastGenByArgs().Error()) + return + } + ctx := context.WithValue(r.Context(), clusterCtxKey{}, rc) + h.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +type clusterCtxKey struct{} + +func getCluster(r *http.Request) *cluster.RaftCluster { + return r.Context().Value(clusterCtxKey{}).(*cluster.RaftCluster) +} + +type auditMiddleware struct { + svr *server.Server +} + +func newAuditMiddleware(s *server.Server) negroni.Handler { + return &auditMiddleware{svr: s} +} + +// ServeHTTP is used to implememt negroni.Handler for auditMiddleware +func (s *auditMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if !s.svr.GetServiceMiddlewarePersistOptions().IsAuditEnabled() { + next(w, r) + return + } + + requestInfo, ok := requestutil.RequestInfoFrom(r.Context()) + if !ok { + requestInfo = requestutil.GetRequestInfo(r) + } + + labels := s.svr.GetServiceAuditBackendLabels(requestInfo.ServiceLabel) + if labels == nil { + next(w, r) + return + } + + beforeNextBackends := make([]audit.Backend, 0) + afterNextBackends := make([]audit.Backend, 0) + for _, backend := range s.svr.GetAuditBackend() { + if backend.Match(labels) { + if backend.ProcessBeforeHandler() { + beforeNextBackends = append(beforeNextBackends, backend) + } else { + afterNextBackends = append(afterNextBackends, backend) + } + } + } + for _, backend := range beforeNextBackends { + backend.ProcessHTTPRequest(r) + } + + next(w, r) + + endTime := time.Now().Unix() + r = r.WithContext(requestutil.WithEndTime(r.Context(), endTime)) + for _, backend := range afterNextBackends { + backend.ProcessHTTPRequest(r) + } +} + +type rateLimitMiddleware struct { + svr *server.Server +} + +func newRateLimitMiddleware(s *server.Server) negroni.Handler { + return &rateLimitMiddleware{svr: s} +} + +// ServeHTTP is used to implememt negroni.Handler for rateLimitMiddleware +func (s *rateLimitMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if !s.svr.GetServiceMiddlewarePersistOptions().IsRateLimitEnabled() { + next(w, r) + return + } + requestInfo, ok := requestutil.RequestInfoFrom(r.Context()) + if !ok { + requestInfo = requestutil.GetRequestInfo(r) + } + + // There is no need to check whether rateLimiter is nil. CreateServer ensures that it is created + rateLimiter := s.svr.GetServiceRateLimiter() + if rateLimiter.Allow(requestInfo.ServiceLabel) { + defer rateLimiter.Release(requestInfo.ServiceLabel) + next(w, r) + } else { + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + } +} diff --git a/server/api/region.go b/server/api/region.go old mode 100644 new mode 100755 index 796c3acafa4..d1c70d74be4 --- a/server/api/region.go +++ b/server/api/region.go @@ -290,12 +290,12 @@ func (h *regionsHandler) CheckRegionsReplicated(w http.ResponseWriter, r *http.R } } } - failpoint.Inject("mockPending", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockPending")); _err_ == nil { aok, ok := val.(bool) if ok && aok { state = "PENDING" } - }) + } h.rd.JSON(w, http.StatusOK, state) } @@ -973,13 +973,13 @@ func (h *regionsHandler) SplitRegions(w http.ResponseWriter, r *http.Request) { percentage, newRegionsID := rc.GetRegionSplitter().SplitRegions(r.Context(), splitKeys, retryLimit) s.ProcessedPercentage = percentage s.NewRegionsID = newRegionsID - failpoint.Inject("splitResponses", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("splitResponses")); _err_ == nil { rawID, ok := val.(int) if ok { s.ProcessedPercentage = 100 s.NewRegionsID = []uint64{uint64(rawID)} } - }) + } h.rd.JSON(w, http.StatusOK, &s) } diff --git a/server/api/region.go__failpoint_stash__ b/server/api/region.go__failpoint_stash__ new file mode 100644 index 00000000000..796c3acafa4 --- /dev/null +++ b/server/api/region.go__failpoint_stash__ @@ -0,0 +1,1044 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "container/heap" + "encoding/hex" + "fmt" + "net/http" + "net/url" + "sort" + "strconv" + + "github.com/gorilla/mux" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/kvproto/pkg/replication_modepb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/typeutil" + "github.com/tikv/pd/server" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/schedule/filter" + "github.com/tikv/pd/server/statistics" + "github.com/unrolled/render" + "go.uber.org/zap" +) + +// MetaPeer is api compatible with *metapb.Peer. +// NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. +type MetaPeer struct { + *metapb.Peer + // RoleName is `Role.String()`. + // Since Role is serialized as int by json by default, + // introducing it will make the output of pd-ctl easier to identify Role. + RoleName string `json:"role_name"` + // IsLearner is `Role == "Learner"`. + // Since IsLearner was changed to Role in kvproto in 5.0, this field was introduced to ensure api compatibility. + IsLearner bool `json:"is_learner,omitempty"` +} + +// PDPeerStats is api compatible with *pdpb.PeerStats. +// NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. +type PDPeerStats struct { + *pdpb.PeerStats + Peer MetaPeer `json:"peer"` +} + +func fromPeer(peer *metapb.Peer) MetaPeer { + if peer == nil { + return MetaPeer{} + } + return MetaPeer{ + Peer: peer, + RoleName: peer.GetRole().String(), + IsLearner: core.IsLearner(peer), + } +} + +func fromPeerSlice(peers []*metapb.Peer) []MetaPeer { + if peers == nil { + return nil + } + slice := make([]MetaPeer, len(peers)) + for i, peer := range peers { + slice[i] = fromPeer(peer) + } + return slice +} + +func fromPeerStats(peer *pdpb.PeerStats) PDPeerStats { + return PDPeerStats{ + PeerStats: peer, + Peer: fromPeer(peer.Peer), + } +} + +func fromPeerStatsSlice(peers []*pdpb.PeerStats) []PDPeerStats { + if peers == nil { + return nil + } + slice := make([]PDPeerStats, len(peers)) + for i, peer := range peers { + slice[i] = fromPeerStats(peer) + } + return slice +} + +// RegionInfo records detail region info for api usage. +// NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. +type RegionInfo struct { + ID uint64 `json:"id"` + StartKey string `json:"start_key"` + EndKey string `json:"end_key"` + RegionEpoch *metapb.RegionEpoch `json:"epoch,omitempty"` + Peers []MetaPeer `json:"peers,omitempty"` + + Leader MetaPeer `json:"leader,omitempty"` + DownPeers []PDPeerStats `json:"down_peers,omitempty"` + PendingPeers []MetaPeer `json:"pending_peers,omitempty"` + CPUUsage uint64 `json:"cpu_usage"` + WrittenBytes uint64 `json:"written_bytes"` + ReadBytes uint64 `json:"read_bytes"` + WrittenKeys uint64 `json:"written_keys"` + ReadKeys uint64 `json:"read_keys"` + ApproximateSize int64 `json:"approximate_size"` + ApproximateKeys int64 `json:"approximate_keys"` + Buckets []string `json:"buckets,omitempty"` + + ReplicationStatus *ReplicationStatus `json:"replication_status,omitempty"` +} + +// ReplicationStatus represents the replication mode status of the region. +// NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. +type ReplicationStatus struct { + State string `json:"state"` + StateID uint64 `json:"state_id"` +} + +func fromPBReplicationStatus(s *replication_modepb.RegionReplicationStatus) *ReplicationStatus { + if s == nil { + return nil + } + return &ReplicationStatus{ + State: s.GetState().String(), + StateID: s.GetStateId(), + } +} + +// NewAPIRegionInfo create a new API RegionInfo. +func NewAPIRegionInfo(r *core.RegionInfo) *RegionInfo { + return InitRegion(r, &RegionInfo{}) +} + +// InitRegion init a new API RegionInfo from the core.RegionInfo. +func InitRegion(r *core.RegionInfo, s *RegionInfo) *RegionInfo { + if r == nil { + return nil + } + + s.ID = r.GetID() + s.StartKey = core.HexRegionKeyStr(r.GetStartKey()) + s.EndKey = core.HexRegionKeyStr(r.GetEndKey()) + s.RegionEpoch = r.GetRegionEpoch() + s.Peers = fromPeerSlice(r.GetPeers()) + s.Leader = fromPeer(r.GetLeader()) + s.DownPeers = fromPeerStatsSlice(r.GetDownPeers()) + s.PendingPeers = fromPeerSlice(r.GetPendingPeers()) + s.CPUUsage = r.GetCPUUsage() + s.WrittenBytes = r.GetBytesWritten() + s.WrittenKeys = r.GetKeysWritten() + s.ReadBytes = r.GetBytesRead() + s.ReadKeys = r.GetKeysRead() + s.ApproximateSize = r.GetApproximateSize() + s.ApproximateKeys = r.GetApproximateKeys() + s.ReplicationStatus = fromPBReplicationStatus(r.GetReplicationStatus()) + + keys := r.GetBuckets().GetKeys() + + if len(keys) > 0 { + s.Buckets = make([]string, len(keys)) + for i, key := range keys { + s.Buckets[i] = core.HexRegionKeyStr(key) + } + } + return s +} + +// Adjust is only used in testing, in order to compare the data from json deserialization. +func (r *RegionInfo) Adjust() { + for _, peer := range r.DownPeers { + // Since api.PDPeerStats uses the api.MetaPeer type variable Peer to overwrite PeerStats.Peer, + // it needs to be restored after deserialization to be completely consistent with the original. + peer.PeerStats.Peer = peer.Peer.Peer + } +} + +// RegionsInfo contains some regions with the detailed region info. +type RegionsInfo struct { + Count int `json:"count"` + Regions []RegionInfo `json:"regions"` +} + +// Adjust is only used in testing, in order to compare the data from json deserialization. +func (s *RegionsInfo) Adjust() { + for _, r := range s.Regions { + r.Adjust() + } +} + +type regionHandler struct { + svr *server.Server + rd *render.Render +} + +func newRegionHandler(svr *server.Server, rd *render.Render) *regionHandler { + return ®ionHandler{ + svr: svr, + rd: rd, + } +} + +// @Tags region +// @Summary Search for a region by region ID. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {object} RegionInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /region/id/{id} [get] +func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + + vars := mux.Vars(r) + regionIDStr := vars["id"] + regionID, err := strconv.ParseUint(regionIDStr, 10, 64) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + + regionInfo := rc.GetRegion(regionID) + h.rd.JSON(w, http.StatusOK, NewAPIRegionInfo(regionInfo)) +} + +// @Tags region +// @Summary Search for a region by a key. GetRegion is named to be consistent with gRPC +// @Param key path string true "Region key" +// @Produce json +// @Success 200 {object} RegionInfo +// @Router /region/key/{key} [get] +func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + vars := mux.Vars(r) + key := vars["key"] + key, err := url.QueryUnescape(key) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + regionInfo := rc.GetRegionByKey([]byte(key)) + h.rd.JSON(w, http.StatusOK, NewAPIRegionInfo(regionInfo)) +} + +// @Tags region +// @Summary Check if regions in the given key ranges are replicated. Returns 'REPLICATED', 'INPROGRESS', or 'PENDING'. 'PENDING' means that there is at least one region pending for scheduling. Similarly, 'INPROGRESS' means there is at least one region in scheduling. +// @Param startKey query string true "Regions start key, hex encoded" +// @Param endKey query string true "Regions end key, hex encoded" +// @Produce plain +// @Success 200 {string} string "INPROGRESS" +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/replicated [get] +func (h *regionsHandler) CheckRegionsReplicated(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + + vars := mux.Vars(r) + startKeyHex := vars["startKey"] + startKey, err := hex.DecodeString(startKeyHex) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + endKeyHex := vars["endKey"] + endKey, err := hex.DecodeString(endKeyHex) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + + regions := rc.ScanRegions(startKey, endKey, -1) + state := "REPLICATED" + for _, region := range regions { + if !filter.IsRegionReplicated(rc, region) { + state = "INPROGRESS" + if rc.GetCoordinator().IsPendingRegion(region.GetID()) { + state = "PENDING" + break + } + } + } + failpoint.Inject("mockPending", func(val failpoint.Value) { + aok, ok := val.(bool) + if ok && aok { + state = "PENDING" + } + }) + h.rd.JSON(w, http.StatusOK, state) +} + +type regionsHandler struct { + svr *server.Server + rd *render.Render +} + +func newRegionsHandler(svr *server.Server, rd *render.Render) *regionsHandler { + return ®ionsHandler{ + svr: svr, + rd: rd, + } +} + +func convertToAPIRegions(regions []*core.RegionInfo) *RegionsInfo { + regionInfos := make([]RegionInfo, len(regions)) + for i, r := range regions { + InitRegion(r, ®ionInfos[i]) + } + return &RegionsInfo{ + Count: len(regions), + Regions: regionInfos, + } +} + +// @Tags region +// @Summary List all regions in the cluster. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Router /regions [get] +func (h *regionsHandler) GetRegions(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + regions := rc.GetRegions() + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List regions in a given range [startKey, endKey). +// @Param key query string true "Region range start key" +// @Param endkey query string true "Region range end key" +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/key [get] +func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + startKey := r.URL.Query().Get("key") + endKey := r.URL.Query().Get("end_key") + + limit := defaultRegionLimit + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + var err error + limit, err = strconv.Atoi(limitStr) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + } + if limit > maxRegionLimit { + limit = maxRegionLimit + } + regions := rc.ScanRegions([]byte(startKey), []byte(endKey), limit) + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary Get count of regions. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Router /regions/count [get] +func (h *regionsHandler) GetRegionCount(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + count := rc.GetRegionCount() + h.rd.JSON(w, http.StatusOK, &RegionsInfo{Count: count}) +} + +// @Tags region +// @Summary List all regions of a specific store. +// @Param id path integer true "Store Id" +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/store/{id} [get] +func (h *regionsHandler) GetStoreRegions(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + + vars := mux.Vars(r) + id, err := strconv.Atoi(vars["id"]) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + regions := rc.GetStoreRegions(uint64(id)) + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List all regions that miss peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/miss-peer [get] +func (h *regionsHandler) GetMissPeerRegions(w http.ResponseWriter, r *http.Request) { + handler := h.svr.GetHandler() + regions, err := handler.GetRegionsByType(statistics.MissPeer) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List all regions that has extra peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/extra-peer [get] +func (h *regionsHandler) GetExtraPeerRegions(w http.ResponseWriter, r *http.Request) { + handler := h.svr.GetHandler() + regions, err := handler.GetRegionsByType(statistics.ExtraPeer) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List all regions that has pending peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/pending-peer [get] +func (h *regionsHandler) GetPendingPeerRegions(w http.ResponseWriter, r *http.Request) { + handler := h.svr.GetHandler() + regions, err := handler.GetRegionsByType(statistics.PendingPeer) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List all regions that has down peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/down-peer [get] +func (h *regionsHandler) GetDownPeerRegions(w http.ResponseWriter, r *http.Request) { + handler := h.svr.GetHandler() + regions, err := handler.GetRegionsByType(statistics.DownPeer) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List all regions that has learner peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/learner-peer [get] +func (h *regionsHandler) GetLearnerPeerRegions(w http.ResponseWriter, r *http.Request) { + handler := h.svr.GetHandler() + regions, err := handler.GetRegionsByType(statistics.LearnerPeer) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List all regions that has offline peer. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/offline-peer [get] +func (h *regionsHandler) GetOfflinePeerRegions(w http.ResponseWriter, r *http.Request) { + handler := h.svr.GetHandler() + regions, err := handler.GetOfflinePeer(statistics.OfflinePeer) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List all regions that are oversized. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/oversized-region [get] +func (h *regionsHandler) GetOverSizedRegions(w http.ResponseWriter, r *http.Request) { + handler := h.svr.GetHandler() + regions, err := handler.GetRegionsByType(statistics.OversizedRegion) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List all regions that are undersized. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/undersized-region [get] +func (h *regionsHandler) GetUndersizedRegions(w http.ResponseWriter, r *http.Request) { + handler := h.svr.GetHandler() + regions, err := handler.GetRegionsByType(statistics.UndersizedRegion) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary List all empty regions. +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /regions/check/empty-region [get] +func (h *regionsHandler) GetEmptyRegions(w http.ResponseWriter, r *http.Request) { + handler := h.svr.GetHandler() + regions, err := handler.GetRegionsByType(statistics.EmptyRegion) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +type histItem struct { + Start int64 `json:"start"` + End int64 `json:"end"` + Count int64 `json:"count"` +} + +type histSlice []*histItem + +func (hist histSlice) Len() int { + return len(hist) +} + +func (hist histSlice) Swap(i, j int) { + hist[i], hist[j] = hist[j], hist[i] +} + +func (hist histSlice) Less(i, j int) bool { + return hist[i].Start < hist[j].Start +} + +// @Tags region +// @Summary Get size of histogram. +// @Param bound query integer false "Size bound of region histogram" minimum(1) +// @Produce json +// @Success 200 {array} histItem +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/check/hist-size [get] +func (h *regionsHandler) GetSizeHistogram(w http.ResponseWriter, r *http.Request) { + bound := minRegionHistogramSize + bound, err := calBound(bound, r) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + rc := getCluster(r) + regions := rc.GetRegions() + histSizes := make([]int64, 0, len(regions)) + for _, region := range regions { + histSizes = append(histSizes, region.GetApproximateSize()) + } + histItems := calHist(bound, &histSizes) + h.rd.JSON(w, http.StatusOK, histItems) +} + +// @Tags region +// @Summary Get keys of histogram. +// @Param bound query integer false "Key bound of region histogram" minimum(1000) +// @Produce json +// @Success 200 {array} histItem +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/check/hist-keys [get] +func (h *regionsHandler) GetKeysHistogram(w http.ResponseWriter, r *http.Request) { + bound := minRegionHistogramKeys + bound, err := calBound(bound, r) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + rc := getCluster(r) + regions := rc.GetRegions() + histKeys := make([]int64, 0, len(regions)) + for _, region := range regions { + histKeys = append(histKeys, region.GetApproximateKeys()) + } + histItems := calHist(bound, &histKeys) + h.rd.JSON(w, http.StatusOK, histItems) +} + +func calBound(bound int, r *http.Request) (int, error) { + if boundStr := r.URL.Query().Get("bound"); boundStr != "" { + boundInput, err := strconv.Atoi(boundStr) + if err != nil { + return -1, err + } + if bound < boundInput { + bound = boundInput + } + } + return bound, nil +} + +func calHist(bound int, list *[]int64) *[]*histItem { + var histMap = make(map[int64]int) + for _, item := range *list { + multiple := item / int64(bound) + if oldCount, ok := histMap[multiple]; ok { + histMap[multiple] = oldCount + 1 + } else { + histMap[multiple] = 1 + } + } + histItems := make([]*histItem, 0, len(histMap)) + for multiple, count := range histMap { + histInfo := &histItem{} + histInfo.Start = multiple * int64(bound) + histInfo.End = (multiple+1)*int64(bound) - 1 + histInfo.Count = int64(count) + histItems = append(histItems, histInfo) + } + sort.Sort(histSlice(histItems)) + return &histItems +} + +// @Tags region +// @Summary List all range holes whitout any region info. +// @Produce json +// @Success 200 {object} [][]string +// @Router /regions/range-holes [get] +func (h *regionsHandler) GetRangeHoles(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + h.rd.JSON(w, http.StatusOK, rc.GetRangeHoles()) +} + +// @Tags region +// @Summary List sibling regions of a specific region. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Router /regions/sibling/{id} [get] +func (h *regionsHandler) GetRegionSiblings(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + + vars := mux.Vars(r) + id, err := strconv.Atoi(vars["id"]) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + region := rc.GetRegion(uint64(id)) + if region == nil { + h.rd.JSON(w, http.StatusNotFound, server.ErrRegionNotFound(uint64(id)).Error()) + return + } + + left, right := rc.GetAdjacentRegions(region) + regionsInfo := convertToAPIRegions([]*core.RegionInfo{left, right}) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +const ( + defaultRegionLimit = 16 + maxRegionLimit = 10240 + minRegionHistogramSize = 1 + minRegionHistogramKeys = 1000 +) + +// @Tags region +// @Summary List regions with the highest write flow. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/writeflow [get] +func (h *regionsHandler) GetTopWriteFlowRegions(w http.ResponseWriter, r *http.Request) { + h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { return a.GetBytesWritten() < b.GetBytesWritten() }) +} + +// @Tags region +// @Summary List regions with the highest read flow. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/readflow [get] +func (h *regionsHandler) GetTopReadFlowRegions(w http.ResponseWriter, r *http.Request) { + h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { return a.GetBytesRead() < b.GetBytesRead() }) +} + +// @Tags region +// @Summary List regions with the largest conf version. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/confver [get] +func (h *regionsHandler) GetTopConfVerRegions(w http.ResponseWriter, r *http.Request) { + h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { + return a.GetMeta().GetRegionEpoch().GetConfVer() < b.GetMeta().GetRegionEpoch().GetConfVer() + }) +} + +// @Tags region +// @Summary List regions with the largest version. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/version [get] +func (h *regionsHandler) GetTopVersionRegions(w http.ResponseWriter, r *http.Request) { + h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { + return a.GetMeta().GetRegionEpoch().GetVersion() < b.GetMeta().GetRegionEpoch().GetVersion() + }) +} + +// @Tags region +// @Summary List regions with the largest size. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/size [get] +func (h *regionsHandler) GetTopSizeRegions(w http.ResponseWriter, r *http.Request) { + h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { + return a.GetApproximateSize() < b.GetApproximateSize() + }) +} + +// @Tags region +// @Summary List regions with the largest keys. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/keys [get] +func (h *regionsHandler) GetTopKeysRegions(w http.ResponseWriter, r *http.Request) { + h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { + return a.GetApproximateKeys() < b.GetApproximateKeys() + }) +} + +// @Tags region +// @Summary List regions with the highest CPU usage. +// @Param limit query integer false "Limit count" default(16) +// @Produce json +// @Success 200 {object} RegionsInfo +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/cpu [get] +func (h *regionsHandler) GetTopCPURegions(w http.ResponseWriter, r *http.Request) { + h.GetTopNRegions(w, r, func(a, b *core.RegionInfo) bool { + return a.GetCPUUsage() < b.GetCPUUsage() + }) +} + +// @Tags region +// @Summary Accelerate regions scheduling a in given range, only receive hex format for keys +// @Accept json +// @Param body body object true "json params" +// @Param limit query integer false "Limit count" default(256) +// @Produce json +// @Success 200 {string} string "Accelerate regions scheduling in a given range [startKey, endKey)" +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/accelerate-schedule [post] +func (h *regionsHandler) AccelerateRegionsScheduleInRange(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + var input map[string]interface{} + if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { + return + } + startKey, rawStartKey, err := apiutil.ParseKey("start_key", input) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + + endKey, rawEndKey, err := apiutil.ParseKey("end_key", input) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + + limit := 256 + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + var err error + limit, err = strconv.Atoi(limitStr) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + } + if limit > maxRegionLimit { + limit = maxRegionLimit + } + + regions := rc.ScanRegions(startKey, endKey, limit) + if len(regions) > 0 { + regionsIDList := make([]uint64, 0, len(regions)) + for _, region := range regions { + regionsIDList = append(regionsIDList, region.GetID()) + } + rc.AddSuspectRegions(regionsIDList...) + } + h.rd.Text(w, http.StatusOK, fmt.Sprintf("Accelerate regions scheduling in a given range [%s,%s)", rawStartKey, rawEndKey)) +} + +func (h *regionsHandler) GetTopNRegions(w http.ResponseWriter, r *http.Request, less func(a, b *core.RegionInfo) bool) { + rc := getCluster(r) + limit := defaultRegionLimit + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + var err error + limit, err = strconv.Atoi(limitStr) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + } + if limit > maxRegionLimit { + limit = maxRegionLimit + } + regions := TopNRegions(rc.GetRegions(), less, limit) + regionsInfo := convertToAPIRegions(regions) + h.rd.JSON(w, http.StatusOK, regionsInfo) +} + +// @Tags region +// @Summary Scatter regions by given key ranges or regions id distributed by given group with given retry limit +// @Accept json +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Scatter regions by given key ranges or regions id distributed by given group with given retry limit" +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/scatter [post] +func (h *regionsHandler) ScatterRegions(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + var input map[string]interface{} + if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { + return + } + _, ok1 := input["start_key"].(string) + _, ok2 := input["end_key"].(string) + group, ok := input["group"].(string) + if !ok { + group = "" + } + retryLimit := 5 + if rl, ok := input["retry_limit"].(float64); ok { + retryLimit = int(rl) + } + opsCount := 0 + var failures map[uint64]error + var err error + if ok1 && ok2 { + startKey, _, err := apiutil.ParseKey("start_key", input) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + endKey, _, err := apiutil.ParseKey("end_key", input) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + opsCount, failures, err = rc.GetRegionScatter().ScatterRegionsByRange(startKey, endKey, group, retryLimit) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + } else { + ids, ok := typeutil.JSONToUint64Slice(input["regions_id"]) + if !ok { + h.rd.JSON(w, http.StatusBadRequest, "regions_id is invalid") + return + } + opsCount, failures, err = rc.GetRegionScatter().ScatterRegionsByID(ids, group, retryLimit) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + } + // If there existed any operator failed to be added into Operator Controller, add its regions into unProcessedRegions + percentage := 100 + if len(failures) > 0 { + percentage = 100 - 100*len(failures)/(opsCount+len(failures)) + log.Debug("scatter regions", zap.Errors("failures", func() []error { + r := make([]error, 0, len(failures)) + for _, err := range failures { + r = append(r, err) + } + return r + }())) + } + s := struct { + ProcessedPercentage int `json:"processed-percentage"` + }{ + ProcessedPercentage: percentage, + } + h.rd.JSON(w, http.StatusOK, &s) +} + +// @Tags region +// @Summary Split regions with given split keys +// @Accept json +// @Param body body object true "json params" +// @Produce json +// @Success 200 {string} string "Split regions with given split keys" +// @Failure 400 {string} string "The input is invalid." +// @Router /regions/split [post] +func (h *regionsHandler) SplitRegions(w http.ResponseWriter, r *http.Request) { + rc := getCluster(r) + var input map[string]interface{} + if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { + return + } + rawSplitKeys, ok := input["split_keys"].([]interface{}) + if !ok { + h.rd.JSON(w, http.StatusBadRequest, "split_keys should be provided.") + return + } + if len(rawSplitKeys) < 1 { + h.rd.JSON(w, http.StatusBadRequest, "empty split keys.") + return + } + retryLimit := 5 + if rl, ok := input["retry_limit"].(float64); ok { + retryLimit = int(rl) + } + splitKeys := make([][]byte, 0, len(rawSplitKeys)) + for _, rawKey := range rawSplitKeys { + key, err := hex.DecodeString(rawKey.(string)) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + splitKeys = append(splitKeys, key) + } + s := struct { + ProcessedPercentage int `json:"processed-percentage"` + NewRegionsID []uint64 `json:"regions-id"` + }{} + percentage, newRegionsID := rc.GetRegionSplitter().SplitRegions(r.Context(), splitKeys, retryLimit) + s.ProcessedPercentage = percentage + s.NewRegionsID = newRegionsID + failpoint.Inject("splitResponses", func(val failpoint.Value) { + rawID, ok := val.(int) + if ok { + s.ProcessedPercentage = 100 + s.NewRegionsID = []uint64{uint64(rawID)} + } + }) + h.rd.JSON(w, http.StatusOK, &s) +} + +// RegionHeap implements heap.Interface, used for selecting top n regions. +type RegionHeap struct { + regions []*core.RegionInfo + less func(a, b *core.RegionInfo) bool +} + +func (h *RegionHeap) Len() int { return len(h.regions) } +func (h *RegionHeap) Less(i, j int) bool { return h.less(h.regions[i], h.regions[j]) } +func (h *RegionHeap) Swap(i, j int) { h.regions[i], h.regions[j] = h.regions[j], h.regions[i] } + +// Push pushes an element x onto the heap. +func (h *RegionHeap) Push(x interface{}) { + h.regions = append(h.regions, x.(*core.RegionInfo)) +} + +// Pop removes the minimum element (according to Less) from the heap and returns +// it. +func (h *RegionHeap) Pop() interface{} { + pos := len(h.regions) - 1 + x := h.regions[pos] + h.regions = h.regions[:pos] + return x +} + +// Min returns the minimum region from the heap. +func (h *RegionHeap) Min() *core.RegionInfo { + if h.Len() == 0 { + return nil + } + return h.regions[0] +} + +// TopNRegions returns top n regions according to the given rule. +func TopNRegions(regions []*core.RegionInfo, less func(a, b *core.RegionInfo) bool, n int) []*core.RegionInfo { + if n <= 0 { + return nil + } + + hp := &RegionHeap{ + regions: make([]*core.RegionInfo, 0, n), + less: less, + } + for _, r := range regions { + if hp.Len() < n { + heap.Push(hp, r) + continue + } + if less(hp.Min(), r) { + heap.Pop(hp) + heap.Push(hp, r) + } + } + + res := make([]*core.RegionInfo, hp.Len()) + for i := hp.Len() - 1; i >= 0; i-- { + res[i] = heap.Pop(hp).(*core.RegionInfo) + } + return res +} diff --git a/server/api/router.go b/server/api/router.go old mode 100644 new mode 100755 index 2c750b12eb7..a2874eb7d7c --- a/server/api/router.go +++ b/server/api/router.go @@ -359,14 +359,13 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { unsafeOperationHandler.GetFailedStoresRemovalStatus, setMethods(http.MethodGet), setAuditBackend(prometheus)) // API to set or unset failpoints - failpoint.Inject("enableFailpointAPI", func() { - // this function will be named to "func2". It may be used in test - registerPrefix(apiRouter, "/fail", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // The HTTP handler of failpoint requires the full path to be the failpoint path. - r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix+apiPrefix+"/fail") - new(failpoint.HttpHandler).ServeHTTP(w, r) - }), setAuditBackend("test")) - }) + + // this function will be named to "func2". It may be used in test + registerPrefix(apiRouter, "/fail", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The HTTP handler of failpoint requires the full path to be the failpoint path. + r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix+apiPrefix+"/fail") + new(failpoint.HttpHandler).ServeHTTP(w, r) + }), setAuditBackend("test")) // Deprecated: use /pd/api/v1/health instead. rootRouter.HandleFunc("/health", healthHandler.GetHealthStatus).Methods(http.MethodGet) diff --git a/server/api/router.go__failpoint_stash__ b/server/api/router.go__failpoint_stash__ new file mode 100644 index 00000000000..2c750b12eb7 --- /dev/null +++ b/server/api/router.go__failpoint_stash__ @@ -0,0 +1,394 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "net/http" + "net/http/pprof" + "reflect" + "runtime" + "strings" + + "github.com/gorilla/mux" + "github.com/pingcap/failpoint" + "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/audit" + "github.com/tikv/pd/pkg/ratelimit" + "github.com/tikv/pd/server" + "github.com/unrolled/render" +) + +// createRouteOption is used to register service for mux.Route +type createRouteOption func(route *mux.Route) + +// setMethods is used to add HTTP Method matcher for mux.Route +func setMethods(method ...string) createRouteOption { + return func(route *mux.Route) { + route.Methods(method...) + } +} + +// setQueries is used to add queries for mux.Route +func setQueries(pairs ...string) createRouteOption { + return func(route *mux.Route) { + route.Queries(pairs...) + } +} + +// routeCreateFunc is used to registers a new route which will be registered matcher or service by opts for the URL path +func routeCreateFunc(route *mux.Route, handler http.Handler, name string, opts ...createRouteOption) { + route = route.Handler(handler).Name(name) + for _, opt := range opts { + opt(route) + } +} + +func createStreamingRender() *render.Render { + return render.New(render.Options{ + StreamingJSON: true, + }) +} + +func createIndentRender() *render.Render { + return render.New(render.Options{ + IndentJSON: true, + }) +} + +func getFunctionName(f interface{}) string { + strs := strings.Split(runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name(), ".") + return strings.Split(strs[len(strs)-1], "-")[0] +} + +// The returned function is used as a lazy router to avoid the data race problem. +// @title Placement Driver Core API +// @version 1.0 +// @description This is placement driver. +// @contact.name Placement Driver Support +// @contact.url https://github.com/tikv/pd/issues +// @contact.email info@pingcap.com +// @license.name Apache 2.0 +// @license.url http://www.apache.org/licenses/LICENSE-2.0.html +// @BasePath /pd/api/v1 +func createRouter(prefix string, svr *server.Server) *mux.Router { + serviceMiddle := newServiceMiddlewareBuilder(svr) + registerPrefix := func(router *mux.Router, prefixPath string, + handleFunc func(http.ResponseWriter, *http.Request), opts ...createRouteOption) { + routeCreateFunc(router.PathPrefix(prefixPath), serviceMiddle.createHandler(handleFunc), + getFunctionName(handleFunc), opts...) + } + registerFunc := func(router *mux.Router, path string, + handleFunc func(http.ResponseWriter, *http.Request), opts ...createRouteOption) { + routeCreateFunc(router.Path(path), serviceMiddle.createHandler(handleFunc), + getFunctionName(handleFunc), opts...) + } + + setAuditBackend := func(labels ...string) createRouteOption { + return func(route *mux.Route) { + if len(labels) > 0 { + svr.SetServiceAuditBackendLabels(route.GetName(), labels) + } + } + } + + // localLog should be used in modifying the configuration or admin operations. + localLog := audit.LocalLogLabel + // prometheus will be used in all API. + prometheus := audit.PrometheusHistogram + + setRateLimitAllowList := func() createRouteOption { + return func(route *mux.Route) { + svr.UpdateServiceRateLimiter(route.GetName(), ratelimit.AddLabelAllowList()) + } + } + + rd := createIndentRender() + rootRouter := mux.NewRouter().PathPrefix(prefix).Subrouter() + handler := svr.GetHandler() + + apiPrefix := "/api/v1" + apiRouter := rootRouter.PathPrefix(apiPrefix).Subrouter() + + clusterRouter := apiRouter.NewRoute().Subrouter() + clusterRouter.Use(newClusterMiddleware(svr).Middleware) + + escapeRouter := clusterRouter.NewRoute().Subrouter().UseEncodedPath() + + operatorHandler := newOperatorHandler(handler, rd) + registerFunc(apiRouter, "/operators", operatorHandler.GetOperators, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/operators", operatorHandler.CreateOperator, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/operators/records", operatorHandler.GetOperatorRecords, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/operators/{region_id}", operatorHandler.GetOperatorsByRegion, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/operators/{region_id}", operatorHandler.DeleteOperatorByRegion, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + + checkerHandler := newCheckerHandler(svr, rd) + registerFunc(apiRouter, "/checker/{name}", checkerHandler.PauseOrResumeChecker, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/checker/{name}", checkerHandler.GetCheckerStatus, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + schedulerHandler := newSchedulerHandler(svr, rd) + registerFunc(apiRouter, "/schedulers", schedulerHandler.GetSchedulers, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/schedulers", schedulerHandler.CreateScheduler, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/schedulers/{name}", schedulerHandler.DeleteScheduler, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/schedulers/{name}", schedulerHandler.PauseOrResumeScheduler, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + + diagnosticHandler := newDiagnosticHandler(svr, rd) + registerFunc(clusterRouter, "/schedulers/diagnostic/{name}", diagnosticHandler.GetDiagnosticResult, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + schedulerConfigHandler := newSchedulerConfigHandler(svr, rd) + registerPrefix(apiRouter, "/scheduler-config", schedulerConfigHandler.GetSchedulerConfig, setAuditBackend(prometheus)) + + clusterHandler := newClusterHandler(svr, rd) + registerFunc(apiRouter, "/cluster", clusterHandler.GetCluster, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/cluster/status", clusterHandler.GetClusterStatus, setAuditBackend(prometheus)) + + confHandler := newConfHandler(svr, rd) + registerFunc(apiRouter, "/config", confHandler.GetConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/config", confHandler.SetConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/config/default", confHandler.GetDefaultConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/config/schedule", confHandler.GetScheduleConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/config/schedule", confHandler.SetScheduleConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/config/pd-server", confHandler.GetPDServerConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/config/replicate", confHandler.GetReplicationConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/config/replicate", confHandler.SetReplicationConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/config/label-property", confHandler.GetLabelPropertyConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/config/label-property", confHandler.SetLabelPropertyConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/config/cluster-version", confHandler.GetClusterVersion, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/config/cluster-version", confHandler.SetClusterVersion, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/config/replication-mode", confHandler.GetReplicationModeConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/config/replication-mode", confHandler.SetReplicationModeConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + + rulesHandler := newRulesHandler(svr, rd) + registerFunc(clusterRouter, "/config/rules", rulesHandler.GetAllRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/rules", rulesHandler.SetAllRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/config/rules/batch", rulesHandler.BatchRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/config/rules/group/{group}", rulesHandler.GetRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/rules/region/{region}", rulesHandler.GetRulesByRegion, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/rules/region/{region}/detail", rulesHandler.CheckRegionPlacementRule, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/rules/key/{key}", rulesHandler.GetRulesByKey, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/rule/{group}/{id}", rulesHandler.GetRuleByGroupAndID, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/rule", rulesHandler.SetRule, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/config/rule/{group}/{id}", rulesHandler.DeleteRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + + registerFunc(clusterRouter, "/config/rule_group/{id}", rulesHandler.GetGroupConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/rule_group", rulesHandler.SetGroupConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/config/rule_group/{id}", rulesHandler.DeleteGroupConfig, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/config/rule_groups", rulesHandler.GetAllGroupConfigs, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + registerFunc(clusterRouter, "/config/placement-rule", rulesHandler.GetPlacementRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/placement-rule", rulesHandler.SetPlacementRules, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + // {group} can be a regular expression, we should enable path encode to + // support special characters. + registerFunc(clusterRouter, "/config/placement-rule/{group}", rulesHandler.GetPlacementRuleByGroup, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/placement-rule/{group}", rulesHandler.SetPlacementRuleByGroup, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(escapeRouter, "/config/placement-rule/{group}", rulesHandler.DeletePlacementRuleByGroup, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + + regionLabelHandler := newRegionLabelHandler(svr, rd) + registerFunc(clusterRouter, "/config/region-label/rules", regionLabelHandler.GetAllRegionLabelRules, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/config/region-label/rules/ids", regionLabelHandler.GetRegionLabelRulesByIDs, setMethods(http.MethodGet), setAuditBackend(prometheus)) + // {id} can be a string with special characters, we should enable path encode to support it. + registerFunc(escapeRouter, "/config/region-label/rule/{id}", regionLabelHandler.GetRegionLabelRuleByID, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(escapeRouter, "/config/region-label/rule/{id}", regionLabelHandler.DeleteRegionLabelRule, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/config/region-label/rule", regionLabelHandler.SetRegionLabelRule, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/config/region-label/rules", regionLabelHandler.PatchRegionLabelRules, setMethods(http.MethodPatch), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/region/id/{id}/label/{key}", regionLabelHandler.GetRegionLabelByKey, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/region/id/{id}/labels", regionLabelHandler.GetRegionLabels, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + storeHandler := newStoreHandler(handler, rd) + registerFunc(clusterRouter, "/store/{id}", storeHandler.GetStore, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/store/{id}", storeHandler.DeleteStore, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/store/{id}/state", storeHandler.SetStoreState, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/store/{id}/label", storeHandler.SetStoreLabel, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/store/{id}/label", storeHandler.DeleteStoreLabel, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/store/{id}/weight", storeHandler.SetStoreWeight, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/store/{id}/limit", storeHandler.SetStoreLimit, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + + storesHandler := newStoresHandler(handler, rd) + registerFunc(clusterRouter, "/stores", storesHandler.GetStores, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/stores/remove-tombstone", storesHandler.RemoveTombStone, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/stores/limit", storesHandler.GetAllStoresLimit, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/stores/limit", storesHandler.SetAllStoresLimit, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/stores/limit/scene", storesHandler.SetStoreLimitScene, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/stores/limit/scene", storesHandler.GetStoreLimitScene, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/stores/progress", storesHandler.GetStoresProgress, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + labelsHandler := newLabelsHandler(svr, rd) + registerFunc(clusterRouter, "/labels", labelsHandler.GetLabels, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/labels/stores", labelsHandler.GetStoresByLabel, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + hotStatusHandler := newHotStatusHandler(handler, rd) + registerFunc(apiRouter, "/hotspot/regions/write", hotStatusHandler.GetHotWriteRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/hotspot/regions/read", hotStatusHandler.GetHotReadRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/hotspot/regions/history", hotStatusHandler.GetHistoryHotRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/hotspot/stores", hotStatusHandler.GetHotStores, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + regionHandler := newRegionHandler(svr, rd) + registerFunc(clusterRouter, "/region/id/{id}", regionHandler.GetRegionByID, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter.UseEncodedPath(), "/region/key/{key}", regionHandler.GetRegion, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + srd := createStreamingRender() + regionsAllHandler := newRegionsHandler(svr, srd) + registerFunc(clusterRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + regionsHandler := newRegionsHandler(svr, rd) + registerFunc(clusterRouter, "/regions/key", regionsHandler.ScanRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/count", regionsHandler.GetRegionCount, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/store/{id}", regionsHandler.GetStoreRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/writeflow", regionsHandler.GetTopWriteFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/readflow", regionsHandler.GetTopReadFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/confver", regionsHandler.GetTopConfVerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/version", regionsHandler.GetTopVersionRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/size", regionsHandler.GetTopSizeRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/keys", regionsHandler.GetTopKeysRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/cpu", regionsHandler.GetTopCPURegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/miss-peer", regionsHandler.GetMissPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/extra-peer", regionsHandler.GetExtraPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/pending-peer", regionsHandler.GetPendingPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/down-peer", regionsHandler.GetDownPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/learner-peer", regionsHandler.GetLearnerPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/empty-region", regionsHandler.GetEmptyRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/offline-peer", regionsHandler.GetOfflinePeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/oversized-region", regionsHandler.GetOverSizedRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/undersized-region", regionsHandler.GetUndersizedRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + registerFunc(clusterRouter, "/regions/check/hist-size", regionsHandler.GetSizeHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/check/hist-keys", regionsHandler.GetKeysHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/sibling/{id}", regionsHandler.GetRegionSiblings, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/accelerate-schedule", regionsHandler.AccelerateRegionsScheduleInRange, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/regions/scatter", regionsHandler.ScatterRegions, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/regions/split", regionsHandler.SplitRegions, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/regions/range-holes", regionsHandler.GetRangeHoles, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions/replicated", regionsHandler.CheckRegionsReplicated, setMethods(http.MethodGet), setQueries("startKey", "{startKey}", "endKey", "{endKey}"), setAuditBackend(prometheus)) + + registerFunc(apiRouter, "/version", newVersionHandler(rd).GetVersion, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/status", newStatusHandler(svr, rd).GetPDStatus, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + memberHandler := newMemberHandler(svr, rd) + registerFunc(apiRouter, "/members", memberHandler.GetMembers, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/members/name/{name}", memberHandler.DeleteMemberByName, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/members/id/{id}", memberHandler.DeleteMemberByID, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/members/name/{name}", memberHandler.SetMemberPropertyByName, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + + leaderHandler := newLeaderHandler(svr, rd) + registerFunc(apiRouter, "/leader", leaderHandler.GetLeader, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/leader/resign", leaderHandler.ResignLeader, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/leader/transfer/{next_leader}", leaderHandler.TransferLeader, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + + statsHandler := newStatsHandler(svr, rd) + registerFunc(clusterRouter, "/stats/region", statsHandler.GetRegionStatus, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + trendHandler := newTrendHandler(svr, rd) + registerFunc(apiRouter, "/trend", trendHandler.GetTrend, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + adminHandler := newAdminHandler(svr, rd) + registerFunc(clusterRouter, "/admin/cache/region/{id}", adminHandler.DeleteRegionCache, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/admin/cache/regions", adminHandler.DeleteAllRegionCache, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + // br ebs restore phase 1 will reset ts, but at that time the cluster hasn't bootstrapped, so cannot use clusterRouter + registerFunc(apiRouter, "/admin/reset-ts", adminHandler.ResetTS, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/admin/persist-file/{file_name}", adminHandler.SavePersistFile, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/admin/persist-file/{file_name}", adminHandler.SavePersistFile, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/admin/cluster/markers/snapshot-recovering", adminHandler.IsSnapshotRecovering, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/admin/cluster/markers/snapshot-recovering", adminHandler.MarkSnapshotRecovering, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/admin/cluster/markers/snapshot-recovering", adminHandler.UnmarkSnapshotRecovering, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/admin/base-alloc-id", adminHandler.RecoverAllocID, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + + serviceMiddlewareHandler := newServiceMiddlewareHandler(svr, rd) + registerFunc(apiRouter, "/service-middleware/config", serviceMiddlewareHandler.GetServiceMiddlewareConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/service-middleware/config", serviceMiddlewareHandler.SetServiceMiddlewareConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(apiRouter, "/service-middleware/config/rate-limit", serviceMiddlewareHandler.SetRatelimitConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus), setRateLimitAllowList()) + + logHandler := newLogHandler(svr, rd) + registerFunc(apiRouter, "/admin/log", logHandler.SetLogLevel, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + replicationModeHandler := newReplicationModeHandler(svr, rd) + registerFunc(clusterRouter, "/replication_mode/status", replicationModeHandler.GetReplicationModeStatus, setAuditBackend(prometheus)) + + pluginHandler := newPluginHandler(handler, rd) + registerFunc(apiRouter, "/plugin", pluginHandler.LoadPlugin, setMethods(http.MethodPost), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/plugin", pluginHandler.UnloadPlugin, setMethods(http.MethodDelete), setAuditBackend(prometheus)) + + healthHandler := newHealthHandler(svr, rd) + registerFunc(apiRouter, "/health", healthHandler.GetHealthStatus, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/ping", healthHandler.Ping, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + // metric query use to query metric data, the protocol is compatible with prometheus. + registerFunc(apiRouter, "/metric/query", newQueryMetric(svr).QueryMetric, setMethods(http.MethodGet, http.MethodPost), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/metric/query_range", newQueryMetric(svr).QueryMetric, setMethods(http.MethodGet, http.MethodPost), setAuditBackend(prometheus)) + + // tso API + tsoHandler := newTSOHandler(svr, rd) + registerFunc(apiRouter, "/tso/allocator/transfer/{name}", tsoHandler.TransferLocalTSOAllocator, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + + pprofHandler := newPprofHandler(svr, rd) + // profile API + registerFunc(apiRouter, "/debug/pprof/profile", pprof.Profile) + registerFunc(apiRouter, "/debug/pprof/trace", pprof.Trace) + registerFunc(apiRouter, "/debug/pprof/symbol", pprof.Symbol) + registerFunc(apiRouter, "/debug/pprof/heap", pprofHandler.PProfHeap) + registerFunc(apiRouter, "/debug/pprof/mutex", pprofHandler.PProfMutex) + registerFunc(apiRouter, "/debug/pprof/allocs", pprofHandler.PProfAllocs) + registerFunc(apiRouter, "/debug/pprof/block", pprofHandler.PProfBlock) + registerFunc(apiRouter, "/debug/pprof/goroutine", pprofHandler.PProfGoroutine) + registerFunc(apiRouter, "/debug/pprof/threadcreate", pprofHandler.PProfThreadcreate) + registerFunc(apiRouter, "/debug/pprof/zip", pprofHandler.PProfZip) + + // service GC safepoint API + serviceGCSafepointHandler := newServiceGCSafepointHandler(svr, rd) + registerFunc(apiRouter, "/gc/safepoint", serviceGCSafepointHandler.GetGCSafePoint, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(apiRouter, "/gc/safepoint/{service_id}", serviceGCSafepointHandler.DeleteGCSafePoint, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + + // min resolved ts API + minResolvedTSHandler := newMinResolvedTSHandler(svr, rd) + registerFunc(clusterRouter, "/min-resolved-ts", minResolvedTSHandler.GetMinResolvedTS, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + // unsafe admin operation API + unsafeOperationHandler := newUnsafeOperationHandler(svr, rd) + registerFunc(clusterRouter, "/admin/unsafe/remove-failed-stores", + unsafeOperationHandler.RemoveFailedStores, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) + registerFunc(clusterRouter, "/admin/unsafe/remove-failed-stores/show", + unsafeOperationHandler.GetFailedStoresRemovalStatus, setMethods(http.MethodGet), setAuditBackend(prometheus)) + + // API to set or unset failpoints + failpoint.Inject("enableFailpointAPI", func() { + // this function will be named to "func2". It may be used in test + registerPrefix(apiRouter, "/fail", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The HTTP handler of failpoint requires the full path to be the failpoint path. + r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix+apiPrefix+"/fail") + new(failpoint.HttpHandler).ServeHTTP(w, r) + }), setAuditBackend("test")) + }) + + // Deprecated: use /pd/api/v1/health instead. + rootRouter.HandleFunc("/health", healthHandler.GetHealthStatus).Methods(http.MethodGet) + // Deprecated: use /pd/api/v1/ping instead. + rootRouter.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {}).Methods(http.MethodGet) + + rootRouter.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { + serviceLabel := route.GetName() + methods, _ := route.GetMethods() + path, _ := route.GetPathTemplate() + if len(serviceLabel) == 0 { + return nil + } + if len(methods) > 0 { + for _, method := range methods { + svr.AddServiceLabel(serviceLabel, apiutil.NewAccessPath(path, method)) + } + } else { + svr.AddServiceLabel(serviceLabel, apiutil.NewAccessPath(path, "")) + } + return nil + }) + + return rootRouter +} diff --git a/server/binding__failpoint_binding__.go b/server/binding__failpoint_binding__.go new file mode 100755 index 00000000000..88484133239 --- /dev/null +++ b/server/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package server + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/cluster/binding__failpoint_binding__.go b/server/cluster/binding__failpoint_binding__.go new file mode 100755 index 00000000000..54b9a7feafc --- /dev/null +++ b/server/cluster/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package cluster + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go old mode 100644 new mode 100755 index 23255a01c05..59ac965eb52 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -388,9 +388,9 @@ func (c *RaftCluster) runMetricsCollectionJob() { defer c.wg.Done() ticker := time.NewTicker(metricsCollectionJobInterval) - failpoint.Inject("highFrequencyClusterJobs", func() { + if _, _err_ := failpoint.Eval(_curpkg_("highFrequencyClusterJobs")); _err_ == nil { ticker = time.NewTicker(time.Microsecond) - }) + } defer ticker.Stop() @@ -412,9 +412,9 @@ func (c *RaftCluster) runNodeStateCheckJob() { defer c.wg.Done() ticker := time.NewTicker(nodeStateCheckJobInterval) - failpoint.Inject("highFrequencyClusterJobs", func() { + if _, _err_ := failpoint.Eval(_curpkg_("highFrequencyClusterJobs")); _err_ == nil { ticker = time.NewTicker(2 * time.Second) - }) + } defer ticker.Stop() for { @@ -809,9 +809,9 @@ func (c *RaftCluster) processReportBuckets(buckets *metapb.Buckets) error { bucketEventCounter.WithLabelValues("version_not_match").Inc() return nil } - failpoint.Inject("concurrentBucketHeartbeat", func() { + if _, _err_ := failpoint.Eval(_curpkg_("concurrentBucketHeartbeat")); _err_ == nil { time.Sleep(500 * time.Millisecond) - }) + } if ok := region.UpdateBuckets(buckets, old); ok { return nil } @@ -859,15 +859,15 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { return nil } - failpoint.Inject("concurrentRegionHeartbeat", func() { + if _, _err_ := failpoint.Eval(_curpkg_("concurrentRegionHeartbeat")); _err_ == nil { time.Sleep(500 * time.Millisecond) - }) + } var overlaps []*core.RegionInfo if saveCache { - failpoint.Inject("decEpoch", func() { + if _, _err_ := failpoint.Eval(_curpkg_("decEpoch")); _err_ == nil { region = region.Clone(core.SetRegionConfVer(2), core.SetRegionVersion(2)) - }) + } // To prevent a concurrent heartbeat of another region from overriding the up-to-date region info by a stale one, // check its validation again here. // @@ -2006,9 +2006,9 @@ func (c *RaftCluster) onStoreVersionChangeLocked() { clusterVersion := c.opt.GetClusterVersion() // If the cluster version of PD is less than the minimum version of all stores, // it will update the cluster version. - failpoint.Inject("versionChangeConcurrency", func() { + if _, _err_ := failpoint.Eval(_curpkg_("versionChangeConcurrency")); _err_ == nil { time.Sleep(500 * time.Millisecond) - }) + } if minVersion == nil || clusterVersion.Equal(*minVersion) { return } diff --git a/server/cluster/cluster.go__failpoint_stash__ b/server/cluster/cluster.go__failpoint_stash__ new file mode 100644 index 00000000000..23255a01c05 --- /dev/null +++ b/server/cluster/cluster.go__failpoint_stash__ @@ -0,0 +1,2524 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cluster + +import ( + "context" + "fmt" + "math" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/coreos/go-semver/semver" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/etcdutil" + "github.com/tikv/pd/pkg/logutil" + "github.com/tikv/pd/pkg/netutil" + "github.com/tikv/pd/pkg/progress" + "github.com/tikv/pd/pkg/slice" + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/pkg/typeutil" + "github.com/tikv/pd/server/config" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/core/storelimit" + "github.com/tikv/pd/server/id" + syncer "github.com/tikv/pd/server/region_syncer" + "github.com/tikv/pd/server/replication" + "github.com/tikv/pd/server/schedule" + "github.com/tikv/pd/server/schedule/checker" + "github.com/tikv/pd/server/schedule/hbstream" + "github.com/tikv/pd/server/schedule/labeler" + "github.com/tikv/pd/server/schedule/placement" + "github.com/tikv/pd/server/schedulers" + "github.com/tikv/pd/server/statistics" + "github.com/tikv/pd/server/statistics/buckets" + "github.com/tikv/pd/server/storage" + "github.com/tikv/pd/server/storage/endpoint" + "github.com/tikv/pd/server/versioninfo" + "go.etcd.io/etcd/clientv3" + "go.uber.org/zap" +) + +var ( + // DefaultMinResolvedTSPersistenceInterval is the default value of min resolved ts persistence interval. + // If interval in config is zero, it means not to persist resolved ts and check config with this DefaultMinResolvedTSPersistenceInterval + DefaultMinResolvedTSPersistenceInterval = config.DefaultMinResolvedTSPersistenceInterval + regionUpdateCacheEventCounter = regionEventCounter.WithLabelValues("update_cache") + regionUpdateKVEventCounter = regionEventCounter.WithLabelValues("update_kv") +) + +// regionLabelGCInterval is the interval to run region-label's GC work. +const regionLabelGCInterval = time.Hour + +const ( + // nodeStateCheckJobInterval is the interval to run node state check job. + nodeStateCheckJobInterval = 10 * time.Second + // metricsCollectionJobInterval is the interval to run metrics collection job. + metricsCollectionJobInterval = 10 * time.Second + updateStoreStatsInterval = 9 * time.Millisecond + clientTimeout = 3 * time.Second + defaultChangedRegionsLimit = 10000 + gcTombstoreInterval = 30 * 24 * time.Hour + // persistLimitRetryTimes is used to reduce the probability of the persistent error + // since the once the store is add or remove, we shouldn't return an error even if the store limit is failed to persist. + persistLimitRetryTimes = 5 + persistLimitWaitTime = 100 * time.Millisecond + removingAction = "removing" + preparingAction = "preparing" +) + +// Server is the interface for cluster. +type Server interface { + GetAllocator() id.Allocator + GetConfig() *config.Config + GetPersistOptions() *config.PersistOptions + GetStorage() storage.Storage + GetHBStreams() *hbstream.HeartbeatStreams + GetRaftCluster() *RaftCluster + GetBasicCluster() *core.BasicCluster + GetMembers() ([]*pdpb.Member, error) + ReplicateFileToMember(ctx context.Context, member *pdpb.Member, name string, data []byte) error +} + +// RaftCluster is used for cluster config management. +// Raft cluster key format: +// cluster 1 -> /1/raft, value is metapb.Cluster +// cluster 2 -> /2/raft +// For cluster 1 +// store 1 -> /1/raft/s/1, value is metapb.Store +// region 1 -> /1/raft/r/1, value is metapb.Region +type RaftCluster struct { + syncutil.RWMutex + wg sync.WaitGroup + + serverCtx context.Context + ctx context.Context + cancel context.CancelFunc + + etcdClient *clientv3.Client + httpClient *http.Client + + running atomic.Bool + meta *metapb.Cluster + storeConfigManager *config.StoreConfigManager + storage storage.Storage + minResolvedTS uint64 + externalTS uint64 + + // Keep the previous store limit settings when removing a store. + prevStoreLimit map[uint64]map[storelimit.Type]float64 + + // This below fields are all read-only, we cannot update itself after the raft cluster starts. + clusterID uint64 + id id.Allocator + core *core.BasicCluster // cached cluster info + opt *config.PersistOptions + limiter *StoreLimiter + coordinator *coordinator + labelLevelStats *statistics.LabelStatistics + regionStats *statistics.RegionStatistics + hotStat *statistics.HotStat + hotBuckets *buckets.HotBucketCache + ruleManager *placement.RuleManager + regionLabeler *labeler.RegionLabeler + replicationMode *replication.ModeManager + unsafeRecoveryController *unsafeRecoveryController + progressManager *progress.Manager + regionSyncer *syncer.RegionSyncer + changedRegions chan *core.RegionInfo +} + +// Status saves some state information. +// NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. +type Status struct { + RaftBootstrapTime time.Time `json:"raft_bootstrap_time,omitempty"` + IsInitialized bool `json:"is_initialized"` + ReplicationStatus string `json:"replication_status"` +} + +// NewRaftCluster create a new cluster. +func NewRaftCluster(ctx context.Context, clusterID uint64, regionSyncer *syncer.RegionSyncer, etcdClient *clientv3.Client, + httpClient *http.Client) *RaftCluster { + return &RaftCluster{ + serverCtx: ctx, + clusterID: clusterID, + regionSyncer: regionSyncer, + httpClient: httpClient, + etcdClient: etcdClient, + } +} + +// GetStoreConfig returns the store config. +func (c *RaftCluster) GetStoreConfig() *config.StoreConfig { + return c.storeConfigManager.GetStoreConfig() +} + +// LoadClusterStatus loads the cluster status. +func (c *RaftCluster) LoadClusterStatus() (*Status, error) { + bootstrapTime, err := c.loadBootstrapTime() + if err != nil { + return nil, err + } + var isInitialized bool + if bootstrapTime != typeutil.ZeroTime { + isInitialized = c.isInitialized() + } + var replicationStatus string + if c.replicationMode != nil { + replicationStatus = c.replicationMode.GetReplicationStatus().String() + } + return &Status{ + RaftBootstrapTime: bootstrapTime, + IsInitialized: isInitialized, + ReplicationStatus: replicationStatus, + }, nil +} + +func (c *RaftCluster) isInitialized() bool { + if c.core.GetRegionCount() > 1 { + return true + } + region := c.core.GetRegionByKey(nil) + return region != nil && + len(region.GetVoters()) >= int(c.opt.GetReplicationConfig().MaxReplicas) && + len(region.GetPendingPeers()) == 0 +} + +// loadBootstrapTime loads the saved bootstrap time from etcd. It returns zero +// value of time.Time when there is error or the cluster is not bootstrapped yet. +func (c *RaftCluster) loadBootstrapTime() (time.Time, error) { + var t time.Time + data, err := c.storage.Load(endpoint.ClusterBootstrapTimeKey()) + if err != nil { + return t, err + } + if data == "" { + return t, nil + } + return typeutil.ParseTimestamp([]byte(data)) +} + +// InitCluster initializes the raft cluster. +func (c *RaftCluster) InitCluster( + id id.Allocator, + opt *config.PersistOptions, + storage storage.Storage, + basicCluster *core.BasicCluster) { + c.core, c.opt, c.storage, c.id = basicCluster, opt, storage, id + c.ctx, c.cancel = context.WithCancel(c.serverCtx) + c.labelLevelStats = statistics.NewLabelStatistics() + c.hotStat = statistics.NewHotStat(c.ctx) + c.hotBuckets = buckets.NewBucketsCache(c.ctx) + c.progressManager = progress.NewManager() + c.changedRegions = make(chan *core.RegionInfo, defaultChangedRegionsLimit) + c.prevStoreLimit = make(map[uint64]map[storelimit.Type]float64) + c.unsafeRecoveryController = newUnsafeRecoveryController(c) +} + +// Start starts a cluster. +func (c *RaftCluster) Start(s Server) error { + if c.IsRunning() { + log.Warn("raft cluster has already been started") + return nil + } + + c.Lock() + defer c.Unlock() + + c.InitCluster(s.GetAllocator(), s.GetPersistOptions(), s.GetStorage(), s.GetBasicCluster()) + cluster, err := c.LoadClusterInfo() + if err != nil { + return err + } + if cluster == nil { + return nil + } + + c.ruleManager = placement.NewRuleManager(c.storage, c, c.GetOpts()) + if c.opt.IsPlacementRulesEnabled() { + err = c.ruleManager.Initialize(c.opt.GetMaxReplicas(), c.opt.GetLocationLabels()) + if err != nil { + return err + } + } + + c.regionLabeler, err = labeler.NewRegionLabeler(c.ctx, c.storage, regionLabelGCInterval) + if err != nil { + return err + } + + c.replicationMode, err = replication.NewReplicationModeManager(s.GetConfig().ReplicationMode, c.storage, cluster, s) + if err != nil { + return err + } + c.storeConfigManager = config.NewStoreConfigManager(c.httpClient) + c.coordinator = newCoordinator(c.ctx, cluster, s.GetHBStreams()) + c.regionStats = statistics.NewRegionStatistics(c.opt, c.ruleManager, c.storeConfigManager) + c.limiter = NewStoreLimiter(s.GetPersistOptions()) + c.externalTS, err = c.storage.LoadExternalTS() + if err != nil { + log.Error("load external timestamp meets error", zap.Error(err)) + } + + c.wg.Add(9) + go c.runCoordinator() + go c.runMetricsCollectionJob() + go c.runNodeStateCheckJob() + go c.runStatsBackgroundJobs() + go c.syncRegions() + go c.runReplicationMode() + go c.runMinResolvedTSJob() + go c.runSyncConfig() + go c.runUpdateStoreStats() + + c.running.Store(true) + return nil +} + +// runSyncConfig runs the job to sync tikv config. +func (c *RaftCluster) runSyncConfig() { + defer logutil.LogPanic() + defer c.wg.Done() + + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + stores := c.GetStores() + + syncConfig(c.storeConfigManager, stores) + for { + select { + case <-c.ctx.Done(): + log.Info("sync store config job is stopped") + return + case <-ticker.C: + if !syncConfig(c.storeConfigManager, stores) { + stores = c.GetStores() + } + } + } +} + +func syncConfig(manager *config.StoreConfigManager, stores []*core.StoreInfo) bool { + for index := 0; index < len(stores); index++ { + // filter out the stores that are tiflash + store := stores[index] + if core.IsStoreContainLabel(store.GetMeta(), core.EngineKey, core.EngineTiFlash) { + continue + } + + // filter out the stores that are not up. + if !(store.IsPreparing() || store.IsServing()) { + continue + } + // it will try next store if the current store is failed. + address := netutil.ResolveLoopBackAddr(stores[index].GetStatusAddress(), stores[index].GetAddress()) + if err := manager.ObserveConfig(address); err != nil { + storeSyncConfigEvent.WithLabelValues(address, "fail").Inc() + log.Debug("sync store config failed, it will try next store", zap.Error(err)) + continue + } + storeSyncConfigEvent.WithLabelValues(address, "succ").Inc() + // it will only try one store. + return true + } + return false +} + +// LoadClusterInfo loads cluster related info. +func (c *RaftCluster) LoadClusterInfo() (*RaftCluster, error) { + c.meta = &metapb.Cluster{} + ok, err := c.storage.LoadMeta(c.meta) + if err != nil { + return nil, err + } + if !ok { + return nil, nil + } + + c.core.ResetStores() + start := time.Now() + if err := c.storage.LoadStores(c.core.PutStore); err != nil { + return nil, err + } + log.Info("load stores", + zap.Int("count", c.GetStoreCount()), + zap.Duration("cost", time.Since(start)), + ) + + start = time.Now() + + // used to load region from kv storage to cache storage. + if err := storage.TryLoadRegionsOnce(c.ctx, c.storage, c.core.CheckAndPutRegion); err != nil { + return nil, err + } + log.Info("load regions", + zap.Int("count", c.core.GetRegionCount()), + zap.Duration("cost", time.Since(start)), + ) + for _, store := range c.GetStores() { + storeID := store.GetID() + c.hotStat.GetOrCreateRollingStoreStats(storeID) + } + return c, nil +} + +func (c *RaftCluster) runMetricsCollectionJob() { + defer logutil.LogPanic() + defer c.wg.Done() + + ticker := time.NewTicker(metricsCollectionJobInterval) + failpoint.Inject("highFrequencyClusterJobs", func() { + ticker = time.NewTicker(time.Microsecond) + }) + + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + log.Info("metrics are reset") + c.resetMetrics() + log.Info("metrics collection job has been stopped") + return + case <-ticker.C: + c.collectMetrics() + } + } +} + +func (c *RaftCluster) runNodeStateCheckJob() { + defer logutil.LogPanic() + defer c.wg.Done() + + ticker := time.NewTicker(nodeStateCheckJobInterval) + failpoint.Inject("highFrequencyClusterJobs", func() { + ticker = time.NewTicker(2 * time.Second) + }) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + log.Info("node state check job has been stopped") + return + case <-ticker.C: + c.checkStores() + } + } +} + +func (c *RaftCluster) runStatsBackgroundJobs() { + defer logutil.LogPanic() + defer c.wg.Done() + + ticker := time.NewTicker(statistics.RegionsStatsObserveInterval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + log.Info("statistics background jobs has been stopped") + return + case <-ticker.C: + c.hotStat.ObserveRegionsStats(c.core.GetStoresWriteRate()) + } + } +} + +func (c *RaftCluster) runUpdateStoreStats() { + defer logutil.LogPanic() + defer c.wg.Done() + + ticker := time.NewTicker(updateStoreStatsInterval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + log.Info("update store stats background jobs has been stopped") + return + case <-ticker.C: + // Update related stores. + start := time.Now() + stores := c.GetStores() + for _, store := range stores { + if store.IsRemoved() { + continue + } + c.core.UpdateStoreStatus(store.GetID()) + } + updateStoreStatsGauge.Set(time.Since(start).Seconds()) + } + } +} + +func (c *RaftCluster) runCoordinator() { + defer logutil.LogPanic() + defer c.wg.Done() + c.coordinator.runUntilStop() +} + +func (c *RaftCluster) syncRegions() { + defer logutil.LogPanic() + defer c.wg.Done() + c.regionSyncer.RunServer(c.ctx, c.changedRegionNotifier()) +} + +func (c *RaftCluster) runReplicationMode() { + defer logutil.LogPanic() + defer c.wg.Done() + c.replicationMode.Run(c.ctx) +} + +// Stop stops the cluster. +func (c *RaftCluster) Stop() { + c.Lock() + if !c.running.CompareAndSwap(true, false) { + c.Unlock() + return + } + + c.coordinator.stop() + c.cancel() + c.Unlock() + c.wg.Wait() + log.Info("raftcluster is stopped") +} + +// IsRunning return if the cluster is running. +func (c *RaftCluster) IsRunning() bool { + return c.running.Load() +} + +// Context returns the context of RaftCluster. +func (c *RaftCluster) Context() context.Context { + if c.running.Load() { + return c.ctx + } + return nil +} + +// GetCoordinator returns the coordinator. +func (c *RaftCluster) GetCoordinator() *coordinator { + return c.coordinator +} + +// GetOperatorController returns the operator controller. +func (c *RaftCluster) GetOperatorController() *schedule.OperatorController { + return c.coordinator.opController +} + +// SetPrepared set the prepare check to prepared. Only for test purpose. +func (c *RaftCluster) SetPrepared() { + c.coordinator.prepareChecker.Lock() + defer c.coordinator.prepareChecker.Unlock() + c.coordinator.prepareChecker.prepared = true +} + +// GetRegionScatter returns the region scatter. +func (c *RaftCluster) GetRegionScatter() *schedule.RegionScatterer { + return c.coordinator.regionScatterer +} + +// GetRegionSplitter returns the region splitter +func (c *RaftCluster) GetRegionSplitter() *schedule.RegionSplitter { + return c.coordinator.regionSplitter +} + +// GetMergeChecker returns merge checker. +func (c *RaftCluster) GetMergeChecker() *checker.MergeChecker { + return c.coordinator.checkers.GetMergeChecker() +} + +// GetRuleChecker returns rule checker. +func (c *RaftCluster) GetRuleChecker() *checker.RuleChecker { + return c.coordinator.checkers.GetRuleChecker() +} + +// RecordOpStepWithTTL records OpStep with TTL +func (c *RaftCluster) RecordOpStepWithTTL(regionID uint64) { + c.GetRuleChecker().RecordRegionPromoteToNonWitness(regionID) +} + +// GetSchedulers gets all schedulers. +func (c *RaftCluster) GetSchedulers() []string { + return c.coordinator.getSchedulers() +} + +// GetSchedulerHandlers gets all scheduler handlers. +func (c *RaftCluster) GetSchedulerHandlers() map[string]http.Handler { + return c.coordinator.getSchedulerHandlers() +} + +// AddScheduler adds a scheduler. +func (c *RaftCluster) AddScheduler(scheduler schedule.Scheduler, args ...string) error { + return c.coordinator.addScheduler(scheduler, args...) +} + +// RemoveScheduler removes a scheduler. +func (c *RaftCluster) RemoveScheduler(name string) error { + return c.coordinator.removeScheduler(name) +} + +// PauseOrResumeScheduler pauses or resumes a scheduler. +func (c *RaftCluster) PauseOrResumeScheduler(name string, t int64) error { + return c.coordinator.pauseOrResumeScheduler(name, t) +} + +// IsSchedulerPaused checks if a scheduler is paused. +func (c *RaftCluster) IsSchedulerPaused(name string) (bool, error) { + return c.coordinator.isSchedulerPaused(name) +} + +// IsSchedulerDisabled checks if a scheduler is disabled. +func (c *RaftCluster) IsSchedulerDisabled(name string) (bool, error) { + return c.coordinator.isSchedulerDisabled(name) +} + +// IsSchedulerAllowed checks if a scheduler is allowed. +func (c *RaftCluster) IsSchedulerAllowed(name string) (bool, error) { + return c.coordinator.isSchedulerAllowed(name) +} + +// IsSchedulerExisted checks if a scheduler is existed. +func (c *RaftCluster) IsSchedulerExisted(name string) (bool, error) { + return c.coordinator.isSchedulerExisted(name) +} + +// PauseOrResumeChecker pauses or resumes checker. +func (c *RaftCluster) PauseOrResumeChecker(name string, t int64) error { + return c.coordinator.pauseOrResumeChecker(name, t) +} + +// IsCheckerPaused returns if checker is paused +func (c *RaftCluster) IsCheckerPaused(name string) (bool, error) { + return c.coordinator.isCheckerPaused(name) +} + +// GetAllocator returns cluster's id allocator. +func (c *RaftCluster) GetAllocator() id.Allocator { + return c.id +} + +// GetRegionSyncer returns the region syncer. +func (c *RaftCluster) GetRegionSyncer() *syncer.RegionSyncer { + return c.regionSyncer +} + +// GetReplicationMode returns the ReplicationMode. +func (c *RaftCluster) GetReplicationMode() *replication.ModeManager { + return c.replicationMode +} + +// GetRuleManager returns the rule manager reference. +func (c *RaftCluster) GetRuleManager() *placement.RuleManager { + return c.ruleManager +} + +// GetRegionLabeler returns the region labeler. +func (c *RaftCluster) GetRegionLabeler() *labeler.RegionLabeler { + return c.regionLabeler +} + +// GetStorage returns the storage. +func (c *RaftCluster) GetStorage() storage.Storage { + c.RLock() + defer c.RUnlock() + return c.storage +} + +// SetStorage set the storage for test purpose. +func (c *RaftCluster) SetStorage(s storage.Storage) { + c.Lock() + defer c.Unlock() + c.storage = s +} + +// GetOpts returns cluster's configuration. +// There is no need a lock since it won't changed. +func (c *RaftCluster) GetOpts() *config.PersistOptions { + return c.opt +} + +// AddSuspectRegions adds regions to suspect list. +func (c *RaftCluster) AddSuspectRegions(regionIDs ...uint64) { + c.coordinator.checkers.AddSuspectRegions(regionIDs...) +} + +// GetSuspectRegions gets all suspect regions. +func (c *RaftCluster) GetSuspectRegions() []uint64 { + return c.coordinator.checkers.GetSuspectRegions() +} + +// GetHotStat gets hot stat for test. +func (c *RaftCluster) GetHotStat() *statistics.HotStat { + return c.hotStat +} + +// RemoveSuspectRegion removes region from suspect list. +func (c *RaftCluster) RemoveSuspectRegion(id uint64) { + c.coordinator.checkers.RemoveSuspectRegion(id) +} + +// GetUnsafeRecoveryController returns the unsafe recovery controller. +func (c *RaftCluster) GetUnsafeRecoveryController() *unsafeRecoveryController { + return c.unsafeRecoveryController +} + +// AddSuspectKeyRange adds the key range with the its ruleID as the key +// The instance of each keyRange is like following format: +// [2][]byte: start key/end key +func (c *RaftCluster) AddSuspectKeyRange(start, end []byte) { + c.coordinator.checkers.AddSuspectKeyRange(start, end) +} + +// PopOneSuspectKeyRange gets one suspect keyRange group. +// it would return value and true if pop success, or return empty [][2][]byte and false +// if suspectKeyRanges couldn't pop keyRange group. +func (c *RaftCluster) PopOneSuspectKeyRange() ([2][]byte, bool) { + return c.coordinator.checkers.PopOneSuspectKeyRange() +} + +// ClearSuspectKeyRanges clears the suspect keyRanges, only for unit test +func (c *RaftCluster) ClearSuspectKeyRanges() { + c.coordinator.checkers.ClearSuspectKeyRanges() +} + +// HandleStoreHeartbeat updates the store status. +func (c *RaftCluster) HandleStoreHeartbeat(heartbeat *pdpb.StoreHeartbeatRequest, resp *pdpb.StoreHeartbeatResponse) error { + stats := heartbeat.GetStats() + storeID := stats.GetStoreId() + c.Lock() + defer c.Unlock() + store := c.GetStore(storeID) + if store == nil { + return errors.Errorf("store %v not found", storeID) + } + + nowTime := time.Now() + var newStore *core.StoreInfo + // If this cluster has slow stores, we should awaken hibernated regions in other stores. + if needAwaken, slowStoreIDs := c.NeedAwakenAllRegionsInStore(storeID); needAwaken { + log.Info("forcely awaken hibernated regions", zap.Uint64("store-id", storeID), zap.Uint64s("slow-stores", slowStoreIDs)) + newStore = store.Clone(core.SetStoreStats(stats), core.SetLastHeartbeatTS(nowTime), core.SetLastAwakenTime(nowTime)) + resp.AwakenRegions = &pdpb.AwakenRegions{ + AbnormalStores: slowStoreIDs, + } + } else { + newStore = store.Clone(core.SetStoreStats(stats), core.SetLastHeartbeatTS(nowTime)) + } + + if newStore.IsLowSpace(c.opt.GetLowSpaceRatio()) { + log.Warn("store does not have enough disk space", + zap.Uint64("store-id", storeID), + zap.Uint64("capacity", newStore.GetCapacity()), + zap.Uint64("available", newStore.GetAvailable())) + } + if newStore.NeedPersist() && c.storage != nil { + if err := c.storage.SaveStore(newStore.GetMeta()); err != nil { + log.Error("failed to persist store", zap.Uint64("store-id", storeID), errs.ZapError(err)) + } else { + newStore = newStore.Clone(core.SetLastPersistTime(nowTime)) + } + } + if store := c.core.GetStore(storeID); store != nil { + statistics.UpdateStoreHeartbeatMetrics(store) + } + c.core.PutStore(newStore) + c.hotStat.Observe(storeID, newStore.GetStoreStats()) + c.hotStat.FilterUnhealthyStore(c) + reportInterval := stats.GetInterval() + interval := reportInterval.GetEndTimestamp() - reportInterval.GetStartTimestamp() + + // c.limiter is nil before "start" is called + if c.limiter != nil && c.opt.GetStoreLimitMode() == "auto" { + c.limiter.Collect(newStore.GetStoreStats()) + } + + regions := make(map[uint64]*core.RegionInfo, len(stats.GetPeerStats())) + for _, peerStat := range stats.GetPeerStats() { + regionID := peerStat.GetRegionId() + region := c.GetRegion(regionID) + regions[regionID] = region + if region == nil { + log.Warn("discard hot peer stat for unknown region", + zap.Uint64("region-id", regionID), + zap.Uint64("store-id", storeID)) + continue + } + peer := region.GetStorePeer(storeID) + if peer == nil { + log.Warn("discard hot peer stat for unknown region peer", + zap.Uint64("region-id", regionID), + zap.Uint64("store-id", storeID)) + continue + } + readQueryNum := core.GetReadQueryNum(peerStat.GetQueryStats()) + loads := []float64{ + statistics.RegionReadBytes: float64(peerStat.GetReadBytes()), + statistics.RegionReadKeys: float64(peerStat.GetReadKeys()), + statistics.RegionReadQueryNum: float64(readQueryNum), + statistics.RegionWriteBytes: 0, + statistics.RegionWriteKeys: 0, + statistics.RegionWriteQueryNum: 0, + } + peerInfo := core.NewPeerInfo(peer, loads, interval) + c.hotStat.CheckReadAsync(statistics.NewCheckPeerTask(peerInfo, region)) + } + // Here we will compare the reported regions with the previous hot peers to decide if it is still hot. + c.hotStat.CheckReadAsync(statistics.NewCollectUnReportedPeerTask(storeID, regions, interval)) + return nil +} + +// processReportBuckets update the bucket information. +func (c *RaftCluster) processReportBuckets(buckets *metapb.Buckets) error { + region := c.core.GetRegion(buckets.GetRegionId()) + if region == nil { + bucketEventCounter.WithLabelValues("region_cache_miss").Inc() + return errors.Errorf("region %v not found", buckets.GetRegionId()) + } + // use CAS to update the bucket information. + // the two request(A:3,B:2) get the same region and need to update the buckets. + // the A will pass the check and set the version to 3, the B will fail because the region.bucket has changed. + // the retry should keep the old version and the new version will be set to the region.bucket, like two requests (A:2,B:3). + for retry := 0; retry < 3; retry++ { + old := region.GetBuckets() + // region should not update if the version of the buckets is less than the old one. + if old != nil && buckets.GetVersion() <= old.GetVersion() { + bucketEventCounter.WithLabelValues("version_not_match").Inc() + return nil + } + failpoint.Inject("concurrentBucketHeartbeat", func() { + time.Sleep(500 * time.Millisecond) + }) + if ok := region.UpdateBuckets(buckets, old); ok { + return nil + } + } + bucketEventCounter.WithLabelValues("update_failed").Inc() + return nil +} + +// IsPrepared return true if the prepare checker is ready. +func (c *RaftCluster) IsPrepared() bool { + return c.coordinator.prepareChecker.isPrepared() +} + +var regionGuide = core.GenerateRegionGuideFunc(true) + +// processRegionHeartbeat updates the region information. +func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { + origin, err := c.core.PreCheckPutRegion(region) + if err != nil { + return err + } + region.Inherit(origin, c.storeConfigManager.GetStoreConfig().IsEnableRegionBucket()) + + c.hotStat.CheckWriteAsync(statistics.NewCheckExpiredItemTask(region)) + c.hotStat.CheckReadAsync(statistics.NewCheckExpiredItemTask(region)) + reportInterval := region.GetInterval() + interval := reportInterval.GetEndTimestamp() - reportInterval.GetStartTimestamp() + for _, peer := range region.GetPeers() { + peerInfo := core.NewPeerInfo(peer, region.GetWriteLoads(), interval) + c.hotStat.CheckWriteAsync(statistics.NewCheckPeerTask(peerInfo, region)) + } + c.coordinator.CheckTransferWitnessLeader(region) + + hasRegionStats := c.regionStats != nil + // Save to storage if meta is updated. + // Save to cache if meta or leader is updated, or contains any down/pending peer. + // Mark isNew if the region in cache does not have leader. + isNew, saveKV, saveCache, needSync := regionGuide(region, origin) + if !saveKV && !saveCache && !isNew { + // Due to some config changes need to update the region stats as well, + // so we do some extra checks here. + if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { + c.regionStats.Observe(region, c.getRegionStoresLocked(region)) + } + return nil + } + + failpoint.Inject("concurrentRegionHeartbeat", func() { + time.Sleep(500 * time.Millisecond) + }) + + var overlaps []*core.RegionInfo + if saveCache { + failpoint.Inject("decEpoch", func() { + region = region.Clone(core.SetRegionConfVer(2), core.SetRegionVersion(2)) + }) + // To prevent a concurrent heartbeat of another region from overriding the up-to-date region info by a stale one, + // check its validation again here. + // + // However it can't solve the race condition of concurrent heartbeats from the same region. + if overlaps, err = c.core.AtomicCheckAndPutRegion(region); err != nil { + return err + } + + for _, item := range overlaps { + if c.regionStats != nil { + c.regionStats.ClearDefunctRegion(item.GetID()) + } + c.labelLevelStats.ClearDefunctRegion(item.GetID()) + } + regionUpdateCacheEventCounter.Inc() + } + + if hasRegionStats { + c.regionStats.Observe(region, c.getRegionStoresLocked(region)) + } + + if !c.IsPrepared() && isNew { + c.coordinator.prepareChecker.collect(region) + } + + if c.storage != nil { + // If there are concurrent heartbeats from the same region, the last write will win even if + // writes to storage in the critical area. So don't use mutex to protect it. + // Not successfully saved to storage is not fatal, it only leads to longer warm-up + // after restart. Here we only log the error then go on updating cache. + for _, item := range overlaps { + if err := c.storage.DeleteRegion(item.GetMeta()); err != nil { + log.Error("failed to delete region from storage", + zap.Uint64("region-id", item.GetID()), + logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(item.GetMeta())), + errs.ZapError(err)) + } + } + if saveKV { + if err := c.storage.SaveRegion(region.GetMeta()); err != nil { + log.Error("failed to save region to storage", + zap.Uint64("region-id", region.GetID()), + logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(region.GetMeta())), + errs.ZapError(err)) + } + regionUpdateKVEventCounter.Inc() + } + } + + if saveKV || needSync { + select { + case c.changedRegions <- region: + default: + } + } + + return nil +} + +func (c *RaftCluster) putMetaLocked(meta *metapb.Cluster) error { + if c.storage != nil { + if err := c.storage.SaveMeta(meta); err != nil { + return err + } + } + c.meta = meta + return nil +} + +// GetBasicCluster returns the basic cluster. +func (c *RaftCluster) GetBasicCluster() *core.BasicCluster { + return c.core +} + +// GetRegionByKey gets regionInfo by region key from cluster. +func (c *RaftCluster) GetRegionByKey(regionKey []byte) *core.RegionInfo { + return c.core.GetRegionByKey(regionKey) +} + +// GetPrevRegionByKey gets previous region and leader peer by the region key from cluster. +func (c *RaftCluster) GetPrevRegionByKey(regionKey []byte) *core.RegionInfo { + return c.core.GetPrevRegionByKey(regionKey) +} + +// ScanRegions scans region with start key, until the region contains endKey, or +// total number greater than limit. +func (c *RaftCluster) ScanRegions(startKey, endKey []byte, limit int) []*core.RegionInfo { + return c.core.ScanRange(startKey, endKey, limit) +} + +// GetRegion searches for a region by ID. +func (c *RaftCluster) GetRegion(regionID uint64) *core.RegionInfo { + return c.core.GetRegion(regionID) +} + +// GetMetaRegions gets regions from cluster. +func (c *RaftCluster) GetMetaRegions() []*metapb.Region { + return c.core.GetMetaRegions() +} + +// GetRegions returns all regions' information in detail. +func (c *RaftCluster) GetRegions() []*core.RegionInfo { + return c.core.GetRegions() +} + +// GetRegionCount returns total count of regions +func (c *RaftCluster) GetRegionCount() int { + return c.core.GetRegionCount() +} + +// GetStoreRegions returns all regions' information with a given storeID. +func (c *RaftCluster) GetStoreRegions(storeID uint64) []*core.RegionInfo { + return c.core.GetStoreRegions(storeID) +} + +// RandLeaderRegions returns some random regions that has leader on the store. +func (c *RaftCluster) RandLeaderRegions(storeID uint64, ranges []core.KeyRange) []*core.RegionInfo { + return c.core.RandLeaderRegions(storeID, ranges) +} + +// RandFollowerRegions returns some random regions that has a follower on the store. +func (c *RaftCluster) RandFollowerRegions(storeID uint64, ranges []core.KeyRange) []*core.RegionInfo { + return c.core.RandFollowerRegions(storeID, ranges) +} + +// RandPendingRegions returns some random regions that has a pending peer on the store. +func (c *RaftCluster) RandPendingRegions(storeID uint64, ranges []core.KeyRange) []*core.RegionInfo { + return c.core.RandPendingRegions(storeID, ranges) +} + +// RandLearnerRegions returns some random regions that has a learner peer on the store. +func (c *RaftCluster) RandLearnerRegions(storeID uint64, ranges []core.KeyRange) []*core.RegionInfo { + return c.core.RandLearnerRegions(storeID, ranges) +} + +// GetLeaderStore returns all stores that contains the region's leader peer. +func (c *RaftCluster) GetLeaderStore(region *core.RegionInfo) *core.StoreInfo { + return c.core.GetLeaderStore(region) +} + +// GetFollowerStores returns all stores that contains the region's follower peer. +func (c *RaftCluster) GetFollowerStores(region *core.RegionInfo) []*core.StoreInfo { + return c.core.GetFollowerStores(region) +} + +// GetRegionStores returns all stores that contains the region's peer. +func (c *RaftCluster) GetRegionStores(region *core.RegionInfo) []*core.StoreInfo { + return c.core.GetRegionStores(region) +} + +// GetStoreCount returns the count of stores. +func (c *RaftCluster) GetStoreCount() int { + return c.core.GetStoreCount() +} + +// GetStoreRegionCount returns the number of regions for a given store. +func (c *RaftCluster) GetStoreRegionCount(storeID uint64) int { + return c.core.GetStoreRegionCount(storeID) +} + +// GetAverageRegionSize returns the average region approximate size. +func (c *RaftCluster) GetAverageRegionSize() int64 { + return c.core.GetAverageRegionSize() +} + +// DropCacheRegion removes a region from the cache. +func (c *RaftCluster) DropCacheRegion(id uint64) { + c.core.RemoveRegionIfExist(id) +} + +// DropCacheAllRegion removes all regions from the cache. +func (c *RaftCluster) DropCacheAllRegion() { + c.core.ResetRegionCache() +} + +// GetMetaStores gets stores from cluster. +func (c *RaftCluster) GetMetaStores() []*metapb.Store { + return c.core.GetMetaStores() +} + +// GetStores returns all stores in the cluster. +func (c *RaftCluster) GetStores() []*core.StoreInfo { + return c.core.GetStores() +} + +// GetLeaderStoreByRegionID returns the leader store of the given region. +func (c *RaftCluster) GetLeaderStoreByRegionID(regionID uint64) *core.StoreInfo { + return c.core.GetLeaderStoreByRegionID(regionID) +} + +// GetStore gets store from cluster. +func (c *RaftCluster) GetStore(storeID uint64) *core.StoreInfo { + return c.core.GetStore(storeID) +} + +// GetAdjacentRegions returns regions' information that are adjacent with the specific region ID. +func (c *RaftCluster) GetAdjacentRegions(region *core.RegionInfo) (*core.RegionInfo, *core.RegionInfo) { + return c.core.GetAdjacentRegions(region) +} + +// GetRangeHoles returns all range holes, i.e the key ranges without any region info. +func (c *RaftCluster) GetRangeHoles() [][]string { + return c.core.GetRangeHoles() +} + +// UpdateStoreLabels updates a store's location labels +// If 'force' is true, then update the store's labels forcibly. +func (c *RaftCluster) UpdateStoreLabels(storeID uint64, labels []*metapb.StoreLabel, force bool) error { + store := c.GetStore(storeID) + if store == nil { + return errs.ErrInvalidStoreID.FastGenByArgs(storeID) + } + newStore := typeutil.DeepClone(store.GetMeta(), core.StoreFactory) + if force { + newStore.Labels = labels + } else { + // If 'force' isn't set, the given labels will merge into those labels which already existed in the store. + newStore.Labels = core.MergeLabels(newStore.GetLabels(), labels) + } + // PutStore will perform label merge. + return c.putStoreImpl(newStore) +} + +// DeleteStoreLabel updates a store's location labels +func (c *RaftCluster) DeleteStoreLabel(storeID uint64, labelKey string) error { + store := c.GetStore(storeID) + if store == nil { + return errs.ErrInvalidStoreID.FastGenByArgs(storeID) + } + newStore := typeutil.DeepClone(store.GetMeta(), core.StoreFactory) + labels := make([]*metapb.StoreLabel, 0, len(newStore.GetLabels())-1) + for _, label := range newStore.GetLabels() { + if label.Key == labelKey { + continue + } + labels = append(labels, label) + } + if len(labels) == len(store.GetLabels()) { + return errors.Errorf("the label key %s does not exist", labelKey) + } + newStore.Labels = labels + // PutStore will perform label merge. + return c.putStoreImpl(newStore) +} + +// PutStore puts a store. +func (c *RaftCluster) PutStore(store *metapb.Store) error { + if err := c.putStoreImpl(store); err != nil { + return err + } + c.OnStoreVersionChange() + c.AddStoreLimit(store) + return nil +} + +// putStoreImpl puts a store. +// If 'force' is true, then overwrite the store's labels. +func (c *RaftCluster) putStoreImpl(store *metapb.Store) error { + c.Lock() + defer c.Unlock() + + if store.GetId() == 0 { + return errors.Errorf("invalid put store %v", store) + } + + if err := c.checkStoreVersion(store); err != nil { + return err + } + + // Store address can not be the same as other stores. + for _, s := range c.GetStores() { + // It's OK to start a new store on the same address if the old store has been removed or physically destroyed. + if s.IsRemoved() || s.IsPhysicallyDestroyed() { + continue + } + if s.GetID() != store.GetId() && s.GetAddress() == store.GetAddress() { + return errors.Errorf("duplicated store address: %v, already registered by %v", store, s.GetMeta()) + } + } + + s := c.GetStore(store.GetId()) + if s == nil { + // Add a new store. + s = core.NewStoreInfo(store) + } else { + // Use the given labels to update the store. + labels := store.GetLabels() + // Update an existed store. + s = s.Clone( + core.SetStoreAddress(store.Address, store.StatusAddress, store.PeerAddress), + core.SetStoreVersion(store.GitHash, store.Version), + core.SetStoreLabels(labels), + core.SetStoreStartTime(store.StartTimestamp), + core.SetStoreDeployPath(store.DeployPath), + ) + } + if err := c.checkStoreLabels(s); err != nil { + return err + } + return c.putStoreLocked(s) +} + +func (c *RaftCluster) checkStoreVersion(store *metapb.Store) error { + v, err := versioninfo.ParseVersion(store.GetVersion()) + if err != nil { + return errors.Errorf("invalid put store %v, error: %s", store, err) + } + clusterVersion := *c.opt.GetClusterVersion() + if !versioninfo.IsCompatible(clusterVersion, *v) { + return errors.Errorf("version should compatible with version %s, got %s", clusterVersion, v) + } + return nil +} + +func (c *RaftCluster) checkStoreLabels(s *core.StoreInfo) error { + keysSet := make(map[string]struct{}) + for _, k := range c.opt.GetLocationLabels() { + keysSet[k] = struct{}{} + if v := s.GetLabelValue(k); len(v) == 0 { + log.Warn("label configuration is incorrect", + zap.Stringer("store", s.GetMeta()), + zap.String("label-key", k)) + if c.opt.GetStrictlyMatchLabel() { + return errors.Errorf("label configuration is incorrect, need to specify the key: %s ", k) + } + } + } + for _, label := range s.GetLabels() { + key := label.GetKey() + if _, ok := keysSet[key]; !ok { + log.Warn("not found the key match with the store label", + zap.Stringer("store", s.GetMeta()), + zap.String("label-key", key)) + if c.opt.GetStrictlyMatchLabel() { + return errors.Errorf("key matching the label was not found in the PD, store label key: %s ", key) + } + } + } + return nil +} + +// RemoveStore marks a store as offline in cluster. +// State transition: Up -> Offline. +func (c *RaftCluster) RemoveStore(storeID uint64, physicallyDestroyed bool) error { + c.Lock() + defer c.Unlock() + + store := c.GetStore(storeID) + if store == nil { + return errs.ErrStoreNotFound.FastGenByArgs(storeID) + } + + // Remove an offline store should be OK, nothing to do. + if store.IsRemoving() && store.IsPhysicallyDestroyed() == physicallyDestroyed { + return nil + } + + if store.IsRemoved() { + return errs.ErrStoreRemoved.FastGenByArgs(storeID) + } + + if store.IsPhysicallyDestroyed() { + return errs.ErrStoreDestroyed.FastGenByArgs(storeID) + } + + if (store.IsPreparing() || store.IsServing()) && !physicallyDestroyed { + if err := c.checkReplicaBeforeOfflineStore(storeID); err != nil { + return err + } + } + + newStore := store.Clone(core.OfflineStore(physicallyDestroyed)) + log.Warn("store has been offline", + zap.Uint64("store-id", storeID), + zap.String("store-address", newStore.GetAddress()), + zap.Bool("physically-destroyed", newStore.IsPhysicallyDestroyed())) + err := c.putStoreLocked(newStore) + if err == nil { + regionSize := float64(c.core.GetStoreRegionSize(storeID)) + c.resetProgress(storeID, store.GetAddress()) + c.progressManager.AddProgress(encodeRemovingProgressKey(storeID), regionSize, regionSize, nodeStateCheckJobInterval) + // record the current store limit in memory + c.prevStoreLimit[storeID] = map[storelimit.Type]float64{ + storelimit.AddPeer: c.GetStoreLimitByType(storeID, storelimit.AddPeer), + storelimit.RemovePeer: c.GetStoreLimitByType(storeID, storelimit.RemovePeer), + } + // TODO: if the persist operation encounters error, the "Unlimited" will be rollback. + // And considering the store state has changed, RemoveStore is actually successful. + _ = c.SetStoreLimit(storeID, storelimit.RemovePeer, storelimit.Unlimited) + } + return err +} + +func (c *RaftCluster) checkReplicaBeforeOfflineStore(storeID uint64) error { + upStores := c.getUpStores() + expectUpStoresNum := len(upStores) - 1 + if expectUpStoresNum < c.opt.GetMaxReplicas() { + return errs.ErrStoresNotEnough.FastGenByArgs(storeID, expectUpStoresNum, c.opt.GetMaxReplicas()) + } + + // Check if there are extra up store to store the leaders of the regions. + evictStores := c.getEvictLeaderStores() + if len(evictStores) < expectUpStoresNum { + return nil + } + + expectUpstores := make(map[uint64]bool) + for _, UpStoreID := range upStores { + if UpStoreID == storeID { + continue + } + expectUpstores[UpStoreID] = true + } + evictLeaderStoresNum := 0 + for _, evictStoreID := range evictStores { + if _, ok := expectUpstores[evictStoreID]; ok { + evictLeaderStoresNum++ + } + } + + // returns error if there is no store for leader. + if evictLeaderStoresNum == len(expectUpstores) { + return errs.ErrNoStoreForRegionLeader.FastGenByArgs(storeID) + } + + return nil +} + +func (c *RaftCluster) getEvictLeaderStores() (evictStores []uint64) { + if c.coordinator == nil { + return nil + } + handler, ok := c.coordinator.getSchedulerHandlers()[schedulers.EvictLeaderName] + if !ok { + return + } + type evictLeaderHandler interface { + EvictStoreIDs() []uint64 + } + h, ok := handler.(evictLeaderHandler) + if !ok { + return + } + return h.EvictStoreIDs() +} + +func (c *RaftCluster) getUpStores() []uint64 { + upStores := make([]uint64, 0) + for _, store := range c.GetStores() { + if store.IsUp() { + upStores = append(upStores, store.GetID()) + } + } + return upStores +} + +// BuryStore marks a store as tombstone in cluster. +// If forceBury is false, the store should be offlined and emptied before calling this func. +func (c *RaftCluster) BuryStore(storeID uint64, forceBury bool) error { + c.Lock() + defer c.Unlock() + + store := c.GetStore(storeID) + if store == nil { + return errs.ErrStoreNotFound.FastGenByArgs(storeID) + } + + // Bury a tombstone store should be OK, nothing to do. + if store.IsRemoved() { + return nil + } + + if store.IsUp() { + if !forceBury { + return errs.ErrStoreIsUp.FastGenByArgs() + } else if !store.IsDisconnected() { + return errors.Errorf("The store %v is not offline nor disconnected", storeID) + } + } + + newStore := store.Clone(core.TombstoneStore()) + log.Warn("store has been Tombstone", + zap.Uint64("store-id", storeID), + zap.String("store-address", newStore.GetAddress()), + zap.String("state", store.GetState().String()), + zap.Bool("physically-destroyed", store.IsPhysicallyDestroyed())) + err := c.putStoreLocked(newStore) + c.onStoreVersionChangeLocked() + if err == nil { + // clean up the residual information. + delete(c.prevStoreLimit, storeID) + c.RemoveStoreLimit(storeID) + c.resetProgress(storeID, store.GetAddress()) + c.hotStat.RemoveRollingStoreStats(storeID) + } + return err +} + +// PauseLeaderTransfer prevents the store from been selected as source or +// target store of TransferLeader. +func (c *RaftCluster) PauseLeaderTransfer(storeID uint64) error { + return c.core.PauseLeaderTransfer(storeID) +} + +// ResumeLeaderTransfer cleans a store's pause state. The store can be selected +// as source or target of TransferLeader again. +func (c *RaftCluster) ResumeLeaderTransfer(storeID uint64) { + c.core.ResumeLeaderTransfer(storeID) +} + +// SlowStoreEvicted marks a store as a slow store and prevents transferring +// leader to the store +func (c *RaftCluster) SlowStoreEvicted(storeID uint64) error { + return c.core.SlowStoreEvicted(storeID) +} + +// SlowStoreRecovered cleans the evicted state of a store. +func (c *RaftCluster) SlowStoreRecovered(storeID uint64) { + c.core.SlowStoreRecovered(storeID) +} + +// NeedAwakenAllRegionsInStore checks whether we should do AwakenRegions operation. +func (c *RaftCluster) NeedAwakenAllRegionsInStore(storeID uint64) (needAwaken bool, slowStoreIDs []uint64) { + store := c.GetStore(storeID) + // We just return AwakenRegions messages to those Serving stores which need to be awaken. + if store.IsSlow() || !store.NeedAwakenStore() { + return false, nil + } + + needAwaken = false + for _, store := range c.GetStores() { + if store.IsRemoved() { + continue + } + + // We will filter out heartbeat requests from slowStores. + if (store.IsUp() || store.IsRemoving()) && store.IsSlow() && + store.GetStoreStats().GetStoreId() != storeID { + needAwaken = true + slowStoreIDs = append(slowStoreIDs, store.GetID()) + } + } + return needAwaken, slowStoreIDs +} + +// UpStore up a store from offline +func (c *RaftCluster) UpStore(storeID uint64) error { + c.Lock() + defer c.Unlock() + + store := c.GetStore(storeID) + if store == nil { + return errs.ErrStoreNotFound.FastGenByArgs(storeID) + } + + if store.IsRemoved() { + return errs.ErrStoreRemoved.FastGenByArgs(storeID) + } + + if store.IsPhysicallyDestroyed() { + return errs.ErrStoreDestroyed.FastGenByArgs(storeID) + } + + if store.IsUp() { + return nil + } + + options := []core.StoreCreateOption{core.UpStore()} + // get the previous store limit recorded in memory + limiter, exist := c.prevStoreLimit[storeID] + if exist { + options = append(options, + core.ResetStoreLimit(storelimit.AddPeer, limiter[storelimit.AddPeer]), + core.ResetStoreLimit(storelimit.RemovePeer, limiter[storelimit.RemovePeer]), + ) + } + newStore := store.Clone(options...) + log.Warn("store has been up", + zap.Uint64("store-id", storeID), + zap.String("store-address", newStore.GetAddress())) + err := c.putStoreLocked(newStore) + if err == nil { + if exist { + // persist the store limit + _ = c.SetStoreLimit(storeID, storelimit.AddPeer, limiter[storelimit.AddPeer]) + _ = c.SetStoreLimit(storeID, storelimit.RemovePeer, limiter[storelimit.RemovePeer]) + } + c.resetProgress(storeID, store.GetAddress()) + } + return err +} + +// ReadyToServe change store's node state to Serving. +func (c *RaftCluster) ReadyToServe(storeID uint64) error { + c.Lock() + defer c.Unlock() + + store := c.GetStore(storeID) + if store == nil { + return errs.ErrStoreNotFound.FastGenByArgs(storeID) + } + + if store.IsRemoved() { + return errs.ErrStoreRemoved.FastGenByArgs(storeID) + } + + if store.IsPhysicallyDestroyed() { + return errs.ErrStoreDestroyed.FastGenByArgs(storeID) + } + + if store.IsServing() { + return errs.ErrStoreServing.FastGenByArgs(storeID) + } + + newStore := store.Clone(core.UpStore()) + log.Info("store has changed to serving", + zap.Uint64("store-id", storeID), + zap.String("store-address", newStore.GetAddress())) + err := c.putStoreLocked(newStore) + if err == nil { + c.resetProgress(storeID, store.GetAddress()) + } + return err +} + +// SetStoreWeight sets up a store's leader/region balance weight. +func (c *RaftCluster) SetStoreWeight(storeID uint64, leaderWeight, regionWeight float64) error { + store := c.GetStore(storeID) + if store == nil { + return errs.ErrStoreNotFound.FastGenByArgs(storeID) + } + + if err := c.storage.SaveStoreWeight(storeID, leaderWeight, regionWeight); err != nil { + return err + } + + newStore := store.Clone( + core.SetLeaderWeight(leaderWeight), + core.SetRegionWeight(regionWeight), + ) + + return c.putStoreLocked(newStore) +} + +func (c *RaftCluster) putStoreLocked(store *core.StoreInfo) error { + if c.storage != nil { + if err := c.storage.SaveStore(store.GetMeta()); err != nil { + return err + } + } + c.core.PutStore(store) + c.hotStat.GetOrCreateRollingStoreStats(store.GetID()) + return nil +} + +func (c *RaftCluster) checkStores() { + var offlineStores []*metapb.Store + var upStoreCount int + stores := c.GetStores() + + for _, store := range stores { + // the store has already been tombstone + if store.IsRemoved() { + if store.DownTime() > gcTombstoreInterval { + err := c.deleteStore(store) + if err != nil { + log.Error("auto gc the tombstore store failed", + zap.Stringer("store", store.GetMeta()), + zap.Duration("down-time", store.DownTime()), + errs.ZapError(err)) + } else { + log.Info("auto gc the tombstore store success", zap.Stringer("store", store.GetMeta()), zap.Duration("down-time", store.DownTime())) + } + } + continue + } + + storeID := store.GetID() + if store.IsPreparing() { + if store.GetUptime() >= c.opt.GetMaxStorePreparingTime() || c.GetRegionCount() < core.InitClusterRegionThreshold { + if err := c.ReadyToServe(storeID); err != nil { + log.Error("change store to serving failed", + zap.Stringer("store", store.GetMeta()), + errs.ZapError(err)) + } + } else if c.IsPrepared() { + threshold := c.getThreshold(stores, store) + log.Debug("store serving threshold", zap.Uint64("store-id", storeID), zap.Float64("threshold", threshold)) + regionSize := float64(store.GetRegionSize()) + if regionSize >= threshold { + if err := c.ReadyToServe(storeID); err != nil { + log.Error("change store to serving failed", + zap.Stringer("store", store.GetMeta()), + errs.ZapError(err)) + } + } else { + remaining := threshold - regionSize + // If we add multiple stores, the total will need to be changed. + c.progressManager.UpdateProgressTotal(encodePreparingProgressKey(storeID), threshold) + c.updateProgress(storeID, store.GetAddress(), preparingAction, regionSize, remaining, true /* inc */) + } + } + } + + if store.IsUp() { + if !store.IsLowSpace(c.opt.GetLowSpaceRatio()) { + upStoreCount++ + } + continue + } + + offlineStore := store.GetMeta() + id := offlineStore.GetId() + regionSize := c.core.GetStoreRegionSize(id) + if c.IsPrepared() { + c.updateProgress(id, store.GetAddress(), removingAction, float64(regionSize), float64(regionSize), false /* dec */) + } + regionCount := c.core.GetStoreRegionCount(id) + // If the store is empty, it can be buried. + if regionCount == 0 { + if err := c.BuryStore(id, false); err != nil { + log.Error("bury store failed", + zap.Stringer("store", offlineStore), + errs.ZapError(err)) + } + } else { + offlineStores = append(offlineStores, offlineStore) + } + } + + if len(offlineStores) == 0 { + return + } + + // When placement rules feature is enabled. It is hard to determine required replica count precisely. + if !c.opt.IsPlacementRulesEnabled() && upStoreCount < c.opt.GetMaxReplicas() { + for _, offlineStore := range offlineStores { + log.Warn("store may not turn into Tombstone, there are no extra up store has enough space to accommodate the extra replica", zap.Stringer("store", offlineStore)) + } + } +} + +func (c *RaftCluster) getThreshold(stores []*core.StoreInfo, store *core.StoreInfo) float64 { + start := time.Now() + if !c.opt.IsPlacementRulesEnabled() { + regionSize := c.core.GetRegionSizeByRange([]byte(""), []byte("")) * int64(c.opt.GetMaxReplicas()) + weight := getStoreTopoWeight(store, stores, c.opt.GetLocationLabels(), c.opt.GetMaxReplicas()) + return float64(regionSize) * weight * 0.9 + } + + keys := c.ruleManager.GetSplitKeys([]byte(""), []byte("")) + if len(keys) == 0 { + return c.calculateRange(stores, store, []byte(""), []byte("")) * 0.9 + } + + storeSize := 0.0 + startKey := []byte("") + for _, key := range keys { + endKey := key + storeSize += c.calculateRange(stores, store, startKey, endKey) + startKey = endKey + } + // the range from the last split key to the last key + storeSize += c.calculateRange(stores, store, startKey, []byte("")) + log.Debug("threshold calculation time", zap.Duration("cost", time.Since(start))) + return storeSize * 0.9 +} + +func (c *RaftCluster) calculateRange(stores []*core.StoreInfo, store *core.StoreInfo, startKey, endKey []byte) float64 { + var storeSize float64 + rules := c.ruleManager.GetRulesForApplyRange(startKey, endKey) + for _, rule := range rules { + if !placement.MatchLabelConstraints(store, rule.LabelConstraints) { + continue + } + + var matchStores []*core.StoreInfo + for _, s := range stores { + if s.IsRemoving() || s.IsRemoved() { + continue + } + if placement.MatchLabelConstraints(s, rule.LabelConstraints) { + matchStores = append(matchStores, s) + } + } + regionSize := c.core.GetRegionSizeByRange(startKey, endKey) * int64(rule.Count) + weight := getStoreTopoWeight(store, matchStores, rule.LocationLabels, rule.Count) + storeSize += float64(regionSize) * weight + log.Debug("calculate range result", + logutil.ZapRedactString("start-key", string(core.HexRegionKey(startKey))), + logutil.ZapRedactString("end-key", string(core.HexRegionKey(endKey))), + zap.Uint64("store-id", store.GetID()), + zap.String("rule", rule.String()), + zap.Int64("region-size", regionSize), + zap.Float64("weight", weight), + zap.Float64("store-size", storeSize), + ) + } + return storeSize +} + +func getStoreTopoWeight(store *core.StoreInfo, stores []*core.StoreInfo, locationLabels []string, count int) float64 { + topology, validLabels, sameLocationStoreNum, isMatch := buildTopology(store, stores, locationLabels, count) + weight := 1.0 + topo := topology + if isMatch { + return weight / float64(count) / sameLocationStoreNum + } + + storeLabels := getSortedLabels(store.GetLabels(), locationLabels) + for _, label := range storeLabels { + if _, ok := topo[label.Value]; ok { + if slice.Contains(validLabels, label.Key) { + weight /= float64(len(topo)) + } + topo = topo[label.Value].(map[string]interface{}) + } else { + break + } + } + + return weight / sameLocationStoreNum +} + +func buildTopology(s *core.StoreInfo, stores []*core.StoreInfo, locationLabels []string, count int) (map[string]interface{}, []string, float64, bool) { + topology := make(map[string]interface{}) + sameLocationStoreNum := 1.0 + totalLabelCount := make([]int, len(locationLabels)) + for _, store := range stores { + if store.IsServing() || store.IsPreparing() { + labelCount := updateTopology(topology, getSortedLabels(store.GetLabels(), locationLabels)) + for i, c := range labelCount { + totalLabelCount[i] += c + } + } + } + + validLabels := locationLabels + var isMatch bool + for i, c := range totalLabelCount { + if count/c == 0 { + validLabels = validLabels[:i] + break + } + if count/c == 1 && count%c == 0 { + validLabels = validLabels[:i+1] + isMatch = true + break + } + } + for _, store := range stores { + if store.GetID() == s.GetID() { + continue + } + if s.CompareLocation(store, validLabels) == -1 { + sameLocationStoreNum++ + } + } + + return topology, validLabels, sameLocationStoreNum, isMatch +} + +func getSortedLabels(storeLabels []*metapb.StoreLabel, locationLabels []string) []*metapb.StoreLabel { + var sortedLabels []*metapb.StoreLabel + for _, ll := range locationLabels { + find := false + for _, sl := range storeLabels { + if ll == sl.Key { + sortedLabels = append(sortedLabels, sl) + find = true + break + } + } + // TODO: we need to improve this logic to make the label calculation more accurate if the user has the wrong label settings. + if !find { + sortedLabels = append(sortedLabels, &metapb.StoreLabel{Key: ll, Value: ""}) + } + } + return sortedLabels +} + +// updateTopology records stores' topology in the `topology` variable. +func updateTopology(topology map[string]interface{}, sortedLabels []*metapb.StoreLabel) []int { + labelCount := make([]int, len(sortedLabels)) + if len(sortedLabels) == 0 { + return labelCount + } + topo := topology + for i, l := range sortedLabels { + if _, exist := topo[l.Value]; !exist { + topo[l.Value] = make(map[string]interface{}) + labelCount[i] += 1 + } + topo = topo[l.Value].(map[string]interface{}) + } + return labelCount +} + +func (c *RaftCluster) updateProgress(storeID uint64, storeAddress, action string, current, remaining float64, isInc bool) { + storeLabel := strconv.FormatUint(storeID, 10) + var progress string + switch action { + case removingAction: + progress = encodeRemovingProgressKey(storeID) + case preparingAction: + progress = encodePreparingProgressKey(storeID) + } + + if exist := c.progressManager.AddProgress(progress, current, remaining, nodeStateCheckJobInterval); !exist { + return + } + c.progressManager.UpdateProgress(progress, current, remaining, isInc) + process, ls, cs, err := c.progressManager.Status(progress) + if err != nil { + log.Error("get progress status failed", zap.String("progress", progress), zap.Float64("remaining", remaining), errs.ZapError(err)) + return + } + storesProgressGauge.WithLabelValues(storeAddress, storeLabel, action).Set(process) + storesSpeedGauge.WithLabelValues(storeAddress, storeLabel, action).Set(cs) + storesETAGauge.WithLabelValues(storeAddress, storeLabel, action).Set(ls) +} + +func (c *RaftCluster) resetProgress(storeID uint64, storeAddress string) { + storeLabel := strconv.FormatUint(storeID, 10) + + progress := encodePreparingProgressKey(storeID) + if exist := c.progressManager.RemoveProgress(progress); exist { + storesProgressGauge.DeleteLabelValues(storeAddress, storeLabel, preparingAction) + storesSpeedGauge.DeleteLabelValues(storeAddress, storeLabel, preparingAction) + storesETAGauge.DeleteLabelValues(storeAddress, storeLabel, preparingAction) + } + progress = encodeRemovingProgressKey(storeID) + if exist := c.progressManager.RemoveProgress(progress); exist { + storesProgressGauge.DeleteLabelValues(storeAddress, storeLabel, removingAction) + storesSpeedGauge.DeleteLabelValues(storeAddress, storeLabel, removingAction) + storesETAGauge.DeleteLabelValues(storeAddress, storeLabel, removingAction) + } +} + +func encodeRemovingProgressKey(storeID uint64) string { + return fmt.Sprintf("%s-%d", removingAction, storeID) +} + +func encodePreparingProgressKey(storeID uint64) string { + return fmt.Sprintf("%s-%d", preparingAction, storeID) +} + +// RemoveTombStoneRecords removes the tombStone Records. +func (c *RaftCluster) RemoveTombStoneRecords() error { + c.Lock() + defer c.Unlock() + + var failedStores []uint64 + for _, store := range c.GetStores() { + if store.IsRemoved() { + if c.core.GetStoreRegionCount(store.GetID()) > 0 { + log.Warn("skip removing tombstone", zap.Stringer("store", store.GetMeta())) + failedStores = append(failedStores, store.GetID()) + continue + } + // the store has already been tombstone + err := c.deleteStore(store) + if err != nil { + log.Error("delete store failed", + zap.Stringer("store", store.GetMeta()), + errs.ZapError(err)) + return err + } + c.RemoveStoreLimit(store.GetID()) + log.Info("delete store succeeded", + zap.Stringer("store", store.GetMeta())) + } + } + var stores string + if len(failedStores) != 0 { + for i, storeID := range failedStores { + stores += fmt.Sprintf("%d", storeID) + if i != len(failedStores)-1 { + stores += ", " + } + } + return errors.Errorf("failed stores: %v", stores) + } + return nil +} + +// deleteStore deletes the store from the cluster. it's concurrent safe. +func (c *RaftCluster) deleteStore(store *core.StoreInfo) error { + if c.storage != nil { + if err := c.storage.DeleteStore(store.GetMeta()); err != nil { + return err + } + } + c.core.DeleteStore(store) + return nil +} + +// SetHotPendingInfluenceMetrics sets pending influence in hot scheduler. +func (c *RaftCluster) SetHotPendingInfluenceMetrics(storeLabel, rwTy, dim string, load float64) { + hotPendingSum.WithLabelValues(storeLabel, rwTy, dim).Set(load) +} + +func (c *RaftCluster) collectMetrics() { + statsMap := statistics.NewStoreStatisticsMap(c.opt, c.storeConfigManager.GetStoreConfig()) + stores := c.GetStores() + for _, s := range stores { + statsMap.Observe(s, c.hotStat.StoresStats) + } + statsMap.Collect() + + c.coordinator.collectSchedulerMetrics() + c.coordinator.collectHotSpotMetrics() + c.collectClusterMetrics() + c.collectHealthStatus() +} + +func (c *RaftCluster) resetMetrics() { + statsMap := statistics.NewStoreStatisticsMap(c.opt, c.storeConfigManager.GetStoreConfig()) + statsMap.Reset() + + c.coordinator.resetSchedulerMetrics() + c.coordinator.resetHotSpotMetrics() + c.resetClusterMetrics() + c.resetHealthStatus() + c.resetProgressIndicator() +} + +func (c *RaftCluster) collectClusterMetrics() { + if c.regionStats == nil { + return + } + c.regionStats.Collect() + c.labelLevelStats.Collect() + // collect hot cache metrics + c.hotStat.CollectMetrics() +} + +func (c *RaftCluster) resetClusterMetrics() { + if c.regionStats == nil { + return + } + c.regionStats.Reset() + c.labelLevelStats.Reset() + // reset hot cache metrics + c.hotStat.ResetMetrics() +} + +func (c *RaftCluster) collectHealthStatus() { + members, err := GetMembers(c.etcdClient) + if err != nil { + log.Error("get members error", errs.ZapError(err)) + } + healthy := CheckHealth(c.httpClient, members) + for _, member := range members { + var v float64 + if _, ok := healthy[member.GetMemberId()]; ok { + v = 1 + } + healthStatusGauge.WithLabelValues(member.GetName()).Set(v) + } +} + +func (c *RaftCluster) resetHealthStatus() { + healthStatusGauge.Reset() +} + +func (c *RaftCluster) resetProgressIndicator() { + c.progressManager.Reset() + storesProgressGauge.Reset() + storesSpeedGauge.Reset() + storesETAGauge.Reset() +} + +// GetRegionStatsByType gets the status of the region by types. +func (c *RaftCluster) GetRegionStatsByType(typ statistics.RegionStatisticType) []*core.RegionInfo { + if c.regionStats == nil { + return nil + } + return c.regionStats.GetRegionStatsByType(typ) +} + +// GetOfflineRegionStatsByType gets the status of the offline region by types. +func (c *RaftCluster) GetOfflineRegionStatsByType(typ statistics.RegionStatisticType) []*core.RegionInfo { + if c.regionStats == nil { + return nil + } + return c.regionStats.GetOfflineRegionStatsByType(typ) +} + +func (c *RaftCluster) updateRegionsLabelLevelStats(regions []*core.RegionInfo) { + for _, region := range regions { + c.labelLevelStats.Observe(region, c.getStoresWithoutLabelLocked(region, core.EngineKey, core.EngineTiFlash), c.opt.GetLocationLabels()) + } +} + +func (c *RaftCluster) getRegionStoresLocked(region *core.RegionInfo) []*core.StoreInfo { + stores := make([]*core.StoreInfo, 0, len(region.GetPeers())) + for _, p := range region.GetPeers() { + if store := c.core.GetStore(p.StoreId); store != nil { + stores = append(stores, store) + } + } + return stores +} + +func (c *RaftCluster) getStoresWithoutLabelLocked(region *core.RegionInfo, key, value string) []*core.StoreInfo { + stores := make([]*core.StoreInfo, 0, len(region.GetPeers())) + for _, p := range region.GetPeers() { + if store := c.core.GetStore(p.StoreId); store != nil && !core.IsStoreContainLabel(store.GetMeta(), key, value) { + stores = append(stores, store) + } + } + return stores +} + +// OnStoreVersionChange changes the version of the cluster when needed. +func (c *RaftCluster) OnStoreVersionChange() { + c.RLock() + defer c.RUnlock() + c.onStoreVersionChangeLocked() +} + +func (c *RaftCluster) onStoreVersionChangeLocked() { + var minVersion *semver.Version + stores := c.GetStores() + for _, s := range stores { + if s.IsRemoved() { + continue + } + v := versioninfo.MustParseVersion(s.GetVersion()) + + if minVersion == nil || v.LessThan(*minVersion) { + minVersion = v + } + } + clusterVersion := c.opt.GetClusterVersion() + // If the cluster version of PD is less than the minimum version of all stores, + // it will update the cluster version. + failpoint.Inject("versionChangeConcurrency", func() { + time.Sleep(500 * time.Millisecond) + }) + if minVersion == nil || clusterVersion.Equal(*minVersion) { + return + } + + if !c.opt.CASClusterVersion(clusterVersion, minVersion) { + log.Error("cluster version changed by API at the same time") + } + err := c.opt.Persist(c.storage) + if err != nil { + log.Error("persist cluster version meet error", errs.ZapError(err)) + } + log.Info("cluster version changed", + zap.Stringer("old-cluster-version", clusterVersion), + zap.Stringer("new-cluster-version", minVersion)) +} + +func (c *RaftCluster) changedRegionNotifier() <-chan *core.RegionInfo { + return c.changedRegions +} + +// GetMetaCluster gets meta cluster. +func (c *RaftCluster) GetMetaCluster() *metapb.Cluster { + c.RLock() + defer c.RUnlock() + return typeutil.DeepClone(c.meta, core.ClusterFactory) +} + +// PutMetaCluster puts meta cluster. +func (c *RaftCluster) PutMetaCluster(meta *metapb.Cluster) error { + c.Lock() + defer c.Unlock() + if meta.GetId() != c.clusterID { + return errors.Errorf("invalid cluster %v, mismatch cluster id %d", meta, c.clusterID) + } + return c.putMetaLocked(typeutil.DeepClone(meta, core.ClusterFactory)) +} + +// GetRegionStats returns region statistics from cluster. +func (c *RaftCluster) GetRegionStats(startKey, endKey []byte) *statistics.RegionStats { + return statistics.GetRegionStats(c.core.ScanRange(startKey, endKey, -1)) +} + +// GetRangeCount returns the number of regions in the range. +func (c *RaftCluster) GetRangeCount(startKey, endKey []byte) *statistics.RegionStats { + stats := &statistics.RegionStats{} + stats.Count = c.core.GetRangeCount(startKey, endKey) + return stats +} + +// GetStoresStats returns stores' statistics from cluster. +// And it will be unnecessary to filter unhealthy store, because it has been solved in process heartbeat +func (c *RaftCluster) GetStoresStats() *statistics.StoresStats { + return c.hotStat.StoresStats +} + +// GetStoresLoads returns load stats of all stores. +func (c *RaftCluster) GetStoresLoads() map[uint64][]float64 { + return c.hotStat.GetStoresLoads() +} + +// IsRegionHot checks if a region is in hot state. +func (c *RaftCluster) IsRegionHot(region *core.RegionInfo) bool { + return c.hotStat.IsRegionHot(region, c.opt.GetHotRegionCacheHitsThreshold()) +} + +// GetHotPeerStat returns hot peer stat with specified regionID and storeID. +func (c *RaftCluster) GetHotPeerStat(rw statistics.RWType, regionID, storeID uint64) *statistics.HotPeerStat { + return c.hotStat.GetHotPeerStat(rw, regionID, storeID) +} + +// RegionReadStats returns hot region's read stats. +// The result only includes peers that are hot enough. +// RegionStats is a thread-safe method +func (c *RaftCluster) RegionReadStats() map[uint64][]*statistics.HotPeerStat { + // As read stats are reported by store heartbeat, the threshold needs to be adjusted. + threshold := c.GetOpts().GetHotRegionCacheHitsThreshold() * + (statistics.RegionHeartBeatReportInterval / statistics.StoreHeartBeatReportInterval) + return c.hotStat.RegionStats(statistics.Read, threshold) +} + +// BucketsStats returns hot region's buckets stats. +func (c *RaftCluster) BucketsStats(degree int) map[uint64][]*buckets.BucketStat { + task := buckets.NewCollectBucketStatsTask(degree) + if !c.hotBuckets.CheckAsync(task) { + return nil + } + return task.WaitRet(c.ctx) +} + +// RegionWriteStats returns hot region's write stats. +// The result only includes peers that are hot enough. +func (c *RaftCluster) RegionWriteStats() map[uint64][]*statistics.HotPeerStat { + // RegionStats is a thread-safe method + return c.hotStat.RegionStats(statistics.Write, c.GetOpts().GetHotRegionCacheHitsThreshold()) +} + +// TODO: remove me. +// only used in test. +func (c *RaftCluster) putRegion(region *core.RegionInfo) error { + if c.storage != nil { + if err := c.storage.SaveRegion(region.GetMeta()); err != nil { + return err + } + } + c.core.PutRegion(region) + return nil +} + +// GetHotWriteRegions gets hot write regions' info. +func (c *RaftCluster) GetHotWriteRegions(storeIDs ...uint64) *statistics.StoreHotPeersInfos { + hotWriteRegions := c.coordinator.getHotRegionsByType(statistics.Write) + if len(storeIDs) > 0 && hotWriteRegions != nil { + hotWriteRegions = getHotRegionsByStoreIDs(hotWriteRegions, storeIDs...) + } + return hotWriteRegions +} + +// GetHotReadRegions gets hot read regions' info. +func (c *RaftCluster) GetHotReadRegions(storeIDs ...uint64) *statistics.StoreHotPeersInfos { + hotReadRegions := c.coordinator.getHotRegionsByType(statistics.Read) + if len(storeIDs) > 0 && hotReadRegions != nil { + hotReadRegions = getHotRegionsByStoreIDs(hotReadRegions, storeIDs...) + } + return hotReadRegions +} + +func getHotRegionsByStoreIDs(hotPeerInfos *statistics.StoreHotPeersInfos, storeIDs ...uint64) *statistics.StoreHotPeersInfos { + asLeader := statistics.StoreHotPeersStat{} + asPeer := statistics.StoreHotPeersStat{} + for _, storeID := range storeIDs { + asLeader[storeID] = hotPeerInfos.AsLeader[storeID] + asPeer[storeID] = hotPeerInfos.AsPeer[storeID] + } + return &statistics.StoreHotPeersInfos{ + AsLeader: asLeader, + AsPeer: asPeer, + } +} + +// GetStoreLimiter returns the dynamic adjusting limiter +func (c *RaftCluster) GetStoreLimiter() *StoreLimiter { + return c.limiter +} + +// GetStoreLimitByType returns the store limit for a given store ID and type. +func (c *RaftCluster) GetStoreLimitByType(storeID uint64, typ storelimit.Type) float64 { + return c.opt.GetStoreLimitByType(storeID, typ) +} + +// GetAllStoresLimit returns all store limit +func (c *RaftCluster) GetAllStoresLimit() map[uint64]config.StoreLimitConfig { + return c.opt.GetAllStoresLimit() +} + +// AddStoreLimit add a store limit for a given store ID. +func (c *RaftCluster) AddStoreLimit(store *metapb.Store) { + storeID := store.GetId() + cfg := c.opt.GetScheduleConfig().Clone() + if _, ok := cfg.StoreLimit[storeID]; ok { + return + } + + sc := config.StoreLimitConfig{ + AddPeer: config.DefaultStoreLimit.GetDefaultStoreLimit(storelimit.AddPeer), + RemovePeer: config.DefaultStoreLimit.GetDefaultStoreLimit(storelimit.RemovePeer), + } + if core.IsStoreContainLabel(store, core.EngineKey, core.EngineTiFlash) { + sc = config.StoreLimitConfig{ + AddPeer: config.DefaultTiFlashStoreLimit.GetDefaultStoreLimit(storelimit.AddPeer), + RemovePeer: config.DefaultTiFlashStoreLimit.GetDefaultStoreLimit(storelimit.RemovePeer), + } + } + + cfg.StoreLimit[storeID] = sc + c.opt.SetScheduleConfig(cfg) + var err error + for i := 0; i < persistLimitRetryTimes; i++ { + if err = c.opt.Persist(c.storage); err == nil { + log.Info("store limit added", zap.Uint64("store-id", storeID)) + return + } + time.Sleep(persistLimitWaitTime) + } + log.Error("persist store limit meet error", errs.ZapError(err)) +} + +// RemoveStoreLimit remove a store limit for a given store ID. +func (c *RaftCluster) RemoveStoreLimit(storeID uint64) { + cfg := c.opt.GetScheduleConfig().Clone() + for _, limitType := range storelimit.TypeNameValue { + c.core.ResetStoreLimit(storeID, limitType) + } + delete(cfg.StoreLimit, storeID) + c.opt.SetScheduleConfig(cfg) + var err error + for i := 0; i < persistLimitRetryTimes; i++ { + if err = c.opt.Persist(c.storage); err == nil { + log.Info("store limit removed", zap.Uint64("store-id", storeID)) + id := strconv.FormatUint(storeID, 10) + statistics.StoreLimitGauge.DeleteLabelValues(id, "add-peer") + statistics.StoreLimitGauge.DeleteLabelValues(id, "remove-peer") + return + } + time.Sleep(persistLimitWaitTime) + } + log.Error("persist store limit meet error", errs.ZapError(err)) +} + +// SetMinResolvedTS sets up a store with min resolved ts. +func (c *RaftCluster) SetMinResolvedTS(storeID, minResolvedTS uint64) error { + c.Lock() + defer c.Unlock() + + store := c.GetStore(storeID) + if store == nil { + return errs.ErrStoreNotFound.FastGenByArgs(storeID) + } + + newStore := store.Clone(core.SetMinResolvedTS(minResolvedTS)) + c.core.PutStore(newStore) + return nil +} + +func (c *RaftCluster) checkAndUpdateMinResolvedTS() (uint64, bool) { + c.Lock() + defer c.Unlock() + + if !c.isInitialized() { + return math.MaxUint64, false + } + curMinResolvedTS := uint64(math.MaxUint64) + for _, s := range c.GetStores() { + if !core.IsAvailableForMinResolvedTS(s) { + continue + } + if curMinResolvedTS > s.GetMinResolvedTS() { + curMinResolvedTS = s.GetMinResolvedTS() + } + } + if curMinResolvedTS == math.MaxUint64 || curMinResolvedTS <= c.minResolvedTS { + return c.minResolvedTS, false + } + c.minResolvedTS = curMinResolvedTS + return c.minResolvedTS, true +} + +func (c *RaftCluster) runMinResolvedTSJob() { + defer logutil.LogPanic() + defer c.wg.Done() + + interval := c.opt.GetMinResolvedTSPersistenceInterval() + if interval == 0 { + interval = DefaultMinResolvedTSPersistenceInterval + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + c.loadMinResolvedTS() + for { + select { + case <-c.ctx.Done(): + log.Info("min resolved ts background jobs has been stopped") + return + case <-ticker.C: + interval = c.opt.GetMinResolvedTSPersistenceInterval() + if interval != 0 { + if current, needPersist := c.checkAndUpdateMinResolvedTS(); needPersist { + c.storage.SaveMinResolvedTS(current) + } + } else { + // If interval in config is zero, it means not to persist resolved ts and check config with this interval + interval = DefaultMinResolvedTSPersistenceInterval + } + ticker.Reset(interval) + } + } +} + +func (c *RaftCluster) loadMinResolvedTS() { + // Use `c.GetStorage()` here to prevent from the data race in test. + minResolvedTS, err := c.GetStorage().LoadMinResolvedTS() + if err != nil { + log.Error("load min resolved ts meet error", errs.ZapError(err)) + return + } + c.Lock() + defer c.Unlock() + c.minResolvedTS = minResolvedTS +} + +// GetMinResolvedTS returns the min resolved ts of the cluster. +func (c *RaftCluster) GetMinResolvedTS() uint64 { + c.RLock() + defer c.RUnlock() + if !c.isInitialized() { + return math.MaxUint64 + } + return c.minResolvedTS +} + +// GetExternalTS returns the external timestamp. +func (c *RaftCluster) GetExternalTS() uint64 { + c.RLock() + defer c.RUnlock() + if !c.isInitialized() { + return math.MaxUint64 + } + return c.externalTS +} + +// SetExternalTS sets the external timestamp. +func (c *RaftCluster) SetExternalTS(timestamp uint64) error { + c.Lock() + defer c.Unlock() + c.externalTS = timestamp + c.storage.SaveExternalTS(timestamp) + return nil +} + +// SetStoreLimit sets a store limit for a given type and rate. +func (c *RaftCluster) SetStoreLimit(storeID uint64, typ storelimit.Type, ratePerMin float64) error { + old := c.opt.GetScheduleConfig().Clone() + c.opt.SetStoreLimit(storeID, typ, ratePerMin) + if err := c.opt.Persist(c.storage); err != nil { + // roll back the store limit + c.opt.SetScheduleConfig(old) + log.Error("persist store limit meet error", errs.ZapError(err)) + return err + } + log.Info("store limit changed", zap.Uint64("store-id", storeID), zap.String("type", typ.String()), zap.Float64("rate-per-min", ratePerMin)) + return nil +} + +// SetAllStoresLimit sets all store limit for a given type and rate. +func (c *RaftCluster) SetAllStoresLimit(typ storelimit.Type, ratePerMin float64) error { + old := c.opt.GetScheduleConfig().Clone() + oldAdd := config.DefaultStoreLimit.GetDefaultStoreLimit(storelimit.AddPeer) + oldRemove := config.DefaultStoreLimit.GetDefaultStoreLimit(storelimit.RemovePeer) + c.opt.SetAllStoresLimit(typ, ratePerMin) + if err := c.opt.Persist(c.storage); err != nil { + // roll back the store limit + c.opt.SetScheduleConfig(old) + config.DefaultStoreLimit.SetDefaultStoreLimit(storelimit.AddPeer, oldAdd) + config.DefaultStoreLimit.SetDefaultStoreLimit(storelimit.RemovePeer, oldRemove) + log.Error("persist store limit meet error", errs.ZapError(err)) + return err + } + log.Info("all store limit changed", zap.String("type", typ.String()), zap.Float64("rate-per-min", ratePerMin)) + return nil +} + +// SetAllStoresLimitTTL sets all store limit for a given type and rate with ttl. +func (c *RaftCluster) SetAllStoresLimitTTL(typ storelimit.Type, ratePerMin float64, ttl time.Duration) { + c.opt.SetAllStoresLimitTTL(c.ctx, c.etcdClient, typ, ratePerMin, ttl) +} + +// GetClusterVersion returns the current cluster version. +func (c *RaftCluster) GetClusterVersion() string { + return c.opt.GetClusterVersion().String() +} + +// GetEtcdClient returns the current etcd client +func (c *RaftCluster) GetEtcdClient() *clientv3.Client { + return c.etcdClient +} + +// GetProgressByID returns the progress details for a given store ID. +func (c *RaftCluster) GetProgressByID(storeID string) (action string, process, ls, cs float64, err error) { + filter := func(progress string) bool { + s := strings.Split(progress, "-") + return len(s) == 2 && s[1] == storeID + } + progress := c.progressManager.GetProgresses(filter) + if len(progress) != 0 { + process, ls, cs, err = c.progressManager.Status(progress[0]) + if err != nil { + return + } + if strings.HasPrefix(progress[0], removingAction) { + action = removingAction + } else if strings.HasPrefix(progress[0], preparingAction) { + action = preparingAction + } + return + } + return "", 0, 0, 0, errs.ErrProgressNotFound.FastGenByArgs(fmt.Sprintf("the given store ID: %s", storeID)) +} + +// GetProgressByAction returns the progress details for a given action. +func (c *RaftCluster) GetProgressByAction(action string) (process, ls, cs float64, err error) { + filter := func(progress string) bool { + return strings.HasPrefix(progress, action) + } + + progresses := c.progressManager.GetProgresses(filter) + if len(progresses) == 0 { + return 0, 0, 0, errs.ErrProgressNotFound.FastGenByArgs(fmt.Sprintf("the action: %s", action)) + } + var p, l, s float64 + for _, progress := range progresses { + p, l, s, err = c.progressManager.Status(progress) + if err != nil { + return + } + process += p + ls += l + cs += s + } + num := float64(len(progresses)) + process /= num + cs /= num + ls /= num + // handle the special cases + if math.IsNaN(ls) || math.IsInf(ls, 0) { + ls = math.MaxFloat64 + } + return +} + +var healthURL = "/pd/api/v1/ping" + +// CheckHealth checks if members are healthy. +func CheckHealth(client *http.Client, members []*pdpb.Member) map[uint64]*pdpb.Member { + healthMembers := make(map[uint64]*pdpb.Member) + for _, member := range members { + for _, cURL := range member.ClientUrls { + ctx, cancel := context.WithTimeout(context.Background(), clientTimeout) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", cURL, healthURL), nil) + if err != nil { + log.Error("failed to new request", errs.ZapError(errs.ErrNewHTTPRequest, err)) + cancel() + continue + } + + resp, err := client.Do(req) + if resp != nil { + resp.Body.Close() + } + cancel() + if err == nil && resp.StatusCode == http.StatusOK { + healthMembers[member.GetMemberId()] = member + break + } + } + } + return healthMembers +} + +// GetMembers return a slice of Members. +func GetMembers(etcdClient *clientv3.Client) ([]*pdpb.Member, error) { + listResp, err := etcdutil.ListEtcdMembers(etcdClient) + if err != nil { + return nil, err + } + + members := make([]*pdpb.Member, 0, len(listResp.Members)) + for _, m := range listResp.Members { + info := &pdpb.Member{ + Name: m.Name, + MemberId: m.ID, + ClientUrls: m.ClientURLs, + PeerUrls: m.PeerURLs, + } + members = append(members, info) + } + + return members, nil +} + +// IsClientURL returns whether addr is a ClientUrl of any member. +func IsClientURL(addr string, etcdClient *clientv3.Client) bool { + members, err := GetMembers(etcdClient) + if err != nil { + return false + } + for _, member := range members { + for _, u := range member.GetClientUrls() { + if u == addr { + return true + } + } + } + return false +} + +// cacheCluster include cache info to improve the performance. +type cacheCluster struct { + *RaftCluster + stores []*core.StoreInfo +} + +// GetStores returns store infos from cache +func (c *cacheCluster) GetStores() []*core.StoreInfo { + return c.stores +} + +// newCacheCluster constructor for cache +func newCacheCluster(c *RaftCluster) *cacheCluster { + return &cacheCluster{ + RaftCluster: c, + stores: c.GetStores(), + } +} + +// GetPausedSchedulerDelayAt returns DelayAt of a paused scheduler +func (c *RaftCluster) GetPausedSchedulerDelayAt(name string) (int64, error) { + return c.coordinator.getPausedSchedulerDelayAt(name) +} + +// GetPausedSchedulerDelayUntil returns DelayUntil of a paused scheduler +func (c *RaftCluster) GetPausedSchedulerDelayUntil(name string) (int64, error) { + return c.coordinator.getPausedSchedulerDelayUntil(name) +} diff --git a/server/cluster/coordinator.go b/server/cluster/coordinator.go old mode 100644 new mode 100755 index c1770f861eb..5511ce38daf --- a/server/cluster/coordinator.go +++ b/server/cluster/coordinator.go @@ -151,9 +151,9 @@ func (c *coordinator) patrolRegions() { patrolCheckRegionsGauge.Set(time.Since(start).Seconds()) start = time.Now() } - failpoint.Inject("break-patrol", func() { - failpoint.Break() - }) + if _, _err_ := failpoint.Eval(_curpkg_("break-patrol")); _err_ == nil { + break + } } } @@ -306,9 +306,9 @@ func (c *coordinator) runUntilStop() { func (c *coordinator) run() { ticker := time.NewTicker(runSchedulerCheckInterval) - failpoint.Inject("changeCoordinatorTicker", func() { + if _, _err_ := failpoint.Eval(_curpkg_("changeCoordinatorTicker")); _err_ == nil { ticker = time.NewTicker(100 * time.Millisecond) - }) + } defer ticker.Stop() log.Info("coordinator starts to collect cluster information") for { diff --git a/server/cluster/coordinator.go__failpoint_stash__ b/server/cluster/coordinator.go__failpoint_stash__ new file mode 100644 index 00000000000..c1770f861eb --- /dev/null +++ b/server/cluster/coordinator.go__failpoint_stash__ @@ -0,0 +1,995 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cluster + +import ( + "bytes" + "context" + "net/http" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/cache" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/logutil" + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/server/config" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/schedule" + "github.com/tikv/pd/server/schedule/checker" + "github.com/tikv/pd/server/schedule/hbstream" + "github.com/tikv/pd/server/schedule/operator" + "github.com/tikv/pd/server/schedule/plan" + "github.com/tikv/pd/server/schedulers" + "github.com/tikv/pd/server/statistics" + "github.com/tikv/pd/server/storage" + "go.uber.org/zap" +) + +const ( + runSchedulerCheckInterval = 3 * time.Second + checkSuspectRangesInterval = 100 * time.Millisecond + collectFactor = 0.9 + collectTimeout = 5 * time.Minute + maxScheduleRetries = 10 + maxLoadConfigRetries = 10 + + patrolScanRegionLimit = 128 // It takes about 14 minutes to iterate 1 million regions. + // PluginLoad means action for load plugin + PluginLoad = "PluginLoad" + // PluginUnload means action for unload plugin + PluginUnload = "PluginUnload" +) + +// coordinator is used to manage all schedulers and checkers to decide if the region needs to be scheduled. +type coordinator struct { + syncutil.RWMutex + + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + cluster *RaftCluster + prepareChecker *prepareChecker + checkers *checker.Controller + regionScatterer *schedule.RegionScatterer + regionSplitter *schedule.RegionSplitter + schedulers map[string]*scheduleController + opController *schedule.OperatorController + hbStreams *hbstream.HeartbeatStreams + pluginInterface *schedule.PluginInterface + diagnosticManager *diagnosticManager +} + +// newCoordinator creates a new coordinator. +func newCoordinator(ctx context.Context, cluster *RaftCluster, hbStreams *hbstream.HeartbeatStreams) *coordinator { + ctx, cancel := context.WithCancel(ctx) + opController := schedule.NewOperatorController(ctx, cluster, hbStreams) + schedulers := make(map[string]*scheduleController) + return &coordinator{ + ctx: ctx, + cancel: cancel, + cluster: cluster, + prepareChecker: newPrepareChecker(), + checkers: checker.NewController(ctx, cluster, cluster.ruleManager, cluster.regionLabeler, opController), + regionScatterer: schedule.NewRegionScatterer(ctx, cluster, opController), + regionSplitter: schedule.NewRegionSplitter(cluster, schedule.NewSplitRegionsHandler(cluster, opController)), + schedulers: schedulers, + opController: opController, + hbStreams: hbStreams, + pluginInterface: schedule.NewPluginInterface(), + diagnosticManager: newDiagnosticManager(cluster), + } +} + +func (c *coordinator) GetWaitingRegions() []*cache.Item { + return c.checkers.GetWaitingRegions() +} + +func (c *coordinator) IsPendingRegion(region uint64) bool { + return c.checkers.IsPendingRegion(region) +} + +// patrolRegions is used to scan regions. +// The checkers will check these regions to decide if they need to do some operations. +func (c *coordinator) patrolRegions() { + defer logutil.LogPanic() + + defer c.wg.Done() + timer := time.NewTimer(c.cluster.GetOpts().GetPatrolRegionInterval()) + defer timer.Stop() + + log.Info("coordinator starts patrol regions") + start := time.Now() + var ( + key []byte + regions []*core.RegionInfo + ) + for { + select { + case <-timer.C: + timer.Reset(c.cluster.GetOpts().GetPatrolRegionInterval()) + case <-c.ctx.Done(): + log.Info("patrol regions has been stopped") + return + } + if c.cluster.GetUnsafeRecoveryController().IsRunning() { + // Skip patrolling regions during unsafe recovery. + continue + } + + // Check priority regions first. + c.checkPriorityRegions() + // Check suspect regions first. + c.checkSuspectRegions() + // Check regions in the waiting list + c.checkWaitingRegions() + + key, regions = c.checkRegions(key) + if len(regions) == 0 { + continue + } + // Updates the label level isolation statistics. + c.cluster.updateRegionsLabelLevelStats(regions) + if len(key) == 0 { + patrolCheckRegionsGauge.Set(time.Since(start).Seconds()) + start = time.Now() + } + failpoint.Inject("break-patrol", func() { + failpoint.Break() + }) + } +} + +func (c *coordinator) checkRegions(startKey []byte) (key []byte, regions []*core.RegionInfo) { + regions = c.cluster.ScanRegions(startKey, nil, patrolScanRegionLimit) + if len(regions) == 0 { + // Resets the scan key. + key = nil + return + } + + for _, region := range regions { + c.tryAddOperators(region) + key = region.GetEndKey() + } + return +} + +func (c *coordinator) checkSuspectRegions() { + for _, id := range c.checkers.GetSuspectRegions() { + region := c.cluster.GetRegion(id) + c.tryAddOperators(region) + } +} + +func (c *coordinator) checkWaitingRegions() { + items := c.checkers.GetWaitingRegions() + regionListGauge.WithLabelValues("waiting_list").Set(float64(len(items))) + for _, item := range items { + region := c.cluster.GetRegion(item.Key) + c.tryAddOperators(region) + } +} + +// checkPriorityRegions checks priority regions +func (c *coordinator) checkPriorityRegions() { + items := c.checkers.GetPriorityRegions() + removes := make([]uint64, 0) + regionListGauge.WithLabelValues("priority_list").Set(float64(len(items))) + for _, id := range items { + region := c.cluster.GetRegion(id) + if region == nil { + removes = append(removes, id) + continue + } + ops := c.checkers.CheckRegion(region) + // it should skip if region needs to merge + if len(ops) == 0 || ops[0].Kind()&operator.OpMerge != 0 { + continue + } + if !c.opController.ExceedStoreLimit(ops...) { + c.opController.AddWaitingOperator(ops...) + } + } + for _, v := range removes { + c.checkers.RemovePriorityRegions(v) + } +} + +// checkSuspectRanges would pop one suspect key range group +// The regions of new version key range and old version key range would be placed into +// the suspect regions map +func (c *coordinator) checkSuspectRanges() { + defer c.wg.Done() + log.Info("coordinator begins to check suspect key ranges") + ticker := time.NewTicker(checkSuspectRangesInterval) + defer ticker.Stop() + for { + select { + case <-c.ctx.Done(): + log.Info("check suspect key ranges has been stopped") + return + case <-ticker.C: + keyRange, success := c.checkers.PopOneSuspectKeyRange() + if !success { + continue + } + limit := 1024 + regions := c.cluster.ScanRegions(keyRange[0], keyRange[1], limit) + if len(regions) == 0 { + continue + } + regionIDList := make([]uint64, 0, len(regions)) + for _, region := range regions { + regionIDList = append(regionIDList, region.GetID()) + } + + // if the last region's end key is smaller the keyRange[1] which means there existed the remaining regions between + // keyRange[0] and keyRange[1] after scan regions, so we put the end key and keyRange[1] into Suspect KeyRanges + lastRegion := regions[len(regions)-1] + if lastRegion.GetEndKey() != nil && bytes.Compare(lastRegion.GetEndKey(), keyRange[1]) < 0 { + c.checkers.AddSuspectKeyRange(lastRegion.GetEndKey(), keyRange[1]) + } + c.checkers.AddSuspectRegions(regionIDList...) + } + } +} + +func (c *coordinator) tryAddOperators(region *core.RegionInfo) { + if region == nil { + // the region could be recent split, continue to wait. + return + } + id := region.GetID() + if c.opController.GetOperator(id) != nil { + c.checkers.RemoveWaitingRegion(id) + c.checkers.RemoveSuspectRegion(id) + return + } + ops := c.checkers.CheckRegion(region) + if len(ops) == 0 { + return + } + + if !c.opController.ExceedStoreLimit(ops...) { + c.opController.AddWaitingOperator(ops...) + c.checkers.RemoveWaitingRegion(id) + c.checkers.RemoveSuspectRegion(id) + } else { + c.checkers.AddWaitingRegion(region) + } +} + +// drivePushOperator is used to push the unfinished operator to the executor. +func (c *coordinator) drivePushOperator() { + defer logutil.LogPanic() + + defer c.wg.Done() + log.Info("coordinator begins to actively drive push operator") + ticker := time.NewTicker(schedule.PushOperatorTickInterval) + defer ticker.Stop() + for { + select { + case <-c.ctx.Done(): + log.Info("drive push operator has been stopped") + return + case <-ticker.C: + c.opController.PushOperators() + } + } +} + +func (c *coordinator) runUntilStop() { + c.run() + <-c.ctx.Done() + log.Info("coordinator is stopping") + c.wg.Wait() + log.Info("coordinator has been stopped") +} + +func (c *coordinator) run() { + ticker := time.NewTicker(runSchedulerCheckInterval) + failpoint.Inject("changeCoordinatorTicker", func() { + ticker = time.NewTicker(100 * time.Millisecond) + }) + defer ticker.Stop() + log.Info("coordinator starts to collect cluster information") + for { + if c.shouldRun() { + log.Info("coordinator has finished cluster information preparation") + break + } + select { + case <-ticker.C: + case <-c.ctx.Done(): + log.Info("coordinator stops running") + return + } + } + log.Info("coordinator starts to run schedulers") + var ( + scheduleNames []string + configs []string + err error + ) + for i := 0; i < maxLoadConfigRetries; i++ { + scheduleNames, configs, err = c.cluster.storage.LoadAllScheduleConfig() + select { + case <-c.ctx.Done(): + log.Info("coordinator stops running") + return + default: + } + if err == nil { + break + } + log.Error("cannot load schedulers' config", zap.Int("retry-times", i), errs.ZapError(err)) + } + if err != nil { + log.Fatal("cannot load schedulers' config", errs.ZapError(err)) + } + + scheduleCfg := c.cluster.opt.GetScheduleConfig().Clone() + // The new way to create scheduler with the independent configuration. + for i, name := range scheduleNames { + data := configs[i] + typ := schedule.FindSchedulerTypeByName(name) + var cfg config.SchedulerConfig + for _, c := range scheduleCfg.Schedulers { + if c.Type == typ { + cfg = c + break + } + } + if len(cfg.Type) == 0 { + log.Error("the scheduler type not found", zap.String("scheduler-name", name), errs.ZapError(errs.ErrSchedulerNotFound)) + continue + } + if cfg.Disable { + log.Info("skip create scheduler with independent configuration", zap.String("scheduler-name", name), zap.String("scheduler-type", cfg.Type), zap.Strings("scheduler-args", cfg.Args)) + continue + } + s, err := schedule.CreateScheduler(cfg.Type, c.opController, c.cluster.storage, schedule.ConfigJSONDecoder([]byte(data))) + if err != nil { + log.Error("can not create scheduler with independent configuration", zap.String("scheduler-name", name), zap.Strings("scheduler-args", cfg.Args), errs.ZapError(err)) + continue + } + log.Info("create scheduler with independent configuration", zap.String("scheduler-name", s.GetName())) + if err = c.addScheduler(s); err != nil { + log.Error("can not add scheduler with independent configuration", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", cfg.Args), errs.ZapError(err)) + } + } + + // The old way to create the scheduler. + k := 0 + for _, schedulerCfg := range scheduleCfg.Schedulers { + if schedulerCfg.Disable { + scheduleCfg.Schedulers[k] = schedulerCfg + k++ + log.Info("skip create scheduler", zap.String("scheduler-type", schedulerCfg.Type), zap.Strings("scheduler-args", schedulerCfg.Args)) + continue + } + + s, err := schedule.CreateScheduler(schedulerCfg.Type, c.opController, c.cluster.storage, schedule.ConfigSliceDecoder(schedulerCfg.Type, schedulerCfg.Args)) + if err != nil { + log.Error("can not create scheduler", zap.String("scheduler-type", schedulerCfg.Type), zap.Strings("scheduler-args", schedulerCfg.Args), errs.ZapError(err)) + continue + } + + log.Info("create scheduler", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", schedulerCfg.Args)) + if err = c.addScheduler(s, schedulerCfg.Args...); err != nil && !errors.ErrorEqual(err, errs.ErrSchedulerExisted.FastGenByArgs()) { + log.Error("can not add scheduler", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", schedulerCfg.Args), errs.ZapError(err)) + } else { + // Only records the valid scheduler config. + scheduleCfg.Schedulers[k] = schedulerCfg + k++ + } + } + + // Removes the invalid scheduler config and persist. + scheduleCfg.Schedulers = scheduleCfg.Schedulers[:k] + c.cluster.opt.SetScheduleConfig(scheduleCfg) + if err := c.cluster.opt.Persist(c.cluster.storage); err != nil { + log.Error("cannot persist schedule config", errs.ZapError(err)) + } + + c.wg.Add(3) + // Starts to patrol regions. + go c.patrolRegions() + // Checks suspect key ranges + go c.checkSuspectRanges() + go c.drivePushOperator() +} + +// LoadPlugin load user plugin +func (c *coordinator) LoadPlugin(pluginPath string, ch chan string) { + log.Info("load plugin", zap.String("plugin-path", pluginPath)) + // get func: SchedulerType from plugin + SchedulerType, err := c.pluginInterface.GetFunction(pluginPath, "SchedulerType") + if err != nil { + log.Error("GetFunction SchedulerType error", errs.ZapError(err)) + return + } + schedulerType := SchedulerType.(func() string) + // get func: SchedulerArgs from plugin + SchedulerArgs, err := c.pluginInterface.GetFunction(pluginPath, "SchedulerArgs") + if err != nil { + log.Error("GetFunction SchedulerArgs error", errs.ZapError(err)) + return + } + schedulerArgs := SchedulerArgs.(func() []string) + // create and add user scheduler + s, err := schedule.CreateScheduler(schedulerType(), c.opController, c.cluster.storage, schedule.ConfigSliceDecoder(schedulerType(), schedulerArgs())) + if err != nil { + log.Error("can not create scheduler", zap.String("scheduler-type", schedulerType()), errs.ZapError(err)) + return + } + log.Info("create scheduler", zap.String("scheduler-name", s.GetName())) + if err = c.addScheduler(s); err != nil { + log.Error("can't add scheduler", zap.String("scheduler-name", s.GetName()), errs.ZapError(err)) + return + } + + c.wg.Add(1) + go c.waitPluginUnload(pluginPath, s.GetName(), ch) +} + +func (c *coordinator) waitPluginUnload(pluginPath, schedulerName string, ch chan string) { + defer logutil.LogPanic() + defer c.wg.Done() + // Get signal from channel which means user unload the plugin + for { + select { + case action := <-ch: + if action == PluginUnload { + err := c.removeScheduler(schedulerName) + if err != nil { + log.Error("can not remove scheduler", zap.String("scheduler-name", schedulerName), errs.ZapError(err)) + } else { + log.Info("unload plugin", zap.String("plugin", pluginPath)) + return + } + } else { + log.Error("unknown action", zap.String("action", action)) + } + case <-c.ctx.Done(): + log.Info("unload plugin has been stopped") + return + } + } +} + +func (c *coordinator) stop() { + c.cancel() +} + +func (c *coordinator) getHotRegionsByType(typ statistics.RWType) *statistics.StoreHotPeersInfos { + isTraceFlow := c.cluster.GetOpts().IsTraceRegionFlow() + storeLoads := c.cluster.GetStoresLoads() + stores := c.cluster.GetStores() + var infos *statistics.StoreHotPeersInfos + switch typ { + case statistics.Write: + regionStats := c.cluster.RegionWriteStats() + infos = statistics.GetHotStatus(stores, storeLoads, regionStats, statistics.Write, isTraceFlow) + case statistics.Read: + regionStats := c.cluster.RegionReadStats() + infos = statistics.GetHotStatus(stores, storeLoads, regionStats, statistics.Read, isTraceFlow) + default: + } + // update params `IsLearner` and `LastUpdateTime` + for _, stores := range []statistics.StoreHotPeersStat{infos.AsLeader, infos.AsPeer} { + for _, store := range stores { + for _, hotPeer := range store.Stats { + region := c.cluster.GetRegion(hotPeer.RegionID) + hotPeer.UpdateHotPeerStatShow(region) + } + } + } + return infos +} + +func (c *coordinator) getSchedulers() []string { + c.RLock() + defer c.RUnlock() + names := make([]string, 0, len(c.schedulers)) + for name := range c.schedulers { + names = append(names, name) + } + return names +} + +func (c *coordinator) getSchedulerHandlers() map[string]http.Handler { + c.RLock() + defer c.RUnlock() + handlers := make(map[string]http.Handler, len(c.schedulers)) + for name, scheduler := range c.schedulers { + handlers[name] = scheduler.Scheduler + } + return handlers +} + +func (c *coordinator) collectSchedulerMetrics() { + c.RLock() + defer c.RUnlock() + for _, s := range c.schedulers { + var allowScheduler float64 + // If the scheduler is not allowed to schedule, it will disappear in Grafana panel. + // See issue #1341. + if !s.IsPaused() && !s.cluster.GetUnsafeRecoveryController().IsRunning() { + allowScheduler = 1 + } + schedulerStatusGauge.WithLabelValues(s.GetName(), "allow").Set(allowScheduler) + } +} + +func (c *coordinator) resetSchedulerMetrics() { + schedulerStatusGauge.Reset() +} + +func (c *coordinator) collectHotSpotMetrics() { + stores := c.cluster.GetStores() + // Collects hot write region metrics. + collectHotMetrics(c.cluster, stores, statistics.Write) + // Collects hot read region metrics. + collectHotMetrics(c.cluster, stores, statistics.Read) +} + +func collectHotMetrics(cluster *RaftCluster, stores []*core.StoreInfo, typ statistics.RWType) { + var ( + kind string + regionStats map[uint64][]*statistics.HotPeerStat + ) + + switch typ { + case statistics.Read: + regionStats = cluster.RegionReadStats() + kind = statistics.Read.String() + case statistics.Write: + regionStats = cluster.RegionWriteStats() + kind = statistics.Write.String() + } + status := statistics.CollectHotPeerInfos(stores, regionStats) // only returns TotalBytesRate,TotalKeysRate,TotalQueryRate,Count + + for _, s := range stores { + storeAddress := s.GetAddress() + storeID := s.GetID() + storeLabel := strconv.FormatUint(storeID, 10) + stat, hasHotLeader := status.AsLeader[storeID] + if hasHotLeader { + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_leader").Set(stat.TotalBytesRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_leader").Set(stat.TotalKeysRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_leader").Set(stat.TotalQueryRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_leader").Set(float64(stat.Count)) + } else { + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_leader") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_leader") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_leader") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_leader") + } + + stat, hasHotPeer := status.AsPeer[storeID] + if hasHotPeer { + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_peer").Set(stat.TotalBytesRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_peer").Set(stat.TotalKeysRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_peer").Set(stat.TotalQueryRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_peer").Set(float64(stat.Count)) + } else { + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_peer") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_peer") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_peer") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_peer") + } + + if !hasHotLeader && !hasHotPeer { + statistics.ForeachRegionStats(func(rwTy statistics.RWType, dim int, _ statistics.RegionStatKind) { + hotPendingSum.DeleteLabelValues(storeLabel, rwTy.String(), statistics.DimToString(dim)) + }) + } + } +} + +func (c *coordinator) resetHotSpotMetrics() { + hotSpotStatusGauge.Reset() + hotPendingSum.Reset() +} + +func (c *coordinator) shouldRun() bool { + return c.prepareChecker.check(c.cluster.GetBasicCluster()) +} + +func (c *coordinator) addScheduler(scheduler schedule.Scheduler, args ...string) error { + c.Lock() + defer c.Unlock() + + if _, ok := c.schedulers[scheduler.GetName()]; ok { + return errs.ErrSchedulerExisted.FastGenByArgs() + } + + s := newScheduleController(c, scheduler) + if err := s.Prepare(c.cluster); err != nil { + return err + } + + c.wg.Add(1) + go c.runScheduler(s) + c.schedulers[s.GetName()] = s + c.cluster.opt.AddSchedulerCfg(s.GetType(), args) + return nil +} + +func (c *coordinator) removeScheduler(name string) error { + c.Lock() + defer c.Unlock() + if c.cluster == nil { + return errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return errs.ErrSchedulerNotFound.FastGenByArgs() + } + + opt := c.cluster.opt + if err := c.removeOptScheduler(opt, name); err != nil { + log.Error("can not remove scheduler", zap.String("scheduler-name", name), errs.ZapError(err)) + return err + } + + if err := opt.Persist(c.cluster.storage); err != nil { + log.Error("the option can not persist scheduler config", errs.ZapError(err)) + return err + } + + if err := c.cluster.storage.RemoveScheduleConfig(name); err != nil { + log.Error("can not remove the scheduler config", errs.ZapError(err)) + return err + } + + s.Stop() + schedulerStatusGauge.DeleteLabelValues(name, "allow") + delete(c.schedulers, name) + + return nil +} + +func (c *coordinator) removeOptScheduler(o *config.PersistOptions, name string) error { + v := o.GetScheduleConfig().Clone() + for i, schedulerCfg := range v.Schedulers { + // To create a temporary scheduler is just used to get scheduler's name + decoder := schedule.ConfigSliceDecoder(schedulerCfg.Type, schedulerCfg.Args) + tmp, err := schedule.CreateScheduler(schedulerCfg.Type, schedule.NewOperatorController(c.ctx, nil, nil), storage.NewStorageWithMemoryBackend(), decoder) + if err != nil { + return err + } + if tmp.GetName() == name { + if config.IsDefaultScheduler(tmp.GetType()) { + schedulerCfg.Disable = true + v.Schedulers[i] = schedulerCfg + } else { + v.Schedulers = append(v.Schedulers[:i], v.Schedulers[i+1:]...) + } + o.SetScheduleConfig(v) + return nil + } + } + return nil +} + +func (c *coordinator) pauseOrResumeScheduler(name string, t int64) error { + c.Lock() + defer c.Unlock() + if c.cluster == nil { + return errs.ErrNotBootstrapped.FastGenByArgs() + } + var s []*scheduleController + if name != "all" { + sc, ok := c.schedulers[name] + if !ok { + return errs.ErrSchedulerNotFound.FastGenByArgs() + } + s = append(s, sc) + } else { + for _, sc := range c.schedulers { + s = append(s, sc) + } + } + var err error + for _, sc := range s { + var delayAt, delayUntil int64 + if t > 0 { + delayAt = time.Now().Unix() + delayUntil = delayAt + t + } + atomic.StoreInt64(&sc.delayAt, delayAt) + atomic.StoreInt64(&sc.delayUntil, delayUntil) + } + return err +} + +// isSchedulerAllowed returns whether a scheduler is allowed to schedule, a scheduler is not allowed to schedule if it is paused or blocked by unsafe recovery. +func (c *coordinator) isSchedulerAllowed(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return false, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.AllowSchedule(false), nil +} + +func (c *coordinator) isSchedulerPaused(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return false, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.IsPaused(), nil +} + +func (c *coordinator) isSchedulerDisabled(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return false, errs.ErrSchedulerNotFound.FastGenByArgs() + } + t := s.GetType() + scheduleConfig := c.cluster.GetOpts().GetScheduleConfig() + for _, s := range scheduleConfig.Schedulers { + if t == s.Type { + return s.Disable, nil + } + } + return false, nil +} + +func (c *coordinator) isSchedulerExisted(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + _, ok := c.schedulers[name] + if !ok { + return false, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return true, nil +} + +func (c *coordinator) runScheduler(s *scheduleController) { + defer logutil.LogPanic() + defer c.wg.Done() + defer s.Cleanup(c.cluster) + + timer := time.NewTimer(s.GetInterval()) + defer timer.Stop() + for { + select { + case <-timer.C: + timer.Reset(s.GetInterval()) + diagnosable := s.diagnosticRecorder.isAllowed() + if !s.AllowSchedule(diagnosable) { + continue + } + if op := s.Schedule(diagnosable); len(op) > 0 { + added := c.opController.AddWaitingOperator(op...) + log.Debug("add operator", zap.Int("added", added), zap.Int("total", len(op)), zap.String("scheduler", s.GetName())) + } + + case <-s.Ctx().Done(): + log.Info("scheduler has been stopped", + zap.String("scheduler-name", s.GetName()), + errs.ZapError(s.Ctx().Err())) + return + } + } +} + +func (c *coordinator) pauseOrResumeChecker(name string, t int64) error { + c.Lock() + defer c.Unlock() + if c.cluster == nil { + return errs.ErrNotBootstrapped.FastGenByArgs() + } + p, err := c.checkers.GetPauseController(name) + if err != nil { + return err + } + p.PauseOrResume(t) + return nil +} + +func (c *coordinator) isCheckerPaused(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + p, err := c.checkers.GetPauseController(name) + if err != nil { + return false, err + } + return p.IsPaused(), nil +} + +func (c *coordinator) GetDiagnosticResult(name string) (*DiagnosticResult, error) { + return c.diagnosticManager.getDiagnosticResult(name) +} + +// scheduleController is used to manage a scheduler to schedule. +type scheduleController struct { + schedule.Scheduler + cluster *RaftCluster + opController *schedule.OperatorController + nextInterval time.Duration + ctx context.Context + cancel context.CancelFunc + delayAt int64 + delayUntil int64 + diagnosticRecorder *diagnosticRecorder +} + +// newScheduleController creates a new scheduleController. +func newScheduleController(c *coordinator, s schedule.Scheduler) *scheduleController { + ctx, cancel := context.WithCancel(c.ctx) + return &scheduleController{ + Scheduler: s, + cluster: c.cluster, + opController: c.opController, + nextInterval: s.GetMinInterval(), + ctx: ctx, + cancel: cancel, + diagnosticRecorder: c.diagnosticManager.getRecorder(s.GetName()), + } +} + +func (s *scheduleController) Ctx() context.Context { + return s.ctx +} + +func (s *scheduleController) Stop() { + s.cancel() +} + +func (s *scheduleController) Schedule(diagnosable bool) []*operator.Operator { + for i := 0; i < maxScheduleRetries; i++ { + // no need to retry if schedule should stop to speed exit + select { + case <-s.ctx.Done(): + return nil + default: + } + cacheCluster := newCacheCluster(s.cluster) + // we need only process diagnostic once in the retry loop + diagnosable = diagnosable && i == 0 + ops, plans := s.Scheduler.Schedule(cacheCluster, diagnosable) + if diagnosable { + s.diagnosticRecorder.setResultFromPlans(ops, plans) + } + if len(ops) > 0 { + // If we have schedule, reset interval to the minimal interval. + s.nextInterval = s.Scheduler.GetMinInterval() + return ops + } + } + s.nextInterval = s.Scheduler.GetNextInterval(s.nextInterval) + return nil +} + +func (s *scheduleController) DiagnoseDryRun() ([]*operator.Operator, []plan.Plan) { + cacheCluster := newCacheCluster(s.cluster) + return s.Scheduler.Schedule(cacheCluster, true) +} + +// GetInterval returns the interval of scheduling for a scheduler. +func (s *scheduleController) GetInterval() time.Duration { + return s.nextInterval +} + +// AllowSchedule returns if a scheduler is allowed to schedule. +func (s *scheduleController) AllowSchedule(diagnosable bool) bool { + if !s.Scheduler.IsScheduleAllowed(s.cluster) { + if diagnosable { + s.diagnosticRecorder.setResultFromStatus(pending) + } + return false + } + if s.IsPaused() || s.cluster.GetUnsafeRecoveryController().IsRunning() { + if diagnosable { + s.diagnosticRecorder.setResultFromStatus(paused) + } + return false + } + return true +} + +// isPaused returns if a scheduler is paused. +func (s *scheduleController) IsPaused() bool { + delayUntil := atomic.LoadInt64(&s.delayUntil) + return time.Now().Unix() < delayUntil +} + +// GetPausedSchedulerDelayAt returns paused timestamp of a paused scheduler +func (s *scheduleController) GetDelayAt() int64 { + if s.IsPaused() { + return atomic.LoadInt64(&s.delayAt) + } + return 0 +} + +// GetPausedSchedulerDelayUntil returns resume timestamp of a paused scheduler +func (s *scheduleController) GetDelayUntil() int64 { + if s.IsPaused() { + return atomic.LoadInt64(&s.delayUntil) + } + return 0 +} + +func (c *coordinator) getPausedSchedulerDelayAt(name string) (int64, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return -1, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return -1, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.GetDelayAt(), nil +} + +func (c *coordinator) getPausedSchedulerDelayUntil(name string) (int64, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return -1, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return -1, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.GetDelayUntil(), nil +} + +// CheckTransferWitnessLeader determines if transfer leader is required, then sends to the scheduler if needed +func (c *coordinator) CheckTransferWitnessLeader(region *core.RegionInfo) { + if core.NeedTransferWitnessLeader(region) { + c.RLock() + s, ok := c.schedulers[schedulers.TransferWitnessLeaderName] + c.RUnlock() + if ok { + select { + case schedulers.RecvRegionInfo(s.Scheduler) <- region: + default: + log.Warn("drop transfer witness leader due to recv region channel full", zap.Uint64("region-id", region.GetID())) + } + } + } +} diff --git a/server/config/binding__failpoint_binding__.go b/server/config/binding__failpoint_binding__.go new file mode 100755 index 00000000000..bbb067eddce --- /dev/null +++ b/server/config/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package config + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/config/persist_options.go b/server/config/persist_options.go old mode 100644 new mode 100755 index 6264cb447b2..2a0969faaa1 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -680,9 +680,9 @@ func (o *PersistOptions) Persist(storage endpoint.ConfigStorage) error { ClusterVersion: *o.GetClusterVersion(), } err := storage.SaveConfig(cfg) - failpoint.Inject("persistFail", func() { + if _, _err_ := failpoint.Eval(_curpkg_("persistFail")); _err_ == nil { err = errors.New("fail to persist") - }) + } return err } diff --git a/server/config/persist_options.go__failpoint_stash__ b/server/config/persist_options.go__failpoint_stash__ new file mode 100644 index 00000000000..6264cb447b2 --- /dev/null +++ b/server/config/persist_options.go__failpoint_stash__ @@ -0,0 +1,858 @@ +// Copyright 2017 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "context" + "fmt" + "reflect" + "strconv" + "strings" + "sync/atomic" + "time" + "unsafe" + + "github.com/coreos/go-semver/semver" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/cache" + "github.com/tikv/pd/pkg/etcdutil" + "github.com/tikv/pd/pkg/slice" + "github.com/tikv/pd/pkg/typeutil" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/core/storelimit" + "github.com/tikv/pd/server/storage/endpoint" + "go.etcd.io/etcd/clientv3" +) + +// PersistOptions wraps all configurations that need to persist to storage and +// allows to access them safely. +type PersistOptions struct { + // configuration -> ttl value + ttl *cache.TTLString + schedule atomic.Value + replication atomic.Value + pdServerConfig atomic.Value + replicationMode atomic.Value + labelProperty atomic.Value + clusterVersion unsafe.Pointer +} + +// NewPersistOptions creates a new PersistOptions instance. +func NewPersistOptions(cfg *Config) *PersistOptions { + o := &PersistOptions{} + o.schedule.Store(&cfg.Schedule) + o.replication.Store(&cfg.Replication) + o.pdServerConfig.Store(&cfg.PDServerCfg) + o.replicationMode.Store(&cfg.ReplicationMode) + o.labelProperty.Store(cfg.LabelProperty) + o.SetClusterVersion(&cfg.ClusterVersion) + o.ttl = nil + return o +} + +// GetScheduleConfig returns scheduling configurations. +func (o *PersistOptions) GetScheduleConfig() *ScheduleConfig { + return o.schedule.Load().(*ScheduleConfig) +} + +// SetScheduleConfig sets the PD scheduling configuration. +func (o *PersistOptions) SetScheduleConfig(cfg *ScheduleConfig) { + o.schedule.Store(cfg) +} + +// GetReplicationConfig returns replication configurations. +func (o *PersistOptions) GetReplicationConfig() *ReplicationConfig { + return o.replication.Load().(*ReplicationConfig) +} + +// SetReplicationConfig sets the PD replication configuration. +func (o *PersistOptions) SetReplicationConfig(cfg *ReplicationConfig) { + o.replication.Store(cfg) +} + +// GetPDServerConfig returns pd server configurations. +func (o *PersistOptions) GetPDServerConfig() *PDServerConfig { + return o.pdServerConfig.Load().(*PDServerConfig) +} + +// SetPDServerConfig sets the PD configuration. +func (o *PersistOptions) SetPDServerConfig(cfg *PDServerConfig) { + o.pdServerConfig.Store(cfg) +} + +// GetReplicationModeConfig returns the replication mode config. +func (o *PersistOptions) GetReplicationModeConfig() *ReplicationModeConfig { + return o.replicationMode.Load().(*ReplicationModeConfig) +} + +// SetReplicationModeConfig sets the replication mode config. +func (o *PersistOptions) SetReplicationModeConfig(cfg *ReplicationModeConfig) { + o.replicationMode.Store(cfg) +} + +// GetLabelPropertyConfig returns the label property. +func (o *PersistOptions) GetLabelPropertyConfig() LabelPropertyConfig { + return o.labelProperty.Load().(LabelPropertyConfig) +} + +// SetLabelPropertyConfig sets the label property configuration. +func (o *PersistOptions) SetLabelPropertyConfig(cfg LabelPropertyConfig) { + o.labelProperty.Store(cfg) +} + +// GetClusterVersion returns the cluster version. +func (o *PersistOptions) GetClusterVersion() *semver.Version { + return (*semver.Version)(atomic.LoadPointer(&o.clusterVersion)) +} + +// SetClusterVersion sets the cluster version. +func (o *PersistOptions) SetClusterVersion(v *semver.Version) { + atomic.StorePointer(&o.clusterVersion, unsafe.Pointer(v)) +} + +// CASClusterVersion sets the cluster version. +func (o *PersistOptions) CASClusterVersion(old, new *semver.Version) bool { + return atomic.CompareAndSwapPointer(&o.clusterVersion, unsafe.Pointer(old), unsafe.Pointer(new)) +} + +// GetLocationLabels returns the location labels for each region. +func (o *PersistOptions) GetLocationLabels() []string { + return o.GetReplicationConfig().LocationLabels +} + +// SetLocationLabels sets the location labels. +func (o *PersistOptions) SetLocationLabels(labels []string) { + v := o.GetReplicationConfig().Clone() + v.LocationLabels = labels + o.SetReplicationConfig(v) +} + +// GetIsolationLevel returns the isolation label for each region. +func (o *PersistOptions) GetIsolationLevel() string { + return o.GetReplicationConfig().IsolationLevel +} + +// IsPlacementRulesEnabled returns if the placement rules is enabled. +func (o *PersistOptions) IsPlacementRulesEnabled() bool { + return o.GetReplicationConfig().EnablePlacementRules +} + +// SetPlacementRuleEnabled set PlacementRuleEnabled +func (o *PersistOptions) SetPlacementRuleEnabled(enabled bool) { + v := o.GetReplicationConfig().Clone() + v.EnablePlacementRules = enabled + o.SetReplicationConfig(v) +} + +// IsPlacementRulesCacheEnabled returns if the placement rules cache is enabled +func (o *PersistOptions) IsPlacementRulesCacheEnabled() bool { + return o.GetReplicationConfig().EnablePlacementRulesCache +} + +// SetPlacementRulesCacheEnabled set EnablePlacementRulesCache +func (o *PersistOptions) SetPlacementRulesCacheEnabled(enabled bool) { + v := o.GetReplicationConfig().Clone() + v.EnablePlacementRulesCache = enabled + o.SetReplicationConfig(v) +} + +// GetStrictlyMatchLabel returns whether check label strict. +func (o *PersistOptions) GetStrictlyMatchLabel() bool { + return o.GetReplicationConfig().StrictlyMatchLabel +} + +// GetMaxReplicas returns the number of replicas for each region. +func (o *PersistOptions) GetMaxReplicas() int { + return int(o.GetReplicationConfig().MaxReplicas) +} + +// SetMaxReplicas sets the number of replicas for each region. +func (o *PersistOptions) SetMaxReplicas(replicas int) { + v := o.GetReplicationConfig().Clone() + v.MaxReplicas = uint64(replicas) + o.SetReplicationConfig(v) +} + +const ( + maxSnapshotCountKey = "schedule.max-snapshot-count" + maxMergeRegionSizeKey = "schedule.max-merge-region-size" + maxPendingPeerCountKey = "schedule.max-pending-peer-count" + maxMergeRegionKeysKey = "schedule.max-merge-region-keys" + leaderScheduleLimitKey = "schedule.leader-schedule-limit" + regionScheduleLimitKey = "schedule.region-schedule-limit" + replicaRescheduleLimitKey = "schedule.replica-schedule-limit" + mergeScheduleLimitKey = "schedule.merge-schedule-limit" + hotRegionScheduleLimitKey = "schedule.hot-region-schedule-limit" + schedulerMaxWaitingOperatorKey = "schedule.scheduler-max-waiting-operator" + enableLocationReplacement = "schedule.enable-location-replacement" + // it's related to schedule, but it's not an explicit config + enableTiKVSplitRegion = "schedule.enable-tikv-split-region" +) + +var supportedTTLConfigs = []string{ + maxSnapshotCountKey, + maxMergeRegionSizeKey, + maxPendingPeerCountKey, + maxMergeRegionKeysKey, + leaderScheduleLimitKey, + regionScheduleLimitKey, + replicaRescheduleLimitKey, + mergeScheduleLimitKey, + hotRegionScheduleLimitKey, + schedulerMaxWaitingOperatorKey, + enableLocationReplacement, + enableTiKVSplitRegion, + "default-add-peer", + "default-remove-peer", +} + +// IsSupportedTTLConfig checks whether a key is a supported config item with ttl +func IsSupportedTTLConfig(key string) bool { + for _, supportedConfig := range supportedTTLConfigs { + if key == supportedConfig { + return true + } + } + return strings.HasPrefix(key, "add-peer-") || strings.HasPrefix(key, "remove-peer-") +} + +// GetMaxSnapshotCount returns the number of the max snapshot which is allowed to send. +func (o *PersistOptions) GetMaxSnapshotCount() uint64 { + return o.getTTLUintOr(maxSnapshotCountKey, o.GetScheduleConfig().MaxSnapshotCount) +} + +// GetMaxPendingPeerCount returns the number of the max pending peers. +func (o *PersistOptions) GetMaxPendingPeerCount() uint64 { + return o.getTTLUintOr(maxPendingPeerCountKey, o.GetScheduleConfig().MaxPendingPeerCount) +} + +// GetMaxMergeRegionSize returns the max region size. +func (o *PersistOptions) GetMaxMergeRegionSize() uint64 { + return o.getTTLUintOr(maxMergeRegionSizeKey, o.GetScheduleConfig().MaxMergeRegionSize) +} + +// GetMaxMergeRegionKeys returns the max number of keys. +// It returns size * 10000 if the key of max-merge-region-Keys doesn't exist. +func (o *PersistOptions) GetMaxMergeRegionKeys() uint64 { + keys, exist, err := o.getTTLUint(maxMergeRegionKeysKey) + if exist && err == nil { + return keys + } + size, exist, err := o.getTTLUint(maxMergeRegionSizeKey) + if exist && err == nil { + return size * 10000 + } + return o.GetScheduleConfig().GetMaxMergeRegionKeys() +} + +// GetSplitMergeInterval returns the interval between finishing split and starting to merge. +func (o *PersistOptions) GetSplitMergeInterval() time.Duration { + return o.GetScheduleConfig().SplitMergeInterval.Duration +} + +// SetSplitMergeInterval to set the interval between finishing split and starting to merge. It's only used to test. +func (o *PersistOptions) SetSplitMergeInterval(splitMergeInterval time.Duration) { + v := o.GetScheduleConfig().Clone() + v.SplitMergeInterval = typeutil.Duration{Duration: splitMergeInterval} + o.SetScheduleConfig(v) +} + +// GetSwitchWitnessInterval returns the interval between promote to non-witness and starting to switch to witness. +func (o *PersistOptions) GetSwitchWitnessInterval() time.Duration { + return o.GetScheduleConfig().SwitchWitnessInterval.Duration +} + +// IsDiagnosticAllowed returns whether is enable to use diagnostic. +func (o *PersistOptions) IsDiagnosticAllowed() bool { + return o.GetScheduleConfig().EnableDiagnostic +} + +// SetEnableDiagnostic to set the option for diagnose. It's only used to test. +func (o *PersistOptions) SetEnableDiagnostic(enable bool) { + v := o.GetScheduleConfig().Clone() + v.EnableDiagnostic = enable + o.SetScheduleConfig(v) +} + +// IsWitnessAllowed returns whether is enable to use witness. +func (o *PersistOptions) IsWitnessAllowed() bool { + return o.GetScheduleConfig().EnableWitness +} + +// SetEnableWitness to set the option for witness. It's only used to test. +func (o *PersistOptions) SetEnableWitness(enable bool) { + v := o.GetScheduleConfig().Clone() + v.EnableWitness = enable + o.SetScheduleConfig(v) +} + +// SetMaxMergeRegionSize sets the max merge region size. +func (o *PersistOptions) SetMaxMergeRegionSize(maxMergeRegionSize uint64) { + v := o.GetScheduleConfig().Clone() + v.MaxMergeRegionSize = maxMergeRegionSize + o.SetScheduleConfig(v) +} + +// SetMaxMergeRegionKeys sets the max merge region keys. +func (o *PersistOptions) SetMaxMergeRegionKeys(maxMergeRegionKeys uint64) { + v := o.GetScheduleConfig().Clone() + v.MaxMergeRegionKeys = maxMergeRegionKeys + o.SetScheduleConfig(v) +} + +// SetStoreLimit sets a store limit for a given type and rate. +func (o *PersistOptions) SetStoreLimit(storeID uint64, typ storelimit.Type, ratePerMin float64) { + v := o.GetScheduleConfig().Clone() + var sc StoreLimitConfig + var rate float64 + switch typ { + case storelimit.AddPeer: + if _, ok := v.StoreLimit[storeID]; !ok { + rate = DefaultStoreLimit.GetDefaultStoreLimit(storelimit.RemovePeer) + } else { + rate = v.StoreLimit[storeID].RemovePeer + } + sc = StoreLimitConfig{AddPeer: ratePerMin, RemovePeer: rate} + case storelimit.RemovePeer: + if _, ok := v.StoreLimit[storeID]; !ok { + rate = DefaultStoreLimit.GetDefaultStoreLimit(storelimit.AddPeer) + } else { + rate = v.StoreLimit[storeID].AddPeer + } + sc = StoreLimitConfig{AddPeer: rate, RemovePeer: ratePerMin} + } + v.StoreLimit[storeID] = sc + o.SetScheduleConfig(v) +} + +// SetAllStoresLimit sets all store limit for a given type and rate. +func (o *PersistOptions) SetAllStoresLimit(typ storelimit.Type, ratePerMin float64) { + v := o.GetScheduleConfig().Clone() + switch typ { + case storelimit.AddPeer: + DefaultStoreLimit.SetDefaultStoreLimit(storelimit.AddPeer, ratePerMin) + for storeID := range v.StoreLimit { + sc := StoreLimitConfig{AddPeer: ratePerMin, RemovePeer: v.StoreLimit[storeID].RemovePeer} + v.StoreLimit[storeID] = sc + } + case storelimit.RemovePeer: + DefaultStoreLimit.SetDefaultStoreLimit(storelimit.RemovePeer, ratePerMin) + for storeID := range v.StoreLimit { + sc := StoreLimitConfig{AddPeer: v.StoreLimit[storeID].AddPeer, RemovePeer: ratePerMin} + v.StoreLimit[storeID] = sc + } + } + + o.SetScheduleConfig(v) +} + +// IsOneWayMergeEnabled returns if a region can only be merged into the next region of it. +func (o *PersistOptions) IsOneWayMergeEnabled() bool { + return o.GetScheduleConfig().EnableOneWayMerge +} + +// IsCrossTableMergeEnabled returns if across table merge is enabled. +func (o *PersistOptions) IsCrossTableMergeEnabled() bool { + return o.GetScheduleConfig().EnableCrossTableMerge +} + +// GetPatrolRegionInterval returns the interval of patrolling region. +func (o *PersistOptions) GetPatrolRegionInterval() time.Duration { + return o.GetScheduleConfig().PatrolRegionInterval.Duration +} + +// GetMaxStoreDownTime returns the max down time of a store. +func (o *PersistOptions) GetMaxStoreDownTime() time.Duration { + return o.GetScheduleConfig().MaxStoreDownTime.Duration +} + +// GetMaxStorePreparingTime returns the max preparing time of a store. +func (o *PersistOptions) GetMaxStorePreparingTime() time.Duration { + return o.GetScheduleConfig().MaxStorePreparingTime.Duration +} + +// GetLeaderScheduleLimit returns the limit for leader schedule. +func (o *PersistOptions) GetLeaderScheduleLimit() uint64 { + return o.getTTLUintOr(leaderScheduleLimitKey, o.GetScheduleConfig().LeaderScheduleLimit) +} + +// GetRegionScheduleLimit returns the limit for region schedule. +func (o *PersistOptions) GetRegionScheduleLimit() uint64 { + return o.getTTLUintOr(regionScheduleLimitKey, o.GetScheduleConfig().RegionScheduleLimit) +} + +// GetReplicaScheduleLimit returns the limit for replica schedule. +func (o *PersistOptions) GetReplicaScheduleLimit() uint64 { + return o.getTTLUintOr(replicaRescheduleLimitKey, o.GetScheduleConfig().ReplicaScheduleLimit) +} + +// GetMergeScheduleLimit returns the limit for merge schedule. +func (o *PersistOptions) GetMergeScheduleLimit() uint64 { + return o.getTTLUintOr(mergeScheduleLimitKey, o.GetScheduleConfig().MergeScheduleLimit) +} + +// GetHotRegionScheduleLimit returns the limit for hot region schedule. +func (o *PersistOptions) GetHotRegionScheduleLimit() uint64 { + return o.getTTLUintOr(hotRegionScheduleLimitKey, o.GetScheduleConfig().HotRegionScheduleLimit) +} + +// GetStoreLimit returns the limit of a store. +func (o *PersistOptions) GetStoreLimit(storeID uint64) (returnSC StoreLimitConfig) { + defer func() { + returnSC.RemovePeer = o.getTTLFloatOr(fmt.Sprintf("remove-peer-%v", storeID), returnSC.RemovePeer) + returnSC.AddPeer = o.getTTLFloatOr(fmt.Sprintf("add-peer-%v", storeID), returnSC.AddPeer) + }() + if limit, ok := o.GetScheduleConfig().StoreLimit[storeID]; ok { + return limit + } + cfg := o.GetScheduleConfig().Clone() + sc := StoreLimitConfig{ + AddPeer: DefaultStoreLimit.GetDefaultStoreLimit(storelimit.AddPeer), + RemovePeer: DefaultStoreLimit.GetDefaultStoreLimit(storelimit.RemovePeer), + } + v, ok1, err := o.getTTLFloat("default-add-peer") + if err != nil { + log.Warn("failed to parse default-add-peer from PersistOptions's ttl storage") + } + canSetAddPeer := ok1 && err == nil + if canSetAddPeer { + returnSC.AddPeer = v + } + + v, ok2, err := o.getTTLFloat("default-remove-peer") + if err != nil { + log.Warn("failed to parse default-remove-peer from PersistOptions's ttl storage") + } + canSetRemovePeer := ok2 && err == nil + if canSetRemovePeer { + returnSC.RemovePeer = v + } + + if canSetAddPeer || canSetRemovePeer { + return returnSC + } + cfg.StoreLimit[storeID] = sc + o.SetScheduleConfig(cfg) + return o.GetScheduleConfig().StoreLimit[storeID] +} + +// GetStoreLimitByType returns the limit of a store with a given type. +func (o *PersistOptions) GetStoreLimitByType(storeID uint64, typ storelimit.Type) (returned float64) { + defer func() { + if typ == storelimit.RemovePeer { + returned = o.getTTLFloatOr(fmt.Sprintf("remove-peer-%v", storeID), returned) + } else if typ == storelimit.AddPeer { + returned = o.getTTLFloatOr(fmt.Sprintf("add-peer-%v", storeID), returned) + } + }() + limit := o.GetStoreLimit(storeID) + switch typ { + case storelimit.AddPeer: + return limit.AddPeer + case storelimit.RemovePeer: + return limit.RemovePeer + default: + panic("no such limit type") + } +} + +// GetAllStoresLimit returns the limit of all stores. +func (o *PersistOptions) GetAllStoresLimit() map[uint64]StoreLimitConfig { + return o.GetScheduleConfig().StoreLimit +} + +// GetStoreLimitMode returns the limit mode of store. +func (o *PersistOptions) GetStoreLimitMode() string { + return o.GetScheduleConfig().StoreLimitMode +} + +// GetTolerantSizeRatio gets the tolerant size ratio. +func (o *PersistOptions) GetTolerantSizeRatio() float64 { + return o.GetScheduleConfig().TolerantSizeRatio +} + +// GetLowSpaceRatio returns the low space ratio. +func (o *PersistOptions) GetLowSpaceRatio() float64 { + return o.GetScheduleConfig().LowSpaceRatio +} + +// GetHighSpaceRatio returns the high space ratio. +func (o *PersistOptions) GetHighSpaceRatio() float64 { + return o.GetScheduleConfig().HighSpaceRatio +} + +// GetRegionScoreFormulaVersion returns the formula version config. +func (o *PersistOptions) GetRegionScoreFormulaVersion() string { + return o.GetScheduleConfig().RegionScoreFormulaVersion +} + +// GetSchedulerMaxWaitingOperator returns the number of the max waiting operators. +func (o *PersistOptions) GetSchedulerMaxWaitingOperator() uint64 { + return o.getTTLUintOr(schedulerMaxWaitingOperatorKey, o.GetScheduleConfig().SchedulerMaxWaitingOperator) +} + +// GetLeaderSchedulePolicy is to get leader schedule policy. +func (o *PersistOptions) GetLeaderSchedulePolicy() core.SchedulePolicy { + return core.StringToSchedulePolicy(o.GetScheduleConfig().LeaderSchedulePolicy) +} + +// GetKeyType is to get key type. +func (o *PersistOptions) GetKeyType() core.KeyType { + return core.StringToKeyType(o.GetPDServerConfig().KeyType) +} + +// GetMaxResetTSGap gets the max gap to reset the tso. +func (o *PersistOptions) GetMaxResetTSGap() time.Duration { + return o.GetPDServerConfig().MaxResetTSGap.Duration +} + +// GetDashboardAddress gets dashboard address. +func (o *PersistOptions) GetDashboardAddress() string { + return o.GetPDServerConfig().DashboardAddress +} + +// IsUseRegionStorage returns if the independent region storage is enabled. +func (o *PersistOptions) IsUseRegionStorage() bool { + return o.GetPDServerConfig().UseRegionStorage +} + +// IsRemoveDownReplicaEnabled returns if remove down replica is enabled. +func (o *PersistOptions) IsRemoveDownReplicaEnabled() bool { + return o.GetScheduleConfig().EnableRemoveDownReplica +} + +// IsReplaceOfflineReplicaEnabled returns if replace offline replica is enabled. +func (o *PersistOptions) IsReplaceOfflineReplicaEnabled() bool { + return o.GetScheduleConfig().EnableReplaceOfflineReplica +} + +// IsMakeUpReplicaEnabled returns if make up replica is enabled. +func (o *PersistOptions) IsMakeUpReplicaEnabled() bool { + return o.GetScheduleConfig().EnableMakeUpReplica +} + +// IsRemoveExtraReplicaEnabled returns if remove extra replica is enabled. +func (o *PersistOptions) IsRemoveExtraReplicaEnabled() bool { + return o.GetScheduleConfig().EnableRemoveExtraReplica +} + +// IsTikvRegionSplitEnabled returns whether tikv split region is disabled. +func (o *PersistOptions) IsTikvRegionSplitEnabled() bool { + return o.getTTLBoolOr(enableTiKVSplitRegion, o.GetScheduleConfig().EnableTiKVSplitRegion) +} + +// IsLocationReplacementEnabled returns if location replace is enabled. +func (o *PersistOptions) IsLocationReplacementEnabled() bool { + return o.getTTLBoolOr(enableLocationReplacement, o.GetScheduleConfig().EnableLocationReplacement) +} + +// GetMaxMovableHotPeerSize returns the max movable hot peer size. +func (o *PersistOptions) GetMaxMovableHotPeerSize() int64 { + size := o.GetScheduleConfig().MaxMovableHotPeerSize + if size <= 0 { + size = defaultMaxMovableHotPeerSize + } + return size +} + +// IsDebugMetricsEnabled returns if debug metrics is enabled. +func (o *PersistOptions) IsDebugMetricsEnabled() bool { + return o.GetScheduleConfig().EnableDebugMetrics +} + +// IsUseJointConsensus returns if using joint consensus as a operator step is enabled. +func (o *PersistOptions) IsUseJointConsensus() bool { + return o.GetScheduleConfig().EnableJointConsensus +} + +// SetEnableUseJointConsensus to set the option for using joint consensus. It's only used to test. +func (o *PersistOptions) SetEnableUseJointConsensus(enable bool) { + v := o.GetScheduleConfig().Clone() + v.EnableJointConsensus = enable + o.SetScheduleConfig(v) +} + +// IsTraceRegionFlow returns if the region flow is tracing. +// If the accuracy cannot reach 0.1 MB, it is considered not. +func (o *PersistOptions) IsTraceRegionFlow() bool { + return o.GetPDServerConfig().FlowRoundByDigit <= maxTraceFlowRoundByDigit +} + +// GetHotRegionCacheHitsThreshold is a threshold to decide if a region is hot. +func (o *PersistOptions) GetHotRegionCacheHitsThreshold() int { + return int(o.GetScheduleConfig().HotRegionCacheHitsThreshold) +} + +// GetStoresLimit gets the stores' limit. +func (o *PersistOptions) GetStoresLimit() map[uint64]StoreLimitConfig { + return o.GetScheduleConfig().StoreLimit +} + +// GetSchedulers gets the scheduler configurations. +func (o *PersistOptions) GetSchedulers() SchedulerConfigs { + return o.GetScheduleConfig().Schedulers +} + +// GetHotRegionsWriteInterval gets interval for PD to store Hot Region information. +func (o *PersistOptions) GetHotRegionsWriteInterval() time.Duration { + return o.GetScheduleConfig().HotRegionsWriteInterval.Duration +} + +// GetHotRegionsReservedDays gets days hot region information is kept. +func (o *PersistOptions) GetHotRegionsReservedDays() uint64 { + return o.GetScheduleConfig().HotRegionsReservedDays +} + +// AddSchedulerCfg adds the scheduler configurations. +func (o *PersistOptions) AddSchedulerCfg(tp string, args []string) { + v := o.GetScheduleConfig().Clone() + for i, schedulerCfg := range v.Schedulers { + // comparing args is to cover the case that there are schedulers in same type but not with same name + // such as two schedulers of type "evict-leader", + // one name is "evict-leader-scheduler-1" and the other is "evict-leader-scheduler-2" + if reflect.DeepEqual(schedulerCfg, SchedulerConfig{Type: tp, Args: args, Disable: false}) { + return + } + + if reflect.DeepEqual(schedulerCfg, SchedulerConfig{Type: tp, Args: args, Disable: true}) { + schedulerCfg.Disable = false + v.Schedulers[i] = schedulerCfg + o.SetScheduleConfig(v) + return + } + } + v.Schedulers = append(v.Schedulers, SchedulerConfig{Type: tp, Args: args, Disable: false}) + o.SetScheduleConfig(v) +} + +// SetLabelProperty sets the label property. +func (o *PersistOptions) SetLabelProperty(typ, labelKey, labelValue string) { + cfg := o.GetLabelPropertyConfig().Clone() + for _, l := range cfg[typ] { + if l.Key == labelKey && l.Value == labelValue { + return + } + } + cfg[typ] = append(cfg[typ], StoreLabel{Key: labelKey, Value: labelValue}) + o.labelProperty.Store(cfg) +} + +// DeleteLabelProperty deletes the label property. +func (o *PersistOptions) DeleteLabelProperty(typ, labelKey, labelValue string) { + cfg := o.GetLabelPropertyConfig().Clone() + oldLabels := cfg[typ] + cfg[typ] = []StoreLabel{} + for _, l := range oldLabels { + if l.Key == labelKey && l.Value == labelValue { + continue + } + cfg[typ] = append(cfg[typ], l) + } + if len(cfg[typ]) == 0 { + delete(cfg, typ) + } + o.labelProperty.Store(cfg) +} + +// Persist saves the configuration to the storage. +func (o *PersistOptions) Persist(storage endpoint.ConfigStorage) error { + cfg := &Config{ + Schedule: *o.GetScheduleConfig(), + Replication: *o.GetReplicationConfig(), + PDServerCfg: *o.GetPDServerConfig(), + ReplicationMode: *o.GetReplicationModeConfig(), + LabelProperty: o.GetLabelPropertyConfig(), + ClusterVersion: *o.GetClusterVersion(), + } + err := storage.SaveConfig(cfg) + failpoint.Inject("persistFail", func() { + err = errors.New("fail to persist") + }) + return err +} + +// Reload reloads the configuration from the storage. +func (o *PersistOptions) Reload(storage endpoint.ConfigStorage) error { + cfg := &Config{} + // pass nil to initialize cfg to default values (all items undefined) + cfg.Adjust(nil, true) + + isExist, err := storage.LoadConfig(cfg) + if err != nil { + return err + } + o.adjustScheduleCfg(&cfg.Schedule) + cfg.PDServerCfg.MigrateDeprecatedFlags() + if isExist { + o.schedule.Store(&cfg.Schedule) + o.replication.Store(&cfg.Replication) + o.pdServerConfig.Store(&cfg.PDServerCfg) + o.replicationMode.Store(&cfg.ReplicationMode) + o.labelProperty.Store(cfg.LabelProperty) + o.SetClusterVersion(&cfg.ClusterVersion) + } + return nil +} + +func (o *PersistOptions) adjustScheduleCfg(scheduleCfg *ScheduleConfig) { + // In case we add new default schedulers. + for _, ps := range DefaultSchedulers { + if slice.NoneOf(scheduleCfg.Schedulers, func(i int) bool { + return scheduleCfg.Schedulers[i].Type == ps.Type + }) { + scheduleCfg.Schedulers = append(scheduleCfg.Schedulers, ps) + } + } + scheduleCfg.MigrateDeprecatedFlags() +} + +// CheckLabelProperty checks the label property. +func (o *PersistOptions) CheckLabelProperty(typ string, labels []*metapb.StoreLabel) bool { + pc := o.labelProperty.Load().(LabelPropertyConfig) + for _, cfg := range pc[typ] { + for _, l := range labels { + if l.Key == cfg.Key && l.Value == cfg.Value { + return true + } + } + } + return false +} + +// GetMinResolvedTSPersistenceInterval gets the interval for PD to save min resolved ts. +func (o *PersistOptions) GetMinResolvedTSPersistenceInterval() time.Duration { + return o.GetPDServerConfig().MinResolvedTSPersistenceInterval.Duration +} + +const ttlConfigPrefix = "/config/ttl" + +// SetTTLData set temporary configuration +func (o *PersistOptions) SetTTLData(parCtx context.Context, client *clientv3.Client, key string, value string, ttl time.Duration) error { + if o.ttl == nil { + o.ttl = cache.NewStringTTL(parCtx, time.Second*5, time.Minute*5) + } + _, err := etcdutil.EtcdKVPutWithTTL(parCtx, client, ttlConfigPrefix+"/"+key, value, int64(ttl.Seconds())) + if err != nil { + return err + } + o.ttl.PutWithTTL(key, value, ttl) + return nil +} + +func (o *PersistOptions) getTTLUint(key string) (uint64, bool, error) { + stringForm, ok := o.GetTTLData(key) + if !ok { + return 0, false, nil + } + r, err := strconv.ParseUint(stringForm, 10, 64) + return r, true, err +} + +func (o *PersistOptions) getTTLUintOr(key string, defaultValue uint64) uint64 { + if v, ok, err := o.getTTLUint(key); ok { + if err == nil { + return v + } + log.Warn("failed to parse " + key + " from PersistOptions's ttl storage") + } + return defaultValue +} + +func (o *PersistOptions) getTTLBool(key string) (result bool, contains bool, err error) { + stringForm, ok := o.GetTTLData(key) + if !ok { + return + } + result, err = strconv.ParseBool(stringForm) + contains = true + return +} + +func (o *PersistOptions) getTTLBoolOr(key string, defaultValue bool) bool { + if v, ok, err := o.getTTLBool(key); ok { + if err == nil { + return v + } + log.Warn("failed to parse " + key + " from PersistOptions's ttl storage") + } + return defaultValue +} + +func (o *PersistOptions) getTTLFloat(key string) (float64, bool, error) { + stringForm, ok := o.GetTTLData(key) + if !ok { + return 0, false, nil + } + r, err := strconv.ParseFloat(stringForm, 64) + return r, true, err +} + +func (o *PersistOptions) getTTLFloatOr(key string, defaultValue float64) float64 { + if v, ok, err := o.getTTLFloat(key); ok { + if err == nil { + return v + } + log.Warn("failed to parse " + key + " from PersistOptions's ttl storage") + } + return defaultValue +} + +// GetTTLData returns if there is a TTL data for a given key. +func (o *PersistOptions) GetTTLData(key string) (string, bool) { + if o.ttl == nil { + return "", false + } + if result, ok := o.ttl.Get(key); ok { + return result.(string), ok + } + return "", false +} + +// LoadTTLFromEtcd loads temporary configuration which was persisted into etcd +func (o *PersistOptions) LoadTTLFromEtcd(ctx context.Context, client *clientv3.Client) error { + resps, err := etcdutil.EtcdKVGet(client, ttlConfigPrefix, clientv3.WithPrefix()) + if err != nil { + return err + } + if o.ttl == nil { + o.ttl = cache.NewStringTTL(ctx, time.Second*5, time.Minute*5) + } + for _, resp := range resps.Kvs { + key := string(resp.Key)[len(ttlConfigPrefix)+1:] + value := string(resp.Value) + leaseID := resp.Lease + resp, err := client.TimeToLive(ctx, clientv3.LeaseID(leaseID)) + if err != nil { + return err + } + o.ttl.PutWithTTL(key, value, time.Duration(resp.TTL)*time.Second) + } + return nil +} + +// SetAllStoresLimitTTL sets all store limit for a given type and rate with ttl. +func (o *PersistOptions) SetAllStoresLimitTTL(ctx context.Context, client *clientv3.Client, typ storelimit.Type, ratePerMin float64, ttl time.Duration) error { + var err error + switch typ { + case storelimit.AddPeer: + err = o.SetTTLData(ctx, client, "default-add-peer", fmt.Sprint(ratePerMin), ttl) + case storelimit.RemovePeer: + err = o.SetTTLData(ctx, client, "default-remove-peer", fmt.Sprint(ratePerMin), ttl) + } + return err +} diff --git a/server/config/service_middleware_persist_options.go b/server/config/service_middleware_persist_options.go old mode 100644 new mode 100755 index 20f8c110a5f..f0efc72f3b9 --- a/server/config/service_middleware_persist_options.go +++ b/server/config/service_middleware_persist_options.go @@ -74,9 +74,9 @@ func (o *ServiceMiddlewarePersistOptions) Persist(storage endpoint.ServiceMiddle RateLimitConfig: *o.GetRateLimitConfig(), } err := storage.SaveServiceMiddlewareConfig(cfg) - failpoint.Inject("persistServiceMiddlewareFail", func() { + if _, _err_ := failpoint.Eval(_curpkg_("persistServiceMiddlewareFail")); _err_ == nil { err = errors.New("fail to persist") - }) + } return err } diff --git a/server/config/service_middleware_persist_options.go__failpoint_stash__ b/server/config/service_middleware_persist_options.go__failpoint_stash__ new file mode 100644 index 00000000000..20f8c110a5f --- /dev/null +++ b/server/config/service_middleware_persist_options.go__failpoint_stash__ @@ -0,0 +1,96 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "errors" + "sync/atomic" + + "github.com/pingcap/failpoint" + "github.com/tikv/pd/server/storage/endpoint" +) + +// ServiceMiddlewarePersistOptions wraps all service middleware configurations that need to persist to storage and +// allows to access them safely. +type ServiceMiddlewarePersistOptions struct { + audit atomic.Value + rateLimit atomic.Value +} + +// NewServiceMiddlewarePersistOptions creates a new ServiceMiddlewarePersistOptions instance. +func NewServiceMiddlewarePersistOptions(cfg *ServiceMiddlewareConfig) *ServiceMiddlewarePersistOptions { + o := &ServiceMiddlewarePersistOptions{} + o.audit.Store(&cfg.AuditConfig) + o.rateLimit.Store(&cfg.RateLimitConfig) + return o +} + +// GetAuditConfig returns pd service middleware configurations. +func (o *ServiceMiddlewarePersistOptions) GetAuditConfig() *AuditConfig { + return o.audit.Load().(*AuditConfig) +} + +// SetAuditConfig sets the PD service middleware configuration. +func (o *ServiceMiddlewarePersistOptions) SetAuditConfig(cfg *AuditConfig) { + o.audit.Store(cfg) +} + +// IsAuditEnabled returns whether audit middleware is enabled +func (o *ServiceMiddlewarePersistOptions) IsAuditEnabled() bool { + return o.GetAuditConfig().EnableAudit +} + +// GetRateLimitConfig returns pd service middleware configurations. +func (o *ServiceMiddlewarePersistOptions) GetRateLimitConfig() *RateLimitConfig { + return o.rateLimit.Load().(*RateLimitConfig) +} + +// SetRateLimitConfig sets the PD service middleware configuration. +func (o *ServiceMiddlewarePersistOptions) SetRateLimitConfig(cfg *RateLimitConfig) { + o.rateLimit.Store(cfg) +} + +// IsRateLimitEnabled returns whether rate limit middleware is enabled +func (o *ServiceMiddlewarePersistOptions) IsRateLimitEnabled() bool { + return o.GetRateLimitConfig().EnableRateLimit +} + +// Persist saves the configuration to the storage. +func (o *ServiceMiddlewarePersistOptions) Persist(storage endpoint.ServiceMiddlewareStorage) error { + cfg := &ServiceMiddlewareConfig{ + AuditConfig: *o.GetAuditConfig(), + RateLimitConfig: *o.GetRateLimitConfig(), + } + err := storage.SaveServiceMiddlewareConfig(cfg) + failpoint.Inject("persistServiceMiddlewareFail", func() { + err = errors.New("fail to persist") + }) + return err +} + +// Reload reloads the configuration from the storage. +func (o *ServiceMiddlewarePersistOptions) Reload(storage endpoint.ServiceMiddlewareStorage) error { + cfg := NewServiceMiddlewareConfig() + + isExist, err := storage.LoadServiceMiddlewareConfig(cfg) + if err != nil { + return err + } + if isExist { + o.audit.Store(&cfg.AuditConfig) + o.rateLimit.Store(&cfg.RateLimitConfig) + } + return nil +} diff --git a/server/election/binding__failpoint_binding__.go b/server/election/binding__failpoint_binding__.go new file mode 100755 index 00000000000..c9acc7e9ccd --- /dev/null +++ b/server/election/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package election + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/election/leadership.go b/server/election/leadership.go old mode 100644 new mode 100755 index c9215318c64..680d8d50de0 --- a/server/election/leadership.go +++ b/server/election/leadership.go @@ -186,7 +186,7 @@ func (ls *Leadership) Watch(serverCtx context.Context, revision int64) { // If the revision is compacted, will meet required revision has been compacted error. // In this case, use the compact revision to re-watch the key. for { - failpoint.Inject("delayWatcher", nil) + failpoint.Eval(_curpkg_("delayWatcher")) rch := watcher.Watch(ctx, ls.leaderKey, clientv3.WithRev(revision)) for wresp := range rch { // meet compacted error, use the compact revision. diff --git a/server/election/leadership.go__failpoint_stash__ b/server/election/leadership.go__failpoint_stash__ new file mode 100644 index 00000000000..c9215318c64 --- /dev/null +++ b/server/election/leadership.go__failpoint_stash__ @@ -0,0 +1,237 @@ +// Copyright 2020 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package election + +import ( + "context" + "sync/atomic" + + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/etcdutil" + "github.com/tikv/pd/server/storage/kv" + "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/mvcc/mvccpb" + "go.uber.org/zap" +) + +// GetLeader gets the corresponding leader from etcd by given leaderPath (as the key). +func GetLeader(c *clientv3.Client, leaderPath string) (*pdpb.Member, int64, error) { + leader := &pdpb.Member{} + ok, rev, err := etcdutil.GetProtoMsgWithModRev(c, leaderPath, leader) + if err != nil { + return nil, 0, err + } + if !ok { + return nil, 0, nil + } + + return leader, rev, nil +} + +// Leadership is used to manage the leadership campaigning. +type Leadership struct { + // purpose is used to show what this election for + purpose string + // The lease which is used to get this leadership + lease atomic.Value // stored as *lease + client *clientv3.Client + // leaderKey and leaderValue are key-value pair in etcd + leaderKey string + leaderValue string + + keepAliveCtx context.Context + keepAliveCancelFunc context.CancelFunc +} + +// NewLeadership creates a new Leadership. +func NewLeadership(client *clientv3.Client, leaderKey, purpose string) *Leadership { + leadership := &Leadership{ + purpose: purpose, + client: client, + leaderKey: leaderKey, + } + return leadership +} + +// getLease gets the lease of leadership, only if leadership is valid, +// i.e the owner is a true leader, the lease is not nil. +func (ls *Leadership) getLease() *lease { + l := ls.lease.Load() + if l == nil { + return nil + } + return l.(*lease) +} + +func (ls *Leadership) setLease(lease *lease) { + ls.lease.Store(lease) +} + +// GetClient is used to get the etcd client. +func (ls *Leadership) GetClient() *clientv3.Client { + if ls == nil { + return nil + } + return ls.client +} + +// GetLeaderKey is used to get the leader key of etcd. +func (ls *Leadership) GetLeaderKey() string { + if ls == nil { + return "" + } + return ls.leaderKey +} + +// Campaign is used to campaign the leader with given lease and returns a leadership +func (ls *Leadership) Campaign(leaseTimeout int64, leaderData string, cmps ...clientv3.Cmp) error { + ls.leaderValue = leaderData + // Create a new lease to campaign + newLease := &lease{ + Purpose: ls.purpose, + client: ls.client, + lease: clientv3.NewLease(ls.client), + } + ls.setLease(newLease) + if err := newLease.Grant(leaseTimeout); err != nil { + return err + } + finalCmps := make([]clientv3.Cmp, 0, len(cmps)+1) + finalCmps = append(finalCmps, cmps...) + // The leader key must not exist, so the CreateRevision is 0. + finalCmps = append(finalCmps, clientv3.Compare(clientv3.CreateRevision(ls.leaderKey), "=", 0)) + resp, err := kv.NewSlowLogTxn(ls.client). + If(finalCmps...). + Then(clientv3.OpPut(ls.leaderKey, leaderData, clientv3.WithLease(newLease.ID))). + Commit() + log.Info("check campaign resp", zap.Any("resp", resp)) + if err != nil { + newLease.Close() + return errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause() + } + if !resp.Succeeded { + newLease.Close() + return errs.ErrEtcdTxnConflict.FastGenByArgs() + } + log.Info("write leaderData to leaderPath ok", zap.String("leaderPath", ls.leaderKey), zap.String("purpose", ls.purpose)) + return nil +} + +// Keep will keep the leadership available by update the lease's expired time continuously +func (ls *Leadership) Keep(ctx context.Context) { + if ls == nil { + return + } + ls.keepAliveCtx, ls.keepAliveCancelFunc = context.WithCancel(ctx) + go ls.getLease().KeepAlive(ls.keepAliveCtx) +} + +// Check returns whether the leadership is still available. +func (ls *Leadership) Check() bool { + return ls != nil && ls.getLease() != nil && !ls.getLease().IsExpired() +} + +// LeaderTxn returns txn() with a leader comparison to guarantee that +// the transaction can be executed only if the server is leader. +func (ls *Leadership) LeaderTxn(cs ...clientv3.Cmp) clientv3.Txn { + txn := kv.NewSlowLogTxn(ls.client) + return txn.If(append(cs, ls.leaderCmp())...) +} + +func (ls *Leadership) leaderCmp() clientv3.Cmp { + return clientv3.Compare(clientv3.Value(ls.leaderKey), "=", ls.leaderValue) +} + +// DeleteLeaderKey deletes the corresponding leader from etcd by the leaderPath as the key. +func (ls *Leadership) DeleteLeaderKey() error { + resp, err := kv.NewSlowLogTxn(ls.client).Then(clientv3.OpDelete(ls.leaderKey)).Commit() + if err != nil { + return errs.ErrEtcdKVDelete.Wrap(err).GenWithStackByCause() + } + if !resp.Succeeded { + return errs.ErrEtcdTxnConflict.FastGenByArgs() + } + // Reset the lease as soon as possible. + ls.Reset() + log.Info("delete the leader key ok", zap.String("leaderPath", ls.leaderKey), zap.String("purpose", ls.purpose)) + return nil +} + +// Watch is used to watch the changes of the leadership, usually is used to +// detect the leadership stepping down and restart an election as soon as possible. +func (ls *Leadership) Watch(serverCtx context.Context, revision int64) { + if ls == nil { + return + } + watcher := clientv3.NewWatcher(ls.client) + defer watcher.Close() + ctx, cancel := context.WithCancel(serverCtx) + defer cancel() + // The revision is the revision of last modification on this key. + // If the revision is compacted, will meet required revision has been compacted error. + // In this case, use the compact revision to re-watch the key. + for { + failpoint.Inject("delayWatcher", nil) + rch := watcher.Watch(ctx, ls.leaderKey, clientv3.WithRev(revision)) + for wresp := range rch { + // meet compacted error, use the compact revision. + if wresp.CompactRevision != 0 { + log.Warn("required revision has been compacted, use the compact revision", + zap.Int64("required-revision", revision), + zap.Int64("compact-revision", wresp.CompactRevision)) + revision = wresp.CompactRevision + break + } + if wresp.Canceled { + log.Error("leadership watcher is canceled with", + zap.Int64("revision", revision), + zap.String("leader-key", ls.leaderKey), + zap.String("purpose", ls.purpose), + errs.ZapError(errs.ErrEtcdWatcherCancel, wresp.Err())) + return + } + + for _, ev := range wresp.Events { + if ev.Type == mvccpb.DELETE { + log.Info("current leadership is deleted", + zap.String("leader-key", ls.leaderKey), + zap.String("purpose", ls.purpose)) + return + } + } + } + + select { + case <-ctx.Done(): + // server closed, return + return + default: + } + } +} + +// Reset does some defer jobs such as closing lease, resetting lease etc. +func (ls *Leadership) Reset() { + if ls == nil || ls.getLease() == nil { + return + } + if ls.keepAliveCancelFunc != nil { + ls.keepAliveCancelFunc() + } + ls.getLease().Close() +} diff --git a/server/grpc_service.go b/server/grpc_service.go old mode 100644 new mode 100755 index 7229b085d0b..f44423fffb1 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -73,9 +73,9 @@ type GrpcServer struct { type forwardFn func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) func (s *GrpcServer) unaryMiddleware(ctx context.Context, header *pdpb.RequestHeader, fn forwardFn) (rsp interface{}, err error) { - failpoint.Inject("customTimeout", func() { + if _, _err_ := failpoint.Eval(_curpkg_("customTimeout")); _err_ == nil { time.Sleep(5 * time.Second) - }) + } forwardedHost := getForwardedHost(ctx) if !s.isLocalRequest(forwardedHost) { client, err := s.getDelegateClient(ctx, forwardedHost) @@ -768,10 +768,10 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { }() for { request, err := server.Recv() - failpoint.Inject("grpcClientClosed", func() { + if _, _err_ := failpoint.Eval(_curpkg_("grpcClientClosed")); _err_ == nil { err = status.Error(codes.Canceled, "grpc client closed") request = nil - }) + } if err == io.EOF { return nil } @@ -779,9 +779,9 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { return errors.WithStack(err) } forwardedHost := getForwardedHost(stream.Context()) - failpoint.Inject("grpcClientClosed", func() { + if _, _err_ := failpoint.Eval(_curpkg_("grpcClientClosed")); _err_ == nil { forwardedHost = s.GetMember().Member().GetClientUrls()[0] - }) + } if !s.isLocalRequest(forwardedHost) { if forwardStream == nil || lastForwardedHost != forwardedHost { if cancel != nil { @@ -1602,13 +1602,13 @@ func (s *GrpcServer) SyncMaxTS(_ context.Context, request *pdpb.SyncMaxTSRequest syncedDCs = append(syncedDCs, allocator.GetDCLocation()) } - failpoint.Inject("mockLocalAllocatorLeaderChange", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockLocalAllocatorLeaderChange")); _err_ == nil { if !mockLocalAllocatorLeaderChangeFlag { maxLocalTS = nil request.MaxTs = nil mockLocalAllocatorLeaderChangeFlag = true } - }) + } if maxLocalTS == nil { return &pdpb.SyncMaxTSResponse{ @@ -1810,9 +1810,9 @@ func getForwardedHost(ctx context.Context) string { } func (s *GrpcServer) isLocalRequest(forwardedHost string) bool { - failpoint.Inject("useForwardRequest", func() { - failpoint.Return(false) - }) + if _, _err_ := failpoint.Eval(_curpkg_("useForwardRequest")); _err_ == nil { + return false + } if forwardedHost == "" { return true } diff --git a/server/grpc_service.go__failpoint_stash__ b/server/grpc_service.go__failpoint_stash__ new file mode 100644 index 00000000000..7229b085d0b --- /dev/null +++ b/server/grpc_service.go__failpoint_stash__ @@ -0,0 +1,2085 @@ +// Copyright 2017 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "fmt" + "io" + "strconv" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/grpcutil" + "github.com/tikv/pd/pkg/logutil" + "github.com/tikv/pd/pkg/tsoutil" + "github.com/tikv/pd/server/cluster" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/storage/endpoint" + "github.com/tikv/pd/server/storage/kv" + "github.com/tikv/pd/server/tso" + "github.com/tikv/pd/server/versioninfo" + "go.etcd.io/etcd/clientv3" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +const ( + heartbeatSendTimeout = 5 * time.Second + + // tso + maxMergeTSORequests = 10000 + defaultTSOProxyTimeout = 3 * time.Second + + // global config + globalConfigPath = "/global/config/" +) + +// gRPC errors +var ( + // ErrNotLeader is returned when current server is not the leader and not possible to process request. + // TODO: work as proxy. + ErrNotLeader = status.Errorf(codes.Unavailable, "not leader") + ErrNotStarted = status.Errorf(codes.Unavailable, "server not started") + ErrSendHeartbeatTimeout = status.Errorf(codes.DeadlineExceeded, "send heartbeat timeout") +) + +// GrpcServer wraps Server to provide grpc service. +type GrpcServer struct { + *Server +} + +type forwardFn func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) + +func (s *GrpcServer) unaryMiddleware(ctx context.Context, header *pdpb.RequestHeader, fn forwardFn) (rsp interface{}, err error) { + failpoint.Inject("customTimeout", func() { + time.Sleep(5 * time.Second) + }) + forwardedHost := getForwardedHost(ctx) + if !s.isLocalRequest(forwardedHost) { + client, err := s.getDelegateClient(ctx, forwardedHost) + if err != nil { + return nil, err + } + ctx = grpcutil.ResetForwardContext(ctx) + return fn(ctx, client) + } + if err := s.validateRequest(header); err != nil { + return nil, err + } + return nil, nil +} + +func (s *GrpcServer) wrapErrorToHeader(errorType pdpb.ErrorType, message string) *pdpb.ResponseHeader { + return s.errorHeader(&pdpb.Error{ + Type: errorType, + Message: message, + }) +} + +// GetMembers implements gRPC PDServer. +func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb.GetMembersResponse, error) { + // Here we purposely do not check the cluster ID because the client does not know the correct cluster ID + // at startup and needs to get the cluster ID with the first request (i.e. GetMembers). + if s.IsClosed() { + return &pdpb.GetMembersResponse{ + Header: &pdpb.ResponseHeader{ + Error: &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: errs.ErrServerNotStarted.FastGenByArgs().Error(), + }, + }, + }, nil + } + members, err := cluster.GetMembers(s.GetClient()) + if err != nil { + return &pdpb.GetMembersResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + }, nil + } + + var etcdLeader, pdLeader *pdpb.Member + leaderID := s.member.GetEtcdLeader() + for _, m := range members { + if m.MemberId == leaderID { + etcdLeader = m + break + } + } + + tsoAllocatorManager := s.GetTSOAllocatorManager() + tsoAllocatorLeaders, err := tsoAllocatorManager.GetLocalAllocatorLeaders() + if err != nil { + return &pdpb.GetMembersResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + }, nil + } + + leader := s.member.GetLeader() + for _, m := range members { + if m.MemberId == leader.GetMemberId() { + pdLeader = m + break + } + } + + return &pdpb.GetMembersResponse{ + Header: s.header(), + Members: members, + Leader: pdLeader, + EtcdLeader: etcdLeader, + TsoAllocatorLeaders: tsoAllocatorLeaders, + }, nil +} + +// Tso implements gRPC PDServer. +func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { + var ( + doneCh chan struct{} + errCh chan error + ) + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + for { + // Prevent unnecessary performance overhead of the channel. + if errCh != nil { + select { + case err := <-errCh: + return errors.WithStack(err) + default: + } + } + request, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return errors.WithStack(err) + } + + streamCtx := stream.Context() + forwardedHost := getForwardedHost(streamCtx) + if !s.isLocalRequest(forwardedHost) { + if errCh == nil { + doneCh = make(chan struct{}) + defer close(doneCh) + errCh = make(chan error) + } + s.dispatchTSORequest(ctx, &tsoRequest{ + forwardedHost, + request, + stream, + }, forwardedHost, doneCh, errCh) + continue + } + + start := time.Now() + // TSO uses leader lease to determine validity. No need to check leader here. + if s.IsClosed() { + return status.Errorf(codes.Unknown, "server not started") + } + if request.GetHeader().GetClusterId() != s.clusterID { + return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId()) + } + count := request.GetCount() + ts, err := s.tsoAllocatorManager.HandleTSORequest(request.GetDcLocation(), count) + if err != nil { + return status.Errorf(codes.Unknown, err.Error()) + } + tsoHandleDuration.Observe(time.Since(start).Seconds()) + response := &pdpb.TsoResponse{ + Header: s.header(), + Timestamp: &ts, + Count: count, + } + if err := stream.Send(response); err != nil { + return errors.WithStack(err) + } + } +} + +type tsoRequest struct { + forwardedHost string + request *pdpb.TsoRequest + stream pdpb.PD_TsoServer +} + +func (s *GrpcServer) dispatchTSORequest(ctx context.Context, request *tsoRequest, forwardedHost string, doneCh <-chan struct{}, errCh chan<- error) { + tsoRequestChInterface, loaded := s.tsoDispatcher.LoadOrStore(forwardedHost, make(chan *tsoRequest, maxMergeTSORequests)) + if !loaded { + tsDeadlineCh := make(chan deadline, 1) + go s.handleDispatcher(ctx, forwardedHost, tsoRequestChInterface.(chan *tsoRequest), tsDeadlineCh, doneCh, errCh) + go watchTSDeadline(ctx, tsDeadlineCh) + } + tsoRequestChInterface.(chan *tsoRequest) <- request +} + +func (s *GrpcServer) handleDispatcher(ctx context.Context, forwardedHost string, tsoRequestCh <-chan *tsoRequest, tsDeadlineCh chan<- deadline, doneCh <-chan struct{}, errCh chan<- error) { + dispatcherCtx, ctxCancel := context.WithCancel(ctx) + defer ctxCancel() + defer s.tsoDispatcher.Delete(forwardedHost) + + var ( + forwardStream pdpb.PD_TsoClient + cancel context.CancelFunc + ) + client, err := s.getDelegateClient(ctx, forwardedHost) + if err != nil { + goto errHandling + } + log.Info("create tso forward stream", zap.String("forwarded-host", forwardedHost)) + forwardStream, cancel, err = s.createTsoForwardStream(client) +errHandling: + if err != nil || forwardStream == nil { + log.Error("create tso forwarding stream error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCCreateStream, err)) + select { + case <-dispatcherCtx.Done(): + return + case _, ok := <-doneCh: + if !ok { + return + } + case errCh <- err: + close(errCh) + return + } + } + defer cancel() + + requests := make([]*tsoRequest, maxMergeTSORequests+1) + for { + select { + case first := <-tsoRequestCh: + pendingTSOReqCount := len(tsoRequestCh) + 1 + requests[0] = first + for i := 1; i < pendingTSOReqCount; i++ { + requests[i] = <-tsoRequestCh + } + done := make(chan struct{}) + dl := deadline{ + timer: time.After(defaultTSOProxyTimeout), + done: done, + cancel: cancel, + } + select { + case tsDeadlineCh <- dl: + case <-dispatcherCtx.Done(): + return + } + err = s.processTSORequests(forwardStream, requests[:pendingTSOReqCount]) + close(done) + if err != nil { + log.Error("proxy forward tso error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCSend, err)) + select { + case <-dispatcherCtx.Done(): + return + case _, ok := <-doneCh: + if !ok { + return + } + case errCh <- err: + close(errCh) + return + } + } + case <-dispatcherCtx.Done(): + return + } + } +} + +func (s *GrpcServer) processTSORequests(forwardStream pdpb.PD_TsoClient, requests []*tsoRequest) error { + start := time.Now() + // Merge the requests + count := uint32(0) + for _, request := range requests { + count += request.request.GetCount() + } + req := &pdpb.TsoRequest{ + Header: requests[0].request.GetHeader(), + Count: count, + // TODO: support Local TSO proxy forwarding. + DcLocation: requests[0].request.GetDcLocation(), + } + // Send to the leader stream. + if err := forwardStream.Send(req); err != nil { + return err + } + resp, err := forwardStream.Recv() + if err != nil { + return err + } + tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) + tsoProxyBatchSize.Observe(float64(count)) + // Split the response + physical, logical, suffixBits := resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits() + // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. + // This is different from the logic of client batch, for example, if we have a largest ts whose logical part is 10, + // count is 5, then the splitting results should be 5 and 10. + firstLogical := addLogical(logical, -int64(count), suffixBits) + return s.finishTSORequest(requests, physical, firstLogical, suffixBits) +} + +// Because of the suffix, we need to shift the count before we add it to the logical part. +func addLogical(logical, count int64, suffixBits uint32) int64 { + return logical + count< s.cfg.HeartbeatStreamBindInterval.Duration { + regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "bind").Inc() + s.hbStreams.BindStream(storeID, server) + // refresh FlowRoundByDigit + flowRoundOption = core.WithFlowRoundByDigit(s.persistOptions.GetPDServerConfig().FlowRoundByDigit) + lastBind = time.Now() + } + + region := core.RegionFromHeartbeat(request, flowRoundOption, core.SetFromHeartbeat(true)) + if region.GetLeader() == nil { + log.Error("invalid request, the leader is nil", zap.Reflect("request", request), errs.ZapError(errs.ErrLeaderNil)) + regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "invalid-leader").Inc() + msg := fmt.Sprintf("invalid request leader, %v", request) + s.hbStreams.SendErr(pdpb.ErrorType_UNKNOWN, msg, request.GetLeader()) + continue + } + if region.GetID() == 0 { + regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "invalid-region").Inc() + msg := fmt.Sprintf("invalid request region, %v", request) + s.hbStreams.SendErr(pdpb.ErrorType_UNKNOWN, msg, request.GetLeader()) + continue + } + + // If the region peer count is 0, then we should not handle this. + if len(region.GetPeers()) == 0 { + log.Warn("invalid region, zero region peer count", + logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(region.GetMeta()))) + regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "no-peer").Inc() + msg := fmt.Sprintf("invalid region, zero region peer count: %v", logutil.RedactStringer(core.RegionToHexMeta(region.GetMeta()))) + s.hbStreams.SendErr(pdpb.ErrorType_UNKNOWN, msg, request.GetLeader()) + continue + } + start := time.Now() + + err = rc.HandleRegionHeartbeat(region) + if err != nil { + regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "err").Inc() + msg := err.Error() + s.hbStreams.SendErr(pdpb.ErrorType_UNKNOWN, msg, request.GetLeader()) + continue + } + regionHeartbeatHandleDuration.WithLabelValues(storeAddress, storeLabel).Observe(time.Since(start).Seconds()) + regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "ok").Inc() + } +} + +// GetRegion implements gRPC PDServer. +func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).GetRegion(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.GetRegionResponse), nil + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.GetRegionResponse{Header: s.notBootstrappedHeader()}, nil + } + region := rc.GetRegionByKey(request.GetRegionKey()) + if region == nil { + return &pdpb.GetRegionResponse{Header: s.header()}, nil + } + var buckets *metapb.Buckets + if rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { + buckets = region.GetBuckets() + } + return &pdpb.GetRegionResponse{ + Header: s.header(), + Region: region.GetMeta(), + Leader: region.GetLeader(), + DownPeers: region.GetDownPeers(), + PendingPeers: region.GetPendingPeers(), + Buckets: buckets, + }, nil +} + +// GetPrevRegion implements gRPC PDServer +func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).GetPrevRegion(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.GetRegionResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.GetRegionResponse{Header: s.notBootstrappedHeader()}, nil + } + + region := rc.GetPrevRegionByKey(request.GetRegionKey()) + if region == nil { + return &pdpb.GetRegionResponse{Header: s.header()}, nil + } + var buckets *metapb.Buckets + if rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { + buckets = region.GetBuckets() + } + return &pdpb.GetRegionResponse{ + Header: s.header(), + Region: region.GetMeta(), + Leader: region.GetLeader(), + DownPeers: region.GetDownPeers(), + PendingPeers: region.GetPendingPeers(), + Buckets: buckets, + }, nil +} + +// GetRegionByID implements gRPC PDServer. +func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionByIDRequest) (*pdpb.GetRegionResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).GetRegionByID(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.GetRegionResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.GetRegionResponse{Header: s.notBootstrappedHeader()}, nil + } + region := rc.GetRegion(request.GetRegionId()) + if region == nil { + return &pdpb.GetRegionResponse{Header: s.header()}, nil + } + var buckets *metapb.Buckets + if rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { + buckets = region.GetBuckets() + } + return &pdpb.GetRegionResponse{ + Header: s.header(), + Region: region.GetMeta(), + Leader: region.GetLeader(), + DownPeers: region.GetDownPeers(), + PendingPeers: region.GetPendingPeers(), + Buckets: buckets, + }, nil +} + +// ScanRegions implements gRPC PDServer. +func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsRequest) (*pdpb.ScanRegionsResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).ScanRegions(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.ScanRegionsResponse), nil + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.ScanRegionsResponse{Header: s.notBootstrappedHeader()}, nil + } + regions := rc.ScanRegions(request.GetStartKey(), request.GetEndKey(), int(request.GetLimit())) + resp := &pdpb.ScanRegionsResponse{Header: s.header()} + for _, r := range regions { + leader := r.GetLeader() + if leader == nil { + leader = &metapb.Peer{} + } + // Set RegionMetas and Leaders to make it compatible with old client. + resp.RegionMetas = append(resp.RegionMetas, r.GetMeta()) + resp.Leaders = append(resp.Leaders, leader) + resp.Regions = append(resp.Regions, &pdpb.Region{ + Region: r.GetMeta(), + Leader: leader, + DownPeers: r.GetDownPeers(), + PendingPeers: r.GetPendingPeers(), + }) + } + return resp, nil +} + +// AskSplit implements gRPC PDServer. +func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).AskSplit(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.AskSplitResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.AskSplitResponse{Header: s.notBootstrappedHeader()}, nil + } + if request.GetRegion() == nil { + return &pdpb.AskSplitResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_REGION_NOT_FOUND, + "missing region for split"), + }, nil + } + req := &pdpb.AskSplitRequest{ + Region: request.Region, + } + split, err := rc.HandleAskSplit(req) + if err != nil { + return &pdpb.AskSplitResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + }, nil + } + + return &pdpb.AskSplitResponse{ + Header: s.header(), + NewRegionId: split.NewRegionId, + NewPeerIds: split.NewPeerIds, + }, nil +} + +// AskBatchSplit implements gRPC PDServer. +func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).AskBatchSplit(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.AskBatchSplitResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.AskBatchSplitResponse{Header: s.notBootstrappedHeader()}, nil + } + + if !versioninfo.IsFeatureSupported(rc.GetOpts().GetClusterVersion(), versioninfo.BatchSplit) { + return &pdpb.AskBatchSplitResponse{Header: s.incompatibleVersion("batch_split")}, nil + } + if request.GetRegion() == nil { + return &pdpb.AskBatchSplitResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_REGION_NOT_FOUND, + "missing region for split"), + }, nil + } + req := &pdpb.AskBatchSplitRequest{ + Region: request.Region, + SplitCount: request.SplitCount, + } + split, err := rc.HandleAskBatchSplit(req) + if err != nil { + return &pdpb.AskBatchSplitResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + }, nil + } + + return &pdpb.AskBatchSplitResponse{ + Header: s.header(), + Ids: split.Ids, + }, nil +} + +// ReportSplit implements gRPC PDServer. +func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitRequest) (*pdpb.ReportSplitResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).ReportSplit(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.ReportSplitResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.ReportSplitResponse{Header: s.notBootstrappedHeader()}, nil + } + _, err := rc.HandleReportSplit(request) + if err != nil { + return &pdpb.ReportSplitResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + }, nil + } + + return &pdpb.ReportSplitResponse{ + Header: s.header(), + }, nil +} + +// ReportBatchSplit implements gRPC PDServer. +func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportBatchSplitRequest) (*pdpb.ReportBatchSplitResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).ReportBatchSplit(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.ReportBatchSplitResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.ReportBatchSplitResponse{Header: s.notBootstrappedHeader()}, nil + } + + _, err := rc.HandleBatchReportSplit(request) + if err != nil { + return &pdpb.ReportBatchSplitResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, + err.Error()), + }, nil + } + + return &pdpb.ReportBatchSplitResponse{ + Header: s.header(), + }, nil +} + +// GetClusterConfig implements gRPC PDServer. +func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClusterConfigRequest) (*pdpb.GetClusterConfigResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).GetClusterConfig(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.GetClusterConfigResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.GetClusterConfigResponse{Header: s.notBootstrappedHeader()}, nil + } + return &pdpb.GetClusterConfigResponse{ + Header: s.header(), + Cluster: rc.GetMetaCluster(), + }, nil +} + +// PutClusterConfig implements gRPC PDServer. +func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClusterConfigRequest) (*pdpb.PutClusterConfigResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).PutClusterConfig(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.PutClusterConfigResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.PutClusterConfigResponse{Header: s.notBootstrappedHeader()}, nil + } + conf := request.GetCluster() + if err := rc.PutMetaCluster(conf); err != nil { + return &pdpb.PutClusterConfigResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, + err.Error()), + }, nil + } + + log.Info("put cluster config ok", zap.Reflect("config", conf)) + + return &pdpb.PutClusterConfigResponse{ + Header: s.header(), + }, nil +} + +// ScatterRegion implements gRPC PDServer. +func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterRegionRequest) (*pdpb.ScatterRegionResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).ScatterRegion(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.ScatterRegionResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.ScatterRegionResponse{Header: s.notBootstrappedHeader()}, nil + } + + if len(request.GetRegionsId()) > 0 { + percentage, err := scatterRegions(rc, request.GetRegionsId(), request.GetGroup(), int(request.GetRetryLimit())) + if err != nil { + return nil, err + } + return &pdpb.ScatterRegionResponse{ + Header: s.header(), + FinishedPercentage: uint64(percentage), + }, nil + } + // TODO: Deprecate it use `request.GetRegionsID`. + //nolint + region := rc.GetRegion(request.GetRegionId()) + if region == nil { + if request.GetRegion() == nil { + //nolint + return &pdpb.ScatterRegionResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_REGION_NOT_FOUND, + "region %d not found"), + }, nil + } + region = core.NewRegionInfo(request.GetRegion(), request.GetLeader()) + } + + op, err := rc.GetRegionScatter().Scatter(region, request.GetGroup()) + if err != nil { + return nil, err + } + if op != nil { + rc.GetOperatorController().AddOperator(op) + } + + return &pdpb.ScatterRegionResponse{ + Header: s.header(), + FinishedPercentage: 100, + }, nil +} + +// GetGCSafePoint implements gRPC PDServer. +func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafePointRequest) (*pdpb.GetGCSafePointResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).GetGCSafePoint(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.GetGCSafePointResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.GetGCSafePointResponse{Header: s.notBootstrappedHeader()}, nil + } + + safePoint, err := s.gcSafePointManager.LoadGCSafePoint() + if err != nil { + return nil, err + } + + return &pdpb.GetGCSafePointResponse{ + Header: s.header(), + SafePoint: safePoint, + }, nil +} + +// SyncRegions syncs the regions. +func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { + if s.IsClosed() || s.cluster == nil { + return ErrNotStarted + } + ctx := s.cluster.Context() + if ctx == nil { + return ErrNotStarted + } + return s.cluster.GetRegionSyncer().Sync(ctx, stream) +} + +// UpdateGCSafePoint implements gRPC PDServer. +func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.UpdateGCSafePointRequest) (*pdpb.UpdateGCSafePointResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).UpdateGCSafePoint(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.UpdateGCSafePointResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.UpdateGCSafePointResponse{Header: s.notBootstrappedHeader()}, nil + } + + newSafePoint := request.GetSafePoint() + oldSafePoint, err := s.gcSafePointManager.UpdateGCSafePoint(newSafePoint) + if err != nil { + return nil, err + } + + if newSafePoint > oldSafePoint { + log.Info("updated gc safe point", + zap.Uint64("safe-point", newSafePoint)) + } else if newSafePoint < oldSafePoint { + log.Warn("trying to update gc safe point", + zap.Uint64("old-safe-point", oldSafePoint), + zap.Uint64("new-safe-point", newSafePoint)) + newSafePoint = oldSafePoint + } + + return &pdpb.UpdateGCSafePointResponse{ + Header: s.header(), + NewSafePoint: newSafePoint, + }, nil +} + +// UpdateServiceGCSafePoint update the safepoint for specific service +func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb.UpdateServiceGCSafePointRequest) (*pdpb.UpdateServiceGCSafePointResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).UpdateServiceGCSafePoint(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.UpdateServiceGCSafePointResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.UpdateServiceGCSafePointResponse{Header: s.notBootstrappedHeader()}, nil + } + var storage endpoint.GCSafePointStorage = s.storage + if request.TTL <= 0 { + if err := storage.RemoveServiceGCSafePoint(string(request.ServiceId)); err != nil { + return nil, err + } + } + + nowTSO, err := s.tsoAllocatorManager.HandleTSORequest(tso.GlobalDCLocation, 1) + if err != nil { + return nil, err + } + now, _ := tsoutil.ParseTimestamp(nowTSO) + serviceID := string(request.ServiceId) + min, updated, err := s.gcSafePointManager.UpdateServiceGCSafePoint(serviceID, request.GetSafePoint(), request.GetTTL(), now) + if err != nil { + return nil, err + } + if updated { + log.Info("update service GC safe point", + zap.String("service-id", serviceID), + zap.Int64("expire-at", now.Unix()+request.GetTTL()), + zap.Uint64("safepoint", request.GetSafePoint())) + } + return &pdpb.UpdateServiceGCSafePointResponse{ + Header: s.header(), + ServiceId: []byte(min.ServiceID), + TTL: min.ExpiredAt - now.Unix(), + MinSafePoint: min.SafePoint, + }, nil +} + +// GetOperator gets information about the operator belonging to the specify region. +func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorRequest) (*pdpb.GetOperatorResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).GetOperator(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.GetOperatorResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.GetOperatorResponse{Header: s.notBootstrappedHeader()}, nil + } + + opController := rc.GetOperatorController() + requestID := request.GetRegionId() + r := opController.GetOperatorStatus(requestID) + if r == nil { + header := s.errorHeader(&pdpb.Error{ + Type: pdpb.ErrorType_REGION_NOT_FOUND, + Message: "Not Found", + }) + return &pdpb.GetOperatorResponse{Header: header}, nil + } + + return &pdpb.GetOperatorResponse{ + Header: s.header(), + RegionId: requestID, + Desc: []byte(r.Desc()), + Kind: []byte(r.Kind().String()), + Status: r.Status, + }, nil +} + +// validateRequest checks if Server is leader and clusterID is matched. +// TODO: Call it in gRPC interceptor. +func (s *GrpcServer) validateRequest(header *pdpb.RequestHeader) error { + if s.IsClosed() || !s.member.IsLeader() { + return ErrNotLeader + } + if header.GetClusterId() != s.clusterID { + return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, header.GetClusterId()) + } + return nil +} + +func (s *GrpcServer) header() *pdpb.ResponseHeader { + if s.clusterID == 0 { + return s.wrapErrorToHeader(pdpb.ErrorType_NOT_BOOTSTRAPPED, "cluster id is not ready") + } + return &pdpb.ResponseHeader{ClusterId: s.clusterID} +} + +func (s *GrpcServer) errorHeader(err *pdpb.Error) *pdpb.ResponseHeader { + return &pdpb.ResponseHeader{ + ClusterId: s.clusterID, + Error: err, + } +} + +func (s *GrpcServer) notBootstrappedHeader() *pdpb.ResponseHeader { + return s.errorHeader(&pdpb.Error{ + Type: pdpb.ErrorType_NOT_BOOTSTRAPPED, + Message: "cluster is not bootstrapped", + }) +} + +func (s *GrpcServer) incompatibleVersion(tag string) *pdpb.ResponseHeader { + msg := fmt.Sprintf("%s incompatible with current cluster version %s", tag, s.persistOptions.GetClusterVersion()) + return s.errorHeader(&pdpb.Error{ + Type: pdpb.ErrorType_INCOMPATIBLE_VERSION, + Message: msg, + }) +} + +func (s *GrpcServer) invalidValue(msg string) *pdpb.ResponseHeader { + return s.errorHeader(&pdpb.Error{ + Type: pdpb.ErrorType_INVALID_VALUE, + Message: msg, + }) +} + +// Only used for the TestLocalAllocatorLeaderChange. +var mockLocalAllocatorLeaderChangeFlag = false + +// SyncMaxTS will check whether MaxTS is the biggest one among all Local TSOs this PD is holding when skipCheck is set, +// and write it into all Local TSO Allocators then if it's indeed the biggest one. +func (s *GrpcServer) SyncMaxTS(_ context.Context, request *pdpb.SyncMaxTSRequest) (*pdpb.SyncMaxTSResponse, error) { + if err := s.validateInternalRequest(request.GetHeader(), true); err != nil { + return nil, err + } + tsoAllocatorManager := s.GetTSOAllocatorManager() + // There is no dc-location found in this server, return err. + if tsoAllocatorManager.GetClusterDCLocationsNumber() == 0 { + return &pdpb.SyncMaxTSResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, + "empty cluster dc-location found, checker may not work properly"), + }, nil + } + // Get all Local TSO Allocator leaders + allocatorLeaders, err := tsoAllocatorManager.GetHoldingLocalAllocatorLeaders() + if err != nil { + return &pdpb.SyncMaxTSResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + }, nil + } + if !request.GetSkipCheck() { + var maxLocalTS *pdpb.Timestamp + syncedDCs := make([]string, 0, len(allocatorLeaders)) + for _, allocator := range allocatorLeaders { + // No longer leader, just skip here because + // the global allocator will check if all DCs are handled. + if !allocator.IsAllocatorLeader() { + continue + } + currentLocalTSO, err := allocator.GetCurrentTSO() + if err != nil { + return &pdpb.SyncMaxTSResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + }, nil + } + if tsoutil.CompareTimestamp(currentLocalTSO, maxLocalTS) > 0 { + maxLocalTS = currentLocalTSO + } + syncedDCs = append(syncedDCs, allocator.GetDCLocation()) + } + + failpoint.Inject("mockLocalAllocatorLeaderChange", func() { + if !mockLocalAllocatorLeaderChangeFlag { + maxLocalTS = nil + request.MaxTs = nil + mockLocalAllocatorLeaderChangeFlag = true + } + }) + + if maxLocalTS == nil { + return &pdpb.SyncMaxTSResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, + "local tso allocator leaders have changed during the sync, should retry"), + }, nil + } + if request.GetMaxTs() == nil { + return &pdpb.SyncMaxTSResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, + "empty maxTS in the request, should retry"), + }, nil + } + // Found a bigger or equal maxLocalTS, return it directly. + cmpResult := tsoutil.CompareTimestamp(maxLocalTS, request.GetMaxTs()) + if cmpResult >= 0 { + // Found an equal maxLocalTS, plus 1 to logical part before returning it. + // For example, we have a Global TSO t1 and a Local TSO t2, they have the + // same physical and logical parts. After being differentiating with suffix, + // there will be (t1.logical << suffixNum + 0) < (t2.logical << suffixNum + N), + // where N is bigger than 0, which will cause a Global TSO fallback than the previous Local TSO. + if cmpResult == 0 { + maxLocalTS.Logical += 1 + } + return &pdpb.SyncMaxTSResponse{ + Header: s.header(), + MaxLocalTs: maxLocalTS, + SyncedDcs: syncedDCs, + }, nil + } + } + syncedDCs := make([]string, 0, len(allocatorLeaders)) + for _, allocator := range allocatorLeaders { + if !allocator.IsAllocatorLeader() { + continue + } + if err := allocator.WriteTSO(request.GetMaxTs()); err != nil { + return &pdpb.SyncMaxTSResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + }, nil + } + syncedDCs = append(syncedDCs, allocator.GetDCLocation()) + } + return &pdpb.SyncMaxTSResponse{ + Header: s.header(), + SyncedDcs: syncedDCs, + }, nil +} + +// SplitRegions split regions by the given split keys +func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegionsRequest) (*pdpb.SplitRegionsResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).SplitRegions(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.SplitRegionsResponse), err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.SplitRegionsResponse{Header: s.notBootstrappedHeader()}, nil + } + finishedPercentage, newRegionIDs := rc.GetRegionSplitter().SplitRegions(ctx, request.GetSplitKeys(), int(request.GetRetryLimit())) + return &pdpb.SplitRegionsResponse{ + Header: s.header(), + RegionsId: newRegionIDs, + FinishedPercentage: uint64(finishedPercentage), + }, nil +} + +// SplitAndScatterRegions split regions by the given split keys, and scatter regions. +// Only regions which splited successfully will be scattered. +// scatterFinishedPercentage indicates the percentage of successfully splited regions that are scattered. +func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.SplitAndScatterRegionsRequest) (*pdpb.SplitAndScatterRegionsResponse, error) { + fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { + return pdpb.NewPDClient(client).SplitAndScatterRegions(ctx, request) + } + if rsp, err := s.unaryMiddleware(ctx, request.GetHeader(), fn); err != nil { + return nil, err + } else if rsp != nil { + return rsp.(*pdpb.SplitAndScatterRegionsResponse), err + } + rc := s.GetRaftCluster() + splitFinishedPercentage, newRegionIDs := rc.GetRegionSplitter().SplitRegions(ctx, request.GetSplitKeys(), int(request.GetRetryLimit())) + scatterFinishedPercentage, err := scatterRegions(rc, newRegionIDs, request.GetGroup(), int(request.GetRetryLimit())) + if err != nil { + return nil, err + } + return &pdpb.SplitAndScatterRegionsResponse{ + Header: s.header(), + RegionsId: newRegionIDs, + SplitFinishedPercentage: uint64(splitFinishedPercentage), + ScatterFinishedPercentage: uint64(scatterFinishedPercentage), + }, nil +} + +// scatterRegions add operators to scatter regions and return the processed percentage and error +func scatterRegions(cluster *cluster.RaftCluster, regionsID []uint64, group string, retryLimit int) (int, error) { + opsCount, failures, err := cluster.GetRegionScatter().ScatterRegionsByID(regionsID, group, retryLimit) + if err != nil { + return 0, err + } + percentage := 100 + if len(failures) > 0 { + percentage = 100 - 100*len(failures)/(opsCount+len(failures)) + log.Debug("scatter regions", zap.Errors("failures", func() []error { + r := make([]error, 0, len(failures)) + for _, err := range failures { + r = append(r, err) + } + return r + }())) + } + return percentage, nil +} + +// GetDCLocationInfo gets the dc-location info of the given dc-location from PD leader's TSO allocator manager. +func (s *GrpcServer) GetDCLocationInfo(ctx context.Context, request *pdpb.GetDCLocationInfoRequest) (*pdpb.GetDCLocationInfoResponse, error) { + var err error + if err = s.validateInternalRequest(request.GetHeader(), false); err != nil { + return nil, err + } + if !s.member.IsLeader() { + return nil, ErrNotLeader + } + am := s.tsoAllocatorManager + info, ok := am.GetDCLocationInfo(request.GetDcLocation()) + if !ok { + am.ClusterDCLocationChecker() + return &pdpb.GetDCLocationInfoResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, + fmt.Sprintf("dc-location %s is not found", request.GetDcLocation())), + }, nil + } + resp := &pdpb.GetDCLocationInfoResponse{ + Header: s.header(), + Suffix: info.Suffix, + } + // Because the number of suffix bits is changing dynamically according to the dc-location number, + // there is a corner case may cause the Local TSO is not unique while member changing. + // Example: + // t1: xxxxxxxxxxxxxxx1 | 11 + // t2: xxxxxxxxxxxxxxx | 111 + // So we will force the newly added Local TSO Allocator to have a Global TSO synchronization + // when it becomes the Local TSO Allocator leader. + // Please take a look at https://github.com/tikv/pd/issues/3260 for more details. + if resp.MaxTs, err = am.GetMaxLocalTSO(ctx); err != nil { + return &pdpb.GetDCLocationInfoResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + }, nil + } + return resp, nil +} + +// validateInternalRequest checks if server is closed, which is used to validate +// the gRPC communication between PD servers internally. +func (s *GrpcServer) validateInternalRequest(header *pdpb.RequestHeader, onlyAllowLeader bool) error { + if s.IsClosed() { + return ErrNotStarted + } + // If onlyAllowLeader is true, check whether the sender is PD leader. + if onlyAllowLeader { + leaderID := s.GetLeader().GetMemberId() + if leaderID != header.GetSenderId() { + return status.Errorf(codes.FailedPrecondition, "%s, need %d but got %d", errs.MismatchLeaderErr, leaderID, header.GetSenderId()) + } + } + return nil +} + +func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string) (*grpc.ClientConn, error) { + client, ok := s.clientConns.Load(forwardedHost) + if !ok { + tlsConfig, err := s.GetTLSConfig().ToTLSConfig() + if err != nil { + return nil, err + } + cc, err := grpcutil.GetClientConn(ctx, forwardedHost, tlsConfig) + if err != nil { + return nil, err + } + client = cc + s.clientConns.Store(forwardedHost, cc) + } + return client.(*grpc.ClientConn), nil +} + +func getForwardedHost(ctx context.Context) string { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + log.Debug("failed to get forwarding metadata") + } + if t, ok := md[grpcutil.ForwardMetadataKey]; ok { + return t[0] + } + return "" +} + +func (s *GrpcServer) isLocalRequest(forwardedHost string) bool { + failpoint.Inject("useForwardRequest", func() { + failpoint.Return(false) + }) + if forwardedHost == "" { + return true + } + memberAddrs := s.GetMember().Member().GetClientUrls() + for _, addr := range memberAddrs { + if addr == forwardedHost { + return true + } + } + return false +} + +func (s *GrpcServer) createTsoForwardStream(client *grpc.ClientConn) (pdpb.PD_TsoClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go checkStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).Tso(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func (s *GrpcServer) createHeartbeatForwardStream(client *grpc.ClientConn) (pdpb.PD_RegionHeartbeatClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go checkStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).RegionHeartbeat(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { + defer close(errCh) + for { + resp, err := forwardStream.Recv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go checkStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).ReportBuckets(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient, server *bucketHeartbeatServer, errCh chan error) { + defer close(errCh) + for { + resp, err := forwardStream.CloseAndRecv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +// TODO: If goroutine here timeout when tso stream created successfully, we need to handle it correctly. +func checkStream(streamCtx context.Context, cancel context.CancelFunc, done chan struct{}) { + select { + case <-done: + return + case <-time.After(3 * time.Second): + cancel() + case <-streamCtx.Done(): + } + <-done +} + +// StoreGlobalConfig store global config into etcd by transaction +func (s *GrpcServer) StoreGlobalConfig(_ context.Context, request *pdpb.StoreGlobalConfigRequest) (*pdpb.StoreGlobalConfigResponse, error) { + ops := make([]clientv3.Op, len(request.Changes)) + for i, item := range request.Changes { + name := globalConfigPath + item.GetName() + value := item.GetValue() + ops[i] = clientv3.OpPut(name, value) + } + res, err := + kv.NewSlowLogTxn(s.client).Then(ops...).Commit() + if err != nil { + return &pdpb.StoreGlobalConfigResponse{Error: &pdpb.Error{Type: pdpb.ErrorType_UNKNOWN, Message: err.Error()}}, err + } + if !res.Succeeded { + return &pdpb.StoreGlobalConfigResponse{Error: &pdpb.Error{Type: pdpb.ErrorType_UNKNOWN, Message: "failed to execute StoreGlobalConfig transaction"}}, errors.Errorf("failed to execute StoreGlobalConfig transaction") + } + return &pdpb.StoreGlobalConfigResponse{}, err +} + +// LoadGlobalConfig load global config from etcd +func (s *GrpcServer) LoadGlobalConfig(ctx context.Context, request *pdpb.LoadGlobalConfigRequest) (*pdpb.LoadGlobalConfigResponse, error) { + names := request.Names + res := make([]*pdpb.GlobalConfigItem, len(names)) + for i, name := range names { + r, err := s.client.Get(ctx, globalConfigPath+name) + if err != nil { + res[i] = &pdpb.GlobalConfigItem{Name: name, Error: &pdpb.Error{Type: pdpb.ErrorType_UNKNOWN, Message: err.Error()}} + } else if len(r.Kvs) == 0 { + msg := "key " + name + " not found" + res[i] = &pdpb.GlobalConfigItem{Name: name, Error: &pdpb.Error{Type: pdpb.ErrorType_GLOBAL_CONFIG_NOT_FOUND, Message: msg}} + } else { + res[i] = &pdpb.GlobalConfigItem{Name: name, Value: string(r.Kvs[0].Value)} + } + } + return &pdpb.LoadGlobalConfigResponse{Items: res}, nil +} + +// WatchGlobalConfig if the connection of WatchGlobalConfig is end +// or stoped by whatever reason +// just reconnect to it. +func (s *GrpcServer) WatchGlobalConfig(_ *pdpb.WatchGlobalConfigRequest, server pdpb.PD_WatchGlobalConfigServer) error { + ctx, cancel := context.WithCancel(s.Context()) + defer cancel() + err := s.sendAllGlobalConfig(ctx, server) + if err != nil { + return err + } + watchChan := s.client.Watch(ctx, globalConfigPath, clientv3.WithPrefix()) + for { + select { + case <-ctx.Done(): + return nil + case res := <-watchChan: + cfgs := make([]*pdpb.GlobalConfigItem, 0, len(res.Events)) + for _, e := range res.Events { + if e.Type != clientv3.EventTypePut { + continue + } + cfgs = append(cfgs, &pdpb.GlobalConfigItem{Name: string(e.Kv.Key), Value: string(e.Kv.Value)}) + } + if len(cfgs) > 0 { + err := server.Send(&pdpb.WatchGlobalConfigResponse{Changes: cfgs}) + if err != nil { + return err + } + } + } + } +} + +func (s *GrpcServer) sendAllGlobalConfig(ctx context.Context, server pdpb.PD_WatchGlobalConfigServer) error { + configList, err := s.client.Get(ctx, globalConfigPath, clientv3.WithPrefix()) + if err != nil { + return err + } + ls := make([]*pdpb.GlobalConfigItem, configList.Count) + for i, kv := range configList.Kvs { + ls[i] = &pdpb.GlobalConfigItem{Name: string(kv.Key), Value: string(kv.Value)} + } + err = server.Send(&pdpb.WatchGlobalConfigResponse{Changes: ls}) + return err +} + +// Evict the leaders when the store is damaged. Damaged regions are emergency errors +// and requires user to manually remove the `evict-leader-scheduler` with pd-ctl +func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) { + // TODO: regions have no special process for the time being + // and need to be removed in the future + damagedRegions := stats.GetDamagedRegionsId() + if len(damagedRegions) == 0 { + return + } + + for _, regionID := range stats.GetDamagedRegionsId() { + // Remove peers to make sst recovery physically delete files in TiKV. + err := s.GetHandler().AddRemovePeerOperator(regionID, stats.GetStoreId()) + if err != nil { + log.Error("store damaged but can't add remove peer operator", + zap.Uint64("region-id", regionID), zap.Uint64("store-id", stats.GetStoreId()), zap.String("error", err.Error())) + } else { + log.Info("added remove peer operator due to damaged region", + zap.Uint64("region-id", regionID), zap.Uint64("store-id", stats.GetStoreId())) + } + } +} + +// ReportMinResolvedTS implements gRPC PDServer. +func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) { + forwardedHost := getForwardedHost(ctx) + if !s.isLocalRequest(forwardedHost) { + client, err := s.getDelegateClient(ctx, forwardedHost) + if err != nil { + return nil, err + } + ctx = grpcutil.ResetForwardContext(ctx) + return pdpb.NewPDClient(client).ReportMinResolvedTS(ctx, request) + } + + if err := s.validateRequest(request.GetHeader()); err != nil { + return nil, err + } + + rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.ReportMinResolvedTsResponse{Header: s.notBootstrappedHeader()}, nil + } + + storeID := request.GetStoreId() + minResolvedTS := request.GetMinResolvedTs() + if err := rc.SetMinResolvedTS(storeID, minResolvedTS); err != nil { + return nil, err + } + log.Debug("updated min resolved-ts", + zap.Uint64("store", storeID), + zap.Uint64("min resolved-ts", minResolvedTS)) + return &pdpb.ReportMinResolvedTsResponse{ + Header: s.header(), + }, nil +} + +// SetExternalTimestamp implements gRPC PDServer. +func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) { + forwardedHost := getForwardedHost(ctx) + if !s.isLocalRequest(forwardedHost) { + client, err := s.getDelegateClient(ctx, forwardedHost) + if err != nil { + return nil, err + } + ctx = grpcutil.ResetForwardContext(ctx) + return pdpb.NewPDClient(client).SetExternalTimestamp(ctx, request) + } + + if err := s.validateRequest(request.GetHeader()); err != nil { + return nil, err + } + + timestamp := request.GetTimestamp() + if err := s.SetExternalTS(timestamp); err != nil { + return &pdpb.SetExternalTimestampResponse{Header: s.invalidValue(err.Error())}, nil + } + log.Debug("set external timestamp", + zap.Uint64("timestamp", timestamp)) + return &pdpb.SetExternalTimestampResponse{ + Header: s.header(), + }, nil +} + +// GetExternalTimestamp implements gRPC PDServer. +func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) { + forwardedHost := getForwardedHost(ctx) + if !s.isLocalRequest(forwardedHost) { + client, err := s.getDelegateClient(ctx, forwardedHost) + if err != nil { + return nil, err + } + ctx = grpcutil.ResetForwardContext(ctx) + return pdpb.NewPDClient(client).GetExternalTimestamp(ctx, request) + } + + if err := s.validateRequest(request.GetHeader()); err != nil { + return nil, err + } + + timestamp := s.GetExternalTS() + return &pdpb.GetExternalTimestampResponse{ + Header: s.header(), + Timestamp: timestamp, + }, nil +} diff --git a/server/join/binding__failpoint_binding__.go b/server/join/binding__failpoint_binding__.go new file mode 100755 index 00000000000..6560aaf35d7 --- /dev/null +++ b/server/join/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package join + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/join/join.go b/server/join/join.go old mode 100644 new mode 100755 index 6c675bb6d3a..60180fe50f9 --- a/server/join/join.go +++ b/server/join/join.go @@ -150,10 +150,10 @@ func PrepareJoinCluster(cfg *config.Config) error { var addResp *clientv3.MemberAddResponse - failpoint.Inject("add-member-failed", func() { + if _, _err_ := failpoint.Eval(_curpkg_("add-member-failed")); _err_ == nil { listMemberRetryTimes = 2 - failpoint.Goto("LabelSkipAddMember") - }) + goto LabelSkipAddMember + } // - A new PD joins an existing cluster. // - A deleted PD joins to previous cluster. { @@ -163,7 +163,7 @@ func PrepareJoinCluster(cfg *config.Config) error { return err } } - failpoint.Label("LabelSkipAddMember") +LabelSkipAddMember: var ( pds []string diff --git a/server/join/join.go__failpoint_stash__ b/server/join/join.go__failpoint_stash__ new file mode 100644 index 00000000000..6c675bb6d3a --- /dev/null +++ b/server/join/join.go__failpoint_stash__ @@ -0,0 +1,233 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "fmt" + "os" + "path" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/etcdutil" + "github.com/tikv/pd/server/config" + "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/embed" + "go.uber.org/zap" +) + +const ( + // privateFileMode grants owner to read/write a file. + privateFileMode = 0600 + // privateDirMode grants owner to make/remove files inside the directory. + privateDirMode = 0700 +) + +// listMemberRetryTimes is the retry times of list member. +var listMemberRetryTimes = 20 + +// PrepareJoinCluster sends MemberAdd command to PD cluster, +// and returns the initial configuration of the PD cluster. +// +// TL;TR: The join functionality is safe. With data, join does nothing, w/o data +// +// and it is not a member of cluster, join does MemberAdd, it returns an +// error if PD tries to join itself, missing data or join a duplicated PD. +// +// Etcd automatically re-joins the cluster if there is a data directory. So +// first it checks if there is a data directory or not. If there is, it returns +// an empty string (etcd will get the correct configurations from the data +// directory.) +// +// If there is no data directory, there are following cases: +// +// - A new PD joins an existing cluster. +// What join does: MemberAdd, MemberList, then generate initial-cluster. +// +// - A failed PD re-joins the previous cluster. +// What join does: return an error. (etcd reports: raft log corrupted, +// truncated, or lost?) +// +// - A deleted PD joins to previous cluster. +// What join does: MemberAdd, MemberList, then generate initial-cluster. +// (it is not in the member list and there is no data, so +// we can treat it as a new PD.) +// +// If there is a data directory, there are following special cases: +// +// - A failed PD tries to join the previous cluster but it has been deleted +// during its downtime. +// What join does: return "" (etcd will connect to other peers and find +// that the PD itself has been removed.) +// +// - A deleted PD joins the previous cluster. +// What join does: return "" (as etcd will read data directory and find +// that the PD itself has been removed, so an empty string +// is fine.) +func PrepareJoinCluster(cfg *config.Config) error { + // - A PD tries to join itself. + if cfg.Join == "" { + return nil + } + + if cfg.Join == cfg.AdvertiseClientUrls { + return errors.New("join self is forbidden") + } + + filePath := path.Join(cfg.DataDir, "join") + // Read the persist join config + if _, err := os.Stat(filePath); !os.IsNotExist(err) { + s, err := os.ReadFile(filePath) + if err != nil { + log.Fatal("read the join config meet error", errs.ZapError(errs.ErrIORead, err)) + } + cfg.InitialCluster = strings.TrimSpace(string(s)) + cfg.InitialClusterState = embed.ClusterStateFlagExisting + return nil + } + + initialCluster := "" + // Cases with data directory. + if isDataExist(path.Join(cfg.DataDir, "member")) { + cfg.InitialCluster = initialCluster + cfg.InitialClusterState = embed.ClusterStateFlagExisting + return nil + } + + // Below are cases without data directory. + tlsConfig, err := cfg.Security.ToTLSConfig() + if err != nil { + return err + } + lgc := zap.NewProductionConfig() + lgc.Encoding = log.ZapEncodingName + client, err := clientv3.New(clientv3.Config{ + Endpoints: strings.Split(cfg.Join, ","), + DialTimeout: etcdutil.DefaultDialTimeout, + TLS: tlsConfig, + LogConfig: &lgc, + }) + if err != nil { + return errors.WithStack(err) + } + defer client.Close() + + listResp, err := etcdutil.ListEtcdMembers(client) + if err != nil { + return err + } + + existed := false + for _, m := range listResp.Members { + if len(m.Name) == 0 { + return errors.New("there is a member that has not joined successfully") + } + if m.Name == cfg.Name { + existed = true + } + } + + // - A failed PD re-joins the previous cluster. + if existed { + return errors.New("missing data or join a duplicated pd") + } + + var addResp *clientv3.MemberAddResponse + + failpoint.Inject("add-member-failed", func() { + listMemberRetryTimes = 2 + failpoint.Goto("LabelSkipAddMember") + }) + // - A new PD joins an existing cluster. + // - A deleted PD joins to previous cluster. + { + // First adds member through the API + addResp, err = etcdutil.AddEtcdMember(client, []string{cfg.AdvertisePeerUrls}) + if err != nil { + return err + } + } + failpoint.Label("LabelSkipAddMember") + + var ( + pds []string + listSucc bool + ) + + for i := 0; i < listMemberRetryTimes; i++ { + listResp, err = etcdutil.ListEtcdMembers(client) + if err != nil { + return err + } + + pds = []string{} + for _, memb := range listResp.Members { + n := memb.Name + if addResp != nil && memb.ID == addResp.Member.ID { + n = cfg.Name + listSucc = true + } + if len(n) == 0 { + return errors.New("there is a member that has not joined successfully") + } + for _, m := range memb.PeerURLs { + pds = append(pds, fmt.Sprintf("%s=%s", n, m)) + } + } + + if listSucc { + break + } + time.Sleep(500 * time.Millisecond) + } + if !listSucc { + return errors.Errorf("join failed, adds the new member %s may failed", cfg.Name) + } + + initialCluster = strings.Join(pds, ",") + cfg.InitialCluster = initialCluster + cfg.InitialClusterState = embed.ClusterStateFlagExisting + err = os.MkdirAll(cfg.DataDir, privateDirMode) + if err != nil && !os.IsExist(err) { + return errors.WithStack(err) + } + + err = os.WriteFile(filePath, []byte(cfg.InitialCluster), privateFileMode) + return errors.WithStack(err) +} + +func isDataExist(d string) bool { + dir, err := os.Open(d) + if err != nil { + log.Info("failed to open directory, maybe start for the first time", errs.ZapError(err)) + return false + } + defer func() { + if err := dir.Close(); err != nil { + log.Error("failed to close file", errs.ZapError(err)) + } + }() + + names, err := dir.Readdirnames(-1) + if err != nil { + log.Error("failed to list directory", errs.ZapError(errs.ErrReadDirName, err)) + return false + } + return len(names) != 0 +} diff --git a/server/region_syncer/binding__failpoint_binding__.go b/server/region_syncer/binding__failpoint_binding__.go new file mode 100755 index 00000000000..3db8a06874b --- /dev/null +++ b/server/region_syncer/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package syncer + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/region_syncer/server.go b/server/region_syncer/server.go old mode 100644 new mode 100755 index 14a5c75a78a..036e710868b --- a/server/region_syncer/server.go +++ b/server/region_syncer/server.go @@ -244,13 +244,13 @@ func (s *RegionSyncer) syncHistoryRegion(ctx context.Context, request *pdpb.Sync select { case <-ctx.Done(): log.Info("discontinue sending sync region response") - failpoint.Inject("noFastExitSync", func() { - failpoint.Goto("doSync") - }) + if _, _err_ := failpoint.Eval(_curpkg_("noFastExitSync")); _err_ == nil { + goto doSync + } return nil default: } - failpoint.Label("doSync") + doSync: metas = append(metas, r.GetMeta()) stats = append(stats, r.GetStat()) leader := &metapb.Peer{} diff --git a/server/region_syncer/server.go__failpoint_stash__ b/server/region_syncer/server.go__failpoint_stash__ new file mode 100644 index 00000000000..14a5c75a78a --- /dev/null +++ b/server/region_syncer/server.go__failpoint_stash__ @@ -0,0 +1,355 @@ +// Copyright 2018 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncer + +import ( + "context" + "io" + "sync" + "time" + + "github.com/docker/go-units" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/grpcutil" + "github.com/tikv/pd/pkg/ratelimit" + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/storage" + "github.com/tikv/pd/server/storage/kv" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + msgSize = 8 * units.MiB + defaultBucketRate = 20 * units.MiB // 20MB/s + defaultBucketCapacity = 20 * units.MiB // 20MB + maxSyncRegionBatchSize = 100 + syncerKeepAliveInterval = 10 * time.Second + defaultHistoryBufferSize = 10000 +) + +// ClientStream is the client side of the region syncer. +type ClientStream interface { + Recv() (*pdpb.SyncRegionResponse, error) + CloseSend() error +} + +// ServerStream is the server side of the region syncer. +type ServerStream interface { + Send(regions *pdpb.SyncRegionResponse) error +} + +// Server is the abstraction of the syncer storage server. +type Server interface { + LoopContext() context.Context + ClusterID() uint64 + GetMemberInfo() *pdpb.Member + GetLeader() *pdpb.Member + GetStorage() storage.Storage + Name() string + GetRegions() []*core.RegionInfo + GetTLSConfig() *grpcutil.TLSConfig + GetBasicCluster() *core.BasicCluster +} + +// RegionSyncer is used to sync the region information without raft. +type RegionSyncer struct { + mu struct { + syncutil.RWMutex + streams map[string]ServerStream + clientCtx context.Context + clientCancel context.CancelFunc + } + server Server + wg sync.WaitGroup + history *historyBuffer + limit *ratelimit.RateLimiter + tlsConfig *grpcutil.TLSConfig +} + +// NewRegionSyncer returns a region syncer. +// The final consistency is ensured by the heartbeat. +// Strong consistency is not guaranteed. +// Usually open the region syncer in huge cluster and the server +// no longer etcd but go-leveldb. +func NewRegionSyncer(s Server) *RegionSyncer { + localRegionStorage := storage.TryGetLocalRegionStorage(s.GetStorage()) + if localRegionStorage == nil { + return nil + } + syncer := &RegionSyncer{ + server: s, + history: newHistoryBuffer(defaultHistoryBufferSize, localRegionStorage.(kv.Base)), + limit: ratelimit.NewRateLimiter(defaultBucketRate, defaultBucketCapacity), + tlsConfig: s.GetTLSConfig(), + } + syncer.mu.streams = make(map[string]ServerStream) + return syncer +} + +// RunServer runs the server of the region syncer. +// regionNotifier is used to get the changed regions. +func (s *RegionSyncer) RunServer(ctx context.Context, regionNotifier <-chan *core.RegionInfo) { + var requests []*metapb.Region + var stats []*pdpb.RegionStat + var leaders []*metapb.Peer + var buckets []*metapb.Buckets + ticker := time.NewTicker(syncerKeepAliveInterval) + + defer func() { + ticker.Stop() + s.mu.Lock() + s.mu.streams = make(map[string]ServerStream) + s.mu.Unlock() + }() + + for { + select { + case <-ctx.Done(): + log.Info("region syncer has been stopped") + return + case first := <-regionNotifier: + requests = append(requests, first.GetMeta()) + stats = append(stats, first.GetStat()) + // bucket should not be nil to avoid grpc marshal panic. + bucket := &metapb.Buckets{} + if b := first.GetBuckets(); b != nil { + bucket = b + } + buckets = append(buckets, bucket) + leaders = append(leaders, first.GetLeader()) + startIndex := s.history.GetNextIndex() + s.history.Record(first) + pending := len(regionNotifier) + for i := 0; i < pending && i < maxSyncRegionBatchSize; i++ { + region := <-regionNotifier + requests = append(requests, region.GetMeta()) + stats = append(stats, region.GetStat()) + // bucket should not be nil to avoid grpc marshal panic. + bucket := &metapb.Buckets{} + if b := region.GetBuckets(); b != nil { + bucket = b + } + buckets = append(buckets, bucket) + leaders = append(leaders, region.GetLeader()) + s.history.Record(region) + } + regions := &pdpb.SyncRegionResponse{ + Header: &pdpb.ResponseHeader{ClusterId: s.server.ClusterID()}, + Regions: requests, + StartIndex: startIndex, + RegionStats: stats, + RegionLeaders: leaders, + Buckets: buckets, + } + s.broadcast(regions) + case <-ticker.C: + alive := &pdpb.SyncRegionResponse{ + Header: &pdpb.ResponseHeader{ClusterId: s.server.ClusterID()}, + StartIndex: s.history.GetNextIndex(), + } + s.broadcast(alive) + } + requests = requests[:0] + stats = stats[:0] + leaders = leaders[:0] + buckets = buckets[:0] + } +} + +// GetAllDownstreamNames tries to get the all bind stream's name. +// Only for test +func (s *RegionSyncer) GetAllDownstreamNames() []string { + s.mu.RLock() + names := make([]string, 0, len(s.mu.streams)) + for name := range s.mu.streams { + names = append(names, name) + } + s.mu.RUnlock() + return names +} + +// Sync firstly tries to sync the history records to client. +// then to sync the latest records. +func (s *RegionSyncer) Sync(ctx context.Context, stream pdpb.PD_SyncRegionsServer) error { + for { + select { + case <-ctx.Done(): + return nil + default: + } + + request, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return errors.WithStack(err) + } + clusterID := request.GetHeader().GetClusterId() + if clusterID != s.server.ClusterID() { + return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.server.ClusterID(), clusterID) + } + log.Info("establish sync region stream", + zap.String("requested-server", request.GetMember().GetName()), + zap.String("url", request.GetMember().GetClientUrls()[0])) + + err = s.syncHistoryRegion(ctx, request, stream) + if err != nil { + return err + } + s.bindStream(request.GetMember().GetName(), stream) + } +} + +func (s *RegionSyncer) syncHistoryRegion(ctx context.Context, request *pdpb.SyncRegionRequest, stream pdpb.PD_SyncRegionsServer) error { + startIndex := request.GetStartIndex() + name := request.GetMember().GetName() + records := s.history.RecordsFrom(startIndex) + if len(records) == 0 { + if s.history.GetNextIndex() == startIndex { + log.Info("requested server has already in sync with server", + zap.String("requested-server", name), zap.String("server", s.server.Name()), zap.Uint64("last-index", startIndex)) + return nil + } + // do full synchronization + if startIndex == 0 { + regions := s.server.GetRegions() + lastIndex := 0 + start := time.Now() + metas := make([]*metapb.Region, 0, maxSyncRegionBatchSize) + stats := make([]*pdpb.RegionStat, 0, maxSyncRegionBatchSize) + leaders := make([]*metapb.Peer, 0, maxSyncRegionBatchSize) + buckets := make([]*metapb.Buckets, 0, maxSyncRegionBatchSize) + for syncedIndex, r := range regions { + select { + case <-ctx.Done(): + log.Info("discontinue sending sync region response") + failpoint.Inject("noFastExitSync", func() { + failpoint.Goto("doSync") + }) + return nil + default: + } + failpoint.Label("doSync") + metas = append(metas, r.GetMeta()) + stats = append(stats, r.GetStat()) + leader := &metapb.Peer{} + if r.GetLeader() != nil { + leader = r.GetLeader() + } + leaders = append(leaders, leader) + bucket := &metapb.Buckets{} + if r.GetBuckets() != nil { + bucket = r.GetBuckets() + } + buckets = append(buckets, bucket) + if len(metas) < maxSyncRegionBatchSize && syncedIndex < len(regions)-1 { + continue + } + resp := &pdpb.SyncRegionResponse{ + Header: &pdpb.ResponseHeader{ClusterId: s.server.ClusterID()}, + Regions: metas, + StartIndex: uint64(lastIndex), + RegionStats: stats, + RegionLeaders: leaders, + Buckets: buckets, + } + s.limit.WaitN(ctx, resp.Size()) + lastIndex += len(metas) + if err := stream.Send(resp); err != nil { + log.Error("failed to send sync region response", errs.ZapError(errs.ErrGRPCSend, err)) + return err + } + metas = metas[:0] + stats = stats[:0] + leaders = leaders[:0] + buckets = buckets[:0] + } + log.Info("requested server has completed full synchronization with server", + zap.String("requested-server", name), zap.String("server", s.server.Name()), zap.Duration("cost", time.Since(start))) + return nil + } + log.Warn("no history regions from index, the leader may be restarted", zap.Uint64("index", startIndex)) + return nil + } + log.Info("sync the history regions with server", + zap.String("server", name), + zap.Uint64("from-index", startIndex), + zap.Uint64("last-index", s.history.GetNextIndex()), + zap.Int("records-length", len(records))) + regions := make([]*metapb.Region, len(records)) + stats := make([]*pdpb.RegionStat, len(records)) + leaders := make([]*metapb.Peer, len(records)) + buckets := make([]*metapb.Buckets, len(records)) + for i, r := range records { + regions[i] = r.GetMeta() + stats[i] = r.GetStat() + leader := &metapb.Peer{} + if r.GetLeader() != nil { + leader = r.GetLeader() + } + leaders[i] = leader + // bucket should not be nil to avoid grpc marshal panic. + buckets[i] = &metapb.Buckets{} + if r.GetBuckets() != nil { + buckets[i] = r.GetBuckets() + } + } + resp := &pdpb.SyncRegionResponse{ + Header: &pdpb.ResponseHeader{ClusterId: s.server.ClusterID()}, + Regions: regions, + StartIndex: startIndex, + RegionStats: stats, + RegionLeaders: leaders, + Buckets: buckets, + } + return stream.Send(resp) +} + +// bindStream binds the established server stream. +func (s *RegionSyncer) bindStream(name string, stream ServerStream) { + s.mu.Lock() + defer s.mu.Unlock() + s.mu.streams[name] = stream +} + +func (s *RegionSyncer) broadcast(regions *pdpb.SyncRegionResponse) { + var failed []string + s.mu.RLock() + for name, sender := range s.mu.streams { + err := sender.Send(regions) + if err != nil { + log.Error("region syncer send data meet error", errs.ZapError(errs.ErrGRPCSend, err)) + failed = append(failed, name) + } + } + s.mu.RUnlock() + if len(failed) > 0 { + s.mu.Lock() + for _, name := range failed { + delete(s.mu.streams, name) + log.Info("region syncer delete the stream", zap.String("stream", name)) + } + s.mu.Unlock() + } +} diff --git a/server/schedule/binding__failpoint_binding__.go b/server/schedule/binding__failpoint_binding__.go new file mode 100755 index 00000000000..363e7d92b7d --- /dev/null +++ b/server/schedule/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package schedule + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/schedule/checker/binding__failpoint_binding__.go b/server/schedule/checker/binding__failpoint_binding__.go new file mode 100755 index 00000000000..4b0c7dae4c1 --- /dev/null +++ b/server/schedule/checker/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package checker + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/schedule/checker/rule_checker.go b/server/schedule/checker/rule_checker.go old mode 100644 new mode 100755 index 16cc3cf46b4..60c0342e769 --- a/server/schedule/checker/rule_checker.go +++ b/server/schedule/checker/rule_checker.go @@ -98,15 +98,15 @@ func (c *RuleChecker) CheckWithFit(region *core.RegionInfo, fit *placement.Regio } // If the fit is fetched from cache, it seems that the region doesn't need cache if c.cluster.GetOpts().IsPlacementRulesCacheEnabled() && fit.IsCached() { - failpoint.Inject("assertShouldNotCache", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertShouldNotCache")); _err_ == nil { panic("cached shouldn't be used") - }) + } checkerCounter.WithLabelValues("rule_checker", "get-cache").Inc() return nil } - failpoint.Inject("assertShouldCache", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertShouldCache")); _err_ == nil { panic("cached should be used") - }) + } // If the fit is calculated by FitRegion, which means we get a new fit result, thus we should // invalid the cache if it exists diff --git a/server/schedule/checker/rule_checker.go__failpoint_stash__ b/server/schedule/checker/rule_checker.go__failpoint_stash__ new file mode 100644 index 00000000000..16cc3cf46b4 --- /dev/null +++ b/server/schedule/checker/rule_checker.go__failpoint_stash__ @@ -0,0 +1,561 @@ +// Copyright 2019 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checker + +import ( + "context" + "errors" + "math" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/cache" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/schedule" + "github.com/tikv/pd/server/schedule/filter" + "github.com/tikv/pd/server/schedule/operator" + "github.com/tikv/pd/server/schedule/placement" + "github.com/tikv/pd/server/versioninfo" + "go.uber.org/zap" +) + +var ( + errNoStoreToAdd = errors.New("no store to add peer") + errNoStoreToReplace = errors.New("no store to replace peer") + errPeerCannotBeLeader = errors.New("peer cannot be leader") + errPeerCannotBeWitness = errors.New("peer cannot be witness") + errNoNewLeader = errors.New("no new leader") + errRegionNoLeader = errors.New("region no leader") +) + +const maxPendingListLen = 100000 + +// RuleChecker fix/improve region by placement rules. +type RuleChecker struct { + PauseController + cluster schedule.Cluster + ruleManager *placement.RuleManager + name string + regionWaitingList cache.Cache + pendingList cache.Cache + switchWitnessCache *cache.TTLUint64 + record *recorder +} + +// NewRuleChecker creates a checker instance. +func NewRuleChecker(ctx context.Context, cluster schedule.Cluster, ruleManager *placement.RuleManager, regionWaitingList cache.Cache) *RuleChecker { + return &RuleChecker{ + cluster: cluster, + ruleManager: ruleManager, + name: "rule-checker", + regionWaitingList: regionWaitingList, + pendingList: cache.NewDefaultCache(maxPendingListLen), + switchWitnessCache: cache.NewIDTTL(ctx, time.Minute, cluster.GetOpts().GetSwitchWitnessInterval()), + record: newRecord(), + } +} + +// GetType returns RuleChecker's Type +func (c *RuleChecker) GetType() string { + return "rule-checker" +} + +// Check checks if the region matches placement rules and returns Operator to +// fix it. +func (c *RuleChecker) Check(region *core.RegionInfo) *operator.Operator { + fit := c.cluster.GetRuleManager().FitRegion(c.cluster, region) + return c.CheckWithFit(region, fit) +} + +// CheckWithFit is similar with Checker with placement.RegionFit +func (c *RuleChecker) CheckWithFit(region *core.RegionInfo, fit *placement.RegionFit) (op *operator.Operator) { + // checker is paused + if c.IsPaused() { + checkerCounter.WithLabelValues("rule_checker", "paused").Inc() + return nil + } + // skip no leader region + if region.GetLeader() == nil { + checkerCounter.WithLabelValues("rule_checker", "region-no-leader").Inc() + log.Debug("fail to check region", zap.Uint64("region-id", region.GetID()), zap.Error(errRegionNoLeader)) + return + } + // If the fit is fetched from cache, it seems that the region doesn't need cache + if c.cluster.GetOpts().IsPlacementRulesCacheEnabled() && fit.IsCached() { + failpoint.Inject("assertShouldNotCache", func() { + panic("cached shouldn't be used") + }) + checkerCounter.WithLabelValues("rule_checker", "get-cache").Inc() + return nil + } + failpoint.Inject("assertShouldCache", func() { + panic("cached should be used") + }) + + // If the fit is calculated by FitRegion, which means we get a new fit result, thus we should + // invalid the cache if it exists + c.ruleManager.InvalidCache(region.GetID()) + + checkerCounter.WithLabelValues("rule_checker", "check").Inc() + c.record.refresh(c.cluster) + + if len(fit.RuleFits) == 0 { + checkerCounter.WithLabelValues("rule_checker", "need-split").Inc() + // If the region matches no rules, the most possible reason is it spans across + // multiple rules. + return nil + } + op, err := c.fixOrphanPeers(region, fit) + if err != nil { + log.Debug("fail to fix orphan peer", errs.ZapError(err)) + } else if op != nil { + c.pendingList.Remove(region.GetID()) + return op + } + for _, rf := range fit.RuleFits { + op, err := c.fixRulePeer(region, fit, rf) + if err != nil { + log.Debug("fail to fix rule peer", zap.String("rule-group", rf.Rule.GroupID), zap.String("rule-id", rf.Rule.ID), errs.ZapError(err)) + continue + } + if op != nil { + c.pendingList.Remove(region.GetID()) + return op + } + } + if c.cluster.GetOpts().IsPlacementRulesCacheEnabled() { + if placement.ValidateFit(fit) && placement.ValidateRegion(region) && placement.ValidateStores(fit.GetRegionStores()) { + // If there is no need to fix, we will cache the fit + c.ruleManager.SetRegionFitCache(region, fit) + checkerCounter.WithLabelValues("rule_checker", "set-cache").Inc() + } + } + return nil +} + +// RecordRegionPromoteToNonWitness put the recently switch non-witness region into cache. RuleChecker +// will skip switch it back to witness for a while. +func (c *RuleChecker) RecordRegionPromoteToNonWitness(regionID uint64) { + c.switchWitnessCache.PutWithTTL(regionID, nil, c.cluster.GetOpts().GetSwitchWitnessInterval()) +} + +func (c *RuleChecker) isWitnessEnabled() bool { + return versioninfo.IsFeatureSupported(c.cluster.GetOpts().GetClusterVersion(), versioninfo.SwitchWitness) && + c.cluster.GetOpts().IsWitnessAllowed() +} + +func (c *RuleChecker) fixRulePeer(region *core.RegionInfo, fit *placement.RegionFit, rf *placement.RuleFit) (*operator.Operator, error) { + // make up peers. + if len(rf.Peers) < rf.Rule.Count { + return c.addRulePeer(region, rf) + } + // fix down/offline peers. + for _, peer := range rf.Peers { + if c.isDownPeer(region, peer) { + if c.isStoreDownTimeHitMaxDownTime(peer.GetStoreId()) { + checkerCounter.WithLabelValues("rule_checker", "replace-down").Inc() + return c.replaceUnexpectRulePeer(region, rf, fit, peer, downStatus) + } + // When witness placement rule is enabled, promotes the witness to voter when region has down peer. + if c.isWitnessEnabled() { + if witness, ok := c.hasAvailableWitness(region, peer); ok { + checkerCounter.WithLabelValues("rule_checker", "promote-witness").Inc() + return operator.CreateNonWitnessPeerOperator("promote-witness", c.cluster, region, witness) + } + } + } + if c.isOfflinePeer(peer) { + checkerCounter.WithLabelValues("rule_checker", "replace-offline").Inc() + return c.replaceUnexpectRulePeer(region, rf, fit, peer, offlineStatus) + } + } + // fix loose matched peers. + for _, peer := range rf.PeersWithDifferentRole { + op, err := c.fixLooseMatchPeer(region, fit, rf, peer) + if err != nil { + return nil, err + } + if op != nil { + return op, nil + } + } + return c.fixBetterLocation(region, rf) +} + +func (c *RuleChecker) addRulePeer(region *core.RegionInfo, rf *placement.RuleFit) (*operator.Operator, error) { + checkerCounter.WithLabelValues("rule_checker", "add-rule-peer").Inc() + ruleStores := c.getRuleFitStores(rf) + store, filterByTempState := c.strategy(region, rf.Rule).SelectStoreToAdd(ruleStores) + if store == 0 { + checkerCounter.WithLabelValues("rule_checker", "no-store-add").Inc() + c.handleFilterState(region, filterByTempState) + return nil, errNoStoreToAdd + } + isWitness := rf.Rule.IsWitness + if !c.isWitnessEnabled() { + isWitness = false + } + peer := &metapb.Peer{StoreId: store, Role: rf.Rule.Role.MetaPeerRole(), IsWitness: isWitness} + op, err := operator.CreateAddPeerOperator("add-rule-peer", c.cluster, region, peer, operator.OpReplica) + if err != nil { + return nil, err + } + op.SetPriorityLevel(core.High) + return op, nil +} + +// The peer's store may in Offline or Down, need to be replace. +func (c *RuleChecker) replaceUnexpectRulePeer(region *core.RegionInfo, rf *placement.RuleFit, fit *placement.RegionFit, peer *metapb.Peer, status string) (*operator.Operator, error) { + ruleStores := c.getRuleFitStores(rf) + store, filterByTempState := c.strategy(region, rf.Rule).SelectStoreToFix(ruleStores, peer.GetStoreId()) + if store == 0 { + checkerCounter.WithLabelValues("rule_checker", "no-store-replace").Inc() + c.handleFilterState(region, filterByTempState) + return nil, errNoStoreToReplace + } + var isWitness bool + if c.isWitnessEnabled() { + // No matter whether witness placement rule is enabled or disabled, when peer's downtime + // exceeds the threshold(30min), add a witness and remove the down peer. Then witness is + // promoted to non-witness gradually to improve availability. + if status == "down" { + isWitness = true + } else { + isWitness = rf.Rule.IsWitness + } + } else { + isWitness = false + } + newPeer := &metapb.Peer{StoreId: store, Role: rf.Rule.Role.MetaPeerRole(), IsWitness: isWitness} + // pick the smallest leader store to avoid the Offline store be snapshot generator bottleneck. + var newLeader *metapb.Peer + if region.GetLeader().GetId() == peer.GetId() { + minCount := uint64(math.MaxUint64) + for _, p := range region.GetPeers() { + count := c.record.getOfflineLeaderCount(p.GetStoreId()) + checkPeerhealth := func() bool { + if p.GetId() == peer.GetId() { + return true + } + if region.GetDownPeer(p.GetId()) != nil || region.GetPendingPeer(p.GetId()) != nil { + return false + } + return c.allowLeader(fit, p) + } + if minCount > count && checkPeerhealth() { + minCount = count + newLeader = p + } + } + } + + createOp := func() (*operator.Operator, error) { + if newLeader != nil && newLeader.GetId() != peer.GetId() { + return operator.CreateReplaceLeaderPeerOperator("replace-rule-"+status+"-leader-peer", c.cluster, region, operator.OpReplica, peer.StoreId, newPeer, newLeader) + } + return operator.CreateMovePeerOperator("replace-rule-"+status+"-peer", c.cluster, region, operator.OpReplica, peer.StoreId, newPeer) + } + op, err := createOp() + if err != nil { + return nil, err + } + if newLeader != nil { + c.record.incOfflineLeaderCount(newLeader.GetStoreId()) + } + op.SetPriorityLevel(core.High) + return op, nil +} + +func (c *RuleChecker) fixLooseMatchPeer(region *core.RegionInfo, fit *placement.RegionFit, rf *placement.RuleFit, peer *metapb.Peer) (*operator.Operator, error) { + if core.IsLearner(peer) && rf.Rule.Role != placement.Learner { + checkerCounter.WithLabelValues("rule_checker", "fix-peer-role").Inc() + return operator.CreatePromoteLearnerOperator("fix-peer-role", c.cluster, region, peer) + } + if region.GetLeader().GetId() != peer.GetId() && rf.Rule.Role == placement.Leader { + checkerCounter.WithLabelValues("rule_checker", "fix-leader-role").Inc() + if c.allowLeader(fit, peer) { + return operator.CreateTransferLeaderOperator("fix-leader-role", c.cluster, region, region.GetLeader().GetStoreId(), peer.GetStoreId(), []uint64{}, 0) + } + checkerCounter.WithLabelValues("rule_checker", "not-allow-leader") + return nil, errPeerCannotBeLeader + } + if region.GetLeader().GetId() == peer.GetId() && rf.Rule.Role == placement.Follower { + checkerCounter.WithLabelValues("rule_checker", "fix-follower-role").Inc() + for _, p := range region.GetPeers() { + if c.allowLeader(fit, p) { + return operator.CreateTransferLeaderOperator("fix-follower-role", c.cluster, region, peer.GetStoreId(), p.GetStoreId(), []uint64{}, 0) + } + } + checkerCounter.WithLabelValues("rule_checker", "no-new-leader").Inc() + return nil, errNoNewLeader + } + if core.IsVoter(peer) && rf.Rule.Role == placement.Learner { + checkerCounter.WithLabelValues("rule_checker", "demote-voter-role").Inc() + return operator.CreateDemoteVoterOperator("fix-demote-voter", c.cluster, region, peer) + } + if region.GetLeader().GetId() == peer.GetId() && rf.Rule.IsWitness { + return nil, errPeerCannotBeWitness + } + if !core.IsWitness(peer) && rf.Rule.IsWitness && c.isWitnessEnabled() { + c.switchWitnessCache.UpdateTTL(c.cluster.GetOpts().GetSwitchWitnessInterval()) + if c.switchWitnessCache.Exists(region.GetID()) { + checkerCounter.WithLabelValues("rule_checker", "recently-promote-to-non-witness").Inc() + return nil, nil + } + if len(region.GetPendingPeers()) > 0 { + checkerCounter.WithLabelValues("rule_checker", "cancel-switch-to-witness").Inc() + return nil, nil + } + lv := "set-voter-witness" + if core.IsLearner(peer) { + lv = "set-learner-witness" + } + checkerCounter.WithLabelValues("rule_checker", lv).Inc() + return operator.CreateWitnessPeerOperator("fix-witness-peer", c.cluster, region, peer) + } else if core.IsWitness(peer) && (!rf.Rule.IsWitness || !c.isWitnessEnabled()) { + lv := "set-voter-non-witness" + if core.IsLearner(peer) { + lv = "set-learner-non-witness" + } + checkerCounter.WithLabelValues("rule_checker", lv).Inc() + return operator.CreateNonWitnessPeerOperator("fix-non-witness-peer", c.cluster, region, peer) + } + return nil, nil +} + +func (c *RuleChecker) allowLeader(fit *placement.RegionFit, peer *metapb.Peer) bool { + if core.IsLearner(peer) { + return false + } + s := c.cluster.GetStore(peer.GetStoreId()) + if s == nil { + return false + } + stateFilter := &filter.StoreStateFilter{ActionScope: "rule-checker", TransferLeader: true} + if !stateFilter.Target(c.cluster.GetOpts(), s).IsOK() { + return false + } + for _, rf := range fit.RuleFits { + if (rf.Rule.Role == placement.Leader || rf.Rule.Role == placement.Voter) && + placement.MatchLabelConstraints(s, rf.Rule.LabelConstraints) { + return true + } + } + return false +} + +func (c *RuleChecker) fixBetterLocation(region *core.RegionInfo, rf *placement.RuleFit) (*operator.Operator, error) { + if len(rf.Rule.LocationLabels) == 0 || rf.Rule.Count <= 1 { + return nil, nil + } + + strategy := c.strategy(region, rf.Rule) + ruleStores := c.getRuleFitStores(rf) + oldStore := strategy.SelectStoreToRemove(ruleStores) + if oldStore == 0 { + return nil, nil + } + newStore, filterByTempState := strategy.SelectStoreToImprove(ruleStores, oldStore) + if newStore == 0 { + log.Debug("no replacement store", zap.Uint64("region-id", region.GetID())) + c.handleFilterState(region, filterByTempState) + return nil, nil + } + checkerCounter.WithLabelValues("rule_checker", "move-to-better-location").Inc() + isWitness := rf.Rule.IsWitness + if !c.isWitnessEnabled() { + isWitness = false + } + newPeer := &metapb.Peer{StoreId: newStore, Role: rf.Rule.Role.MetaPeerRole(), IsWitness: isWitness} + return operator.CreateMovePeerOperator("move-to-better-location", c.cluster, region, operator.OpReplica, oldStore, newPeer) +} + +func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.RegionFit) (*operator.Operator, error) { + if len(fit.OrphanPeers) == 0 { + return nil, nil + } + isUnhealthyPeer := func(id uint64) bool { + for _, pendingPeer := range region.GetPendingPeers() { + if pendingPeer.GetId() == id { + return true + } + } + for _, downPeer := range region.GetDownPeers() { + if downPeer.Peer.GetId() == id { + return true + } + } + return false + } + // remove orphan peers only when all rules are satisfied (count+role) and all peers selected + // by RuleFits is not pending or down. + hasUnhealthyFit := false +loopFits: + for _, rf := range fit.RuleFits { + if !rf.IsSatisfied() { + hasUnhealthyFit = true + break + } + for _, p := range rf.Peers { + if isUnhealthyPeer(p.GetId()) { + hasUnhealthyFit = true + break loopFits + } + } + } + // If hasUnhealthyFit is false, it is safe to delete the OrphanPeer. + if !hasUnhealthyFit { + checkerCounter.WithLabelValues("rule_checker", "remove-orphan-peer").Inc() + return operator.CreateRemovePeerOperator("remove-orphan-peer", c.cluster, 0, region, fit.OrphanPeers[0].StoreId) + } + // If hasUnhealthyFit is true, try to remove unhealthy orphan peers only if number of OrphanPeers is >= 2. + // Ref https://github.com/tikv/pd/issues/4045 + if len(fit.OrphanPeers) >= 2 { + for _, orphanPeer := range fit.OrphanPeers { + if isUnhealthyPeer(orphanPeer.GetId()) { + checkerCounter.WithLabelValues("rule_checker", "remove-orphan-peer").Inc() + return operator.CreateRemovePeerOperator("remove-orphan-peer", c.cluster, 0, region, orphanPeer.StoreId) + } + } + } + checkerCounter.WithLabelValues("rule_checker", "skip-remove-orphan-peer").Inc() + return nil, nil +} + +func (c *RuleChecker) isDownPeer(region *core.RegionInfo, peer *metapb.Peer) bool { + for _, stats := range region.GetDownPeers() { + if stats.GetPeer().GetId() == peer.GetId() { + storeID := peer.GetStoreId() + store := c.cluster.GetStore(storeID) + if store == nil { + log.Warn("lost the store, maybe you are recovering the PD cluster", zap.Uint64("store-id", storeID)) + return false + } + return true + } + } + return false +} + +func (c *RuleChecker) isStoreDownTimeHitMaxDownTime(storeID uint64) bool { + store := c.cluster.GetStore(storeID) + return store.DownTime() >= c.cluster.GetOpts().GetMaxStoreDownTime() +} + +func (c *RuleChecker) isOfflinePeer(peer *metapb.Peer) bool { + store := c.cluster.GetStore(peer.GetStoreId()) + if store == nil { + log.Warn("lost the store, maybe you are recovering the PD cluster", zap.Uint64("store-id", peer.StoreId)) + return false + } + return !store.IsPreparing() && !store.IsServing() +} + +func (c *RuleChecker) hasAvailableWitness(region *core.RegionInfo, peer *metapb.Peer) (*metapb.Peer, bool) { + witnesses := region.GetWitnesses() + if len(witnesses) == 0 { + return nil, false + } + isAvailable := func(downPeers []*pdpb.PeerStats, witness *metapb.Peer) bool { + for _, stats := range downPeers { + if stats.GetPeer().GetId() == witness.GetId() { + return false + } + } + return c.cluster.GetStore(witness.GetStoreId()) != nil + } + downPeers := region.GetDownPeers() + for _, witness := range witnesses { + if witness.GetId() != peer.GetId() && isAvailable(downPeers, witness) { + return witness, true + } + } + return nil, false +} + +func (c *RuleChecker) strategy(region *core.RegionInfo, rule *placement.Rule) *ReplicaStrategy { + return &ReplicaStrategy{ + checkerName: c.name, + cluster: c.cluster, + isolationLevel: rule.IsolationLevel, + locationLabels: rule.LocationLabels, + region: region, + extraFilters: []filter.Filter{filter.NewLabelConstraintFilter(c.name, rule.LabelConstraints)}, + } +} + +func (c *RuleChecker) getRuleFitStores(rf *placement.RuleFit) []*core.StoreInfo { + var stores []*core.StoreInfo + for _, p := range rf.Peers { + if s := c.cluster.GetStore(p.GetStoreId()); s != nil { + stores = append(stores, s) + } + } + return stores +} + +func (c *RuleChecker) handleFilterState(region *core.RegionInfo, filterByTempState bool) { + if filterByTempState { + c.regionWaitingList.Put(region.GetID(), nil) + c.pendingList.Remove(region.GetID()) + } else { + c.pendingList.Put(region.GetID(), nil) + } +} + +type recorder struct { + offlineLeaderCounter map[uint64]uint64 + lastUpdateTime time.Time +} + +func newRecord() *recorder { + return &recorder{ + offlineLeaderCounter: make(map[uint64]uint64), + lastUpdateTime: time.Now(), + } +} + +func (o *recorder) getOfflineLeaderCount(storeID uint64) uint64 { + return o.offlineLeaderCounter[storeID] +} + +func (o *recorder) incOfflineLeaderCount(storeID uint64) { + o.offlineLeaderCounter[storeID] += 1 + o.lastUpdateTime = time.Now() +} + +// Offline is triggered manually and only appears when the node makes some adjustments. here is an operator timeout / 2. +var offlineCounterTTL = 5 * time.Minute + +func (o *recorder) refresh(cluster schedule.Cluster) { + // re-count the offlineLeaderCounter if the store is already tombstone or store is gone. + if len(o.offlineLeaderCounter) > 0 && time.Since(o.lastUpdateTime) > offlineCounterTTL { + needClean := false + for _, storeID := range o.offlineLeaderCounter { + store := cluster.GetStore(storeID) + if store == nil || store.IsRemoved() { + needClean = true + break + } + } + if needClean { + o.offlineLeaderCounter = make(map[uint64]uint64) + } + } +} diff --git a/server/schedule/labeler/binding__failpoint_binding__.go b/server/schedule/labeler/binding__failpoint_binding__.go new file mode 100755 index 00000000000..f1831405e53 --- /dev/null +++ b/server/schedule/labeler/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package labeler + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/schedule/labeler/rules.go b/server/schedule/labeler/rules.go old mode 100644 new mode 100755 index c902fff8f66..1809afbafee --- a/server/schedule/labeler/rules.go +++ b/server/schedule/labeler/rules.go @@ -75,11 +75,11 @@ type LabelRulePatch struct { } func (l *RegionLabel) expireBefore(t time.Time) bool { - failpoint.Inject("regionLabelExpireSub1Minute", func() { + if _, _err_ := failpoint.Eval(_curpkg_("regionLabelExpireSub1Minute")); _err_ == nil { if l.expire != nil { *l.expire = l.expire.Add(-time.Minute) } - }) + } if l.expire == nil { return false } diff --git a/server/schedule/labeler/rules.go__failpoint_stash__ b/server/schedule/labeler/rules.go__failpoint_stash__ new file mode 100644 index 00000000000..c902fff8f66 --- /dev/null +++ b/server/schedule/labeler/rules.go__failpoint_stash__ @@ -0,0 +1,223 @@ +// Copyright 2021 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package labeler + +import ( + "bytes" + "encoding/hex" + "fmt" + "reflect" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "go.uber.org/zap" +) + +// RegionLabel is the label of a region. +// NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. +type RegionLabel struct { + Key string `json:"key"` + Value string `json:"value"` + TTL string `json:"ttl,omitempty"` + StartAt string `json:"start_at,omitempty"` + expire *time.Time +} + +// LabelRule is the rule to assign labels to a region. +// NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. +type LabelRule struct { + ID string `json:"id"` + Index int `json:"index"` + Labels []RegionLabel `json:"labels"` + RuleType string `json:"rule_type"` + Data interface{} `json:"data"` + minExpire *time.Time +} + +const ( + // KeyRange is the rule type that specifies a list of key ranges. + KeyRange = "key-range" +) + +const ( + scheduleOptionLabel = "schedule" + scheduleOptioonValueDeny = "deny" +) + +// KeyRangeRule contains the start key and end key of the LabelRule. +// NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. +type KeyRangeRule struct { + StartKey []byte `json:"-"` // range start key + StartKeyHex string `json:"start_key"` // hex format start key, for marshal/unmarshal + EndKey []byte `json:"-"` // range end key + EndKeyHex string `json:"end_key"` // hex format end key, for marshal/unmarshal +} + +// LabelRulePatch is the patch to update the label rules. +// NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. +type LabelRulePatch struct { + SetRules []*LabelRule `json:"sets"` + DeleteRules []string `json:"deletes"` +} + +func (l *RegionLabel) expireBefore(t time.Time) bool { + failpoint.Inject("regionLabelExpireSub1Minute", func() { + if l.expire != nil { + *l.expire = l.expire.Add(-time.Minute) + } + }) + if l.expire == nil { + return false + } + return l.expire.Before(t) +} + +func (l *RegionLabel) checkAndAdjustExpire() (err error) { + if len(l.TTL) == 0 { + l.expire = nil + return + } + ttl, err := time.ParseDuration(l.TTL) + if err != nil { + return err + } + var startAt time.Time + if len(l.StartAt) == 0 { + startAt = time.Now() + l.StartAt = startAt.Format(time.UnixDate) + } else { + startAt, err = time.Parse(time.UnixDate, l.StartAt) + if err != nil { + return err + } + } + expire := startAt.Add(ttl) + l.expire = &expire + return nil +} + +func (rule *LabelRule) checkAndRemoveExpireLabels(now time.Time) bool { + labels := make([]RegionLabel, 0) + rule.minExpire = nil + for _, l := range rule.Labels { + if l.expireBefore(now) { + continue + } + labels = append(labels, l) + if rule.minExpire == nil || l.expireBefore(*rule.minExpire) { + rule.minExpire = l.expire + } + } + + if len(labels) == len(rule.Labels) { + return false + } + rule.Labels = labels + return true +} + +func (rule *LabelRule) checkAndAdjust() error { + if rule.ID == "" { + return errs.ErrRegionRuleContent.FastGenByArgs("empty rule id") + } + if len(rule.Labels) == 0 { + return errs.ErrRegionRuleContent.FastGenByArgs("no region labels") + } + for id, l := range rule.Labels { + if l.Key == "" { + return errs.ErrRegionRuleContent.FastGenByArgs("empty region label key") + } + if l.Value == "" { + return errs.ErrRegionRuleContent.FastGenByArgs("empty region label value") + } + if err := rule.Labels[id].checkAndAdjustExpire(); err != nil { + err := fmt.Sprintf("region label with invalid ttl info %v", err) + return errs.ErrRegionRuleContent.FastGenByArgs(err) + } + } + rule.checkAndRemoveExpireLabels(time.Now()) + if len(rule.Labels) == 0 { + return errs.ErrRegionRuleContent.FastGenByArgs("region label with expired ttl") + } + + // TODO: change it to switch statement once we support more types. + if rule.RuleType == KeyRange { + var err error + rule.Data, err = initKeyRangeRulesFromLabelRuleData(rule.Data) + return err + } + log.Error("invalid rule type", zap.String("rule-type", rule.RuleType)) + return errs.ErrRegionRuleContent.FastGenByArgs(fmt.Sprintf("invalid rule type: %s", rule.RuleType)) +} + +func (rule *LabelRule) expireBefore(t time.Time) bool { + if rule.minExpire == nil { + return false + } + return rule.minExpire.Before(t) +} + +// initKeyRangeRulesFromLabelRuleData init and adjust []KeyRangeRule from `LabelRule.Data“ +func initKeyRangeRulesFromLabelRuleData(data interface{}) ([]*KeyRangeRule, error) { + rules, ok := data.([]interface{}) + if !ok { + return nil, errs.ErrRegionRuleContent.FastGenByArgs(fmt.Sprintf("invalid rule type: %T", data)) + } + if len(rules) == 0 { + return nil, errs.ErrRegionRuleContent.FastGenByArgs("no key ranges") + } + rs := make([]*KeyRangeRule, 0, len(rules)) + for _, r := range rules { + rr, err := initAndAdjustKeyRangeRule(r) + if err != nil { + return nil, err + } + rs = append(rs, rr) + } + return rs, nil +} + +// initAndAdjustKeyRangeRule inits and adjusts the KeyRangeRule from one item in `LabelRule.Data` +func initAndAdjustKeyRangeRule(rule interface{}) (*KeyRangeRule, error) { + data, ok := rule.(map[string]interface{}) + if !ok { + return nil, errs.ErrRegionRuleContent.FastGenByArgs(fmt.Sprintf("invalid rule type: %T", reflect.TypeOf(rule))) + } + startKey, ok := data["start_key"].(string) + if !ok { + return nil, errs.ErrRegionRuleContent.FastGenByArgs(fmt.Sprintf("invalid startKey type: %T", reflect.TypeOf(data["start_key"]))) + } + endKey, ok := data["end_key"].(string) + if !ok { + return nil, errs.ErrRegionRuleContent.FastGenByArgs(fmt.Sprintf("invalid endKey type: %T", reflect.TypeOf(data["end_key"]))) + } + var r KeyRangeRule + r.StartKeyHex, r.EndKeyHex = startKey, endKey + var err error + r.StartKey, err = hex.DecodeString(r.StartKeyHex) + if err != nil { + return nil, errs.ErrHexDecodingString.FastGenByArgs(r.StartKeyHex) + } + r.EndKey, err = hex.DecodeString(r.EndKeyHex) + if err != nil { + return nil, errs.ErrHexDecodingString.FastGenByArgs(r.EndKeyHex) + } + if len(r.EndKey) > 0 && bytes.Compare(r.EndKey, r.StartKey) <= 0 { + return nil, errs.ErrRegionRuleContent.FastGenByArgs("endKey should be greater than startKey") + } + return &r, nil +} diff --git a/server/schedule/operator_controller.go b/server/schedule/operator_controller.go old mode 100644 new mode 100755 index c7b2cc99d34..09cd4c00abc --- a/server/schedule/operator_controller.go +++ b/server/schedule/operator_controller.go @@ -102,9 +102,9 @@ func (oc *OperatorController) GetCluster() Cluster { func (oc *OperatorController) Dispatch(region *core.RegionInfo, source string) { // Check existed operator. if op := oc.GetOperator(region.GetID()); op != nil { - failpoint.Inject("concurrentRemoveOperator", func() { + if _, _err_ := failpoint.Eval(_curpkg_("concurrentRemoveOperator")); _err_ == nil { time.Sleep(500 * time.Millisecond) - }) + } // Update operator status: // The operator status should be STARTED. @@ -143,9 +143,9 @@ func (oc *OperatorController) Dispatch(region *core.RegionInfo, source string) { zap.String("status", operator.OpStatusToString(op.Status())), zap.Reflect("operator", op), errs.ZapError(errs.ErrUnexpectedOperatorStatus)) operatorWaitCounter.WithLabelValues(op.Desc(), "unexpected").Inc() - failpoint.Inject("unexpectedOperator", func() { + if _, _err_ := failpoint.Eval(_curpkg_("unexpectedOperator")); _err_ == nil { panic(op) - }) + } _ = op.Cancel() oc.buryOperator(op) operatorWaitCounter.WithLabelValues(op.Desc(), "promote-unexpected").Inc() @@ -408,9 +408,9 @@ func (oc *OperatorController) checkAddOperator(isPromoting bool, ops ...*operato zap.Uint64("region-id", op.RegionID()), zap.String("status", operator.OpStatusToString(op.Status())), zap.Reflect("operator", op), errs.ZapError(errs.ErrUnexpectedOperatorStatus)) - failpoint.Inject("unexpectedOperator", func() { + if _, _err_ := failpoint.Eval(_curpkg_("unexpectedOperator")); _err_ == nil { panic(op) - }) + } operatorWaitCounter.WithLabelValues(op.Desc(), "unexpected-status").Inc() return false } @@ -467,9 +467,9 @@ func (oc *OperatorController) addOperatorLocked(op *operator.Operator) bool { zap.Uint64("region-id", regionID), zap.String("status", operator.OpStatusToString(op.Status())), zap.Reflect("operator", op), errs.ZapError(errs.ErrUnexpectedOperatorStatus)) - failpoint.Inject("unexpectedOperator", func() { + if _, _err_ := failpoint.Eval(_curpkg_("unexpectedOperator")); _err_ == nil { panic(op) - }) + } operatorCounter.WithLabelValues(op.Desc(), "unexpected").Inc() return false } @@ -553,9 +553,9 @@ func (oc *OperatorController) buryOperator(op *operator.Operator, extraFields .. zap.Uint64("region-id", op.RegionID()), zap.String("status", operator.OpStatusToString(op.Status())), zap.Reflect("operator", op), errs.ZapError(errs.ErrUnexpectedOperatorStatus)) - failpoint.Inject("unexpectedOperator", func() { + if _, _err_ := failpoint.Eval(_curpkg_("unexpectedOperator")); _err_ == nil { panic(op) - }) + } operatorCounter.WithLabelValues(op.Desc(), "unexpected").Inc() _ = op.Cancel() } diff --git a/server/schedule/operator_controller.go__failpoint_stash__ b/server/schedule/operator_controller.go__failpoint_stash__ new file mode 100644 index 00000000000..c7b2cc99d34 --- /dev/null +++ b/server/schedule/operator_controller.go__failpoint_stash__ @@ -0,0 +1,869 @@ +// Copyright 2018 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package schedule + +import ( + "container/heap" + "context" + "fmt" + "strconv" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/cache" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/core/storelimit" + "github.com/tikv/pd/server/schedule/hbstream" + "github.com/tikv/pd/server/schedule/labeler" + "github.com/tikv/pd/server/schedule/operator" + "github.com/tikv/pd/server/versioninfo" + "go.uber.org/zap" +) + +// The source of dispatched region. +const ( + DispatchFromHeartBeat = "heartbeat" + DispatchFromNotifierQueue = "active push" + DispatchFromCreate = "create" +) + +var ( + slowNotifyInterval = 5 * time.Second + fastNotifyInterval = 2 * time.Second + // PushOperatorTickInterval is the interval try to push the operator. + PushOperatorTickInterval = 500 * time.Millisecond + // StoreBalanceBaseTime represents the base time of balance rate. + StoreBalanceBaseTime float64 = 60 + // FastOperatorFinishTime min finish time, if finish duration less than it,op will be pushed to fast operator queue + FastOperatorFinishTime = 10 * time.Second +) + +// OperatorController is used to limit the speed of scheduling. +type OperatorController struct { + syncutil.RWMutex + ctx context.Context + cluster Cluster + operators map[uint64]*operator.Operator + hbStreams *hbstream.HeartbeatStreams + fastOperators *cache.TTLUint64 + counts map[operator.OpKind]uint64 + opRecords *OperatorRecords + wop WaitingOperator + wopStatus *WaitingOperatorStatus + opNotifierQueue operatorQueue +} + +// NewOperatorController creates a OperatorController. +func NewOperatorController(ctx context.Context, cluster Cluster, hbStreams *hbstream.HeartbeatStreams) *OperatorController { + return &OperatorController{ + ctx: ctx, + cluster: cluster, + operators: make(map[uint64]*operator.Operator), + hbStreams: hbStreams, + fastOperators: cache.NewIDTTL(ctx, time.Minute, FastOperatorFinishTime), + counts: make(map[operator.OpKind]uint64), + opRecords: NewOperatorRecords(ctx), + wop: NewRandBuckets(), + wopStatus: NewWaitingOperatorStatus(), + opNotifierQueue: make(operatorQueue, 0), + } +} + +// Ctx returns a context which will be canceled once RaftCluster is stopped. +// For now, it is only used to control the lifetime of TTL cache in schedulers. +func (oc *OperatorController) Ctx() context.Context { + return oc.ctx +} + +// GetCluster exports cluster to evict-scheduler for check store status. +func (oc *OperatorController) GetCluster() Cluster { + oc.RLock() + defer oc.RUnlock() + return oc.cluster +} + +// Dispatch is used to dispatch the operator of a region. +func (oc *OperatorController) Dispatch(region *core.RegionInfo, source string) { + // Check existed operator. + if op := oc.GetOperator(region.GetID()); op != nil { + failpoint.Inject("concurrentRemoveOperator", func() { + time.Sleep(500 * time.Millisecond) + }) + + // Update operator status: + // The operator status should be STARTED. + // Check will call CheckSuccess and CheckTimeout. + step := op.Check(region) + switch op.Status() { + case operator.STARTED: + operatorCounter.WithLabelValues(op.Desc(), "check").Inc() + if source == DispatchFromHeartBeat && oc.checkStaleOperator(op, step, region) { + return + } + oc.SendScheduleCommand(region, step, source) + case operator.SUCCESS: + if op.ContainNonWitnessStep() { + oc.cluster.RecordOpStepWithTTL(op.RegionID()) + } + if oc.RemoveOperator(op) { + operatorWaitCounter.WithLabelValues(op.Desc(), "promote-success").Inc() + oc.PromoteWaitingOperator() + } + if time.Since(op.GetStartTime()) < FastOperatorFinishTime { + log.Debug("op finish duration less than 10s", zap.Uint64("region-id", op.RegionID())) + oc.pushFastOperator(op) + } + case operator.TIMEOUT: + if oc.RemoveOperator(op) { + operatorCounter.WithLabelValues(op.Desc(), "promote-timeout").Inc() + oc.PromoteWaitingOperator() + } + default: + if oc.removeOperatorWithoutBury(op) { + // CREATED, EXPIRED must not appear. + // CANCELED, REPLACED must remove before transition. + log.Error("dispatching operator with unexpected status", + zap.Uint64("region-id", op.RegionID()), + zap.String("status", operator.OpStatusToString(op.Status())), + zap.Reflect("operator", op), errs.ZapError(errs.ErrUnexpectedOperatorStatus)) + operatorWaitCounter.WithLabelValues(op.Desc(), "unexpected").Inc() + failpoint.Inject("unexpectedOperator", func() { + panic(op) + }) + _ = op.Cancel() + oc.buryOperator(op) + operatorWaitCounter.WithLabelValues(op.Desc(), "promote-unexpected").Inc() + oc.PromoteWaitingOperator() + } + } + } +} + +func (oc *OperatorController) checkStaleOperator(op *operator.Operator, step operator.OpStep, region *core.RegionInfo) bool { + err := step.CheckInProgress(oc.cluster, region) + if err != nil { + if oc.RemoveOperator(op, zap.String("reason", err.Error())) { + operatorCounter.WithLabelValues(op.Desc(), "stale").Inc() + operatorWaitCounter.WithLabelValues(op.Desc(), "promote-stale").Inc() + oc.PromoteWaitingOperator() + return true + } + } + // When the "source" is heartbeat, the region may have a newer + // confver than the region that the operator holds. In this case, + // the operator is stale, and will not be executed even we would + // have sent it to TiKV servers. Here, we just cancel it. + origin := op.RegionEpoch() + latest := region.GetRegionEpoch() + changes := latest.GetConfVer() - origin.GetConfVer() + if changes > op.ConfVerChanged(region) { + if oc.RemoveOperator( + op, + zap.String("reason", "stale operator, confver does not meet expectations"), + zap.Reflect("latest-epoch", region.GetRegionEpoch()), + zap.Uint64("diff", changes), + ) { + operatorCounter.WithLabelValues(op.Desc(), "stale").Inc() + operatorWaitCounter.WithLabelValues(op.Desc(), "promote-stale").Inc() + oc.PromoteWaitingOperator() + return true + } + } + + return false +} + +func (oc *OperatorController) getNextPushOperatorTime(step operator.OpStep, now time.Time) time.Time { + nextTime := slowNotifyInterval + switch step.(type) { + case operator.TransferLeader, operator.PromoteLearner, operator.ChangePeerV2Enter, operator.ChangePeerV2Leave: + nextTime = fastNotifyInterval + } + return now.Add(nextTime) +} + +// pollNeedDispatchRegion returns the region need to dispatch, +// "next" is true to indicate that it may exist in next attempt, +// and false is the end for the poll. +func (oc *OperatorController) pollNeedDispatchRegion() (r *core.RegionInfo, next bool) { + oc.Lock() + defer oc.Unlock() + if oc.opNotifierQueue.Len() == 0 { + return nil, false + } + item := heap.Pop(&oc.opNotifierQueue).(*operatorWithTime) + regionID := item.op.RegionID() + op, ok := oc.operators[regionID] + if !ok || op == nil { + return nil, true + } + r = oc.cluster.GetRegion(regionID) + if r == nil { + _ = oc.removeOperatorLocked(op) + if op.Cancel() { + log.Warn("remove operator because region disappeared", + zap.Uint64("region-id", op.RegionID()), + zap.Stringer("operator", op)) + operatorCounter.WithLabelValues(op.Desc(), "disappear").Inc() + } + oc.buryOperator(op) + return nil, true + } + step := op.Check(r) + if step == nil { + return r, true + } + now := time.Now() + if now.Before(item.time) { + heap.Push(&oc.opNotifierQueue, item) + return nil, false + } + + // pushes with new notify time. + item.time = oc.getNextPushOperatorTime(step, now) + heap.Push(&oc.opNotifierQueue, item) + return r, true +} + +// PushOperators periodically pushes the unfinished operator to the executor(TiKV). +func (oc *OperatorController) PushOperators() { + for { + r, next := oc.pollNeedDispatchRegion() + if !next { + break + } + if r == nil { + continue + } + + oc.Dispatch(r, DispatchFromNotifierQueue) + } +} + +// AddWaitingOperator adds operators to waiting operators. +func (oc *OperatorController) AddWaitingOperator(ops ...*operator.Operator) int { + oc.Lock() + added := 0 + needPromoted := 0 + + for i := 0; i < len(ops); i++ { + op := ops[i] + desc := op.Desc() + isMerge := false + if op.Kind()&operator.OpMerge != 0 { + if i+1 >= len(ops) { + // should not be here forever + log.Error("orphan merge operators found", zap.String("desc", desc), errs.ZapError(errs.ErrMergeOperator.FastGenByArgs("orphan operator found"))) + oc.Unlock() + return added + } + if ops[i+1].Kind()&operator.OpMerge == 0 { + log.Error("merge operator should be paired", zap.String("desc", + ops[i+1].Desc()), errs.ZapError(errs.ErrMergeOperator.FastGenByArgs("operator should be paired"))) + oc.Unlock() + return added + } + isMerge = true + } + if !oc.checkAddOperator(false, op) { + _ = op.Cancel() + oc.buryOperator(op) + if isMerge { + // Merge operation have two operators, cancel them all + i++ + next := ops[i] + _ = next.Cancel() + oc.buryOperator(next) + } + continue + } + oc.wop.PutOperator(op) + if isMerge { + // count two merge operators as one, so wopStatus.ops[desc] should + // not be updated here + i++ + added++ + oc.wop.PutOperator(ops[i]) + } + operatorWaitCounter.WithLabelValues(desc, "put").Inc() + oc.wopStatus.ops[desc]++ + added++ + needPromoted++ + } + + oc.Unlock() + operatorWaitCounter.WithLabelValues(ops[0].Desc(), "promote-add").Add(float64(needPromoted)) + for i := 0; i < needPromoted; i++ { + oc.PromoteWaitingOperator() + } + return added +} + +// AddOperator adds operators to the running operators. +func (oc *OperatorController) AddOperator(ops ...*operator.Operator) bool { + oc.Lock() + defer oc.Unlock() + + // note: checkAddOperator uses false param for `isPromoting`. + // This is used to keep check logic before fixing issue #4946, + // but maybe user want to add operator when waiting queue is busy + if oc.exceedStoreLimitLocked(ops...) || !oc.checkAddOperator(false, ops...) { + for _, op := range ops { + _ = op.Cancel() + oc.buryOperator(op) + } + return false + } + for _, op := range ops { + if !oc.addOperatorLocked(op) { + return false + } + } + return true +} + +// PromoteWaitingOperator promotes operators from waiting operators. +func (oc *OperatorController) PromoteWaitingOperator() { + oc.Lock() + defer oc.Unlock() + var ops []*operator.Operator + for { + // GetOperator returns one operator or two merge operators + ops = oc.wop.GetOperator() + if ops == nil { + return + } + operatorWaitCounter.WithLabelValues(ops[0].Desc(), "get").Inc() + + if oc.exceedStoreLimitLocked(ops...) || !oc.checkAddOperator(true, ops...) { + for _, op := range ops { + operatorWaitCounter.WithLabelValues(op.Desc(), "promote-canceled").Inc() + _ = op.Cancel() + oc.buryOperator(op) + } + oc.wopStatus.ops[ops[0].Desc()]-- + continue + } + oc.wopStatus.ops[ops[0].Desc()]-- + break + } + + for _, op := range ops { + if !oc.addOperatorLocked(op) { + break + } + } +} + +// checkAddOperator checks if the operator can be added. +// There are several situations that cannot be added: +// - There is no such region in the cluster +// - The epoch of the operator and the epoch of the corresponding region are no longer consistent. +// - The region already has a higher priority or same priority operator. +// - Exceed the max number of waiting operators +// - At least one operator is expired. +func (oc *OperatorController) checkAddOperator(isPromoting bool, ops ...*operator.Operator) bool { + for _, op := range ops { + region := oc.cluster.GetRegion(op.RegionID()) + if region == nil { + log.Debug("region not found, cancel add operator", + zap.Uint64("region-id", op.RegionID())) + operatorWaitCounter.WithLabelValues(op.Desc(), "not-found").Inc() + return false + } + if region.GetRegionEpoch().GetVersion() != op.RegionEpoch().GetVersion() || + region.GetRegionEpoch().GetConfVer() != op.RegionEpoch().GetConfVer() { + log.Debug("region epoch not match, cancel add operator", + zap.Uint64("region-id", op.RegionID()), + zap.Reflect("old", region.GetRegionEpoch()), + zap.Reflect("new", op.RegionEpoch())) + operatorWaitCounter.WithLabelValues(op.Desc(), "epoch-not-match").Inc() + return false + } + if old := oc.operators[op.RegionID()]; old != nil && !isHigherPriorityOperator(op, old) { + log.Debug("already have operator, cancel add operator", + zap.Uint64("region-id", op.RegionID()), + zap.Reflect("old", old)) + operatorWaitCounter.WithLabelValues(op.Desc(), "already-have").Inc() + return false + } + if op.Status() != operator.CREATED { + log.Error("trying to add operator with unexpected status", + zap.Uint64("region-id", op.RegionID()), + zap.String("status", operator.OpStatusToString(op.Status())), + zap.Reflect("operator", op), errs.ZapError(errs.ErrUnexpectedOperatorStatus)) + failpoint.Inject("unexpectedOperator", func() { + panic(op) + }) + operatorWaitCounter.WithLabelValues(op.Desc(), "unexpected-status").Inc() + return false + } + if !isPromoting && oc.wopStatus.ops[op.Desc()] >= oc.cluster.GetOpts().GetSchedulerMaxWaitingOperator() { + log.Debug("exceed max return false", zap.Uint64("waiting", oc.wopStatus.ops[op.Desc()]), zap.String("desc", op.Desc()), zap.Uint64("max", oc.cluster.GetOpts().GetSchedulerMaxWaitingOperator())) + operatorWaitCounter.WithLabelValues(op.Desc(), "exceed-max").Inc() + return false + } + + if op.SchedulerKind() == operator.OpAdmin || op.IsLeaveJointStateOperator() { + continue + } + if cl, ok := oc.cluster.(interface{ GetRegionLabeler() *labeler.RegionLabeler }); ok { + l := cl.GetRegionLabeler() + if l.ScheduleDisabled(region) { + log.Debug("schedule disabled", zap.Uint64("region-id", op.RegionID())) + operatorWaitCounter.WithLabelValues(op.Desc(), "schedule-disabled").Inc() + return false + } + } + } + expired := false + for _, op := range ops { + if op.CheckExpired() { + expired = true + operatorWaitCounter.WithLabelValues(op.Desc(), "expired").Inc() + } + } + return !expired +} + +func isHigherPriorityOperator(new, old *operator.Operator) bool { + return new.GetPriorityLevel() > old.GetPriorityLevel() +} + +func (oc *OperatorController) addOperatorLocked(op *operator.Operator) bool { + regionID := op.RegionID() + + log.Info("add operator", + zap.Uint64("region-id", regionID), + zap.Reflect("operator", op), + zap.String("additional-info", op.GetAdditionalInfo())) + + // If there is an old operator, replace it. The priority should be checked + // already. + if old, ok := oc.operators[regionID]; ok { + _ = oc.removeOperatorLocked(old) + _ = old.Replace() + oc.buryOperator(old) + } + + if !op.Start() { + log.Error("adding operator with unexpected status", + zap.Uint64("region-id", regionID), + zap.String("status", operator.OpStatusToString(op.Status())), + zap.Reflect("operator", op), errs.ZapError(errs.ErrUnexpectedOperatorStatus)) + failpoint.Inject("unexpectedOperator", func() { + panic(op) + }) + operatorCounter.WithLabelValues(op.Desc(), "unexpected").Inc() + return false + } + oc.operators[regionID] = op + operatorCounter.WithLabelValues(op.Desc(), "start").Inc() + operatorSizeHist.WithLabelValues(op.Desc()).Observe(float64(op.ApproximateSize)) + operatorWaitDuration.WithLabelValues(op.Desc()).Observe(op.ElapsedTime().Seconds()) + opInfluence := NewTotalOpInfluence([]*operator.Operator{op}, oc.cluster) + for storeID := range opInfluence.StoresInfluence { + store := oc.cluster.GetStore(storeID) + if store == nil { + log.Info("missing store", zap.Uint64("store-id", storeID)) + continue + } + limit := store.GetStoreLimit() + for n, v := range storelimit.TypeNameValue { + stepCost := opInfluence.GetStoreInfluence(storeID).GetStepCost(v) + if stepCost == 0 { + continue + } + limit.Take(stepCost, v) + storeLimitCostCounter.WithLabelValues(strconv.FormatUint(storeID, 10), n).Add(float64(stepCost) / float64(storelimit.RegionInfluence[v])) + } + } + oc.updateCounts(oc.operators) + + var step operator.OpStep + if region := oc.cluster.GetRegion(op.RegionID()); region != nil { + if step = op.Check(region); step != nil { + oc.SendScheduleCommand(region, step, DispatchFromCreate) + } + } + + heap.Push(&oc.opNotifierQueue, &operatorWithTime{op: op, time: oc.getNextPushOperatorTime(step, time.Now())}) + operatorCounter.WithLabelValues(op.Desc(), "create").Inc() + for _, counter := range op.Counters { + counter.Inc() + } + return true +} + +// RemoveOperator removes a operator from the running operators. +func (oc *OperatorController) RemoveOperator(op *operator.Operator, extraFields ...zap.Field) bool { + oc.Lock() + removed := oc.removeOperatorLocked(op) + oc.Unlock() + if removed { + if op.Cancel() { + log.Info("operator removed", + zap.Uint64("region-id", op.RegionID()), + zap.Duration("takes", op.RunningTime()), + zap.Reflect("operator", op)) + } + oc.buryOperator(op, extraFields...) + } + return removed +} + +func (oc *OperatorController) removeOperatorWithoutBury(op *operator.Operator) bool { + oc.Lock() + defer oc.Unlock() + return oc.removeOperatorLocked(op) +} + +func (oc *OperatorController) removeOperatorLocked(op *operator.Operator) bool { + regionID := op.RegionID() + if cur := oc.operators[regionID]; cur == op { + delete(oc.operators, regionID) + oc.updateCounts(oc.operators) + operatorCounter.WithLabelValues(op.Desc(), "remove").Inc() + return true + } + return false +} + +func (oc *OperatorController) buryOperator(op *operator.Operator, extraFields ...zap.Field) { + st := op.Status() + + if !operator.IsEndStatus(st) { + log.Error("burying operator with non-end status", + zap.Uint64("region-id", op.RegionID()), + zap.String("status", operator.OpStatusToString(op.Status())), + zap.Reflect("operator", op), errs.ZapError(errs.ErrUnexpectedOperatorStatus)) + failpoint.Inject("unexpectedOperator", func() { + panic(op) + }) + operatorCounter.WithLabelValues(op.Desc(), "unexpected").Inc() + _ = op.Cancel() + } + + switch st { + case operator.SUCCESS: + log.Info("operator finish", + zap.Uint64("region-id", op.RegionID()), + zap.Duration("takes", op.RunningTime()), + zap.Reflect("operator", op), + zap.String("additional-info", op.GetAdditionalInfo())) + operatorCounter.WithLabelValues(op.Desc(), "finish").Inc() + operatorDuration.WithLabelValues(op.Desc()).Observe(op.RunningTime().Seconds()) + for _, counter := range op.FinishedCounters { + counter.Inc() + } + case operator.REPLACED: + log.Info("replace old operator", + zap.Uint64("region-id", op.RegionID()), + zap.Duration("takes", op.RunningTime()), + zap.Reflect("operator", op), + zap.String("additional-info", op.GetAdditionalInfo())) + operatorCounter.WithLabelValues(op.Desc(), "replace").Inc() + case operator.EXPIRED: + log.Info("operator expired", + zap.Uint64("region-id", op.RegionID()), + zap.Duration("lives", op.ElapsedTime()), + zap.Reflect("operator", op)) + operatorCounter.WithLabelValues(op.Desc(), "expire").Inc() + case operator.TIMEOUT: + log.Info("operator timeout", + zap.Uint64("region-id", op.RegionID()), + zap.Duration("takes", op.RunningTime()), + zap.Reflect("operator", op), + zap.String("additional-info", op.GetAdditionalInfo())) + operatorCounter.WithLabelValues(op.Desc(), "timeout").Inc() + case operator.CANCELED: + fields := []zap.Field{ + zap.Uint64("region-id", op.RegionID()), + zap.Duration("takes", op.RunningTime()), + zap.Reflect("operator", op), + zap.String("additional-info", op.GetAdditionalInfo()), + } + fields = append(fields, extraFields...) + log.Info("operator canceled", + fields..., + ) + operatorCounter.WithLabelValues(op.Desc(), "cancel").Inc() + } + + oc.opRecords.Put(op) +} + +// GetOperatorStatus gets the operator and its status with the specify id. +func (oc *OperatorController) GetOperatorStatus(id uint64) *OperatorWithStatus { + oc.Lock() + defer oc.Unlock() + if op, ok := oc.operators[id]; ok { + return NewOperatorWithStatus(op) + } + return oc.opRecords.Get(id) +} + +// GetOperator gets a operator from the given region. +func (oc *OperatorController) GetOperator(regionID uint64) *operator.Operator { + oc.RLock() + defer oc.RUnlock() + return oc.operators[regionID] +} + +// GetOperators gets operators from the running operators. +func (oc *OperatorController) GetOperators() []*operator.Operator { + oc.RLock() + defer oc.RUnlock() + + operators := make([]*operator.Operator, 0, len(oc.operators)) + for _, op := range oc.operators { + operators = append(operators, op) + } + + return operators +} + +// GetWaitingOperators gets operators from the waiting operators. +func (oc *OperatorController) GetWaitingOperators() []*operator.Operator { + oc.RLock() + defer oc.RUnlock() + return oc.wop.ListOperator() +} + +// SendScheduleCommand sends a command to the region. +func (oc *OperatorController) SendScheduleCommand(region *core.RegionInfo, step operator.OpStep, source string) { + log.Info("send schedule command", + zap.Uint64("region-id", region.GetID()), + zap.Stringer("step", step), + zap.String("source", source)) + + useConfChangeV2 := versioninfo.IsFeatureSupported(oc.cluster.GetOpts().GetClusterVersion(), versioninfo.ConfChangeV2) + cmd := step.GetCmd(region, useConfChangeV2) + if cmd == nil { + return + } + oc.hbStreams.SendMsg(region, cmd) +} + +func (oc *OperatorController) pushFastOperator(op *operator.Operator) { + oc.fastOperators.Put(op.RegionID(), op) +} + +// GetRecords gets operators' records. +func (oc *OperatorController) GetRecords(from time.Time) []*operator.OpRecord { + records := make([]*operator.OpRecord, 0, oc.opRecords.ttl.Len()) + for _, id := range oc.opRecords.ttl.GetAllID() { + op := oc.opRecords.Get(id) + if op == nil || op.FinishTime.Before(from) { + continue + } + records = append(records, op.Record(op.FinishTime)) + } + return records +} + +// GetHistory gets operators' history. +func (oc *OperatorController) GetHistory(start time.Time) []operator.OpHistory { + history := make([]operator.OpHistory, 0, oc.opRecords.ttl.Len()) + for _, id := range oc.opRecords.ttl.GetAllID() { + op := oc.opRecords.Get(id) + if op == nil || op.FinishTime.Before(start) { + continue + } + history = append(history, op.History()...) + } + return history +} + +// updateCounts updates resource counts using current pending operators. +func (oc *OperatorController) updateCounts(operators map[uint64]*operator.Operator) { + for k := range oc.counts { + delete(oc.counts, k) + } + for _, op := range operators { + oc.counts[op.SchedulerKind()]++ + } +} + +// OperatorCount gets the count of operators filtered by kind. +// kind only has one OpKind. +func (oc *OperatorController) OperatorCount(kind operator.OpKind) uint64 { + oc.RLock() + defer oc.RUnlock() + return oc.counts[kind] +} + +// GetOpInfluence gets OpInfluence. +func (oc *OperatorController) GetOpInfluence(cluster Cluster) operator.OpInfluence { + influence := operator.OpInfluence{ + StoresInfluence: make(map[uint64]*operator.StoreInfluence), + } + oc.RLock() + defer oc.RUnlock() + for _, op := range oc.operators { + if !op.CheckTimeout() && !op.CheckSuccess() { + region := cluster.GetRegion(op.RegionID()) + if region != nil { + op.UnfinishedInfluence(influence, region) + } + } + } + return influence +} + +// GetFastOpInfluence get fast finish operator influence +func (oc *OperatorController) GetFastOpInfluence(cluster Cluster, influence operator.OpInfluence) { + for _, id := range oc.fastOperators.GetAllID() { + value, ok := oc.fastOperators.Get(id) + if !ok { + continue + } + op, ok := value.(*operator.Operator) + if !ok { + continue + } + AddOpInfluence(op, influence, cluster) + } +} + +// AddOpInfluence add operator influence for cluster +func AddOpInfluence(op *operator.Operator, influence operator.OpInfluence, cluster Cluster) { + region := cluster.GetRegion(op.RegionID()) + if region != nil { + op.TotalInfluence(influence, region) + } +} + +// NewTotalOpInfluence creates a OpInfluence. +func NewTotalOpInfluence(operators []*operator.Operator, cluster Cluster) operator.OpInfluence { + influence := operator.OpInfluence{ + StoresInfluence: make(map[uint64]*operator.StoreInfluence), + } + + for _, op := range operators { + AddOpInfluence(op, influence, cluster) + } + + return influence +} + +// SetOperator is only used for test. +func (oc *OperatorController) SetOperator(op *operator.Operator) { + oc.Lock() + defer oc.Unlock() + oc.operators[op.RegionID()] = op + oc.updateCounts(oc.operators) +} + +// OperatorWithStatus records the operator and its status. +type OperatorWithStatus struct { + *operator.Operator + Status pdpb.OperatorStatus + FinishTime time.Time +} + +// NewOperatorWithStatus creates an OperatorStatus from an operator. +func NewOperatorWithStatus(op *operator.Operator) *OperatorWithStatus { + return &OperatorWithStatus{ + Operator: op, + Status: operator.OpStatusToPDPB(op.Status()), + FinishTime: time.Now(), + } +} + +// MarshalJSON returns the status of operator as a JSON string +func (o *OperatorWithStatus) MarshalJSON() ([]byte, error) { + return []byte(`"` + fmt.Sprintf("status: %s, operator: %s", o.Status.String(), o.Operator.String()) + `"`), nil +} + +// OperatorRecords remains the operator and its status for a while. +type OperatorRecords struct { + ttl *cache.TTLUint64 +} + +const operatorStatusRemainTime = 10 * time.Minute + +// NewOperatorRecords returns a OperatorRecords. +func NewOperatorRecords(ctx context.Context) *OperatorRecords { + return &OperatorRecords{ + ttl: cache.NewIDTTL(ctx, time.Minute, operatorStatusRemainTime), + } +} + +// Get gets the operator and its status. +func (o *OperatorRecords) Get(id uint64) *OperatorWithStatus { + v, exist := o.ttl.Get(id) + if !exist { + return nil + } + return v.(*OperatorWithStatus) +} + +// Put puts the operator and its status. +func (o *OperatorRecords) Put(op *operator.Operator) { + id := op.RegionID() + record := NewOperatorWithStatus(op) + o.ttl.Put(id, record) +} + +// ExceedStoreLimit returns true if the store exceeds the cost limit after adding the operator. Otherwise, returns false. +func (oc *OperatorController) ExceedStoreLimit(ops ...*operator.Operator) bool { + oc.Lock() + defer oc.Unlock() + return oc.exceedStoreLimitLocked(ops...) +} + +// exceedStoreLimitLocked returns true if the store exceeds the cost limit after adding the operator. Otherwise, returns false. +func (oc *OperatorController) exceedStoreLimitLocked(ops ...*operator.Operator) bool { + // The operator with Urgent priority, like admin operators, should ignore the store limit check. + if len(ops) != 0 && ops[0].GetPriorityLevel() == core.Urgent { + return false + } + opInfluence := NewTotalOpInfluence(ops, oc.cluster) + for storeID := range opInfluence.StoresInfluence { + for _, v := range storelimit.TypeNameValue { + stepCost := opInfluence.GetStoreInfluence(storeID).GetStepCost(v) + if stepCost == 0 { + continue + } + limiter := oc.getOrCreateStoreLimit(storeID, v) + if limiter == nil { + return false + } + if !limiter.Available(stepCost, v) { + return true + } + } + } + return false +} + +// getOrCreateStoreLimit is used to get or create the limit of a store. +func (oc *OperatorController) getOrCreateStoreLimit(storeID uint64, limitType storelimit.Type) storelimit.StoreLimit { + ratePerSec := oc.cluster.GetOpts().GetStoreLimitByType(storeID, limitType) / StoreBalanceBaseTime + s := oc.cluster.GetStore(storeID) + if s == nil { + log.Error("invalid store ID", zap.Uint64("store-id", storeID)) + return nil + } + + limit := s.GetStoreLimit() + limit.Reset(ratePerSec, limitType) + return limit +} diff --git a/server/schedule/region_scatterer.go b/server/schedule/region_scatterer.go old mode 100644 new mode 100755 index 13e4d9215d5..7be03d8eb37 --- a/server/schedule/region_scatterer.go +++ b/server/schedule/region_scatterer.go @@ -230,11 +230,11 @@ func (r *RegionScatterer) scatterRegions(regions map[uint64]*core.RegionInfo, fa for currentRetry := 0; currentRetry <= retryLimit; currentRetry++ { for _, region := range regions { op, err := r.Scatter(region, group) - failpoint.Inject("scatterFail", func() { + if _, _err_ := failpoint.Eval(_curpkg_("scatterFail")); _err_ == nil { if region.GetID() == 1 { err = errors.New("mock error") } - }) + } if err != nil { failures[region.GetID()] = err continue @@ -247,10 +247,10 @@ func (r *RegionScatterer) scatterRegions(regions map[uint64]*core.RegionInfo, fa failures[op.RegionID()] = fmt.Errorf("region %v failed to add operator", op.RegionID()) continue } - failpoint.Inject("scatterHbStreamsDrain", func() { + if _, _err_ := failpoint.Eval(_curpkg_("scatterHbStreamsDrain")); _err_ == nil { r.opController.hbStreams.Drain(1) r.opController.RemoveOperator(op) - }) + } } delete(failures, region.GetID()) } diff --git a/server/schedule/region_scatterer.go__failpoint_stash__ b/server/schedule/region_scatterer.go__failpoint_stash__ new file mode 100644 index 00000000000..13e4d9215d5 --- /dev/null +++ b/server/schedule/region_scatterer.go__failpoint_stash__ @@ -0,0 +1,545 @@ +// Copyright 2017 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package schedule + +import ( + "context" + "fmt" + "math" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/cache" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/pkg/typeutil" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/schedule/filter" + "github.com/tikv/pd/server/schedule/operator" + "github.com/tikv/pd/server/schedule/placement" + "go.uber.org/zap" +) + +const regionScatterName = "region-scatter" + +var gcInterval = time.Minute +var gcTTL = time.Minute * 3 + +type selectedStores struct { + mu syncutil.RWMutex + groupDistribution *cache.TTLString // value type: map[uint64]uint64, group -> StoreID -> count +} + +func newSelectedStores(ctx context.Context) *selectedStores { + return &selectedStores{ + groupDistribution: cache.NewStringTTL(ctx, gcInterval, gcTTL), + } +} + +// Put plus count by storeID and group +func (s *selectedStores) Put(id uint64, group string) { + s.mu.Lock() + defer s.mu.Unlock() + distribution, ok := s.getDistributionByGroupLocked(group) + if !ok { + distribution = map[uint64]uint64{} + distribution[id] = 0 + } + distribution[id]++ + s.groupDistribution.Put(group, distribution) +} + +// Get the count by storeID and group +func (s *selectedStores) Get(id uint64, group string) uint64 { + s.mu.RLock() + defer s.mu.RUnlock() + distribution, ok := s.getDistributionByGroupLocked(group) + if !ok { + return 0 + } + count, ok := distribution[id] + if !ok { + return 0 + } + return count +} + +// GetGroupDistribution get distribution group by `group` +func (s *selectedStores) GetGroupDistribution(group string) (map[uint64]uint64, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + return s.getDistributionByGroupLocked(group) +} + +// TotalCountByStore counts the total count by store +func (s *selectedStores) TotalCountByStore(storeID uint64) uint64 { + s.mu.RLock() + defer s.mu.RUnlock() + groups := s.groupDistribution.GetAllID() + totalCount := uint64(0) + for _, group := range groups { + storeDistribution, ok := s.getDistributionByGroupLocked(group) + if !ok { + continue + } + count, ok := storeDistribution[storeID] + if !ok { + continue + } + totalCount += count + } + return totalCount +} + +// getDistributionByGroupLocked should be called with lock +func (s *selectedStores) getDistributionByGroupLocked(group string) (map[uint64]uint64, bool) { + if result, ok := s.groupDistribution.Get(group); ok { + return result.(map[uint64]uint64), true + } + return nil, false +} + +// RegionScatterer scatters regions. +type RegionScatterer struct { + ctx context.Context + name string + cluster Cluster + ordinaryEngine engineContext + specialEngines sync.Map + opController *OperatorController +} + +// NewRegionScatterer creates a region scatterer. +// RegionScatter is used for the `Lightning`, it will scatter the specified regions before import data. +func NewRegionScatterer(ctx context.Context, cluster Cluster, opController *OperatorController) *RegionScatterer { + return &RegionScatterer{ + ctx: ctx, + name: regionScatterName, + cluster: cluster, + opController: opController, + ordinaryEngine: newEngineContext(ctx, func() filter.Filter { + return filter.NewEngineFilter(regionScatterName, filter.NotSpecialEngines) + }), + } +} + +type filterFunc func() filter.Filter + +type engineContext struct { + filterFuncs []filterFunc + selectedPeer *selectedStores + selectedLeader *selectedStores +} + +func newEngineContext(ctx context.Context, filterFuncs ...filterFunc) engineContext { + filterFuncs = append(filterFuncs, func() filter.Filter { + return &filter.StoreStateFilter{ActionScope: regionScatterName, MoveRegion: true, ScatterRegion: true} + }) + return engineContext{ + filterFuncs: filterFuncs, + selectedPeer: newSelectedStores(ctx), + selectedLeader: newSelectedStores(ctx), + } +} + +const maxSleepDuration = time.Minute +const initialSleepDuration = 100 * time.Millisecond +const maxRetryLimit = 30 + +// ScatterRegionsByRange directly scatter regions by ScatterRegions +func (r *RegionScatterer) ScatterRegionsByRange(startKey, endKey []byte, group string, retryLimit int) (int, map[uint64]error, error) { + regions := r.cluster.ScanRegions(startKey, endKey, -1) + if len(regions) < 1 { + scatterCounter.WithLabelValues("skip", "empty-region").Inc() + return 0, nil, errors.New("empty region") + } + failures := make(map[uint64]error, len(regions)) + regionMap := make(map[uint64]*core.RegionInfo, len(regions)) + for _, region := range regions { + regionMap[region.GetID()] = region + } + // If there existed any region failed to relocated after retry, add it into unProcessedRegions + opsCount, err := r.scatterRegions(regionMap, failures, group, retryLimit) + if err != nil { + return 0, nil, err + } + return opsCount, failures, nil +} + +// ScatterRegionsByID directly scatter regions by ScatterRegions +func (r *RegionScatterer) ScatterRegionsByID(regionsID []uint64, group string, retryLimit int) (int, map[uint64]error, error) { + if len(regionsID) < 1 { + scatterCounter.WithLabelValues("skip", "empty-region").Inc() + return 0, nil, errors.New("empty region") + } + failures := make(map[uint64]error, len(regionsID)) + regions := make([]*core.RegionInfo, 0, len(regionsID)) + for _, id := range regionsID { + region := r.cluster.GetRegion(id) + if region == nil { + scatterCounter.WithLabelValues("skip", "no-region").Inc() + log.Warn("failed to find region during scatter", zap.Uint64("region-id", id)) + failures[id] = errors.New(fmt.Sprintf("failed to find region %v", id)) + continue + } + regions = append(regions, region) + } + regionMap := make(map[uint64]*core.RegionInfo, len(regions)) + for _, region := range regions { + regionMap[region.GetID()] = region + } + // If there existed any region failed to relocated after retry, add it into unProcessedRegions + opsCount, err := r.scatterRegions(regionMap, failures, group, retryLimit) + if err != nil { + return 0, nil, err + } + return opsCount, failures, nil +} + +// scatterRegions relocates the regions. If the group is defined, the regions' leader with the same group would be scattered +// in a group level instead of cluster level. +// RetryTimes indicates the retry times if any of the regions failed to relocate during scattering. There will be +// time.Sleep between each retry. +// Failures indicates the regions which are failed to be relocated, the key of the failures indicates the regionID +// and the value of the failures indicates the failure error. +func (r *RegionScatterer) scatterRegions(regions map[uint64]*core.RegionInfo, failures map[uint64]error, group string, retryLimit int) (int, error) { + if len(regions) < 1 { + scatterCounter.WithLabelValues("skip", "empty-region").Inc() + return 0, errors.New("empty region") + } + if retryLimit > maxRetryLimit { + retryLimit = maxRetryLimit + } + opsCount := 0 + for currentRetry := 0; currentRetry <= retryLimit; currentRetry++ { + for _, region := range regions { + op, err := r.Scatter(region, group) + failpoint.Inject("scatterFail", func() { + if region.GetID() == 1 { + err = errors.New("mock error") + } + }) + if err != nil { + failures[region.GetID()] = err + continue + } + delete(regions, region.GetID()) + opsCount++ + if op != nil { + if ok := r.opController.AddOperator(op); !ok { + // If there existed any operator failed to be added into Operator Controller, add its regions into unProcessedRegions + failures[op.RegionID()] = fmt.Errorf("region %v failed to add operator", op.RegionID()) + continue + } + failpoint.Inject("scatterHbStreamsDrain", func() { + r.opController.hbStreams.Drain(1) + r.opController.RemoveOperator(op) + }) + } + delete(failures, region.GetID()) + } + // all regions have been relocated, break the loop. + if len(regions) < 1 { + break + } + // Wait for a while if there are some regions failed to be relocated + time.Sleep(typeutil.MinDuration(maxSleepDuration, time.Duration(math.Pow(2, float64(currentRetry)))*initialSleepDuration)) + } + return opsCount, nil +} + +// Scatter relocates the region. If the group is defined, the regions' leader with the same group would be scattered +// in a group level instead of cluster level. +func (r *RegionScatterer) Scatter(region *core.RegionInfo, group string) (*operator.Operator, error) { + if !filter.IsRegionReplicated(r.cluster, region) { + r.cluster.AddSuspectRegions(region.GetID()) + scatterCounter.WithLabelValues("skip", "not-replicated").Inc() + log.Warn("region not replicated during scatter", zap.Uint64("region-id", region.GetID())) + return nil, errors.Errorf("region %d is not fully replicated", region.GetID()) + } + + if region.GetLeader() == nil { + scatterCounter.WithLabelValues("skip", "no-leader").Inc() + log.Warn("region no leader during scatter", zap.Uint64("region-id", region.GetID())) + return nil, errors.Errorf("region %d has no leader", region.GetID()) + } + + if r.cluster.IsRegionHot(region) { + scatterCounter.WithLabelValues("skip", "hot").Inc() + log.Warn("region too hot during scatter", zap.Uint64("region-id", region.GetID())) + return nil, errors.Errorf("region %d is hot", region.GetID()) + } + + return r.scatterRegion(region, group), nil +} + +func (r *RegionScatterer) scatterRegion(region *core.RegionInfo, group string) *operator.Operator { + engineFilter := filter.NewEngineFilter(r.name, filter.NotSpecialEngines) + ordinaryPeers := make(map[uint64]*metapb.Peer, len(region.GetPeers())) + specialPeers := make(map[string]map[uint64]*metapb.Peer) + oldFit := r.cluster.GetRuleManager().FitRegion(r.cluster, region) + // Group peers by the engine of their stores + for _, peer := range region.GetPeers() { + store := r.cluster.GetStore(peer.GetStoreId()) + if store == nil { + return nil + } + if engineFilter.Target(r.cluster.GetOpts(), store).IsOK() { + ordinaryPeers[peer.GetStoreId()] = peer + } else { + engine := store.GetLabelValue(core.EngineKey) + if _, ok := specialPeers[engine]; !ok { + specialPeers[engine] = make(map[uint64]*metapb.Peer) + } + specialPeers[engine][peer.GetStoreId()] = peer + } + } + + targetPeers := make(map[uint64]*metapb.Peer, len(region.GetPeers())) // StoreID -> Peer + selectedStores := make(map[uint64]struct{}, len(region.GetPeers())) // selected StoreID set + leaderCandidateStores := make([]uint64, 0, len(region.GetPeers())) // StoreID allowed to become Leader + scatterWithSameEngine := func(peers map[uint64]*metapb.Peer, context engineContext) { // peers: StoreID -> Peer + for _, peer := range peers { + if _, ok := selectedStores[peer.GetStoreId()]; ok { + if allowLeader(oldFit, peer) { + leaderCandidateStores = append(leaderCandidateStores, peer.GetStoreId()) + } + // It is both sourcePeer and targetPeer itself, no need to select. + continue + } + for { + candidates := r.selectCandidates(region, oldFit, peer.GetStoreId(), selectedStores, context) + newPeer := r.selectStore(group, peer, peer.GetStoreId(), candidates, context) + targetPeers[newPeer.GetStoreId()] = newPeer + selectedStores[newPeer.GetStoreId()] = struct{}{} + // If the selected peer is a peer other than origin peer in this region, + // it is considered that the selected peer select itself. + // This origin peer re-selects. + if _, ok := peers[newPeer.GetStoreId()]; !ok || peer.GetStoreId() == newPeer.GetStoreId() { + selectedStores[peer.GetStoreId()] = struct{}{} + if allowLeader(oldFit, peer) { + leaderCandidateStores = append(leaderCandidateStores, newPeer.GetStoreId()) + } + break + } + } + } + } + + scatterWithSameEngine(ordinaryPeers, r.ordinaryEngine) + // FIXME: target leader only considers the ordinary stores, maybe we need to consider the + // special engine stores if the engine supports to become a leader. But now there is only + // one engine, tiflash, which does not support the leader, so don't consider it for now. + targetLeader := r.selectAvailableLeaderStore(group, region, leaderCandidateStores, r.ordinaryEngine) + if targetLeader == 0 { + scatterCounter.WithLabelValues("no-leader", "").Inc() + return nil + } + + for engine, peers := range specialPeers { + ctx, ok := r.specialEngines.Load(engine) + if !ok { + ctx = newEngineContext(r.ctx, func() filter.Filter { + return filter.NewEngineFilter(r.name, placement.LabelConstraint{Key: core.EngineKey, Op: placement.In, Values: []string{engine}}) + }) + r.specialEngines.Store(engine, ctx) + } + scatterWithSameEngine(peers, ctx.(engineContext)) + } + + if isSameDistribution(region, targetPeers, targetLeader) { + scatterCounter.WithLabelValues("unnecessary", "").Inc() + r.Put(targetPeers, targetLeader, group) + return nil + } + op, err := operator.CreateScatterRegionOperator("scatter-region", r.cluster, region, targetPeers, targetLeader) + if err != nil { + scatterCounter.WithLabelValues("fail", "").Inc() + for _, peer := range region.GetPeers() { + targetPeers[peer.GetStoreId()] = peer + } + r.Put(targetPeers, region.GetLeader().GetStoreId(), group) + log.Debug("fail to create scatter region operator", errs.ZapError(err)) + return nil + } + if op != nil { + scatterCounter.WithLabelValues("success", "").Inc() + r.Put(targetPeers, targetLeader, group) + op.SetPriorityLevel(core.High) + } + return op +} + +func allowLeader(fit *placement.RegionFit, peer *metapb.Peer) bool { + switch peer.GetRole() { + case metapb.PeerRole_Learner, metapb.PeerRole_DemotingVoter: + return false + } + if peer.IsWitness { + return false + } + peerFit := fit.GetRuleFit(peer.GetId()) + if peerFit == nil || peerFit.Rule == nil { + return false + } + if peerFit.Rule.IsWitness { + return false + } + switch peerFit.Rule.Role { + case placement.Voter, placement.Leader: + return true + } + return false +} + +func isSameDistribution(region *core.RegionInfo, targetPeers map[uint64]*metapb.Peer, targetLeader uint64) bool { + peers := region.GetPeers() + for _, peer := range peers { + if _, ok := targetPeers[peer.GetStoreId()]; !ok { + return false + } + } + return region.GetLeader().GetStoreId() == targetLeader +} + +func (r *RegionScatterer) selectCandidates(region *core.RegionInfo, oldFit *placement.RegionFit, sourceStoreID uint64, selectedStores map[uint64]struct{}, context engineContext) []uint64 { + sourceStore := r.cluster.GetStore(sourceStoreID) + if sourceStore == nil { + log.Error("failed to get the store", zap.Uint64("store-id", sourceStoreID), errs.ZapError(errs.ErrGetSourceStore)) + return nil + } + filters := []filter.Filter{ + filter.NewExcludedFilter(r.name, nil, selectedStores), + } + scoreGuard := filter.NewPlacementSafeguard(r.name, r.cluster.GetOpts(), r.cluster.GetBasicCluster(), r.cluster.GetRuleManager(), region, sourceStore, oldFit) + for _, filterFunc := range context.filterFuncs { + filters = append(filters, filterFunc()) + } + filters = append(filters, scoreGuard) + stores := r.cluster.GetStores() + candidates := make([]uint64, 0) + maxStoreTotalCount := uint64(0) + minStoreTotalCount := uint64(math.MaxUint64) + for _, store := range stores { + count := context.selectedPeer.TotalCountByStore(store.GetID()) + if count > maxStoreTotalCount { + maxStoreTotalCount = count + } + if count < minStoreTotalCount { + minStoreTotalCount = count + } + } + for _, store := range stores { + storeCount := context.selectedPeer.TotalCountByStore(store.GetID()) + // If storeCount is equal to the maxStoreTotalCount, we should skip this store as candidate. + // If the storeCount are all the same for the whole cluster(maxStoreTotalCount == minStoreTotalCount), any store + // could be selected as candidate. + if storeCount < maxStoreTotalCount || maxStoreTotalCount == minStoreTotalCount { + if filter.Target(r.cluster.GetOpts(), store, filters) { + candidates = append(candidates, store.GetID()) + } + } + } + return candidates +} + +func (r *RegionScatterer) selectStore(group string, peer *metapb.Peer, sourceStoreID uint64, candidates []uint64, context engineContext) *metapb.Peer { + if len(candidates) < 1 { + return peer + } + var newPeer *metapb.Peer + minCount := uint64(math.MaxUint64) + for _, storeID := range candidates { + count := context.selectedPeer.Get(storeID, group) + if count < minCount { + minCount = count + newPeer = &metapb.Peer{ + StoreId: storeID, + Role: peer.GetRole(), + } + } + } + // if the source store have the least count, we don't need to scatter this peer + for _, storeID := range candidates { + if storeID == sourceStoreID && context.selectedPeer.Get(sourceStoreID, group) <= minCount { + return peer + } + } + if newPeer == nil { + return peer + } + return newPeer +} + +// selectAvailableLeaderStore select the target leader store from the candidates. The candidates would be collected by +// the existed peers store depended on the leader counts in the group level. Please use this func before scatter spacial engines. +func (r *RegionScatterer) selectAvailableLeaderStore(group string, region *core.RegionInfo, leaderCandidateStores []uint64, context engineContext) uint64 { + sourceStore := r.cluster.GetStore(region.GetLeader().GetStoreId()) + if sourceStore == nil { + log.Error("failed to get the store", zap.Uint64("store-id", region.GetLeader().GetStoreId()), errs.ZapError(errs.ErrGetSourceStore)) + return 0 + } + minStoreGroupLeader := uint64(math.MaxUint64) + id := uint64(0) + for _, storeID := range leaderCandidateStores { + store := r.cluster.GetStore(storeID) + if store == nil { + continue + } + storeGroupLeaderCount := context.selectedLeader.Get(storeID, group) + if minStoreGroupLeader > storeGroupLeaderCount { + minStoreGroupLeader = storeGroupLeaderCount + id = storeID + } + } + return id +} + +// Put put the final distribution in the context no matter the operator was created +func (r *RegionScatterer) Put(peers map[uint64]*metapb.Peer, leaderStoreID uint64, group string) { + engineFilter := filter.NewEngineFilter(r.name, filter.NotSpecialEngines) + // Group peers by the engine of their stores + for _, peer := range peers { + storeID := peer.GetStoreId() + store := r.cluster.GetStore(storeID) + if store == nil { + continue + } + if engineFilter.Target(r.cluster.GetOpts(), store).IsOK() { + r.ordinaryEngine.selectedPeer.Put(storeID, group) + scatterDistributionCounter.WithLabelValues( + fmt.Sprintf("%v", storeID), + fmt.Sprintf("%v", false), + core.EngineTiKV).Inc() + } else { + engine := store.GetLabelValue(core.EngineKey) + ctx, _ := r.specialEngines.Load(engine) + ctx.(engineContext).selectedPeer.Put(storeID, group) + scatterDistributionCounter.WithLabelValues( + fmt.Sprintf("%v", storeID), + fmt.Sprintf("%v", false), + engine).Inc() + } + } + r.ordinaryEngine.selectedLeader.Put(leaderStoreID, group) + scatterDistributionCounter.WithLabelValues( + fmt.Sprintf("%v", leaderStoreID), + fmt.Sprintf("%v", true), + core.EngineTiKV).Inc() +} diff --git a/server/schedulers/binding__failpoint_binding__.go b/server/schedulers/binding__failpoint_binding__.go new file mode 100755 index 00000000000..7ae23dc7ba6 --- /dev/null +++ b/server/schedulers/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package schedulers + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/schedulers/evict_leader.go b/server/schedulers/evict_leader.go old mode 100644 new mode 100755 index c13c4eb4c70..75bb1f45329 --- a/server/schedulers/evict_leader.go +++ b/server/schedulers/evict_leader.go @@ -133,9 +133,9 @@ func (conf *evictLeaderSchedulerConfig) Persist() error { conf.mu.RLock() defer conf.mu.RUnlock() data, err := schedule.EncodeConfig(conf) - failpoint.Inject("persistFail", func() { + if _, _err_ := failpoint.Eval(_curpkg_("persistFail")); _err_ == nil { err = errors.New("fail to persist") - }) + } if err != nil { return err } diff --git a/server/schedulers/evict_leader.go__failpoint_stash__ b/server/schedulers/evict_leader.go__failpoint_stash__ new file mode 100644 index 00000000000..c13c4eb4c70 --- /dev/null +++ b/server/schedulers/evict_leader.go__failpoint_stash__ @@ -0,0 +1,455 @@ +// Copyright 2017 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package schedulers + +import ( + "net/http" + "strconv" + + "github.com/gorilla/mux" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/schedule" + "github.com/tikv/pd/server/schedule/filter" + "github.com/tikv/pd/server/schedule/operator" + "github.com/tikv/pd/server/schedule/plan" + "github.com/tikv/pd/server/storage/endpoint" + "github.com/unrolled/render" +) + +const ( + // EvictLeaderName is evict leader scheduler name. + EvictLeaderName = "evict-leader-scheduler" + // EvictLeaderType is evict leader scheduler type. + EvictLeaderType = "evict-leader" + // EvictLeaderBatchSize is the number of operators to to transfer + // leaders by one scheduling + EvictLeaderBatchSize = 3 + lastStoreDeleteInfo = "The last store has been deleted" +) + +func init() { + schedule.RegisterSliceDecoderBuilder(EvictLeaderType, func(args []string) schedule.ConfigDecoder { + return func(v interface{}) error { + if len(args) != 1 { + return errs.ErrSchedulerConfig.FastGenByArgs("id") + } + conf, ok := v.(*evictLeaderSchedulerConfig) + if !ok { + return errs.ErrScheduleConfigNotExist.FastGenByArgs() + } + + id, err := strconv.ParseUint(args[0], 10, 64) + if err != nil { + return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + } + + ranges, err := getKeyRanges(args[1:]) + if err != nil { + return err + } + conf.StoreIDWithRanges[id] = ranges + return nil + } + }) + + schedule.RegisterScheduler(EvictLeaderType, func(opController *schedule.OperatorController, storage endpoint.ConfigStorage, decoder schedule.ConfigDecoder) (schedule.Scheduler, error) { + conf := &evictLeaderSchedulerConfig{StoreIDWithRanges: make(map[uint64][]core.KeyRange), storage: storage} + if err := decoder(conf); err != nil { + return nil, err + } + conf.cluster = opController.GetCluster() + return newEvictLeaderScheduler(opController, conf), nil + }) +} + +type evictLeaderSchedulerConfig struct { + mu syncutil.RWMutex + storage endpoint.ConfigStorage + StoreIDWithRanges map[uint64][]core.KeyRange `json:"store-id-ranges"` + cluster schedule.Cluster +} + +func (conf *evictLeaderSchedulerConfig) getStores() []uint64 { + conf.mu.RLock() + defer conf.mu.RUnlock() + stores := make([]uint64, 0, len(conf.StoreIDWithRanges)) + for storeID := range conf.StoreIDWithRanges { + stores = append(stores, storeID) + } + return stores +} + +func (conf *evictLeaderSchedulerConfig) BuildWithArgs(args []string) error { + if len(args) != 1 { + return errs.ErrSchedulerConfig.FastGenByArgs("id") + } + + id, err := strconv.ParseUint(args[0], 10, 64) + if err != nil { + return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + } + ranges, err := getKeyRanges(args[1:]) + if err != nil { + return err + } + conf.mu.Lock() + defer conf.mu.Unlock() + conf.StoreIDWithRanges[id] = ranges + return nil +} + +func (conf *evictLeaderSchedulerConfig) Clone() *evictLeaderSchedulerConfig { + conf.mu.RLock() + defer conf.mu.RUnlock() + storeIDWithRanges := make(map[uint64][]core.KeyRange) + for id, ranges := range conf.StoreIDWithRanges { + storeIDWithRanges[id] = append(storeIDWithRanges[id], ranges...) + } + return &evictLeaderSchedulerConfig{ + StoreIDWithRanges: storeIDWithRanges, + } +} + +func (conf *evictLeaderSchedulerConfig) Persist() error { + name := conf.getSchedulerName() + conf.mu.RLock() + defer conf.mu.RUnlock() + data, err := schedule.EncodeConfig(conf) + failpoint.Inject("persistFail", func() { + err = errors.New("fail to persist") + }) + if err != nil { + return err + } + return conf.storage.SaveScheduleConfig(name, data) +} + +func (conf *evictLeaderSchedulerConfig) getSchedulerName() string { + return EvictLeaderName +} + +func (conf *evictLeaderSchedulerConfig) getRanges(id uint64) []string { + conf.mu.RLock() + defer conf.mu.RUnlock() + ranges := conf.StoreIDWithRanges[id] + res := make([]string, 0, len(ranges)*2) + for index := range ranges { + res = append(res, (string)(ranges[index].StartKey), (string)(ranges[index].EndKey)) + } + return res +} + +func (conf *evictLeaderSchedulerConfig) removeStore(id uint64) (succ bool, last bool) { + conf.mu.Lock() + defer conf.mu.Unlock() + _, exists := conf.StoreIDWithRanges[id] + succ, last = false, false + if exists { + delete(conf.StoreIDWithRanges, id) + conf.cluster.ResumeLeaderTransfer(id) + succ = true + last = len(conf.StoreIDWithRanges) == 0 + } + return succ, last +} + +func (conf *evictLeaderSchedulerConfig) resetStore(id uint64, keyRange []core.KeyRange) { + conf.mu.Lock() + defer conf.mu.Unlock() + conf.cluster.PauseLeaderTransfer(id) + conf.StoreIDWithRanges[id] = keyRange +} + +func (conf *evictLeaderSchedulerConfig) getKeyRangesByID(id uint64) []core.KeyRange { + conf.mu.RLock() + defer conf.mu.RUnlock() + if ranges, exist := conf.StoreIDWithRanges[id]; exist { + return ranges + } + return nil +} + +type evictLeaderScheduler struct { + *BaseScheduler + conf *evictLeaderSchedulerConfig + handler http.Handler +} + +// newEvictLeaderScheduler creates an admin scheduler that transfers all leaders +// out of a store. +func newEvictLeaderScheduler(opController *schedule.OperatorController, conf *evictLeaderSchedulerConfig) schedule.Scheduler { + base := NewBaseScheduler(opController) + handler := newEvictLeaderHandler(conf) + return &evictLeaderScheduler{ + BaseScheduler: base, + conf: conf, + handler: handler, + } +} + +// EvictStores returns the IDs of the evict-stores. +func (s *evictLeaderScheduler) EvictStoreIDs() []uint64 { + return s.conf.getStores() +} + +func (s *evictLeaderScheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.handler.ServeHTTP(w, r) +} + +func (s *evictLeaderScheduler) GetName() string { + return EvictLeaderName +} + +func (s *evictLeaderScheduler) GetType() string { + return EvictLeaderType +} + +func (s *evictLeaderScheduler) EncodeConfig() ([]byte, error) { + s.conf.mu.RLock() + defer s.conf.mu.RUnlock() + return schedule.EncodeConfig(s.conf) +} + +func (s *evictLeaderScheduler) Prepare(cluster schedule.Cluster) error { + s.conf.mu.RLock() + defer s.conf.mu.RUnlock() + var res error + for id := range s.conf.StoreIDWithRanges { + if err := cluster.PauseLeaderTransfer(id); err != nil { + res = err + } + } + return res +} + +func (s *evictLeaderScheduler) Cleanup(cluster schedule.Cluster) { + s.conf.mu.RLock() + defer s.conf.mu.RUnlock() + for id := range s.conf.StoreIDWithRanges { + cluster.ResumeLeaderTransfer(id) + } +} + +func (s *evictLeaderScheduler) IsScheduleAllowed(cluster schedule.Cluster) bool { + allowed := s.OpController.OperatorCount(operator.OpLeader) < cluster.GetOpts().GetLeaderScheduleLimit() + if !allowed { + operator.OperatorLimitCounter.WithLabelValues(s.GetType(), operator.OpLeader.String()).Inc() + } + return allowed +} + +func (s *evictLeaderScheduler) Schedule(cluster schedule.Cluster, dryRun bool) ([]*operator.Operator, []plan.Plan) { + schedulerCounter.WithLabelValues(s.GetName(), "schedule").Inc() + return scheduleEvictLeaderBatch(s.GetName(), s.GetType(), cluster, s.conf, EvictLeaderBatchSize), nil +} + +func uniqueAppendOperator(dst []*operator.Operator, src ...*operator.Operator) []*operator.Operator { + regionIDs := make(map[uint64]struct{}) + for i := range dst { + regionIDs[dst[i].RegionID()] = struct{}{} + } + for i := range src { + if _, ok := regionIDs[src[i].RegionID()]; ok { + continue + } + regionIDs[src[i].RegionID()] = struct{}{} + dst = append(dst, src[i]) + } + return dst +} + +type evictLeaderStoresConf interface { + getStores() []uint64 + getKeyRangesByID(id uint64) []core.KeyRange +} + +func scheduleEvictLeaderBatch(name, typ string, cluster schedule.Cluster, conf evictLeaderStoresConf, batchSize int) []*operator.Operator { + var ops []*operator.Operator + for i := 0; i < batchSize; i++ { + once := scheduleEvictLeaderOnce(name, typ, cluster, conf) + // no more regions + if len(once) == 0 { + break + } + ops = uniqueAppendOperator(ops, once...) + // the batch has been fulfilled + if len(ops) > batchSize { + break + } + } + return ops +} + +func scheduleEvictLeaderOnce(name, typ string, cluster schedule.Cluster, conf evictLeaderStoresConf) []*operator.Operator { + stores := conf.getStores() + ops := make([]*operator.Operator, 0, len(stores)) + for _, storeID := range stores { + ranges := conf.getKeyRangesByID(storeID) + if len(ranges) == 0 { + continue + } + var filters []filter.Filter + pendingFilter := filter.NewRegionPendingFilter() + downFilter := filter.NewRegionDownFilter() + region := filter.SelectOneRegion(cluster.RandLeaderRegions(storeID, ranges), nil, pendingFilter, downFilter) + if region == nil { + // try to pick unhealthy region + region = filter.SelectOneRegion(cluster.RandLeaderRegions(storeID, ranges), nil) + if region == nil { + schedulerCounter.WithLabelValues(name, "no-leader").Inc() + continue + } + schedulerCounter.WithLabelValues(name, "pick-unhealthy-region").Inc() + unhealthyPeerStores := make(map[uint64]struct{}) + for _, peer := range region.GetDownPeers() { + unhealthyPeerStores[peer.GetPeer().GetStoreId()] = struct{}{} + } + for _, peer := range region.GetPendingPeers() { + unhealthyPeerStores[peer.GetStoreId()] = struct{}{} + } + filters = append(filters, filter.NewExcludedFilter(name, nil, unhealthyPeerStores)) + } + + filters = append(filters, &filter.StoreStateFilter{ActionScope: name, TransferLeader: true}) + candidates := filter.NewCandidates(cluster.GetFollowerStores(region)). + FilterTarget(cluster.GetOpts(), nil, nil, filters...) + // Compatible with old TiKV transfer leader logic. + target := candidates.RandomPick() + targets := candidates.PickAll() + // `targets` MUST contains `target`, so only needs to check if `target` is nil here. + if target == nil { + schedulerCounter.WithLabelValues(name, "no-target-store").Inc() + continue + } + targetIDs := make([]uint64, 0, len(targets)) + for _, t := range targets { + targetIDs = append(targetIDs, t.GetID()) + } + op, err := operator.CreateTransferLeaderOperator(typ, cluster, region, region.GetLeader().GetStoreId(), target.GetID(), targetIDs, operator.OpLeader) + if err != nil { + log.Debug("fail to create evict leader operator", errs.ZapError(err)) + continue + } + op.SetPriorityLevel(core.Urgent) + op.Counters = append(op.Counters, schedulerCounter.WithLabelValues(name, "new-operator")) + ops = append(ops, op) + } + return ops +} + +type evictLeaderHandler struct { + rd *render.Render + config *evictLeaderSchedulerConfig +} + +func (handler *evictLeaderHandler) UpdateConfig(w http.ResponseWriter, r *http.Request) { + var input map[string]interface{} + if err := apiutil.ReadJSONRespondError(handler.rd, w, r.Body, &input); err != nil { + return + } + var args []string + var exists bool + var id uint64 + idFloat, ok := input["store_id"].(float64) + if ok { + id = (uint64)(idFloat) + handler.config.mu.RLock() + if _, exists = handler.config.StoreIDWithRanges[id]; !exists { + if err := handler.config.cluster.PauseLeaderTransfer(id); err != nil { + handler.config.mu.RUnlock() + handler.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + } + handler.config.mu.RUnlock() + args = append(args, strconv.FormatUint(id, 10)) + } + + ranges, ok := (input["ranges"]).([]string) + if ok { + args = append(args, ranges...) + } else if exists { + args = append(args, handler.config.getRanges(id)...) + } + + handler.config.BuildWithArgs(args) + err := handler.config.Persist() + if err != nil { + handler.config.removeStore(id) + handler.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + handler.rd.JSON(w, http.StatusOK, nil) +} + +func (handler *evictLeaderHandler) ListConfig(w http.ResponseWriter, r *http.Request) { + conf := handler.config.Clone() + handler.rd.JSON(w, http.StatusOK, conf) +} + +func (handler *evictLeaderHandler) DeleteConfig(w http.ResponseWriter, r *http.Request) { + idStr := mux.Vars(r)["store_id"] + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + handler.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + + var resp interface{} + keyRanges := handler.config.getKeyRangesByID(id) + succ, last := handler.config.removeStore(id) + if succ { + err = handler.config.Persist() + if err != nil { + handler.config.resetStore(id, keyRanges) + handler.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + if last { + if err := handler.config.cluster.RemoveScheduler(EvictLeaderName); err != nil { + if errors.ErrorEqual(err, errs.ErrSchedulerNotFound.FastGenByArgs()) { + handler.rd.JSON(w, http.StatusNotFound, err.Error()) + } else { + handler.config.resetStore(id, keyRanges) + handler.rd.JSON(w, http.StatusInternalServerError, err.Error()) + } + return + } + resp = lastStoreDeleteInfo + } + handler.rd.JSON(w, http.StatusOK, resp) + return + } + + handler.rd.JSON(w, http.StatusNotFound, errs.ErrScheduleConfigNotExist.FastGenByArgs().Error()) +} + +func newEvictLeaderHandler(config *evictLeaderSchedulerConfig) http.Handler { + h := &evictLeaderHandler{ + config: config, + rd: render.New(render.Options{IndentJSON: true}), + } + router := mux.NewRouter() + router.HandleFunc("/config", h.UpdateConfig).Methods(http.MethodPost) + router.HandleFunc("/list", h.ListConfig).Methods(http.MethodGet) + router.HandleFunc("/delete/{store_id}", h.DeleteConfig).Methods(http.MethodDelete) + return router +} diff --git a/server/schedulers/evict_slow_store.go b/server/schedulers/evict_slow_store.go old mode 100644 new mode 100755 index 606da127f68..cf50f8c74b1 --- a/server/schedulers/evict_slow_store.go +++ b/server/schedulers/evict_slow_store.go @@ -60,9 +60,9 @@ type evictSlowStoreSchedulerConfig struct { func (conf *evictSlowStoreSchedulerConfig) Persist() error { name := conf.getSchedulerName() data, err := schedule.EncodeConfig(conf) - failpoint.Inject("persistFail", func() { + if _, _err_ := failpoint.Eval(_curpkg_("persistFail")); _err_ == nil { err = errors.New("fail to persist") - }) + } if err != nil { return err } diff --git a/server/schedulers/evict_slow_store.go__failpoint_stash__ b/server/schedulers/evict_slow_store.go__failpoint_stash__ new file mode 100644 index 00000000000..606da127f68 --- /dev/null +++ b/server/schedulers/evict_slow_store.go__failpoint_stash__ @@ -0,0 +1,234 @@ +// Copyright 2021 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package schedulers + +import ( + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/schedule" + "github.com/tikv/pd/server/schedule/operator" + "github.com/tikv/pd/server/schedule/plan" + "github.com/tikv/pd/server/storage/endpoint" + "go.uber.org/zap" +) + +const ( + // EvictSlowStoreName is evict leader scheduler name. + EvictSlowStoreName = "evict-slow-store-scheduler" + // EvictSlowStoreType is evict leader scheduler type. + EvictSlowStoreType = "evict-slow-store" + + slowStoreEvictThreshold = 100 + slowStoreRecoverThreshold = 1 +) + +func init() { + schedule.RegisterSliceDecoderBuilder(EvictSlowStoreType, func(args []string) schedule.ConfigDecoder { + return func(v interface{}) error { + return nil + } + }) + + schedule.RegisterScheduler(EvictSlowStoreType, func(opController *schedule.OperatorController, storage endpoint.ConfigStorage, decoder schedule.ConfigDecoder) (schedule.Scheduler, error) { + conf := &evictSlowStoreSchedulerConfig{storage: storage, EvictedStores: make([]uint64, 0)} + if err := decoder(conf); err != nil { + return nil, err + } + return newEvictSlowStoreScheduler(opController, conf), nil + }) +} + +type evictSlowStoreSchedulerConfig struct { + storage endpoint.ConfigStorage + EvictedStores []uint64 `json:"evict-stores"` +} + +func (conf *evictSlowStoreSchedulerConfig) Persist() error { + name := conf.getSchedulerName() + data, err := schedule.EncodeConfig(conf) + failpoint.Inject("persistFail", func() { + err = errors.New("fail to persist") + }) + if err != nil { + return err + } + return conf.storage.SaveScheduleConfig(name, data) +} + +func (conf *evictSlowStoreSchedulerConfig) getSchedulerName() string { + return EvictSlowStoreName +} + +func (conf *evictSlowStoreSchedulerConfig) getStores() []uint64 { + return conf.EvictedStores +} + +func (conf *evictSlowStoreSchedulerConfig) getKeyRangesByID(id uint64) []core.KeyRange { + if conf.evictStore() != id { + return nil + } + return []core.KeyRange{core.NewKeyRange("", "")} +} + +func (conf *evictSlowStoreSchedulerConfig) evictStore() uint64 { + if len(conf.EvictedStores) == 0 { + return 0 + } + return conf.EvictedStores[0] +} + +func (conf *evictSlowStoreSchedulerConfig) setStoreAndPersist(id uint64) error { + conf.EvictedStores = []uint64{id} + return conf.Persist() +} + +func (conf *evictSlowStoreSchedulerConfig) clearAndPersist() (oldID uint64, err error) { + oldID = conf.evictStore() + if oldID > 0 { + conf.EvictedStores = []uint64{} + err = conf.Persist() + } + return +} + +type evictSlowStoreScheduler struct { + *BaseScheduler + conf *evictSlowStoreSchedulerConfig +} + +func (s *evictSlowStoreScheduler) GetName() string { + return EvictSlowStoreName +} + +func (s *evictSlowStoreScheduler) GetType() string { + return EvictSlowStoreType +} + +func (s *evictSlowStoreScheduler) EncodeConfig() ([]byte, error) { + return schedule.EncodeConfig(s.conf) +} + +func (s *evictSlowStoreScheduler) Prepare(cluster schedule.Cluster) error { + evictStore := s.conf.evictStore() + if evictStore != 0 { + return cluster.SlowStoreEvicted(evictStore) + } + return nil +} + +func (s *evictSlowStoreScheduler) Cleanup(cluster schedule.Cluster) { + s.cleanupEvictLeader(cluster) +} + +func (s *evictSlowStoreScheduler) prepareEvictLeader(cluster schedule.Cluster, storeID uint64) error { + err := s.conf.setStoreAndPersist(storeID) + if err != nil { + log.Info("evict-slow-store-scheduler persist config failed", zap.Uint64("store-id", storeID)) + return err + } + + return cluster.SlowStoreEvicted(storeID) +} + +func (s *evictSlowStoreScheduler) cleanupEvictLeader(cluster schedule.Cluster) { + evictSlowStore, err := s.conf.clearAndPersist() + if err != nil { + log.Info("evict-slow-store-scheduler persist config failed", zap.Uint64("store-id", evictSlowStore)) + } + if evictSlowStore == 0 { + return + } + cluster.SlowStoreRecovered(evictSlowStore) +} + +func (s *evictSlowStoreScheduler) schedulerEvictLeader(cluster schedule.Cluster) []*operator.Operator { + return scheduleEvictLeaderBatch(s.GetName(), s.GetType(), cluster, s.conf, EvictLeaderBatchSize) +} + +func (s *evictSlowStoreScheduler) IsScheduleAllowed(cluster schedule.Cluster) bool { + if s.conf.evictStore() != 0 { + allowed := s.OpController.OperatorCount(operator.OpLeader) < cluster.GetOpts().GetLeaderScheduleLimit() + if !allowed { + operator.OperatorLimitCounter.WithLabelValues(s.GetType(), operator.OpLeader.String()).Inc() + } + return allowed + } + return true +} + +func (s *evictSlowStoreScheduler) Schedule(cluster schedule.Cluster, dryRun bool) ([]*operator.Operator, []plan.Plan) { + schedulerCounter.WithLabelValues(s.GetName(), "schedule").Inc() + var ops []*operator.Operator + + if s.conf.evictStore() != 0 { + store := cluster.GetStore(s.conf.evictStore()) + if store == nil || store.IsRemoved() { + // Previous slow store had been removed, remove the scheduler and check + // slow node next time. + log.Info("slow store has been removed", + zap.Uint64("store-id", store.GetID())) + } else if store.GetSlowScore() <= slowStoreRecoverThreshold { + log.Info("slow store has been recovered", + zap.Uint64("store-id", store.GetID())) + } else { + return s.schedulerEvictLeader(cluster), nil + } + s.cleanupEvictLeader(cluster) + return ops, nil + } + + var slowStore *core.StoreInfo + + for _, store := range cluster.GetStores() { + if store.IsRemoved() { + continue + } + + if (store.IsPreparing() || store.IsServing()) && store.IsSlow() { + // Do nothing if there is more than one slow store. + if slowStore != nil { + return ops, nil + } + slowStore = store + } + } + + if slowStore == nil || slowStore.GetSlowScore() < slowStoreEvictThreshold { + return ops, nil + } + + // If there is only one slow store, evict leaders from that store. + log.Info("detected slow store, start to evict leaders", + zap.Uint64("store-id", slowStore.GetID())) + err := s.prepareEvictLeader(cluster, slowStore.GetID()) + if err != nil { + log.Info("prepare for evicting leader failed", zap.Error(err), zap.Uint64("store-id", slowStore.GetID())) + return ops, nil + } + return s.schedulerEvictLeader(cluster), nil +} + +// newEvictSlowStoreScheduler creates a scheduler that detects and evicts slow stores. +func newEvictSlowStoreScheduler(opController *schedule.OperatorController, conf *evictSlowStoreSchedulerConfig) schedule.Scheduler { + base := NewBaseScheduler(opController) + + s := &evictSlowStoreScheduler{ + BaseScheduler: base, + conf: conf, + } + return s +} diff --git a/server/server.go b/server/server.go old mode 100644 new mode 100755 index 385c5cb01bc..b2c5ae0210e --- a/server/server.go +++ b/server/server.go @@ -364,9 +364,9 @@ func (s *Server) startEtcd(ctx context.Context) error { }, } - failpoint.Inject("memberNil", func() { + if _, _err_ := failpoint.Eval(_curpkg_("memberNil")); _err_ == nil { time.Sleep(1500 * time.Millisecond) - }) + } s.member = member.NewMember(etcd, client, etcdServerID) return nil } @@ -717,7 +717,7 @@ func (s *Server) createRaftCluster() error { } func (s *Server) stopRaftCluster() { - failpoint.Inject("raftclusterIsBusy", func() {}) + failpoint.Eval(_curpkg_("raftclusterIsBusy")) s.cluster.Stop() } @@ -1465,11 +1465,11 @@ func (s *Server) campaignLeader() { } defer func() { s.tsoAllocatorManager.ResetAllocatorGroup(tso.GlobalDCLocation) - failpoint.Inject("updateAfterResetTSO", func() { + if _, _err_ := failpoint.Eval(_curpkg_("updateAfterResetTSO")); _err_ == nil { if err = allocator.UpdateTSO(); err != nil { panic(err) } - }) + } }() if err := s.reloadConfigFromKV(); err != nil { @@ -1527,14 +1527,14 @@ func (s *Server) campaignLeader() { return } // add failpoint to test exit leader, failpoint judge the member is the give value, then break - failpoint.Inject("exitCampaignLeader", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("exitCampaignLeader")); _err_ == nil { memberString := val.(string) memberID, _ := strconv.ParseUint(memberString, 10, 64) - if s.member.ID() == memberID { - log.Info("exit PD leader") - failpoint.Return() + if memberID == s.member.ID() { + log.Info("exit PD leader", zap.Uint64("member-id", s.member.ID())) + return } - }) + } case <-ctx.Done(): // Server is closed and it should return nil. log.Info("server is closed") diff --git a/server/server.go__failpoint_stash__ b/server/server.go__failpoint_stash__ new file mode 100644 index 00000000000..385c5cb01bc --- /dev/null +++ b/server/server.go__failpoint_stash__ @@ -0,0 +1,1724 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/coreos/go-semver/semver" + "github.com/gorilla/mux" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/diagnosticspb" + "github.com/pingcap/kvproto/pkg/keyspacepb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/pingcap/sysutil" + "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/audit" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/etcdutil" + "github.com/tikv/pd/pkg/grpcutil" + "github.com/tikv/pd/pkg/jsonutil" + "github.com/tikv/pd/pkg/logutil" + "github.com/tikv/pd/pkg/ratelimit" + "github.com/tikv/pd/pkg/systimemon" + "github.com/tikv/pd/pkg/tsoutil" + "github.com/tikv/pd/pkg/typeutil" + "github.com/tikv/pd/server/cluster" + "github.com/tikv/pd/server/config" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/encryptionkm" + "github.com/tikv/pd/server/gc" + "github.com/tikv/pd/server/id" + "github.com/tikv/pd/server/keyspace" + "github.com/tikv/pd/server/member" + syncer "github.com/tikv/pd/server/region_syncer" + "github.com/tikv/pd/server/schedule" + "github.com/tikv/pd/server/schedule/hbstream" + "github.com/tikv/pd/server/schedule/placement" + "github.com/tikv/pd/server/storage" + "github.com/tikv/pd/server/storage/endpoint" + "github.com/tikv/pd/server/storage/kv" + "github.com/tikv/pd/server/tso" + "github.com/tikv/pd/server/versioninfo" + "github.com/urfave/negroni" + "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/embed" + "go.etcd.io/etcd/pkg/types" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +const ( + etcdTimeout = time.Second * 3 + serverMetricsInterval = time.Minute + leaderTickInterval = 50 * time.Millisecond + // pdRootPath for all pd servers. + pdRootPath = "/pd" + pdAPIPrefix = "/pd/" + pdClusterIDPath = "/pd/cluster_id" + // idAllocPath for idAllocator to save persistent window's end. + idAllocPath = "alloc_id" + idAllocLabel = "idalloc" + + recoveringMarkPath = "cluster/markers/snapshot-recovering" +) + +// EtcdStartTimeout the timeout of the startup etcd. +var EtcdStartTimeout = time.Minute * 5 + +// Server is the pd server. +// nolint +type Server struct { + diagnosticspb.DiagnosticsServer + + // Server state. + isServing int64 + + // Server start timestamp + startTimestamp int64 + + // Configs and initial fields. + cfg *config.Config + serviceMiddlewareCfg *config.ServiceMiddlewareConfig + etcdCfg *embed.Config + serviceMiddlewarePersistOptions *config.ServiceMiddlewarePersistOptions + persistOptions *config.PersistOptions + handler *Handler + + ctx context.Context + serverLoopCtx context.Context + serverLoopCancel func() + serverLoopWg sync.WaitGroup + + // for PD leader election. + member *member.Member + // etcd client + client *clientv3.Client + // http client + httpClient *http.Client + clusterID uint64 // pd cluster id. + rootPath string + + // Server services. + // for id allocator, we can use one allocator for + // store, region and peer, because we just need + // a unique ID. + idAllocator id.Allocator + // for encryption + encryptionKeyManager *encryptionkm.KeyManager + // for storage operation. + storage storage.Storage + // safepoint manager + gcSafePointManager *gc.SafePointManager + // keyspace manager + keyspaceManager *keyspace.Manager + // for basicCluster operation. + basicCluster *core.BasicCluster + // for tso. + tsoAllocatorManager *tso.AllocatorManager + // for raft cluster + cluster *cluster.RaftCluster + // For async region heartbeat. + hbStreams *hbstream.HeartbeatStreams + // Zap logger + lg *zap.Logger + logProps *log.ZapProperties + + // Add callback functions at different stages + startCallbacks []func() + closeCallbacks []func() + + // hot region history info storeage + hotRegionStorage *storage.HotRegionStorage + // Store as map[string]*grpc.ClientConn + clientConns sync.Map + // tsoDispatcher is used to dispatch different TSO requests to + // the corresponding forwarding TSO channel. + tsoDispatcher sync.Map /* Store as map[string]chan *tsoRequest */ + + serviceRateLimiter *ratelimit.Limiter + serviceLabels map[string][]apiutil.AccessPath + apiServiceLabelMap map[apiutil.AccessPath]string + + serviceAuditBackendLabels map[string]*audit.BackendLabels + + auditBackends []audit.Backend +} + +// HandlerBuilder builds a server HTTP handler. +type HandlerBuilder func(context.Context, *Server) (http.Handler, ServiceGroup, error) + +// ServiceGroup used to register the service. +type ServiceGroup struct { + Name string + Version string + IsCore bool + PathPrefix string +} + +const ( + // CorePath the core group, is at REST path `/pd/api/v1`. + CorePath = "/pd/api/v1" + // ExtensionsPath the named groups are REST at `/pd/apis/{GROUP_NAME}/{Version}`. + ExtensionsPath = "/pd/apis" +) + +func combineBuilderServerHTTPService(ctx context.Context, svr *Server, serviceBuilders ...HandlerBuilder) (map[string]http.Handler, error) { + userHandlers := make(map[string]http.Handler) + registerMap := make(map[string]struct{}) + + apiService := negroni.New() + recovery := negroni.NewRecovery() + apiService.Use(recovery) + router := mux.NewRouter() + + for _, build := range serviceBuilders { + handler, info, err := build(ctx, svr) + if err != nil { + return nil, err + } + if !info.IsCore && len(info.PathPrefix) == 0 && (len(info.Name) == 0 || len(info.Version) == 0) { + return nil, errs.ErrAPIInformationInvalid.FastGenByArgs(info.Name, info.Version) + } + var pathPrefix string + if len(info.PathPrefix) != 0 { + pathPrefix = info.PathPrefix + } else if info.IsCore { + pathPrefix = CorePath + } else { + pathPrefix = path.Join(ExtensionsPath, info.Name, info.Version) + } + if _, ok := registerMap[pathPrefix]; ok { + return nil, errs.ErrServiceRegistered.FastGenByArgs(pathPrefix) + } + + log.Info("register REST path", zap.String("path", pathPrefix)) + registerMap[pathPrefix] = struct{}{} + if len(info.PathPrefix) != 0 { + // If PathPrefix is specified, register directly into userHandlers + userHandlers[pathPrefix] = handler + } else { + // If PathPrefix is not specified, register into apiService, + // and finally apiService is registered in userHandlers. + router.PathPrefix(pathPrefix).Handler(handler) + if info.IsCore { + // Deprecated + router.Path("/pd/health").Handler(handler) + // Deprecated + router.Path("/pd/ping").Handler(handler) + } + } + } + apiService.UseHandler(router) + userHandlers[pdAPIPrefix] = apiService + + return userHandlers, nil +} + +// CreateServer creates the UNINITIALIZED pd server with given configuration. +func CreateServer(ctx context.Context, cfg *config.Config, serviceBuilders ...HandlerBuilder) (*Server, error) { + log.Info("PD Config", zap.Reflect("config", cfg)) + rand.Seed(time.Now().UnixNano()) + serviceMiddlewareCfg := config.NewServiceMiddlewareConfig() + + s := &Server{ + cfg: cfg, + persistOptions: config.NewPersistOptions(cfg), + serviceMiddlewareCfg: serviceMiddlewareCfg, + serviceMiddlewarePersistOptions: config.NewServiceMiddlewarePersistOptions(serviceMiddlewareCfg), + member: &member.Member{}, + ctx: ctx, + startTimestamp: time.Now().Unix(), + DiagnosticsServer: sysutil.NewDiagnosticsServer(cfg.Log.File.Filename), + } + s.handler = newHandler(s) + + // create audit backend + s.auditBackends = []audit.Backend{ + audit.NewLocalLogBackend(true), + audit.NewPrometheusHistogramBackend(serviceAuditHistogram, false), + } + s.serviceRateLimiter = ratelimit.NewLimiter() + s.serviceAuditBackendLabels = make(map[string]*audit.BackendLabels) + s.serviceRateLimiter = ratelimit.NewLimiter() + s.serviceLabels = make(map[string][]apiutil.AccessPath) + s.apiServiceLabelMap = make(map[apiutil.AccessPath]string) + + // Adjust etcd config. + etcdCfg, err := s.cfg.GenEmbedEtcdConfig() + if err != nil { + return nil, err + } + if len(serviceBuilders) != 0 { + userHandlers, err := combineBuilderServerHTTPService(ctx, s, serviceBuilders...) + if err != nil { + return nil, err + } + etcdCfg.UserHandlers = userHandlers + } + etcdCfg.ServiceRegister = func(gs *grpc.Server) { + grpcServer := &GrpcServer{Server: s} + pdpb.RegisterPDServer(gs, grpcServer) + keyspacepb.RegisterKeyspaceServer(gs, &KeyspaceServer{GrpcServer: grpcServer}) + diagnosticspb.RegisterDiagnosticsServer(gs, s) + } + s.etcdCfg = etcdCfg + s.lg = cfg.GetZapLogger() + s.logProps = cfg.GetZapLogProperties() + return s, nil +} + +func (s *Server) startEtcd(ctx context.Context) error { + newCtx, cancel := context.WithTimeout(ctx, EtcdStartTimeout) + defer cancel() + + etcd, err := embed.StartEtcd(s.etcdCfg) + if err != nil { + return errs.ErrStartEtcd.Wrap(err).GenWithStackByCause() + } + + // Check cluster ID + urlMap, err := types.NewURLsMap(s.cfg.InitialCluster) + if err != nil { + return errs.ErrEtcdURLMap.Wrap(err).GenWithStackByCause() + } + tlsConfig, err := s.cfg.Security.ToTLSConfig() + if err != nil { + return err + } + + if err = etcdutil.CheckClusterID(etcd.Server.Cluster().ID(), urlMap, tlsConfig); err != nil { + return err + } + + select { + // Wait etcd until it is ready to use + case <-etcd.Server.ReadyNotify(): + case <-newCtx.Done(): + return errs.ErrCancelStartEtcd.FastGenByArgs() + } + + endpoints := []string{s.etcdCfg.ACUrls[0].String()} + log.Info("create etcd v3 client", zap.Strings("endpoints", endpoints), zap.Reflect("cert", s.cfg.Security)) + + lgc := zap.NewProductionConfig() + lgc.Encoding = log.ZapEncodingName + client, err := clientv3.New(clientv3.Config{ + Endpoints: endpoints, + DialTimeout: etcdTimeout, + TLS: tlsConfig, + LogConfig: &lgc, + }) + if err != nil { + return errs.ErrNewEtcdClient.Wrap(err).GenWithStackByCause() + } + + etcdServerID := uint64(etcd.Server.ID()) + + // update advertise peer urls. + etcdMembers, err := etcdutil.ListEtcdMembers(client) + if err != nil { + return err + } + for _, m := range etcdMembers.Members { + if etcdServerID == m.ID { + etcdPeerURLs := strings.Join(m.PeerURLs, ",") + if s.cfg.AdvertisePeerUrls != etcdPeerURLs { + log.Info("update advertise peer urls", zap.String("from", s.cfg.AdvertisePeerUrls), zap.String("to", etcdPeerURLs)) + s.cfg.AdvertisePeerUrls = etcdPeerURLs + } + } + } + s.client = client + s.httpClient = &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: tlsConfig, + }, + } + + failpoint.Inject("memberNil", func() { + time.Sleep(1500 * time.Millisecond) + }) + s.member = member.NewMember(etcd, client, etcdServerID) + return nil +} + +// AddStartCallback adds a callback in the startServer phase. +func (s *Server) AddStartCallback(callbacks ...func()) { + s.startCallbacks = append(s.startCallbacks, callbacks...) +} + +func (s *Server) startServer(ctx context.Context) error { + var err error + if err = s.initClusterID(); err != nil { + return err + } + log.Info("init cluster id", zap.Uint64("cluster-id", s.clusterID)) + // It may lose accuracy if use float64 to store uint64. So we store the + // cluster id in label. + metadataGauge.WithLabelValues(fmt.Sprintf("cluster%d", s.clusterID)).Set(0) + serverInfo.WithLabelValues(versioninfo.PDReleaseVersion, versioninfo.PDGitHash).Set(float64(time.Now().Unix())) + + s.rootPath = path.Join(pdRootPath, strconv.FormatUint(s.clusterID, 10)) + s.member.MemberInfo(s.cfg, s.Name(), s.rootPath) + s.member.SetMemberDeployPath(s.member.ID()) + s.member.SetMemberBinaryVersion(s.member.ID(), versioninfo.PDReleaseVersion) + s.member.SetMemberGitHash(s.member.ID(), versioninfo.PDGitHash) + s.idAllocator = id.NewAllocator(&id.AllocatorParams{ + Client: s.client, + RootPath: s.rootPath, + AllocPath: idAllocPath, + Label: idAllocLabel, + Member: s.member.MemberValue(), + }) + s.tsoAllocatorManager = tso.NewAllocatorManager( + s.member, s.rootPath, s.cfg, + func() time.Duration { return s.persistOptions.GetMaxResetTSGap() }) + // Set up the Global TSO Allocator here, it will be initialized once the PD campaigns leader successfully. + s.tsoAllocatorManager.SetUpAllocator(ctx, tso.GlobalDCLocation, s.member.GetLeadership()) + // When disabled the Local TSO, we should clean up the Local TSO Allocator's meta info written in etcd if it exists. + if !s.cfg.EnableLocalTSO { + if err = s.tsoAllocatorManager.CleanUpDCLocation(); err != nil { + return err + } + } + if zone, exist := s.cfg.Labels[config.ZoneLabel]; exist && zone != "" && s.cfg.EnableLocalTSO { + if err = s.tsoAllocatorManager.SetLocalTSOConfig(zone); err != nil { + return err + } + } + s.encryptionKeyManager, err = encryptionkm.NewKeyManager(s.client, &s.cfg.Security.Encryption) + if err != nil { + return err + } + regionStorage, err := storage.NewStorageWithLevelDBBackend(ctx, filepath.Join(s.cfg.DataDir, "region-meta"), s.encryptionKeyManager) + if err != nil { + return err + } + defaultStorage := storage.NewStorageWithEtcdBackend(s.client, s.rootPath) + s.storage = storage.NewCoreStorage(defaultStorage, regionStorage) + s.gcSafePointManager = gc.NewSafePointManager(s.storage) + keyspaceIDAllocator := id.NewAllocator(&id.AllocatorParams{ + Client: s.client, + RootPath: s.rootPath, + AllocPath: endpoint.KeyspaceIDAlloc(), + Label: keyspace.AllocLabel, + Member: s.member.MemberValue(), + Step: keyspace.AllocStep, + }) + s.keyspaceManager, err = keyspace.NewKeyspaceManager(s.storage, keyspaceIDAllocator) + if err != nil { + return err + } + s.basicCluster = core.NewBasicCluster() + s.cluster = cluster.NewRaftCluster(ctx, s.clusterID, syncer.NewRegionSyncer(s), s.client, s.httpClient) + s.hbStreams = hbstream.NewHeartbeatStreams(ctx, s.clusterID, s.cluster) + // initial hot_region_storage in here. + s.hotRegionStorage, err = storage.NewHotRegionsStorage( + ctx, filepath.Join(s.cfg.DataDir, "hot-region"), s.encryptionKeyManager, s.handler) + if err != nil { + return err + } + // Run callbacks + for _, cb := range s.startCallbacks { + cb() + } + + // Server has started. + atomic.StoreInt64(&s.isServing, 1) + return nil +} + +func (s *Server) initClusterID() error { + // Get any cluster key to parse the cluster ID. + resp, err := etcdutil.EtcdKVGet(s.client, pdClusterIDPath) + if err != nil { + return err + } + + // If no key exist, generate a random cluster ID. + if len(resp.Kvs) == 0 { + s.clusterID, err = initOrGetClusterID(s.client, pdClusterIDPath) + return err + } + s.clusterID, err = typeutil.BytesToUint64(resp.Kvs[0].Value) + return err +} + +// AddCloseCallback adds a callback in the Close phase. +func (s *Server) AddCloseCallback(callbacks ...func()) { + s.closeCallbacks = append(s.closeCallbacks, callbacks...) +} + +// Close closes the server. +func (s *Server) Close() { + if !atomic.CompareAndSwapInt64(&s.isServing, 1, 0) { + // server is already closed + return + } + + log.Info("closing server") + + s.stopServerLoop() + + if s.client != nil { + if err := s.client.Close(); err != nil { + log.Error("close etcd client meet error", errs.ZapError(errs.ErrCloseEtcdClient, err)) + } + } + + if s.httpClient != nil { + s.httpClient.CloseIdleConnections() + } + + if s.member.Etcd() != nil { + s.member.Close() + } + + if s.hbStreams != nil { + s.hbStreams.Close() + } + if err := s.storage.Close(); err != nil { + log.Error("close storage meet error", errs.ZapError(err)) + } + + if err := s.hotRegionStorage.Close(); err != nil { + log.Error("close hot region storage meet error", errs.ZapError(err)) + } + + // Run callbacks + for _, cb := range s.closeCallbacks { + cb() + } + + log.Info("close server") +} + +// IsClosed checks whether server is closed or not. +func (s *Server) IsClosed() bool { + return atomic.LoadInt64(&s.isServing) == 0 +} + +// Run runs the pd server. +func (s *Server) Run() error { + go systimemon.StartMonitor(s.ctx, time.Now, func() { + log.Error("system time jumps backward", errs.ZapError(errs.ErrIncorrectSystemTime)) + timeJumpBackCounter.Inc() + }) + if err := s.startEtcd(s.ctx); err != nil { + return err + } + if err := s.startServer(s.ctx); err != nil { + return err + } + + s.startServerLoop(s.ctx) + + return nil +} + +// SetServiceAuditBackendForHTTP is used to register service audit config for HTTP. +func (s *Server) SetServiceAuditBackendForHTTP(route *mux.Route, labels ...string) { + if len(route.GetName()) == 0 { + return + } + if len(labels) > 0 { + s.SetServiceAuditBackendLabels(route.GetName(), labels) + } +} + +// Context returns the context of server. +func (s *Server) Context() context.Context { + return s.ctx +} + +// LoopContext returns the loop context of server. +func (s *Server) LoopContext() context.Context { + return s.serverLoopCtx +} + +func (s *Server) startServerLoop(ctx context.Context) { + s.serverLoopCtx, s.serverLoopCancel = context.WithCancel(ctx) + s.serverLoopWg.Add(5) + go s.leaderLoop() + go s.etcdLeaderLoop() + go s.serverMetricsLoop() + go s.tsoAllocatorLoop() + go s.encryptionKeyManagerLoop() +} + +func (s *Server) stopServerLoop() { + s.serverLoopCancel() + s.serverLoopWg.Wait() +} + +func (s *Server) serverMetricsLoop() { + defer logutil.LogPanic() + defer s.serverLoopWg.Done() + + ctx, cancel := context.WithCancel(s.serverLoopCtx) + defer cancel() + for { + select { + case <-time.After(serverMetricsInterval): + s.collectEtcdStateMetrics() + case <-ctx.Done(): + log.Info("server is closed, exit metrics loop") + return + } + } +} + +// tsoAllocatorLoop is used to run the TSO Allocator updating daemon. +func (s *Server) tsoAllocatorLoop() { + defer logutil.LogPanic() + defer s.serverLoopWg.Done() + + ctx, cancel := context.WithCancel(s.serverLoopCtx) + defer cancel() + s.tsoAllocatorManager.AllocatorDaemon(ctx) + log.Info("server is closed, exit allocator loop") +} + +// encryptionKeyManagerLoop is used to start monitor encryption key changes. +func (s *Server) encryptionKeyManagerLoop() { + defer logutil.LogPanic() + defer s.serverLoopWg.Done() + + ctx, cancel := context.WithCancel(s.serverLoopCtx) + defer cancel() + s.encryptionKeyManager.StartBackgroundLoop(ctx) + log.Info("server is closed, exist encryption key manager loop") +} + +func (s *Server) collectEtcdStateMetrics() { + etcdStateGauge.WithLabelValues("term").Set(float64(s.member.Etcd().Server.Term())) + etcdStateGauge.WithLabelValues("appliedIndex").Set(float64(s.member.Etcd().Server.AppliedIndex())) + etcdStateGauge.WithLabelValues("committedIndex").Set(float64(s.member.Etcd().Server.CommittedIndex())) +} + +func (s *Server) bootstrapCluster(req *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { + clusterID := s.clusterID + + log.Info("try to bootstrap raft cluster", + zap.Uint64("cluster-id", clusterID), + zap.String("request", fmt.Sprintf("%v", req))) + + if err := checkBootstrapRequest(clusterID, req); err != nil { + return nil, err + } + + clusterMeta := metapb.Cluster{ + Id: clusterID, + MaxPeerCount: uint32(s.persistOptions.GetMaxReplicas()), + } + + // Set cluster meta + clusterValue, err := clusterMeta.Marshal() + if err != nil { + return nil, errors.WithStack(err) + } + clusterRootPath := endpoint.ClusterRootPath(s.rootPath) + + var ops []clientv3.Op + ops = append(ops, clientv3.OpPut(clusterRootPath, string(clusterValue))) + + // Set bootstrap time + // Because we will write the cluster meta into etcd directly, + // so we need to handle the root key path manually here. + bootstrapKey := endpoint.AppendToRootPath(s.rootPath, endpoint.ClusterBootstrapTimeKey()) + nano := time.Now().UnixNano() + + timeData := typeutil.Uint64ToBytes(uint64(nano)) + ops = append(ops, clientv3.OpPut(bootstrapKey, string(timeData))) + + // Set store meta + storeMeta := req.GetStore() + storePath := endpoint.AppendToRootPath(s.rootPath, endpoint.StorePath(storeMeta.GetId())) + storeValue, err := storeMeta.Marshal() + if err != nil { + return nil, errors.WithStack(err) + } + ops = append(ops, clientv3.OpPut(storePath, string(storeValue))) + + regionValue, err := req.GetRegion().Marshal() + if err != nil { + return nil, errors.WithStack(err) + } + + // Set region meta with region id. + regionPath := endpoint.AppendToRootPath(s.rootPath, endpoint.RegionPath(req.GetRegion().GetId())) + ops = append(ops, clientv3.OpPut(regionPath, string(regionValue))) + + // TODO: we must figure out a better way to handle bootstrap failed, maybe intervene manually. + bootstrapCmp := clientv3.Compare(clientv3.CreateRevision(clusterRootPath), "=", 0) + resp, err := kv.NewSlowLogTxn(s.client).If(bootstrapCmp).Then(ops...).Commit() + if err != nil { + return nil, errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause() + } + if !resp.Succeeded { + log.Warn("cluster already bootstrapped", zap.Uint64("cluster-id", clusterID)) + return nil, errs.ErrEtcdTxnConflict.FastGenByArgs() + } + + log.Info("bootstrap cluster ok", zap.Uint64("cluster-id", clusterID)) + err = s.storage.SaveRegion(req.GetRegion()) + if err != nil { + log.Warn("save the bootstrap region failed", errs.ZapError(err)) + } + err = s.storage.Flush() + if err != nil { + log.Warn("flush the bootstrap region failed", errs.ZapError(err)) + } + + if err := s.cluster.Start(s); err != nil { + return nil, err + } + + return &pdpb.BootstrapResponse{ + ReplicationStatus: s.cluster.GetReplicationMode().GetReplicationStatus(), + }, nil +} + +func (s *Server) createRaftCluster() error { + if s.cluster.IsRunning() { + return nil + } + + return s.cluster.Start(s) +} + +func (s *Server) stopRaftCluster() { + failpoint.Inject("raftclusterIsBusy", func() {}) + s.cluster.Stop() +} + +// GetAddr returns the server urls for clients. +func (s *Server) GetAddr() string { + return s.cfg.AdvertiseClientUrls +} + +// GetClientScheme returns the client URL scheme +func (s *Server) GetClientScheme() string { + if len(s.cfg.Security.CertPath) == 0 && len(s.cfg.Security.KeyPath) == 0 { + return "http" + } + return "https" +} + +// GetMemberInfo returns the server member information. +func (s *Server) GetMemberInfo() *pdpb.Member { + return typeutil.DeepClone(s.member.Member(), core.MemberFactory) +} + +// GetHandler returns the handler for API. +func (s *Server) GetHandler() *Handler { + return s.handler +} + +// GetEndpoints returns the etcd endpoints for outer use. +func (s *Server) GetEndpoints() []string { + return s.client.Endpoints() +} + +// GetClient returns builtin etcd client. +func (s *Server) GetClient() *clientv3.Client { + return s.client +} + +// GetHTTPClient returns builtin etcd client. +func (s *Server) GetHTTPClient() *http.Client { + return s.httpClient +} + +// GetLeader returns the leader of PD cluster(i.e the PD leader). +func (s *Server) GetLeader() *pdpb.Member { + return s.member.GetLeader() +} + +// GetMember returns the member of server. +func (s *Server) GetMember() *member.Member { + return s.member +} + +// GetStorage returns the backend storage of server. +func (s *Server) GetStorage() storage.Storage { + return s.storage +} + +// GetHistoryHotRegionStorage returns the backend storage of historyHotRegion. +func (s *Server) GetHistoryHotRegionStorage() *storage.HotRegionStorage { + return s.hotRegionStorage +} + +// SetStorage changes the storage only for test purpose. +// When we use it, we should prevent calling GetStorage, otherwise, it may cause a data race problem. +func (s *Server) SetStorage(storage storage.Storage) { + s.storage = storage +} + +// GetBasicCluster returns the basic cluster of server. +func (s *Server) GetBasicCluster() *core.BasicCluster { + return s.basicCluster +} + +// GetPersistOptions returns the schedule option. +func (s *Server) GetPersistOptions() *config.PersistOptions { + return s.persistOptions +} + +// GetServiceMiddlewarePersistOptions returns the service middleware persist option. +func (s *Server) GetServiceMiddlewarePersistOptions() *config.ServiceMiddlewarePersistOptions { + return s.serviceMiddlewarePersistOptions +} + +// GetHBStreams returns the heartbeat streams. +func (s *Server) GetHBStreams() *hbstream.HeartbeatStreams { + return s.hbStreams +} + +// GetAllocator returns the ID allocator of server. +func (s *Server) GetAllocator() id.Allocator { + return s.idAllocator +} + +// GetTSOAllocatorManager returns the manager of TSO Allocator. +func (s *Server) GetTSOAllocatorManager() *tso.AllocatorManager { + return s.tsoAllocatorManager +} + +// GetKeyspaceManager returns the keyspace manager of server. +func (s *Server) GetKeyspaceManager() *keyspace.Manager { + return s.keyspaceManager +} + +// Name returns the unique etcd Name for this server in etcd cluster. +func (s *Server) Name() string { + return s.cfg.Name +} + +// ClusterID returns the cluster ID of this server. +func (s *Server) ClusterID() uint64 { + return s.clusterID +} + +// StartTimestamp returns the start timestamp of this server +func (s *Server) StartTimestamp() int64 { + return s.startTimestamp +} + +// GetMembers returns PD server list. +func (s *Server) GetMembers() ([]*pdpb.Member, error) { + if s.IsClosed() { + return nil, errs.ErrServerNotStarted.FastGenByArgs() + } + return cluster.GetMembers(s.GetClient()) +} + +// GetServiceMiddlewareConfig gets the service middleware config information. +func (s *Server) GetServiceMiddlewareConfig() *config.ServiceMiddlewareConfig { + cfg := s.serviceMiddlewareCfg.Clone() + cfg.AuditConfig = *s.serviceMiddlewarePersistOptions.GetAuditConfig().Clone() + cfg.RateLimitConfig = *s.serviceMiddlewarePersistOptions.GetRateLimitConfig().Clone() + return cfg +} + +// SetEnableLocalTSO sets enable-local-tso flag of Server. This function only for test. +func (s *Server) SetEnableLocalTSO(enableLocalTSO bool) { + s.cfg.EnableLocalTSO = enableLocalTSO +} + +// GetConfig gets the config information. +func (s *Server) GetConfig() *config.Config { + cfg := s.cfg.Clone() + cfg.Schedule = *s.persistOptions.GetScheduleConfig().Clone() + cfg.Replication = *s.persistOptions.GetReplicationConfig().Clone() + cfg.PDServerCfg = *s.persistOptions.GetPDServerConfig().Clone() + cfg.ReplicationMode = *s.persistOptions.GetReplicationModeConfig() + cfg.LabelProperty = s.persistOptions.GetLabelPropertyConfig().Clone() + cfg.ClusterVersion = *s.persistOptions.GetClusterVersion() + if s.storage == nil { + return cfg + } + sches, configs, err := s.storage.LoadAllScheduleConfig() + if err != nil { + return cfg + } + payload := make(map[string]interface{}) + for i, sche := range sches { + var config interface{} + err := schedule.DecodeConfig([]byte(configs[i]), &config) + if err != nil { + log.Error("failed to decode scheduler config", + zap.String("config", configs[i]), + zap.String("scheduler", sche), + errs.ZapError(err)) + continue + } + payload[sche] = config + } + cfg.Schedule.SchedulersPayload = payload + return cfg +} + +// GetScheduleConfig gets the balance config information. +func (s *Server) GetScheduleConfig() *config.ScheduleConfig { + return s.persistOptions.GetScheduleConfig().Clone() +} + +// SetScheduleConfig sets the balance config information. +func (s *Server) SetScheduleConfig(cfg config.ScheduleConfig) error { + if err := cfg.Validate(); err != nil { + return err + } + if err := cfg.Deprecated(); err != nil { + return err + } + old := s.persistOptions.GetScheduleConfig() + cfg.SchedulersPayload = nil + s.persistOptions.SetScheduleConfig(&cfg) + if err := s.persistOptions.Persist(s.storage); err != nil { + s.persistOptions.SetScheduleConfig(old) + log.Error("failed to update schedule config", + zap.Reflect("new", cfg), + zap.Reflect("old", old), + errs.ZapError(err)) + return err + } + log.Info("schedule config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + return nil +} + +// GetReplicationConfig get the replication config. +func (s *Server) GetReplicationConfig() *config.ReplicationConfig { + return s.persistOptions.GetReplicationConfig().Clone() +} + +// SetReplicationConfig sets the replication config. +func (s *Server) SetReplicationConfig(cfg config.ReplicationConfig) error { + if err := cfg.Validate(); err != nil { + return err + } + old := s.persistOptions.GetReplicationConfig() + if cfg.EnablePlacementRules != old.EnablePlacementRules { + raftCluster := s.GetRaftCluster() + if raftCluster == nil { + return errs.ErrNotBootstrapped.GenWithStackByArgs() + } + if cfg.EnablePlacementRules { + // initialize rule manager. + if err := raftCluster.GetRuleManager().Initialize(int(cfg.MaxReplicas), cfg.LocationLabels); err != nil { + return err + } + } else { + // NOTE: can be removed after placement rules feature is enabled by default. + for _, s := range raftCluster.GetStores() { + if !s.IsRemoved() && s.IsTiFlash() { + return errors.New("cannot disable placement rules with TiFlash nodes") + } + } + } + } + + var rule *placement.Rule + if cfg.EnablePlacementRules { + // replication.MaxReplicas won't work when placement rule is enabled and not only have one default rule. + defaultRule := s.GetRaftCluster().GetRuleManager().GetRule("pd", "default") + + CheckInDefaultRule := func() error { + // replication config won't work when placement rule is enabled and exceeds one default rule + if !(defaultRule != nil && + len(defaultRule.StartKey) == 0 && len(defaultRule.EndKey) == 0) { + return errors.New("cannot update MaxReplicas or LocationLabels when placement rules feature is enabled and not only default rule exists, please update rule instead") + } + if !(defaultRule.Count == int(old.MaxReplicas) && typeutil.StringsEqual(defaultRule.LocationLabels, []string(old.LocationLabels))) { + return errors.New("cannot to update replication config, the default rules do not consistent with replication config, please update rule instead") + } + + return nil + } + + if !(cfg.MaxReplicas == old.MaxReplicas && typeutil.StringsEqual(cfg.LocationLabels, old.LocationLabels)) { + if err := CheckInDefaultRule(); err != nil { + return err + } + rule = defaultRule + } + } + + if rule != nil { + rule.Count = int(cfg.MaxReplicas) + rule.LocationLabels = cfg.LocationLabels + if err := s.GetRaftCluster().GetRuleManager().SetRule(rule); err != nil { + log.Error("failed to update rule count", + errs.ZapError(err)) + return err + } + } + + s.persistOptions.SetReplicationConfig(&cfg) + if err := s.persistOptions.Persist(s.storage); err != nil { + s.persistOptions.SetReplicationConfig(old) + if rule != nil { + rule.Count = int(old.MaxReplicas) + if e := s.GetRaftCluster().GetRuleManager().SetRule(rule); e != nil { + log.Error("failed to roll back count of rule when update replication config", errs.ZapError(e)) + } + } + log.Error("failed to update replication config", + zap.Reflect("new", cfg), + zap.Reflect("old", old), + errs.ZapError(err)) + return err + } + log.Info("replication config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + return nil +} + +// GetAuditConfig gets the audit config information. +func (s *Server) GetAuditConfig() *config.AuditConfig { + return s.serviceMiddlewarePersistOptions.GetAuditConfig().Clone() +} + +// SetAuditConfig sets the audit config. +func (s *Server) SetAuditConfig(cfg config.AuditConfig) error { + old := s.serviceMiddlewarePersistOptions.GetAuditConfig() + s.serviceMiddlewarePersistOptions.SetAuditConfig(&cfg) + if err := s.serviceMiddlewarePersistOptions.Persist(s.storage); err != nil { + s.serviceMiddlewarePersistOptions.SetAuditConfig(old) + log.Error("failed to update Audit config", + zap.Reflect("new", cfg), + zap.Reflect("old", old), + errs.ZapError(err)) + return err + } + log.Info("Audit config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + return nil +} + +// UpdateRateLimitConfig is used to update rate-limit config which will reserve old limiter-config +func (s *Server) UpdateRateLimitConfig(key, label string, value ratelimit.DimensionConfig) error { + cfg := s.GetServiceMiddlewareConfig() + rateLimitCfg := make(map[string]ratelimit.DimensionConfig) + for label, item := range cfg.LimiterConfig { + rateLimitCfg[label] = item + } + rateLimitCfg[label] = value + return s.UpdateRateLimit(&cfg.RateLimitConfig, key, &rateLimitCfg) +} + +// UpdateRateLimit is used to update rate-limit config which will overwrite limiter-config +func (s *Server) UpdateRateLimit(cfg *config.RateLimitConfig, key string, value interface{}) error { + updated, found, err := jsonutil.AddKeyValue(cfg, key, value) + if err != nil { + return err + } + + if !found { + return errors.Errorf("config item %s not found", key) + } + + if updated { + err = s.SetRateLimitConfig(*cfg) + } + return err +} + +// GetRateLimitConfig gets the rate limit config information. +func (s *Server) GetRateLimitConfig() *config.RateLimitConfig { + return s.serviceMiddlewarePersistOptions.GetRateLimitConfig().Clone() +} + +// SetRateLimitConfig sets the rate limit config. +func (s *Server) SetRateLimitConfig(cfg config.RateLimitConfig) error { + old := s.serviceMiddlewarePersistOptions.GetRateLimitConfig() + s.serviceMiddlewarePersistOptions.SetRateLimitConfig(&cfg) + if err := s.serviceMiddlewarePersistOptions.Persist(s.storage); err != nil { + s.serviceMiddlewarePersistOptions.SetRateLimitConfig(old) + log.Error("failed to update Rate Limit config", + zap.Reflect("new", cfg), + zap.Reflect("old", old), + errs.ZapError(err)) + return err + } + log.Info("Rate Limit config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + return nil +} + +// GetPDServerConfig gets the balance config information. +func (s *Server) GetPDServerConfig() *config.PDServerConfig { + return s.persistOptions.GetPDServerConfig().Clone() +} + +// SetPDServerConfig sets the server config. +func (s *Server) SetPDServerConfig(cfg config.PDServerConfig) error { + switch cfg.DashboardAddress { + case "auto": + case "none": + default: + if !strings.HasPrefix(cfg.DashboardAddress, "http") { + cfg.DashboardAddress = fmt.Sprintf("%s://%s", s.GetClientScheme(), cfg.DashboardAddress) + } + if !cluster.IsClientURL(cfg.DashboardAddress, s.client) { + return errors.Errorf("%s is not the client url of any member", cfg.DashboardAddress) + } + } + if err := cfg.Validate(); err != nil { + return err + } + + old := s.persistOptions.GetPDServerConfig() + s.persistOptions.SetPDServerConfig(&cfg) + if err := s.persistOptions.Persist(s.storage); err != nil { + s.persistOptions.SetPDServerConfig(old) + log.Error("failed to update PDServer config", + zap.Reflect("new", cfg), + zap.Reflect("old", old), + errs.ZapError(err)) + return err + } + log.Info("PD server config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + return nil +} + +// SetLabelPropertyConfig sets the label property config. +func (s *Server) SetLabelPropertyConfig(cfg config.LabelPropertyConfig) error { + old := s.persistOptions.GetLabelPropertyConfig() + s.persistOptions.SetLabelPropertyConfig(cfg) + if err := s.persistOptions.Persist(s.storage); err != nil { + s.persistOptions.SetLabelPropertyConfig(old) + log.Error("failed to update label property config", + zap.Reflect("new", cfg), + zap.Reflect("old", &old), + errs.ZapError(err)) + return err + } + log.Info("label property config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + return nil +} + +// SetLabelProperty inserts a label property config. +func (s *Server) SetLabelProperty(typ, labelKey, labelValue string) error { + s.persistOptions.SetLabelProperty(typ, labelKey, labelValue) + err := s.persistOptions.Persist(s.storage) + if err != nil { + s.persistOptions.DeleteLabelProperty(typ, labelKey, labelValue) + log.Error("failed to update label property config", + zap.String("typ", typ), + zap.String("label-key", labelKey), + zap.String("label-value", labelValue), + zap.Reflect("config", s.persistOptions.GetLabelPropertyConfig()), + errs.ZapError(err)) + return err + } + + log.Info("label property config is updated", zap.Reflect("config", s.persistOptions.GetLabelPropertyConfig())) + return nil +} + +// DeleteLabelProperty deletes a label property config. +func (s *Server) DeleteLabelProperty(typ, labelKey, labelValue string) error { + s.persistOptions.DeleteLabelProperty(typ, labelKey, labelValue) + err := s.persistOptions.Persist(s.storage) + if err != nil { + s.persistOptions.SetLabelProperty(typ, labelKey, labelValue) + log.Error("failed to delete label property config", + zap.String("typ", typ), + zap.String("label-key", labelKey), + zap.String("label-value", labelValue), + zap.Reflect("config", s.persistOptions.GetLabelPropertyConfig()), + errs.ZapError(err)) + return err + } + + log.Info("label property config is deleted", zap.Reflect("config", s.persistOptions.GetLabelPropertyConfig())) + return nil +} + +// GetLabelProperty returns the whole label property config. +func (s *Server) GetLabelProperty() config.LabelPropertyConfig { + return s.persistOptions.GetLabelPropertyConfig().Clone() +} + +// SetClusterVersion sets the version of cluster. +func (s *Server) SetClusterVersion(v string) error { + version, err := versioninfo.ParseVersion(v) + if err != nil { + return err + } + old := s.persistOptions.GetClusterVersion() + s.persistOptions.SetClusterVersion(version) + err = s.persistOptions.Persist(s.storage) + if err != nil { + s.persistOptions.SetClusterVersion(old) + log.Error("failed to update cluster version", + zap.String("old-version", old.String()), + zap.String("new-version", v), + errs.ZapError(err)) + return err + } + log.Info("cluster version is updated", zap.String("new-version", v)) + return nil +} + +// GetClusterVersion returns the version of cluster. +func (s *Server) GetClusterVersion() semver.Version { + return *s.persistOptions.GetClusterVersion() +} + +// GetTLSConfig get the security config. +func (s *Server) GetTLSConfig() *grpcutil.TLSConfig { + return &s.cfg.Security.TLSConfig +} + +// GetRaftCluster gets Raft cluster. +// If cluster has not been bootstrapped, return nil. +func (s *Server) GetRaftCluster() *cluster.RaftCluster { + if s.IsClosed() || !s.cluster.IsRunning() { + return nil + } + return s.cluster +} + +// GetCluster gets cluster. +func (s *Server) GetCluster() *metapb.Cluster { + return &metapb.Cluster{ + Id: s.clusterID, + MaxPeerCount: uint32(s.persistOptions.GetMaxReplicas()), + } +} + +// GetServerOption gets the option of the server. +func (s *Server) GetServerOption() *config.PersistOptions { + return s.persistOptions +} + +// GetMetaRegions gets meta regions from cluster. +func (s *Server) GetMetaRegions() []*metapb.Region { + cluster := s.GetRaftCluster() + if cluster != nil { + return cluster.GetMetaRegions() + } + return nil +} + +// GetRegions gets regions from cluster. +func (s *Server) GetRegions() []*core.RegionInfo { + cluster := s.GetRaftCluster() + if cluster != nil { + return cluster.GetRegions() + } + return nil +} + +// GetServiceLabels returns ApiAccessPaths by given service label +// TODO: this function will be used for updating api rate limit config +func (s *Server) GetServiceLabels(serviceLabel string) []apiutil.AccessPath { + if apis, ok := s.serviceLabels[serviceLabel]; ok { + return apis + } + return nil +} + +// GetAPIAccessServiceLabel returns service label by given access path +// TODO: this function will be used for updating api rate limit config +func (s *Server) GetAPIAccessServiceLabel(accessPath apiutil.AccessPath) string { + if servicelabel, ok := s.apiServiceLabelMap[accessPath]; ok { + return servicelabel + } + accessPathNoMethod := apiutil.NewAccessPath(accessPath.Path, "") + if servicelabel, ok := s.apiServiceLabelMap[accessPathNoMethod]; ok { + return servicelabel + } + return "" +} + +// AddServiceLabel is used to add the relationship between service label and api access path +// TODO: this function will be used for updating api rate limit config +func (s *Server) AddServiceLabel(serviceLabel string, accessPath apiutil.AccessPath) { + if slice, ok := s.serviceLabels[serviceLabel]; ok { + slice = append(slice, accessPath) + s.serviceLabels[serviceLabel] = slice + } else { + slice = []apiutil.AccessPath{accessPath} + s.serviceLabels[serviceLabel] = slice + } + + s.apiServiceLabelMap[accessPath] = serviceLabel +} + +// GetAuditBackend returns audit backends +func (s *Server) GetAuditBackend() []audit.Backend { + return s.auditBackends +} + +// GetServiceAuditBackendLabels returns audit backend labels by serviceLabel +func (s *Server) GetServiceAuditBackendLabels(serviceLabel string) *audit.BackendLabels { + return s.serviceAuditBackendLabels[serviceLabel] +} + +// SetServiceAuditBackendLabels is used to add audit backend labels for service by service label +func (s *Server) SetServiceAuditBackendLabels(serviceLabel string, labels []string) { + s.serviceAuditBackendLabels[serviceLabel] = &audit.BackendLabels{Labels: labels} +} + +// GetServiceRateLimiter is used to get rate limiter +func (s *Server) GetServiceRateLimiter() *ratelimit.Limiter { + return s.serviceRateLimiter +} + +// IsInRateLimitAllowList returns whethis given service label is in allow lost +func (s *Server) IsInRateLimitAllowList(serviceLabel string) bool { + return s.serviceRateLimiter.IsInAllowList(serviceLabel) +} + +// UpdateServiceRateLimiter is used to update RateLimiter +func (s *Server) UpdateServiceRateLimiter(serviceLabel string, opts ...ratelimit.Option) ratelimit.UpdateStatus { + return s.serviceRateLimiter.Update(serviceLabel, opts...) +} + +// GetClusterStatus gets cluster status. +func (s *Server) GetClusterStatus() (*cluster.Status, error) { + s.cluster.Lock() + defer s.cluster.Unlock() + return s.cluster.LoadClusterStatus() +} + +// SetLogLevel sets log level. +func (s *Server) SetLogLevel(level string) error { + if !isLevelLegal(level) { + return errors.Errorf("log level %s is illegal", level) + } + s.cfg.Log.Level = level + log.SetLevel(logutil.StringToZapLogLevel(level)) + log.Warn("log level changed", zap.String("level", log.GetLevel().String())) + return nil +} + +func isLevelLegal(level string) bool { + switch strings.ToLower(level) { + case "fatal", "error", "warn", "warning", "debug", "info": + return true + default: + return false + } +} + +// GetReplicationModeConfig returns the replication mode config. +func (s *Server) GetReplicationModeConfig() *config.ReplicationModeConfig { + return s.persistOptions.GetReplicationModeConfig().Clone() +} + +// SetReplicationModeConfig sets the replication mode. +func (s *Server) SetReplicationModeConfig(cfg config.ReplicationModeConfig) error { + if config.NormalizeReplicationMode(cfg.ReplicationMode) == "" { + return errors.Errorf("invalid replication mode: %v", cfg.ReplicationMode) + } + + old := s.persistOptions.GetReplicationModeConfig() + s.persistOptions.SetReplicationModeConfig(&cfg) + if err := s.persistOptions.Persist(s.storage); err != nil { + s.persistOptions.SetReplicationModeConfig(old) + log.Error("failed to update replication mode config", + zap.Reflect("new", cfg), + zap.Reflect("old", &old), + errs.ZapError(err)) + return err + } + log.Info("replication mode config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + + cluster := s.GetRaftCluster() + if cluster != nil { + err := cluster.GetReplicationMode().UpdateConfig(cfg) + if err != nil { + log.Warn("failed to update replication mode", errs.ZapError(err)) + // revert to old config + // NOTE: since we can't put the 2 storage mutations in a batch, it + // is possible that memory and persistent data become different + // (when below revert fail). They will become the same after PD is + // restart or PD leader is changed. + s.persistOptions.SetReplicationModeConfig(old) + revertErr := s.persistOptions.Persist(s.storage) + if revertErr != nil { + log.Error("failed to revert replication mode persistent config", errs.ZapError(revertErr)) + } + } + return err + } + + return nil +} + +func (s *Server) leaderLoop() { + defer logutil.LogPanic() + defer s.serverLoopWg.Done() + + for { + if s.IsClosed() { + log.Info("server is closed, return pd leader loop") + return + } + + leader, rev, checkAgain := s.member.CheckLeader() + if checkAgain { + continue + } + if leader != nil { + err := s.reloadConfigFromKV() + if err != nil { + log.Error("reload config failed", errs.ZapError(err)) + continue + } + // Check the cluster dc-location after the PD leader is elected + go s.tsoAllocatorManager.ClusterDCLocationChecker() + syncer := s.cluster.GetRegionSyncer() + if s.persistOptions.IsUseRegionStorage() { + syncer.StartSyncWithLeader(leader.GetClientUrls()[0]) + } + log.Info("start to watch pd leader", zap.Stringer("pd-leader", leader)) + // WatchLeader will keep looping and never return unless the PD leader has changed. + s.member.WatchLeader(s.serverLoopCtx, leader, rev) + syncer.StopSyncWithLeader() + log.Info("pd leader has changed, try to re-campaign a pd leader") + } + + // To make sure the etcd leader and PD leader are on the same server. + etcdLeader := s.member.GetEtcdLeader() + if etcdLeader != s.member.ID() { + log.Info("skip campaigning of pd leader and check later", + zap.String("server-name", s.Name()), + zap.Uint64("etcd-leader-id", etcdLeader), + zap.Uint64("member-id", s.member.ID())) + time.Sleep(200 * time.Millisecond) + continue + } + s.campaignLeader() + } +} + +func (s *Server) campaignLeader() { + log.Info("start to campaign pd leader", zap.String("campaign-pd-leader-name", s.Name())) + if err := s.member.CampaignLeader(s.cfg.LeaderLease); err != nil { + if err.Error() == errs.ErrEtcdTxnConflict.Error() { + log.Info("campaign pd leader meets error due to txn conflict, another PD server may campaign successfully", + zap.String("campaign-pd-leader-name", s.Name())) + } else { + log.Error("campaign pd leader meets error due to etcd error", + zap.String("campaign-pd-leader-name", s.Name()), + errs.ZapError(err)) + } + return + } + + // Start keepalive the leadership and enable TSO service. + // TSO service is strictly enabled/disabled by PD leader lease for 2 reasons: + // 1. lease based approach is not affected by thread pause, slow runtime schedule, etc. + // 2. load region could be slow. Based on lease we can recover TSO service faster. + ctx, cancel := context.WithCancel(s.serverLoopCtx) + var resetLeaderOnce sync.Once + defer resetLeaderOnce.Do(func() { + cancel() + s.member.ResetLeader() + }) + + // maintain the PD leadership, after this, TSO can be service. + s.member.KeepLeader(ctx) + log.Info("campaign pd leader ok", zap.String("campaign-pd-leader-name", s.Name())) + + allocator, err := s.tsoAllocatorManager.GetAllocator(tso.GlobalDCLocation) + if err != nil { + log.Error("failed to get the global TSO allocator", errs.ZapError(err)) + return + } + log.Info("initializing the global TSO allocator") + if err := allocator.Initialize(0); err != nil { + log.Error("failed to initialize the global TSO allocator", errs.ZapError(err)) + return + } + defer func() { + s.tsoAllocatorManager.ResetAllocatorGroup(tso.GlobalDCLocation) + failpoint.Inject("updateAfterResetTSO", func() { + if err = allocator.UpdateTSO(); err != nil { + panic(err) + } + }) + }() + + if err := s.reloadConfigFromKV(); err != nil { + log.Error("failed to reload configuration", errs.ZapError(err)) + return + } + + if err := s.persistOptions.LoadTTLFromEtcd(s.ctx, s.client); err != nil { + log.Error("failed to load persistOptions from etcd", errs.ZapError(err)) + return + } + + if err := s.encryptionKeyManager.SetLeadership(s.member.GetLeadership()); err != nil { + log.Error("failed to initialize encryption", errs.ZapError(err)) + return + } + + // Try to create raft cluster. + if err := s.createRaftCluster(); err != nil { + log.Error("failed to create raft cluster", errs.ZapError(err)) + return + } + defer s.stopRaftCluster() + if err := s.idAllocator.Rebase(); err != nil { + log.Error("failed to sync id from etcd", errs.ZapError(err)) + return + } + // EnableLeader to accept the remaining service, such as GetStore, GetRegion. + s.member.EnableLeader() + // Check the cluster dc-location after the PD leader is elected. + go s.tsoAllocatorManager.ClusterDCLocationChecker() + defer resetLeaderOnce.Do(func() { + // as soon as cancel the leadership keepalive, then other member have chance + // to be new leader. + cancel() + s.member.ResetLeader() + }) + + CheckPDVersion(s.persistOptions) + log.Info("PD cluster leader is ready to serve", zap.String("pd-leader-name", s.Name())) + + leaderTicker := time.NewTicker(leaderTickInterval) + defer leaderTicker.Stop() + + for { + select { + case <-leaderTicker.C: + if !s.member.IsLeader() { + log.Info("no longer a leader because lease has expired, pd leader will step down") + return + } + etcdLeader := s.member.GetEtcdLeader() + if etcdLeader != s.member.ID() { + log.Info("etcd leader changed, resigns pd leadership", zap.String("old-pd-leader-name", s.Name())) + return + } + // add failpoint to test exit leader, failpoint judge the member is the give value, then break + failpoint.Inject("exitCampaignLeader", func(val failpoint.Value) { + memberString := val.(string) + memberID, _ := strconv.ParseUint(memberString, 10, 64) + if s.member.ID() == memberID { + log.Info("exit PD leader") + failpoint.Return() + } + }) + case <-ctx.Done(): + // Server is closed and it should return nil. + log.Info("server is closed") + return + } + } +} + +func (s *Server) etcdLeaderLoop() { + defer logutil.LogPanic() + defer s.serverLoopWg.Done() + + ctx, cancel := context.WithCancel(s.serverLoopCtx) + defer cancel() + for { + select { + case <-time.After(s.cfg.LeaderPriorityCheckInterval.Duration): + s.member.CheckPriority(ctx) + case <-ctx.Done(): + log.Info("server is closed, exit etcd leader loop") + return + } + } +} + +func (s *Server) reloadConfigFromKV() error { + err := s.persistOptions.Reload(s.storage) + if err != nil { + return err + } + err = s.serviceMiddlewarePersistOptions.Reload(s.storage) + if err != nil { + return err + } + s.loadRateLimitConfig() + useRegionStorage := s.persistOptions.IsUseRegionStorage() + regionStorage := storage.TrySwitchRegionStorage(s.storage, useRegionStorage) + if regionStorage != nil { + if useRegionStorage { + log.Info("server enable region storage") + } else { + log.Info("server disable region storage") + } + } + return nil +} + +func (s *Server) loadRateLimitConfig() { + cfg := s.serviceMiddlewarePersistOptions.GetRateLimitConfig().LimiterConfig + for key := range cfg { + value := cfg[key] + s.serviceRateLimiter.Update(key, ratelimit.UpdateDimensionConfig(&value)) + } +} + +// ReplicateFileToMember is used to synchronize state to a member. +// Each member will write `data` to a local file named `name`. +// For security reason, data should be in JSON format. +func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member, name string, data []byte) error { + clientUrls := member.GetClientUrls() + if len(clientUrls) == 0 { + log.Warn("failed to replicate file", zap.String("name", name), zap.String("member", member.GetName())) + return errs.ErrClientURLEmpty.FastGenByArgs() + } + url := clientUrls[0] + filepath.Join("/pd/api/v1/admin/persist-file", name) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) + req.Header.Set("PD-Allow-follower-handle", "true") + res, err := s.httpClient.Do(req) + if err != nil { + log.Warn("failed to replicate file", zap.String("name", name), zap.String("member", member.GetName()), errs.ZapError(err)) + return errs.ErrSendRequest.Wrap(err).GenWithStackByCause() + } + // Since we don't read the body, we can close it immediately. + res.Body.Close() + if res.StatusCode != http.StatusOK { + log.Warn("failed to replicate file", zap.String("name", name), zap.String("member", member.GetName()), zap.Int("status-code", res.StatusCode)) + return errs.ErrSendRequest.FastGenByArgs() + } + return nil +} + +// PersistFile saves a file in DataDir. +func (s *Server) PersistFile(name string, data []byte) error { + log.Info("persist file", zap.String("name", name), zap.Binary("data", data)) + return os.WriteFile(filepath.Join(s.GetConfig().DataDir, name), data, 0644) // #nosec +} + +// SaveTTLConfig save ttl config +func (s *Server) SaveTTLConfig(data map[string]interface{}, ttl time.Duration) error { + for k := range data { + if !config.IsSupportedTTLConfig(k) { + return fmt.Errorf("unsupported ttl config %s", k) + } + } + for k, v := range data { + if err := s.persistOptions.SetTTLData(s.ctx, s.client, k, fmt.Sprint(v), ttl); err != nil { + return err + } + } + return nil +} + +// IsTTLConfigExist returns true if the ttl config is existed for a given config. +func (s *Server) IsTTLConfigExist(key string) bool { + if config.IsSupportedTTLConfig(key) { + if _, ok := s.persistOptions.GetTTLData(key); ok { + return true + } + } + return false +} + +// MarkSnapshotRecovering mark pd that we're recovering +// tikv will get this state during BR EBS restore. +// we write this info into etcd for simplicity, the key only stays inside etcd temporary +// during BR EBS restore in which period the cluster is not able to serve request. +// and is deleted after BR EBS restore is done. +func (s *Server) MarkSnapshotRecovering() error { + log.Info("mark snapshot recovering") + markPath := endpoint.AppendToRootPath(s.rootPath, recoveringMarkPath) + // the value doesn't matter, set to a static string + _, err := kv.NewSlowLogTxn(s.client). + If(clientv3.Compare(clientv3.CreateRevision(markPath), "=", 0)). + Then(clientv3.OpPut(markPath, "on")). + Commit() + // if other client already marked, return success too + return err +} + +// IsSnapshotRecovering check whether recovering-mark marked +func (s *Server) IsSnapshotRecovering(ctx context.Context) (bool, error) { + markPath := endpoint.AppendToRootPath(s.rootPath, recoveringMarkPath) + resp, err := s.client.Get(ctx, markPath) + if err != nil { + return false, err + } + return len(resp.Kvs) > 0, nil +} + +// UnmarkSnapshotRecovering unmark recovering mark +func (s *Server) UnmarkSnapshotRecovering(ctx context.Context) error { + log.Info("unmark snapshot recovering") + markPath := endpoint.AppendToRootPath(s.rootPath, recoveringMarkPath) + _, err := s.client.Delete(ctx, markPath) + // if other client already unmarked, return success too + return err +} + +// RecoverAllocID recover alloc id. set current base id to input id +func (s *Server) RecoverAllocID(ctx context.Context, id uint64) error { + return s.idAllocator.SetBase(id) +} + +// GetGlobalTS returns global tso. +func (s *Server) GetGlobalTS() (uint64, error) { + ts, err := s.tsoAllocatorManager.GetGlobalTSO() + if err != nil { + return 0, err + } + return tsoutil.GenerateTS(ts), nil +} + +// GetExternalTS returns external timestamp. +func (s *Server) GetExternalTS() uint64 { + return s.GetRaftCluster().GetExternalTS() +} + +// SetExternalTS returns external timestamp. +func (s *Server) SetExternalTS(externalTS uint64) error { + globalTS, err := s.GetGlobalTS() + if err != nil { + return err + } + if tsoutil.CompareTimestampUint64(externalTS, globalTS) == 1 { + desc := "the external timestamp should not be larger than global ts" + log.Error(desc, zap.Uint64("request timestamp", externalTS), zap.Uint64("global ts", globalTS)) + return errors.New(desc) + } + currentExternalTS := s.GetRaftCluster().GetExternalTS() + if tsoutil.CompareTimestampUint64(externalTS, currentExternalTS) != 1 { + desc := "the external timestamp should be larger than now" + log.Error(desc, zap.Uint64("request timestamp", externalTS), zap.Uint64("current external timestamp", currentExternalTS)) + return errors.New(desc) + } + s.GetRaftCluster().SetExternalTS(externalTS) + return nil +} diff --git a/server/storage/binding__failpoint_binding__.go b/server/storage/binding__failpoint_binding__.go new file mode 100755 index 00000000000..a1a747a15d5 --- /dev/null +++ b/server/storage/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package storage + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/storage/endpoint/binding__failpoint_binding__.go b/server/storage/endpoint/binding__failpoint_binding__.go new file mode 100755 index 00000000000..5aade927635 --- /dev/null +++ b/server/storage/endpoint/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package endpoint + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/storage/endpoint/gc_key_space.go b/server/storage/endpoint/gc_key_space.go old mode 100644 new mode 100755 index e23197da90e..43b2c88af86 --- a/server/storage/endpoint/gc_key_space.go +++ b/server/storage/endpoint/gc_key_space.go @@ -114,14 +114,14 @@ func (se *StorageEndpoint) LoadMinServiceSafePoint(spaceID string, now time.Time } } // failpoint for immediate removal - failpoint.Inject("removeExpiredKeys", func() { + if _, _err_ := failpoint.Eval(_curpkg_("removeExpiredKeys")); _err_ == nil { for _, key := range expiredKeys { if err = se.Remove(key); err != nil { log.Error("remove expired key meet error", zap.String("key", key), errs.ZapError(err)) } } expiredKeys = []string{} - }) + } // remove expired keys asynchronously go func() { for _, key := range expiredKeys { diff --git a/server/storage/endpoint/gc_key_space.go__failpoint_stash__ b/server/storage/endpoint/gc_key_space.go__failpoint_stash__ new file mode 100644 index 00000000000..e23197da90e --- /dev/null +++ b/server/storage/endpoint/gc_key_space.go__failpoint_stash__ @@ -0,0 +1,198 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package endpoint + +import ( + "encoding/json" + "math" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "go.etcd.io/etcd/clientv3" + "go.uber.org/zap" +) + +// KeySpaceGCSafePoint is gcWorker's safepoint for specific key-space +type KeySpaceGCSafePoint struct { + SpaceID string `json:"space_id"` + SafePoint uint64 `json:"safe_point,omitempty"` +} + +// KeySpaceGCSafePointStorage defines the storage operations on KeySpaces' safe points +type KeySpaceGCSafePointStorage interface { + // Service safe point interfaces. + SaveServiceSafePoint(spaceID string, ssp *ServiceSafePoint) error + LoadServiceSafePoint(spaceID, serviceID string) (*ServiceSafePoint, error) + LoadMinServiceSafePoint(spaceID string, now time.Time) (*ServiceSafePoint, error) + RemoveServiceSafePoint(spaceID, serviceID string) error + // GC safe point interfaces. + SaveKeySpaceGCSafePoint(spaceID string, safePoint uint64) error + LoadKeySpaceGCSafePoint(spaceID string) (uint64, error) + LoadAllKeySpaceGCSafePoints(withGCSafePoint bool) ([]*KeySpaceGCSafePoint, error) +} + +var _ KeySpaceGCSafePointStorage = (*StorageEndpoint)(nil) + +// SaveServiceSafePoint saves service safe point under given key-space. +func (se *StorageEndpoint) SaveServiceSafePoint(spaceID string, ssp *ServiceSafePoint) error { + if ssp.ServiceID == "" { + return errors.New("service id of service safepoint cannot be empty") + } + key := KeySpaceServiceSafePointPath(spaceID, ssp.ServiceID) + value, err := json.Marshal(ssp) + if err != nil { + return err + } + return se.Save(key, string(value)) +} + +// LoadServiceSafePoint reads ServiceSafePoint for the given key-space ID and service name. +// Return nil if no safepoint exist for given service or just expired. +func (se *StorageEndpoint) LoadServiceSafePoint(spaceID, serviceID string) (*ServiceSafePoint, error) { + key := KeySpaceServiceSafePointPath(spaceID, serviceID) + value, err := se.Load(key) + if err != nil || value == "" { + return nil, err + } + ssp := &ServiceSafePoint{} + if err := json.Unmarshal([]byte(value), ssp); err != nil { + return nil, err + } + if ssp.ExpiredAt < time.Now().Unix() { + go func() { + if err = se.Remove(key); err != nil { + log.Error("remove expired key meet error", zap.String("key", key), errs.ZapError(err)) + } + }() + return nil, nil + } + return ssp, nil +} + +// LoadMinServiceSafePoint returns the minimum safepoint for the given key-space. +// Note that gc worker safe point are store separately. +// If no service safe point exist for the given key-space or all the service safe points just expired, return nil. +func (se *StorageEndpoint) LoadMinServiceSafePoint(spaceID string, now time.Time) (*ServiceSafePoint, error) { + prefix := KeySpaceServiceSafePointPrefix(spaceID) + prefixEnd := clientv3.GetPrefixRangeEnd(prefix) + keys, values, err := se.LoadRange(prefix, prefixEnd, 0) + if err != nil { + return nil, err + } + min := &ServiceSafePoint{SafePoint: math.MaxUint64} + expiredKeys := make([]string, 0) + for i, key := range keys { + ssp := &ServiceSafePoint{} + if err = json.Unmarshal([]byte(values[i]), ssp); err != nil { + return nil, err + } + + // gather expired keys + if ssp.ExpiredAt < now.Unix() { + expiredKeys = append(expiredKeys, key) + continue + } + if ssp.SafePoint < min.SafePoint { + min = ssp + } + } + // failpoint for immediate removal + failpoint.Inject("removeExpiredKeys", func() { + for _, key := range expiredKeys { + if err = se.Remove(key); err != nil { + log.Error("remove expired key meet error", zap.String("key", key), errs.ZapError(err)) + } + } + expiredKeys = []string{} + }) + // remove expired keys asynchronously + go func() { + for _, key := range expiredKeys { + if err = se.Remove(key); err != nil { + log.Error("remove expired key meet error", zap.String("key", key), errs.ZapError(err)) + } + } + }() + if min.SafePoint == math.MaxUint64 { + // no service safe point or all of them are expired. + return nil, nil + } + + // successfully found a valid min safe point. + return min, nil +} + +// RemoveServiceSafePoint removes target ServiceSafePoint +func (se *StorageEndpoint) RemoveServiceSafePoint(spaceID, serviceID string) error { + key := KeySpaceServiceSafePointPath(spaceID, serviceID) + return se.Remove(key) +} + +// SaveKeySpaceGCSafePoint saves GCSafePoint to the given key-space. +func (se *StorageEndpoint) SaveKeySpaceGCSafePoint(spaceID string, safePoint uint64) error { + value := strconv.FormatUint(safePoint, 16) + return se.Save(KeySpaceGCSafePointPath(spaceID), value) +} + +// LoadKeySpaceGCSafePoint reads GCSafePoint for the given key-space. +// Returns 0 if target safepoint not exist. +func (se *StorageEndpoint) LoadKeySpaceGCSafePoint(spaceID string) (uint64, error) { + value, err := se.Load(KeySpaceGCSafePointPath(spaceID)) + if err != nil || value == "" { + return 0, err + } + safePoint, err := strconv.ParseUint(value, 16, 64) + if err != nil { + return 0, err + } + return safePoint, nil +} + +// LoadAllKeySpaceGCSafePoints returns slice of KeySpaceGCSafePoint. +// If withGCSafePoint set to false, returned safePoints will be 0. +func (se *StorageEndpoint) LoadAllKeySpaceGCSafePoints(withGCSafePoint bool) ([]*KeySpaceGCSafePoint, error) { + prefix := KeySpaceSafePointPrefix() + prefixEnd := clientv3.GetPrefixRangeEnd(prefix) + suffix := KeySpaceGCSafePointSuffix() + keys, values, err := se.LoadRange(prefix, prefixEnd, 0) + if err != nil { + return nil, err + } + safePoints := make([]*KeySpaceGCSafePoint, 0, len(values)) + for i := range keys { + // skip non gc safe points + if !strings.HasSuffix(keys[i], suffix) { + continue + } + safePoint := &KeySpaceGCSafePoint{} + spaceID := strings.TrimPrefix(keys[i], prefix) + spaceID = strings.TrimSuffix(spaceID, suffix) + safePoint.SpaceID = spaceID + if withGCSafePoint { + value, err := strconv.ParseUint(values[i], 16, 64) + if err != nil { + return nil, err + } + safePoint.SafePoint = value + } + safePoints = append(safePoints, safePoint) + } + return safePoints, nil +} diff --git a/server/storage/endpoint/meta.go b/server/storage/endpoint/meta.go old mode 100644 new mode 100755 index bb848485c39..d60de6133a1 --- a/server/storage/endpoint/meta.go +++ b/server/storage/endpoint/meta.go @@ -175,10 +175,10 @@ func (se *StorageEndpoint) LoadRegions(ctx context.Context, f func(region *core. // a variable rangeLimit to work around. rangeLimit := MaxKVRangeLimit for { - failpoint.Inject("slowLoadRegion", func() { + if _, _err_ := failpoint.Eval(_curpkg_("slowLoadRegion")); _err_ == nil { rangeLimit = 1 time.Sleep(time.Second) - }) + } startKey := RegionPath(nextID) _, res, err := se.LoadRange(startKey, endKey, rangeLimit) if err != nil { diff --git a/server/storage/endpoint/meta.go__failpoint_stash__ b/server/storage/endpoint/meta.go__failpoint_stash__ new file mode 100644 index 00000000000..bb848485c39 --- /dev/null +++ b/server/storage/endpoint/meta.go__failpoint_stash__ @@ -0,0 +1,242 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package endpoint + +import ( + "context" + "math" + "strconv" + "time" + + "github.com/gogo/protobuf/proto" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/tikv/pd/pkg/encryption" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/server/core" +) + +// MetaStorage defines the storage operations on the PD cluster meta info. +type MetaStorage interface { + LoadMeta(meta *metapb.Cluster) (bool, error) + SaveMeta(meta *metapb.Cluster) error + LoadStore(storeID uint64, store *metapb.Store) (bool, error) + SaveStore(store *metapb.Store) error + SaveStoreWeight(storeID uint64, leader, region float64) error + LoadStores(f func(store *core.StoreInfo)) error + DeleteStore(store *metapb.Store) error + RegionStorage +} + +// RegionStorage defines the storage operations on the Region meta info. +type RegionStorage interface { + LoadRegion(regionID uint64, region *metapb.Region) (ok bool, err error) + LoadRegions(ctx context.Context, f func(region *core.RegionInfo) []*core.RegionInfo) error + SaveRegion(region *metapb.Region) error + DeleteRegion(region *metapb.Region) error + Flush() error + Close() error +} + +var _ MetaStorage = (*StorageEndpoint)(nil) + +const ( + // MaxKVRangeLimit is the max limit of the number of keys in a range. + MaxKVRangeLimit = 10000 + // MinKVRangeLimit is the min limit of the number of keys in a range. + MinKVRangeLimit = 100 +) + +// LoadMeta loads cluster meta from the storage. This method will only +// be used by the PD server, so we should only implement it for the etcd storage. +func (se *StorageEndpoint) LoadMeta(meta *metapb.Cluster) (bool, error) { + return se.loadProto(clusterPath, meta) +} + +// SaveMeta save cluster meta to the storage. This method will only +// be used by the PD server, so we should only implement it for the etcd storage. +func (se *StorageEndpoint) SaveMeta(meta *metapb.Cluster) error { + return se.saveProto(clusterPath, meta) +} + +// LoadStore loads one store from storage. +func (se *StorageEndpoint) LoadStore(storeID uint64, store *metapb.Store) (bool, error) { + return se.loadProto(StorePath(storeID), store) +} + +// SaveStore saves one store to storage. +func (se *StorageEndpoint) SaveStore(store *metapb.Store) error { + return se.saveProto(StorePath(store.GetId()), store) +} + +// SaveStoreWeight saves a store's leader and region weight to storage. +func (se *StorageEndpoint) SaveStoreWeight(storeID uint64, leader, region float64) error { + leaderValue := strconv.FormatFloat(leader, 'f', -1, 64) + if err := se.Save(storeLeaderWeightPath(storeID), leaderValue); err != nil { + return err + } + regionValue := strconv.FormatFloat(region, 'f', -1, 64) + return se.Save(storeRegionWeightPath(storeID), regionValue) +} + +// LoadStores loads all stores from storage to StoresInfo. +func (se *StorageEndpoint) LoadStores(f func(store *core.StoreInfo)) error { + nextID := uint64(0) + endKey := StorePath(math.MaxUint64) + for { + key := StorePath(nextID) + _, res, err := se.LoadRange(key, endKey, MinKVRangeLimit) + if err != nil { + return err + } + for _, str := range res { + store := &metapb.Store{} + if err := store.Unmarshal([]byte(str)); err != nil { + return errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByArgs() + } + if store.State == metapb.StoreState_Offline { + store.NodeState = metapb.NodeState_Removing + } + if store.State == metapb.StoreState_Tombstone { + store.NodeState = metapb.NodeState_Removed + } + leaderWeight, err := se.loadFloatWithDefaultValue(storeLeaderWeightPath(store.GetId()), 1.0) + if err != nil { + return err + } + regionWeight, err := se.loadFloatWithDefaultValue(storeRegionWeightPath(store.GetId()), 1.0) + if err != nil { + return err + } + newStoreInfo := core.NewStoreInfo(store, core.SetLeaderWeight(leaderWeight), core.SetRegionWeight(regionWeight)) + + nextID = store.GetId() + 1 + f(newStoreInfo) + } + if len(res) < MinKVRangeLimit { + return nil + } + } +} + +func (se *StorageEndpoint) loadFloatWithDefaultValue(path string, def float64) (float64, error) { + res, err := se.Load(path) + if err != nil { + return 0, err + } + if res == "" { + return def, nil + } + val, err := strconv.ParseFloat(res, 64) + if err != nil { + return 0, errs.ErrStrconvParseFloat.Wrap(err).GenWithStackByArgs() + } + return val, nil +} + +// DeleteStore deletes one store from storage. +func (se *StorageEndpoint) DeleteStore(store *metapb.Store) error { + return se.Remove(StorePath(store.GetId())) +} + +// LoadRegion loads one region from the backend storage. +func (se *StorageEndpoint) LoadRegion(regionID uint64, region *metapb.Region) (ok bool, err error) { + value, err := se.Load(RegionPath(regionID)) + if err != nil || value == "" { + return false, err + } + err = proto.Unmarshal([]byte(value), region) + if err != nil { + return true, errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByArgs() + } + err = encryption.DecryptRegion(region, se.encryptionKeyManager) + return true, err +} + +// LoadRegions loads all regions from storage to RegionsInfo. +func (se *StorageEndpoint) LoadRegions(ctx context.Context, f func(region *core.RegionInfo) []*core.RegionInfo) error { + nextID := uint64(0) + endKey := RegionPath(math.MaxUint64) + + // Since the region key may be very long, using a larger rangeLimit will cause + // the message packet to exceed the grpc message size limit (4MB). Here we use + // a variable rangeLimit to work around. + rangeLimit := MaxKVRangeLimit + for { + failpoint.Inject("slowLoadRegion", func() { + rangeLimit = 1 + time.Sleep(time.Second) + }) + startKey := RegionPath(nextID) + _, res, err := se.LoadRange(startKey, endKey, rangeLimit) + if err != nil { + if rangeLimit /= 2; rangeLimit >= MinKVRangeLimit { + continue + } + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + for _, r := range res { + region := &metapb.Region{} + if err := region.Unmarshal([]byte(r)); err != nil { + return errs.ErrProtoUnmarshal.Wrap(err).GenWithStackByArgs() + } + if err = encryption.DecryptRegion(region, se.encryptionKeyManager); err != nil { + return err + } + + nextID = region.GetId() + 1 + overlaps := f(core.NewRegionInfo(region, nil)) + for _, item := range overlaps { + if err := se.DeleteRegion(item.GetMeta()); err != nil { + return err + } + } + } + + if len(res) < rangeLimit { + return nil + } + } +} + +// SaveRegion saves one region to storage. +func (se *StorageEndpoint) SaveRegion(region *metapb.Region) error { + region, err := encryption.EncryptRegion(region, se.encryptionKeyManager) + if err != nil { + return err + } + value, err := proto.Marshal(region) + if err != nil { + return errs.ErrProtoMarshal.Wrap(err).GenWithStackByArgs() + } + return se.Save(RegionPath(region.GetId()), string(value)) +} + +// DeleteRegion deletes one region from storage. +func (se *StorageEndpoint) DeleteRegion(region *metapb.Region) error { + return se.Remove(RegionPath(region.GetId())) +} + +// Flush flushes the pending data to the underlying storage backend. +func (se *StorageEndpoint) Flush() error { return nil } + +// Close closes the underlying storage backend. +func (se *StorageEndpoint) Close() error { return nil } diff --git a/server/storage/kv/binding__failpoint_binding__.go b/server/storage/kv/binding__failpoint_binding__.go new file mode 100755 index 00000000000..91ba6650c6d --- /dev/null +++ b/server/storage/kv/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package kv + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/storage/kv/etcd_kv.go b/server/storage/kv/etcd_kv.go old mode 100644 new mode 100755 index 00320f0134b..549d4ec9b67 --- a/server/storage/kv/etcd_kv.go +++ b/server/storage/kv/etcd_kv.go @@ -86,9 +86,9 @@ func (kv *etcdKVBase) LoadRange(key, endKey string, limit int) ([]string, []stri } func (kv *etcdKVBase) Save(key, value string) error { - failpoint.Inject("etcdSaveFailed", func() { - failpoint.Return(errors.New("save failed")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("etcdSaveFailed")); _err_ == nil { + return errors.New("save failed") + } key = path.Join(kv.rootPath, key) txn := NewSlowLogTxn(kv.client) resp, err := txn.Then(clientv3.OpPut(key, value)).Commit() diff --git a/server/storage/kv/etcd_kv.go__failpoint_stash__ b/server/storage/kv/etcd_kv.go__failpoint_stash__ new file mode 100644 index 00000000000..00320f0134b --- /dev/null +++ b/server/storage/kv/etcd_kv.go__failpoint_stash__ @@ -0,0 +1,173 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kv + +import ( + "context" + "path" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/etcdutil" + "go.etcd.io/etcd/clientv3" + "go.uber.org/zap" +) + +const ( + requestTimeout = 10 * time.Second + slowRequestTime = time.Second +) + +type etcdKVBase struct { + client *clientv3.Client + rootPath string +} + +// NewEtcdKVBase creates a new etcd kv. +func NewEtcdKVBase(client *clientv3.Client, rootPath string) *etcdKVBase { + return &etcdKVBase{ + client: client, + rootPath: rootPath, + } +} + +func (kv *etcdKVBase) Load(key string) (string, error) { + key = path.Join(kv.rootPath, key) + + resp, err := etcdutil.EtcdKVGet(kv.client, key) + if err != nil { + return "", err + } + if n := len(resp.Kvs); n == 0 { + return "", nil + } else if n > 1 { + return "", errs.ErrEtcdKVGetResponse.GenWithStackByArgs(resp.Kvs) + } + return string(resp.Kvs[0].Value), nil +} + +func (kv *etcdKVBase) LoadRange(key, endKey string, limit int) ([]string, []string, error) { + // Note: reason to use `strings.Join` instead of `path.Join` is that the latter will + // removes suffix '/' of the joined string. + // As a result, when we try to scan from "foo/", it ends up scanning from "/pd/foo" + // internally, and returns unexpected keys such as "foo_bar/baz". + key = strings.Join([]string{kv.rootPath, key}, "/") + endKey = strings.Join([]string{kv.rootPath, endKey}, "/") + + withRange := clientv3.WithRange(endKey) + withLimit := clientv3.WithLimit(int64(limit)) + resp, err := etcdutil.EtcdKVGet(kv.client, key, withRange, withLimit) + if err != nil { + return nil, nil, err + } + keys := make([]string, 0, len(resp.Kvs)) + values := make([]string, 0, len(resp.Kvs)) + for _, item := range resp.Kvs { + keys = append(keys, strings.TrimPrefix(strings.TrimPrefix(string(item.Key), kv.rootPath), "/")) + values = append(values, string(item.Value)) + } + return keys, values, nil +} + +func (kv *etcdKVBase) Save(key, value string) error { + failpoint.Inject("etcdSaveFailed", func() { + failpoint.Return(errors.New("save failed")) + }) + key = path.Join(kv.rootPath, key) + txn := NewSlowLogTxn(kv.client) + resp, err := txn.Then(clientv3.OpPut(key, value)).Commit() + if err != nil { + e := errs.ErrEtcdKVPut.Wrap(err).GenWithStackByCause() + log.Error("save to etcd meet error", zap.String("key", key), zap.String("value", value), errs.ZapError(e)) + return e + } + if !resp.Succeeded { + return errs.ErrEtcdTxnConflict.FastGenByArgs() + } + return nil +} + +func (kv *etcdKVBase) Remove(key string) error { + key = path.Join(kv.rootPath, key) + + txn := NewSlowLogTxn(kv.client) + resp, err := txn.Then(clientv3.OpDelete(key)).Commit() + if err != nil { + err = errs.ErrEtcdKVDelete.Wrap(err).GenWithStackByCause() + log.Error("remove from etcd meet error", zap.String("key", key), errs.ZapError(err)) + return err + } + if !resp.Succeeded { + return errs.ErrEtcdTxnConflict.FastGenByArgs() + } + return nil +} + +// SlowLogTxn wraps etcd transaction and log slow one. +type SlowLogTxn struct { + clientv3.Txn + cancel context.CancelFunc +} + +// NewSlowLogTxn create a SlowLogTxn. +func NewSlowLogTxn(client *clientv3.Client) clientv3.Txn { + ctx, cancel := context.WithTimeout(client.Ctx(), requestTimeout) + return &SlowLogTxn{ + Txn: client.Txn(ctx), + cancel: cancel, + } +} + +// If takes a list of comparison. If all comparisons passed in succeed, +// the operations passed into Then() will be executed. Or the operations +// passed into Else() will be executed. +func (t *SlowLogTxn) If(cs ...clientv3.Cmp) clientv3.Txn { + t.Txn = t.Txn.If(cs...) + return t +} + +// Then takes a list of operations. The Ops list will be executed, if the +// comparisons passed in If() succeed. +func (t *SlowLogTxn) Then(ops ...clientv3.Op) clientv3.Txn { + t.Txn = t.Txn.Then(ops...) + return t +} + +// Commit implements Txn Commit interface. +func (t *SlowLogTxn) Commit() (*clientv3.TxnResponse, error) { + start := time.Now() + resp, err := t.Txn.Commit() + t.cancel() + + cost := time.Since(start) + if cost > slowRequestTime { + log.Warn("txn runs too slow", + zap.Reflect("response", resp), + zap.Duration("cost", cost), + errs.ZapError(err)) + } + label := "success" + if err != nil { + label = "failed" + } + txnCounter.WithLabelValues(label).Inc() + txnDuration.WithLabelValues(label).Observe(cost.Seconds()) + + return resp, errors.WithStack(err) +} diff --git a/server/storage/kv/mem_kv.go b/server/storage/kv/mem_kv.go old mode 100644 new mode 100755 index b68ed89a451..59721b00e73 --- a/server/storage/kv/mem_kv.go +++ b/server/storage/kv/mem_kv.go @@ -54,12 +54,12 @@ func (kv *memoryKV) Load(key string) (string, error) { } func (kv *memoryKV) LoadRange(key, endKey string, limit int) ([]string, []string, error) { - failpoint.Inject("withRangeLimit", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("withRangeLimit")); _err_ == nil { rangeLimit, ok := val.(int) if ok && limit > rangeLimit { - failpoint.Return(nil, nil, errors.Errorf("limit %d exceed max rangeLimit %d", limit, rangeLimit)) + return nil, nil, errors.Errorf("limit %d exceed max rangeLimit %d", limit, rangeLimit) } - }) + } kv.RLock() defer kv.RUnlock() keys := make([]string, 0, limit) diff --git a/server/storage/kv/mem_kv.go__failpoint_stash__ b/server/storage/kv/mem_kv.go__failpoint_stash__ new file mode 100644 index 00000000000..b68ed89a451 --- /dev/null +++ b/server/storage/kv/mem_kv.go__failpoint_stash__ @@ -0,0 +1,91 @@ +// Copyright 2017 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kv + +import ( + "github.com/google/btree" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/tikv/pd/pkg/syncutil" +) + +type memoryKV struct { + syncutil.RWMutex + tree *btree.BTreeG[memoryKVItem] +} + +// NewMemoryKV returns an in-memory kvBase for testing. +func NewMemoryKV() Base { + return &memoryKV{ + tree: btree.NewG(2, func(i, j memoryKVItem) bool { + return i.Less(&j) + }), + } +} + +type memoryKVItem struct { + key, value string +} + +func (s *memoryKVItem) Less(than *memoryKVItem) bool { + return s.key < than.key +} + +func (kv *memoryKV) Load(key string) (string, error) { + kv.RLock() + defer kv.RUnlock() + item, ok := kv.tree.Get(memoryKVItem{key, ""}) + if !ok { + return "", nil + } + return item.value, nil +} + +func (kv *memoryKV) LoadRange(key, endKey string, limit int) ([]string, []string, error) { + failpoint.Inject("withRangeLimit", func(val failpoint.Value) { + rangeLimit, ok := val.(int) + if ok && limit > rangeLimit { + failpoint.Return(nil, nil, errors.Errorf("limit %d exceed max rangeLimit %d", limit, rangeLimit)) + } + }) + kv.RLock() + defer kv.RUnlock() + keys := make([]string, 0, limit) + values := make([]string, 0, limit) + kv.tree.AscendRange(memoryKVItem{key, ""}, memoryKVItem{endKey, ""}, func(item memoryKVItem) bool { + keys = append(keys, item.key) + values = append(values, item.value) + if limit > 0 { + return len(keys) < limit + } + return true + }) + return keys, values, nil +} + +func (kv *memoryKV) Save(key, value string) error { + kv.Lock() + defer kv.Unlock() + kv.tree.ReplaceOrInsert(memoryKVItem{key, value}) + return nil +} + +func (kv *memoryKV) Remove(key string) error { + kv.Lock() + defer kv.Unlock() + + kv.tree.Delete(memoryKVItem{key, ""}) + return nil +} diff --git a/server/storage/leveldb_backend.go b/server/storage/leveldb_backend.go old mode 100644 new mode 100755 index f30ec48d0ac..935f862732d --- a/server/storage/leveldb_backend.go +++ b/server/storage/leveldb_backend.go @@ -92,9 +92,9 @@ func (lb *levelDBBackend) backgroundFlush() { case <-ticker.C: lb.mu.RLock() isFlush = lb.flushTime.Before(time.Now()) - failpoint.Inject("regionStorageFastFlush", func() { + if _, _err_ := failpoint.Eval(_curpkg_("regionStorageFastFlush")); _err_ == nil { isFlush = true - }) + } lb.mu.RUnlock() if !isFlush { continue diff --git a/server/storage/leveldb_backend.go__failpoint_stash__ b/server/storage/leveldb_backend.go__failpoint_stash__ new file mode 100644 index 00000000000..f30ec48d0ac --- /dev/null +++ b/server/storage/leveldb_backend.go__failpoint_stash__ @@ -0,0 +1,183 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "time" + + "github.com/gogo/protobuf/proto" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + "github.com/syndtr/goleveldb/leveldb" + "github.com/tikv/pd/pkg/encryption" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/server/encryptionkm" + "github.com/tikv/pd/server/storage/endpoint" + "github.com/tikv/pd/server/storage/kv" +) + +const ( + // DefaultFlushRegionRate is the ttl to sync the regions to region storage. + defaultFlushRegionRate = 3 * time.Second + // DefaultBatchSize is the batch size to save the regions to region storage. + defaultBatchSize = 100 +) + +// levelDBBackend is a storage backend that stores data in LevelDB, +// which is mainly used by the PD region storage. +type levelDBBackend struct { + *endpoint.StorageEndpoint + ekm *encryptionkm.KeyManager + mu syncutil.RWMutex + batchRegions map[string]*metapb.Region + batchSize int + cacheSize int + flushRate time.Duration + flushTime time.Time + regionStorageCtx context.Context + regionStorageCancel context.CancelFunc +} + +// newLevelDBBackend is used to create a new LevelDB backend. +func newLevelDBBackend( + ctx context.Context, + filePath string, + ekm *encryptionkm.KeyManager, +) (*levelDBBackend, error) { + levelDB, err := kv.NewLevelDBKV(filePath) + if err != nil { + return nil, err + } + regionStorageCtx, regionStorageCancel := context.WithCancel(ctx) + lb := &levelDBBackend{ + StorageEndpoint: endpoint.NewStorageEndpoint(levelDB, ekm), + ekm: ekm, + batchSize: defaultBatchSize, + flushRate: defaultFlushRegionRate, + batchRegions: make(map[string]*metapb.Region, defaultBatchSize), + flushTime: time.Now().Add(defaultFlushRegionRate), + regionStorageCtx: regionStorageCtx, + regionStorageCancel: regionStorageCancel, + } + go lb.backgroundFlush() + return lb, nil +} + +var dirtyFlushTick = time.Second + +func (lb *levelDBBackend) backgroundFlush() { + var ( + isFlush bool + err error + ) + ticker := time.NewTicker(dirtyFlushTick) + defer ticker.Stop() + for { + select { + case <-ticker.C: + lb.mu.RLock() + isFlush = lb.flushTime.Before(time.Now()) + failpoint.Inject("regionStorageFastFlush", func() { + isFlush = true + }) + lb.mu.RUnlock() + if !isFlush { + continue + } + if err = lb.Flush(); err != nil { + log.Error("flush regions meet error", errs.ZapError(err)) + } + case <-lb.regionStorageCtx.Done(): + return + } + } +} + +func (lb *levelDBBackend) SaveRegion(region *metapb.Region) error { + region, err := encryption.EncryptRegion(region, lb.ekm) + if err != nil { + return err + } + lb.mu.Lock() + defer lb.mu.Unlock() + if lb.cacheSize < lb.batchSize-1 { + lb.batchRegions[endpoint.RegionPath(region.GetId())] = region + lb.cacheSize++ + + lb.flushTime = time.Now().Add(lb.flushRate) + return nil + } + lb.batchRegions[endpoint.RegionPath(region.GetId())] = region + err = lb.flushLocked() + + if err != nil { + return err + } + return nil +} + +func (lb *levelDBBackend) DeleteRegion(region *metapb.Region) error { + return lb.Remove(endpoint.RegionPath(region.GetId())) +} + +// Flush saves the cache region to the underlying storage. +func (lb *levelDBBackend) Flush() error { + lb.mu.Lock() + defer lb.mu.Unlock() + return lb.flushLocked() +} + +func (lb *levelDBBackend) flushLocked() error { + if err := lb.saveRegions(lb.batchRegions); err != nil { + return err + } + lb.cacheSize = 0 + lb.batchRegions = make(map[string]*metapb.Region, lb.batchSize) + return nil +} + +func (lb *levelDBBackend) saveRegions(regions map[string]*metapb.Region) error { + batch := new(leveldb.Batch) + + for key, r := range regions { + value, err := proto.Marshal(r) + if err != nil { + return errs.ErrProtoMarshal.Wrap(err).GenWithStackByCause() + } + batch.Put([]byte(key), value) + } + + if err := lb.Base.(*kv.LevelDBKV).Write(batch, nil); err != nil { + return errs.ErrLevelDBWrite.Wrap(err).GenWithStackByCause() + } + return nil +} + +// Close closes the LevelDB kv. It will call Flush() once before closing. +func (lb *levelDBBackend) Close() error { + err := lb.Flush() + if err != nil { + log.Error("meet error before close the region storage", errs.ZapError(err)) + } + lb.regionStorageCancel() + err = lb.Base.(*kv.LevelDBKV).Close() + if err != nil { + return errs.ErrLevelDBClose.Wrap(err).GenWithStackByArgs() + } + return nil +} diff --git a/server/tso/allocator_manager.go b/server/tso/allocator_manager.go old mode 100644 new mode 100755 index 8ce1b898287..f9486252a6e --- a/server/tso/allocator_manager.go +++ b/server/tso/allocator_manager.go @@ -470,7 +470,7 @@ func (am *AllocatorManager) campaignAllocatorLeader( nextLeaderValue := fmt.Sprintf("%v", am.member.ID()) cmps = append(cmps, clientv3.Compare(clientv3.Value(nextLeaderKey), "=", nextLeaderValue)) } - failpoint.Inject("injectNextLeaderKey", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectNextLeaderKey")); _err_ == nil { if val.(bool) { // In order not to campaign leader too often in tests time.Sleep(5 * time.Second) @@ -478,7 +478,7 @@ func (am *AllocatorManager) campaignAllocatorLeader( clientv3.Compare(clientv3.Value(nextLeaderKey), "=", "mockValue"), } } - }) + } if err := allocator.CampaignAllocatorLeader(defaultAllocatorLeaderLease, cmps...); err != nil { if err.Error() == errs.ErrEtcdTxnConflict.Error() { log.Info("failed to campaign local tso allocator leader due to txn conflict, another allocator may campaign successfully", diff --git a/server/tso/allocator_manager.go__failpoint_stash__ b/server/tso/allocator_manager.go__failpoint_stash__ new file mode 100644 index 00000000000..8ce1b898287 --- /dev/null +++ b/server/tso/allocator_manager.go__failpoint_stash__ @@ -0,0 +1,1194 @@ +// Copyright 2020 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tso + +import ( + "context" + "fmt" + "math" + "path" + "strconv" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/etcdutil" + "github.com/tikv/pd/pkg/grpcutil" + "github.com/tikv/pd/pkg/slice" + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/server/config" + "github.com/tikv/pd/server/election" + "github.com/tikv/pd/server/member" + "github.com/tikv/pd/server/storage/kv" + "go.etcd.io/etcd/clientv3" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +const ( + // GlobalDCLocation is the Global TSO Allocator's DC location label. + GlobalDCLocation = "global" + checkStep = time.Minute + patrolStep = time.Second + defaultAllocatorLeaderLease = 3 + leaderTickInterval = 50 * time.Millisecond + localTSOAllocatorEtcdPrefix = "lta" + localTSOSuffixEtcdPrefix = "lts" +) + +var ( + // PriorityCheck exported is only for test. + PriorityCheck = time.Minute +) + +// AllocatorGroupFilter is used to select AllocatorGroup. +type AllocatorGroupFilter func(ag *allocatorGroup) bool + +type allocatorGroup struct { + dcLocation string + // ctx is built with cancel from a parent context when set up which can be different + // in order to receive Done() signal correctly. + // cancel would be call when allocatorGroup is deleted to stop background loop. + ctx context.Context + cancel context.CancelFunc + // For the Global TSO Allocator, leadership is a PD leader's + // leadership, and for the Local TSO Allocator, leadership + // is a DC-level certificate to allow an allocator to generate + // TSO for local transactions in its DC. + leadership *election.Leadership + allocator Allocator +} + +// DCLocationInfo is used to record some dc-location related info, +// such as suffix sign and server IDs in this dc-location. +type DCLocationInfo struct { + // dc-location/global (string) -> Member IDs + ServerIDs []uint64 + // dc-location (string) -> Suffix sign. It is collected and maintained by the PD leader. + Suffix int32 +} + +func (info *DCLocationInfo) clone() DCLocationInfo { + copiedInfo := DCLocationInfo{ + Suffix: info.Suffix, + } + // Make a deep copy here for the slice + copiedInfo.ServerIDs = make([]uint64, len(info.ServerIDs)) + copy(copiedInfo.ServerIDs, info.ServerIDs) + return copiedInfo +} + +// AllocatorManager is used to manage the TSO Allocators a PD server holds. +// It is in charge of maintaining TSO allocators' leadership, checking election +// priority, and forwarding TSO allocation requests to correct TSO Allocators. +type AllocatorManager struct { + enableLocalTSO bool + mu struct { + syncutil.RWMutex + // There are two kinds of TSO Allocators: + // 1. Global TSO Allocator, as a global single point to allocate + // TSO for global transactions, such as cross-region cases. + // 2. Local TSO Allocator, servers for DC-level transactions. + // dc-location/global (string) -> TSO Allocator + allocatorGroups map[string]*allocatorGroup + clusterDCLocations map[string]*DCLocationInfo + // The max suffix sign we have so far, it will be used to calculate + // the number of suffix bits we need in the TSO logical part. + maxSuffix int32 + } + wg sync.WaitGroup + // for election use + member *member.Member + // TSO config + rootPath string + saveInterval time.Duration + updatePhysicalInterval time.Duration + maxResetTSGap func() time.Duration + securityConfig *grpcutil.TLSConfig + // for gRPC use + localAllocatorConn struct { + syncutil.RWMutex + clientConns map[string]*grpc.ClientConn + } +} + +// NewAllocatorManager creates a new TSO Allocator Manager. +func NewAllocatorManager( + m *member.Member, + rootPath string, + cfg *config.Config, + maxResetTSGap func() time.Duration, +) *AllocatorManager { + allocatorManager := &AllocatorManager{ + enableLocalTSO: cfg.EnableLocalTSO, + member: m, + rootPath: rootPath, + saveInterval: cfg.TSOSaveInterval.Duration, + updatePhysicalInterval: cfg.TSOUpdatePhysicalInterval.Duration, + maxResetTSGap: maxResetTSGap, + securityConfig: &cfg.Security.TLSConfig, + } + allocatorManager.mu.allocatorGroups = make(map[string]*allocatorGroup) + allocatorManager.mu.clusterDCLocations = make(map[string]*DCLocationInfo) + allocatorManager.localAllocatorConn.clientConns = make(map[string]*grpc.ClientConn) + return allocatorManager +} + +// SetLocalTSOConfig receives the zone label of this PD server and write it into etcd as dc-location +// to make the whole cluster know the DC-level topology for later Local TSO Allocator campaign. +func (am *AllocatorManager) SetLocalTSOConfig(dcLocation string) error { + serverName := am.member.Member().Name + serverID := am.member.ID() + if err := am.checkDCLocationUpperLimit(dcLocation); err != nil { + log.Error("check dc-location upper limit failed", + zap.Int("upper-limit", int(math.Pow(2, MaxSuffixBits))-1), + zap.String("dc-location", dcLocation), + zap.String("server-name", serverName), + zap.Uint64("server-id", serverID), + errs.ZapError(err)) + return err + } + // The key-value pair in etcd will be like: serverID -> dcLocation + dcLocationKey := am.member.GetDCLocationPath(serverID) + resp, err := kv. + NewSlowLogTxn(am.member.Client()). + Then(clientv3.OpPut(dcLocationKey, dcLocation)). + Commit() + if err != nil { + return errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause() + } + if !resp.Succeeded { + log.Warn("write dc-location configuration into etcd failed", + zap.String("dc-location", dcLocation), + zap.String("server-name", serverName), + zap.Uint64("server-id", serverID)) + return errs.ErrEtcdTxnConflict.FastGenByArgs() + } + log.Info("write dc-location configuration into etcd", + zap.String("dc-location", dcLocation), + zap.String("server-name", serverName), + zap.Uint64("server-id", serverID)) + go am.ClusterDCLocationChecker() + return nil +} + +func (am *AllocatorManager) checkDCLocationUpperLimit(dcLocation string) error { + clusterDCLocations, err := am.GetClusterDCLocationsFromEtcd() + if err != nil { + return err + } + // It's ok to add a new PD to the old dc-location. + if _, ok := clusterDCLocations[dcLocation]; ok { + return nil + } + // Check whether the dc-location number meets the upper limit 2**(LogicalBits-1)-1, + // which includes 1 global and 2**(LogicalBits-1) local + if len(clusterDCLocations) == int(math.Pow(2, MaxSuffixBits))-1 { + return errs.ErrSetLocalTSOConfig.FastGenByArgs("the number of dc-location meets the upper limit") + } + return nil +} + +// GetClusterDCLocationsFromEtcd fetches dcLocation topology from etcd +func (am *AllocatorManager) GetClusterDCLocationsFromEtcd() (clusterDCLocations map[string][]uint64, err error) { + resp, err := etcdutil.EtcdKVGet( + am.member.Client(), + am.member.GetDCLocationPathPrefix(), + clientv3.WithPrefix()) + if err != nil { + return clusterDCLocations, err + } + clusterDCLocations = make(map[string][]uint64) + for _, kv := range resp.Kvs { + // The key will contain the member ID and the value is its dcLocation + serverPath := strings.Split(string(kv.Key), "/") + // Get serverID from serverPath, e.g, /pd/dc-location/1232143243253 -> 1232143243253 + serverID, err := strconv.ParseUint(serverPath[len(serverPath)-1], 10, 64) + dcLocation := string(kv.Value) + if err != nil { + log.Warn("get server id and dcLocation from etcd failed, invalid server id", + zap.Any("splitted-serverPath", serverPath), + zap.String("dc-location", dcLocation), + errs.ZapError(err)) + continue + } + clusterDCLocations[dcLocation] = append(clusterDCLocations[dcLocation], serverID) + } + return clusterDCLocations, nil +} + +// GetDCLocationInfo returns a copy of DCLocationInfo of the given dc-location, +func (am *AllocatorManager) GetDCLocationInfo(dcLocation string) (DCLocationInfo, bool) { + am.mu.RLock() + defer am.mu.RUnlock() + infoPtr, ok := am.mu.clusterDCLocations[dcLocation] + if !ok { + return DCLocationInfo{}, false + } + return infoPtr.clone(), true +} + +// CleanUpDCLocation cleans up certain server's DCLocationInfo +func (am *AllocatorManager) CleanUpDCLocation() error { + serverID := am.member.ID() + dcLocationKey := am.member.GetDCLocationPath(serverID) + // remove dcLocationKey from etcd + if resp, err := kv. + NewSlowLogTxn(am.member.Client()). + Then(clientv3.OpDelete(dcLocationKey)). + Commit(); err != nil { + return errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause() + } else if !resp.Succeeded { + return errs.ErrEtcdTxnConflict.FastGenByArgs() + } + log.Info("delete the dc-location key previously written in etcd", + zap.Uint64("server-id", serverID)) + go am.ClusterDCLocationChecker() + return nil +} + +// GetClusterDCLocations returns all dc-locations of a cluster with a copy of map, +// which satisfies dcLocation -> DCLocationInfo. +func (am *AllocatorManager) GetClusterDCLocations() map[string]DCLocationInfo { + am.mu.RLock() + defer am.mu.RUnlock() + dcLocationMap := make(map[string]DCLocationInfo) + for dcLocation, info := range am.mu.clusterDCLocations { + dcLocationMap[dcLocation] = info.clone() + } + return dcLocationMap +} + +// GetClusterDCLocationsNumber returns the number of cluster dc-locations. +func (am *AllocatorManager) GetClusterDCLocationsNumber() int { + am.mu.RLock() + defer am.mu.RUnlock() + return len(am.mu.clusterDCLocations) +} + +// compareAndSetMaxSuffix sets the max suffix sign if suffix is greater than am.mu.maxSuffix. +func (am *AllocatorManager) compareAndSetMaxSuffix(suffix int32) { + am.mu.Lock() + defer am.mu.Unlock() + if suffix > am.mu.maxSuffix { + am.mu.maxSuffix = suffix + } +} + +// GetSuffixBits calculates the bits of suffix sign +// by the max number of suffix so far, +// which will be used in the TSO logical part. +func (am *AllocatorManager) GetSuffixBits() int { + am.mu.RLock() + defer am.mu.RUnlock() + return CalSuffixBits(am.mu.maxSuffix) +} + +// CalSuffixBits calculates the bits of suffix by the max suffix sign. +func CalSuffixBits(maxSuffix int32) int { + // maxSuffix + 1 because we have the Global TSO holds 0 as the suffix sign + return int(math.Ceil(math.Log2(float64(maxSuffix + 1)))) +} + +// SetUpAllocator is used to set up an allocator, which will initialize the allocator and put it into allocator daemon. +// One TSO Allocator should only be set once, and may be initialized and reset multiple times depending on the election. +func (am *AllocatorManager) SetUpAllocator(parentCtx context.Context, dcLocation string, leadership *election.Leadership) { + am.mu.Lock() + defer am.mu.Unlock() + if am.updatePhysicalInterval != config.DefaultTSOUpdatePhysicalInterval { + log.Warn("tso update physical interval is non-default", + zap.Duration("update-physical-interval", am.updatePhysicalInterval)) + } + if _, exist := am.mu.allocatorGroups[dcLocation]; exist { + return + } + var allocator Allocator + if dcLocation == GlobalDCLocation { + allocator = NewGlobalTSOAllocator(am, leadership) + } else { + allocator = NewLocalTSOAllocator(am, leadership, dcLocation) + } + // Create a new allocatorGroup + ctx, cancel := context.WithCancel(parentCtx) + am.mu.allocatorGroups[dcLocation] = &allocatorGroup{ + dcLocation: dcLocation, + ctx: ctx, + cancel: cancel, + leadership: leadership, + allocator: allocator, + } + // Because the Global TSO Allocator only depends on PD leader's leadership, + // so we can directly return here. The election and initialization process + // will happen in server.campaignLeader(). + if dcLocation == GlobalDCLocation { + return + } + // Start election of the Local TSO Allocator here + localTSOAllocator, _ := allocator.(*LocalTSOAllocator) + go am.allocatorLeaderLoop(parentCtx, localTSOAllocator) +} + +func (am *AllocatorManager) getAllocatorPath(dcLocation string) string { + // For backward compatibility, the global timestamp's store path will still use the old one + if dcLocation == GlobalDCLocation { + return am.rootPath + } + return path.Join(am.getLocalTSOAllocatorPath(), dcLocation) +} + +// Add a prefix to the root path to prevent being conflicted +// with other system key paths such as leader, member, alloc_id, raft, etc. +func (am *AllocatorManager) getLocalTSOAllocatorPath() string { + return path.Join(am.rootPath, localTSOAllocatorEtcdPrefix) +} + +// similar logic with leaderLoop in server/server.go +func (am *AllocatorManager) allocatorLeaderLoop(ctx context.Context, allocator *LocalTSOAllocator) { + defer log.Info("server is closed, return local tso allocator leader loop", + zap.String("dc-location", allocator.GetDCLocation()), + zap.String("local-tso-allocator-name", am.member.Member().Name)) + for { + select { + case <-ctx.Done(): + return + default: + } + + // Check whether the Local TSO Allocator has the leader already + allocatorLeader, rev, checkAgain := allocator.CheckAllocatorLeader() + if checkAgain { + continue + } + if allocatorLeader != nil { + log.Info("start to watch allocator leader", + zap.Stringer(fmt.Sprintf("%s-allocator-leader", allocator.GetDCLocation()), allocatorLeader), + zap.String("local-tso-allocator-name", am.member.Member().Name)) + // WatchAllocatorLeader will keep looping and never return unless the Local TSO Allocator leader has changed. + allocator.WatchAllocatorLeader(ctx, allocatorLeader, rev) + log.Info("local tso allocator leader has changed, try to re-campaign a local tso allocator leader", + zap.String("dc-location", allocator.GetDCLocation())) + } + + // Check the next-leader key + nextLeader, err := am.getNextLeaderID(allocator.GetDCLocation()) + if err != nil { + log.Error("get next leader from etcd failed", + zap.String("dc-location", allocator.GetDCLocation()), + errs.ZapError(err)) + time.Sleep(200 * time.Millisecond) + continue + } + isNextLeader := false + if nextLeader != 0 { + if nextLeader != am.member.ID() { + log.Info("skip campaigning of the local tso allocator leader and check later", + zap.String("server-name", am.member.Member().Name), + zap.Uint64("server-id", am.member.ID()), + zap.Uint64("next-leader-id", nextLeader)) + time.Sleep(200 * time.Millisecond) + continue + } + isNextLeader = true + } + + // Make sure the leader is aware of this new dc-location in order to make the + // Global TSO synchronization can cover up this dc-location. + ok, dcLocationInfo, err := am.getDCLocationInfoFromLeader(ctx, allocator.GetDCLocation()) + if err != nil { + log.Error("get dc-location info from pd leader failed", + zap.String("dc-location", allocator.GetDCLocation()), + errs.ZapError(err)) + // PD leader hasn't been elected out, wait for the campaign + if !longSleep(ctx, time.Second) { + return + } + continue + } + if !ok || dcLocationInfo.Suffix <= 0 || dcLocationInfo.MaxTs == nil { + log.Warn("pd leader is not aware of dc-location during allocatorLeaderLoop, wait next round", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + zap.String("wait-duration", checkStep.String())) + // Because the checkStep is long, we use select here to check whether the ctx is done + // to prevent the leak of goroutine. + if !longSleep(ctx, checkStep) { + return + } + continue + } + + am.campaignAllocatorLeader(ctx, allocator, dcLocationInfo, isNextLeader) + } +} + +// longSleep is used to sleep the long wait duration while also watching the +// ctx.Done() to prevent the goroutine from leaking. This function returns +// true if the sleep is over, false if the ctx is done. +func longSleep(ctx context.Context, waitStep time.Duration) bool { + waitTicker := time.NewTicker(waitStep) + defer waitTicker.Stop() + select { + case <-ctx.Done(): + return false + case <-waitTicker.C: + return true + } +} + +func (am *AllocatorManager) campaignAllocatorLeader( + loopCtx context.Context, + allocator *LocalTSOAllocator, + dcLocationInfo *pdpb.GetDCLocationInfoResponse, + isNextLeader bool, +) { + log.Info("start to campaign local tso allocator leader", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + zap.String("name", am.member.Member().Name)) + cmps := make([]clientv3.Cmp, 0) + nextLeaderKey := am.nextLeaderKey(allocator.GetDCLocation()) + if !isNextLeader { + cmps = append(cmps, clientv3.Compare(clientv3.CreateRevision(nextLeaderKey), "=", 0)) + } else { + nextLeaderValue := fmt.Sprintf("%v", am.member.ID()) + cmps = append(cmps, clientv3.Compare(clientv3.Value(nextLeaderKey), "=", nextLeaderValue)) + } + failpoint.Inject("injectNextLeaderKey", func(val failpoint.Value) { + if val.(bool) { + // In order not to campaign leader too often in tests + time.Sleep(5 * time.Second) + cmps = []clientv3.Cmp{ + clientv3.Compare(clientv3.Value(nextLeaderKey), "=", "mockValue"), + } + } + }) + if err := allocator.CampaignAllocatorLeader(defaultAllocatorLeaderLease, cmps...); err != nil { + if err.Error() == errs.ErrEtcdTxnConflict.Error() { + log.Info("failed to campaign local tso allocator leader due to txn conflict, another allocator may campaign successfully", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + zap.String("name", am.member.Member().Name)) + } else { + log.Error("failed to campaign local tso allocator leader due to etcd error", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + zap.String("name", am.member.Member().Name), + errs.ZapError(err)) + } + return + } + + // Start keepalive the Local TSO Allocator leadership and enable Local TSO service. + ctx, cancel := context.WithCancel(loopCtx) + defer cancel() + defer am.ResetAllocatorGroup(allocator.GetDCLocation()) + // Maintain the Local TSO Allocator leader + go allocator.KeepAllocatorLeader(ctx) + log.Info("campaign local tso allocator leader ok", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + zap.String("name", am.member.Member().Name)) + + log.Info("initialize the local TSO allocator", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + zap.String("name", am.member.Member().Name)) + if err := allocator.Initialize(int(dcLocationInfo.Suffix)); err != nil { + log.Error("failed to initialize the local TSO allocator", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + errs.ZapError(err)) + return + } + if dcLocationInfo.GetMaxTs().GetPhysical() != 0 { + if err := allocator.WriteTSO(dcLocationInfo.GetMaxTs()); err != nil { + log.Error("failed to write the max local TSO after member changed", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + errs.ZapError(err)) + return + } + } + am.compareAndSetMaxSuffix(dcLocationInfo.Suffix) + allocator.EnableAllocatorLeader() + // The next leader is me, delete it to finish campaigning + am.deleteNextLeaderID(allocator.GetDCLocation()) + log.Info("local tso allocator leader is ready to serve", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + zap.String("name", am.member.Member().Name)) + + leaderTicker := time.NewTicker(leaderTickInterval) + defer leaderTicker.Stop() + + for { + select { + case <-leaderTicker.C: + if !allocator.IsAllocatorLeader() { + log.Info("no longer a local tso allocator leader because lease has expired, local tso allocator leader will step down", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + zap.String("name", am.member.Member().Name)) + return + } + case <-ctx.Done(): + // Server is closed and it should return nil. + log.Info("server is closed, reset the local tso allocator", + zap.String("dc-location", allocator.GetDCLocation()), + zap.Any("dc-location-info", dcLocationInfo), + zap.String("name", am.member.Member().Name)) + return + } + } +} + +// AllocatorDaemon is used to update every allocator's TSO and check whether we have +// any new local allocator that needs to be set up. +func (am *AllocatorManager) AllocatorDaemon(serverCtx context.Context) { + // allocatorPatroller should only work when enableLocalTSO is true to + // set up the new Local TSO Allocator in time. + var patrolTicker = &time.Ticker{} + if am.enableLocalTSO { + patrolTicker = time.NewTicker(patrolStep) + defer patrolTicker.Stop() + } + tsTicker := time.NewTicker(am.updatePhysicalInterval) + defer tsTicker.Stop() + checkerTicker := time.NewTicker(PriorityCheck) + defer checkerTicker.Stop() + + for { + select { + case <-patrolTicker.C: + // Inspect the cluster dc-location info and set up the new Local TSO Allocator in time. + am.allocatorPatroller(serverCtx) + case <-tsTicker.C: + // Update the initialized TSO Allocator to advance TSO. + am.allocatorUpdater() + case <-checkerTicker.C: + // Check and maintain the cluster's meta info about dc-location distribution. + go am.ClusterDCLocationChecker() + // We won't have any Local TSO Allocator set up in PD without enabling Local TSO. + if am.enableLocalTSO { + // Check the election priority of every Local TSO Allocator this PD is holding. + go am.PriorityChecker() + } + // PS: ClusterDCLocationChecker and PriorityChecker are time consuming and low frequent to run, + // we should run them concurrently to speed up the progress. + case <-serverCtx.Done(): + return + } + } +} + +// Update the Local TSO Allocator leaders TSO in memory concurrently. +func (am *AllocatorManager) allocatorUpdater() { + // Filter out allocators without leadership and uninitialized + allocatorGroups := am.getAllocatorGroups(FilterUninitialized(), FilterUnavailableLeadership()) + // Update each allocator concurrently + for _, ag := range allocatorGroups { + am.wg.Add(1) + go am.updateAllocator(ag) + } + am.wg.Wait() +} + +// updateAllocator is used to update the allocator in the group. +func (am *AllocatorManager) updateAllocator(ag *allocatorGroup) { + defer am.wg.Done() + select { + case <-ag.ctx.Done(): + // Resetting the allocator will clear TSO in memory + ag.allocator.Reset() + return + default: + } + if !ag.leadership.Check() { + log.Info("allocator doesn't campaign leadership yet", zap.String("dc-location", ag.dcLocation)) + time.Sleep(200 * time.Millisecond) + return + } + if err := ag.allocator.UpdateTSO(); err != nil { + log.Warn("failed to update allocator's timestamp", zap.String("dc-location", ag.dcLocation), errs.ZapError(err)) + am.ResetAllocatorGroup(ag.dcLocation) + return + } +} + +// Check if we have any new dc-location configured, if yes, +// then set up the corresponding local allocator. +func (am *AllocatorManager) allocatorPatroller(serverCtx context.Context) { + // Collect all dc-locations + dcLocations := am.GetClusterDCLocations() + // Get all Local TSO Allocators + allocatorGroups := am.getAllocatorGroups(FilterDCLocation(GlobalDCLocation)) + // Set up the new one + for dcLocation := range dcLocations { + if slice.NoneOf(allocatorGroups, func(i int) bool { + return allocatorGroups[i].dcLocation == dcLocation + }) { + am.SetUpAllocator(serverCtx, dcLocation, election.NewLeadership( + am.member.Client(), + am.getAllocatorPath(dcLocation), + fmt.Sprintf("%s local allocator leader election", dcLocation), + )) + } + } + // Clean up the unused one + for _, ag := range allocatorGroups { + if _, exist := dcLocations[ag.dcLocation]; !exist { + am.deleteAllocatorGroup(ag.dcLocation) + } + } +} + +// ClusterDCLocationChecker collects all dc-locations of a cluster, computes some related info +// and stores them into the DCLocationInfo, then finally writes them into am.mu.clusterDCLocations. +func (am *AllocatorManager) ClusterDCLocationChecker() { + // Wait for the PD leader to be elected out. + if am.member.GetLeader() == nil { + return + } + newClusterDCLocations, err := am.GetClusterDCLocationsFromEtcd() + if err != nil { + log.Error("get cluster dc-locations from etcd failed", errs.ZapError(err)) + return + } + am.mu.Lock() + // Clean up the useless dc-locations + for dcLocation := range am.mu.clusterDCLocations { + if _, ok := newClusterDCLocations[dcLocation]; !ok { + delete(am.mu.clusterDCLocations, dcLocation) + } + } + // May be used to rollback the updating after + newDCLocations := make([]string, 0) + // Update the new dc-locations + for dcLocation, serverIDs := range newClusterDCLocations { + if _, ok := am.mu.clusterDCLocations[dcLocation]; !ok { + am.mu.clusterDCLocations[dcLocation] = &DCLocationInfo{ + ServerIDs: serverIDs, + Suffix: -1, + } + newDCLocations = append(newDCLocations, dcLocation) + } + } + // Only leader can write the TSO suffix to etcd in order to make it consistent in the cluster + if am.member.IsLeader() { + for dcLocation, info := range am.mu.clusterDCLocations { + if info.Suffix > 0 { + continue + } + suffix, err := am.getOrCreateLocalTSOSuffix(dcLocation) + if err != nil { + log.Warn("get or create the local tso suffix failed", zap.String("dc-location", dcLocation), errs.ZapError(err)) + continue + } + if suffix > am.mu.maxSuffix { + am.mu.maxSuffix = suffix + } + am.mu.clusterDCLocations[dcLocation].Suffix = suffix + } + } else { + // Follower should check and update the am.mu.maxSuffix + maxSuffix, err := am.getMaxLocalTSOSuffix() + if err != nil { + log.Error("get the max local tso suffix from etcd failed", errs.ZapError(err)) + // Rollback the new dc-locations we update before + for _, dcLocation := range newDCLocations { + delete(am.mu.clusterDCLocations, dcLocation) + } + } else if maxSuffix > am.mu.maxSuffix { + am.mu.maxSuffix = maxSuffix + } + } + am.mu.Unlock() +} + +// getOrCreateLocalTSOSuffix will check whether we have the Local TSO suffix written into etcd. +// If not, it will write a number into etcd according to the its joining order. +// If yes, it will just return the previous persisted one. +func (am *AllocatorManager) getOrCreateLocalTSOSuffix(dcLocation string) (int32, error) { + // Try to get the suffix from etcd + dcLocationSuffix, err := am.getDCLocationSuffixMapFromEtcd() + if err != nil { + return -1, nil + } + var maxSuffix int32 + for curDCLocation, suffix := range dcLocationSuffix { + // If we already have the suffix persistted in etcd before, + // just use it as the result directly. + if curDCLocation == dcLocation { + return suffix, nil + } + if suffix > maxSuffix { + maxSuffix = suffix + } + } + maxSuffix++ + localTSOSuffixKey := am.GetLocalTSOSuffixPath(dcLocation) + // The Local TSO suffix is determined by the joining order of this dc-location. + localTSOSuffixValue := strconv.FormatInt(int64(maxSuffix), 10) + txnResp, err := kv.NewSlowLogTxn(am.member.Client()). + If(clientv3.Compare(clientv3.CreateRevision(localTSOSuffixKey), "=", 0)). + Then(clientv3.OpPut(localTSOSuffixKey, localTSOSuffixValue)). + Commit() + if err != nil { + return -1, errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause() + } + if !txnResp.Succeeded { + log.Warn("write local tso suffix into etcd failed", + zap.String("dc-location", dcLocation), + zap.String("local-tso-surfix", localTSOSuffixValue), + zap.String("server-name", am.member.Member().Name), + zap.Uint64("server-id", am.member.ID())) + return -1, errs.ErrEtcdTxnConflict.FastGenByArgs() + } + return maxSuffix, nil +} + +func (am *AllocatorManager) getDCLocationSuffixMapFromEtcd() (map[string]int32, error) { + resp, err := etcdutil.EtcdKVGet( + am.member.Client(), + am.GetLocalTSOSuffixPathPrefix(), + clientv3.WithPrefix()) + if err != nil { + return nil, err + } + dcLocationSuffix := make(map[string]int32) + for _, kv := range resp.Kvs { + suffix, err := strconv.ParseInt(string(kv.Value), 10, 32) + if err != nil { + return nil, err + } + splittedKey := strings.Split(string(kv.Key), "/") + dcLocation := splittedKey[len(splittedKey)-1] + dcLocationSuffix[dcLocation] = int32(suffix) + } + return dcLocationSuffix, nil +} + +func (am *AllocatorManager) getMaxLocalTSOSuffix() (int32, error) { + // Try to get the suffix from etcd + dcLocationSuffix, err := am.getDCLocationSuffixMapFromEtcd() + if err != nil { + return -1, err + } + var maxSuffix int32 + for _, suffix := range dcLocationSuffix { + if suffix > maxSuffix { + maxSuffix = suffix + } + } + return maxSuffix, nil +} + +// GetLocalTSOSuffixPathPrefix returns the etcd key prefix of the Local TSO suffix for the given dc-location. +func (am *AllocatorManager) GetLocalTSOSuffixPathPrefix() string { + return path.Join(am.rootPath, localTSOSuffixEtcdPrefix) +} + +// GetLocalTSOSuffixPath returns the etcd key of the Local TSO suffix for the given dc-location. +func (am *AllocatorManager) GetLocalTSOSuffixPath(dcLocation string) string { + return path.Join(am.GetLocalTSOSuffixPathPrefix(), dcLocation) +} + +// PriorityChecker is used to check the election priority of a Local TSO Allocator. +// In the normal case, if we want to elect a Local TSO Allocator for a certain DC, +// such as dc-1, we need to make sure the follow priority rules: +// 1. The PD server with dc-location="dc-1" needs to be elected as the allocator +// leader with the highest priority. +// 2. If all PD servers with dc-location="dc-1" are down, then the other PD servers +// of DC could be elected. +func (am *AllocatorManager) PriorityChecker() { + serverID := am.member.ID() + myServerDCLocation := am.getServerDCLocation(serverID) + // Check all Local TSO Allocator followers to see if their priorities is higher than the leaders + // Filter out allocators with leadership and initialized + allocatorGroups := am.getAllocatorGroups(FilterDCLocation(GlobalDCLocation), FilterAvailableLeadership()) + for _, allocatorGroup := range allocatorGroups { + localTSOAllocator, _ := allocatorGroup.allocator.(*LocalTSOAllocator) + leaderServerID := localTSOAllocator.GetAllocatorLeader().GetMemberId() + // No leader, maybe the leader is not been watched yet + if leaderServerID == 0 { + continue + } + leaderServerDCLocation := am.getServerDCLocation(leaderServerID) + // For example, an allocator leader for dc-1 is elected by a server of dc-2, then the server of dc-1 will + // find this allocator's dc-location isn't the same with server of dc-2 but is same with itself. + if allocatorGroup.dcLocation != leaderServerDCLocation && allocatorGroup.dcLocation == myServerDCLocation { + log.Info("try to move the local tso allocator", + zap.Uint64("old-leader-id", leaderServerID), + zap.String("old-dc-location", leaderServerDCLocation), + zap.Uint64("next-leader-id", serverID), + zap.String("next-dc-location", myServerDCLocation)) + if err := am.transferLocalAllocator(allocatorGroup.dcLocation, am.member.ID()); err != nil { + log.Error("move the local tso allocator failed", + zap.Uint64("old-leader-id", leaderServerID), + zap.String("old-dc-location", leaderServerDCLocation), + zap.Uint64("next-leader-id", serverID), + zap.String("next-dc-location", myServerDCLocation), + errs.ZapError(err)) + continue + } + } + } + // Check next leader and resign + // Filter out allocators with leadership + allocatorGroups = am.getAllocatorGroups(FilterDCLocation(GlobalDCLocation), FilterUnavailableLeadership()) + for _, allocatorGroup := range allocatorGroups { + nextLeader, err := am.getNextLeaderID(allocatorGroup.dcLocation) + if err != nil { + log.Error("get next leader from etcd failed", + zap.String("dc-location", allocatorGroup.dcLocation), + errs.ZapError(err)) + continue + } + // nextLeader is not empty and isn't same with the server ID, resign the leader + if nextLeader != 0 && nextLeader != serverID { + log.Info("next leader key found, resign current leader", zap.Uint64("nextLeaderID", nextLeader)) + am.ResetAllocatorGroup(allocatorGroup.dcLocation) + } + } +} + +// TransferAllocatorForDCLocation transfer local tso allocator to the target member for the given dcLocation +func (am *AllocatorManager) TransferAllocatorForDCLocation(dcLocation string, memberID uint64) error { + if dcLocation == GlobalDCLocation { + return fmt.Errorf("dc-location %v should be transferred by transfer leader", dcLocation) + } + dcLocationsInfo := am.GetClusterDCLocations() + _, ok := dcLocationsInfo[dcLocation] + if !ok { + return fmt.Errorf("dc-location %v haven't been discovered yet", dcLocation) + } + allocator, err := am.GetAllocator(dcLocation) + if err != nil { + return err + } + localTSOAllocator, _ := allocator.(*LocalTSOAllocator) + leaderServerID := localTSOAllocator.GetAllocatorLeader().GetMemberId() + if leaderServerID == memberID { + return nil + } + return am.transferLocalAllocator(dcLocation, memberID) +} + +func (am *AllocatorManager) getServerDCLocation(serverID uint64) string { + am.mu.RLock() + defer am.mu.RUnlock() + for dcLocation, info := range am.mu.clusterDCLocations { + if slice.AnyOf(info.ServerIDs, func(i int) bool { return info.ServerIDs[i] == serverID }) { + return dcLocation + } + } + return "" +} + +func (am *AllocatorManager) getNextLeaderID(dcLocation string) (uint64, error) { + nextLeaderKey := am.nextLeaderKey(dcLocation) + nextLeaderValue, err := etcdutil.GetValue(am.member.Client(), nextLeaderKey) + if err != nil { + return 0, err + } + if len(nextLeaderValue) == 0 { + return 0, nil + } + return strconv.ParseUint(string(nextLeaderValue), 10, 64) +} + +func (am *AllocatorManager) deleteNextLeaderID(dcLocation string) error { + nextLeaderKey := am.nextLeaderKey(dcLocation) + resp, err := kv.NewSlowLogTxn(am.member.Client()). + Then(clientv3.OpDelete(nextLeaderKey)). + Commit() + if err != nil { + return errs.ErrEtcdKVDelete.Wrap(err).GenWithStackByCause() + } + if !resp.Succeeded { + return errs.ErrEtcdTxnConflict.FastGenByArgs() + } + return nil +} + +// deleteAllocatorGroup should only be used to remove the unused Local TSO Allocator from an unused dc-location. +// If you want to clear or reset a TSO allocator, use (*AllocatorManager).ResetAllocatorGroup. +func (am *AllocatorManager) deleteAllocatorGroup(dcLocation string) { + am.mu.Lock() + defer am.mu.Unlock() + if allocatorGroup, exist := am.mu.allocatorGroups[dcLocation]; exist { + allocatorGroup.allocator.Reset() + allocatorGroup.leadership.Reset() + allocatorGroup.cancel() + delete(am.mu.allocatorGroups, dcLocation) + } +} + +// HandleTSORequest forwards TSO allocation requests to correct TSO Allocators. +func (am *AllocatorManager) HandleTSORequest(dcLocation string, count uint32) (pdpb.Timestamp, error) { + if dcLocation == "" { + dcLocation = GlobalDCLocation + } + allocatorGroup, exist := am.getAllocatorGroup(dcLocation) + if !exist { + err := errs.ErrGetAllocator.FastGenByArgs(fmt.Sprintf("%s allocator not found, generate timestamp failed", dcLocation)) + return pdpb.Timestamp{}, err + } + return allocatorGroup.allocator.GenerateTSO(count) +} + +// ResetAllocatorGroup will reset the allocator's leadership and TSO initialized in memory. +// It usually should be called before re-triggering an Allocator leader campaign. +func (am *AllocatorManager) ResetAllocatorGroup(dcLocation string) { + am.mu.Lock() + defer am.mu.Unlock() + if allocatorGroup, exist := am.mu.allocatorGroups[dcLocation]; exist { + allocatorGroup.allocator.Reset() + // Reset if it still has the leadership. Otherwise the data race may occur because of the re-campaigning. + if allocatorGroup.leadership.Check() { + allocatorGroup.leadership.Reset() + } + } +} + +func (am *AllocatorManager) getAllocatorGroups(filters ...AllocatorGroupFilter) []*allocatorGroup { + am.mu.RLock() + defer am.mu.RUnlock() + var allocatorGroups []*allocatorGroup + for _, ag := range am.mu.allocatorGroups { + if ag == nil { + continue + } + if slice.NoneOf(filters, func(i int) bool { return filters[i](ag) }) { + allocatorGroups = append(allocatorGroups, ag) + } + } + return allocatorGroups +} + +func (am *AllocatorManager) getAllocatorGroup(dcLocation string) (*allocatorGroup, bool) { + am.mu.RLock() + defer am.mu.RUnlock() + allocatorGroup, exist := am.mu.allocatorGroups[dcLocation] + return allocatorGroup, exist +} + +// GetAllocator get the allocator by dc-location. +func (am *AllocatorManager) GetAllocator(dcLocation string) (Allocator, error) { + am.mu.RLock() + defer am.mu.RUnlock() + allocatorGroup, exist := am.mu.allocatorGroups[dcLocation] + if !exist { + return nil, errs.ErrGetAllocator.FastGenByArgs(fmt.Sprintf("%s allocator not found", dcLocation)) + } + return allocatorGroup.allocator, nil +} + +// GetAllocators get all allocators with some filters. +func (am *AllocatorManager) GetAllocators(filters ...AllocatorGroupFilter) []Allocator { + allocatorGroups := am.getAllocatorGroups(filters...) + allocators := make([]Allocator, 0, len(allocatorGroups)) + for _, ag := range allocatorGroups { + allocators = append(allocators, ag.allocator) + } + return allocators +} + +// GetHoldingLocalAllocatorLeaders returns all Local TSO Allocator leaders this server holds. +func (am *AllocatorManager) GetHoldingLocalAllocatorLeaders() ([]*LocalTSOAllocator, error) { + localAllocators := am.GetAllocators( + FilterDCLocation(GlobalDCLocation), + FilterUnavailableLeadership()) + localAllocatorLeaders := make([]*LocalTSOAllocator, 0, len(localAllocators)) + for _, localAllocator := range localAllocators { + localAllocatorLeader, ok := localAllocator.(*LocalTSOAllocator) + if !ok { + return nil, errs.ErrGetLocalAllocator.FastGenByArgs("invalid local tso allocator found") + } + localAllocatorLeaders = append(localAllocatorLeaders, localAllocatorLeader) + } + return localAllocatorLeaders, nil +} + +// GetLocalAllocatorLeaders returns all Local TSO Allocator leaders' member info. +func (am *AllocatorManager) GetLocalAllocatorLeaders() (map[string]*pdpb.Member, error) { + localAllocators := am.GetAllocators(FilterDCLocation(GlobalDCLocation)) + localAllocatorLeaderMember := make(map[string]*pdpb.Member) + for _, allocator := range localAllocators { + localAllocator, ok := allocator.(*LocalTSOAllocator) + if !ok { + return nil, errs.ErrGetLocalAllocator.FastGenByArgs("invalid local tso allocator found") + } + localAllocatorLeaderMember[localAllocator.GetDCLocation()] = localAllocator.GetAllocatorLeader() + } + return localAllocatorLeaderMember, nil +} + +func (am *AllocatorManager) getOrCreateGRPCConn(ctx context.Context, addr string) (*grpc.ClientConn, error) { + conn, ok := am.getGRPCConn(addr) + if ok { + return conn, nil + } + tlsCfg, err := am.securityConfig.ToTLSConfig() + if err != nil { + return nil, err + } + ctxWithTimeout, cancel := context.WithTimeout(ctx, dialTimeout) + defer cancel() + cc, err := grpcutil.GetClientConn(ctxWithTimeout, addr, tlsCfg) + if err != nil { + return nil, err + } + am.setGRPCConn(cc, addr) + conn, _ = am.getGRPCConn(addr) + return conn, nil +} + +func (am *AllocatorManager) getDCLocationInfoFromLeader(ctx context.Context, dcLocation string) (bool, *pdpb.GetDCLocationInfoResponse, error) { + if am.member.IsLeader() { + info, ok := am.GetDCLocationInfo(dcLocation) + if !ok { + return false, &pdpb.GetDCLocationInfoResponse{}, nil + } + dcLocationInfo := &pdpb.GetDCLocationInfoResponse{Suffix: info.Suffix} + var err error + if dcLocationInfo.MaxTs, err = am.GetMaxLocalTSO(ctx); err != nil { + return false, &pdpb.GetDCLocationInfoResponse{}, err + } + return ok, dcLocationInfo, nil + } + + leaderAddrs := am.member.GetLeader().GetClientUrls() + if leaderAddrs == nil || len(leaderAddrs) < 1 { + return false, &pdpb.GetDCLocationInfoResponse{}, fmt.Errorf("failed to get leader client url") + } + conn, err := am.getOrCreateGRPCConn(ctx, leaderAddrs[0]) + if err != nil { + return false, &pdpb.GetDCLocationInfoResponse{}, err + } + getCtx, cancel := context.WithTimeout(ctx, rpcTimeout) + defer cancel() + resp, err := pdpb.NewPDClient(conn).GetDCLocationInfo(getCtx, &pdpb.GetDCLocationInfoRequest{ + Header: &pdpb.RequestHeader{ + SenderId: am.member.Member().GetMemberId(), + }, + DcLocation: dcLocation, + }) + if err != nil { + return false, &pdpb.GetDCLocationInfoResponse{}, err + } + if resp.GetHeader().GetError() != nil { + return false, &pdpb.GetDCLocationInfoResponse{}, errors.Errorf("get the dc-location info from leader failed: %s", resp.GetHeader().GetError().String()) + } + return resp.GetSuffix() != 0, resp, nil +} + +// GetMaxLocalTSO will sync with the current Local TSO Allocators among the cluster to get the +// max Local TSO. +func (am *AllocatorManager) GetMaxLocalTSO(ctx context.Context) (*pdpb.Timestamp, error) { + // Sync the max local TSO from the other Local TSO Allocators who has been initialized + clusterDCLocations := am.GetClusterDCLocations() + for dcLocation := range clusterDCLocations { + allocatorGroup, ok := am.getAllocatorGroup(dcLocation) + if !(ok && allocatorGroup.leadership.Check()) { + delete(clusterDCLocations, dcLocation) + } + } + maxTSO := &pdpb.Timestamp{} + if len(clusterDCLocations) == 0 { + return maxTSO, nil + } + globalAllocator, err := am.GetAllocator(GlobalDCLocation) + if err != nil { + return nil, err + } + if err := globalAllocator.(*GlobalTSOAllocator).SyncMaxTS(ctx, clusterDCLocations, maxTSO, false); err != nil { + return nil, err + } + return maxTSO, nil +} + +// GetGlobalTSO returns global tso. +func (am *AllocatorManager) GetGlobalTSO() (*pdpb.Timestamp, error) { + globalAllocator, err := am.GetAllocator(GlobalDCLocation) + if err != nil { + return nil, err + } + return globalAllocator.(*GlobalTSOAllocator).getCurrentTSO() +} + +func (am *AllocatorManager) getGRPCConn(addr string) (*grpc.ClientConn, bool) { + am.localAllocatorConn.RLock() + defer am.localAllocatorConn.RUnlock() + conn, ok := am.localAllocatorConn.clientConns[addr] + return conn, ok +} + +func (am *AllocatorManager) setGRPCConn(newConn *grpc.ClientConn, addr string) { + am.localAllocatorConn.Lock() + defer am.localAllocatorConn.Unlock() + if _, ok := am.localAllocatorConn.clientConns[addr]; ok { + newConn.Close() + log.Debug("use old connection", zap.String("target", newConn.Target()), zap.String("state", newConn.GetState().String())) + return + } + am.localAllocatorConn.clientConns[addr] = newConn +} + +func (am *AllocatorManager) transferLocalAllocator(dcLocation string, serverID uint64) error { + nextLeaderKey := am.nextLeaderKey(dcLocation) + // Grant a etcd lease with checkStep * 1.5 + nextLeaderLease := clientv3.NewLease(am.member.Client()) + ctx, cancel := context.WithTimeout(am.member.Client().Ctx(), etcdutil.DefaultRequestTimeout) + leaseResp, err := nextLeaderLease.Grant(ctx, int64(checkStep.Seconds()*1.5)) + cancel() + if err != nil { + err = errs.ErrEtcdGrantLease.Wrap(err).GenWithStackByCause() + log.Error("failed to grant the lease of the next leader key", + zap.String("dc-location", dcLocation), zap.Uint64("serverID", serverID), + errs.ZapError(err)) + return err + } + resp, err := kv.NewSlowLogTxn(am.member.Client()). + If(clientv3.Compare(clientv3.CreateRevision(nextLeaderKey), "=", 0)). + Then(clientv3.OpPut(nextLeaderKey, fmt.Sprint(serverID), clientv3.WithLease(leaseResp.ID))). + Commit() + if err != nil { + err = errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause() + log.Error("failed to write next leader key into etcd", + zap.String("dc-location", dcLocation), zap.Uint64("serverID", serverID), + errs.ZapError(err)) + return err + } + if !resp.Succeeded { + log.Warn("write next leader id into etcd unsuccessfully", zap.String("dc-location", dcLocation)) + return errs.ErrEtcdTxnConflict.GenWithStack("write next leader id into etcd unsuccessfully") + } + return nil +} + +func (am *AllocatorManager) nextLeaderKey(dcLocation string) string { + return path.Join(am.getAllocatorPath(dcLocation), "next-leader") +} + +// EnableLocalTSO returns the value of AllocatorManager.enableLocalTSO. +func (am *AllocatorManager) EnableLocalTSO() bool { + return am.enableLocalTSO +} diff --git a/server/tso/binding__failpoint_binding__.go b/server/tso/binding__failpoint_binding__.go new file mode 100755 index 00000000000..ed07caea74b --- /dev/null +++ b/server/tso/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package tso + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/server/tso/global_allocator.go b/server/tso/global_allocator.go old mode 100644 new mode 100755 index 8e035808317..abaedb3ce9f --- a/server/tso/global_allocator.go +++ b/server/tso/global_allocator.go @@ -251,12 +251,12 @@ func (gta *GlobalTSOAllocator) GenerateTSO(count uint32) (pdpb.Timestamp, error) var globalTSOOverflowFlag = true func (gta *GlobalTSOAllocator) precheckLogical(maxTSO *pdpb.Timestamp, suffixBits int) bool { - failpoint.Inject("globalTSOOverflow", func() { + if _, _err_ := failpoint.Eval(_curpkg_("globalTSOOverflow")); _err_ == nil { if globalTSOOverflowFlag { maxTSO.Logical = maxLogical globalTSOOverflowFlag = false } - }) + } // Make sure the physical time is not empty again. if maxTSO.GetPhysical() == 0 { return false diff --git a/server/tso/global_allocator.go__failpoint_stash__ b/server/tso/global_allocator.go__failpoint_stash__ new file mode 100644 index 00000000000..8e035808317 --- /dev/null +++ b/server/tso/global_allocator.go__failpoint_stash__ @@ -0,0 +1,445 @@ +// Copyright 2020 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tso + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/slice" + "github.com/tikv/pd/pkg/tsoutil" + "github.com/tikv/pd/pkg/typeutil" + "github.com/tikv/pd/server/election" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +// Allocator is a Timestamp Oracle allocator. +type Allocator interface { + // Initialize is used to initialize a TSO allocator. + // It will synchronize TSO with etcd and initialize the + // memory for later allocation work. + Initialize(suffix int) error + // IsInitialize is used to indicates whether this allocator is initialized. + IsInitialize() bool + // UpdateTSO is used to update the TSO in memory and the time window in etcd. + UpdateTSO() error + // SetTSO sets the physical part with given TSO. It's mainly used for BR restore. + // Cannot set the TSO smaller than now in any case. + // if ignoreSmaller=true, if input ts is smaller than current, ignore silently, else return error + // if skipUpperBoundCheck=true, skip tso upper bound check + SetTSO(tso uint64, ignoreSmaller, skipUpperBoundCheck bool) error + // GenerateTSO is used to generate a given number of TSOs. + // Make sure you have initialized the TSO allocator before calling. + GenerateTSO(count uint32) (pdpb.Timestamp, error) + // Reset is used to reset the TSO allocator. + Reset() +} + +// GlobalTSOAllocator is the global single point TSO allocator. +type GlobalTSOAllocator struct { + // for global TSO synchronization + allocatorManager *AllocatorManager + // leadership is used to check the current PD server's leadership + // to determine whether a TSO request could be processed. + leadership *election.Leadership + timestampOracle *timestampOracle + // syncRTT is the RTT duration a SyncMaxTS RPC call will cost, + // which is used to estimate the MaxTS in a Global TSO generation + // to reduce the gRPC network IO latency. + syncRTT atomic.Value // store as int64 milliseconds +} + +// NewGlobalTSOAllocator creates a new global TSO allocator. +func NewGlobalTSOAllocator( + am *AllocatorManager, + leadership *election.Leadership, +) Allocator { + gta := &GlobalTSOAllocator{ + allocatorManager: am, + leadership: leadership, + timestampOracle: ×tampOracle{ + client: leadership.GetClient(), + rootPath: am.rootPath, + saveInterval: am.saveInterval, + updatePhysicalInterval: am.updatePhysicalInterval, + maxResetTSGap: am.maxResetTSGap, + dcLocation: GlobalDCLocation, + tsoMux: &tsoObject{}, + }, + } + return gta +} + +func (gta *GlobalTSOAllocator) setSyncRTT(rtt int64) { + gta.syncRTT.Store(rtt) + tsoGauge.WithLabelValues("global_tso_sync_rtt", gta.timestampOracle.dcLocation).Set(float64(rtt)) +} + +func (gta *GlobalTSOAllocator) getSyncRTT() int64 { + syncRTT := gta.syncRTT.Load() + if syncRTT == nil { + return 0 + } + return syncRTT.(int64) +} + +func (gta *GlobalTSOAllocator) estimateMaxTS(count uint32, suffixBits int) (*pdpb.Timestamp, bool, error) { + physical, logical, lastUpdateTime := gta.timestampOracle.generateTSO(int64(count), 0) + if physical == 0 { + return &pdpb.Timestamp{}, false, errs.ErrGenerateTimestamp.FastGenByArgs("timestamp in memory isn't initialized") + } + estimatedMaxTSO := &pdpb.Timestamp{ + Physical: physical + time.Since(lastUpdateTime).Milliseconds() + 2*gta.getSyncRTT(), // TODO: make the coefficient of RTT configurable + Logical: logical, + } + // Precheck to make sure the logical part won't overflow after being differentiated. + // If precheckLogical returns false, it means the logical part is overflow, + // we need to wait a updatePhysicalInterval and retry the estimation later. + if !gta.precheckLogical(estimatedMaxTSO, suffixBits) { + return nil, true, nil + } + return estimatedMaxTSO, false, nil +} + +// Initialize will initialize the created global TSO allocator. +func (gta *GlobalTSOAllocator) Initialize(int) error { + tsoAllocatorRole.WithLabelValues(gta.timestampOracle.dcLocation).Set(1) + // The suffix of a Global TSO should always be 0. + gta.timestampOracle.suffix = 0 + return gta.timestampOracle.SyncTimestamp(gta.leadership) +} + +// IsInitialize is used to indicates whether this allocator is initialized. +func (gta *GlobalTSOAllocator) IsInitialize() bool { + return gta.timestampOracle.isInitialized() +} + +// UpdateTSO is used to update the TSO in memory and the time window in etcd. +func (gta *GlobalTSOAllocator) UpdateTSO() error { + return gta.timestampOracle.UpdateTimestamp(gta.leadership) +} + +// SetTSO sets the physical part with given TSO. +func (gta *GlobalTSOAllocator) SetTSO(tso uint64, ignoreSmaller, skipUpperBoundCheck bool) error { + return gta.timestampOracle.resetUserTimestampInner(gta.leadership, tso, ignoreSmaller, skipUpperBoundCheck) +} + +// GenerateTSO is used to generate the given number of TSOs. +// Make sure you have initialized the TSO allocator before calling this method. +// Basically, there are two ways to generate a Global TSO: +// 1. The old way to generate a normal TSO from memory directly, which makes the TSO service node become single point. +// 2. The new way to generate a Global TSO by synchronizing with all other Local TSO Allocators. +// +// And for the new way, there are two different strategies: +// 1. Collect the max Local TSO from all Local TSO Allocator leaders and write it back to them as MaxTS. +// 2. Estimate a MaxTS and try to write it to all Local TSO Allocator leaders directly to reduce the RTT. +// During the process, if the estimated MaxTS is not accurate, it will fallback to the collecting way. +func (gta *GlobalTSOAllocator) GenerateTSO(count uint32) (pdpb.Timestamp, error) { + if !gta.leadership.Check() { + tsoCounter.WithLabelValues("not_leader", gta.timestampOracle.dcLocation).Inc() + return pdpb.Timestamp{}, errs.ErrGenerateTimestamp.FastGenByArgs(fmt.Sprintf("requested pd %s of cluster", errs.NotLeaderErr)) + } + // To check if we have any dc-location configured in the cluster + dcLocationMap := gta.allocatorManager.GetClusterDCLocations() + // No dc-locations configured in the cluster, use the normal Global TSO generation way. + // (without synchronization with other Local TSO Allocators) + if len(dcLocationMap) == 0 { + return gta.timestampOracle.getTS(gta.leadership, count, 0) + } + + // Have dc-locations configured in the cluster, use the Global TSO generation way. + // (whit synchronization with other Local TSO Allocators) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 0; i < maxRetryCount; i++ { + var ( + err error + shouldRetry, skipCheck bool + globalTSOResp pdpb.Timestamp + estimatedMaxTSO *pdpb.Timestamp + suffixBits = gta.allocatorManager.GetSuffixBits() + ) + // TODO: add a switch to control whether to enable the MaxTSO estimation. + // 1. Estimate a MaxTS among all Local TSO Allocator leaders according to the RTT. + estimatedMaxTSO, shouldRetry, err = gta.estimateMaxTS(count, suffixBits) + if err != nil { + log.Error("global tso allocator estimates MaxTS failed", errs.ZapError(err)) + continue + } + if shouldRetry { + time.Sleep(gta.timestampOracle.updatePhysicalInterval) + continue + } + SETTING_PHASE: + // 2. Send the MaxTSO to all Local TSO Allocators leaders to make sure the subsequent Local TSOs will be bigger than it. + // It's not safe to skip check at the first time here because the estimated maxTSO may not be big enough, + // we need to validate it first before we write it into every Local TSO Allocator's memory. + globalTSOResp = *estimatedMaxTSO + if err = gta.SyncMaxTS(ctx, dcLocationMap, &globalTSOResp, skipCheck); err != nil { + log.Error("global tso allocator synchronizes MaxTS failed", errs.ZapError(err)) + continue + } + // 3. If skipCheck is false and the maxTSO is bigger than estimatedMaxTSO, + // we need to redo the setting phase with the bigger one and skip the check safely. + if !skipCheck && tsoutil.CompareTimestamp(&globalTSOResp, estimatedMaxTSO) > 0 { + tsoCounter.WithLabelValues("global_tso_sync", gta.timestampOracle.dcLocation).Inc() + *estimatedMaxTSO = globalTSOResp + // Re-add the count and check the overflow. + estimatedMaxTSO.Logical += int64(count) + if !gta.precheckLogical(estimatedMaxTSO, suffixBits) { + estimatedMaxTSO.Physical += UpdateTimestampGuard.Milliseconds() + estimatedMaxTSO.Logical = int64(count) + } + skipCheck = true + goto SETTING_PHASE + } + // Is skipCheck is false and globalTSOResp remains the same, it means the estimatedTSO is valide. + if !skipCheck && tsoutil.CompareTimestamp(&globalTSOResp, estimatedMaxTSO) == 0 { + tsoCounter.WithLabelValues("global_tso_estimate", gta.timestampOracle.dcLocation).Inc() + } + // 4. Persist MaxTS into memory, and etcd if needed + var currentGlobalTSO *pdpb.Timestamp + if currentGlobalTSO, err = gta.getCurrentTSO(); err != nil { + log.Error("global tso allocator gets the current global tso in memory failed", errs.ZapError(err)) + continue + } + if tsoutil.CompareTimestamp(currentGlobalTSO, &globalTSOResp) < 0 { + tsoCounter.WithLabelValues("global_tso_persist", gta.timestampOracle.dcLocation).Inc() + // Update the Global TSO in memory + if err = gta.timestampOracle.resetUserTimestamp(gta.leadership, tsoutil.GenerateTS(&globalTSOResp), true); err != nil { + tsoCounter.WithLabelValues("global_tso_persist_err", gta.timestampOracle.dcLocation).Inc() + log.Error("global tso allocator update the global tso in memory failed", errs.ZapError(err)) + continue + } + } + // 5. Check leadership again before we returning the response. + if !gta.leadership.Check() { + tsoCounter.WithLabelValues("not_leader_anymore", gta.timestampOracle.dcLocation).Inc() + return pdpb.Timestamp{}, errs.ErrGenerateTimestamp.FastGenByArgs("not the pd leader anymore") + } + // 6. Differentiate the logical part to make the TSO unique globally by giving it a unique suffix in the whole cluster + globalTSOResp.Logical = gta.timestampOracle.differentiateLogical(globalTSOResp.GetLogical(), suffixBits) + globalTSOResp.SuffixBits = uint32(suffixBits) + return globalTSOResp, nil + } + tsoCounter.WithLabelValues("exceeded_max_retry", gta.timestampOracle.dcLocation).Inc() + return pdpb.Timestamp{}, errs.ErrGenerateTimestamp.FastGenByArgs("global tso allocator maximum number of retries exceeded") +} + +// Only used for test +var globalTSOOverflowFlag = true + +func (gta *GlobalTSOAllocator) precheckLogical(maxTSO *pdpb.Timestamp, suffixBits int) bool { + failpoint.Inject("globalTSOOverflow", func() { + if globalTSOOverflowFlag { + maxTSO.Logical = maxLogical + globalTSOOverflowFlag = false + } + }) + // Make sure the physical time is not empty again. + if maxTSO.GetPhysical() == 0 { + return false + } + // Check if the logical part will reach the overflow condition after being differenitated. + if differentiatedLogical := gta.timestampOracle.differentiateLogical(maxTSO.Logical, suffixBits); differentiatedLogical >= maxLogical { + log.Error("estimated logical part outside of max logical interval, please check ntp time", + zap.Reflect("max-tso", maxTSO), errs.ZapError(errs.ErrLogicOverflow)) + tsoCounter.WithLabelValues("precheck_logical_overflow", gta.timestampOracle.dcLocation).Inc() + return false + } + return true +} + +const ( + dialTimeout = 3 * time.Second + rpcTimeout = 3 * time.Second + // TODO: maybe make syncMaxRetryCount configurable + syncMaxRetryCount = 2 +) + +type syncResp struct { + rpcRes *pdpb.SyncMaxTSResponse + err error + rtt time.Duration +} + +// SyncMaxTS is used to sync MaxTS with all Local TSO Allocator leaders in dcLocationMap. +// If maxTSO is the biggest TSO among all Local TSO Allocators, it will be written into +// each allocator and remines the same after the synchronization. +// If not, it will be replaced with the new max Local TSO and return. +func (gta *GlobalTSOAllocator) SyncMaxTS( + ctx context.Context, + dcLocationMap map[string]DCLocationInfo, + maxTSO *pdpb.Timestamp, + skipCheck bool, +) error { + originalMaxTSO := *maxTSO + for i := 0; i < syncMaxRetryCount; i++ { + // Collect all allocator leaders' client URLs + allocatorLeaders := make(map[string]*pdpb.Member) + for dcLocation := range dcLocationMap { + allocator, err := gta.allocatorManager.GetAllocator(dcLocation) + if err != nil { + return err + } + allocatorLeader := allocator.(*LocalTSOAllocator).GetAllocatorLeader() + if allocatorLeader.GetMemberId() == 0 { + return errs.ErrSyncMaxTS.FastGenByArgs(fmt.Sprintf("%s does not have the local allocator leader yet", dcLocation)) + } + allocatorLeaders[dcLocation] = allocatorLeader + } + leaderURLs := make([]string, 0) + for _, allocator := range allocatorLeaders { + // Check if its client URLs are empty + if len(allocator.GetClientUrls()) < 1 { + continue + } + leaderURL := allocator.GetClientUrls()[0] + if slice.NoneOf(leaderURLs, func(i int) bool { return leaderURLs[i] == leaderURL }) { + leaderURLs = append(leaderURLs, leaderURL) + } + } + // Prepare to make RPC requests concurrently + respCh := make(chan *syncResp, len(leaderURLs)) + wg := sync.WaitGroup{} + request := &pdpb.SyncMaxTSRequest{ + Header: &pdpb.RequestHeader{ + SenderId: gta.allocatorManager.member.ID(), + }, + SkipCheck: skipCheck, + MaxTs: maxTSO, + } + for _, leaderURL := range leaderURLs { + leaderConn, err := gta.allocatorManager.getOrCreateGRPCConn(ctx, leaderURL) + if err != nil { + return err + } + // Send SyncMaxTSRequest to all allocator leaders concurrently. + wg.Add(1) + go func(ctx context.Context, conn *grpc.ClientConn, respCh chan<- *syncResp) { + defer wg.Done() + syncMaxTSResp := &syncResp{} + syncCtx, cancel := context.WithTimeout(ctx, rpcTimeout) + startTime := time.Now() + syncMaxTSResp.rpcRes, syncMaxTSResp.err = pdpb.NewPDClient(conn).SyncMaxTS(syncCtx, request) + // Including RPC request -> RPC processing -> RPC response + syncMaxTSResp.rtt = time.Since(startTime) + cancel() + respCh <- syncMaxTSResp + if syncMaxTSResp.err != nil { + log.Error("sync max ts rpc failed, got an error", zap.String("local-allocator-leader-url", leaderConn.Target()), errs.ZapError(err)) + return + } + if syncMaxTSResp.rpcRes.GetHeader().GetError() != nil { + log.Error("sync max ts rpc failed, got an error", zap.String("local-allocator-leader-url", leaderConn.Target()), + errs.ZapError(errors.Errorf("%s", syncMaxTSResp.rpcRes.GetHeader().GetError().String()))) + return + } + }(ctx, leaderConn, respCh) + } + wg.Wait() + close(respCh) + var ( + errList []error + syncedDCs []string + maxTSORtt time.Duration + ) + // Iterate each response to handle the error and compare MaxTSO. + for resp := range respCh { + if resp.err != nil { + errList = append(errList, resp.err) + } + // If any error occurs, just jump out of the loop. + if len(errList) != 0 { + break + } + if resp.rpcRes == nil { + return errs.ErrSyncMaxTS.FastGenByArgs("got nil response") + } + if skipCheck { + // Set all the Local TSOs to the maxTSO unconditionally, so the MaxLocalTS in response should be nil. + if resp.rpcRes.GetMaxLocalTs() != nil { + return errs.ErrSyncMaxTS.FastGenByArgs("got non-nil max local ts in the second sync phase") + } + syncedDCs = append(syncedDCs, resp.rpcRes.GetSyncedDcs()...) + } else { + // Compare and get the max one + if tsoutil.CompareTimestamp(resp.rpcRes.GetMaxLocalTs(), maxTSO) > 0 { + *maxTSO = *(resp.rpcRes.GetMaxLocalTs()) + if resp.rtt > maxTSORtt { + maxTSORtt = resp.rtt + } + } + syncedDCs = append(syncedDCs, resp.rpcRes.GetSyncedDcs()...) + } + } + // We need to collect all info needed to ensure the consistency of TSO. + // So if any error occurs, the synchronization process will fail directly. + if len(errList) != 0 { + return errs.ErrSyncMaxTS.FastGenWithCause(errList) + } + // Check whether all dc-locations have been considered during the synchronization and retry once if any dc-location missed. + if ok, unsyncedDCs := gta.checkSyncedDCs(dcLocationMap, syncedDCs); !ok { + log.Info("unsynced dc-locations found, will retry", zap.Bool("skip-check", skipCheck), zap.Strings("synced-DCs", syncedDCs), zap.Strings("unsynced-DCs", unsyncedDCs)) + if i < syncMaxRetryCount-1 { + // maxTSO should remain the same. + *maxTSO = originalMaxTSO + // To make sure we have the latest dc-location info + gta.allocatorManager.ClusterDCLocationChecker() + continue + } + return errs.ErrSyncMaxTS.FastGenByArgs(fmt.Sprintf("unsynced dc-locations found, skip-check: %t, synced dc-locations: %+v, unsynced dc-locations: %+v", skipCheck, syncedDCs, unsyncedDCs)) + } + // Update the sync RTT to help estimate MaxTS later. + if maxTSORtt != 0 { + gta.setSyncRTT(maxTSORtt.Milliseconds()) + } + } + return nil +} + +func (gta *GlobalTSOAllocator) checkSyncedDCs(dcLocationMap map[string]DCLocationInfo, syncedDCs []string) (bool, []string) { + var unsyncedDCs []string + for dcLocation := range dcLocationMap { + if slice.NoneOf(syncedDCs, func(i int) bool { return syncedDCs[i] == dcLocation }) { + unsyncedDCs = append(unsyncedDCs, dcLocation) + } + } + log.Debug("check unsynced dc-locations", zap.Strings("unsynced-DCs", unsyncedDCs), zap.Strings("synced-DCs", syncedDCs)) + return len(unsyncedDCs) == 0, unsyncedDCs +} + +func (gta *GlobalTSOAllocator) getCurrentTSO() (*pdpb.Timestamp, error) { + currentPhysical, currentLogical := gta.timestampOracle.getTSO() + if currentPhysical == typeutil.ZeroTime { + return &pdpb.Timestamp{}, errs.ErrGenerateTimestamp.FastGenByArgs("timestamp in memory isn't initialized") + } + return tsoutil.GenerateTimestamp(currentPhysical, uint64(currentLogical)), nil +} + +// Reset is used to reset the TSO allocator. +func (gta *GlobalTSOAllocator) Reset() { + tsoAllocatorRole.WithLabelValues(gta.timestampOracle.dcLocation).Set(0) + gta.timestampOracle.ResetTimestamp() +} diff --git a/server/tso/tso.go b/server/tso/tso.go old mode 100644 new mode 100755 index a19252962e9..c5caf37a570 --- a/server/tso/tso.go +++ b/server/tso/tso.go @@ -194,9 +194,9 @@ func (t *timestampOracle) saveTimestamp(leadership *election.Leadership, ts time func (t *timestampOracle) SyncTimestamp(leadership *election.Leadership) error { tsoCounter.WithLabelValues("sync", t.dcLocation).Inc() - failpoint.Inject("delaySyncTimestamp", func() { + if _, _err_ := failpoint.Eval(_curpkg_("delaySyncTimestamp")); _err_ == nil { time.Sleep(time.Second) - }) + } last, err := t.loadTimestamp() if err != nil { @@ -204,12 +204,12 @@ func (t *timestampOracle) SyncTimestamp(leadership *election.Leadership) error { } next := time.Now() - failpoint.Inject("fallBackSync", func() { + if _, _err_ := failpoint.Eval(_curpkg_("fallBackSync")); _err_ == nil { next = next.Add(time.Hour) - }) - failpoint.Inject("systemTimeSlow", func() { + } + if _, _err_ := failpoint.Eval(_curpkg_("systemTimeSlow")); _err_ == nil { next = next.Add(-time.Hour) - }) + } // If the current system time minus the saved etcd timestamp is less than `UpdateTimestampGuard`, // the timestamp allocation will start from the saved etcd timestamp temporarily. if typeutil.SubRealTimeByWallClock(next, last) < UpdateTimestampGuard { @@ -317,12 +317,12 @@ func (t *timestampOracle) UpdateTimestamp(leadership *election.Leadership) error tsoGap.WithLabelValues(t.dcLocation).Set(float64(time.Since(prevPhysical).Milliseconds())) now := time.Now() - failpoint.Inject("fallBackUpdate", func() { + if _, _err_ := failpoint.Eval(_curpkg_("fallBackUpdate")); _err_ == nil { now = now.Add(time.Hour) - }) - failpoint.Inject("systemTimeSlow", func() { + } + if _, _err_ := failpoint.Eval(_curpkg_("systemTimeSlow")); _err_ == nil { now = now.Add(-time.Hour) - }) + } tsoCounter.WithLabelValues("save", t.dcLocation).Inc() diff --git a/server/tso/tso.go__failpoint_stash__ b/server/tso/tso.go__failpoint_stash__ new file mode 100644 index 00000000000..a19252962e9 --- /dev/null +++ b/server/tso/tso.go__failpoint_stash__ @@ -0,0 +1,420 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tso + +import ( + "fmt" + "path" + "strings" + "sync/atomic" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/etcdutil" + "github.com/tikv/pd/pkg/syncutil" + "github.com/tikv/pd/pkg/tsoutil" + "github.com/tikv/pd/pkg/typeutil" + "github.com/tikv/pd/server/election" + "go.etcd.io/etcd/clientv3" + "go.uber.org/zap" +) + +const ( + timestampKey = "timestamp" + // UpdateTimestampGuard is the min timestamp interval. + UpdateTimestampGuard = time.Millisecond + // maxLogical is the max upper limit for logical time. + // When a TSO's logical time reaches this limit, + // the physical time will be forced to increase. + maxLogical = int64(1 << 18) + // MaxSuffixBits indicates the max number of suffix bits. + MaxSuffixBits = 4 + // jetLagWarningThreshold is the warning threshold of jetLag in `timestampOracle.UpdateTimestamp`. + // In case of small `updatePhysicalInterval`, the `3 * updatePhysicalInterval` would also is small, + // and trigger unnecessary warnings about clock offset. + // It's an empirical value. + jetLagWarningThreshold = 150 * time.Millisecond +) + +// tsoObject is used to store the current TSO in memory with a RWMutex lock. +type tsoObject struct { + syncutil.RWMutex + physical time.Time + logical int64 + updateTime time.Time +} + +// timestampOracle is used to maintain the logic of TSO. +type timestampOracle struct { + client *clientv3.Client + rootPath string + // TODO: remove saveInterval + saveInterval time.Duration + updatePhysicalInterval time.Duration + maxResetTSGap func() time.Duration + // tso info stored in the memory + tsoMux *tsoObject + // last timestamp window stored in etcd + lastSavedTime atomic.Value // stored as time.Time + suffix int + dcLocation string +} + +func (t *timestampOracle) setTSOPhysical(next time.Time, force bool) { + t.tsoMux.Lock() + defer t.tsoMux.Unlock() + // Do not update the zero physical time if the `force` flag is false. + if t.tsoMux.physical == typeutil.ZeroTime && !force { + return + } + // make sure the ts won't fall back + if typeutil.SubTSOPhysicalByWallClock(next, t.tsoMux.physical) > 0 { + t.tsoMux.physical = next + t.tsoMux.logical = 0 + t.setTSOUpdateTimeLocked(time.Now()) + } +} + +func (t *timestampOracle) setTSOUpdateTimeLocked(updateTime time.Time) { + t.tsoMux.updateTime = updateTime +} + +func (t *timestampOracle) getTSO() (time.Time, int64) { + t.tsoMux.RLock() + defer t.tsoMux.RUnlock() + if t.tsoMux.physical == typeutil.ZeroTime { + return typeutil.ZeroTime, 0 + } + return t.tsoMux.physical, t.tsoMux.logical +} + +// generateTSO will add the TSO's logical part with the given count and returns the new TSO result. +func (t *timestampOracle) generateTSO(count int64, suffixBits int) (physical int64, logical int64, lastUpdateTime time.Time) { + t.tsoMux.Lock() + defer t.tsoMux.Unlock() + if t.tsoMux.physical == typeutil.ZeroTime { + return 0, 0, typeutil.ZeroTime + } + physical = t.tsoMux.physical.UnixNano() / int64(time.Millisecond) + t.tsoMux.logical += count + logical = t.tsoMux.logical + if suffixBits > 0 && t.suffix >= 0 { + logical = t.differentiateLogical(logical, suffixBits) + } + // Return the last update time + lastUpdateTime = t.tsoMux.updateTime + t.setTSOUpdateTimeLocked(time.Now()) + return physical, logical, lastUpdateTime +} + +// Because the Local TSO in each Local TSO Allocator is independent, so they are possible +// to be the same at sometimes, to avoid this case, we need to use the logical part of the +// Local TSO to do some differentiating work. +// For example, we have three DCs: dc-1, dc-2 and dc-3. The bits of suffix is defined by +// the const suffixBits. Then, for dc-2, the suffix may be 1 because it's persisted +// in etcd with the value of 1. +// Once we get a normal TSO like this (18 bits): xxxxxxxxxxxxxxxxxx. We will make the TSO's +// low bits of logical part from each DC looks like: +// +// global: xxxxxxxxxx00000000 +// dc-1: xxxxxxxxxx00000001 +// dc-2: xxxxxxxxxx00000010 +// dc-3: xxxxxxxxxx00000011 +func (t *timestampOracle) differentiateLogical(rawLogical int64, suffixBits int) int64 { + return rawLogical< 0 { + maxTSWindow = tsWindow + } + } + return maxTSWindow, nil +} + +// save timestamp, if lastTs is 0, we think the timestamp doesn't exist, so create it, +// otherwise, update it. +func (t *timestampOracle) saveTimestamp(leadership *election.Leadership, ts time.Time) error { + key := t.getTimestampPath() + data := typeutil.Uint64ToBytes(uint64(ts.UnixNano())) + resp, err := leadership.LeaderTxn(). + Then(clientv3.OpPut(key, string(data))). + Commit() + if err != nil { + return errs.ErrEtcdKVPut.Wrap(err).GenWithStackByCause() + } + if !resp.Succeeded { + return errs.ErrEtcdTxnConflict.FastGenByArgs() + } + t.lastSavedTime.Store(ts) + return nil +} + +// SyncTimestamp is used to synchronize the timestamp. +func (t *timestampOracle) SyncTimestamp(leadership *election.Leadership) error { + tsoCounter.WithLabelValues("sync", t.dcLocation).Inc() + + failpoint.Inject("delaySyncTimestamp", func() { + time.Sleep(time.Second) + }) + + last, err := t.loadTimestamp() + if err != nil { + return err + } + + next := time.Now() + failpoint.Inject("fallBackSync", func() { + next = next.Add(time.Hour) + }) + failpoint.Inject("systemTimeSlow", func() { + next = next.Add(-time.Hour) + }) + // If the current system time minus the saved etcd timestamp is less than `UpdateTimestampGuard`, + // the timestamp allocation will start from the saved etcd timestamp temporarily. + if typeutil.SubRealTimeByWallClock(next, last) < UpdateTimestampGuard { + log.Error("system time may be incorrect", zap.Time("last", last), zap.Time("next", next), errs.ZapError(errs.ErrIncorrectSystemTime)) + next = last.Add(UpdateTimestampGuard) + } + + save := next.Add(t.saveInterval) + if err = t.saveTimestamp(leadership, save); err != nil { + tsoCounter.WithLabelValues("err_save_sync_ts", t.dcLocation).Inc() + return err + } + + tsoCounter.WithLabelValues("sync_ok", t.dcLocation).Inc() + log.Info("sync and save timestamp", zap.Time("last", last), zap.Time("save", save), zap.Time("next", next)) + // save into memory + t.setTSOPhysical(next, true) + return nil +} + +// isInitialized is used to check whether the timestampOracle is initialized. +// There are two situations we have an uninitialized timestampOracle: +// 1. When the SyncTimestamp has not been called yet. +// 2. When the ResetUserTimestamp has been called already. +func (t *timestampOracle) isInitialized() bool { + t.tsoMux.RLock() + defer t.tsoMux.RUnlock() + return t.tsoMux.physical != typeutil.ZeroTime +} + +// resetUserTimestamp update the TSO in memory with specified TSO by an atomically way. +// When ignoreSmaller is true, resetUserTimestamp will ignore the smaller tso resetting error and do nothing. +// It's used to write MaxTS during the Global TSO synchronization whitout failing the writing as much as possible. +// cannot set timestamp to one which >= current + maxResetTSGap +func (t *timestampOracle) resetUserTimestamp(leadership *election.Leadership, tso uint64, ignoreSmaller bool) error { + return t.resetUserTimestampInner(leadership, tso, ignoreSmaller, false) +} + +func (t *timestampOracle) resetUserTimestampInner(leadership *election.Leadership, tso uint64, ignoreSmaller, skipUpperBoundCheck bool) error { + t.tsoMux.Lock() + defer t.tsoMux.Unlock() + if !leadership.Check() { + tsoCounter.WithLabelValues("err_lease_reset_ts", t.dcLocation).Inc() + return errs.ErrResetUserTimestamp.FastGenByArgs("lease expired") + } + var ( + nextPhysical, nextLogical = tsoutil.ParseTS(tso) + logicalDifference = int64(nextLogical) - t.tsoMux.logical + physicalDifference = typeutil.SubTSOPhysicalByWallClock(nextPhysical, t.tsoMux.physical) + ) + // do not update if next physical time is less/before than prev + if physicalDifference < 0 { + tsoCounter.WithLabelValues("err_reset_small_ts", t.dcLocation).Inc() + if ignoreSmaller { + return nil + } + return errs.ErrResetUserTimestamp.FastGenByArgs("the specified ts is smaller than now") + } + // do not update if next logical time is less/before/equal than prev + if physicalDifference == 0 && logicalDifference <= 0 { + tsoCounter.WithLabelValues("err_reset_small_counter", t.dcLocation).Inc() + if ignoreSmaller { + return nil + } + return errs.ErrResetUserTimestamp.FastGenByArgs("the specified counter is smaller than now") + } + // do not update if physical time is too greater than prev + if !skipUpperBoundCheck && physicalDifference >= t.maxResetTSGap().Milliseconds() { + tsoCounter.WithLabelValues("err_reset_large_ts", t.dcLocation).Inc() + return errs.ErrResetUserTimestamp.FastGenByArgs("the specified ts is too larger than now") + } + // save into etcd only if nextPhysical is close to lastSavedTime + if typeutil.SubRealTimeByWallClock(t.lastSavedTime.Load().(time.Time), nextPhysical) <= UpdateTimestampGuard { + save := nextPhysical.Add(t.saveInterval) + if err := t.saveTimestamp(leadership, save); err != nil { + tsoCounter.WithLabelValues("err_save_reset_ts", t.dcLocation).Inc() + return err + } + } + // save into memory only if nextPhysical or nextLogical is greater. + t.tsoMux.physical = nextPhysical + t.tsoMux.logical = int64(nextLogical) + t.setTSOUpdateTimeLocked(time.Now()) + tsoCounter.WithLabelValues("reset_tso_ok", t.dcLocation).Inc() + return nil +} + +// UpdateTimestamp is used to update the timestamp. +// This function will do two things: +// 1. When the logical time is going to be used up, increase the current physical time. +// 2. When the time window is not big enough, which means the saved etcd time minus the next physical time +// will be less than or equal to `UpdateTimestampGuard`, then the time window needs to be updated and +// we also need to save the next physical time plus `TSOSaveInterval` into etcd. +// +// Here is some constraints that this function must satisfy: +// 1. The saved time is monotonically increasing. +// 2. The physical time is monotonically increasing. +// 3. The physical time is always less than the saved timestamp. +// +// NOTICE: this function should be called after the TSO in memory has been initialized +// and should not be called when the TSO in memory has been reset anymore. +func (t *timestampOracle) UpdateTimestamp(leadership *election.Leadership) error { + prevPhysical, prevLogical := t.getTSO() + tsoGauge.WithLabelValues("tso", t.dcLocation).Set(float64(prevPhysical.UnixNano() / int64(time.Millisecond))) + tsoGap.WithLabelValues(t.dcLocation).Set(float64(time.Since(prevPhysical).Milliseconds())) + + now := time.Now() + failpoint.Inject("fallBackUpdate", func() { + now = now.Add(time.Hour) + }) + failpoint.Inject("systemTimeSlow", func() { + now = now.Add(-time.Hour) + }) + + tsoCounter.WithLabelValues("save", t.dcLocation).Inc() + + jetLag := typeutil.SubRealTimeByWallClock(now, prevPhysical) + if jetLag > 3*t.updatePhysicalInterval && jetLag > jetLagWarningThreshold { + log.Warn("clock offset", zap.Duration("jet-lag", jetLag), zap.Time("prev-physical", prevPhysical), zap.Time("now", now), zap.Duration("update-physical-interval", t.updatePhysicalInterval)) + tsoCounter.WithLabelValues("slow_save", t.dcLocation).Inc() + } + + if jetLag < 0 { + tsoCounter.WithLabelValues("system_time_slow", t.dcLocation).Inc() + } + + var next time.Time + // If the system time is greater, it will be synchronized with the system time. + if jetLag > UpdateTimestampGuard { + next = now + } else if prevLogical > maxLogical/2 { + // The reason choosing maxLogical/2 here is that it's big enough for common cases. + // Because there is enough timestamp can be allocated before next update. + log.Warn("the logical time may be not enough", zap.Int64("prev-logical", prevLogical)) + next = prevPhysical.Add(time.Millisecond) + } else { + // It will still use the previous physical time to alloc the timestamp. + tsoCounter.WithLabelValues("skip_save", t.dcLocation).Inc() + return nil + } + + // It is not safe to increase the physical time to `next`. + // The time window needs to be updated and saved to etcd. + if typeutil.SubRealTimeByWallClock(t.lastSavedTime.Load().(time.Time), next) <= UpdateTimestampGuard { + save := next.Add(t.saveInterval) + if err := t.saveTimestamp(leadership, save); err != nil { + tsoCounter.WithLabelValues("err_save_update_ts", t.dcLocation).Inc() + return err + } + } + // save into memory + t.setTSOPhysical(next, false) + + return nil +} + +var maxRetryCount = 10 + +// getTS is used to get a timestamp. +func (t *timestampOracle) getTS(leadership *election.Leadership, count uint32, suffixBits int) (pdpb.Timestamp, error) { + var resp pdpb.Timestamp + if count == 0 { + return resp, errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") + } + for i := 0; i < maxRetryCount; i++ { + currentPhysical, _ := t.getTSO() + if currentPhysical == typeutil.ZeroTime { + // If it's leader, maybe SyncTimestamp hasn't completed yet + if leadership.Check() { + time.Sleep(200 * time.Millisecond) + continue + } + tsoCounter.WithLabelValues("not_leader_anymore", t.dcLocation).Inc() + return pdpb.Timestamp{}, errs.ErrGenerateTimestamp.FastGenByArgs("timestamp in memory isn't initialized") + } + // Get a new TSO result with the given count + resp.Physical, resp.Logical, _ = t.generateTSO(int64(count), suffixBits) + if resp.GetPhysical() == 0 { + return pdpb.Timestamp{}, errs.ErrGenerateTimestamp.FastGenByArgs("timestamp in memory has been reset") + } + if resp.GetLogical() >= maxLogical { + log.Warn("logical part outside of max logical interval, please check ntp time, or adjust config item `tso-update-physical-interval`", + zap.Reflect("response", resp), + zap.Int("retry-count", i), errs.ZapError(errs.ErrLogicOverflow)) + tsoCounter.WithLabelValues("logical_overflow", t.dcLocation).Inc() + time.Sleep(t.updatePhysicalInterval) + continue + } + // In case lease expired after the first check. + if !leadership.Check() { + return pdpb.Timestamp{}, errs.ErrGenerateTimestamp.FastGenByArgs("not the pd or local tso allocator leader anymore") + } + resp.SuffixBits = uint32(suffixBits) + return resp, nil + } + tsoCounter.WithLabelValues("exceeded_max_retry", t.dcLocation).Inc() + return resp, errs.ErrGenerateTimestamp.FastGenByArgs(fmt.Sprintf("generate %s tso maximum number of retries exceeded", t.dcLocation)) +} + +// ResetTimestamp is used to reset the timestamp in memory. +func (t *timestampOracle) ResetTimestamp() { + t.tsoMux.Lock() + defer t.tsoMux.Unlock() + log.Info("reset the timestamp in memory") + t.tsoMux.physical = typeutil.ZeroTime + t.tsoMux.logical = 0 + t.setTSOUpdateTimeLocked(typeutil.ZeroTime) +}