From 69d64676f429c37a425404fa8b3cc697831cd89d Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Mon, 8 May 2023 10:44:57 +0800 Subject: [PATCH] etcdutil: add watch loop (#6390) close tikv/pd#6391 Signed-off-by: lhy1024 Signed-off-by: zeminzhou --- pkg/keyspace/tso_keyspace_group.go | 167 ++++--------- pkg/tso/keyspace_group_manager.go | 326 ++++++------------------ pkg/tso/keyspace_group_manager_test.go | 20 +- pkg/utils/etcdutil/etcdutil.go | 268 ++++++++++++++++++++ pkg/utils/etcdutil/etcdutil_test.go | 333 +++++++++++++++++++++++++ pkg/utils/tsoutil/tso_dispatcher.go | 18 +- server/grpc_service.go | 8 +- server/keyspace_service.go | 107 +++----- server/server.go | 161 +++--------- 9 files changed, 809 insertions(+), 599 deletions(-) diff --git a/pkg/keyspace/tso_keyspace_group.go b/pkg/keyspace/tso_keyspace_group.go index b65a68011e5..46810be92d5 100644 --- a/pkg/keyspace/tso_keyspace_group.go +++ b/pkg/keyspace/tso_keyspace_group.go @@ -32,6 +32,7 @@ import ( "github.com/tikv/pd/pkg/utils/etcdutil" "github.com/tikv/pd/pkg/utils/logutil" "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/mvcc/mvccpb" "go.uber.org/zap" ) @@ -40,10 +41,6 @@ const ( allocNodesToKeyspaceGroupsInterval = 1 * time.Second allocNodesTimeout = 1 * time.Second allocNodesInterval = 10 * time.Millisecond - // TODO: move it to etcdutil - watchEtcdChangeRetryInterval = 1 * time.Second - maxRetryTimes = 25 - retryInterval = 100 * time.Millisecond ) const ( @@ -65,18 +62,14 @@ type GroupManager struct { // store is the storage for keyspace group related information. store endpoint.KeyspaceGroupStorage - client *clientv3.Client - - // tsoServiceKey is the path of TSO service in etcd. - tsoServiceKey string - // tsoServiceEndKey is the end key of TSO service in etcd. - tsoServiceEndKey string - - // TODO: add user kind with different balancer - // when we ensure where the correspondence between tso node and user kind will be found + // nodeBalancer is the balancer for tso nodes. + // TODO: add user kind with different balancer when we ensure where the correspondence between tso node and user kind will be found nodesBalancer balancer.Balancer[string] // serviceRegistryMap stores the mapping from the service registry key to the service address. + // Note: it is only used in tsoNodesWatcher. serviceRegistryMap map[string]string + // tsoNodesWatcher is the watcher for the registered tso servers. + tsoNodesWatcher *etcdutil.LoopWatcher } // NewKeyspaceGroupManager creates a Manager of keyspace group related data. @@ -87,7 +80,6 @@ func NewKeyspaceGroupManager( clusterID uint64, ) *GroupManager { ctx, cancel := context.WithCancel(ctx) - key := discovery.TSOPath(clusterID) groups := make(map[endpoint.UserKind]*indexedHeap) for i := 0; i < int(endpoint.UserKindCount); i++ { groups[endpoint.UserKind(i)] = newIndexedHeap(int(utils.MaxKeyspaceGroupCountInUse)) @@ -96,9 +88,6 @@ func NewKeyspaceGroupManager( ctx: ctx, cancel: cancel, store: store, - client: client, - tsoServiceKey: key, - tsoServiceEndKey: clientv3.GetPrefixRangeEnd(key) + "/", groups: groups, nodesBalancer: balancer.GenByPolicy[string](defaultBalancerPolicy), serviceRegistryMap: make(map[string]string), @@ -106,10 +95,11 @@ func NewKeyspaceGroupManager( // If the etcd client is not nil, start the watch loop for the registered tso servers. // The PD(TSO) Client relies on this info to discover tso servers. - if m.client != nil { - log.Info("start the watch loop for tso service discovery") - m.wg.Add(1) - go m.startWatchLoop(ctx) + if client != nil { + m.initTSONodesWatcher(client, clusterID) + m.wg.Add(2) + go m.tsoNodesWatcher.StartWatchLoop() + go m.allocNodesToAllKeyspaceGroups() } return m @@ -130,12 +120,6 @@ func (m *GroupManager) Bootstrap() error { m.Lock() defer m.Unlock() - // If the etcd client is not nil, start the watch loop. - if m.client != nil { - m.wg.Add(1) - go m.allocNodesToAllKeyspaceGroups() - } - // Ignore the error if default keyspace group already exists in the storage (e.g. PD restart/recover). err := m.saveKeyspaceGroups([]*endpoint.KeyspaceGroup{defaultKeyspaceGroup}, false) if err != nil && err != ErrKeyspaceGroupExists { @@ -200,109 +184,42 @@ func (m *GroupManager) allocNodesToAllKeyspaceGroups() { } } -func (m *GroupManager) startWatchLoop(parentCtx context.Context) { - defer logutil.LogPanic() - defer m.wg.Done() - ctx, cancel := context.WithCancel(parentCtx) - defer cancel() - var ( - resp *clientv3.GetResponse - revision int64 - err error - ) - ticker := time.NewTicker(retryInterval) - defer ticker.Stop() - for i := 0; i < maxRetryTimes; i++ { - resp, err = etcdutil.EtcdKVGet(m.client, m.tsoServiceKey, clientv3.WithRange(m.tsoServiceEndKey)) - if err == nil { - revision = resp.Header.Revision + 1 - for _, item := range resp.Kvs { - s := &discovery.ServiceRegistryEntry{} - if err := json.Unmarshal(item.Value, s); err != nil { - log.Warn("failed to unmarshal service registry entry", zap.Error(err)) - continue - } - m.nodesBalancer.Put(s.ServiceAddr) - m.serviceRegistryMap[string(item.Key)] = s.ServiceAddr - } - break - } - log.Warn("failed to get tso service addrs from etcd and will retry", zap.Error(err)) - select { - case <-m.ctx.Done(): - return - case <-ticker.C: +func (m *GroupManager) initTSONodesWatcher(client *clientv3.Client, clusterID uint64) { + tsoServiceKey := discovery.TSOPath(clusterID) + tsoServiceEndKey := clientv3.GetPrefixRangeEnd(tsoServiceKey) + "/" + + putFn := func(kv *mvccpb.KeyValue) error { + s := &discovery.ServiceRegistryEntry{} + if err := json.Unmarshal(kv.Value, s); err != nil { + log.Warn("failed to unmarshal service registry entry", + zap.String("event-kv-key", string(kv.Key)), zap.Error(err)) + return err } + m.nodesBalancer.Put(s.ServiceAddr) + m.serviceRegistryMap[string(kv.Key)] = s.ServiceAddr + return nil } - if err != nil || revision == 0 { - log.Warn("failed to get tso service addrs from etcd finally when loading", zap.Error(err)) - } - for { - select { - case <-ctx.Done(): - return - default: - } - nextRevision, err := m.watchServiceAddrs(ctx, revision) - if err != nil { - log.Error("watcher canceled unexpectedly and a new watcher will start after a while", - zap.Int64("next-revision", nextRevision), - zap.Time("retry-at", time.Now().Add(watchEtcdChangeRetryInterval)), - zap.Error(err)) - revision = nextRevision - time.Sleep(watchEtcdChangeRetryInterval) + deleteFn := func(kv *mvccpb.KeyValue) error { + key := string(kv.Key) + if serviceAddr, ok := m.serviceRegistryMap[key]; ok { + delete(m.serviceRegistryMap, key) + m.nodesBalancer.Delete(serviceAddr) + return nil } + return errors.Errorf("failed to find the service address for key %s", key) } -} -func (m *GroupManager) watchServiceAddrs(ctx context.Context, revision int64) (int64, error) { - watcher := clientv3.NewWatcher(m.client) - defer watcher.Close() - for { - WatchChan: - watchChan := watcher.Watch(ctx, m.tsoServiceKey, clientv3.WithRange(m.tsoServiceEndKey), clientv3.WithRev(revision)) - select { - case <-ctx.Done(): - return revision, nil - case wresp := <-watchChan: - if wresp.CompactRevision != 0 { - log.Warn("required revision has been compacted, the watcher will watch again with the compact revision", - zap.Int64("required-revision", revision), - zap.Int64("compact-revision", wresp.CompactRevision)) - revision = wresp.CompactRevision - goto WatchChan - } - if wresp.Err() != nil { - log.Error("watch is canceled or closed", - zap.Int64("required-revision", revision), - zap.Error(wresp.Err())) - return revision, wresp.Err() - } - for _, event := range wresp.Events { - switch event.Type { - case clientv3.EventTypePut: - s := &discovery.ServiceRegistryEntry{} - if err := json.Unmarshal(event.Kv.Value, s); err != nil { - log.Warn("failed to unmarshal service registry entry", - zap.String("event-kv-key", string(event.Kv.Key)), zap.Error(err)) - break - } - m.nodesBalancer.Put(s.ServiceAddr) - m.serviceRegistryMap[string(event.Kv.Key)] = s.ServiceAddr - case clientv3.EventTypeDelete: - key := string(event.Kv.Key) - if serviceAddr, ok := m.serviceRegistryMap[key]; ok { - delete(m.serviceRegistryMap, key) - m.nodesBalancer.Delete(serviceAddr) - } else { - log.Warn("can't retrieve service addr from service registry map", - zap.String("event-kv-key", key)) - } - } - } - revision = wresp.Header.Revision + 1 - } - } + m.tsoNodesWatcher = etcdutil.NewLoopWatcher( + m.ctx, + &m.wg, + client, + "tso-nodes-watcher", + tsoServiceKey, + putFn, + deleteFn, + func() error { return nil }, + clientv3.WithRange(tsoServiceEndKey), + ) } // CreateKeyspaceGroups creates keyspace groups. diff --git a/pkg/tso/keyspace_group_manager.go b/pkg/tso/keyspace_group_manager.go index 0089e0d9bdc..6990a298589 100644 --- a/pkg/tso/keyspace_group_manager.go +++ b/pkg/tso/keyspace_group_manager.go @@ -17,7 +17,6 @@ package tso import ( "context" "encoding/json" - "errors" "fmt" "net/http" "path" @@ -27,7 +26,6 @@ import ( "time" perrors "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" "github.com/tikv/pd/pkg/election" @@ -43,19 +41,14 @@ import ( "github.com/tikv/pd/pkg/utils/memberutil" "github.com/tikv/pd/pkg/utils/tsoutil" "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/mvcc/mvccpb" "go.uber.org/zap" ) const ( // primaryElectionSuffix is the suffix of the key for keyspace group primary election primaryElectionSuffix = "primary" - // defaultLoadKeyspaceGroupsTimeout is the default timeout for loading the initial - // keyspace group assignment - defaultLoadKeyspaceGroupsTimeout = 30 * time.Second - defaultLoadKeyspaceGroupsBatchSize = int64(400) - defaultLoadFromEtcdRetryInterval = 500 * time.Millisecond - defaultLoadFromEtcdMaxRetryTimes = int(defaultLoadKeyspaceGroupsTimeout / defaultLoadFromEtcdRetryInterval) - watchEtcdChangeRetryInterval = 1 * time.Second + defaultRetryInterval = 500 * time.Millisecond ) type state struct { @@ -205,6 +198,7 @@ type KeyspaceGroupManager struct { tsoSvcStorage *endpoint.StorageEndpoint // cfg is the TSO config cfg ServiceConfig + // loadKeyspaceGroupsTimeout is the timeout for loading the initial keyspace group assignment. loadKeyspaceGroupsTimeout time.Duration loadKeyspaceGroupsBatchSize int64 @@ -212,6 +206,8 @@ type KeyspaceGroupManager struct { // groupUpdateRetryList is the list of keyspace groups which failed to update and need to retry. groupUpdateRetryList map[uint32]*endpoint.KeyspaceGroup + + groupWatcher *etcdutil.LoopWatcher } // NewKeyspaceGroupManager creates a new Keyspace Group Manager. @@ -233,19 +229,16 @@ func NewKeyspaceGroupManager( ctx, cancel := context.WithCancel(ctx) kgm := &KeyspaceGroupManager{ - ctx: ctx, - cancel: cancel, - tsoServiceID: tsoServiceID, - etcdClient: etcdClient, - httpClient: httpClient, - electionNamePrefix: electionNamePrefix, - legacySvcRootPath: legacySvcRootPath, - tsoSvcRootPath: tsoSvcRootPath, - cfg: cfg, - loadKeyspaceGroupsTimeout: defaultLoadKeyspaceGroupsTimeout, - loadKeyspaceGroupsBatchSize: defaultLoadKeyspaceGroupsBatchSize, - loadFromEtcdMaxRetryTimes: defaultLoadFromEtcdMaxRetryTimes, - groupUpdateRetryList: make(map[uint32]*endpoint.KeyspaceGroup), + ctx: ctx, + cancel: cancel, + tsoServiceID: tsoServiceID, + etcdClient: etcdClient, + httpClient: httpClient, + electionNamePrefix: electionNamePrefix, + legacySvcRootPath: legacySvcRootPath, + tsoSvcRootPath: tsoSvcRootPath, + cfg: cfg, + groupUpdateRetryList: make(map[uint32]*endpoint.KeyspaceGroup), } kgm.legacySvcStorage = endpoint.NewStorageEndpoint( kv.NewEtcdKVBase(kgm.etcdClient, kgm.legacySvcRootPath), nil) @@ -257,20 +250,69 @@ func NewKeyspaceGroupManager( // Initialize this KeyspaceGroupManager func (kgm *KeyspaceGroupManager) Initialize() error { - // Load the initial keyspace group assignment from storage with time limit - done := make(chan struct{}, 1) - ctx, cancel := context.WithCancel(kgm.ctx) - go kgm.checkInitProgress(ctx, cancel, done) - watchStartRevision, defaultKGConfigured, err := kgm.initAssignment(ctx) - done <- struct{}{} - if err != nil { + rootPath := kgm.legacySvcRootPath + startKey := strings.Join([]string{rootPath, endpoint.KeyspaceGroupIDPath(mcsutils.DefaultKeyspaceGroupID)}, "/") + endKey := strings.Join( + []string{rootPath, clientv3.GetPrefixRangeEnd(endpoint.KeyspaceGroupIDPrefix())}, "/") + + defaultKGConfigured := false + putFn := func(kv *mvccpb.KeyValue) error { + group := &endpoint.KeyspaceGroup{} + if err := json.Unmarshal(kv.Value, group); err != nil { + return errs.ErrJSONUnmarshal.Wrap(err).FastGenWithCause() + } + kgm.updateKeyspaceGroup(group) + if group.ID == mcsutils.DefaultKeyspaceGroupID { + defaultKGConfigured = true + } + return nil + } + deleteFn := func(kv *mvccpb.KeyValue) error { + groupID, err := endpoint.ExtractKeyspaceGroupIDFromPath(string(kv.Key)) + if err != nil { + return err + } + kgm.deleteKeyspaceGroup(groupID) + return nil + } + postEventFn := func() error { + // Retry the groups that are not initialized successfully before. + for id, group := range kgm.groupUpdateRetryList { + delete(kgm.groupUpdateRetryList, id) + kgm.updateKeyspaceGroup(group) + } + return nil + } + kgm.groupWatcher = etcdutil.NewLoopWatcher( + kgm.ctx, + &kgm.wg, + kgm.etcdClient, + "keyspace-watcher", + startKey, + putFn, + deleteFn, + postEventFn, + clientv3.WithRange(endKey), + ) + if kgm.loadKeyspaceGroupsTimeout > 0 { + kgm.groupWatcher.SetLoadTimeout(kgm.loadKeyspaceGroupsTimeout) + } + if kgm.loadFromEtcdMaxRetryTimes > 0 { + kgm.groupWatcher.SetLoadRetryTimes(kgm.loadFromEtcdMaxRetryTimes) + } + if kgm.loadKeyspaceGroupsBatchSize > 0 { + kgm.groupWatcher.SetLoadBatchSize(kgm.loadKeyspaceGroupsBatchSize) + } + kgm.wg.Add(1) + go kgm.groupWatcher.StartWatchLoop() + + if err := kgm.groupWatcher.WaitLoad(); err != nil { log.Error("failed to initialize keyspace group manager", errs.ZapError(err)) // We might have partially loaded/initialized the keyspace groups. Close the manager to clean up. kgm.Close() - return err + return errs.ErrLoadKeyspaceGroupsTerminated } - // Initialize the default keyspace group if it isn't configured in the storage. if !defaultKGConfigured { log.Info("initializing default keyspace group") group := &endpoint.KeyspaceGroup{ @@ -280,12 +322,6 @@ func (kgm *KeyspaceGroupManager) Initialize() error { } kgm.updateKeyspaceGroup(group) } - - // Watch/apply keyspace group membership/distribution meta changes dynamically. - kgm.wg.Add(1) - go kgm.startKeyspaceGroupsMetaWatchLoop(watchStartRevision) - - log.Info("keyspace group manager initialized") return nil } @@ -305,222 +341,6 @@ func (kgm *KeyspaceGroupManager) Close() { log.Info("keyspace group manager closed") } -func (kgm *KeyspaceGroupManager) checkInitProgress(ctx context.Context, cancel context.CancelFunc, done chan struct{}) { - defer logutil.LogPanic() - - select { - case <-done: - return - case <-time.After(kgm.loadKeyspaceGroupsTimeout): - log.Error("failed to initialize keyspace group manager", - zap.Any("timeout-setting", kgm.loadKeyspaceGroupsTimeout), - errs.ZapError(errs.ErrLoadKeyspaceGroupsTimeout)) - cancel() - case <-ctx.Done(): - } - <-done -} - -// initAssignment loads initial keyspace group assignment from storage and initialize the group manager. -// Return watchStartRevision, the start revision for watching keyspace group membership/distribution change. -func (kgm *KeyspaceGroupManager) initAssignment( - ctx context.Context, -) (watchStartRevision int64, defaultKGConfigured bool, err error) { - var ( - groups []*endpoint.KeyspaceGroup - more bool - keyspaceGroupsLoaded uint32 - revision int64 - ) - - // Load all keyspace groups from etcd and apply the ones assigned to this tso service. - for { - revision, groups, more, err = kgm.loadKeyspaceGroups(ctx, keyspaceGroupsLoaded, kgm.loadKeyspaceGroupsBatchSize) - if err != nil { - return - } - - keyspaceGroupsLoaded += uint32(len(groups)) - - if watchStartRevision == 0 || revision < watchStartRevision { - watchStartRevision = revision - } - - // Update the keyspace groups - for _, group := range groups { - select { - case <-ctx.Done(): - err = errs.ErrLoadKeyspaceGroupsTerminated - return - default: - } - - if group.ID == mcsutils.DefaultKeyspaceGroupID { - defaultKGConfigured = true - } - - kgm.updateKeyspaceGroup(group) - } - - if !more { - break - } - } - - log.Info("loaded keyspace groups", zap.Uint32("keyspace-groups-loaded", keyspaceGroupsLoaded)) - return -} - -// loadKeyspaceGroups loads keyspace groups from the start ID with limit. -// If limit is 0, it will load all keyspace groups from the start ID. -func (kgm *KeyspaceGroupManager) loadKeyspaceGroups( - ctx context.Context, startID uint32, limit int64, -) (revision int64, ksgs []*endpoint.KeyspaceGroup, more bool, err error) { - rootPath := kgm.legacySvcRootPath - startKey := strings.Join([]string{rootPath, endpoint.KeyspaceGroupIDPath(startID)}, "/") - endKey := strings.Join( - []string{rootPath, clientv3.GetPrefixRangeEnd(endpoint.KeyspaceGroupIDPrefix())}, "/") - opOption := []clientv3.OpOption{clientv3.WithRange(endKey), clientv3.WithLimit(limit)} - - var ( - i int - resp *clientv3.GetResponse - ) - for ; i < kgm.loadFromEtcdMaxRetryTimes; i++ { - resp, err = etcdutil.EtcdKVGet(kgm.etcdClient, startKey, opOption...) - - failpoint.Inject("delayLoadKeyspaceGroups", func(val failpoint.Value) { - if sleepIntervalSeconds, ok := val.(int); ok && sleepIntervalSeconds > 0 { - time.Sleep(time.Duration(sleepIntervalSeconds) * time.Second) - } - }) - - failpoint.Inject("loadKeyspaceGroupsTemporaryFail", func(val failpoint.Value) { - if maxFailTimes, ok := val.(int); ok && i < maxFailTimes { - err = errors.New("fail to read from etcd") - failpoint.Continue() - } - }) - - if err == nil && resp != nil { - break - } - - select { - case <-ctx.Done(): - return 0, []*endpoint.KeyspaceGroup{}, false, errs.ErrLoadKeyspaceGroupsTerminated - case <-time.After(defaultLoadFromEtcdRetryInterval): - } - } - - if i == kgm.loadFromEtcdMaxRetryTimes { - return 0, []*endpoint.KeyspaceGroup{}, false, errs.ErrLoadKeyspaceGroupsRetryExhausted.FastGenByArgs(err) - } - - kgs := make([]*endpoint.KeyspaceGroup, 0, len(resp.Kvs)) - for _, item := range resp.Kvs { - kg := &endpoint.KeyspaceGroup{} - if err = json.Unmarshal(item.Value, kg); err != nil { - return 0, nil, false, err - } - kgs = append(kgs, kg) - } - - if resp.Header != nil { - revision = resp.Header.Revision + 1 - } - - return revision, kgs, resp.More, nil -} - -// startKeyspaceGroupsMetaWatchLoop repeatedly watches any change in keyspace group membership/distribution -// and apply the change dynamically. -func (kgm *KeyspaceGroupManager) startKeyspaceGroupsMetaWatchLoop(revision int64) { - defer logutil.LogPanic() - defer kgm.wg.Done() - - // Repeatedly watch/apply keyspace group membership/distribution changes until the context is canceled. - for { - select { - case <-kgm.ctx.Done(): - return - default: - } - - nextRevision, err := kgm.watchKeyspaceGroupsMetaChange(revision) - if err != nil { - log.Error("watcher canceled unexpectedly and a new watcher will start after a while", - zap.Int64("next-revision", nextRevision), - zap.Time("retry-at", time.Now().Add(watchEtcdChangeRetryInterval)), - zap.Error(err)) - time.Sleep(watchEtcdChangeRetryInterval) - } - } -} - -// watchKeyspaceGroupsMetaChange watches any change in keyspace group membership/distribution -// and apply the change dynamically. -func (kgm *KeyspaceGroupManager) watchKeyspaceGroupsMetaChange(revision int64) (int64, error) { - watcher := clientv3.NewWatcher(kgm.etcdClient) - defer watcher.Close() - - ksgPrefix := strings.Join([]string{kgm.legacySvcRootPath, endpoint.KeyspaceGroupIDPrefix()}, "/") - log.Info("start to watch keyspace group meta change", zap.Int64("revision", revision), zap.String("prefix", ksgPrefix)) - - for { - watchChan := watcher.Watch(kgm.ctx, ksgPrefix, clientv3.WithPrefix(), clientv3.WithRev(revision)) - for wresp := range watchChan { - if wresp.CompactRevision != 0 { - log.Warn("Required revision has been compacted, the watcher will watch again with the compact revision", - zap.Int64("required-revision", revision), - zap.Int64("compact-revision", wresp.CompactRevision)) - revision = wresp.CompactRevision - break - } - if wresp.Err() != nil { - log.Error("watch is canceled or closed", - zap.Int64("required-revision", revision), - errs.ZapError(errs.ErrEtcdWatcherCancel, wresp.Err())) - return revision, wresp.Err() - } - for _, event := range wresp.Events { - groupID, err := endpoint.ExtractKeyspaceGroupIDFromPath(string(event.Kv.Key)) - if err != nil { - log.Warn("failed to extract keyspace group ID from the key path", - zap.String("key-path", string(event.Kv.Key)), zap.Error(err)) - continue - } - - switch event.Type { - case clientv3.EventTypePut: - group := &endpoint.KeyspaceGroup{} - if err := json.Unmarshal(event.Kv.Value, group); err != nil { - log.Warn("failed to unmarshal keyspace group", - zap.Uint32("keyspace-group-id", groupID), - zap.Error(errs.ErrJSONUnmarshal.Wrap(err).FastGenWithCause())) - break - } - kgm.updateKeyspaceGroup(group) - case clientv3.EventTypeDelete: - kgm.deleteKeyspaceGroup(groupID) - } - } - // Retry the groups that are not initialized successfully before. - for id, group := range kgm.groupUpdateRetryList { - delete(kgm.groupUpdateRetryList, id) - kgm.updateKeyspaceGroup(group) - } - revision = wresp.Header.Revision + 1 - } - - select { - case <-kgm.ctx.Done(): - return revision, nil - default: - } - } -} - func (kgm *KeyspaceGroupManager) isAssignedToMe(group *endpoint.KeyspaceGroup) bool { for _, member := range group.Members { if member.Address == kgm.tsoServiceID.ServiceAddr { diff --git a/pkg/tso/keyspace_group_manager_test.go b/pkg/tso/keyspace_group_manager_test.go index d6e8cb4b046..6b8beb3b0ae 100644 --- a/pkg/tso/keyspace_group_manager_test.go +++ b/pkg/tso/keyspace_group_manager_test.go @@ -106,8 +106,6 @@ func (suite *keyspaceGroupManagerTestSuite) TestNewKeyspaceGroupManager() { re.Equal(legacySvcRootPath, kgm.legacySvcRootPath) re.Equal(tsoSvcRootPath, kgm.tsoSvcRootPath) re.Equal(suite.cfg, kgm.cfg) - re.Equal(defaultLoadKeyspaceGroupsBatchSize, kgm.loadKeyspaceGroupsBatchSize) - re.Equal(defaultLoadKeyspaceGroupsTimeout, kgm.loadKeyspaceGroupsTimeout) am, err := kgm.GetAllocatorManager(mcsutils.DefaultKeyspaceGroupID) re.NoError(err) @@ -179,14 +177,14 @@ func (suite *keyspaceGroupManagerTestSuite) TestLoadKeyspaceGroupsTimeout() { suite.ctx, suite.etcdClient, true, mgr.legacySvcRootPath, mgr.tsoServiceID.ServiceAddr, uint32(0), []uint32{0}) - // Set the timeout to 1 second and inject the delayLoadKeyspaceGroups to return 3 seconds to let + // Set the timeout to 1 second and inject the delayLoad to return 3 seconds to let // the loading sleep 3 seconds. mgr.loadKeyspaceGroupsTimeout = time.Second - re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/tso/delayLoadKeyspaceGroups", "return(3)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/etcdutil/delayLoad", "return(3)")) err := mgr.Initialize() // If loading keyspace groups timeout, the initialization should fail with ErrLoadKeyspaceGroupsTerminated. re.Equal(errs.ErrLoadKeyspaceGroupsTerminated, err) - re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/tso/delayLoadKeyspaceGroups")) + re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/etcdutil/delayLoad")) } // TestLoadKeyspaceGroupsSucceedWithTempFailures tests the initialization should succeed when there are temporary @@ -202,13 +200,13 @@ func (suite *keyspaceGroupManagerTestSuite) TestLoadKeyspaceGroupsSucceedWithTem suite.ctx, suite.etcdClient, true, mgr.legacySvcRootPath, mgr.tsoServiceID.ServiceAddr, uint32(0), []uint32{0}) - // Set the max retry times to 3 and inject the loadKeyspaceGroupsTemporaryFail to return 2 to let + // Set the max retry times to 3 and inject the loadTemporaryFail to return 2 to let // loading from etcd fail 2 times but the whole initialization still succeeds. mgr.loadFromEtcdMaxRetryTimes = 3 - re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/tso/loadKeyspaceGroupsTemporaryFail", "return(2)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/etcdutil/loadTemporaryFail", "return(2)")) err := mgr.Initialize() re.NoError(err) - re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/tso/loadKeyspaceGroupsTemporaryFail")) + re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/etcdutil/loadTemporaryFail")) } // TestLoadKeyspaceGroupsFailed tests the initialization should fail when there are too many failures @@ -224,13 +222,13 @@ func (suite *keyspaceGroupManagerTestSuite) TestLoadKeyspaceGroupsFailed() { suite.ctx, suite.etcdClient, true, mgr.legacySvcRootPath, mgr.tsoServiceID.ServiceAddr, uint32(0), []uint32{0}) - // Set the max retry times to 3 and inject the loadKeyspaceGroupsTemporaryFail to return 3 to let + // Set the max retry times to 3 and inject the loadTemporaryFail to return 3 to let // loading from etcd fail 3 times which should cause the whole initialization to fail. mgr.loadFromEtcdMaxRetryTimes = 3 - re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/tso/loadKeyspaceGroupsTemporaryFail", "return(3)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/etcdutil/loadTemporaryFail", "return(3)")) err := mgr.Initialize() re.Error(err) - re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/tso/loadKeyspaceGroupsTemporaryFail")) + re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/etcdutil/loadTemporaryFail")) } // TestWatchAndDynamicallyApplyChanges tests the keyspace group manager watch and dynamically apply diff --git a/pkg/utils/etcdutil/etcdutil.go b/pkg/utils/etcdutil/etcdutil.go index b65f8b901a4..32f32bd6087 100644 --- a/pkg/utils/etcdutil/etcdutil.go +++ b/pkg/utils/etcdutil/etcdutil.go @@ -20,6 +20,7 @@ import ( "math/rand" "net/http" "net/url" + "sync" "time" "github.com/gogo/protobuf/proto" @@ -27,9 +28,11 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/utils/logutil" "github.com/tikv/pd/pkg/utils/typeutil" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/etcdserver" + "go.etcd.io/etcd/mvcc/mvccpb" "go.etcd.io/etcd/pkg/types" "go.uber.org/zap" ) @@ -350,3 +353,268 @@ func InitOrGetClusterID(c *clientv3.Client, key string) (uint64, error) { return typeutil.BytesToUint64(response.Kvs[0].Value) } + +const ( + defaultLoadDataFromEtcdTimeout = 30 * time.Second + defaultLoadFromEtcdRetryInterval = 200 * time.Millisecond + defaultLoadFromEtcdRetryTimes = int(defaultLoadDataFromEtcdTimeout / defaultLoadFromEtcdRetryInterval) + defaultLoadBatchSize = 400 + defaultWatchChangeRetryInterval = 1 * time.Second +) + +// LoopWatcher loads data from etcd and sets a watcher for it. +type LoopWatcher struct { + ctx context.Context + wg *sync.WaitGroup + name string + client *clientv3.Client + + // key is the etcd key to watch. + key string + // opts is used to set etcd options. + opts []clientv3.OpOption + + // forceLoadCh is used to force loading data from etcd. + forceLoadCh chan struct{} + // isLoadedCh is used to notify that the data has been loaded from etcd first time. + isLoadedCh chan error + + // putFn is used to handle the put event. + putFn func(*mvccpb.KeyValue) error + // deleteFn is used to handle the delete event. + deleteFn func(*mvccpb.KeyValue) error + // postEventFn is used to call after handling all events. + postEventFn func() error + + // loadTimeout is used to set the timeout for loading data from etcd. + loadTimeout time.Duration + // loadRetryTimes is used to set the retry times for loading data from etcd. + loadRetryTimes int + // loadBatchSize is used to set the batch size for loading data from etcd. + loadBatchSize int64 + // watchChangeRetryInterval is used to set the retry interval for watching etcd change. + watchChangeRetryInterval time.Duration + // updateClientCh is used to update the etcd client. + // It's only used for testing. + updateClientCh chan *clientv3.Client +} + +// NewLoopWatcher creates a new LoopWatcher. +func NewLoopWatcher(ctx context.Context, wg *sync.WaitGroup, client *clientv3.Client, name, key string, + putFn, deleteFn func(*mvccpb.KeyValue) error, postEventFn func() error, opts ...clientv3.OpOption) *LoopWatcher { + return &LoopWatcher{ + ctx: ctx, + client: client, + name: name, + key: key, + wg: wg, + forceLoadCh: make(chan struct{}, 1), + isLoadedCh: make(chan error, 1), + updateClientCh: make(chan *clientv3.Client, 1), + putFn: putFn, + deleteFn: deleteFn, + postEventFn: postEventFn, + opts: opts, + loadTimeout: defaultLoadDataFromEtcdTimeout, + loadRetryTimes: defaultLoadFromEtcdRetryTimes, + loadBatchSize: defaultLoadBatchSize, + watchChangeRetryInterval: defaultWatchChangeRetryInterval, + } +} + +// StartWatchLoop starts a loop to watch the key. +func (lw *LoopWatcher) StartWatchLoop() { + defer logutil.LogPanic() + defer lw.wg.Done() + + ctx, cancel := context.WithTimeout(lw.ctx, lw.loadTimeout) + defer cancel() + watchStartRevision := lw.initFromEtcd(ctx) + + log.Info("start to watch loop", zap.String("name", lw.name), zap.String("key", lw.key)) + for { + select { + case <-lw.ctx.Done(): + log.Info("server is closed, exit watch loop", zap.String("name", lw.name), zap.String("key", lw.key)) + return + default: + } + nextRevision, err := lw.watch(lw.ctx, watchStartRevision) + if err != nil { + log.Error("watcher canceled unexpectedly and a new watcher will start after a while for watch loop", + zap.String("name", lw.name), + zap.String("key", lw.key), + zap.Int64("next-revision", nextRevision), + zap.Time("retry-at", time.Now().Add(lw.watchChangeRetryInterval)), + zap.Error(err)) + watchStartRevision = nextRevision + time.Sleep(lw.watchChangeRetryInterval) + failpoint.Inject("updateClient", func() { + lw.client = <-lw.updateClientCh + }) + } + } +} + +func (lw *LoopWatcher) initFromEtcd(ctx context.Context) int64 { + var ( + watchStartRevision int64 + err error + ) + ticker := time.NewTicker(defaultLoadFromEtcdRetryInterval) + defer ticker.Stop() + + for i := 0; i < lw.loadRetryTimes; i++ { + failpoint.Inject("loadTemporaryFail", func(val failpoint.Value) { + if maxFailTimes, ok := val.(int); ok && i < maxFailTimes { + err = errors.New("fail to read from etcd") + failpoint.Continue() + } + }) + failpoint.Inject("delayLoad", func(val failpoint.Value) { + if sleepIntervalSeconds, ok := val.(int); ok && sleepIntervalSeconds > 0 { + time.Sleep(time.Duration(sleepIntervalSeconds) * time.Second) + } + }) + watchStartRevision, err = lw.load(ctx) + if err == nil { + break + } + select { + case <-ctx.Done(): + lw.isLoadedCh <- errors.Errorf("ctx is done before load data from etcd") + return watchStartRevision + case <-ticker.C: + } + } + if err != nil { + log.Warn("meet error when loading in watch loop", zap.String("name", lw.name), zap.String("key", lw.key), zap.Error(err)) + } + lw.isLoadedCh <- err + return watchStartRevision +} + +func (lw *LoopWatcher) watch(ctx context.Context, revision int64) (nextRevision int64, err error) { + watcher := clientv3.NewWatcher(lw.client) + defer watcher.Close() + + for { + WatchChan: + opts := append(lw.opts, clientv3.WithRev(revision)) + watchChan := watcher.Watch(ctx, lw.key, opts...) + select { + case <-ctx.Done(): + return revision, nil + case <-lw.forceLoadCh: + revision, err = lw.load(ctx) + if err != nil { + log.Warn("force load key failed in watch loop", zap.String("name", lw.name), + zap.String("key", lw.key), zap.Error(err)) + } + goto WatchChan + case wresp := <-watchChan: + if wresp.CompactRevision != 0 { + log.Warn("required revision has been compacted, use the compact revision in watch loop", + zap.Int64("required-revision", revision), + zap.Int64("compact-revision", wresp.CompactRevision)) + revision = wresp.CompactRevision + goto WatchChan + } else if wresp.Err() != nil { // wresp.Err() contains CompactRevision not equal to 0 + log.Error("watcher is canceled in watch loop", + zap.Int64("revision", revision), + errs.ZapError(errs.ErrEtcdWatcherCancel, wresp.Err())) + return revision, wresp.Err() + } + for _, event := range wresp.Events { + switch event.Type { + case clientv3.EventTypePut: + if err := lw.putFn(event.Kv); err != nil { + log.Error("put failed in watch loop", zap.String("name", lw.name), + zap.String("key", lw.key), zap.Error(err)) + } + case clientv3.EventTypeDelete: + if err := lw.deleteFn(event.Kv); err != nil { + log.Error("delete failed in watch loop", zap.String("name", lw.name), + zap.String("key", lw.key), zap.Error(err)) + } + } + } + if err := lw.postEventFn(); err != nil { + log.Error("run post event failed in watch loop", zap.String("name", lw.name), + zap.String("key", lw.key), zap.Error(err)) + } + revision = wresp.Header.Revision + 1 + } + } +} + +func (lw *LoopWatcher) load(ctx context.Context) (nextRevision int64, err error) { + ctx, cancel := context.WithTimeout(ctx, DefaultRequestTimeout) + defer cancel() + startKey := lw.key + // If limit is 0, it means no limit. + // If limit is not 0, we need to add 1 to limit to get the next key. + limit := lw.loadBatchSize + if limit != 0 { + limit++ + } + for { + // Sort by key to get the next key and we don't need to worry about the performance, + // Because the default sort is just SortByKey and SortAscend + opts := append(lw.opts, clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend), clientv3.WithLimit(limit)) + resp, err := clientv3.NewKV(lw.client).Get(ctx, startKey, opts...) + if err != nil { + log.Error("load failed in watch loop", zap.String("name", lw.name), + zap.String("key", lw.key), zap.Error(err)) + return 0, err + } + for i, item := range resp.Kvs { + if resp.More && i == len(resp.Kvs)-1 { + // The last key is the start key of the next batch. + // To avoid to get the same key in the next load, we need to skip the last key. + startKey = string(item.Key) + continue + } + err = lw.putFn(item) + if err != nil { + log.Error("put failed in watch loop when loading", zap.String("name", lw.name), zap.String("key", lw.key), zap.Error(err)) + } + } + if !resp.More { + if err := lw.postEventFn(); err != nil { + log.Error("run post event failed in watch loop", zap.String("name", lw.name), + zap.String("key", lw.key), zap.Error(err)) + } + log.Info("load finished in watch loop", zap.String("name", lw.name), zap.String("key", lw.key)) + return resp.Header.Revision + 1, err + } + } +} + +// ForceLoad forces to load the key. +func (lw *LoopWatcher) ForceLoad() { + select { + case lw.forceLoadCh <- struct{}{}: + default: + } +} + +// WaitLoad waits for the result to obtain whether data is loaded. +func (lw *LoopWatcher) WaitLoad() error { + return <-lw.isLoadedCh +} + +// SetLoadRetryTimes sets the retry times when loading data from etcd. +func (lw *LoopWatcher) SetLoadRetryTimes(times int) { + lw.loadRetryTimes = times +} + +// SetLoadTimeout sets the timeout when loading data from etcd. +func (lw *LoopWatcher) SetLoadTimeout(timeout time.Duration) { + lw.loadTimeout = timeout +} + +// SetLoadBatchSize sets the batch size when loading data from etcd. +func (lw *LoopWatcher) SetLoadBatchSize(size int64) { + lw.loadBatchSize = size +} diff --git a/pkg/utils/etcdutil/etcdutil_test.go b/pkg/utils/etcdutil/etcdutil_test.go index e8aa901bee0..6bf63db79c9 100644 --- a/pkg/utils/etcdutil/etcdutil_test.go +++ b/pkg/utils/etcdutil/etcdutil_test.go @@ -21,19 +21,28 @@ import ( "io" "net" "strings" + "sync" "sync/atomic" "testing" "time" "github.com/pingcap/failpoint" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/utils/tempurl" + "github.com/tikv/pd/pkg/utils/testutil" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/embed" "go.etcd.io/etcd/etcdserver/etcdserverpb" + "go.etcd.io/etcd/mvcc/mvccpb" "go.etcd.io/etcd/pkg/types" + "go.uber.org/goleak" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m, testutil.LeakOptions...) +} + func TestMemberHelpers(t *testing.T) { re := require.New(t) cfg1 := NewTestSingleConfig(t) @@ -47,6 +56,9 @@ func TestMemberHelpers(t *testing.T) { client1, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep1}, }) + defer func() { + client1.Close() + }() re.NoError(err) <-etcd1.Server.ReadyNotify() @@ -66,6 +78,9 @@ func TestMemberHelpers(t *testing.T) { client2, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep2}, }) + defer func() { + client2.Close() + }() re.NoError(err) checkMembers(re, client2, []*embed.Etcd{etcd1, etcd2}) @@ -98,6 +113,9 @@ func TestEtcdKVGet(t *testing.T) { client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) + defer func() { + client.Close() + }() re.NoError(err) <-etcd.Server.ReadyNotify() @@ -148,6 +166,9 @@ func TestEtcdKVPutWithTTL(t *testing.T) { client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) + defer func() { + client.Close() + }() re.NoError(err) <-etcd.Server.ReadyNotify() @@ -188,6 +209,9 @@ func TestInitClusterID(t *testing.T) { client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) + defer func() { + client.Close() + }() re.NoError(err) <-etcd.Server.ReadyNotify() @@ -214,6 +238,9 @@ func TestEtcdClientSync(t *testing.T) { // Start a etcd server. cfg1 := NewTestSingleConfig(t) etcd1, err := embed.StartEtcd(cfg1) + defer func() { + etcd1.Close() + }() re.NoError(err) // Create a etcd client with etcd1 as endpoint. @@ -221,6 +248,9 @@ func TestEtcdClientSync(t *testing.T) { urls, err := types.NewURLs([]string{ep1}) re.NoError(err) client1, err := createEtcdClientWithMultiEndpoint(nil, urls) + defer func() { + client1.Close() + }() re.NoError(err) <-etcd1.Server.ReadyNotify() @@ -265,6 +295,9 @@ func TestEtcdScaleInAndOutWithoutMultiPoint(t *testing.T) { // Start a etcd server. cfg1 := NewTestSingleConfig(t) etcd1, err := embed.StartEtcd(cfg1) + defer func() { + etcd1.Close() + }() re.NoError(err) ep1 := cfg1.LCUrls[0].String() <-etcd1.Server.ReadyNotify() @@ -273,12 +306,21 @@ func TestEtcdScaleInAndOutWithoutMultiPoint(t *testing.T) { urls, err := types.NewURLs([]string{ep1}) re.NoError(err) client1, err := createEtcdClient(nil, urls[0]) // execute member change operation with this client + defer func() { + client1.Close() + }() re.NoError(err) client2, err := createEtcdClient(nil, urls[0]) // check member change with this client + defer func() { + client2.Close() + }() re.NoError(err) // Add a new member and check members etcd2 := checkAddEtcdMember(t, cfg1, client1) + defer func() { + etcd2.Close() + }() checkMembers(re, client2, []*embed.Etcd{etcd1, etcd2}) // scale in etcd1 @@ -292,6 +334,9 @@ func checkEtcdWithHangLeader(t *testing.T) error { // Start a etcd server. cfg1 := NewTestSingleConfig(t) etcd1, err := embed.StartEtcd(cfg1) + defer func() { + etcd1.Close() + }() re.NoError(err) ep1 := cfg1.LCUrls[0].String() <-etcd1.Server.ReadyNotify() @@ -305,6 +350,9 @@ func checkEtcdWithHangLeader(t *testing.T) error { urls, err := types.NewURLs([]string{proxyAddr}) re.NoError(err) client1, err := createEtcdClientWithMultiEndpoint(nil, urls) + defer func() { + client1.Close() + }() re.NoError(err) // Add a new member and set the client endpoints to etcd1 and etcd2. @@ -408,3 +456,288 @@ func ioCopy(dst io.Writer, src io.Reader, enableDiscard *atomic.Bool) (err error } return err } + +type loopWatcherTestSuite struct { + suite.Suite + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + cleans []func() + etcd *embed.Etcd + client *clientv3.Client + config *embed.Config +} + +func TestLoopWatcherTestSuite(t *testing.T) { + suite.Run(t, new(loopWatcherTestSuite)) +} + +func (suite *loopWatcherTestSuite) SetupSuite() { + t := suite.T() + suite.ctx, suite.cancel = context.WithCancel(context.Background()) + suite.cleans = make([]func(), 0) + // Start a etcd server and create a client with etcd1 as endpoint. + suite.config = NewTestSingleConfig(t) + suite.startEtcd() + ep1 := suite.config.LCUrls[0].String() + urls, err := types.NewURLs([]string{ep1}) + suite.NoError(err) + suite.client, err = createEtcdClient(nil, urls[0]) + suite.NoError(err) + suite.cleans = append(suite.cleans, func() { + suite.client.Close() + }) +} + +func (suite *loopWatcherTestSuite) TearDownSuite() { + suite.cancel() + suite.wg.Wait() + for _, clean := range suite.cleans { + clean() + } +} + +func (suite *loopWatcherTestSuite) TestLoadWithoutKey() { + cache := struct { + sync.RWMutex + data map[string]struct{} + }{ + data: make(map[string]struct{}), + } + watcher := NewLoopWatcher( + suite.ctx, + &suite.wg, + suite.client, + "test", + "TestLoadWithoutKey", + func(kv *mvccpb.KeyValue) error { + cache.Lock() + defer cache.Unlock() + cache.data[string(kv.Key)] = struct{}{} + return nil + }, + func(kv *mvccpb.KeyValue) error { return nil }, + func() error { return nil }, + ) + + suite.wg.Add(1) + go watcher.StartWatchLoop() + err := watcher.WaitLoad() + suite.NoError(err) // although no key, watcher returns no error + cache.RLock() + defer cache.RUnlock() + suite.Len(cache.data, 0) +} + +func (suite *loopWatcherTestSuite) TestCallBack() { + cache := struct { + sync.RWMutex + data map[string]struct{} + }{ + data: make(map[string]struct{}), + } + result := make([]string, 0) + watcher := NewLoopWatcher( + suite.ctx, + &suite.wg, + suite.client, + "test", + "TestCallBack", + func(kv *mvccpb.KeyValue) error { + result = append(result, string(kv.Key)) + return nil + }, + func(kv *mvccpb.KeyValue) error { + cache.Lock() + defer cache.Unlock() + delete(cache.data, string(kv.Key)) + return nil + }, + func() error { + cache.Lock() + defer cache.Unlock() + for _, r := range result { + cache.data[r] = struct{}{} + } + result = result[:0] + return nil + }, + clientv3.WithPrefix(), + ) + + suite.wg.Add(1) + go watcher.StartWatchLoop() + err := watcher.WaitLoad() + suite.NoError(err) + + // put 10 keys + for i := 0; i < 10; i++ { + suite.put(fmt.Sprintf("TestCallBack%d", i), "") + } + time.Sleep(time.Second) + cache.RLock() + suite.Len(cache.data, 10) + cache.RUnlock() + + // delete 10 keys + for i := 0; i < 10; i++ { + key := fmt.Sprintf("TestCallBack%d", i) + _, err = suite.client.Delete(suite.ctx, key) + suite.NoError(err) + } + time.Sleep(time.Second) + cache.RLock() + suite.Empty(cache.data) + cache.RUnlock() +} + +func (suite *loopWatcherTestSuite) TestWatcherLoadLimit() { + for count := 1; count < 10; count++ { + for limit := 0; limit < 10; limit++ { + ctx, cancel := context.WithCancel(suite.ctx) + for i := 0; i < count; i++ { + suite.put(fmt.Sprintf("TestWatcherLoadLimit%d", i), "") + } + cache := struct { + sync.RWMutex + data []string + }{ + data: make([]string, 0), + } + watcher := NewLoopWatcher( + ctx, + &suite.wg, + suite.client, + "test", + "TestWatcherLoadLimit", + func(kv *mvccpb.KeyValue) error { + cache.Lock() + defer cache.Unlock() + cache.data = append(cache.data, string(kv.Key)) + return nil + }, + func(kv *mvccpb.KeyValue) error { + return nil + }, + func() error { + return nil + }, + clientv3.WithPrefix(), + ) + suite.wg.Add(1) + go watcher.StartWatchLoop() + err := watcher.WaitLoad() + suite.NoError(err) + cache.RLock() + suite.Len(cache.data, count) + cache.RUnlock() + cancel() + } + } +} + +func (suite *loopWatcherTestSuite) TestWatcherBreak() { + cache := struct { + sync.RWMutex + data string + }{} + checkCache := func(expect string) { + testutil.Eventually(suite.Require(), func() bool { + cache.RLock() + defer cache.RUnlock() + return cache.data == expect + }, testutil.WithWaitFor(time.Second)) + } + + watcher := NewLoopWatcher( + suite.ctx, + &suite.wg, + suite.client, + "test", + "TestWatcherBreak", + func(kv *mvccpb.KeyValue) error { + if string(kv.Key) == "TestWatcherBreak" { + cache.Lock() + defer cache.Unlock() + cache.data = string(kv.Value) + } + return nil + }, + func(kv *mvccpb.KeyValue) error { return nil }, + func() error { return nil }, + ) + watcher.watchChangeRetryInterval = 100 * time.Millisecond + + suite.wg.Add(1) + go watcher.StartWatchLoop() + err := watcher.WaitLoad() + suite.NoError(err) + checkCache("") + + // we use close client and update client in failpoint to simulate the network error and recover + failpoint.Enable("github.com/tikv/pd/pkg/utils/etcdutil/updateClient", "return(true)") + + // Case1: restart the etcd server + suite.etcd.Close() + suite.startEtcd() + suite.put("TestWatcherBreak", "1") + checkCache("1") + + // Case2: close the etcd client and put a new value after watcher restarts + suite.client.Close() + suite.client, err = createEtcdClient(nil, suite.config.LCUrls[0]) + suite.NoError(err) + watcher.updateClientCh <- suite.client + suite.put("TestWatcherBreak", "2") + checkCache("2") + + // Case3: close the etcd client and put a new value before watcher restarts + suite.client.Close() + suite.client, err = createEtcdClient(nil, suite.config.LCUrls[0]) + suite.NoError(err) + suite.put("TestWatcherBreak", "3") + watcher.updateClientCh <- suite.client + checkCache("3") + + // Case4: close the etcd client and put a new value with compact + suite.client.Close() + suite.client, err = createEtcdClient(nil, suite.config.LCUrls[0]) + suite.NoError(err) + suite.put("TestWatcherBreak", "4") + resp, err := EtcdKVGet(suite.client, "TestWatcherBreak") + suite.NoError(err) + revision := resp.Header.Revision + resp2, err := suite.etcd.Server.Compact(suite.ctx, &etcdserverpb.CompactionRequest{Revision: revision}) + suite.NoError(err) + suite.Equal(revision, resp2.Header.Revision) + watcher.updateClientCh <- suite.client + checkCache("4") + + // Case5: there is an error data in cache + cache.Lock() + cache.data = "error" + cache.Unlock() + watcher.ForceLoad() + checkCache("4") + + failpoint.Disable("github.com/tikv/pd/pkg/utils/etcdutil/updateClient") +} + +func (suite *loopWatcherTestSuite) startEtcd() { + etcd1, err := embed.StartEtcd(suite.config) + suite.NoError(err) + suite.etcd = etcd1 + <-etcd1.Server.ReadyNotify() + suite.cleans = append(suite.cleans, func() { + suite.etcd.Close() + }) +} + +func (suite *loopWatcherTestSuite) put(key, value string) { + kv := clientv3.NewKV(suite.client) + _, err := kv.Put(suite.ctx, key, value) + suite.NoError(err) + resp, err := kv.Get(suite.ctx, key) + suite.NoError(err) + suite.Equal(value, string(resp.Kvs[0].Value)) +} diff --git a/pkg/utils/tsoutil/tso_dispatcher.go b/pkg/utils/tsoutil/tso_dispatcher.go index 2bf903a0f46..187fd7a527b 100644 --- a/pkg/utils/tsoutil/tso_dispatcher.go +++ b/pkg/utils/tsoutil/tso_dispatcher.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/utils/etcdutil" "github.com/tikv/pd/pkg/utils/logutil" "go.uber.org/zap" "google.golang.org/grpc" @@ -64,12 +65,12 @@ func (s *TSODispatcher) DispatchRequest( tsoProtoFactory ProtoFactory, doneCh <-chan struct{}, errCh chan<- error, - updateServicePrimaryAddrChs ...chan<- struct{}) { + tsoPrimaryWatchers ...*etcdutil.LoopWatcher) { val, loaded := s.dispatchChs.LoadOrStore(req.getForwardedHost(), make(chan Request, maxMergeRequests)) reqCh := val.(chan Request) if !loaded { tsDeadlineCh := make(chan deadline, 1) - go s.dispatch(ctx, tsoProtoFactory, req.getForwardedHost(), req.getClientConn(), reqCh, tsDeadlineCh, doneCh, errCh, updateServicePrimaryAddrChs...) + go s.dispatch(ctx, tsoProtoFactory, req.getForwardedHost(), req.getClientConn(), reqCh, tsDeadlineCh, doneCh, errCh, tsoPrimaryWatchers...) go watchTSDeadline(ctx, tsDeadlineCh) } reqCh <- req @@ -84,7 +85,7 @@ func (s *TSODispatcher) dispatch( tsDeadlineCh chan<- deadline, doneCh <-chan struct{}, errCh chan<- error, - updateServicePrimaryAddrChs ...chan<- struct{}) { + tsoPrimaryWatchers ...*etcdutil.LoopWatcher) { defer logutil.LogPanic() dispatcherCtx, ctxCancel := context.WithCancel(ctx) defer ctxCancel() @@ -111,7 +112,7 @@ func (s *TSODispatcher) dispatch( defer cancel() requests := make([]Request, maxMergeRequests+1) - needUpdateServicePrimaryAddr := len(updateServicePrimaryAddrChs) > 0 && updateServicePrimaryAddrChs[0] != nil + needUpdateServicePrimaryAddr := len(tsoPrimaryWatchers) > 0 && tsoPrimaryWatchers[0] != nil for { select { case first := <-tsoRequestCh: @@ -137,13 +138,8 @@ func (s *TSODispatcher) dispatch( log.Error("proxy forward tso error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCSend, err)) - if needUpdateServicePrimaryAddr { - if strings.Contains(err.Error(), errs.NotLeaderErr) { - select { - case updateServicePrimaryAddrChs[0] <- struct{}{}: - default: - } - } + if needUpdateServicePrimaryAddr && strings.Contains(err.Error(), errs.NotLeaderErr) { + tsoPrimaryWatchers[0].ForceLoad() } select { case <-dispatcherCtx.Done(): diff --git a/server/grpc_service.go b/server/grpc_service.go index 70c26b46448..c143d7d9443 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -230,7 +230,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { } tsoRequest := tsoutil.NewPDProtoRequest(forwardedHost, clientConn, request, stream) - s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh, s.updateServicePrimaryAddrCh) + s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh, s.tsoPrimaryWatcher) continue } @@ -1802,11 +1802,7 @@ func (s *GrpcServer) getGlobalTSOFromTSOServer(ctx context.Context) (pdpb.Timest ts, err = forwardStream.Recv() if err != nil { if strings.Contains(err.Error(), errs.NotLeaderErr) { - select { - case s.updateServicePrimaryAddrCh <- struct{}{}: - log.Info("update service primary address when meet not leader error") - default: - } + s.tsoPrimaryWatcher.ForceLoad() time.Sleep(retryIntervalRequestTSOServer) continue } diff --git a/server/keyspace_service.go b/server/keyspace_service.go index 0ecfe45c1d7..4069987510d 100644 --- a/server/keyspace_service.go +++ b/server/keyspace_service.go @@ -22,12 +22,11 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pingcap/kvproto/pkg/keyspacepb" "github.com/pingcap/kvproto/pkg/pdpb" - "github.com/pingcap/log" - "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/keyspace" "github.com/tikv/pd/pkg/storage/endpoint" + "github.com/tikv/pd/pkg/utils/etcdutil" "go.etcd.io/etcd/clientv3" - "go.uber.org/zap" + "go.etcd.io/etcd/mvcc/mvccpb" ) // KeyspaceServer wraps GrpcServer to provide keyspace service. @@ -73,79 +72,51 @@ func (s *KeyspaceServer) WatchKeyspaces(request *keyspacepb.WatchKeyspacesReques if err := s.validateRequest(request.GetHeader()); err != nil { return err } - ctx, cancel := context.WithCancel(s.Context()) defer cancel() + startKey := path.Join(s.rootPath, endpoint.KeyspaceMetaPrefix()) - revision, err := s.sendAllKeyspaceMeta(ctx, stream) - if err != nil { - return err - } - - watcher := clientv3.NewWatcher(s.client) - defer watcher.Close() - - for { - rch := watcher.Watch(ctx, path.Join(s.rootPath, endpoint.KeyspaceMetaPrefix()), clientv3.WithPrefix(), clientv3.WithRev(revision)) - for wresp := range rch { - if wresp.CompactRevision != 0 { - log.Warn("required revision has been compacted, use the compact revision", - zap.Int64("required-revision", revision), - zap.Int64("compact-revision", wresp.CompactRevision)) - revision = wresp.CompactRevision - break - } - if wresp.Canceled { - log.Error("watcher is canceled with", - zap.Int64("revision", revision), - errs.ZapError(errs.ErrEtcdWatcherCancel, wresp.Err())) - return wresp.Err() - } - keyspaces := make([]*keyspacepb.KeyspaceMeta, 0, len(wresp.Events)) - for _, event := range wresp.Events { - if event.Type != clientv3.EventTypePut { - continue - } - meta := &keyspacepb.KeyspaceMeta{} - if err = proto.Unmarshal(event.Kv.Value, meta); err != nil { - return err - } - keyspaces = append(keyspaces, meta) - } - if len(keyspaces) > 0 { - if err = stream.Send(&keyspacepb.WatchKeyspacesResponse{Header: s.header(), Keyspaces: keyspaces}); err != nil { - return err - } - } - } - select { - case <-ctx.Done(): - // server closed, return - return nil - default: + keyspaces := make([]*keyspacepb.KeyspaceMeta, 0) + putFn := func(kv *mvccpb.KeyValue) error { + meta := &keyspacepb.KeyspaceMeta{} + if err := proto.Unmarshal(kv.Value, meta); err != nil { + return err } + keyspaces = append(keyspaces, meta) + return nil } -} - -func (s *KeyspaceServer) sendAllKeyspaceMeta(ctx context.Context, stream keyspacepb.Keyspace_WatchKeyspacesServer) (int64, error) { - getResp, err := s.client.Get(ctx, path.Join(s.rootPath, endpoint.KeyspaceMetaPrefix()), clientv3.WithPrefix()) - if err != nil { - return 0, err + deleteFn := func(kv *mvccpb.KeyValue) error { + return nil } - metas := make([]*keyspacepb.KeyspaceMeta, getResp.Count) - for i, kv := range getResp.Kvs { - meta := &keyspacepb.KeyspaceMeta{} - if err = proto.Unmarshal(kv.Value, meta); err != nil { - return 0, err - } - metas[i] = meta + postEventFn := func() error { + defer func() { + keyspaces = keyspaces[:0] + }() + return stream.Send(&keyspacepb.WatchKeyspacesResponse{ + Header: s.header(), + Keyspaces: keyspaces}) } - var revision int64 - if getResp.Header != nil { - // start from the next revision - revision = getResp.Header.GetRevision() + 1 + + watcher := etcdutil.NewLoopWatcher( + ctx, + &s.serverLoopWg, + s.client, + "keyspace-server-watcher", + startKey, + putFn, + deleteFn, + postEventFn, + clientv3.WithPrefix(), + ) + s.serverLoopWg.Add(1) + go watcher.StartWatchLoop() + if err := watcher.WaitLoad(); err != nil { + cancel() // cancel context to stop watcher + return err } - return revision, stream.Send(&keyspacepb.WatchKeyspacesResponse{Header: s.header(), Keyspaces: metas}) + + <-ctx.Done() // wait for context done + return nil } // UpdateKeyspaceState updates the state of keyspace specified in the request. diff --git a/server/server.go b/server/server.go index a28963706de..0b10bf0d563 100644 --- a/server/server.go +++ b/server/server.go @@ -77,6 +77,7 @@ import ( syncer "github.com/tikv/pd/server/region_syncer" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/embed" + "go.etcd.io/etcd/mvcc/mvccpb" "go.etcd.io/etcd/pkg/types" "go.uber.org/zap" "google.golang.org/grpc" @@ -104,8 +105,6 @@ const ( maxRetryTimesGetServicePrimary = 25 // retryIntervalGetServicePrimary is the retry interval for getting primary addr. retryIntervalGetServicePrimary = 100 * time.Millisecond - // TODO: move it to etcdutil - watchEtcdChangeRetryInterval = 1 * time.Second ) // EtcdStartTimeout the timeout of the startup etcd. @@ -217,9 +216,7 @@ type Server struct { registry *registry.ServiceRegistry mode string servicePrimaryMap sync.Map /* Store as map[string]string */ - // updateServicePrimaryAddrCh is used to notify the server to update the service primary address. - // Note: it is only used in API service mode. - updateServicePrimaryAddrCh chan struct{} + tsoPrimaryWatcher *etcdutil.LoopWatcher } // HandlerBuilder builds a server HTTP handler. @@ -566,9 +563,10 @@ func (s *Server) startServerLoop(ctx context.Context) { go s.etcdLeaderLoop() go s.serverMetricsLoop() go s.encryptionKeyManagerLoop() - if s.IsAPIServiceMode() { // disable tso service in api server + if s.IsAPIServiceMode() { + s.initTSOPrimaryWatcher() s.serverLoopWg.Add(1) - go s.startWatchServicePrimaryAddrLoop(mcs.TSOServiceName) + go s.tsoPrimaryWatcher.StartWatchLoop() } } @@ -1722,124 +1720,6 @@ func (s *Server) GetServicePrimaryAddr(ctx context.Context, serviceName string) return "", false } -// startWatchServicePrimaryAddrLoop starts a loop to watch the primary address of a given service. -func (s *Server) startWatchServicePrimaryAddrLoop(serviceName string) { - defer logutil.LogPanic() - defer s.serverLoopWg.Done() - ctx, cancel := context.WithCancel(s.serverLoopCtx) - defer cancel() - s.updateServicePrimaryAddrCh = make(chan struct{}, 1) - serviceKey := s.servicePrimaryKey(serviceName) - var ( - revision int64 - err error - ) - for i := 0; i < maxRetryTimesGetServicePrimary; i++ { - revision, err = s.updateServicePrimaryAddr(serviceName) - if revision != 0 && err == nil { // update success - break - } - select { - case <-ctx.Done(): - return - case <-time.After(retryIntervalGetServicePrimary): - } - } - if err != nil { - log.Warn("service primary addr doesn't exist", zap.String("service-key", serviceKey), zap.Error(err)) - } - log.Info("start to watch service primary addr", zap.String("service-key", serviceKey)) - for { - select { - case <-ctx.Done(): - log.Info("server is closed, exist watch service primary addr loop", zap.String("service", serviceName)) - return - default: - } - nextRevision, err := s.watchServicePrimaryAddr(ctx, serviceName, revision) - if err != nil { - log.Error("watcher canceled unexpectedly and a new watcher will start after a while", - zap.Int64("next-revision", nextRevision), - zap.Time("retry-at", time.Now().Add(watchEtcdChangeRetryInterval)), - zap.Error(err)) - revision = nextRevision - time.Sleep(watchEtcdChangeRetryInterval) - } - } -} - -// watchServicePrimaryAddr watches the primary address on etcd. -func (s *Server) watchServicePrimaryAddr(ctx context.Context, serviceName string, revision int64) (nextRevision int64, err error) { - serviceKey := s.servicePrimaryKey(serviceName) - watcher := clientv3.NewWatcher(s.client) - defer watcher.Close() - - for { - WatchChan: - watchChan := watcher.Watch(s.serverLoopCtx, serviceKey, clientv3.WithRev(revision)) - select { - case <-ctx.Done(): - return revision, nil - case <-s.updateServicePrimaryAddrCh: - revision, err = s.updateServicePrimaryAddr(serviceName) - if err != nil { - log.Warn("update service primary addr failed", zap.String("service-key", serviceKey), zap.Error(err)) - } - goto WatchChan - case wresp := <-watchChan: - if wresp.CompactRevision != 0 { - log.Warn("required revision has been compacted, use the compact revision", - zap.Int64("required-revision", revision), - zap.Int64("compact-revision", wresp.CompactRevision)) - revision = wresp.CompactRevision - goto WatchChan - } - if wresp.Err() != nil { - log.Error("watcher is canceled with", - zap.Int64("revision", revision), - errs.ZapError(errs.ErrEtcdWatcherCancel, wresp.Err())) - return revision, wresp.Err() - } - for _, event := range wresp.Events { - switch event.Type { - case clientv3.EventTypePut: - primary := &tsopb.Participant{} - if err := proto.Unmarshal(event.Kv.Value, primary); err != nil { - log.Error("watch service primary addr failed", zap.String("service-key", serviceKey), zap.Error(err)) - } else { - listenUrls := primary.GetListenUrls() - if len(listenUrls) > 0 { - // listenUrls[0] is the primary service endpoint of the keyspace group - s.servicePrimaryMap.Store(serviceName, listenUrls[0]) - } else { - log.Warn("service primary addr doesn't exist", zap.String("service-key", serviceKey)) - } - } - case clientv3.EventTypeDelete: - log.Warn("service primary is deleted", zap.String("service-key", serviceKey)) - s.servicePrimaryMap.Delete(serviceName) - } - } - revision = wresp.Header.Revision + 1 - } - } -} - -// updateServicePrimaryAddr updates the primary address from etcd with get operation. -func (s *Server) updateServicePrimaryAddr(serviceName string) (nextRevision int64, err error) { - serviceKey := s.servicePrimaryKey(serviceName) - primary := &tsopb.Participant{} - ok, revision, err := etcdutil.GetProtoMsgWithModRev(s.client, serviceKey, primary) - listenUrls := primary.GetListenUrls() - if !ok || err != nil || len(listenUrls) == 0 { - return 0, err - } - // listenUrls[0] is the primary service endpoint of the keyspace group - s.servicePrimaryMap.Store(serviceName, listenUrls[0]) - log.Info("update service primary addr", zap.String("service-key", serviceKey), zap.String("primary-addr", listenUrls[0])) - return revision, nil -} - // SetServicePrimaryAddr sets the primary address directly. // Note: This function is only used for test. func (s *Server) SetServicePrimaryAddr(serviceName, addr string) { @@ -1850,6 +1730,37 @@ func (s *Server) servicePrimaryKey(serviceName string) string { return fmt.Sprintf("/ms/%d/%s/%s/%s", s.clusterID, serviceName, fmt.Sprintf("%05d", 0), "primary") } +func (s *Server) initTSOPrimaryWatcher() { + serviceName := mcs.TSOServiceName + tsoServicePrimaryKey := s.servicePrimaryKey(serviceName) + putFn := func(kv *mvccpb.KeyValue) error { + primary := &tsopb.Participant{} // TODO: use Generics + if err := proto.Unmarshal(kv.Value, primary); err != nil { + return err + } + listenUrls := primary.GetListenUrls() + if len(listenUrls) > 0 { + // listenUrls[0] is the primary service endpoint of the keyspace group + s.servicePrimaryMap.Store(serviceName, listenUrls[0]) + } + return nil + } + deleteFn := func(kv *mvccpb.KeyValue) error { + s.servicePrimaryMap.Delete(serviceName) + return nil + } + s.tsoPrimaryWatcher = etcdutil.NewLoopWatcher( + s.serverLoopCtx, + &s.serverLoopWg, + s.client, + "tso-primary-watcher", + tsoServicePrimaryKey, + putFn, + deleteFn, + func() error { return nil }, + ) +} + // RecoverAllocID recover alloc id. set current base id to input id func (s *Server) RecoverAllocID(ctx context.Context, id uint64) error { return s.idAllocator.SetBase(id)