Skip to content

Commit

Permalink
websocket: strawman http2 support
Browse files Browse the repository at this point in the history
This patch adds http2 support to x/net/websocket.

It is still pretty hacky and not well tested yet, but
it shows that it can be done.

Change-Id: I123253a74a2dbb6e42e7e31b724362814da112a5
  • Loading branch information
ethanpailes committed Jun 2, 2022
1 parent d233d0c commit df38e5d
Show file tree
Hide file tree
Showing 7 changed files with 424 additions and 4 deletions.
2 changes: 1 addition & 1 deletion http2/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestHTTP2Stream(t *testing.T) {
// psudo headers by setting things in the headers hashmap.
// I think the real solution here is to add a new `Protocol`
// field to the `http.Request` struct.
req.Header.Add("HACK-HTTP2-Protocol", "websocket")
req.Header.Add("Hack-Http2-Protocol", "websocket")

resp, err := client.Transport.RoundTrip(req)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ func (cc *ClientConn) decrStreamReservationsLocked() {
}

func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Method == "CONNECT" && req.Header.Get("HACK-HTTP2-Protocol") != "" {
if req.Method == "CONNECT" && req.Header.Get("Hack-Http2-Protocol") != "" {
// This is an extended CONNECT https://datatracker.ietf.org/doc/html/rfc8441#section-4
// We need to check if the server supports it.
if err := cc.checkServerSupportsExtendedConnect(); err != nil {
Expand Down Expand Up @@ -1783,7 +1783,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
return nil, err
}

protocol := req.Header.Get("HACK-HTTP2-Protocol")
protocol := req.Header.Get("Hack-Http2-Protocol")

var path string
if req.Method != "CONNECT" || (cc.serverAllowsExtendedConnect && protocol != "") {
Expand Down
109 changes: 108 additions & 1 deletion websocket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ package websocket

import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
)

// DialError is an error that occurs while dialling a websocket server.
Expand Down Expand Up @@ -79,13 +83,22 @@ func parseAuthority(location *url.URL) string {

// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
var client net.Conn
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}

if config.HTTP2Transport != nil {
return dialHTTP2(config)
}

return dialHTTP1(config)
}

func dialHTTP1(config *Config) (ws *Conn, err error) {
var client net.Conn
dialer := config.Dialer
if dialer == nil {
dialer = &net.Dialer{}
Expand All @@ -104,3 +117,97 @@ func DialConfig(config *Config) (ws *Conn, err error) {
Error:
return nil, &DialError{config, err}
}

func dialHTTP2(config *Config) (ws *Conn, err error) {
// Respect tls config set on the top level config if the transport doesn't
// already have one set.
if config.TlsConfig != nil && config.HTTP2Transport.TLSClientConfig == nil {
config.HTTP2Transport.TLSClientConfig = config.TlsConfig
}

// try to respect the dialer configured in the websocket config
if config.Dialer != nil && config.HTTP2Transport.DialTLS == nil {
config.HTTP2Transport.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
d := tls.Dialer{NetDialer: config.Dialer, Config: cfg}
return d.Dial(network, addr)
}
}

if config.Location.Scheme == "ws" && !config.HTTP2Transport.AllowHTTP {
return nil, &DialError{Config: config, Err: errors.New("HTTP/2 requires TLS")}
}

if config.Version != ProtocolVersionHybi13 {
return nil, &DialError{Config: config, Err: ErrBadProtocolVersion}
}

// https://datatracker.ietf.org/doc/html/rfc8441#section-5
// 'The scheme of the target URI (Section 5.1 of [RFC7230]) MUST be
// "https" for "wss"-schemed WebSockets and "http" for "ws"-schemed
// WebSockets.'
if config.Location.Scheme == "wss" {
config.Location.Scheme = "https"
}
if config.Location.Scheme == "ws" {
config.Location.Scheme = "http"
}

// TODO(ethan): replace pipe with something context cancelable
sr, sw := io.Pipe()
req, err := http.NewRequest("CONNECT", config.Location.String(), sr)
if err != nil {
return nil, &DialError{Config: config, Err: err}
}

req.Header.Add("Hack-Http2-Protocol", "websocket")
req.Header.Add("Origin", config.Origin.String())
req.Header.Add("Sec-Websocket-Version", fmt.Sprintf("%d", config.Version))
if len(config.Protocol) > 0 {
req.Header.Add("Sec-Websocket-Protocol", strings.Join(config.Protocol, ","))
}

// inject user supplied headers, if any
for k, vals := range config.Header {
req.Header[k] = vals
}

resp, err := config.HTTP2Transport.RoundTrip(req)
if err != nil {
return nil, &DialError{Config: config, Err: err}
}

// check response headers and status

if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
// we don't support any extentions
return nil, &DialError{Config: config, Err: ErrUnsupportedExtensions}
}

if resp.StatusCode != http.StatusOK {
return nil, &DialError{Config: config, Err: ErrBadStatus}
}

// TODO(ethan): this logic is copied from the HTTP/1.1 branch.
// I should refactor to consolidate.
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
if offeredProtocol != "" {
protocolMatched := false
for i := 0; i < len(config.Protocol); i++ {
if config.Protocol[i] == offeredProtocol {
protocolMatched = true
break
}
}
if !protocolMatched {
return nil, &DialError{Config: config, Err: ErrBadWebSocketProtocol}
}
config.Protocol = []string{offeredProtocol}
}

// The handshake is complete, so we wrap things up in a Conn and return.
stream := newHTTP2ClientStream(sw, resp)
buf := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
conn := newHybiClientConn(config, buf, stream)

return conn, nil
}
183 changes: 183 additions & 0 deletions websocket/http2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package websocket

import (
"bufio"
"errors"
"fmt"
"io"
"net/http"
"strings"
)

// http2Handshaker performs a websocket handshake over an HTTP/2 connection.
// It is similar to a serverHandshaker, but doesn't use quite the same
// interface due to differences in the underlying transport protocol.
type http2Handshaker struct {
// The server's config.
config *Config
// The user-supplied userHandshake callback.
userHandshake func(*Config, *http.Request) error
}

// handshake performs a handshake for an HTTP/2 connection and returns a
// websocket connection or an HTTP status code and an error. The status
// code is only valid if the error is non-nil.
func (h *http2Handshaker) handshake(w http.ResponseWriter, req *http.Request) (conn *Conn, statusCode int, err error) {
statusCode, err = h.checkHeaders(req)
if err != nil {
return nil, statusCode, err
}

// allow the user to perform protocol negotiation
err = h.userHandshake(h.config, req)
if err != nil {
return nil, http.StatusForbidden, ErrBadHandshake
}

// All the headers we've been sent check out, so we can write
// a 200 response and inform the client if we have chosen a particular
// application protocol.
if len(h.config.Protocol) > 0 {
w.Header().Add("Sec-Websocket-Protocol", h.config.Protocol[0])
}
w.WriteHeader(http.StatusOK)

// Flush to force the status onto the wire so that clients can start
// listening.
flusher, ok := w.(http.Flusher)
if !ok {
return nil, http.StatusInternalServerError, errors.New("websocket: response writer must implement flusher")
}
flusher.Flush()

// to get a conn, we need a buffered readwriter, a readwritecloser, and
// the request
stream := newHTTP2ServerStream(w, req)
buf := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
conn = newHybiConn(h.config, buf, stream, req)
return conn, 0, err
}

func (h *http2Handshaker) checkHeaders(req *http.Request) (statusCode int, err error) {
// TODO(ethan): write tests for all of these checks
if req.Method != "CONNECT" {
return http.StatusMethodNotAllowed, ErrBadRequestMethod
}

protocol := req.Header.Get("Hack-Http2-Protocol")
if protocol != "websocket" {
return http.StatusBadRequest, ErrBadProtocol
}

// "On requests that contain the :protocol pseudo-header field, the
// :scheme and :path pseudo-header fields of the target URI (see
// Section 5) MUST also be included."
if req.URL.Path == "" {
return http.StatusBadRequest, ErrBadPath
}

version := req.Header.Get("Sec-Websocket-Version")
if version == "13" {
h.config.Version = ProtocolVersionHybi13
} else {
return http.StatusBadRequest, ErrBadProtocolVersion
}

// parse the list of request protocols
protocolCSV := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
if protocolCSV != "" {
protocols := strings.Split(protocolCSV, ",")
for i := 0; i < len(protocols); i++ {
// It is ok to mutate Protocol like this because server takes its
// receiver by value, not reference, so the whole thing is copied
// for each request.
h.config.Protocol = append(h.config.Protocol, strings.TrimSpace(protocols[i]))
}
}

return 0, nil
}

//
// http2ServerStream
//

// http2ServerStream is a wrapper around a request and response writer that
// implements io.ReadWriteCloser
type http2ServerStream struct {
w http.ResponseWriter
flusher http.Flusher
req *http.Request
}

func newHTTP2ServerStream(w http.ResponseWriter, req *http.Request) *http2ServerStream {
flusher, ok := w.(http.Flusher)
if !ok {
panic("websocket: response writer must implement flusher")
}

return &http2ServerStream{
w: w,
flusher: flusher,
req: req,
}
}

func (s *http2ServerStream) Read(p []byte) (n int, err error) {
return s.req.Body.Read(p)
}
func (s *http2ServerStream) Write(p []byte) (n int, err error) {
n, err = s.w.Write(p)
if err != nil {
return n, err
}

// We flush every time since the main websocket code is going to wrap
// this in a bufio.Writer and expect that when the bufio.Writer is flushed
// the bytes actually land on the wire.
s.flusher.Flush()

return n, err
}
func (s *http2ServerStream) Close() error {
return s.req.Body.Close()
}

//
// http2ClientStream
//

// http2ClientStream is a wrapper around a writer and an http response that
// implements io.ReadWriteCloser
type http2ClientStream struct {
w *io.PipeWriter
resp *http.Response
}

func newHTTP2ClientStream(w *io.PipeWriter, resp *http.Response) *http2ClientStream {
return &http2ClientStream{
w: w,
resp: resp,
}
}

func (s *http2ClientStream) Read(p []byte) (n int, err error) {
return s.resp.Body.Read(p)
}
func (s *http2ClientStream) Write(p []byte) (n int, err error) {
return s.w.Write(p)
}
func (s *http2ClientStream) Close() error {
wErr := s.w.Close()
rErr := s.resp.Body.Close()
if wErr != nil && rErr != nil {
return fmt.Errorf("client close: %s: %w", wErr, rErr)
}
if wErr != nil {
return wErr
}
if rErr != nil {
return rErr
}
return nil
}
Loading

0 comments on commit df38e5d

Please sign in to comment.