Skip to content

Commit

Permalink
pool flate readers
Browse files Browse the repository at this point in the history
  • Loading branch information
y3llowcake committed Dec 19, 2016
1 parent 3ab3a8b commit 2db2f66
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 10 deletions.
47 changes: 41 additions & 6 deletions compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,22 @@ import (

var (
flateWriterPool = sync.Pool{}
flateReaderPool = sync.Pool{}
)

func decompressNoContextTakeover(r io.Reader) io.Reader {
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))

i := flateReaderPool.Get()
if i == nil {
i = flate.NewReader(nil)
}
i.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{i.(io.ReadCloser)}
}

func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
Expand All @@ -36,7 +43,7 @@ func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
fw = i.(*flate.Writer)
fw.Reset(tw)
}
return &flateWrapper{fw: fw, tw: tw}, err
return &flateWriteWrapper{fw: fw, tw: tw}, err
}

// truncWriter is an io.Writer that writes all but the last four bytes of the
Expand Down Expand Up @@ -75,19 +82,19 @@ func (w *truncWriter) Write(p []byte) (int, error) {
return n + nn, err
}

type flateWrapper struct {
type flateWriteWrapper struct {
fw *flate.Writer
tw *truncWriter
}

func (w *flateWrapper) Write(p []byte) (int, error) {
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil {
return 0, errWriteClosed
}
return w.fw.Write(p)
}

func (w *flateWrapper) Close() error {
func (w *flateWriteWrapper) Close() error {
if w.fw == nil {
return errWriteClosed
}
Expand All @@ -103,3 +110,31 @@ func (w *flateWrapper) Close() error {
}
return err2
}

type flateReadWrapper struct {
fr io.ReadCloser
}

func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
r.Close()
}
return n, err
}

func (r *flateReadWrapper) Close() error {
if r.fr == nil {
return io.ErrClosedPipe
}
err := r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
return err
}
18 changes: 14 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ type Conn struct {
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)

// Read fields
reader io.ReadCloser // the current reader returned to the application
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
Expand All @@ -253,7 +254,7 @@ type Conn struct {
messageReader *messageReader // the current low-level reader

readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.Reader
newDecompressionReader func(io.Reader) io.ReadCloser
}

func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
Expand Down Expand Up @@ -855,6 +856,11 @@ func (c *Conn) handleProtocolError(message string) error {
// permanent. Once this method returns a non-nil error, all subsequent calls to
// this method return the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
// Close previous reader, only relevant for decompression.
if c.reader != nil {
c.reader.Close()
c.reader = nil
}

c.messageReader = nil
c.readLength = 0
Expand All @@ -867,11 +873,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
}
if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c}
var r io.Reader = c.messageReader
c.reader = c.messageReader
if c.readDecompress {
r = c.newDecompressionReader(r)
c.reader = c.newDecompressionReader(c.reader)
}
return frameType, r, nil
return frameType, c.reader, nil
}
}

Expand Down Expand Up @@ -933,6 +939,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
return 0, err
}

func (r *messageReader) Close() error {
return nil
}

// ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
Expand Down

0 comments on commit 2db2f66

Please sign in to comment.