Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TSO forward/dispatcher to be shared by both PD and TSO #6175

Merged
merged 3 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions client/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ func GetOrCreateGRPCConn(ctx context.Context, clientConns *sync.Map, addr string
if err != nil {
return nil, err
}
old, ok := clientConns.LoadOrStore(addr, cc)
if !ok {
conn, loaded := clientConns.LoadOrStore(addr, cc)
if !loaded {
// Successfully stored the connection.
return cc, nil
}
cc.Close()
log.Debug("use old connection", zap.String("target", cc.Target()), zap.String("state", cc.GetState().String()))
return old.(*grpc.ClientConn), nil
cc = conn.(*grpc.ClientConn)
log.Debug("use existing connection", zap.String("target", cc.Target()), zap.String("state", cc.GetState().String()))
return cc, nil
}
8 changes: 4 additions & 4 deletions client/tso_batch_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tso
// fetchPendingRequests will start a new round of the batch collecting from the channel.
// It returns true if everything goes well, otherwise false which means we should stop the service.
func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error {
var firstTSORequest *tsoRequest
var firstRequest *tsoRequest
select {
case <-ctx.Done():
return ctx.Err()
case firstTSORequest = <-tbc.tsoRequestCh:
case firstRequest = <-tbc.tsoRequestCh:
}
// Start to batch when the first TSO request arrives.
tbc.batchStartTime = time.Now()
tbc.collectedRequestCount = 0
tbc.pushRequest(firstTSORequest)
tbc.pushRequest(firstRequest)

// This loop is for trying best to collect more requests, so we use `tbc.maxBatchSize` here.
fetchPendingRequestsLoop:
Expand Down Expand Up @@ -130,7 +130,7 @@ func (tbc *tsoBatchController) adjustBestBatchSize() {
}
}

func (tbc *tsoBatchController) revokePendingTSORequest(err error) {
func (tbc *tsoBatchController) revokePendingRequest(err error) {
for i := 0; i < len(tbc.tsoRequestCh); i++ {
req := <-tbc.tsoRequestCh
req.done <- err
Expand Down
2 changes: 1 addition & 1 deletion client/tso_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (c *tsoClient) Close() {
if dispatcherInterface != nil {
dispatcher := dispatcherInterface.(*tsoDispatcher)
tsoErr := errors.WithStack(errClosing)
dispatcher.tsoBatchController.revokePendingTSORequest(tsoErr)
dispatcher.tsoBatchController.revokePendingRequest(tsoErr)
dispatcher.dispatcherCancel()
}
return true
Expand Down
19 changes: 10 additions & 9 deletions client/tso_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ tsoBatchLoop:
err = errs.ErrClientCreateTSOStream.FastGenByArgs(errs.RetryTimeoutErr)
log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err))
c.svcDiscovery.ScheduleCheckMemberChanged()
c.finishTSORequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err))
c.finishRequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err))
continue tsoBatchLoop
case <-time.After(retryInterval):
continue streamChoosingLoop
Expand Down Expand Up @@ -440,7 +440,7 @@ tsoBatchLoop:
case tsDeadlineCh.(chan deadline) <- dl:
}
opts = extractSpanReference(tbc, opts[:0])
err = c.processTSORequests(stream, dc, tbc, opts)
err = c.processRequests(stream, dc, tbc, opts)
close(done)
// If error happens during tso stream handling, reset stream and run the next trial.
if err != nil {
Expand Down Expand Up @@ -691,23 +691,23 @@ func extractSpanReference(tbc *tsoBatchController, opts []opentracing.StartSpanO
return opts
}

func (c *tsoClient) processTSORequests(stream tsoStream, dcLocation string, tbc *tsoBatchController, opts []opentracing.StartSpanOption) error {
func (c *tsoClient) processRequests(stream tsoStream, dcLocation string, tbc *tsoBatchController, opts []opentracing.StartSpanOption) error {
if len(opts) > 0 {
span := opentracing.StartSpan("pdclient.processTSORequests", opts...)
span := opentracing.StartSpan("pdclient.processRequests", opts...)
defer span.Finish()
}

requests := tbc.getCollectedRequests()
count := int64(len(requests))
physical, logical, suffixBits, err := stream.processRequests(c.svcDiscovery.GetClusterID(), dcLocation, requests, tbc.batchStartTime)
if err != nil {
c.finishTSORequest(requests, 0, 0, 0, err)
c.finishRequest(requests, 0, 0, 0, err)
return err
}
// `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request.
firstLogical := addLogical(logical, -count+1, suffixBits)
c.compareAndSwapTS(dcLocation, physical, firstLogical, suffixBits, count)
c.finishTSORequest(requests, physical, firstLogical, suffixBits, nil)
c.finishRequest(requests, physical, firstLogical, suffixBits, nil)
return nil
}

Expand All @@ -729,8 +729,9 @@ func (c *tsoClient) compareAndSwapTS(dcLocation string, physical, firstLogical i
lastTSOPointer := lastTSOInterface.(*lastTSO)
lastPhysical := lastTSOPointer.physical
lastLogical := lastTSOPointer.logical
// The TSO we get is a range like [largestLogical-count+1, largestLogical], so we save the last TSO's largest logical to compare with the new TSO's first logical.
// For example, if we have a TSO resp with logical 10, count 5, then all TSOs we get will be [6, 7, 8, 9, 10].
// The TSO we get is a range like [largestLogical-count+1, largestLogical], so we save the last TSO's largest logical
// to compare with the new TSO's first logical. For example, if we have a TSO resp with logical 10, count 5, then
// all TSOs we get will be [6, 7, 8, 9, 10].
if tsLessEqual(physical, firstLogical, lastPhysical, lastLogical) {
panic(errors.Errorf("%s timestamp fallback, newly acquired ts (%d, %d) is less or equal to last one (%d, %d)",
dcLocation, physical, firstLogical, lastPhysical, lastLogical))
Expand All @@ -747,7 +748,7 @@ func tsLessEqual(physical, logical, thatPhysical, thatLogical int64) bool {
return physical < thatPhysical
}

func (c *tsoClient) finishTSORequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) {
func (c *tsoClient) finishRequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) {
for i := 0; i < len(requests); i++ {
if span := opentracing.SpanFromContext(requests[i].requestCtx); span != nil {
span.Finish()
Expand Down
6 changes: 4 additions & 2 deletions client/tso_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ func (s *pdTSOStream) processRequests(clusterID uint64, dcLocation string, reque
return
}

physical, logical, suffixBits = resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits()
ts := resp.GetTimestamp()
physical, logical, suffixBits = ts.GetPhysical(), ts.GetLogical(), ts.GetSuffixBits()
return
}

Expand Down Expand Up @@ -189,6 +190,7 @@ func (s *tsoTSOStream) processRequests(clusterID uint64, dcLocation string, requ
return
}

physical, logical, suffixBits = resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits()
ts := resp.GetTimestamp()
physical, logical, suffixBits = ts.GetPhysical(), ts.GetLogical(), ts.GetSuffixBits()
return
}
205 changes: 11 additions & 194 deletions pkg/mcs/tso/server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,19 @@ import (
"net/http"
"time"

"github.com/pingcap/kvproto/pkg/pdpb"
"github.com/pingcap/kvproto/pkg/tsopb"
"github.com/pingcap/log"
"github.com/pkg/errors"
bs "github.com/tikv/pd/pkg/basicserver"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/mcs/registry"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/grpcutil"
"github.com/tikv/pd/pkg/utils/logutil"
"go.uber.org/zap"
"github.com/tikv/pd/pkg/utils/tsoutil"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const (
// tso
maxMergeTSORequests = 10000
defaultTSOProxyTimeout = 3 * time.Second
)

// gRPC errors
var (
ErrNotStarted = status.Errorf(codes.Unavailable, "server not started")
Expand Down Expand Up @@ -116,16 +107,20 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error {
streamCtx := stream.Context()
forwardedHost := grpcutil.GetForwardedHost(streamCtx)
if !s.IsLocalRequest(forwardedHost) {
clientConn, err := s.GetDelegateClient(s.ctx, forwardedHost)
if err != nil {
return errors.WithStack(err)
}

if errCh == nil {
doneCh = make(chan struct{})
defer close(doneCh)
errCh = make(chan error)
}
s.dispatchTSORequest(ctx, &tsoRequest{
forwardedHost,
request,
stream,
}, forwardedHost, doneCh, errCh)

tsoProtoFactory := s.tsoProtoFactory
tsoRequest := tsoutil.NewTSOProtoRequest(forwardedHost, clientConn, request, stream)
s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh)
continue
}

Expand All @@ -138,7 +133,7 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error {
return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId())
}
count := request.GetCount()
ts, err := s.tsoAllocatorManager.HandleTSORequest(request.GetDcLocation(), count)
ts, err := s.tsoAllocatorManager.HandleRequest(request.GetDcLocation(), count)
if err != nil {
return status.Errorf(codes.Unknown, err.Error())
}
Expand Down Expand Up @@ -174,181 +169,3 @@ func (s *Service) errorHeader(err *tsopb.Error) *tsopb.ResponseHeader {
Error: err,
}
}

type tsoRequest struct {
forwardedHost string
request *tsopb.TsoRequest
stream tsopb.TSO_TsoServer
}

func (s *Service) dispatchTSORequest(ctx context.Context, request *tsoRequest, forwardedHost string, doneCh <-chan struct{}, errCh chan<- error) {
tsoRequestChInterface, loaded := s.tsoDispatcher.LoadOrStore(forwardedHost, make(chan *tsoRequest, maxMergeTSORequests))
if !loaded {
tsDeadlineCh := make(chan deadline, 1)
go s.handleDispatcher(ctx, forwardedHost, tsoRequestChInterface.(chan *tsoRequest), tsDeadlineCh, doneCh, errCh)
go watchTSDeadline(ctx, tsDeadlineCh)
}
tsoRequestChInterface.(chan *tsoRequest) <- request
}

func (s *Service) handleDispatcher(ctx context.Context, forwardedHost string, tsoRequestCh <-chan *tsoRequest, tsDeadlineCh chan<- deadline, doneCh <-chan struct{}, errCh chan<- error) {
defer logutil.LogPanic()
dispatcherCtx, ctxCancel := context.WithCancel(ctx)
defer ctxCancel()
defer s.tsoDispatcher.Delete(forwardedHost)

var (
forwardStream tsopb.TSO_TsoClient
cancel context.CancelFunc
)
client, err := s.GetDelegateClient(ctx, forwardedHost)
if err != nil {
goto errHandling
}
log.Info("create tso forward stream", zap.String("forwarded-host", forwardedHost))
forwardStream, cancel, err = s.CreateTsoForwardStream(client)
errHandling:
if err != nil || forwardStream == nil {
log.Error("create tso forwarding stream error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCCreateStream, err))
select {
case <-dispatcherCtx.Done():
return
case _, ok := <-doneCh:
if !ok {
return
}
case errCh <- err:
close(errCh)
return
}
}
defer cancel()

requests := make([]*tsoRequest, maxMergeTSORequests+1)
for {
select {
case first := <-tsoRequestCh:
pendingTSOReqCount := len(tsoRequestCh) + 1
requests[0] = first
for i := 1; i < pendingTSOReqCount; i++ {
requests[i] = <-tsoRequestCh
}
done := make(chan struct{})
dl := deadline{
timer: time.After(defaultTSOProxyTimeout),
done: done,
cancel: cancel,
}
select {
case tsDeadlineCh <- dl:
case <-dispatcherCtx.Done():
return
}
err = s.processTSORequests(forwardStream, requests[:pendingTSOReqCount])
close(done)
if err != nil {
log.Error("proxy forward tso error", zap.String("forwarded-host", forwardedHost), errs.ZapError(errs.ErrGRPCSend, err))
select {
case <-dispatcherCtx.Done():
return
case _, ok := <-doneCh:
if !ok {
return
}
case errCh <- err:
close(errCh)
return
}
}
case <-dispatcherCtx.Done():
return
}
}
}

func (s *Service) processTSORequests(forwardStream tsopb.TSO_TsoClient, requests []*tsoRequest) error {
start := time.Now()
// Merge the requests
count := uint32(0)
for _, request := range requests {
count += request.request.GetCount()
}
req := &tsopb.TsoRequest{
Header: requests[0].request.GetHeader(),
Count: count,
// TODO: support Local TSO proxy forwarding.
DcLocation: requests[0].request.GetDcLocation(),
}
// Send to the leader stream.
if err := forwardStream.Send(req); err != nil {
return err
}
resp, err := forwardStream.Recv()
if err != nil {
return err
}
tsoProxyHandleDuration.Observe(time.Since(start).Seconds())
tsoProxyBatchSize.Observe(float64(count))
// Split the response
physical, logical, suffixBits := resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical(), resp.GetTimestamp().GetSuffixBits()
// `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request.
// This is different from the logic of client batch, for example, if we have a largest ts whose logical part is 10,
// count is 5, then the splitting results should be 5 and 10.
firstLogical := addLogical(logical, -int64(count), suffixBits)
return s.finishTSORequest(requests, physical, firstLogical, suffixBits)
}

// Because of the suffix, we need to shift the count before we add it to the logical part.
func addLogical(logical, count int64, suffixBits uint32) int64 {
return logical + count<<suffixBits
}

func (s *Service) finishTSORequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32) error {
countSum := int64(0)
for i := 0; i < len(requests); i++ {
count := requests[i].request.GetCount()
countSum += int64(count)
response := &tsopb.TsoResponse{
Header: s.header(),
Count: count,
Timestamp: &pdpb.Timestamp{
Physical: physical,
Logical: addLogical(firstLogical, countSum, suffixBits),
SuffixBits: suffixBits,
},
}
// Send back to the client.
if err := requests[i].stream.Send(response); err != nil {
return err
}
}
return nil
}

type deadline struct {
timer <-chan time.Time
done chan struct{}
cancel context.CancelFunc
}

func watchTSDeadline(ctx context.Context, tsDeadlineCh <-chan deadline) {
defer logutil.LogPanic()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for {
select {
case d := <-tsDeadlineCh:
select {
case <-d.timer:
log.Error("tso proxy request processing is canceled due to timeout", errs.ZapError(errs.ErrProxyTSOTimeout))
d.cancel()
case <-d.done:
continue
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}
Loading