diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 3ad05ca7cba1..842c772abd96 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -44,16 +44,16 @@ func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tso // fetchPendingRequests will start a new round of the batch collecting from the channel. // It returns true if everything goes well, otherwise false which means we should stop the service. func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error { - var firstTSORequest *tsoRequest + var firstRequest *tsoRequest select { case <-ctx.Done(): return ctx.Err() - case firstTSORequest = <-tbc.tsoRequestCh: + case firstRequest = <-tbc.tsoRequestCh: } // Start to batch when the first TSO request arrives. tbc.batchStartTime = time.Now() tbc.collectedRequestCount = 0 - tbc.pushRequest(firstTSORequest) + tbc.pushRequest(firstRequest) // This loop is for trying best to collect more requests, so we use `tbc.maxBatchSize` here. fetchPendingRequestsLoop: @@ -130,7 +130,7 @@ func (tbc *tsoBatchController) adjustBestBatchSize() { } } -func (tbc *tsoBatchController) revokePendingTSORequest(err error) { +func (tbc *tsoBatchController) revokePendingRequest(err error) { for i := 0; i < len(tbc.tsoRequestCh); i++ { req := <-tbc.tsoRequestCh req.done <- err diff --git a/client/tso_client.go b/client/tso_client.go index c5427af9dc30..a13d635b986f 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -141,7 +141,7 @@ func (c *tsoClient) Close() { if dispatcherInterface != nil { dispatcher := dispatcherInterface.(*tsoDispatcher) tsoErr := errors.WithStack(errClosing) - dispatcher.tsoBatchController.revokePendingTSORequest(tsoErr) + dispatcher.tsoBatchController.revokePendingRequest(tsoErr) dispatcher.dispatcherCancel() } return true diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 7af4c859a3ee..c9c0f47d4d3f 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -404,7 +404,7 @@ tsoBatchLoop: err = errs.ErrClientCreateTSOStream.FastGenByArgs(errs.RetryTimeoutErr) log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) c.svcDiscovery.ScheduleCheckMemberChanged() - c.finishTSORequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err)) + c.finishRequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err)) continue tsoBatchLoop case <-time.After(retryInterval): continue streamChoosingLoop @@ -440,7 +440,7 @@ tsoBatchLoop: case tsDeadlineCh.(chan deadline) <- dl: } opts = extractSpanReference(tbc, opts[:0]) - err = c.processTSORequests(stream, dc, tbc, opts) + err = c.processRequests(stream, dc, tbc, opts) close(done) // If error happens during tso stream handling, reset stream and run the next trial. if err != nil { @@ -691,9 +691,9 @@ func extractSpanReference(tbc *tsoBatchController, opts []opentracing.StartSpanO return opts } -func (c *tsoClient) processTSORequests(stream tsoStream, dcLocation string, tbc *tsoBatchController, opts []opentracing.StartSpanOption) error { +func (c *tsoClient) processRequests(stream tsoStream, dcLocation string, tbc *tsoBatchController, opts []opentracing.StartSpanOption) error { if len(opts) > 0 { - span := opentracing.StartSpan("pdclient.processTSORequests", opts...) + span := opentracing.StartSpan("pdclient.processRequests", opts...) defer span.Finish() } @@ -701,13 +701,13 @@ func (c *tsoClient) processTSORequests(stream tsoStream, dcLocation string, tbc count := int64(len(requests)) physical, logical, suffixBits, err := stream.processRequests(c.svcDiscovery.GetClusterID(), dcLocation, requests, tbc.batchStartTime) if err != nil { - c.finishTSORequest(requests, 0, 0, 0, err) + c.finishRequest(requests, 0, 0, 0, err) return err } // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. firstLogical := addLogical(logical, -count+1, suffixBits) c.compareAndSwapTS(dcLocation, physical, firstLogical, suffixBits, count) - c.finishTSORequest(requests, physical, firstLogical, suffixBits, nil) + c.finishRequest(requests, physical, firstLogical, suffixBits, nil) return nil } @@ -747,7 +747,7 @@ func tsLessEqual(physical, logical, thatPhysical, thatLogical int64) bool { return physical < thatPhysical } -func (c *tsoClient) finishTSORequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) { +func (c *tsoClient) finishRequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) { for i := 0; i < len(requests); i++ { if span := opentracing.SpanFromContext(requests[i].requestCtx); span != nil { span.Finish() diff --git a/pkg/mcs/tso/server/grpc_service.go b/pkg/mcs/tso/server/grpc_service.go index 504cb8f14439..c5a4059e421d 100644 --- a/pkg/mcs/tso/server/grpc_service.go +++ b/pkg/mcs/tso/server/grpc_service.go @@ -118,8 +118,8 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error { errCh = make(chan error) } - tsoProtoFactory := s.TSOProtoFactory - tsoRequest := tsoutil.NewTSOProtoTSORequest(forwardedHost, clientConn, request, stream) + tsoProtoFactory := s.tsoProtoFactory + tsoRequest := tsoutil.NewTSOProtoRequest(forwardedHost, clientConn, request, stream) s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh) continue } @@ -133,7 +133,7 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error { return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId()) } count := request.GetCount() - ts, err := s.tsoAllocatorManager.HandleTSORequest(request.GetDcLocation(), count) + ts, err := s.tsoAllocatorManager.HandleRequest(request.GetDcLocation(), count) if err != nil { return status.Errorf(codes.Unknown, err.Error()) } diff --git a/pkg/mcs/tso/server/server.go b/pkg/mcs/tso/server/server.go index ab2a0ac519f4..0c83fc4e2527 100644 --- a/pkg/mcs/tso/server/server.go +++ b/pkg/mcs/tso/server/server.go @@ -111,12 +111,12 @@ type Server struct { tsoAllocatorManager *tso.AllocatorManager // Store as map[string]*grpc.ClientConn clientConns sync.Map - // tsoDispatcher is used to dispatch different TSO requests to - // the corresponding forwarding TSO channel. + // tsoDispatcher is used to dispatch the TSO requests to + // the corresponding forwarding TSO channels. tsoDispatcher *tsoutil.TSODispatcher - // TSOProtoFactory is the abstract factory for creating tso - // related data structures defined in pd protocol - TSOProtoFactory *tsoutil.TSOProtoFactory + // tsoProtoFactory is the abstract factory for creating tso + // related data structures defined in the tso grpc protocol + tsoProtoFactory *tsoutil.TSOProtoFactory // Callback functions for different stages // startCallbacks will be called after the server is started. @@ -562,7 +562,7 @@ func (s *Server) startServer() (err error) { // Set up the Global TSO Allocator here, it will be initialized once this TSO participant campaigns leader successfully. s.tsoAllocatorManager.SetUpAllocator(s.ctx, tso.GlobalDCLocation, s.participant.GetLeadership()) s.tsoDispatcher = tsoutil.NewTSODispatcher(tsoProxyHandleDuration, tsoProxyBatchSize) - s.TSOProtoFactory = &tsoutil.TSOProtoFactory{} + s.tsoProtoFactory = &tsoutil.TSOProtoFactory{} s.service = &Service{Server: s} diff --git a/pkg/tso/allocator_manager.go b/pkg/tso/allocator_manager.go index a3de87311bb8..2a5fed634941 100644 --- a/pkg/tso/allocator_manager.go +++ b/pkg/tso/allocator_manager.go @@ -983,8 +983,8 @@ func (am *AllocatorManager) deleteAllocatorGroup(dcLocation string) { } } -// HandleTSORequest forwards TSO allocation requests to correct TSO Allocators. -func (am *AllocatorManager) HandleTSORequest(dcLocation string, count uint32) (pdpb.Timestamp, error) { +// HandleRequest forwards TSO allocation requests to correct TSO Allocators. +func (am *AllocatorManager) HandleRequest(dcLocation string, count uint32) (pdpb.Timestamp, error) { if dcLocation == "" { dcLocation = GlobalDCLocation } diff --git a/pkg/utils/tsoutil/tso_dispatcher.go b/pkg/utils/tsoutil/tso_dispatcher.go index 7f91d945ee05..26e24e119995 100644 --- a/pkg/utils/tsoutil/tso_dispatcher.go +++ b/pkg/utils/tsoutil/tso_dispatcher.go @@ -28,7 +28,7 @@ import ( ) const ( - maxMergeTSORequests = 10000 + maxMergeRequests = 10000 // DefaultTSOProxyTimeout defines the default timeout value of TSP Proxying DefaultTSOProxyTimeout = 3 * time.Second ) @@ -37,13 +37,13 @@ type tsoResp interface { GetTimestamp() *pdpb.Timestamp } -// TSODispatcher is used to dispatch different TSO requests to the corresponding forwarding TSO channel. +// TSODispatcher dispatches the TSO requests to the corresponding forwarding TSO channels. type TSODispatcher struct { tsoProxyHandleDuration prometheus.Histogram tsoProxyBatchSize prometheus.Histogram - // dispatchChs is used to dispatch different TSO requests to the corresponding forwarding TSO channel. - dispatchChs sync.Map // Store as map[string]chan TSORequest + // dispatchChs is used to dispatch different TSO requests to the corresponding forwarding TSO channels. + dispatchChs sync.Map // Store as map[string]chan Request } // NewTSODispatcher creates and returns a TSODispatcher @@ -56,18 +56,21 @@ func NewTSODispatcher(tsoProxyHandleDuration, tsoProxyBatchSize prometheus.Histo } // DispatchRequest is the entry point for dispatching/forwarding a tso request to the detination host -func (s *TSODispatcher) DispatchRequest(ctx context.Context, req TSORequest, tsoProtoFactory ProtoFactory, doneCh <-chan struct{}, errCh chan<- error) { - tsoRequestChInterface, loaded := s.dispatchChs.LoadOrStore(req.getForwardedHost(), make(chan TSORequest, maxMergeTSORequests)) +func (s *TSODispatcher) DispatchRequest( + ctx context.Context, req Request, tsoProtoFactory ProtoFactory, doneCh <-chan struct{}, errCh chan<- error) { + val, loaded := s.dispatchChs.LoadOrStore(req.getForwardedHost(), make(chan Request, maxMergeRequests)) + reqCh := val.(chan Request) if !loaded { tsDeadlineCh := make(chan deadline, 1) - go s.handleDispatcher(ctx, tsoProtoFactory, req.getForwardedHost(), req.getClientConn(), tsoRequestChInterface.(chan TSORequest), tsDeadlineCh, doneCh, errCh) + go s.dispatch(ctx, tsoProtoFactory, req.getForwardedHost(), req.getClientConn(), reqCh, tsDeadlineCh, doneCh, errCh) go watchTSDeadline(ctx, tsDeadlineCh) } - tsoRequestChInterface.(chan TSORequest) <- req + reqCh <- req } -func (s *TSODispatcher) handleDispatcher(ctx context.Context, tsoProtoFactory ProtoFactory, forwardedHost string, clientConn *grpc.ClientConn, - tsoRequestCh <-chan TSORequest, tsDeadlineCh chan<- deadline, doneCh <-chan struct{}, errCh chan<- error) { +func (s *TSODispatcher) dispatch( + ctx context.Context, tsoProtoFactory ProtoFactory, forwardedHost string, clientConn *grpc.ClientConn, + tsoRequestCh <-chan Request, tsDeadlineCh chan<- deadline, doneCh <-chan struct{}, errCh chan<- error) { dispatcherCtx, ctxCancel := context.WithCancel(ctx) defer ctxCancel() defer s.dispatchChs.Delete(forwardedHost) @@ -75,7 +78,9 @@ func (s *TSODispatcher) handleDispatcher(ctx context.Context, tsoProtoFactory Pr log.Info("create tso forward stream", zap.String("forwarded-host", forwardedHost)) forwardStream, cancel, err := tsoProtoFactory.createForwardStream(ctx, clientConn) if err != nil || forwardStream == nil { - log.Error("create tso forwarding stream error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCCreateStream, err)) + log.Error("create tso forwarding stream error", + zap.String("forwarded-host", forwardedHost), + errs.ZapError(errs.ErrGRPCCreateStream, err)) select { case <-dispatcherCtx.Done(): return @@ -90,7 +95,7 @@ func (s *TSODispatcher) handleDispatcher(ctx context.Context, tsoProtoFactory Pr } defer cancel() - requests := make([]TSORequest, maxMergeTSORequests+1) + requests := make([]Request, maxMergeRequests+1) for { select { case first := <-tsoRequestCh: @@ -110,10 +115,12 @@ func (s *TSODispatcher) handleDispatcher(ctx context.Context, tsoProtoFactory Pr case <-dispatcherCtx.Done(): return } - err = s.processTSORequests(forwardStream, requests[:pendingTSOReqCount], tsoProtoFactory) + err = s.processRequests(forwardStream, requests[:pendingTSOReqCount], tsoProtoFactory) close(done) if err != nil { - log.Error("proxy forward tso error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCSend, err)) + log.Error("proxy forward tso error", + zap.String("forwarded-host", forwardedHost), + errs.ZapError(errs.ErrGRPCSend, err)) select { case <-dispatcherCtx.Done(): return @@ -132,7 +139,7 @@ func (s *TSODispatcher) handleDispatcher(ctx context.Context, tsoProtoFactory Pr } } -func (s *TSODispatcher) processTSORequests(forwardStream tsoStream, requests []TSORequest, tsoProtoFactory ProtoFactory) error { +func (s *TSODispatcher) processRequests(forwardStream stream, requests []Request, tsoProtoFactory ProtoFactory) error { start := time.Now() // Merge the requests count := uint32(0) @@ -146,12 +153,13 @@ func (s *TSODispatcher) processTSORequests(forwardStream tsoStream, requests []T s.tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) s.tsoProxyBatchSize.Observe(float64(count)) // Split the response - physical, logical, suffixBits := resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits() + ts := resp.GetTimestamp() + physical, logical, suffixBits := ts.GetPhysical(), ts.GetLogical(), ts.GetSuffixBits() // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. // This is different from the logic of client batch, for example, if we have a largest ts whose logical part is 10, // count is 5, then the splitting results should be 5 and 10. firstLogical := addLogical(logical, -int64(count), suffixBits) - return s.finishTSORequest(requests, physical, firstLogical, suffixBits) + return s.finishRequest(requests, physical, firstLogical, suffixBits) } // Because of the suffix, we need to shift the count before we add it to the logical part. @@ -159,7 +167,7 @@ func addLogical(logical, count int64, suffixBits uint32) int64 { return logical + count<