diff --git a/client/grpcutil/grpcutil.go b/client/grpcutil/grpcutil.go index 4d6fdd229dd..59b7224f29e 100644 --- a/client/grpcutil/grpcutil.go +++ b/client/grpcutil/grpcutil.go @@ -90,12 +90,13 @@ func GetOrCreateGRPCConn(ctx context.Context, clientConns *sync.Map, addr string if err != nil { return nil, err } - old, ok := clientConns.LoadOrStore(addr, cc) - if !ok { + conn, loaded := clientConns.LoadOrStore(addr, cc) + if !loaded { // Successfully stored the connection. return cc, nil } cc.Close() - log.Debug("use old connection", zap.String("target", cc.Target()), zap.String("state", cc.GetState().String())) - return old.(*grpc.ClientConn), nil + cc = conn.(*grpc.ClientConn) + log.Debug("use existing connection", zap.String("target", cc.Target()), zap.String("state", cc.GetState().String())) + return cc, nil } diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 3ad05ca7cba..842c772abd9 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -44,16 +44,16 @@ func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tso // fetchPendingRequests will start a new round of the batch collecting from the channel. // It returns true if everything goes well, otherwise false which means we should stop the service. func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error { - var firstTSORequest *tsoRequest + var firstRequest *tsoRequest select { case <-ctx.Done(): return ctx.Err() - case firstTSORequest = <-tbc.tsoRequestCh: + case firstRequest = <-tbc.tsoRequestCh: } // Start to batch when the first TSO request arrives. tbc.batchStartTime = time.Now() tbc.collectedRequestCount = 0 - tbc.pushRequest(firstTSORequest) + tbc.pushRequest(firstRequest) // This loop is for trying best to collect more requests, so we use `tbc.maxBatchSize` here. fetchPendingRequestsLoop: @@ -130,7 +130,7 @@ func (tbc *tsoBatchController) adjustBestBatchSize() { } } -func (tbc *tsoBatchController) revokePendingTSORequest(err error) { +func (tbc *tsoBatchController) revokePendingRequest(err error) { for i := 0; i < len(tbc.tsoRequestCh); i++ { req := <-tbc.tsoRequestCh req.done <- err diff --git a/client/tso_client.go b/client/tso_client.go index c5427af9dc3..a13d635b986 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -141,7 +141,7 @@ func (c *tsoClient) Close() { if dispatcherInterface != nil { dispatcher := dispatcherInterface.(*tsoDispatcher) tsoErr := errors.WithStack(errClosing) - dispatcher.tsoBatchController.revokePendingTSORequest(tsoErr) + dispatcher.tsoBatchController.revokePendingRequest(tsoErr) dispatcher.dispatcherCancel() } return true diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 7af4c859a3e..7a7cda43361 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -404,7 +404,7 @@ tsoBatchLoop: err = errs.ErrClientCreateTSOStream.FastGenByArgs(errs.RetryTimeoutErr) log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) c.svcDiscovery.ScheduleCheckMemberChanged() - c.finishTSORequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err)) + c.finishRequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err)) continue tsoBatchLoop case <-time.After(retryInterval): continue streamChoosingLoop @@ -440,7 +440,7 @@ tsoBatchLoop: case tsDeadlineCh.(chan deadline) <- dl: } opts = extractSpanReference(tbc, opts[:0]) - err = c.processTSORequests(stream, dc, tbc, opts) + err = c.processRequests(stream, dc, tbc, opts) close(done) // If error happens during tso stream handling, reset stream and run the next trial. if err != nil { @@ -691,9 +691,9 @@ func extractSpanReference(tbc *tsoBatchController, opts []opentracing.StartSpanO return opts } -func (c *tsoClient) processTSORequests(stream tsoStream, dcLocation string, tbc *tsoBatchController, opts []opentracing.StartSpanOption) error { +func (c *tsoClient) processRequests(stream tsoStream, dcLocation string, tbc *tsoBatchController, opts []opentracing.StartSpanOption) error { if len(opts) > 0 { - span := opentracing.StartSpan("pdclient.processTSORequests", opts...) + span := opentracing.StartSpan("pdclient.processRequests", opts...) defer span.Finish() } @@ -701,13 +701,13 @@ func (c *tsoClient) processTSORequests(stream tsoStream, dcLocation string, tbc count := int64(len(requests)) physical, logical, suffixBits, err := stream.processRequests(c.svcDiscovery.GetClusterID(), dcLocation, requests, tbc.batchStartTime) if err != nil { - c.finishTSORequest(requests, 0, 0, 0, err) + c.finishRequest(requests, 0, 0, 0, err) return err } // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. firstLogical := addLogical(logical, -count+1, suffixBits) c.compareAndSwapTS(dcLocation, physical, firstLogical, suffixBits, count) - c.finishTSORequest(requests, physical, firstLogical, suffixBits, nil) + c.finishRequest(requests, physical, firstLogical, suffixBits, nil) return nil } @@ -729,8 +729,9 @@ func (c *tsoClient) compareAndSwapTS(dcLocation string, physical, firstLogical i lastTSOPointer := lastTSOInterface.(*lastTSO) lastPhysical := lastTSOPointer.physical lastLogical := lastTSOPointer.logical - // The TSO we get is a range like [largestLogical-count+1, largestLogical], so we save the last TSO's largest logical to compare with the new TSO's first logical. - // For example, if we have a TSO resp with logical 10, count 5, then all TSOs we get will be [6, 7, 8, 9, 10]. + // The TSO we get is a range like [largestLogical-count+1, largestLogical], so we save the last TSO's largest logical + // to compare with the new TSO's first logical. For example, if we have a TSO resp with logical 10, count 5, then + // all TSOs we get will be [6, 7, 8, 9, 10]. if tsLessEqual(physical, firstLogical, lastPhysical, lastLogical) { panic(errors.Errorf("%s timestamp fallback, newly acquired ts (%d, %d) is less or equal to last one (%d, %d)", dcLocation, physical, firstLogical, lastPhysical, lastLogical)) @@ -747,7 +748,7 @@ func tsLessEqual(physical, logical, thatPhysical, thatLogical int64) bool { return physical < thatPhysical } -func (c *tsoClient) finishTSORequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) { +func (c *tsoClient) finishRequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) { for i := 0; i < len(requests); i++ { if span := opentracing.SpanFromContext(requests[i].requestCtx); span != nil { span.Finish() diff --git a/client/tso_stream.go b/client/tso_stream.go index c8435abf9e5..baa764dffb2 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -143,7 +143,8 @@ func (s *pdTSOStream) processRequests(clusterID uint64, dcLocation string, reque return } - physical, logical, suffixBits = resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits() + ts := resp.GetTimestamp() + physical, logical, suffixBits = ts.GetPhysical(), ts.GetLogical(), ts.GetSuffixBits() return } @@ -189,6 +190,7 @@ func (s *tsoTSOStream) processRequests(clusterID uint64, dcLocation string, requ return } - physical, logical, suffixBits = resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits() + ts := resp.GetTimestamp() + physical, logical, suffixBits = ts.GetPhysical(), ts.GetLogical(), ts.GetSuffixBits() return } diff --git a/pkg/mcs/tso/server/grpc_service.go b/pkg/mcs/tso/server/grpc_service.go index c650c4910ad..c5a4059e421 100644 --- a/pkg/mcs/tso/server/grpc_service.go +++ b/pkg/mcs/tso/server/grpc_service.go @@ -20,28 +20,19 @@ import ( "net/http" "time" - "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/kvproto/pkg/tsopb" "github.com/pingcap/log" "github.com/pkg/errors" bs "github.com/tikv/pd/pkg/basicserver" - "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mcs/registry" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/grpcutil" - "github.com/tikv/pd/pkg/utils/logutil" - "go.uber.org/zap" + "github.com/tikv/pd/pkg/utils/tsoutil" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) -const ( - // tso - maxMergeTSORequests = 10000 - defaultTSOProxyTimeout = 3 * time.Second -) - // gRPC errors var ( ErrNotStarted = status.Errorf(codes.Unavailable, "server not started") @@ -116,16 +107,20 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error { streamCtx := stream.Context() forwardedHost := grpcutil.GetForwardedHost(streamCtx) if !s.IsLocalRequest(forwardedHost) { + clientConn, err := s.GetDelegateClient(s.ctx, forwardedHost) + if err != nil { + return errors.WithStack(err) + } + if errCh == nil { doneCh = make(chan struct{}) defer close(doneCh) errCh = make(chan error) } - s.dispatchTSORequest(ctx, &tsoRequest{ - forwardedHost, - request, - stream, - }, forwardedHost, doneCh, errCh) + + tsoProtoFactory := s.tsoProtoFactory + tsoRequest := tsoutil.NewTSOProtoRequest(forwardedHost, clientConn, request, stream) + s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh) continue } @@ -138,7 +133,7 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error { return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId()) } count := request.GetCount() - ts, err := s.tsoAllocatorManager.HandleTSORequest(request.GetDcLocation(), count) + ts, err := s.tsoAllocatorManager.HandleRequest(request.GetDcLocation(), count) if err != nil { return status.Errorf(codes.Unknown, err.Error()) } @@ -174,181 +169,3 @@ func (s *Service) errorHeader(err *tsopb.Error) *tsopb.ResponseHeader { Error: err, } } - -type tsoRequest struct { - forwardedHost string - request *tsopb.TsoRequest - stream tsopb.TSO_TsoServer -} - -func (s *Service) dispatchTSORequest(ctx context.Context, request *tsoRequest, forwardedHost string, doneCh <-chan struct{}, errCh chan<- error) { - tsoRequestChInterface, loaded := s.tsoDispatcher.LoadOrStore(forwardedHost, make(chan *tsoRequest, maxMergeTSORequests)) - if !loaded { - tsDeadlineCh := make(chan deadline, 1) - go s.handleDispatcher(ctx, forwardedHost, tsoRequestChInterface.(chan *tsoRequest), tsDeadlineCh, doneCh, errCh) - go watchTSDeadline(ctx, tsDeadlineCh) - } - tsoRequestChInterface.(chan *tsoRequest) <- request -} - -func (s *Service) handleDispatcher(ctx context.Context, forwardedHost string, tsoRequestCh <-chan *tsoRequest, tsDeadlineCh chan<- deadline, doneCh <-chan struct{}, errCh chan<- error) { - defer logutil.LogPanic() - dispatcherCtx, ctxCancel := context.WithCancel(ctx) - defer ctxCancel() - defer s.tsoDispatcher.Delete(forwardedHost) - - var ( - forwardStream tsopb.TSO_TsoClient - cancel context.CancelFunc - ) - client, err := s.GetDelegateClient(ctx, forwardedHost) - if err != nil { - goto errHandling - } - log.Info("create tso forward stream", zap.String("forwarded-host", forwardedHost)) - forwardStream, cancel, err = s.CreateTsoForwardStream(client) -errHandling: - if err != nil || forwardStream == nil { - log.Error("create tso forwarding stream error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCCreateStream, err)) - select { - case <-dispatcherCtx.Done(): - return - case _, ok := <-doneCh: - if !ok { - return - } - case errCh <- err: - close(errCh) - return - } - } - defer cancel() - - requests := make([]*tsoRequest, maxMergeTSORequests+1) - for { - select { - case first := <-tsoRequestCh: - pendingTSOReqCount := len(tsoRequestCh) + 1 - requests[0] = first - for i := 1; i < pendingTSOReqCount; i++ { - requests[i] = <-tsoRequestCh - } - done := make(chan struct{}) - dl := deadline{ - timer: time.After(defaultTSOProxyTimeout), - done: done, - cancel: cancel, - } - select { - case tsDeadlineCh <- dl: - case <-dispatcherCtx.Done(): - return - } - err = s.processTSORequests(forwardStream, requests[:pendingTSOReqCount]) - close(done) - if err != nil { - log.Error("proxy forward tso error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCSend, err)) - select { - case <-dispatcherCtx.Done(): - return - case _, ok := <-doneCh: - if !ok { - return - } - case errCh <- err: - close(errCh) - return - } - } - case <-dispatcherCtx.Done(): - return - } - } -} - -func (s *Service) processTSORequests(forwardStream tsopb.TSO_TsoClient, requests []*tsoRequest) error { - start := time.Now() - // Merge the requests - count := uint32(0) - for _, request := range requests { - count += request.request.GetCount() - } - req := &tsopb.TsoRequest{ - Header: requests[0].request.GetHeader(), - Count: count, - // TODO: support Local TSO proxy forwarding. - DcLocation: requests[0].request.GetDcLocation(), - } - // Send to the leader stream. - if err := forwardStream.Send(req); err != nil { - return err - } - resp, err := forwardStream.Recv() - if err != nil { - return err - } - tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) - tsoProxyBatchSize.Observe(float64(count)) - // Split the response - physical, logical, suffixBits := resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits() - // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. - // This is different from the logic of client batch, for example, if we have a largest ts whose logical part is 10, - // count is 5, then the splitting results should be 5 and 10. - firstLogical := addLogical(logical, -int64(count), suffixBits) - return s.finishTSORequest(requests, physical, firstLogical, suffixBits) -} - -// Because of the suffix, we need to shift the count before we add it to the logical part. -func addLogical(logical, count int64, suffixBits uint32) int64 { - return logical + count< 0 { + clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) if err != nil { return errors.WithStack(err) } + if errCh == nil { doneCh = make(chan struct{}) defer close(doneCh) errCh = make(chan error) } - s.dispatchTSORequest(ctx, &tsoRequest{ - forwardedHost, - request, - stream, - }, forwardedHost, doneCh, errCh, true) - continue - } - forwardedHost := grpcutil.GetForwardedHost(streamCtx) - if !s.isLocalRequest(forwardedHost) { - if errCh == nil { - doneCh = make(chan struct{}) - defer close(doneCh) - errCh = make(chan error) + + var tsoProtoFactory tsoutil.ProtoFactory + if s.IsAPIServiceMode() { + tsoProtoFactory = s.tsoProtoFactory + } else { + tsoProtoFactory = s.pdProtoFactory } - s.dispatchTSORequest(ctx, &tsoRequest{ - forwardedHost, - request, - stream, - }, forwardedHost, doneCh, errCh, false) + + tsoRequest := tsoutil.NewPDProtoRequest(forwardedHost, clientConn, request, stream) + s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh) continue } @@ -258,7 +237,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId()) } count := request.GetCount() - ts, err := s.tsoAllocatorManager.HandleTSORequest(request.GetDcLocation(), count) + ts, err := s.tsoAllocatorManager.HandleRequest(request.GetDcLocation(), count) if err != nil { return status.Errorf(codes.Unknown, err.Error()) } @@ -274,215 +253,17 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { } } -type tsoRequest struct { - forwardedHost string - request *pdpb.TsoRequest - stream pdpb.PD_TsoServer -} - -func (s *GrpcServer) dispatchTSORequest(ctx context.Context, request *tsoRequest, forwardedHost string, doneCh <-chan struct{}, errCh chan<- error, withTSOProto bool) { - tsoRequestChInterface, loaded := s.tsoDispatcher.LoadOrStore(forwardedHost, make(chan *tsoRequest, maxMergeTSORequests)) - if !loaded { - tsDeadlineCh := make(chan deadline, 1) - go s.handleDispatcher(ctx, forwardedHost, tsoRequestChInterface.(chan *tsoRequest), tsDeadlineCh, doneCh, errCh, withTSOProto) - go watchTSDeadline(ctx, tsDeadlineCh) - } - tsoRequestChInterface.(chan *tsoRequest) <- request -} - -func (s *GrpcServer) handleDispatcher(ctx context.Context, forwardedHost string, tsoRequestCh <-chan *tsoRequest, tsDeadlineCh chan<- deadline, doneCh <-chan struct{}, errCh chan<- error, withTSOProto bool) { - dispatcherCtx, ctxCancel := context.WithCancel(ctx) - defer ctxCancel() - defer s.tsoDispatcher.Delete(forwardedHost) - - var ( - forwardStream pdpb.PD_TsoClient - forwardMCSStream tsopb.TSO_TsoClient - cancel context.CancelFunc - ) - client, err := s.getDelegateClient(ctx, forwardedHost) - if err != nil { - goto errHandling - } - log.Info("create tso forward stream", zap.String("forwarded-host", forwardedHost)) - if withTSOProto { - forwardMCSStream, cancel, err = s.createMCSTSOForwardStream(client) - } else { - forwardStream, cancel, err = s.createTsoForwardStream(client) - } -errHandling: - if err != nil || (forwardStream == nil && !withTSOProto) || (forwardMCSStream == nil && withTSOProto) { - log.Error("create tso forwarding stream error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCCreateStream, err)) - select { - case <-dispatcherCtx.Done(): - return - case _, ok := <-doneCh: - if !ok { - return - } - case errCh <- err: - close(errCh) - return - } - } - defer cancel() - - requests := make([]*tsoRequest, maxMergeTSORequests+1) - for { - select { - case first := <-tsoRequestCh: - pendingTSOReqCount := len(tsoRequestCh) + 1 - requests[0] = first - for i := 1; i < pendingTSOReqCount; i++ { - requests[i] = <-tsoRequestCh - } - done := make(chan struct{}) - dl := deadline{ - timer: time.After(defaultTSOProxyTimeout), - done: done, - cancel: cancel, - } - select { - case tsDeadlineCh <- dl: - case <-dispatcherCtx.Done(): - return - } - err = s.processTSORequests(forwardStream, forwardMCSStream, requests[:pendingTSOReqCount]) - close(done) - if err != nil { - log.Error("proxy forward tso error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCSend, err)) - select { - case <-dispatcherCtx.Done(): - return - case _, ok := <-doneCh: - if !ok { - return - } - case errCh <- err: - close(errCh) - return - } - } - case <-dispatcherCtx.Done(): - return - } - } -} - -type tsoResp interface { - GetTimestamp() *pdpb.Timestamp -} - -func (s *GrpcServer) processTSORequests(forwardStream pdpb.PD_TsoClient, forwardMCSStream tsopb.TSO_TsoClient, requests []*tsoRequest) error { - start := time.Now() - // Merge the requests - count := uint32(0) - for _, request := range requests { - count += request.request.GetCount() - } - var ( - resp tsoResp - err error - ) - if forwardStream != nil { - req := &pdpb.TsoRequest{ - Header: requests[0].request.GetHeader(), - Count: count, - // TODO: support Local TSO proxy forwarding. - DcLocation: requests[0].request.GetDcLocation(), - } - // Send to the tso server stream - if err := forwardStream.Send(req); err != nil { - return err - } - resp, err = forwardStream.Recv() - if err != nil { - return err - } - } - if forwardMCSStream != nil { - req := &tsopb.TsoRequest{ - Header: &tsopb.RequestHeader{ - ClusterId: requests[0].request.GetHeader().GetClusterId(), - KeyspaceId: utils.DefaultKeyspaceID, - KeyspaceGroupId: utils.DefaultKeySpaceGroupID, - }, - Count: count, - // TODO: support Local TSO proxy forwarding. - DcLocation: requests[0].request.GetDcLocation(), - } - // Send to the tso server stream. - if err := forwardMCSStream.Send(req); err != nil { - return err - } - resp, err = forwardMCSStream.Recv() - if err != nil { - return err - } - } - tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) - tsoProxyBatchSize.Observe(float64(count)) - // Split the response - physical, logical, suffixBits := resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits() - // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. - // This is different from the logic of client batch, for example, if we have a largest ts whose logical part is 10, - // count is 5, then the splitting results should be 5 and 10. - firstLogical := addLogical(logical, -int64(count), suffixBits) - return s.finishTSORequest(requests, physical, firstLogical, suffixBits) -} - -// Because of the suffix, we need to shift the count before we add it to the logical part. -func addLogical(logical, count int64, suffixBits uint32) int64 { - return logical + count<