Skip to content

Commit

Permalink
Merge pull request #508 from TarsCloud/feat/lbbniu/timeout
Browse files Browse the repository at this point in the history
Implementing Inter-Service Timeout Propagation using context.Context
  • Loading branch information
lbbniu authored Jan 13, 2024
2 parents b937668 + 0d66280 commit f016edd
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 48 deletions.
14 changes: 11 additions & 3 deletions tars/servant.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/TarsCloud/TarsGo/tars/protocol/res/requestf"
"github.com/TarsCloud/TarsGo/tars/util/current"
"github.com/TarsCloud/TarsGo/tars/util/endpoint"
"github.com/TarsCloud/TarsGo/tars/util/rtimer"
"github.com/TarsCloud/TarsGo/tars/util/tools"
)

Expand Down Expand Up @@ -159,17 +158,26 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte,
msg := &Message{Req: &req, Ser: s, Resp: resp}
msg.Init()

timeout := time.Duration(s.timeout) * time.Millisecond
if ok, hashType, hashCode, isHash := current.GetClientHash(ctx); ok {
msg.isHash = isHash
msg.hashType = HashType(hashType)
msg.hashCode = hashCode
}

timeout := time.Duration(s.timeout) * time.Millisecond
if ok, to, isTimeout := current.GetClientTimeout(ctx); ok && isTimeout {
timeout = time.Duration(to) * time.Millisecond
req.ITimeout = int32(to)
}
// timeout delivery
if dl, ok := ctx.Deadline(); ok {
timeout = time.Until(dl)
req.ITimeout = int32(timeout / time.Millisecond)
} else {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}

var err error
s.manager.preInvoke()
Expand Down Expand Up @@ -253,7 +261,7 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.
return nil
}
select {
case <-rtimer.After(timeout):
case <-ctx.Done():
msg.Status = basef.TARSINVOKETIMEOUT
adp.failAdd()
msg.End()
Expand Down
99 changes: 56 additions & 43 deletions tars/tarsprotocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ func (s *Protocol) Invoke(ctx context.Context, req []byte) (rsp []byte) {
is := codec.NewReader(req[4:])
reqPackage.ReadFrom(is)

recvPkgTs, ok := current.GetRecvPkgTsFromContext(ctx)
if !ok {
recvPkgTs = time.Now().UnixNano() / 1e6
}

// timeout delivery
now := time.Now().UnixNano() / 1e6
if reqPackage.ITimeout > 0 {
sub := now - recvPkgTs // coroutine scheduling time difference
timeout := int64(reqPackage.ITimeout) - sub
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond)
defer cancel()
}

if reqPackage.HasMessageType(basef.TARSMESSAGETYPEDYED) {
if dyeingKey, ok := reqPackage.Status[current.StatusDyedKey]; ok {
if ok = current.SetDyeingKey(ctx, dyeingKey); !ok {
Expand All @@ -62,10 +77,6 @@ func (s *Protocol) Invoke(ctx context.Context, req []byte) (rsp []byte) {
}
}

recvPkgTs, ok := current.GetRecvPkgTsFromContext(ctx)
if !ok {
recvPkgTs = time.Now().UnixNano() / 1e6
}
if reqPackage.CPacketType == basef.TARSONEWAY {
defer func() {
endTime := time.Now().UnixNano() / 1e6
Expand All @@ -81,52 +92,54 @@ func (s *Protocol) Invoke(ctx context.Context, req []byte) (rsp []byte) {
rspPackage.IVersion = reqPackage.IVersion
rspPackage.IRequestId = reqPackage.IRequestId

// Improve server timeout handling
now := time.Now().UnixNano() / 1e6
if ok && reqPackage.ITimeout > 0 && now-recvPkgTs > int64(reqPackage.ITimeout) {
select {
case <-ctx.Done():
rspPackage.IRet = basef.TARSSERVERQUEUETIMEOUT
rspPackage.SResultDesc = "server invoke timeout"
ip, _ := current.GetClientIPFromContext(ctx)
port, _ := current.GetClientPortFromContext(ctx)
TLOG.Errorf("handle queue timeout, obj:%s, func:%s, recv time:%d, now:%d, timeout:%d, cost:%d, addr:(%s:%s), reqId:%d",
reqPackage.SServantName, reqPackage.SFuncName, recvPkgTs, now, reqPackage.ITimeout, now-recvPkgTs, ip, port, reqPackage.IRequestId)
} else if reqPackage.SFuncName != "tars_ping" { // not tars_ping, normal business call branch
if s.withContext {
if ok = current.SetRequestStatus(ctx, reqPackage.Status); !ok {
TLOG.Error("Set request status in context fail!")
}
if ok = current.SetRequestContext(ctx, reqPackage.Context); !ok {
TLOG.Error("Set request context in context fail!")
}
}
var err error
if s.app.allFilters.sf != nil {
err = s.app.allFilters.sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
} else if sf := s.app.getMiddlewareServerFilter(); sf != nil {
err = sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
} else {
// execute pre server filters
for i, v := range s.app.allFilters.preSfs {
err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
if err != nil {
TLOG.Errorf("Pre filter error, No.%v, err: %v", i, err)
TLOG.Errorf("handle queue timeout, obj:%s, func:%s, recv time:%d, now:%d, timeout:%d, cost:%d, addr:(%s:%s), reqId:%d, err: %v",
reqPackage.SServantName, reqPackage.SFuncName, recvPkgTs, now, reqPackage.ITimeout, now-recvPkgTs, ip, port, reqPackage.IRequestId, ctx.Err())
default:
if reqPackage.SFuncName != "tars_ping" { // not tars_ping, normal business call branch
if s.withContext {
if ok = current.SetRequestStatus(ctx, reqPackage.Status); !ok {
TLOG.Error("Set request status in context fail!")
}
if ok = current.SetRequestContext(ctx, reqPackage.Context); !ok {
TLOG.Error("Set request context in context fail!")
}
}
// execute business server
err = s.dispatcher.Dispatch(ctx, s.serverImp, &reqPackage, &rspPackage, s.withContext)
// execute post server filters
for i, v := range s.app.allFilters.postSfs {
err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
if err != nil {
TLOG.Errorf("Post filter error, No.%v, err: %v", i, err)
var err error
if s.app.allFilters.sf != nil {
err = s.app.allFilters.sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
} else if sf := s.app.getMiddlewareServerFilter(); sf != nil {
err = sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
} else {
// execute pre server filters
for i, v := range s.app.allFilters.preSfs {
err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
if err != nil {
TLOG.Errorf("Pre filter error, No.%v, err: %v", i, err)
}
}
// execute business server
err = s.dispatcher.Dispatch(ctx, s.serverImp, &reqPackage, &rspPackage, s.withContext)
// execute post server filters
for i, v := range s.app.allFilters.postSfs {
err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
if err != nil {
TLOG.Errorf("Post filter error, No.%v, err: %v", i, err)
}
}
}
}
if err != nil {
TLOG.Errorf("RequestID:%d, Found err: %v", reqPackage.IRequestId, err)
rspPackage.IRet = 1
rspPackage.SResultDesc = err.Error()
if tarsErr, ok := err.(*Error); ok {
rspPackage.IRet = tarsErr.Code
if err != nil {
TLOG.Errorf("RequestID:%d, Found err: %v", reqPackage.IRequestId, err)
rspPackage.IRet = 1
rspPackage.SResultDesc = err.Error()
if tarsErr, ok := err.(*Error); ok {
rspPackage.IRet = tarsErr.Code
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tars/transport/tcphandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ func (t *tcpHandler) getConnContext(connSt *connInfo) context.Context {

func (t *tcpHandler) handleConn(connSt *connInfo, pkg []byte) {
// recvPkgTs are more accurate
ctx := t.getConnContext(connSt)
handler := func() {
defer atomic.AddInt32(&connSt.numInvoke, -1)
ctx := t.getConnContext(connSt)
rsp := t.server.invoke(ctx, pkg)

cPacketType, ok := current.GetPacketTypeFromContext(ctx)
Expand Down
2 changes: 1 addition & 1 deletion tars/transport/udphandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ func (u *udpHandler) Handle() error {
}
pkg := make([]byte, n)
copy(pkg, buffer[0:n])
ctx := u.getConnContext(udpAddr)
go func() {
atomic.AddInt32(&u.server.numInvoke, 1)
defer atomic.AddInt32(&u.server.numInvoke, -1)
ctx := u.getConnContext(udpAddr)
rsp := u.server.invoke(ctx, pkg) // no need to check package

cPacketType, ok := current.GetPacketTypeFromContext(ctx)
Expand Down

0 comments on commit f016edd

Please sign in to comment.