Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: release v0.6.5 #822

Merged
merged 9 commits into from
Jun 19, 2023
3 changes: 2 additions & 1 deletion _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ contant = "contant"
referer = "referer"
HeaderReferer = "HeaderReferer"
expectedReferer = "expectedReferer"
Referer = "Referer"
Referer = "Referer"
O_WRONLY = "O_WRONLY"
2 changes: 1 addition & 1 deletion cmd/hz/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func New(c *cli.Context) error {
return cli.Exit(fmt.Errorf("persist manifest failed: %v", err), meta.PersistError)
}
if !args.NeedGoMod && args.IsNew() {
fmt.Println(meta.AddThriftReplace)
logs.Warn(meta.AddThriftReplace)
}

return nil
Expand Down
33 changes: 33 additions & 0 deletions pkg/app/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -1294,3 +1294,36 @@ func (ctx *RequestContext) Bind(obj interface{}) error {
func (ctx *RequestContext) Validate(obj interface{}) error {
return binding.Validate(obj)
}

// VisitAllQueryArgs calls f for each existing query arg.
//
// f must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestContext) VisitAllQueryArgs(f func(key, value []byte)) {
ctx.QueryArgs().VisitAll(f)
}

// VisitAllPostArgs calls f for each existing post arg.
//
// f must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestContext) VisitAllPostArgs(f func(key, value []byte)) {
ctx.Request.PostArgs().VisitAll(f)
}

// VisitAllHeaders calls f for each request header.
//
// f must not retain references to key and/or value after returning.
// Copy key and/or value contents before returning if you need retaining them.
//
// To get the headers in order they were received use VisitAllInOrder.
func (ctx *RequestContext) VisitAllHeaders(f func(key, value []byte)) {
ctx.Request.Header.VisitAll(f)
}

// VisitAllCookie calls f for each request cookie.
//
// f must not retain references to key and/or value after returning.
func (ctx *RequestContext) VisitAllCookie(f func(key, value []byte)) {
ctx.Request.Header.VisitAllCookie(f)
}
54 changes: 54 additions & 0 deletions pkg/app/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1248,3 +1248,57 @@ func TestRequestContext_SetCookiePathEmpty(t *testing.T) {
c.SetCookie("user", "hertz", 1, "", "localhost", protocol.CookieSameSiteDisabled, true, true)
assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure", c.Response.Header.Get("Set-Cookie"))
}

func TestRequestContext_VisitAll(t *testing.T) {
t.Run("VisitAllQueryArgs", func(t *testing.T) {
c := NewContext(0)
var s []string
c.QueryArgs().Add("cloudwego", "hertz")
c.QueryArgs().Add("hello", "world")
c.VisitAllQueryArgs(func(key, value []byte) {
s = append(s, string(key), string(value))
})
assert.DeepEqual(t, []string{"cloudwego", "hertz", "hello", "world"}, s)
})

t.Run("VisitAllPostArgs", func(t *testing.T) {
c := NewContext(0)
var s []string
c.PostArgs().Add("cloudwego", "hertz")
c.PostArgs().Add("hello", "world")
c.VisitAllPostArgs(func(key, value []byte) {
s = append(s, string(key), string(value))
})
assert.DeepEqual(t, []string{"cloudwego", "hertz", "hello", "world"}, s)
})

t.Run("VisitAllCookie", func(t *testing.T) {
c := NewContext(0)
var s []string
c.Request.Header.Set("Cookie", "aaa=bbb;ccc=ddd")
c.VisitAllCookie(func(key, value []byte) {
s = append(s, string(key), string(value))
})
assert.DeepEqual(t, []string{"aaa", "bbb", "ccc", "ddd"}, s)
})

t.Run("VisitAllHeaders", func(t *testing.T) {
c := NewContext(0)
c.Request.Header.Set("xxx", "yyy")
c.Request.Header.Set("xxx2", "yyy2")
c.VisitAllHeaders(
func(k, v []byte) {
key := string(k)
value := string(v)
if key != "Xxx" && key != "Xxx2" {
t.Fatalf("Unexpected %v. Expected %v", key, "xxx or yyy")
}
if key == "Xxx" && value != "yyy" {
t.Fatalf("Unexpected %v. Expected %v", value, "yyy")
}
if key == "Xxx2" && value != "yyy2" {
t.Fatalf("Unexpected %v. Expected %v", value, "yyy2")
}
})
})
}
16 changes: 16 additions & 0 deletions pkg/common/test/mock/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ func NewSlowReadConn(source string) *SlowReadConn {
return &SlowReadConn{Conn: NewConn(source)}
}

type ErrorReadConn struct {
*Conn
errorToReturn error
}

func NewErrorReadConn(err error) *ErrorReadConn {
return &ErrorReadConn{
Conn: NewConn(""),
errorToReturn: err,
}
}

func (er *ErrorReadConn) Peek(n int) ([]byte, error) {
return nil, er.errorToReturn
}

type SlowWriteConn struct {
*Conn
writeTimeout time.Duration
Expand Down
2 changes: 1 addition & 1 deletion pkg/network/standard/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (t *transport) serve() (err error) {
if err != nil {
return err
}
hlog.SystemLogger().Infof("HERTZ: HTTP server listening on address=%s", t.ln.Addr().String())
hlog.SystemLogger().Infof("HTTP server listening on address=%s", t.ln.Addr().String())
for {
ctx := context.Background()
conn, err := t.ln.Accept()
Expand Down
3 changes: 3 additions & 0 deletions pkg/protocol/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ func (h *ResponseHeader) CopyTo(dst *ResponseHeader) {
dst.server = append(dst.server[:0], h.server...)
dst.h = copyArgs(dst.h, h.h)
dst.cookies = copyArgs(dst.cookies, h.cookies)
dst.protocol = h.protocol
dst.headerLength = h.headerLength
h.Trailer().CopyTo(dst.Trailer())
}

Expand Down Expand Up @@ -1107,6 +1109,7 @@ func (h *RequestHeader) CopyTo(dst *RequestHeader) {
dst.cookies = copyArgs(dst.cookies, h.cookies)
dst.cookiesCollected = h.cookiesCollected
dst.rawHeaders = append(dst.rawHeaders[:0], h.rawHeaders...)
dst.protocol = h.protocol
}

// Peek returns header value for the given key.
Expand Down
58 changes: 58 additions & 0 deletions pkg/protocol/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,61 @@ func expectResponseHeaderAll(t *testing.T, h *ResponseHeader, key string, expect
}
assert.DeepEqual(t, h.PeekAll(key), expectedValue)
}

func TestRequestHeaderCopyTo(t *testing.T) {
t.Parallel()

h, hCopy := &RequestHeader{}, &RequestHeader{}
h.SetProtocol(consts.HTTP10)
h.SetMethod(consts.MethodPatch)
h.SetNoDefaultContentType(true)
h.Add(consts.HeaderConnection, "keep-alive")
h.Add("Content-Type", "aaa")
h.Add(consts.HeaderHost, "aaabbb")
h.Add("User-Agent", "asdfas")
h.Add("Content-Length", "1123")
h.Add("Cookie", "foobar=baz")
h.Add("aaa", "aaa")
h.Add("aaa", "bbb")

h.CopyTo(hCopy)
expectRequestHeaderAll(t, hCopy, consts.HeaderConnection, [][]byte{[]byte("keep-alive")})
expectRequestHeaderAll(t, hCopy, "Content-Type", [][]byte{[]byte("aaa")})
expectRequestHeaderAll(t, hCopy, consts.HeaderHost, [][]byte{[]byte("aaabbb")})
expectRequestHeaderAll(t, hCopy, "User-Agent", [][]byte{[]byte("asdfas")})
expectRequestHeaderAll(t, hCopy, "Content-Length", [][]byte{[]byte("1123")})
expectRequestHeaderAll(t, hCopy, "Cookie", [][]byte{[]byte("foobar=baz")})
expectRequestHeaderAll(t, hCopy, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")})
assert.DeepEqual(t, hCopy.GetProtocol(), consts.HTTP10)
assert.DeepEqual(t, hCopy.noDefaultContentType, true)
assert.DeepEqual(t, string(hCopy.Method()), consts.MethodPatch)
}

func TestResponseHeaderCopyTo(t *testing.T) {
t.Parallel()

h, hCopy := &ResponseHeader{}, &ResponseHeader{}
h.SetProtocol(consts.HTTP10)
h.SetHeaderLength(100)
h.SetNoDefaultContentType(true)
h.Add(consts.HeaderContentType, "aaa/bbb")
h.Add(consts.HeaderContentEncoding, "gzip")
h.Add(consts.HeaderConnection, "close")
h.Add(consts.HeaderContentLength, "1234")
h.Add(consts.HeaderServer, "aaaa")
h.Add(consts.HeaderSetCookie, "cccc")
h.Add("aaa", "aaa")
h.Add("aaa", "bbb")

h.CopyTo(hCopy)
expectResponseHeaderAll(t, hCopy, consts.HeaderContentType, [][]byte{[]byte("aaa/bbb")})
expectResponseHeaderAll(t, hCopy, consts.HeaderContentEncoding, [][]byte{[]byte("gzip")})
expectResponseHeaderAll(t, hCopy, consts.HeaderConnection, [][]byte{[]byte("close")})
expectResponseHeaderAll(t, hCopy, consts.HeaderContentLength, [][]byte{[]byte("1234")})
expectResponseHeaderAll(t, hCopy, consts.HeaderServer, [][]byte{[]byte("aaaa")})
expectResponseHeaderAll(t, hCopy, consts.HeaderSetCookie, [][]byte{[]byte("cccc")})
expectResponseHeaderAll(t, hCopy, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")})
assert.DeepEqual(t, hCopy.GetProtocol(), consts.HTTP10)
assert.DeepEqual(t, hCopy.noDefaultContentType, true)
assert.DeepEqual(t, hCopy.GetHeaderLength(), 100)
}
4 changes: 2 additions & 2 deletions pkg/protocol/http1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,8 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo
if !c.ResponseBodyStream {
err = respI.ReadHeaderAndLimitBody(resp, zr, c.MaxResponseBodySize)
} else {
err = respI.ReadBodyStream(resp, zr, c.MaxResponseBodySize, func() error {
if shouldCloseConn {
err = respI.ReadBodyStream(resp, zr, c.MaxResponseBodySize, func(shouldClose bool) error {
if shouldCloseConn || shouldClose {
c.closeConn(cc)
} else {
c.releaseConn(cc)
Expand Down
1 change: 1 addition & 0 deletions pkg/protocol/http1/ext/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ func (rs *bodyStream) skipRest() error {
}

// ReleaseBodyStream releases the body stream.
// Error of skipRest may be returned if there is one.
//
// NOTE: Be careful to use this method unless you know what it's for.
func ReleaseBodyStream(requestReader io.Reader) (err error) {
Expand Down
37 changes: 19 additions & 18 deletions pkg/protocol/http1/resp/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (

"github.com/cloudwego/hertz/pkg/common/bytebufferpool"
errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
Expand Down Expand Up @@ -107,42 +108,42 @@ func ReadHeaderAndLimitBody(resp *protocol.Response, r network.Reader, maxBodySi
}
}

var cLen int

if !resp.MustSkipBody() {
bodyBuf := resp.BodyBuffer()
bodyBuf.Reset()
bodyBuf.B, err = ext.ReadBody(r, resp.Header.ContentLength(), maxBodySize, bodyBuf.B)
if err != nil {
return err
}
cLen = len(bodyBuf.B)
}

if resp.Header.ContentLength() == -1 {
err = ext.ReadTrailer(resp.Header.Trailer(), r)
if err != nil && err != io.EOF {
return err
if resp.Header.ContentLength() == -1 {
err = ext.ReadTrailer(resp.Header.Trailer(), r)
if err != nil && err != io.EOF {
return err
}
}
}

if !resp.MustSkipBody() {
resp.Header.SetContentLength(cLen)
resp.Header.SetContentLength(len(bodyBuf.B))
}

return nil
}

type clientRespStream struct {
r io.Reader
closeCallback func() error
closeCallback func(shouldClose bool) error
}

func (c *clientRespStream) Close() (err error) {
runtime.SetFinalizer(c, nil)
ext.ReleaseBodyStream(c.r)
// If error happened in release, the connection may be in abnormal state.
// Close it in the callback in order to avoid other unexpected problems.
err = ext.ReleaseBodyStream(c.r)
shouldClose := false
if err != nil {
shouldClose = true
hlog.Warnf("connection will be closed instead of recycled because an error occurred during the stream body release: %s", err.Error())
}
if c.closeCallback != nil {
err = c.closeCallback()
err = c.closeCallback(shouldClose)
}
c.reset()
return
Expand All @@ -164,7 +165,7 @@ var clientRespStreamPool = sync.Pool{
},
}

func convertClientRespStream(bs io.Reader, fn func() error) *clientRespStream {
func convertClientRespStream(bs io.Reader, fn func(shouldClose bool) error) *clientRespStream {
clientStream := clientRespStreamPool.Get().(*clientRespStream)
clientStream.r = bs
clientStream.closeCallback = fn
Expand All @@ -173,7 +174,7 @@ func convertClientRespStream(bs io.Reader, fn func() error) *clientRespStream {
}

// ReadBodyStream reads response body in stream
func ReadBodyStream(resp *protocol.Response, r network.Reader, maxBodySize int, closeCallBack func() error) error {
func ReadBodyStream(resp *protocol.Response, r network.Reader, maxBodySize int, closeCallBack func(shouldClose bool) error) error {
resp.ResetBody()
err := ReadHeader(&resp.Header, r)
if err != nil {
Expand Down
37 changes: 33 additions & 4 deletions pkg/protocol/http1/resp/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import (
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/cloudwego/hertz/pkg/protocol/http1/ext"
"github.com/cloudwego/netpoll"
)

Expand Down Expand Up @@ -411,15 +412,16 @@ func TestResponseReadLimitBody(t *testing.T) {
}

func TestResponseReadWithoutBody(t *testing.T) {
t.Parallel()

var resp protocol.Response

testResponseReadWithoutBody(t, &resp, "HTTP/1.1 304 Not Modified\r\nContent-Type: aa\r\nContent-Encoding: gzip\r\nContent-Length: 1235\r\n\r\n", false,
consts.StatusNotModified, 1235, "aa", nil, "gzip", consts.HTTP11)

testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTrailer: Foo\r\nContent-Encoding: deflate\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo: bar\r\n\r\n", false,
consts.StatusNoContent, -1, "aab", map[string]string{"Foo": "bar"}, "deflate", consts.HTTP11)
testResponseReadWithoutBody(t, &resp, "HTTP/1.1 200 Foo Bar\r\nContent-Type: aab\r\nTrailer: Foo\r\nContent-Encoding: deflate\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo: bar\r\n\r\nHTTP/1.2", false,
consts.StatusOK, 0, "aab", map[string]string{"Foo": "bar"}, "deflate", consts.HTTP11)

testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTrailer: Foo\r\nContent-Encoding: deflate\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo: bar\r\n\r\nHTTP/1.2", true,
consts.StatusNoContent, -1, "aab", nil, "deflate", consts.HTTP11)

testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Encoding: gzip\r\nContent-Length: 3434\r\n\r\n", false,
123, 3434, "xxx", nil, "gzip", consts.HTTP11)
Expand Down Expand Up @@ -524,6 +526,12 @@ func verifyResponseTrailer(t *testing.T, h *protocol.ResponseHeader, expectedTra
t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, got)
}
}

h.Trailer().VisitAll(func(key, value []byte) {
if v := expectedTrailers[string(key)]; string(value) != v {
t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", string(key), v, string(value))
}
})
}

func testResponseReadLimitBodyError(t *testing.T, s string, maxBodySize int) {
Expand Down Expand Up @@ -607,6 +615,27 @@ func testResponseBodyStreamWithTrailer(t *testing.T, body []byte, disableNormali
}
}

func TestResponseReadBodyStreamBadReader(t *testing.T) {
t.Parallel()

resp := protocol.AcquireResponse()

errReader := mock.NewErrorReadConn(errors.New("test error"))

bodyBuf := resp.BodyBuffer()
bodyBuf.Reset()

bodyStream := ext.AcquireBodyStream(bodyBuf, errReader, resp.Header.Trailer(), 100)
resp.ConstructBodyStream(bodyBuf, convertClientRespStream(bodyStream, func(shouldClose bool) error {
assert.True(t, shouldClose)
return nil
}))

stBody := resp.BodyStream()
closer, _ := stBody.(io.Closer)
closer.Close()
}

func TestSetResponseBodyStreamFixedSize(t *testing.T) {
t.Parallel()

Expand Down
Loading