diff --git a/utils.go b/utils.go index a17a867a..78a4b518 100644 --- a/utils.go +++ b/utils.go @@ -247,32 +247,34 @@ func transferBytes(src io.Reader, dest io.Writer, wg *sync.WaitGroup) (int64, er // tryUpdateConnection attempt to upgrade the connection to a http pdy stream func tryUpdateConnection(req *http.Request, writer http.ResponseWriter, endpoint *url.URL) error { // step: dial the endpoint - tlsConn, err := tryDialEndpoint(endpoint) + server, err := tryDialEndpoint(endpoint) if err != nil { return err } - defer tlsConn.Close() + defer server.Close() - // step: we need to hijack the underlining client connection - clientConn, ok, err := writer.(http.Hijacker).Hijack() - if !ok { + // @check the the response writer implements the Hijack method + if _, ok := writer.(http.Hijacker); !ok { return errors.New("writer does not implement http.Hijacker method") } + + // @step: get the client connection object + client, _, err := writer.(http.Hijacker).Hijack() if err != nil { return err } - defer clientConn.Close() + defer client.Close() // step: write the request to upstream - if err = req.Write(tlsConn); err != nil { + if err = req.Write(server); err != nil { return err } - // step: copy the date between client and upstream endpoint + // @step: copy the data between client and upstream endpoint var wg sync.WaitGroup wg.Add(2) - go transferBytes(tlsConn, clientConn, &wg) - go transferBytes(clientConn, tlsConn, &wg) + go transferBytes(server, client, &wg) + go transferBytes(client, server, &wg) wg.Wait() return nil