From 8abd3fefdfc6aca7c487d863cd38adbfd161ef30 Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 13 Feb 2024 23:51:17 +0530 Subject: [PATCH] fix race in connection timeout --- p2p/transport/webrtc/connection.go | 10 +++++++++- p2p/transport/webrtc/stream.go | 10 +++++++--- p2p/transport/webrtc/stream_read.go | 12 ++++-------- p2p/transport/webrtc/transport_test.go | 4 ++-- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index a9463a199b..f220360ed4 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -145,7 +145,12 @@ func (c *connection) closeWithError(err error) { c.cancel() // closing peerconnection will close the datachannels associated with the streams c.pc.Close() - for _, s := range c.streams { + + c.m.Lock() + streams := c.streams + c.streams = nil + c.m.Unlock() + for _, s := range streams { s.closeForShutdown(err) } c.scope.Done() @@ -207,6 +212,9 @@ func (c *connection) Transport() tpt.Transport { return c.transport } func (c *connection) addStream(str *stream) error { c.m.Lock() defer c.m.Unlock() + if c.streams == nil { + return c.closeErr + } if _, ok := c.streams[str.id]; ok { return errors.New("stream ID already exists") } diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 42e90d5fad..23fef07bbc 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -141,7 +141,8 @@ func newStream( func (s *stream) Close() error { s.mx.Lock() if s.closeForShutdownErr != nil { - return s.closeForShutdownErr + s.mx.Unlock() + return nil } s.mx.Unlock() @@ -166,6 +167,7 @@ func (s *stream) Close() error { func (s *stream) Reset() error { s.mx.Lock() if s.closeForShutdownErr != nil { + s.mx.Unlock() return nil } s.mx.Unlock() @@ -180,6 +182,8 @@ func (s *stream) Reset() error { } func (s *stream) closeForShutdown(closeErr error) { + defer s.cleanup() + s.mx.Lock() defer s.mx.Unlock() @@ -189,7 +193,6 @@ func (s *stream) closeForShutdown(closeErr error) { case s.sendStateChanged <- struct{}{}: default: } - s.cleanup() } func (s *stream) SetDeadline(t time.Time) error { @@ -275,7 +278,8 @@ func (s *stream) spawnControlMessageReader() { s.processIncomingFlag(s.nextMessage.Flag) s.nextMessage = nil } - for s.sendState != sendStateDataReceived && s.sendState != sendStateReset { + for s.closeForShutdownErr == nil && + s.sendState != sendStateDataReceived && s.sendState != sendStateReset { var msg pb.Message if !setDeadline() { return diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index 9d54ec51be..002ebac0ec 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -37,11 +37,11 @@ func (s *stream) Read(b []byte) (int, error) { var msg pb.Message if err := s.reader.ReadMsg(&msg); err != nil { s.mx.Lock() + // connection was closed + if s.closeForShutdownErr != nil { + return 0, s.closeForShutdownErr + } if err == io.EOF { - // connection was closed - if s.closeForShutdownErr != nil { - return 0, s.closeForShutdownErr - } // if the channel was properly closed, return EOF if s.receiveState == receiveStateDataRead { return 0, io.EOF @@ -59,10 +59,6 @@ func (s *stream) Read(b []byte) (int, error) { if s.receiveState == receiveStateDataRead { return 0, io.EOF } - // connection was closed - if s.closeForShutdownErr != nil { - return 0, s.closeForShutdownErr - } return 0, err } s.mx.Lock() diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index 18e677f7c7..11a0cd1600 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -689,7 +689,7 @@ func TestConnectionTimeoutOnListener(t *testing.T) { start := time.Now() for { if _, err := str.Write([]byte("test")); err != nil { - require.True(t, os.IsTimeout(err)) + require.True(t, os.IsTimeout(err), "invalid error type: %v", err) break } @@ -697,7 +697,7 @@ func TestConnectionTimeoutOnListener(t *testing.T) { t.Fatal("timeout") } // make sure to not write too often, we don't want to fill the flow control window - time.Sleep(5 * time.Millisecond) + time.Sleep(20 * time.Millisecond) } // make sure that accepting a stream also returns an error... _, err = conn.AcceptStream()