From 0d0588d67bd871b4396493e286e9ea4007a08601 Mon Sep 17 00:00:00 2001 From: "wangzekun.zekin" Date: Tue, 21 Mar 2023 19:15:18 +0800 Subject: [PATCH] set timeout even timeout <= 0 --- pkg/common/config/request_option.go | 33 +++++++++++++++++ pkg/common/test/mock/network.go | 16 ++++++--- pkg/protocol/client/client.go | 6 ++-- pkg/protocol/http1/client.go | 55 +++++++++++++---------------- pkg/protocol/http1/client_test.go | 6 ++-- pkg/protocol/request.go | 15 -------- 6 files changed, 77 insertions(+), 54 deletions(-) diff --git a/pkg/common/config/request_option.go b/pkg/common/config/request_option.go index 3737ac57a..891a3bb02 100644 --- a/pkg/common/config/request_option.go +++ b/pkg/common/config/request_option.go @@ -27,6 +27,10 @@ type RequestOptions struct { dialTimeout time.Duration readTimeout time.Duration writeTimeout time.Duration + // Request timeout. Usually set by DoDeadline or DoTimeout + // if <= 0, means not set + requestTimeout time.Duration + begin time.Time } // RequestOption is the only struct to set request-level options. @@ -95,6 +99,25 @@ func WithWriteTimeout(t time.Duration) RequestOption { }} } +// WithRequestTimeout sets whole request timeout. If it reaches timeout, +// the client will return. +// +// This is the request level configuration. +func WithRequestTimeout(t time.Duration) RequestOption { + return RequestOption{F: func(o *RequestOptions) { + o.requestTimeout = t + }} +} + +// WithBegin sets the request begin time. +// +// This is the request level configuration. +func WithBegin(t time.Time) RequestOption { + return RequestOption{F: func(o *RequestOptions) { + o.begin = t + }} +} + func (o *RequestOptions) Apply(opts []RequestOption) { for _, op := range opts { op.F(o) @@ -125,6 +148,14 @@ func (o *RequestOptions) WriteTimeout() time.Duration { return o.writeTimeout } +func (o *RequestOptions) RequestTimeout() time.Duration { + return o.requestTimeout +} + +func (o *RequestOptions) Begin() time.Time { + return o.begin +} + func (o *RequestOptions) CopyTo(dst *RequestOptions) { if dst.tags == nil { dst.tags = make(map[string]string) @@ -138,6 +169,8 @@ func (o *RequestOptions) CopyTo(dst *RequestOptions) { dst.readTimeout = o.readTimeout dst.writeTimeout = o.writeTimeout dst.dialTimeout = o.dialTimeout + dst.requestTimeout = o.requestTimeout + dst.begin = o.begin } // SetPreDefinedOpts Pre define some RequestOption here diff --git a/pkg/common/test/mock/network.go b/pkg/common/test/mock/network.go index 389c47266..83c96bbda 100644 --- a/pkg/common/test/mock/network.go +++ b/pkg/common/test/mock/network.go @@ -42,7 +42,7 @@ type Recorder interface { func (m *Conn) SetWriteTimeout(t time.Duration) error { // TODO implement me - panic("implement me") + return nil } type SlowReadConn struct { @@ -50,7 +50,11 @@ type SlowReadConn struct { } func (m *SlowReadConn) SetWriteTimeout(t time.Duration) error { - // TODO implement me + return nil +} + +func (m *SlowReadConn) SetReadTimeout(t time.Duration) error { + m.Conn.readTimeout = t return nil } @@ -133,7 +137,11 @@ func (r *recorder) WroteLen() int { func (m *SlowReadConn) Peek(i int) ([]byte, error) { b, err := m.zr.Peek(i) - time.Sleep(100 * time.Millisecond) + if m.readTimeout > 0 { + time.Sleep(m.readTimeout) + } else { + time.Sleep(100 * time.Millisecond) + } if err != nil || len(b) != i { time.Sleep(m.readTimeout) return nil, errs.ErrReadTimeout @@ -152,7 +160,7 @@ func NewConn(source string) *Conn { } func NewSlowReadConn(source string) *SlowReadConn { - return &SlowReadConn{NewConn(source)} + return &SlowReadConn{Conn: NewConn(source)} } type SlowWriteConn struct { diff --git a/pkg/protocol/client/client.go b/pkg/protocol/client/client.go index 64e76a47a..0496f0ee7 100644 --- a/pkg/protocol/client/client.go +++ b/pkg/protocol/client/client.go @@ -280,7 +280,8 @@ func DoTimeout(ctx context.Context, req *protocol.Request, resp *protocol.Respon if timeout <= 0 { return errTimeout } - req.SetTimeout(timeout) + // Note: it will overwrite the reqTimeout. + req.SetOptions(config.WithRequestTimeout(timeout), config.WithBegin(time.Now())) return c.Do(ctx, req, resp) } @@ -289,6 +290,7 @@ func DoDeadline(ctx context.Context, req *protocol.Request, resp *protocol.Respo if timeout <= 0 { return errTimeout } - req.SetTimeout(timeout) + // Note: it will overwrite the reqTimeout. + req.SetOptions(config.WithRequestTimeout(timeout), config.WithBegin(time.Now())) return c.Do(ctx, req, resp) } diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index 90050da9e..aeed399a2 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -379,13 +379,8 @@ func (c *HostClient) Do(ctx context.Context, req *protocol.Request, resp *protoc atomic.AddInt32(&c.pendingRequests, 1) - var before time.Time - if req.GetTimeout() > 0 { - before = time.Now() - } - for { - canIdempotentRetry, err = c.do(req, resp, before) + canIdempotentRetry, err = c.do(req, resp) if err == nil { break } @@ -428,14 +423,14 @@ func (c *HostClient) PendingRequests() int { return int(atomic.LoadInt32(&c.pendingRequests)) } -func (c *HostClient) do(req *protocol.Request, resp *protocol.Response, before time.Time) (bool, error) { +func (c *HostClient) do(req *protocol.Request, resp *protocol.Response) (bool, error) { nilResp := false if resp == nil { nilResp = true resp = protocol.AcquireResponse() } - canIdempotentRetry, err := c.doNonNilReqResp(req, resp, before) + canIdempotentRetry, err := c.doNonNilReqResp(req, resp) if nilResp { protocol.ReleaseResponse(resp) @@ -479,13 +474,13 @@ func updateReqTimeout(reqTimeout, compareTimeout time.Duration, before time.Time if left <= 0 { return true, 0 } - if compareTimeout <= 0 && left > compareTimeout { - return false, compareTimeout + if left > compareTimeout { + return false, left } - return false, left + return false, compareTimeout } -func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Response, before time.Time) (bool, error) { +func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Response) (bool, error) { if req == nil { panic("BUG: req cannot be nil") } @@ -507,10 +502,12 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo if c.DisablePathNormalizing { req.URI().DisablePathNormalizing = true } + reqTimeout := req.Options().RequestTimeout() + begin := req.Options().Begin() dialTimeout := rc.dialTimeout - if req.GetTimeout() < dialTimeout || dialTimeout == 0 { - dialTimeout = req.GetTimeout() + if reqTimeout < dialTimeout || dialTimeout == 0 { + dialTimeout = reqTimeout } cc, err := c.acquireConn(dialTimeout) // if getting connection error, fast fail @@ -527,17 +524,16 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo resp.ParseNetAddr(conn) - shouldClose, timeout := updateReqTimeout(req.GetTimeout(), rc.writeTimeout, before) + shouldClose, timeout := updateReqTimeout(reqTimeout, rc.writeTimeout, begin) if shouldClose { c.closeConn(cc) return true, errTimeout } - if timeout > 0 { - if err = conn.SetWriteTimeout(timeout); err != nil { - c.closeConn(cc) - // try another connection if retry is enabled - return true, err - } + + if err = conn.SetWriteTimeout(timeout); err != nil { + c.closeConn(cc) + // try another connection if retry is enabled + return true, err } resetConnection := false @@ -595,19 +591,18 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo return true, err } - shouldClose, timeout = updateReqTimeout(req.GetTimeout(), rc.readTimeout, before) + shouldClose, timeout = updateReqTimeout(reqTimeout, rc.readTimeout, begin) if shouldClose { c.closeConn(cc) return true, errTimeout } - if timeout > 0 { - // Set Deadline every time, since golang has fixed the performance issue - // See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details - if err = conn.SetReadTimeout(timeout); err != nil { - c.closeConn(cc) - // try another connection if retry is enabled - return true, err - } + + // Set Deadline every time, since golang has fixed the performance issue + // See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details + if err = conn.SetReadTimeout(timeout); err != nil { + c.closeConn(cc) + // try another connection if retry is enabled + return true, err } if customSkipBody || req.Header.IsHead() || req.Header.IsConnect() { diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go index b2b8e0313..1e4e0139f 100644 --- a/pkg/protocol/http1/client_test.go +++ b/pkg/protocol/http1/client_test.go @@ -80,7 +80,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) { return mock.SlowReadDialer(addr) }), MaxConns: 1, - MaxConnWaitTimeout: 120 * time.Millisecond, + MaxConnWaitTimeout: 50 * time.Millisecond, }, Addr: "foobar", } @@ -279,7 +279,7 @@ func TestDoNonNilReqResp(t *testing.T) { req := protocol.AcquireRequest() resp := protocol.AcquireResponse() req.SetHost("foobar") - retry, err := c.doNonNilReqResp(req, resp, time.Now()) + retry, err := c.doNonNilReqResp(req, resp) assert.False(t, retry) assert.Nil(t, err) assert.DeepEqual(t, resp.StatusCode(), 400) @@ -300,7 +300,7 @@ func TestDoNonNilReqResp1(t *testing.T) { req := protocol.AcquireRequest() resp := protocol.AcquireResponse() req.SetHost("foobar") - retry, err := c.doNonNilReqResp(req, resp, time.Now()) + retry, err := c.doNonNilReqResp(req, resp) assert.True(t, retry) assert.NotNil(t, err) } diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go index 7e30e5b9d..d10edd23c 100644 --- a/pkg/protocol/request.go +++ b/pkg/protocol/request.go @@ -51,7 +51,6 @@ import ( "net/url" "strings" "sync" - "time" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" @@ -114,10 +113,6 @@ type Request struct { // Request level options, service discovery options etc. options *config.RequestOptions - - // Request timeout. Usually set by DoDeadline or DoTimeout - // if <= 0, means not set - timeout time.Duration } type requestBodyWriter struct { @@ -160,14 +155,6 @@ func (req *Request) AppendBody(p []byte) { req.BodyBuffer().Write(p) //nolint:errcheck } -func (req *Request) SetTimeout(timeout time.Duration) { - req.timeout = timeout -} - -func (req *Request) GetTimeout() time.Duration { - return req.timeout -} - func (req *Request) BodyBuffer() *bytebufferpool.ByteBuffer { if req.body == nil { req.body = requestBodyPool.Get() @@ -218,7 +205,6 @@ func (req *Request) Reset() { req.CloseBodyStream() req.options = nil - req.timeout = 0 } func (req *Request) IsURIParsed() bool { @@ -398,7 +384,6 @@ func (req *Request) CopyToSkipBody(dst *Request) { req.options.CopyTo(dst.options) } - dst.timeout = req.timeout // do not copy multipartForm - it will be automatically // re-created on the first call to MultipartForm. }