diff --git a/client/client.go b/client/client.go index 8838c184d92..b9535aa504e 100644 --- a/client/client.go +++ b/client/client.go @@ -793,7 +793,7 @@ func (c *client) GetLocalTSAsync(ctx context.Context, dcLocation string) TSFutur req := c.getTSORequest(ctx, dcLocation) if err := c.dispatchTSORequestWithRetry(req); err != nil { - req.done <- err + req.tryDone(err) } return req } diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index bd7a440fb08..5f3b08c2895 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -19,7 +19,10 @@ import ( "runtime/trace" "time" + "github.com/pingcap/errors" + "github.com/pingcap/log" "github.com/tikv/pd/client/tsoutil" + "go.uber.org/zap" ) type tsoBatchController struct { @@ -138,7 +141,7 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in tsoReq := tbc.collectedRequests[i] tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits) defer trace.StartRegion(tsoReq.requestCtx, "pdclient.tsoReqDequeue").End() - tsoReq.done <- err + tsoReq.tryDone(err) } // Prevent the finished requests from being processed again. tbc.collectedRequestCount = 0 @@ -147,6 +150,15 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in func (tbc *tsoBatchController) revokePendingRequests(err error) { for i := 0; i < len(tbc.tsoRequestCh); i++ { req := <-tbc.tsoRequestCh - req.done <- err + req.tryDone(err) } } + +func (tbc *tsoBatchController) clear() { + log.Info("[pd] clear the tso batch controller", + zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), + zap.Int("collected-request-count", tbc.collectedRequestCount), zap.Int("pending-request-count", len(tbc.tsoRequestCh))) + tsoErr := errors.WithStack(errClosing) + tbc.finishCollectedRequests(0, 0, 0, tsoErr) + tbc.revokePendingRequests(tsoErr) +} diff --git a/client/tso_client.go b/client/tso_client.go index c563df0efdb..5f8b12df36f 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -21,7 +21,6 @@ import ( "sync" "time" - "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/tikv/pd/client/errs" "go.uber.org/zap" @@ -64,6 +63,13 @@ var tsoReqPool = sync.Pool{ }, } +func (req *tsoRequest) tryDone(err error) { + select { + case req.done <- err: + default: + } +} + type tsoClient struct { ctx context.Context cancel context.CancelFunc @@ -140,9 +146,8 @@ func (c *tsoClient) Close() { c.tsoDispatcher.Range(func(_, dispatcherInterface any) bool { if dispatcherInterface != nil { dispatcher := dispatcherInterface.(*tsoDispatcher) - tsoErr := errors.WithStack(errClosing) - dispatcher.tsoBatchController.revokePendingRequests(tsoErr) dispatcher.dispatcherCancel() + dispatcher.tsoBatchController.clear() } return true }) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index a625f8dbbe1..88f8ffd61b5 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -95,8 +95,23 @@ func (c *tsoClient) dispatchRequest(request *tsoRequest) (bool, error) { // tsoClient is closed due to the PD service mode switch, which is retryable. return true, c.ctx.Err() default: + // This failpoint will increase the possibility that the request is sent to a closed dispatcher. + failpoint.Inject("delayDispatchTSORequest", func() { + time.Sleep(time.Second) + }) dispatcher.(*tsoDispatcher).tsoBatchController.tsoRequestCh <- request } + // Check the contexts again to make sure the request is not been sent to a closed dispatcher. + // Never retry on these conditions to prevent unexpected data race. + select { + case <-request.requestCtx.Done(): + return false, request.requestCtx.Err() + case <-request.clientCtx.Done(): + return false, request.clientCtx.Err() + case <-c.ctx.Done(): + return false, c.ctx.Err() + default: + } return false, nil } @@ -368,6 +383,8 @@ func (c *tsoClient) handleDispatcher( cc.(*tsoConnectionContext).cancel() return true }) + // Clear the tso batch controller. + tbc.clear() c.wg.Done() }() // Call updateTSOConnectionCtxs once to init the connectionCtxs first. diff --git a/pkg/utils/etcdutil/etcdutil_test.go b/pkg/utils/etcdutil/etcdutil_test.go index 4fb96895942..e02615b695f 100644 --- a/pkg/utils/etcdutil/etcdutil_test.go +++ b/pkg/utils/etcdutil/etcdutil_test.go @@ -239,7 +239,7 @@ func TestRandomKillEtcd(t *testing.T) { // Randomly kill an etcd server and restart it cfgs := []embed.Config{etcds[0].Config(), etcds[1].Config(), etcds[2].Config()} - for i := 0; i < 10; i++ { + for i := 0; i < len(cfgs)*2; i++ { killIndex := rand.Intn(len(etcds)) etcds[killIndex].Close() checkEtcdEndpointNum(re, client1, 2) @@ -452,9 +452,9 @@ func (suite *loopWatcherTestSuite) TestLoadWithLimitChange() { re := suite.Require() re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/etcdutil/meetEtcdError", `return()`)) cache := make(map[string]struct{}) - for i := 0; i < int(maxLoadBatchSize)*2; i++ { + testutil.GenerateTestDataConcurrently(int(maxLoadBatchSize)*2, func(i int) { suite.put(re, fmt.Sprintf("TestLoadWithLimitChange%d", i), "") - } + }) watcher := NewLoopWatcher( suite.ctx, &suite.wg, @@ -583,25 +583,9 @@ func (suite *loopWatcherTestSuite) TestWatcherLoadLargeKey() { count := 65536 ctx, cancel := context.WithCancel(suite.ctx) defer cancel() - - // create data - var wg sync.WaitGroup - tasks := make(chan int, count) - for w := 0; w < 16; w++ { - wg.Add(1) - go func() { - defer wg.Done() - for i := range tasks { - suite.put(re, fmt.Sprintf("TestWatcherLoadLargeKey/test-%d", i), "") - } - }() - } - for i := 0; i < count; i++ { - tasks <- i - } - close(tasks) - wg.Wait() - + testutil.GenerateTestDataConcurrently(count, func(i int) { + suite.put(re, fmt.Sprintf("TestWatcherLoadLargeKey/test-%d", i), "") + }) cache := make([]string, 0) watcher := NewLoopWatcher( ctx, diff --git a/pkg/utils/testutil/testutil.go b/pkg/utils/testutil/testutil.go index a41fc436ca6..cef952353bc 100644 --- a/pkg/utils/testutil/testutil.go +++ b/pkg/utils/testutil/testutil.go @@ -16,7 +16,9 @@ package testutil import ( "os" + "runtime" "strings" + "sync" "time" "github.com/pingcap/kvproto/pkg/pdpb" @@ -101,3 +103,24 @@ func InitTempFileLogger(level string) (fname string) { log.ReplaceGlobals(lg, p) return fname } + +// GenerateTestDataConcurrently generates test data concurrently. +func GenerateTestDataConcurrently(count int, f func(int)) { + var wg sync.WaitGroup + tasks := make(chan int, count) + workers := runtime.NumCPU() + for w := 0; w < workers; w++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := range tasks { + f(i) + } + }() + } + for i := 0; i < count; i++ { + tasks <- i + } + close(tasks) + wg.Wait() +} diff --git a/server/api/region_test.go b/server/api/region_test.go index e10bfbd1af0..4198cdcb694 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -23,7 +23,6 @@ import ( "net/http" "net/url" "sort" - "sync" "testing" "time" @@ -333,29 +332,15 @@ func TestRegionsWithKillRequest(t *testing.T) { addr := svr.GetAddr() url := fmt.Sprintf("%s%s/api/v1/regions", addr, apiPrefix) mustBootstrapCluster(re, svr) - regionCount := 100000 - // create data - var wg sync.WaitGroup - tasks := make(chan int, regionCount) - for w := 0; w < 16; w++ { - wg.Add(1) - go func() { - defer wg.Done() - for i := range tasks { - r := core.NewTestRegionInfo(uint64(i+2), 1, - []byte(fmt.Sprintf("%09d", i)), - []byte(fmt.Sprintf("%09d", i+1)), - core.SetApproximateKeys(10), core.SetApproximateSize(10)) - mustRegionHeartbeat(re, svr, r) - } - }() - } - for i := 0; i < regionCount; i++ { - tasks <- i - } - close(tasks) - wg.Wait() + regionCount := 100000 + tu.GenerateTestDataConcurrently(regionCount, func(i int) { + r := core.NewTestRegionInfo(uint64(i+2), 1, + []byte(fmt.Sprintf("%09d", i)), + []byte(fmt.Sprintf("%09d", i+1)), + core.SetApproximateKeys(10), core.SetApproximateSize(10)) + mustRegionHeartbeat(re, svr, r) + }) ctx, cancel := context.WithCancel(context.Background()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) diff --git a/tests/integrations/tso/client_test.go b/tests/integrations/tso/client_test.go index b0bd6f1d4e5..d4f484087cf 100644 --- a/tests/integrations/tso/client_test.go +++ b/tests/integrations/tso/client_test.go @@ -425,8 +425,9 @@ func (suite *tsoClientTestSuite) TestRandomShutdown() { re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/tso/fastUpdatePhysicalInterval")) } -func (suite *tsoClientTestSuite) TestGetTSWhileRestingTSOClient() { +func (suite *tsoClientTestSuite) TestGetTSWhileResettingTSOClient() { re := suite.Require() + re.NoError(failpoint.Enable("github.com/tikv/pd/client/delayDispatchTSORequest", "return(true)")) var ( clients []pd.Client stopSignal atomic.Bool @@ -467,6 +468,7 @@ func (suite *tsoClientTestSuite) TestGetTSWhileRestingTSOClient() { } stopSignal.Store(true) wg.Wait() + re.NoError(failpoint.Disable("github.com/tikv/pd/client/delayDispatchTSORequest")) } // When we upgrade the PD cluster, there may be a period of time that the old and new PDs are running at the same time.