diff --git a/client/client.go b/client/client.go index 802a377cbf1..0683aafe7a1 100644 --- a/client/client.go +++ b/client/client.go @@ -791,10 +791,10 @@ func (c *client) GetLocalTSAsync(ctx context.Context, dcLocation string) TSFutur return req } - if err := tsoClient.dispatchRequest(dcLocation, req); err != nil { + if err := tsoClient.dispatchRequest(ctx, dcLocation, req); err != nil { // Wait for a while and try again time.Sleep(50 * time.Millisecond) - if err = tsoClient.dispatchRequest(dcLocation, req); err != nil { + if err = tsoClient.dispatchRequest(ctx, dcLocation, req); err != nil { req.done <- err } } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index c5136d7fd09..18e9e13bf73 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -73,7 +73,7 @@ func (c *tsoClient) scheduleUpdateTSOConnectionCtxs() { } } -func (c *tsoClient) dispatchRequest(dcLocation string, request *tsoRequest) error { +func (c *tsoClient) dispatchRequest(ctx context.Context, dcLocation string, request *tsoRequest) error { dispatcher, ok := c.tsoDispatcher.Load(dcLocation) if !ok { err := errs.ErrClientGetTSO.FastGenByArgs(fmt.Sprintf("unknown dc-location %s to the client", dcLocation)) @@ -83,7 +83,11 @@ func (c *tsoClient) dispatchRequest(dcLocation string, request *tsoRequest) erro } defer trace.StartRegion(request.requestCtx, "tsoReqEnqueue").End() - dispatcher.(*tsoDispatcher).tsoBatchController.tsoRequestCh <- request + select { + case <-ctx.Done(): + return ctx.Err() + case dispatcher.(*tsoDispatcher).tsoBatchController.tsoRequestCh <- request: + } return nil } @@ -311,6 +315,14 @@ func (c *tsoClient) createTSODispatcher(dcLocation string) { make(chan *tsoRequest, defaultMaxTSOBatchSize*2), defaultMaxTSOBatchSize), } + failpoint.Inject("shortDispatcherChannel", func() { + dispatcher = &tsoDispatcher{ + dispatcherCancel: dispatcherCancel, + tsoBatchController: newTSOBatchController( + make(chan *tsoRequest, 1), + defaultMaxTSOBatchSize), + } + }) if _, ok := c.tsoDispatcher.LoadOrStore(dcLocation, dispatcher); !ok { // Successfully stored the value. Start the following goroutine. diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 3e1b0be472e..b66b15d8243 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -858,6 +858,51 @@ func (suite *followerForwardAndHandleTestSuite) TestGetRegionFromFollower() { re.NoError(failpoint.Disable("github.com/tikv/pd/client/fastCheckAvailable")) } +func (suite *followerForwardAndHandleTestSuite) TestGetTSFuture() { + re := suite.Require() + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + + re.NoError(failpoint.Enable("github.com/tikv/pd/client/shortDispatcherChannel", "return(true)")) + + cli := setupCli(re, ctx, suite.endpoints) + + ctxs := make([]context.Context, 20) + cancels := make([]context.CancelFunc, 20) + for i := 0; i < 20; i++ { + ctxs[i], cancels[i] = context.WithCancel(ctx) + } + start := time.Now() + wg1 := sync.WaitGroup{} + wg2 := sync.WaitGroup{} + wg3 := sync.WaitGroup{} + wg1.Add(1) + go func() { + <-time.After(time.Second) + for i := 0; i < 20; i++ { + cancels[i]() + } + wg1.Done() + }() + wg2.Add(1) + go func() { + cli.Close() + wg2.Done() + }() + wg3.Add(1) + go func() { + for i := 0; i < 20; i++ { + cli.GetTSAsync(ctxs[i]) + } + wg3.Done() + }() + wg1.Wait() + wg2.Wait() + wg3.Wait() + re.Less(time.Since(start), time.Second*2) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/shortDispatcherChannel")) +} + func checkTS(re *require.Assertions, cli pd.Client, lastTS uint64) uint64 { for i := 0; i < tsoRequestRound; i++ { physical, logical, err := cli.GetTS(context.TODO())