Skip to content

Commit

Permalink
caddyhttp: record num. bytes read when response writer is hijacked (#…
Browse files Browse the repository at this point in the history
…6173)

* record the number of bytes read when response writer is hijacked

* record body size when not nil
  • Loading branch information
WeidiDeng authored Apr 17, 2024
1 parent 70953e8 commit e0daa39
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
37 changes: 37 additions & 0 deletions modules/caddyhttp/responsewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ type responseRecorder struct {
size int
wroteHeader bool
stream bool

readSize *int
}

// NewResponseRecorder returns a new ResponseRecorder that can be
Expand Down Expand Up @@ -240,6 +242,12 @@ func (rr *responseRecorder) FlushError() error {
return nil
}

// Private interface so it can only be used in this package
// #TODO: maybe export it later
func (rr *responseRecorder) setReadSize(size *int) {
rr.readSize = size
}

func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
Expand All @@ -249,6 +257,15 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
conn = &hijackedConn{conn, rr}
brw.Writer.Reset(conn)

buffered := brw.Reader.Buffered()
if buffered != 0 {
conn.(*hijackedConn).updateReadSize(buffered)
data, _ := brw.Peek(buffered)
brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn))
} else {
brw.Reader.Reset(conn)
}
return conn, brw, nil
}

Expand All @@ -258,6 +275,24 @@ type hijackedConn struct {
rr *responseRecorder
}

func (hc *hijackedConn) updateReadSize(n int) {
if hc.rr.readSize != nil {
*hc.rr.readSize += n
}
}

func (hc *hijackedConn) Read(p []byte) (int, error) {
n, err := hc.Conn.Read(p)
hc.updateReadSize(n)
return n, err
}

func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) {
n, err := io.Copy(w, hc.Conn)
hc.updateReadSize(int(n))
return n, err
}

func (hc *hijackedConn) Write(p []byte) (int, error) {
n, err := hc.Conn.Write(p)
hc.rr.size += n
Expand Down Expand Up @@ -298,4 +333,6 @@ var (
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
_ io.ReaderFrom = (*responseRecorder)(nil)
_ io.ReaderFrom = (*hijackedConn)(nil)

_ io.WriterTo = (*hijackedConn)(nil)
)
5 changes: 5 additions & 0 deletions modules/caddyhttp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Body != nil {
bodyReader = &lengthReader{Source: r.Body}
r.Body = bodyReader

// should always be true, private interface can only be referenced in the same package
if setReadSizer, ok := wrec.(interface{ setReadSize(*int) }); ok {
setReadSizer.setReadSize(&bodyReader.Length)
}
}

// capture the original version of the request
Expand Down

0 comments on commit e0daa39

Please sign in to comment.