From e75c615ba1b356720bd80e876c4f90e1421768d9 Mon Sep 17 00:00:00 2001 From: Sergey Egorov Date: Wed, 24 Jan 2024 10:59:09 +0200 Subject: [PATCH] zmq4: fix another connection reaper deadlock Fixes #149 Co-authored-by: Sergey Egorov Co-authored-by: Sebastien Binet --- reaper_test.go | 102 +++++++++++++++++++++++++++++++++++++++++++++++++ socket.go | 10 ++++- zall_test.go | 31 +++++++++++++++ 3 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 reaper_test.go diff --git a/reaper_test.go b/reaper_test.go new file mode 100644 index 0000000..4b756c5 --- /dev/null +++ b/reaper_test.go @@ -0,0 +1,102 @@ +// Copyright 2024 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zmq4 + +import ( + "context" + "io" + "net" + "sync/atomic" + "testing" + "time" +) + +func TestConnReaperDeadlock2(t *testing.T) { + ep := must(EndPoint("tcp")) + defer cleanUp(ep) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Bind the server. + srv := NewRouter(ctx, WithLogger(Devnull)).(*routerSocket) + if err := srv.Listen(ep); err != nil { + t.Fatalf("could not listen on %q: %+v", ep, err) + } + defer srv.Close() + + // Add modified clients connection to server + // so any send to client will trigger context switch + // and be failing. + // Idea is that while srv.Send is progressing, + // the connection will be closed and assigned + // for connection reaper, and reaper will try to remove those + id := "client-x" + srv.sck.mu.Lock() + rmw := srv.sck.w.(*routerMWriter) + for i := 0; i < 2; i++ { + w := &Conn{} + w.Peer.Meta = make(Metadata) + w.Peer.Meta[sysSockID] = id + w.rw = &sockSendEOF{} + w.onCloseErrorCB = srv.sck.scheduleRmConn + // Do not to call srv.addConn as we dont want to have listener on this fake socket + rmw.addConn(w) + srv.sck.conns = append(srv.sck.conns, w) + } + srv.sck.mu.Unlock() + + // Now try to send a message from the server to all clients. + msg := NewMsgFrom(nil, nil, []byte("payload")) + msg.Frames[0] = []byte(id) + if err := srv.Send(msg); err != nil { + t.Logf("Send to %s failed: %+v\n", id, err) + } +} + +type sockSendEOF struct { +} + +var a atomic.Int32 + +func (r *sockSendEOF) Write(b []byte) (n int, err error) { + // Each odd write fails asap. + // Each even write fails after sleep. + // Such a way we ensure the short write failure + // will cause socket be assinged to connection reaper + // while srv.Send is still in progress due to long writes. + if x := a.Add(1); x&1 == 0 { + time.Sleep(1 * time.Second) + } + return 0, io.EOF +} + +func (r *sockSendEOF) Read(b []byte) (int, error) { + return 0, nil +} + +func (r *sockSendEOF) Close() error { + return nil +} + +func (r *sockSendEOF) LocalAddr() net.Addr { + return nil +} + +func (r *sockSendEOF) RemoteAddr() net.Addr { + return nil +} + +func (r *sockSendEOF) SetDeadline(t time.Time) error { + return nil +} + +func (r *sockSendEOF) SetReadDeadline(t time.Time) error { + return nil +} + +func (r *sockSendEOF) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/socket.go b/socket.go index ca198a6..7a41c71 100644 --- a/socket.go +++ b/socket.go @@ -384,10 +384,16 @@ func (sck *socket) connReaper() { return } - for _, c := range sck.closedConns { + // Clone the known closed connections to avoid data race + // and remove those under reaper unlocked. + // That should fix the deadlock reported in #149. + cc := append([]*Conn{}, sck.closedConns...) // clone + sck.closedConns = sck.closedConns[:0] + sck.reaperCond.L.Unlock() + for _, c := range cc { sck.rmConn(c) } - sck.closedConns = nil + sck.reaperCond.L.Lock() } } diff --git a/zall_test.go b/zall_test.go index c3cd079..90fc514 100644 --- a/zall_test.go +++ b/zall_test.go @@ -5,10 +5,41 @@ package zmq4 import ( + "fmt" "io" "log" + "net" ) var ( Devnull = log.New(io.Discard, "zmq4: ", 0) ) + +func must(str string, err error) string { + if err != nil { + panic(err) + } + return str +} + +func EndPoint(transport string) (string, error) { + switch transport { + case "tcp": + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return "", err + } + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return "", err + } + defer l.Close() + return fmt.Sprintf("tcp://%s", l.Addr()), nil + case "ipc": + return "ipc://tmp-" + newUUID(), nil + case "inproc": + return "inproc://tmp-" + newUUID(), nil + default: + panic("invalid transport: [" + transport + "]") + } +}