Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(http1): support flush chunk in handler #610

Merged
9 changes: 9 additions & 0 deletions pkg/app/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ type RequestContext struct {
formValueFunc FormValueFunc
}

// Flush is the shortcut for ctx.Response.GetHijackWriter().Flush().
// Will return nil if the response writer is not hijacked.
func (ctx *RequestContext) Flush() error {
if ctx.Response.GetHijackWriter() == nil {
return nil
}
return ctx.Response.GetHijackWriter().Flush()
}

func (ctx *RequestContext) SetClientIPFunc(f ClientIP) {
ctx.clientIPFunc = f
}
Expand Down
4 changes: 0 additions & 4 deletions pkg/app/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@ func (r *fsSmallFileReader) WriteTo(w io.Writer) (int64, error) {
// Directory contents is returned if path points to directory.
//
// Use ServeFileUncompressed is you don't need serving compressed file contents.
//
// See also RequestCtx.SendFile.
func ServeFile(ctx *RequestContext, path string) {
rootFSOnce.Do(func() {
rootFSHandler = rootFS.NewRequestHandler()
Expand Down Expand Up @@ -1204,8 +1202,6 @@ func stripLeadingSlashes(path []byte, stripSlashes int) []byte {
//
// ServeFile may be used for saving network traffic when serving files
// with good compression ratio.
//
// See also RequestCtx.SendFile.
func ServeFileUncompressed(ctx *RequestContext, path string) {
ctx.Request.Header.DelBytes(bytestr.StrAcceptEncoding)
ServeFile(ctx, path)
Expand Down
24 changes: 21 additions & 3 deletions pkg/common/test/mock/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ type Conn struct {
readTimeout time.Duration
zr network.Reader
zw network.ReadWriter
wroteLen int
}

type Recorder interface {
network.Reader
WroteLen() int
}

func (m *Conn) SetWriteTimeout(t time.Duration) error {
Expand Down Expand Up @@ -98,19 +104,31 @@ func (m *Conn) Len() int {
}

func (m *Conn) Malloc(n int) (buf []byte, err error) {
m.wroteLen += n
return m.zw.Malloc(n)
}

func (m *Conn) WriteBinary(b []byte) (n int, err error) {
return m.zw.WriteBinary(b)
n, err = m.zw.WriteBinary(b)
m.wroteLen += n
return n, err
}

func (m *Conn) Flush() error {
return m.zw.Flush()
}

func (m *Conn) WriterRecorder() network.Reader {
return m.zw
func (m *Conn) WriterRecorder() Recorder {
return &recorder{c: m, Reader: m.zw}
}

type recorder struct {
c *Conn
network.Reader
}

func (r *recorder) WroteLen() int {
return r.c.wroteLen
}

func (m *SlowReadConn) Peek(i int) ([]byte, error) {
Expand Down
47 changes: 47 additions & 0 deletions pkg/common/test/mock/writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2023 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mock

import "bytes"

type ExtWriter struct {
tmp []byte
Buf *bytes.Buffer
IsFinal *bool
welkeyever marked this conversation as resolved.
Show resolved Hide resolved
}

func (m *ExtWriter) Write(p []byte) (n int, err error) {
m.tmp = p
return len(p), nil
}

func (m *ExtWriter) Flush() error {
_, err := m.Buf.Write(m.tmp)
return err
}

func (m *ExtWriter) Finalize() error {
if !*m.IsFinal {
*m.IsFinal = true
}
return nil
}

func (m *ExtWriter) SetBody(body []byte) {
m.Buf.Reset()
m.tmp = body
}
9 changes: 9 additions & 0 deletions pkg/network/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,12 @@ func NewWriter(w io.Writer) Writer {
w: w,
}
}

type ExtWriter interface {
io.Writer
Flush() error

// Finalize will be called by framework before the writer is released.
// Implementations must guarantee that Finalize is safe for multiple calls.
Finalize() error
}
9 changes: 6 additions & 3 deletions pkg/protocol/http1/ext/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ func WriteBodyChunked(w network.Writer, r io.Reader) error {
panic("BUG: io.Reader returned 0, nil")
}
if err == io.EOF {
if err = writeChunk(w, buf[:0]); err != nil {
if err = WriteChunk(w, buf[:0], true); err != nil {
break
}
err = nil
}
break
}
if err = writeChunk(w, buf[:n]); err != nil {
if err = WriteChunk(w, buf[:n], true); err != nil {
break
}
}
Expand Down Expand Up @@ -287,7 +287,7 @@ func round2(n int) int {
return 1 << x
}

func writeChunk(w network.Writer, b []byte) (err error) {
func WriteChunk(w network.Writer, b []byte, withFlush bool) (err error) {
n := len(b)
if err = bytesconv.WriteHexInt(w, n); err != nil {
return err
Expand All @@ -303,6 +303,9 @@ func writeChunk(w network.Writer, b []byte) (err error) {
w.WriteBinary(bytestr.StrCRLF) //nolint:errcheck
}

if !withFlush {
return nil
}
err = w.Flush()
return
}
Expand Down
76 changes: 76 additions & 0 deletions pkg/protocol/http1/resp/writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2023 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package resp
FGYFFFF marked this conversation as resolved.
Show resolved Hide resolved

import (
"sync"

"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/http1/ext"
)

type chunkedBodyWriter struct {
sync.Once
finalizeErr error
wroteHeader bool
r *protocol.Response
w network.Writer
}

// Write will encode chunked p before writing
// It will only return the length of p and a nil error if the writing is successful or 0, error otherwise.
//
// NOTE: Write will use the user buffer to flush.
// Before flush successfully, the buffer b should be valid.
func (c *chunkedBodyWriter) Write(p []byte) (n int, err error) {
if !c.wroteHeader {
welkeyever marked this conversation as resolved.
Show resolved Hide resolved
c.r.Header.SetContentLength(-1)
if err = WriteHeader(&c.r.Header, c.w); err != nil {
return
}
c.wroteHeader = true
}
if err = ext.WriteChunk(c.w, p, false); err != nil {
return
}
return len(p), nil
}

func (c *chunkedBodyWriter) Flush() error {
return c.w.Flush()
}

// Finalize will write the ending chunk as well as trailer and flush the writer.
// Warning: do not call this method by yourself, unless you know what you are doing.
func (c *chunkedBodyWriter) Finalize() error {
c.Do(func() {
c.finalizeErr = ext.WriteChunk(c.w, nil, true)
if c.finalizeErr != nil {
return
}
c.finalizeErr = ext.WriteTrailer(c.r.Header.Trailer(), c.w)
})
return c.finalizeErr
}

func NewChunkedBodyWriter(r *protocol.Response, w network.Writer) network.ExtWriter {
return &chunkedBodyWriter{
r: r,
w: w,
}
}
53 changes: 53 additions & 0 deletions pkg/protocol/http1/resp/writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2023 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package resp

import (
"strings"
"testing"

"github.com/cloudwego/hertz/internal/bytestr"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/test/mock"
"github.com/cloudwego/hertz/pkg/protocol"
)

func TestNewChunkedBodyWriter(t *testing.T) {
response := protocol.AcquireResponse()
mockConn := mock.NewConn("")
w := NewChunkedBodyWriter(response, mockConn)
w.Write([]byte("hello"))
w.Flush()
out, _ := mockConn.WriterRecorder().ReadBinary(mockConn.WriterRecorder().WroteLen())
assert.True(t, strings.Contains(string(out), "Transfer-Encoding: chunked"))
assert.True(t, strings.Contains(string(out), "5"+string(bytestr.StrCRLF)+"hello"))
assert.False(t, strings.Contains(string(out), "0"+string(bytestr.StrCRLF)+string(bytestr.StrCRLF)))
}

func TestNewChunkedBodyWriter1(t *testing.T) {
response := protocol.AcquireResponse()
mockConn := mock.NewConn("")
w := NewChunkedBodyWriter(response, mockConn)
w.Write([]byte("hello"))
w.Flush()
w.Finalize()
w.Flush()
out, _ := mockConn.WriterRecorder().ReadBinary(mockConn.WriterRecorder().WroteLen())
assert.True(t, strings.Contains(string(out), "Transfer-Encoding: chunked"))
assert.True(t, strings.Contains(string(out), "5"+string(bytestr.StrCRLF)+"hello"))
assert.True(t, strings.Contains(string(out), "0"+string(bytestr.StrCRLF)+string(bytestr.StrCRLF)))
}
5 changes: 5 additions & 0 deletions pkg/protocol/http1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,11 @@ func writeErrorResponse(zw network.Writer, ctx *app.RequestContext, serverName [
}

func writeResponse(ctx *app.RequestContext, w network.Writer) error {
// Skip default response writing logic if it has been hijacked
if ctx.Response.GetHijackWriter() != nil {
return ctx.Response.GetHijackWriter().Finalize()
welkeyever marked this conversation as resolved.
Show resolved Hide resolved
}

err := resp.Write(&ctx.Response, w)
if err != nil {
return err
Expand Down
Loading