Skip to content

Commit

Permalink
adaptor ResponseWriter - adding Hijack method and pass proper fields (#…
Browse files Browse the repository at this point in the history
…1525)

* adding hijack method and pass proper fields

* adding hijack method and pass proper fields - adding tests

* improve hijack handling, use proper test for hijacking

* extend hijackhandler propogation to NewFastHTTPHandlerFunc

* align hijacking of fasthttp adaptor net request with fasthttp request, safe conn handling for proper release of resources and custom hijack handler for more controlled by hijacking implementation

* Implement actual behaviour of net/http Hijacker

---------

Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
  • Loading branch information
gilwo and erikdubbelboer committed Feb 17, 2024
1 parent 56cb753 commit aefd080
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 2 deletions.
50 changes: 48 additions & 2 deletions fasthttpadaptor/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
package fasthttpadaptor

import (
"bufio"
"io"
"net"
"net/http"
"sync"

"github.com/valyala/fasthttp"
)
Expand Down Expand Up @@ -53,8 +56,10 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}

w := netHTTPResponseWriter{w: ctx.Response.BodyWriter()}
w := netHTTPResponseWriter{
w: ctx.Response.BodyWriter(),
ctx: ctx,
}
h.ServeHTTP(&w, r.WithContext(ctx))

ctx.SetStatusCode(w.StatusCode())
Expand Down Expand Up @@ -86,6 +91,7 @@ type netHTTPResponseWriter struct {
statusCode int
h http.Header
w io.Writer
ctx *fasthttp.RequestCtx
}

func (w *netHTTPResponseWriter) StatusCode() int {
Expand All @@ -111,3 +117,43 @@ func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
}

func (w *netHTTPResponseWriter) Flush() {}

type wrappedConn struct {
net.Conn

wg sync.WaitGroup
once sync.Once
}

func (c *wrappedConn) Close() (err error) {
c.once.Do(func() {
err = c.Conn.Close()
c.wg.Done()
})
return
}

func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Hijack assumes control of the connection, so we need to prevent fasthttp from closing it or
// doing anything else with it.
w.ctx.HijackSetNoResponse(true)

conn := &wrappedConn{Conn: w.ctx.Conn()}
conn.wg.Add(1)
w.ctx.Hijack(func(net.Conn) {
conn.wg.Wait()
})

bufW := bufio.NewWriter(conn)

// Write any unflushed body to the hijacked connection buffer.
unflushedBody := w.ctx.Response.Body()
if len(unflushedBody) > 0 {
if _, err := bufW.Write(unflushedBody); err != nil {
conn.Close()
return nil, nil, err
}
}

return conn, &bufio.ReadWriter{Reader: bufio.NewReader(conn), Writer: bufW}, nil
}
73 changes: 73 additions & 0 deletions fasthttpadaptor/adaptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"net/url"
"reflect"
"testing"
"time"

"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
)

func TestNewFastHTTPHandler(t *testing.T) {
Expand Down Expand Up @@ -143,3 +145,74 @@ func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value a
next(ctx)
}
}

func TestHijack(t *testing.T) {
t.Parallel()

nethttpH := func(w http.ResponseWriter, r *http.Request) {
if f, ok := w.(http.Hijacker); !ok {
t.Errorf("expected http.ResponseWriter to implement http.Hijacker")
} else {
if _, err := w.Write([]byte("foo")); err != nil {
t.Error(err)
}

if c, rw, err := f.Hijack(); err != nil {
t.Error(err)
} else {
if _, err := rw.Write([]byte("bar")); err != nil {
t.Error(err)
}

if err := rw.Flush(); err != nil {
t.Error(err)
}

if err := c.Close(); err != nil {
t.Error(err)
}
}
}
}

s := &fasthttp.Server{
Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)),
}

ln := fasthttputil.NewInmemoryListener()

go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()

clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}

if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}

buf, err := io.ReadAll(c)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

if string(buf) != "foobar" {
t.Errorf("unexpected response: %q. Expecting %q", buf, "foobar")
}

close(clientCh)
}()

select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}

0 comments on commit aefd080

Please sign in to comment.