Skip to content

Commit

Permalink
Improve tso proxy
Browse files Browse the repository at this point in the history
Signed-off-by: Bin Shi <binshi.bing@gmail.com>
  • Loading branch information
binshi-bing committed Jun 7, 2023
1 parent 7b1893d commit bacbd0d
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 207 deletions.
34 changes: 2 additions & 32 deletions pkg/mcs/tso/server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/tikv/pd/pkg/mcs/registry"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/grpcutil"
"github.com/tikv/pd/pkg/utils/tsoutil"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -83,21 +82,7 @@ func (s *Service) RegisterRESTHandler(userDefineHandlers map[string]http.Handler

// Tso returns a stream of timestamps
func (s *Service) Tso(stream tsopb.TSO_TsoServer) error {
var (
doneCh chan struct{}
errCh chan error
)
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
for {
// Prevent unnecessary performance overhead of the channel.
if errCh != nil {
select {
case err := <-errCh:
return errors.WithStack(err)
default:
}
}
request, err := stream.Recv()
if err == io.EOF {
return nil
Expand All @@ -106,24 +91,9 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error {
return errors.WithStack(err)
}

streamCtx := stream.Context()
forwardedHost := grpcutil.GetForwardedHost(streamCtx)
forwardedHost := grpcutil.GetForwardedHost(stream.Context())
if !s.IsLocalRequest(forwardedHost) {
clientConn, err := s.GetDelegateClient(s.ctx, forwardedHost)
if err != nil {
return errors.WithStack(err)
}

if errCh == nil {
doneCh = make(chan struct{})
defer close(doneCh)
errCh = make(chan error)
}

tsoProtoFactory := s.tsoProtoFactory
tsoRequest := tsoutil.NewTSOProtoRequest(forwardedHost, clientConn, request, stream)
s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh)
continue
return status.Error(codes.Unimplemented, "tso microservice does not support forwarding requests")
}

start := time.Now()
Expand Down
3 changes: 0 additions & 3 deletions pkg/mcs/tso/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ type Server struct {
keyspaceGroupManager *tso.KeyspaceGroupManager
// Store as map[string]*grpc.ClientConn
clientConns sync.Map
// tsoDispatcher is used to dispatch the TSO requests to
// the corresponding forwarding TSO channels.
tsoDispatcher *tsoutil.TSODispatcher
// tsoProtoFactory is the abstract factory for creating tso
// related data structures defined in the tso grpc protocol
tsoProtoFactory *tsoutil.TSOProtoFactory
Expand Down
215 changes: 128 additions & 87 deletions pkg/utils/tsoutil/tso_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ package tsoutil

import (
"context"
"strings"
"sync"
"time"

"github.com/pingcap/kvproto/pkg/pdpb"
"github.com/pingcap/log"
"github.com/prometheus/client_golang/prometheus"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/utils/etcdutil"
"github.com/tikv/pd/pkg/utils/logutil"
"go.uber.org/zap"
"google.golang.org/grpc"
Expand All @@ -45,119 +43,159 @@ type TSODispatcher struct {
tsoProxyHandleDuration prometheus.Histogram
tsoProxyBatchSize prometheus.Histogram

ctx context.Context
// dispatchChs is used to dispatch different TSO requests to the corresponding forwarding TSO channels.
dispatchChs sync.Map // Store as map[string]chan Request
dispatchChs sync.Map // Store as map[string]chan Request (forwardedHost -> dispatch channel)
// lastErrors is used to record the last error of each forwarding TSO channel.
lastErrors sync.Map // Store as map[string]error (forwardedHost -> last error)
}

// NewTSODispatcher creates and returns a TSODispatcher
func NewTSODispatcher(tsoProxyHandleDuration, tsoProxyBatchSize prometheus.Histogram) *TSODispatcher {
func NewTSODispatcher(
ctx context.Context, tsoProxyHandleDuration, tsoProxyBatchSize prometheus.Histogram,
) *TSODispatcher {
tsoDispatcher := &TSODispatcher{
ctx: ctx,
tsoProxyHandleDuration: tsoProxyHandleDuration,
tsoProxyBatchSize: tsoProxyBatchSize,
}
return tsoDispatcher
}

// DispatchRequest is the entry point for dispatching/forwarding a tso request to the detination host
func (s *TSODispatcher) DispatchRequest(
ctx context.Context,
req Request,
tsoProtoFactory ProtoFactory,
doneCh <-chan struct{},
errCh chan<- error,
tsoPrimaryWatchers ...*etcdutil.LoopWatcher) {
// GetAndDeleteLastError gets and deletes the last error of the forwarded host
func (s *TSODispatcher) GetAndDeleteLastError(forwardedHost string) error {
if val, loaded := s.lastErrors.LoadAndDelete(forwardedHost); loaded {
return val.(error)
}
return nil
}

// DispatchRequest is the entry point for dispatching/forwarding a tso request to the destination host
func (s *TSODispatcher) DispatchRequest(req Request, tsoProtoFactory ProtoFactory) {
val, loaded := s.dispatchChs.LoadOrStore(req.getForwardedHost(), make(chan Request, maxMergeRequests))
reqCh := val.(chan Request)
if !loaded {
tsDeadlineCh := make(chan deadline, 1)
go s.dispatch(ctx, tsoProtoFactory, req.getForwardedHost(), req.getClientConn(), reqCh, tsDeadlineCh, doneCh, errCh, tsoPrimaryWatchers...)
go watchTSDeadline(ctx, tsDeadlineCh)
go s.startDispatchLoop(req.getForwardedHost(), req.getClientConn(), reqCh, tsoProtoFactory)
}
reqCh <- req
}

func (s *TSODispatcher) dispatch(
ctx context.Context,
tsoProtoFactory ProtoFactory,
forwardedHost string,
clientConn *grpc.ClientConn,
tsoRequestCh <-chan Request,
tsDeadlineCh chan<- deadline,
doneCh <-chan struct{},
errCh chan<- error,
tsoPrimaryWatchers ...*etcdutil.LoopWatcher) {
// startDispatchLoop starts the dispatch loop for the forwarded host
func (s *TSODispatcher) cleanup(
forwardedHost string, finalForwardErr error,
unprocessedRequests []Request, unprocessedReqCount int,
) {
pendingReqCount := unprocessedReqCount
pendingRequests := unprocessedRequests[:unprocessedReqCount]
val, loaded := s.dispatchChs.LoadAndDelete(forwardedHost)
if loaded {
reqCh := val.(chan Request)
waitingReqCount := len(reqCh)
for i := 0; i < waitingReqCount; i++ {
req := <-reqCh
pendingRequests = append(pendingRequests, req)
pendingReqCount++
}
}
if finalForwardErr != nil {
for i := 0; i < pendingReqCount; i++ {
if pendingRequests[i] != nil {
pendingRequests[i].sendErrorResponseAsync(finalForwardErr)
}
}
} else if pendingReqCount > 0 {
log.Warn("the dispatch loop exited with pending requests unprocessed",
zap.String("forwarded-host", forwardedHost),
zap.Int("pending-requests-count", pendingReqCount))
}
}

func (s *TSODispatcher) startDispatchLoop(
forwardedHost string, clientConn *grpc.ClientConn,
tsoRequestCh <-chan Request, tsoProtoFactory ProtoFactory,
) {
defer logutil.LogPanic()
dispatcherCtx, ctxCancel := context.WithCancel(ctx)
defer ctxCancel()
defer s.dispatchChs.Delete(forwardedHost)
ctx, cancel := context.WithCancel(s.ctx)
defer cancel()

// forwardErr indicates the failure in the forwarding stream which causes the dispatch loop to exit.
var (
forwardErr error
forwardStream stream
)
pendingRequests := make([]Request, maxMergeRequests+1)
pendingTSOReqCount := 0

log.Info("start the dispatch loop", zap.String("forwarded-host", forwardedHost))
defer func() {
log.Info("exiting from the dispatch loop. cleaning up the pending requests",
zap.String("forwarded-host", forwardedHost))
if forwardStream != nil {
forwardStream.closeSend()
}
s.cleanup(forwardedHost, forwardErr, pendingRequests, pendingTSOReqCount)
log.Info("the dispatch loop exited", zap.String("forwarded-host", forwardedHost))
}()

forwardStream, cancel, err := tsoProtoFactory.createForwardStream(ctx, clientConn)
if err != nil || forwardStream == nil {
forwardStream, _, forwardErr = tsoProtoFactory.createForwardStream(ctx, clientConn)
if forwardErr != nil {
log.Error("create tso forwarding stream error",
zap.String("forwarded-host", forwardedHost),
errs.ZapError(errs.ErrGRPCCreateStream, err))
select {
case <-dispatcherCtx.Done():
return
case _, ok := <-doneCh:
if !ok {
return
}
case errCh <- err:
close(errCh)
return
}
errs.ZapError(errs.ErrGRPCCreateStream, forwardErr))
s.lastErrors.Store(forwardedHost, forwardErr)
return
}
defer cancel()

requests := make([]Request, maxMergeRequests+1)
needUpdateServicePrimaryAddr := len(tsoPrimaryWatchers) > 0 && tsoPrimaryWatchers[0] != nil
tsDeadlineCh := make(chan deadline, 1)
go watchTSDeadline(ctx, forwardedHost, tsDeadlineCh)

for {
select {
case first := <-tsoRequestCh:
pendingTSOReqCount := len(tsoRequestCh) + 1
requests[0] = first
pendingTSOReqCount = len(tsoRequestCh) + 1
pendingRequests[0] = first
for i := 1; i < pendingTSOReqCount; i++ {
requests[i] = <-tsoRequestCh
}
done := make(chan struct{})
dl := deadline{
timer: time.After(DefaultTSOProxyTimeout),
done: done,
cancel: cancel,
pendingRequests[i] = <-tsoRequestCh
}
select {
case tsDeadlineCh <- dl:
case <-dispatcherCtx.Done():
return
}
err = s.processRequests(forwardStream, requests[:pendingTSOReqCount], tsoProtoFactory)
close(done)
if err != nil {
forwardErr = s.processRequestsWithDeadLine(
ctx, tsDeadlineCh, forwardStream, pendingRequests[:pendingTSOReqCount], tsoProtoFactory)
if forwardErr != nil {
log.Error("proxy forward tso error",
zap.String("forwarded-host", forwardedHost),
errs.ZapError(errs.ErrGRPCSend, err))
if needUpdateServicePrimaryAddr && strings.Contains(err.Error(), errs.NotLeaderErr) {
tsoPrimaryWatchers[0].ForceLoad()
}
select {
case <-dispatcherCtx.Done():
return
case _, ok := <-doneCh:
if !ok {
return
}
case errCh <- err:
close(errCh)
return
}
errs.ZapError(errs.ErrGRPCSend, forwardErr))
s.lastErrors.Store(forwardedHost, forwardErr)
return
}
case <-dispatcherCtx.Done():
// All requests are processed successfully, reset this counter to avoid unnecessary cleanup.
pendingTSOReqCount = 0
case <-ctx.Done():
return
}
}
}

func (s *TSODispatcher) processRequestsWithDeadLine(
ctx context.Context, tsDeadlineCh chan<- deadline, forwardStream stream,
requests []Request, tsoProtoFactory ProtoFactory,
) error {
cctx, cancel := context.WithCancel(ctx)
defer cancel()
done := make(chan struct{})
dl := deadline{
timer: time.After(DefaultTSOProxyTimeout),
done: done,
cancel: cancel,
}
select {
case tsDeadlineCh <- dl:
case <-cctx.Done():
return nil
}
err := s.processRequests(forwardStream, requests, tsoProtoFactory)
close(done)
return err
}

func (s *TSODispatcher) processRequests(forwardStream stream, requests []Request, tsoProtoFactory ProtoFactory) error {
// Merge the requests
count := uint32(0)
Expand All @@ -179,24 +217,20 @@ func (s *TSODispatcher) processRequests(forwardStream stream, requests []Request
// This is different from the logic of client batch, for example, if we have a largest ts whose logical part is 10,
// count is 5, then the splitting results should be 5 and 10.
firstLogical := addLogical(logical, -int64(count), suffixBits)
return s.finishRequest(requests, physical, firstLogical, suffixBits)
s.finishRequest(requests, physical, firstLogical, suffixBits)
return nil
}

// Because of the suffix, we need to shift the count before we add it to the logical part.
func addLogical(logical, count int64, suffixBits uint32) int64 {
return logical + count<<suffixBits
}

func (s *TSODispatcher) finishRequest(requests []Request, physical, firstLogical int64, suffixBits uint32) error {
func (s *TSODispatcher) finishRequest(requests []Request, physical, firstLogical int64, suffixBits uint32) {
countSum := int64(0)
for i := 0; i < len(requests); i++ {
newCountSum, err := requests[i].postProcess(countSum, physical, firstLogical, suffixBits)
if err != nil {
return err
}
countSum = newCountSum
countSum = requests[i].sendResponseAsync(countSum, physical, firstLogical, suffixBits)
}
return nil
}

type deadline struct {
Expand All @@ -205,10 +239,17 @@ type deadline struct {
cancel context.CancelFunc
}

func watchTSDeadline(ctx context.Context, tsDeadlineCh <-chan deadline) {
func watchTSDeadline(ctx context.Context, forwardedHost string, tsDeadlineCh <-chan deadline) {
defer logutil.LogPanic()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

log.Info("start to watch tso proxy request deadline", zap.String("forwarded-host", forwardedHost))
defer func() {
log.Info("tso proxy request deadline watch loop is closed",
zap.String("forwarded-host", forwardedHost))
}()

for {
select {
case d := <-tsDeadlineCh:
Expand Down
12 changes: 12 additions & 0 deletions pkg/utils/tsoutil/tso_proto_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ func (s *PDProtoFactory) createForwardStream(ctx context.Context, clientConn *gr
type stream interface {
// process sends a request and receives the response through the stream
process(clusterID uint64, count, keyspaceID, keyspaceGroupID uint32, dcLocation string) (response, error)
// closeSend closes the stream on the sender side
closeSend()
}

type tsoStream struct {
Expand Down Expand Up @@ -83,6 +85,11 @@ func (s *tsoStream) process(clusterID uint64, count, keyspaceID, keyspaceGroupID
return resp, nil
}

// closeSend closes the stream on the sender side
func (s *tsoStream) closeSend() {
s.stream.CloseSend()
}

type pdStream struct {
stream pdpb.PD_TsoClient
}
Expand All @@ -105,3 +112,8 @@ func (s *pdStream) process(clusterID uint64, count, _, _ uint32, dcLocation stri
}
return resp, nil
}

// closeSend closes the stream on the sender side
func (s *pdStream) closeSend() {
s.stream.CloseSend()
}
Loading

0 comments on commit bacbd0d

Please sign in to comment.