Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tighten lock around appending new chunks of read data in stream #28

Merged
merged 6 commits into from
May 2, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (s *Stream) LocalAddr() net.Addr {
return s.session.LocalAddr()
}

// LocalAddr returns the remote address
// RemoteAddr returns the remote address
func (s *Stream) RemoteAddr() net.Addr {
return s.session.RemoteAddr()
}
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ module github.com/libp2p/go-yamux

go 1.12

require github.com/libp2p/go-buffer-pool v0.0.2
require (
github.com/libp2p/go-buffer-pool v0.0.2
)
20 changes: 20 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,22 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/libp2p/go-buffer-pool v0.0.2 h1:QNK2iAFa8gjAe1SPz6mHSMuCcjs+X1wlHzeOSqcmlfs=
github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoRZd1Vi32+RXyFM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
2 changes: 0 additions & 2 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,6 @@ func TestSendData_Large(t *testing.T) {
t.Errorf("err: %v", err)
return
}

t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
}()

go func() {
Expand Down
40 changes: 10 additions & 30 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"sync"
"sync/atomic"
"time"

"github.com/libp2p/go-buffer-pool"
)

type streamState int
Expand All @@ -25,7 +23,6 @@ const (
// Stream is used to represent a logical stream
// within a session.
type Stream struct {
recvWindow uint32
sendWindow uint32

id uint32
Expand All @@ -35,7 +32,7 @@ type Stream struct {
stateLock sync.Mutex

recvLock sync.Mutex
recvBuf pool.Buffer
recvBuf segmentedBuffer

sendLock sync.Mutex

Expand All @@ -52,10 +49,10 @@ func newStream(session *Session, id uint32, state streamState) *Stream {
id: id,
session: session,
state: state,
recvWindow: initialStreamWindow,
sendWindow: initialStreamWindow,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
recvBuf: NewSegmentedBuffer(initialStreamWindow),
recvNotifyCh: make(chan struct{}, 1),
sendNotifyCh: make(chan struct{}, 1),
}
Expand Down Expand Up @@ -84,9 +81,7 @@ START:
case streamRemoteClose:
fallthrough
case streamClosed:
s.recvLock.Lock()
empty := s.recvBuf.Len() == 0
s.recvLock.Unlock()
if empty {
return 0, io.EOF
}
Expand Down Expand Up @@ -213,19 +208,13 @@ func (s *Stream) sendWindowUpdate() error {

// Determine the delta update
max := s.session.config.MaxStreamWindowSize
s.recvLock.Lock()
delta := (max - uint32(s.recvBuf.Len())) - s.recvWindow

// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
// Update our window
needed, delta := s.recvBuf.GrowTo(uint64(max), flags != 0)
if !needed {
return nil
}

// Update our window
s.recvWindow += delta
s.recvLock.Unlock()

// Send the header
hdr := encode(typeWindowUpdate, flags, s.id, delta)
if err := s.session.sendMsg(hdr, nil, nil); err != nil {
Expand Down Expand Up @@ -409,26 +398,17 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}
willscott marked this conversation as resolved.
Show resolved Hide resolved

// Copy into buffer
s.recvLock.Lock()

if length > s.recvWindow {
s.recvLock.Unlock()
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
// Validate it's okay to copy
if !s.recvBuf.TryReserve(length) {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvBuf.Cap(), length)
return ErrRecvWindowExceeded
}

s.recvBuf.Grow(int(length))
if _, err := io.Copy(&s.recvBuf, conn); err != nil {
s.recvLock.Unlock()
// Copy into buffer
if err := s.recvBuf.Append(conn, int(length)); err != nil {
s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
return err
}

// Decrement the receive window
s.recvWindow -= length
s.recvLock.Unlock()

// Unblock any readers
asyncNotify(s.recvNotifyCh)
return nil
Expand Down
117 changes: 117 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
package yamux

import (
"io"
"sync"
"sync/atomic"

pool "github.com/libp2p/go-buffer-pool"
)

// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {
Expand Down Expand Up @@ -29,3 +37,112 @@ func min(values ...uint32) uint32 {
}
return m
}

type segmentedBuffer struct {
cap uint64
pending uint64
len uint64
bm sync.Mutex
b [][]byte
}

// NewSegmentedBuffer allocates a ring buffer.
func NewSegmentedBuffer(initialCapacity uint32) segmentedBuffer {
return segmentedBuffer{cap: uint64(initialCapacity), b: make([][]byte, 0)}
}

func (s *segmentedBuffer) Len() int {
return int(atomic.LoadUint64(&s.len))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the motivation for the atomics? We end up doing a lot of atomic operations just to allow checking the length without taking a lock.

(I haven't profiled it so it may be fine, I'm just wondering).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reads will be regularly checking if the length has extended. my intuition was this would be faster than locking, reading, and unlocking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, my real concern is whether that outweighs the cost of multiple atomic operations (and potential memory barriers) while we're holding a lock anyways. But it works so I have no real objections.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i suspect a lot of that has to do with the list of buffers rather than the length atomics 🤷‍♂️

I'm also going to revert back to uint32's throughout, as they were before. There was an underflow that i was confirming, which is why i bumped them up to 64 bits.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM.

}

func (s *segmentedBuffer) Cap() uint64 {
return atomic.LoadUint64(&s.cap)
}

// If the space to write into + current buffer size has grown to half of the window size,
// grow up to that max size, and indicate how much additional space was reserved.
func (s *segmentedBuffer) GrowTo(max uint64, force bool) (bool, uint32) {
s.bm.Lock()
defer s.bm.Unlock()

currentWindow := atomic.LoadUint64(&s.len) + atomic.LoadUint64(&s.cap) + s.pending
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can drop the atomic loads here given that these values are only ever modified under the lock. But I'm really not sure about that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think i agree, but it seemed like a sane pattern to keep all accesses atomic when it's needed at least some of the time to prevent unexpected glitches.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

if currentWindow > max {
// somewhat counter-intuitively not an error.
// note that len+cap is the 'window' that shouldn't exceed max or a reservation
// would fail, triggering an error.
// We pre-count 'pending' data where we've read a header and are working on
// reading it into available data here, so that we don't undercount the remaining
// window size, but that can mean this sum ends up larger than max.
return false, 0
}
delta := max - currentWindow

if delta < (max/2) && !force {
return false, 0
}

atomic.AddUint64(&s.cap, delta)
return true, uint32(delta)
}

func (s *segmentedBuffer) TryReserve(space uint32) bool {
// It is noticable that the check-and-set of pending is not atomic,
// Due to this, accesses to pending are protected by bm.
s.bm.Lock()
defer s.bm.Unlock()
if atomic.LoadUint64(&s.cap) < s.pending+uint64(space) {
return false
}
s.pending += uint64(space)
return true
}

func (s *segmentedBuffer) Read(b []byte) (int, error) {
s.bm.Lock()
defer s.bm.Unlock()
if len(s.b) == 0 {
return 0, io.EOF
}
n := copy(b, s.b[0])
if n == len(s.b[0]) {
pool.Put(s.b[0])
s.b[0] = nil
s.b = s.b[1:]
} else {
s.b[0] = s.b[0][n:]
}
if n > 0 {
atomic.AddUint64(&s.len, ^uint64(n-1))
}
return n, nil
}

func (s *segmentedBuffer) Append(input io.Reader, length int) error {
dst := pool.Get(length)
n := 0
for {
read, err := input.Read(dst[n:])
n += read
switch err {
case nil:
case io.EOF:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually wonder if we should ignore the error any time we read a full message. But that's probably an edge-case that really doesn't matter.

if n == length {
err = nil
} else {
err = ErrStreamReset
}
fallthrough
default:
s.bm.Lock()
defer s.bm.Unlock()
if n > 0 {
atomic.AddUint64(&s.len, uint64(n))
// cap -= n
atomic.AddUint64(&s.cap, ^uint64(n-1))
s.pending = s.pending - uint64(length)
s.b = append(s.b, dst[0:n])
}
return err
}
}
}