From 98bf7a1480ece5e70d4d90201eb7edb5511f85bd Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 30 May 2019 03:10:19 -0400 Subject: [PATCH] Expand API - Closes #1 (Ping API) - Closes #62 (Read/Write convienence methods) - Closes #83 (SetReadLimit) --- example_echo_test.go | 2 - export_test.go | 18 -------- websocket.go | 99 ++++++++++++++++++++++++++++++++++++++++---- websocket_test.go | 2 + wsjson/wsjson.go | 4 -- wspb/wspb.go | 21 +--------- 6 files changed, 95 insertions(+), 51 deletions(-) delete mode 100644 export_test.go diff --git a/example_echo_test.go b/example_echo_test.go index a86d5b89..405c7a41 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -94,7 +94,6 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { // echo reads from the websocket connection and then writes // the received message back to it. // The entire function has 10s to complete. -// The received message is limited to 32768 bytes. func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() @@ -108,7 +107,6 @@ func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { if err != nil { return err } - r = io.LimitReader(r, 32768) w, err := c.Writer(ctx, typ) if err != nil { diff --git a/export_test.go b/export_test.go deleted file mode 100644 index 465ba9eb..00000000 --- a/export_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package websocket - -import ( - "context" -) - -// Write writes p as a single data frame to the connection. This is an optimization -// method for when the entire message is in memory and does not need to be streamed -// to the peer via Writer. -// -// This prevents the allocation of the Writer. -// Furthermore Writer always has to write an additional fin frame when Close is -// called on the writer which can result in worse performance if the full message -// exceeds the buffer size which is 4096 right now as then an extra syscall -// will be necessary to complete the message. -func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - return c.writeCompleteMessage(ctx, opcode(typ), p) -} diff --git a/websocket.go b/websocket.go index 25688b08..00decaad 100644 --- a/websocket.go +++ b/websocket.go @@ -5,9 +5,13 @@ import ( "context" "fmt" "io" + "io/ioutil" + "math/rand" "os" "runtime" + "strconv" "sync" + "sync/atomic" "time" "golang.org/x/xerrors" @@ -25,6 +29,8 @@ type Conn struct { closer io.Closer client bool + msgReadLimit int64 + closeOnce sync.Once closeErr error closed chan struct{} @@ -41,14 +47,16 @@ type Conn struct { setWriteTimeout chan context.Context setConnContext chan context.Context getConnContext chan context.Context + + pingListener map[string]chan<- struct{} } // Context returns a context derived from parent that will be cancelled -// when the connection is closed. +// when the connection is closed or broken. // If the parent context is cancelled, the connection will be closed. // -// This is an experimental API that may be remove in the future. -// Please let me know how you feel about it. +// This is an experimental API that may be removed in the future. +// Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79 func (c *Conn) Context(parent context.Context) context.Context { select { case <-c.closed: @@ -105,6 +113,8 @@ func (c *Conn) Subprotocol() string { func (c *Conn) init() { c.closed = make(chan struct{}) + c.msgReadLimit = 32768 + c.writeDataLock = make(chan struct{}, 1) c.writeFrameLock = make(chan struct{}, 1) @@ -118,6 +128,8 @@ func (c *Conn) init() { c.setConnContext = make(chan context.Context) c.getConnContext = make(chan context.Context) + c.pingListener = make(map[string]chan<- struct{}) + runtime.SetFinalizer(c, func(c *Conn) { c.close(xerrors.New("connection garbage collected")) }) @@ -242,6 +254,10 @@ func (c *Conn) handleControl(h header) { case opPing: c.writePong(b) case opPong: + listener, ok := c.pingListener[string(b)] + if ok { + close(listener) + } case opClose: ce, err := parseClosePayload(b) if err != nil { @@ -321,7 +337,7 @@ func (c *Conn) writePong(p []byte) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeCompleteMessage(ctx, opPong, p) + err := c.writeMessage(ctx, opPong, p) return err } @@ -369,7 +385,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeCompleteMessage(ctx, opClose, p) + err := c.writeMessage(ctx, opClose, p) c.close(cerr) @@ -399,7 +415,7 @@ func (c *Conn) releaseLock(lock chan struct{}) { <-lock } -func (c *Conn) writeCompleteMessage(ctx context.Context, opcode opcode, p []byte) error { +func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error { if !opcode.controlOp() { err := c.acquireLock(ctx, c.writeDataLock) if err != nil { @@ -445,6 +461,30 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err }, nil } +// Read is a convenience method to read a single message from the connection. +// +// See the Reader method if you want to be able to reuse buffers or want to stream a message. +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, r, err := c.Reader(ctx) + if err != nil { + return 0, nil, err + } + + b, err := ioutil.ReadAll(r) + if err != nil { + return typ, b, err + } + + return typ, b, nil +} + +// Write is a convenience method to write a message to the connection. +// +// See the Writer method if you want to stream a message. +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + return c.writeMessage(ctx, opcode(typ), p) +} + // messageWriter enables writing to a WebSocket connection. type messageWriter struct { ctx context.Context @@ -519,7 +559,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { if err != nil { return 0, nil, xerrors.Errorf("failed to get reader: %w", err) } - return typ, r, nil + return typ, io.LimitReader(r, c.msgReadLimit), nil } func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { @@ -640,3 +680,48 @@ func (r *messageReader) read(p []byte) (int, error) { return n, nil } + +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +func (c *Conn) SetReadLimit(n int64) { + atomic.StoreInt64(&c.msgReadLimit, n) +} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// +// This API is experimental and subject to change. +// Please provide feedback in https://github.com/nhooyr/websocket/issues/1. +func (c *Conn) Ping(ctx context.Context) error { + err := c.ping(ctx) + if err != nil { + return xerrors.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *Conn) ping(ctx context.Context) error { + id := rand.Uint64() + p := strconv.FormatUint(id, 10) + + pong := make(chan struct{}) + c.pingListener[p] = pong + + err := c.writeMessage(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-pong: + return nil + } +} diff --git a/websocket_test.go b/websocket_test.go index b8d7b56c..d982732a 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -489,6 +489,8 @@ func TestAutobahnServer(t *testing.T) { func echoLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") + c.SetReadLimit(1 << 30) + ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 0ff33bf3..d85700bc 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -12,8 +12,6 @@ import ( ) // Read reads a json message from c into v. -// For security reasons, it will not read messages -// larger than 32768 bytes. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { err := read(ctx, c, v) if err != nil { @@ -33,8 +31,6 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { return xerrors.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) } - r = io.LimitReader(r, 32768) - d := json.NewDecoder(r) err = d.Decode(v) if err != nil { diff --git a/wspb/wspb.go b/wspb/wspb.go index 90a0d046..edffede1 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -3,7 +3,6 @@ package wspb import ( "context" - "io" "io/ioutil" "github.com/golang/protobuf/proto" @@ -13,8 +12,6 @@ import ( ) // Read reads a protobuf message from c into v. -// For security reasons, it will not read messages -// larger than 32768 bytes. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { err := read(ctx, c, v) if err != nil { @@ -34,8 +31,6 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return xerrors.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) } - r = io.LimitReader(r, 32768) - b, err := ioutil.ReadAll(r) if err != nil { return xerrors.Errorf("failed to read message: %w", err) @@ -64,19 +59,5 @@ func write(ctx context.Context, c *websocket.Conn, v proto.Message) error { return xerrors.Errorf("failed to marshal protobuf: %w", err) } - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - - _, err = w.Write(b) - if err != nil { - return err - } - - err = w.Close() - if err != nil { - return err - } - return nil + return c.Write(ctx, websocket.MessageBinary, b) }