Skip to content

Commit

Permalink
test: add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
wzekin committed Jan 4, 2023
1 parent 699527f commit d8c5805
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 54 deletions.
6 changes: 2 additions & 4 deletions pkg/protocol/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ import (
)

var (
errBadTrailer = errs.NewPublic("contain forbidden trailer")

ServerDate atomic.Value
ServerDateOnce sync.Once // serverDateOnce.Do(updateServerDate)
)
Expand Down Expand Up @@ -165,7 +163,7 @@ func (h *ResponseHeader) AddTrailerBytes(trailer []byte) error {
}
// Forbidden by RFC 7230, section 4.1.2
if ext.IsBadTrailer(key) {
err = errBadTrailer
err = errs.NewPublicf("forbidden trailer key: %q", key)
continue
}
h.bufKV.key = append(h.bufKV.key[:0], key...)
Expand Down Expand Up @@ -323,7 +321,7 @@ func (h *RequestHeader) AddTrailerBytes(trailer []byte) error {
}
// Forbidden by RFC 7230, section 4.1.2
if ext.IsBadTrailer(key) {
err = errBadTrailer
err = errs.NewPublicf("forbidden trailer key: %q", key)
continue
}
h.bufKV.key = append(h.bufKV.key[:0], key...)
Expand Down
26 changes: 8 additions & 18 deletions pkg/protocol/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,24 +586,14 @@ func TestRequestAddTrailer(t *testing.T) {

func TestRequestAddTrailerError(t *testing.T) {
var h RequestHeader
assert.DeepEqual(t, h.AddTrailer(consts.HeaderAuthorization), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderContentEncoding), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderContentLength), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderContentType), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderContentRange), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderConnection), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderExpect), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderHost), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderKeepAlive), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderMaxForwards), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderProxyConnection), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderProxyAuthenticate), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderProxyAuthorization), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderRange), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderTE), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderTrailer), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderTransferEncoding), errBadTrailer)
assert.DeepEqual(t, h.AddTrailer(consts.HeaderWWWAuthenticate), errBadTrailer)
assert.NotNil(t, h.AddTrailer(consts.HeaderContentType))
assert.NotNil(t, h.AddTrailer(consts.HeaderProxyConnection))
}

func TestResponseAddTrailerError(t *testing.T) {
var h ResponseHeader
assert.NotNil(t, h.AddTrailer(consts.HeaderContentType))
assert.NotNil(t, h.AddTrailer(consts.HeaderProxyConnection))
}

func TestRequestSetTrailer(t *testing.T) {
Expand Down
23 changes: 23 additions & 0 deletions pkg/protocol/http1/ext/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"strings"
"testing"

"github.com/cloudwego/hertz/internal/bytestr"

"github.com/cloudwego/hertz/pkg/common/test/assert"
)

Expand All @@ -43,3 +45,24 @@ func Test_isOnlyCRLF(t *testing.T) {
assert.True(t, isOnlyCRLF([]byte("\r\n")))
assert.True(t, isOnlyCRLF([]byte("\n")))
}

func TestIsBadTrailer(t *testing.T) {
assert.True(t, IsBadTrailer(bytestr.StrAuthorization))
assert.True(t, IsBadTrailer(bytestr.StrContentEncoding))
assert.True(t, IsBadTrailer(bytestr.StrContentLength))
assert.True(t, IsBadTrailer(bytestr.StrContentType))
assert.True(t, IsBadTrailer(bytestr.StrContentRange))
assert.True(t, IsBadTrailer(bytestr.StrConnection))
assert.True(t, IsBadTrailer(bytestr.StrExpect))
assert.True(t, IsBadTrailer(bytestr.StrHost))
assert.True(t, IsBadTrailer(bytestr.StrKeepAlive))
assert.True(t, IsBadTrailer(bytestr.StrMaxForwards))
assert.True(t, IsBadTrailer(bytestr.StrProxyConnection))
assert.True(t, IsBadTrailer(bytestr.StrProxyAuthenticate))
assert.True(t, IsBadTrailer(bytestr.StrProxyAuthorization))
assert.True(t, IsBadTrailer(bytestr.StrRange))
assert.True(t, IsBadTrailer(bytestr.StrTE))
assert.True(t, IsBadTrailer(bytestr.StrTrailer))
assert.True(t, IsBadTrailer(bytestr.StrTransferEncoding))
assert.True(t, IsBadTrailer(bytestr.StrWWWAuthenticate))
}
4 changes: 2 additions & 2 deletions pkg/protocol/http1/req/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ func tryReadTrailer(h *protocol.RequestHeader, r network.Reader, n int) error {
return io.EOF
}

return fmt.Errorf("error when reading request trailer: %w", err)
return errs.NewPublicf("error when reading request trailer: %w", err)
}
b = ext.MustPeekBuffered(r)
headersLen, errParse := parseTrailer(h, b)
Expand Down Expand Up @@ -340,7 +340,7 @@ func parseTrailer(h *protocol.RequestHeader, buf []byte) (int, error) {
}
// Forbidden by RFC 7230, section 4.1.2
if ext.IsBadTrailer(s.Key) {
err = fmt.Errorf("forbidden trailer key %q", s.Key)
err = errs.NewPublicf("forbidden trailer key: %q", s.Key)
continue
}
h.AddArgBytes(s.Key, s.Value, protocol.ArgsHasValue)
Expand Down
20 changes: 10 additions & 10 deletions pkg/protocol/http1/req/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ func GetHTTP1Request(req *protocol.Request) fmt.Stringer {
//
// If MayContinue returns true, the caller must:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
//
// io.EOF is returned if r is closed before reading the first header byte.
func ReadHeaderAndLimitBody(req *protocol.Request, r network.Reader, maxBodySize int, preParse ...bool) error {
Expand All @@ -133,11 +133,11 @@ func ReadHeaderAndLimitBody(req *protocol.Request, r network.Reader, maxBodySize
//
// If MayContinue returns true, the caller must:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
//
// io.EOF is returned if r is closed before reading the first header byte.
func Read(req *protocol.Request, r network.Reader, preParse ...bool) error {
Expand Down
6 changes: 3 additions & 3 deletions pkg/protocol/http1/resp/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func tryReadTrailer(h *protocol.ResponseHeader, r network.Reader, n int) error {
return io.EOF
}

return fmt.Errorf("error when reading request trailer: %w", err)
return errs.NewPublicf("error when reading request trailer: %w", err)
}
b = ext.MustPeekBuffered(r)
headersLen, errParse := parseTrailer(h, b)
Expand Down Expand Up @@ -326,12 +326,12 @@ func parseTrailer(h *protocol.ResponseHeader, buf []byte) (int, error) {
for s.Next() {
if len(s.Key) > 0 {
if bytes.IndexByte(s.Key, ' ') != -1 || bytes.IndexByte(s.Key, '\t') != -1 {
err = fmt.Errorf("invalid trailer key %q", s.Key)
err = errs.NewPublicf("invalid trailer key: %q", s.Key)
continue
}
// Forbidden by RFC 7230, section 4.1.2
if ext.IsBadTrailer(s.Key) {
err = fmt.Errorf("forbidden trailer key %q", s.Key)
err = errs.NewPublicf("forbidden trailer key: %q", s.Key)
continue
}
h.AddArgBytes(s.Key, s.Value, protocol.ArgsHasValue)
Expand Down
103 changes: 88 additions & 15 deletions pkg/protocol/http1/resp/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import (
"testing"

"github.com/cloudwego/hertz/internal/bytestr"
"github.com/cloudwego/hertz/pkg/common/bytebufferpool"
errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/test/mock"
Expand Down Expand Up @@ -712,20 +711,94 @@ func testSetResponseBodyStreamChunked(t *testing.T, body string, trailer map[str
}
}

func TestResponseStream(t *testing.T) {
var pool bytebufferpool.Pool
var resp protocol.Response
bodybuf := pool.Get()
body := mock.NewZeroCopyReader("5\r\n56789\r\n0\r\nfoo: bar\r\n\r\n")
stream := AcquireResponseStream(bodybuf, body, -1, &resp.Header)
byteSlice := make([]byte, 4096)
_, err := stream.Read(byteSlice)
func testResponseReadBodyStreamSuccess(t *testing.T, resp *protocol.Response, response string, expectedStatusCode, expectedContentLength int,
expectedContentType, expectedBody string, expectedTrailer map[string]string,
) {
zr := mock.NewZeroCopyReader(response)
err := ReadBodyStream(resp, zr, 0, nil)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
assert.True(t, resp.IsBodyStream())

body, err := io.ReadAll(resp.BodyStream())
if err != nil && err != io.EOF {
t.Fatalf("Unexpected error: %s", err)
}
verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType, "")
if !bytes.Equal(body, []byte(expectedBody)) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody))
}
verifyResponseTrailer(t, &resp.Header, expectedTrailer)
}

func testResponseReadBodyStreamBadTrailer(t *testing.T, resp *protocol.Response, response string) {
zr := mock.NewZeroCopyReader(response)
err := ReadBodyStream(resp, zr, 0, nil)
if err != nil {
t.Fatalf("unexcepted error when reading response: %s", err)
t.Fatalf("Unexpected error: %s", err)
}
assert.True(t, resp.IsBodyStream())

_, err = io.ReadAll(resp.BodyStream())
if err == nil || err == io.EOF {
t.Fatalf("expected error when reading response.")
}
_, err = stream.Read(byteSlice)
assert.DeepEqual(t, err, io.EOF)
assert.DeepEqual(t, string(bytes.Trim(byteSlice, "\x00")), "56789")
verifyResponseTrailer(t, &resp.Header, map[string]string{"Foo": "bar"})
assert.Nil(t, ReleaseResponseStream(stream))
}

func TestResponseReadBodyStream(t *testing.T) {
t.Parallel()

resp := &protocol.Response{}

// usual response
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789",
consts.StatusOK, 10, "foo/bar", "0123456789", nil)

// zero response
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 500 OK\r\nContent-Length: 0\r\nContent-Type: foo/bar\r\n\r\n",
consts.StatusInternalServerError, 0, "foo/bar", "", nil)

// response with trailer
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\nfoo: bar\r\n\r\n",
consts.StatusMultipleChoices, -1, "bar", "56789", map[string]string{"Foo": "bar"})

// response with trailer disableNormalizing
resp.Header.DisableNormalizing()
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\nfoo: bar\r\n\r\n",
consts.StatusMultipleChoices, -1, "bar", "56789", map[string]string{"foo": "bar"})

// no content-length ('identity' transfer-encoding)
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: foobar\r\n\r\nzxxxx",
consts.StatusOK, -2, "foobar", "zxxxx", nil)

// explicitly stated 'Transfer-Encoding: identity'
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 234 ss\r\nContent-Type: xxx\r\n\r\nxag",
234, -2, "xxx", "xag", nil)

// big 'identity' response
body := string(mock.CreateFixedBody(100500))
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\n\r\n"+body,
consts.StatusOK, -2, "aa", body, nil)

// chunked response
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\nFoo2: bar2\r\n\r\n",
200, -1, "text/html", "qwerty", map[string]string{"Foo2": "bar2"})

// chunked response with non-chunked Transfer-Encoding.
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 230 OK\r\nContent-Type: text\r\nTransfer-Encoding: aaabbb\r\n\r\n2\r\ner\r\n2\r\nty\r\n0\r\nFoo3: bar3\r\n\r\n",
230, -1, "text", "erty", map[string]string{"Foo3": "bar3"})

// chunked response with empty body
testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo5: bar5\r\n\r\n",
consts.StatusOK, -1, "text/html", "", map[string]string{"Foo5": "bar5"})
}

func TestResponseReadBodyStreamBadTrailer(t *testing.T) {
t.Parallel()

resp := &protocol.Response{}

testResponseReadBodyStreamBadTrailer(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\ncontent-type: bar\r\n\r\n")
testResponseReadBodyStreamBadTrailer(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\nproxy-connection: bar2\r\n\r\n")
}
8 changes: 6 additions & 2 deletions pkg/protocol/http1/resp/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/netpoll"
)

var responseStreamPool = sync.Pool{
Expand Down Expand Up @@ -116,6 +117,9 @@ func (rs *responseStream) Read(p []byte) (int, error) {
return copied, err
}

if rs.contentLength == -2 {
}

if rs.offset == rs.contentLength {
return 0, io.EOF
}
Expand Down Expand Up @@ -154,8 +158,8 @@ func (rs *responseStream) Read(p []byte) (int, error) {

if err != nil {
// the data on stream may be incomplete
if err == io.EOF {
if rs.offset != rs.contentLength {
if err == io.EOF || err == netpoll.ErrEOF {
if rs.offset != rs.contentLength && rs.contentLength != -2 {
err = io.ErrUnexpectedEOF
}
// ensure that skipRest works fine
Expand Down

0 comments on commit d8c5805

Please sign in to comment.