Skip to content

Commit

Permalink
Refactor TSO forward/dispatcher to be shared by both PD and TSO (#6175)
Browse files Browse the repository at this point in the history
ref #5895

Add general tso forward/dispatcher for independent pd(tso)/tso services and cross cluster forwarding.

Signed-off-by: Bin Shi <binshi.bing@gmail.com>
  • Loading branch information
binshi-bing authored Mar 27, 2023
1 parent 16703ab commit d67bf27
Show file tree
Hide file tree
Showing 14 changed files with 582 additions and 510 deletions.
9 changes: 5 additions & 4 deletions client/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ func GetOrCreateGRPCConn(ctx context.Context, clientConns *sync.Map, addr string
if err != nil {
return nil, err
}
old, ok := clientConns.LoadOrStore(addr, cc)
if !ok {
conn, loaded := clientConns.LoadOrStore(addr, cc)
if !loaded {
// Successfully stored the connection.
return cc, nil
}
cc.Close()
log.Debug("use old connection", zap.String("target", cc.Target()), zap.String("state", cc.GetState().String()))
return old.(*grpc.ClientConn), nil
cc = conn.(*grpc.ClientConn)
log.Debug("use existing connection", zap.String("target", cc.Target()), zap.String("state", cc.GetState().String()))
return cc, nil
}
8 changes: 4 additions & 4 deletions client/tso_batch_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tso
// fetchPendingRequests will start a new round of the batch collecting from the channel.
// It returns true if everything goes well, otherwise false which means we should stop the service.
func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error {
var firstTSORequest *tsoRequest
var firstRequest *tsoRequest
select {
case <-ctx.Done():
return ctx.Err()
case firstTSORequest = <-tbc.tsoRequestCh:
case firstRequest = <-tbc.tsoRequestCh:
}
// Start to batch when the first TSO request arrives.
tbc.batchStartTime = time.Now()
tbc.collectedRequestCount = 0
tbc.pushRequest(firstTSORequest)
tbc.pushRequest(firstRequest)

// This loop is for trying best to collect more requests, so we use `tbc.maxBatchSize` here.
fetchPendingRequestsLoop:
Expand Down Expand Up @@ -130,7 +130,7 @@ func (tbc *tsoBatchController) adjustBestBatchSize() {
}
}

func (tbc *tsoBatchController) revokePendingTSORequest(err error) {
func (tbc *tsoBatchController) revokePendingRequest(err error) {
for i := 0; i < len(tbc.tsoRequestCh); i++ {
req := <-tbc.tsoRequestCh
req.done <- err
Expand Down
2 changes: 1 addition & 1 deletion client/tso_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (c *tsoClient) Close() {
if dispatcherInterface != nil {
dispatcher := dispatcherInterface.(*tsoDispatcher)
tsoErr := errors.WithStack(errClosing)
dispatcher.tsoBatchController.revokePendingTSORequest(tsoErr)
dispatcher.tsoBatchController.revokePendingRequest(tsoErr)
dispatcher.dispatcherCancel()
}
return true
Expand Down
19 changes: 10 additions & 9 deletions client/tso_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ tsoBatchLoop:
err = errs.ErrClientCreateTSOStream.FastGenByArgs(errs.RetryTimeoutErr)
log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err))
c.svcDiscovery.ScheduleCheckMemberChanged()
c.finishTSORequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err))
c.finishRequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err))
continue tsoBatchLoop
case <-time.After(retryInterval):
continue streamChoosingLoop
Expand Down Expand Up @@ -440,7 +440,7 @@ tsoBatchLoop:
case tsDeadlineCh.(chan deadline) <- dl:
}
opts = extractSpanReference(tbc, opts[:0])
err = c.processTSORequests(stream, dc, tbc, opts)
err = c.processRequests(stream, dc, tbc, opts)
close(done)
// If error happens during tso stream handling, reset stream and run the next trial.
if err != nil {
Expand Down Expand Up @@ -691,23 +691,23 @@ func extractSpanReference(tbc *tsoBatchController, opts []opentracing.StartSpanO
return opts
}

func (c *tsoClient) processTSORequests(stream tsoStream, dcLocation string, tbc *tsoBatchController, opts []opentracing.StartSpanOption) error {
func (c *tsoClient) processRequests(stream tsoStream, dcLocation string, tbc *tsoBatchController, opts []opentracing.StartSpanOption) error {
if len(opts) > 0 {
span := opentracing.StartSpan("pdclient.processTSORequests", opts...)
span := opentracing.StartSpan("pdclient.processRequests", opts...)
defer span.Finish()
}

requests := tbc.getCollectedRequests()
count := int64(len(requests))
physical, logical, suffixBits, err := stream.processRequests(c.svcDiscovery.GetClusterID(), dcLocation, requests, tbc.batchStartTime)
if err != nil {
c.finishTSORequest(requests, 0, 0, 0, err)
c.finishRequest(requests, 0, 0, 0, err)
return err
}
// `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request.
firstLogical := addLogical(logical, -count+1, suffixBits)
c.compareAndSwapTS(dcLocation, physical, firstLogical, suffixBits, count)
c.finishTSORequest(requests, physical, firstLogical, suffixBits, nil)
c.finishRequest(requests, physical, firstLogical, suffixBits, nil)
return nil
}

Expand All @@ -729,8 +729,9 @@ func (c *tsoClient) compareAndSwapTS(dcLocation string, physical, firstLogical i
lastTSOPointer := lastTSOInterface.(*lastTSO)
lastPhysical := lastTSOPointer.physical
lastLogical := lastTSOPointer.logical
// The TSO we get is a range like [largestLogical-count+1, largestLogical], so we save the last TSO's largest logical to compare with the new TSO's first logical.
// For example, if we have a TSO resp with logical 10, count 5, then all TSOs we get will be [6, 7, 8, 9, 10].
// The TSO we get is a range like [largestLogical-count+1, largestLogical], so we save the last TSO's largest logical
// to compare with the new TSO's first logical. For example, if we have a TSO resp with logical 10, count 5, then
// all TSOs we get will be [6, 7, 8, 9, 10].
if tsLessEqual(physical, firstLogical, lastPhysical, lastLogical) {
panic(errors.Errorf("%s timestamp fallback, newly acquired ts (%d, %d) is less or equal to last one (%d, %d)",
dcLocation, physical, firstLogical, lastPhysical, lastLogical))
Expand All @@ -747,7 +748,7 @@ func tsLessEqual(physical, logical, thatPhysical, thatLogical int64) bool {
return physical < thatPhysical
}

func (c *tsoClient) finishTSORequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) {
func (c *tsoClient) finishRequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) {
for i := 0; i < len(requests); i++ {
if span := opentracing.SpanFromContext(requests[i].requestCtx); span != nil {
span.Finish()
Expand Down
6 changes: 4 additions & 2 deletions client/tso_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ func (s *pdTSOStream) processRequests(clusterID uint64, dcLocation string, reque
return
}

physical, logical, suffixBits = resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits()
ts := resp.GetTimestamp()
physical, logical, suffixBits = ts.GetPhysical(), ts.GetLogical(), ts.GetSuffixBits()
return
}

Expand Down Expand Up @@ -189,6 +190,7 @@ func (s *tsoTSOStream) processRequests(clusterID uint64, dcLocation string, requ
return
}

physical, logical, suffixBits = resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits()
ts := resp.GetTimestamp()
physical, logical, suffixBits = ts.GetPhysical(), ts.GetLogical(), ts.GetSuffixBits()
return
}
205 changes: 11 additions & 194 deletions pkg/mcs/tso/server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,19 @@ import (
"net/http"
"time"

"github.com/pingcap/kvproto/pkg/pdpb"
"github.com/pingcap/kvproto/pkg/tsopb"
"github.com/pingcap/log"
"github.com/pkg/errors"
bs "github.com/tikv/pd/pkg/basicserver"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/mcs/registry"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/grpcutil"
"github.com/tikv/pd/pkg/utils/logutil"
"go.uber.org/zap"
"github.com/tikv/pd/pkg/utils/tsoutil"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const (
// tso
maxMergeTSORequests = 10000
defaultTSOProxyTimeout = 3 * time.Second
)

// gRPC errors
var (
ErrNotStarted = status.Errorf(codes.Unavailable, "server not started")
Expand Down Expand Up @@ -116,16 +107,20 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error {
streamCtx := stream.Context()
forwardedHost := grpcutil.GetForwardedHost(streamCtx)
if !s.IsLocalRequest(forwardedHost) {
clientConn, err := s.GetDelegateClient(s.ctx, forwardedHost)
if err != nil {
return errors.WithStack(err)
}

if errCh == nil {
doneCh = make(chan struct{})
defer close(doneCh)
errCh = make(chan error)
}
s.dispatchTSORequest(ctx, &tsoRequest{
forwardedHost,
request,
stream,
}, forwardedHost, doneCh, errCh)

tsoProtoFactory := s.tsoProtoFactory
tsoRequest := tsoutil.NewTSOProtoRequest(forwardedHost, clientConn, request, stream)
s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh)
continue
}

Expand All @@ -138,7 +133,7 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error {
return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId())
}
count := request.GetCount()
ts, err := s.tsoAllocatorManager.HandleTSORequest(request.GetDcLocation(), count)
ts, err := s.tsoAllocatorManager.HandleRequest(request.GetDcLocation(), count)
if err != nil {
return status.Errorf(codes.Unknown, err.Error())
}
Expand Down Expand Up @@ -174,181 +169,3 @@ func (s *Service) errorHeader(err *tsopb.Error) *tsopb.ResponseHeader {
Error: err,
}
}

type tsoRequest struct {
forwardedHost string
request *tsopb.TsoRequest
stream tsopb.TSO_TsoServer
}

func (s *Service) dispatchTSORequest(ctx context.Context, request *tsoRequest, forwardedHost string, doneCh <-chan struct{}, errCh chan<- error) {
tsoRequestChInterface, loaded := s.tsoDispatcher.LoadOrStore(forwardedHost, make(chan *tsoRequest, maxMergeTSORequests))
if !loaded {
tsDeadlineCh := make(chan deadline, 1)
go s.handleDispatcher(ctx, forwardedHost, tsoRequestChInterface.(chan *tsoRequest), tsDeadlineCh, doneCh, errCh)
go watchTSDeadline(ctx, tsDeadlineCh)
}
tsoRequestChInterface.(chan *tsoRequest) <- request
}

func (s *Service) handleDispatcher(ctx context.Context, forwardedHost string, tsoRequestCh <-chan *tsoRequest, tsDeadlineCh chan<- deadline, doneCh <-chan struct{}, errCh chan<- error) {
defer logutil.LogPanic()
dispatcherCtx, ctxCancel := context.WithCancel(ctx)
defer ctxCancel()
defer s.tsoDispatcher.Delete(forwardedHost)

var (
forwardStream tsopb.TSO_TsoClient
cancel context.CancelFunc
)
client, err := s.GetDelegateClient(ctx, forwardedHost)
if err != nil {
goto errHandling
}
log.Info("create tso forward stream", zap.String("forwarded-host", forwardedHost))
forwardStream, cancel, err = s.CreateTsoForwardStream(client)
errHandling:
if err != nil || forwardStream == nil {
log.Error("create tso forwarding stream error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCCreateStream, err))
select {
case <-dispatcherCtx.Done():
return
case _, ok := <-doneCh:
if !ok {
return
}
case errCh <- err:
close(errCh)
return
}
}
defer cancel()

requests := make([]*tsoRequest, maxMergeTSORequests+1)
for {
select {
case first := <-tsoRequestCh:
pendingTSOReqCount := len(tsoRequestCh) + 1
requests[0] = first
for i := 1; i < pendingTSOReqCount; i++ {
requests[i] = <-tsoRequestCh
}
done := make(chan struct{})
dl := deadline{
timer: time.After(defaultTSOProxyTimeout),
done: done,
cancel: cancel,
}
select {
case tsDeadlineCh <- dl:
case <-dispatcherCtx.Done():
return
}
err = s.processTSORequests(forwardStream, requests[:pendingTSOReqCount])
close(done)
if err != nil {
log.Error("proxy forward tso error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCSend, err))
select {
case <-dispatcherCtx.Done():
return
case _, ok := <-doneCh:
if !ok {
return
}
case errCh <- err:
close(errCh)
return
}
}
case <-dispatcherCtx.Done():
return
}
}
}

func (s *Service) processTSORequests(forwardStream tsopb.TSO_TsoClient, requests []*tsoRequest) error {
start := time.Now()
// Merge the requests
count := uint32(0)
for _, request := range requests {
count += request.request.GetCount()
}
req := &tsopb.TsoRequest{
Header: requests[0].request.GetHeader(),
Count: count,
// TODO: support Local TSO proxy forwarding.
DcLocation: requests[0].request.GetDcLocation(),
}
// Send to the leader stream.
if err := forwardStream.Send(req); err != nil {
return err
}
resp, err := forwardStream.Recv()
if err != nil {
return err
}
tsoProxyHandleDuration.Observe(time.Since(start).Seconds())
tsoProxyBatchSize.Observe(float64(count))
// Split the response
physical, logical, suffixBits := resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits()
// `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request.
// This is different from the logic of client batch, for example, if we have a largest ts whose logical part is 10,
// count is 5, then the splitting results should be 5 and 10.
firstLogical := addLogical(logical, -int64(count), suffixBits)
return s.finishTSORequest(requests, physical, firstLogical, suffixBits)
}

// Because of the suffix, we need to shift the count before we add it to the logical part.
func addLogical(logical, count int64, suffixBits uint32) int64 {
return logical + count<<suffixBits
}

func (s *Service) finishTSORequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32) error {
countSum := int64(0)
for i := 0; i < len(requests); i++ {
count := requests[i].request.GetCount()
countSum += int64(count)
response := &tsopb.TsoResponse{
Header: s.header(),
Count: count,
Timestamp: &pdpb.Timestamp{
Physical: physical,
Logical: addLogical(firstLogical, countSum, suffixBits),
SuffixBits: suffixBits,
},
}
// Send back to the client.
if err := requests[i].stream.Send(response); err != nil {
return err
}
}
return nil
}

type deadline struct {
timer <-chan time.Time
done chan struct{}
cancel context.CancelFunc
}

func watchTSDeadline(ctx context.Context, tsDeadlineCh <-chan deadline) {
defer logutil.LogPanic()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for {
select {
case d := <-tsDeadlineCh:
select {
case <-d.timer:
log.Error("tso proxy request processing is canceled due to timeout", errs.ZapError(errs.ErrProxyTSOTimeout))
d.cancel()
case <-d.done:
continue
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}
Loading

0 comments on commit d67bf27

Please sign in to comment.