Skip to content

Commit

Permalink
zmq4: fix another connection reaper deadlock
Browse files Browse the repository at this point in the history
Fixes #149

Co-authored-by: Sergey Egorov <sergey.egorov@teleste.com>
Co-authored-by: Sebastien Binet <binet@cern.ch>
  • Loading branch information
3 people authored Jan 24, 2024
1 parent 16ca7c0 commit e75c615
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 2 deletions.
102 changes: 102 additions & 0 deletions reaper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright 2024 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package zmq4

import (
"context"
"io"
"net"
"sync/atomic"
"testing"
"time"
)

func TestConnReaperDeadlock2(t *testing.T) {
ep := must(EndPoint("tcp"))
defer cleanUp(ep)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Bind the server.
srv := NewRouter(ctx, WithLogger(Devnull)).(*routerSocket)
if err := srv.Listen(ep); err != nil {
t.Fatalf("could not listen on %q: %+v", ep, err)
}
defer srv.Close()

// Add modified clients connection to server
// so any send to client will trigger context switch
// and be failing.
// Idea is that while srv.Send is progressing,
// the connection will be closed and assigned
// for connection reaper, and reaper will try to remove those
id := "client-x"
srv.sck.mu.Lock()
rmw := srv.sck.w.(*routerMWriter)
for i := 0; i < 2; i++ {
w := &Conn{}
w.Peer.Meta = make(Metadata)
w.Peer.Meta[sysSockID] = id
w.rw = &sockSendEOF{}
w.onCloseErrorCB = srv.sck.scheduleRmConn
// Do not to call srv.addConn as we dont want to have listener on this fake socket
rmw.addConn(w)
srv.sck.conns = append(srv.sck.conns, w)
}
srv.sck.mu.Unlock()

// Now try to send a message from the server to all clients.
msg := NewMsgFrom(nil, nil, []byte("payload"))
msg.Frames[0] = []byte(id)
if err := srv.Send(msg); err != nil {
t.Logf("Send to %s failed: %+v\n", id, err)
}
}

type sockSendEOF struct {
}

var a atomic.Int32

func (r *sockSendEOF) Write(b []byte) (n int, err error) {
// Each odd write fails asap.
// Each even write fails after sleep.
// Such a way we ensure the short write failure
// will cause socket be assinged to connection reaper
// while srv.Send is still in progress due to long writes.
if x := a.Add(1); x&1 == 0 {
time.Sleep(1 * time.Second)
}
return 0, io.EOF
}

func (r *sockSendEOF) Read(b []byte) (int, error) {
return 0, nil
}

func (r *sockSendEOF) Close() error {
return nil
}

func (r *sockSendEOF) LocalAddr() net.Addr {
return nil
}

func (r *sockSendEOF) RemoteAddr() net.Addr {
return nil
}

func (r *sockSendEOF) SetDeadline(t time.Time) error {
return nil
}

func (r *sockSendEOF) SetReadDeadline(t time.Time) error {
return nil
}

func (r *sockSendEOF) SetWriteDeadline(t time.Time) error {
return nil
}
10 changes: 8 additions & 2 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,16 @@ func (sck *socket) connReaper() {
return
}

for _, c := range sck.closedConns {
// Clone the known closed connections to avoid data race
// and remove those under reaper unlocked.
// That should fix the deadlock reported in #149.
cc := append([]*Conn{}, sck.closedConns...) // clone
sck.closedConns = sck.closedConns[:0]
sck.reaperCond.L.Unlock()
for _, c := range cc {
sck.rmConn(c)
}
sck.closedConns = nil
sck.reaperCond.L.Lock()
}
}

Expand Down
31 changes: 31 additions & 0 deletions zall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,41 @@
package zmq4

import (
"fmt"
"io"
"log"
"net"
)

var (
Devnull = log.New(io.Discard, "zmq4: ", 0)
)

func must(str string, err error) string {
if err != nil {
panic(err)
}
return str
}

func EndPoint(transport string) (string, error) {
switch transport {
case "tcp":
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
return "", err
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return "", err
}
defer l.Close()
return fmt.Sprintf("tcp://%s", l.Addr()), nil
case "ipc":
return "ipc://tmp-" + newUUID(), nil
case "inproc":
return "inproc://tmp-" + newUUID(), nil
default:
panic("invalid transport: [" + transport + "]")
}
}

0 comments on commit e75c615

Please sign in to comment.