Skip to content

Commit

Permalink
zmq4: add option for automatic reconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
thielepaul committed Jun 15, 2022
1 parent 16d169c commit 6d219d4
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 9 deletions.
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ func WithLogger(msg *log.Logger) Option {
}
}

// WithAutomaticReconnect allows to configure a socket to automatically
// reconnect on connection loss.
func WithAutomaticReconnect(automaticReconnect bool) Option {
return func(s *socket) {
s.autoReconnect = automaticReconnect
}
}

/*
// TODO(sbinet)
Expand Down
27 changes: 18 additions & 9 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ var (

// socket implements the ZeroMQ socket interface
type socket struct {
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
sec Security
log *log.Logger
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
sec Security
log *log.Logger
autoReconnect bool

mu sync.RWMutex
ids map[string]*Conn // ZMTP connection IDs
Expand All @@ -50,8 +51,9 @@ type socket struct {
listener net.Listener
dialer net.Dialer

closedConns []*Conn
reaperCond *sync.Cond
closedConns []*Conn
reaperCond *sync.Cond
reaperStarted bool
}

func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
Expand Down Expand Up @@ -266,7 +268,10 @@ connect:
return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint)
}

go sck.connReaper()
if !sck.reaperStarted {
go sck.connReaper()
sck.reaperStarted = true
}
sck.addConn(zconn)
return nil
}
Expand Down Expand Up @@ -319,6 +324,10 @@ func (sck *socket) scheduleRmConn(c *Conn) {
sck.closedConns = append(sck.closedConns, c)
sck.reaperCond.Signal()
sck.reaperCond.L.Unlock()

if sck.autoReconnect {
sck.Dial(sck.ep)
}
}

// Type returns the type of this Socket (PUB, SUB, ...)
Expand Down
57 changes: 57 additions & 0 deletions socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -220,3 +221,59 @@ func TestConnReaperDeadlock(t *testing.T) {
clients[i].Close()
}
}

func TestSocketAutomaticReconnect(t *testing.T) {
listenEndpoint := "tcp://*:1234"
dialEndpoint := "tcp://localhost:1234"
message := "test"

ctx, cancel := context.WithCancel(context.Background())

sub := zmq4.NewSub(context.Background())
pub := zmq4.NewPub(context.Background(), zmq4.WithAutomaticReconnect(true))
defer pub.Close()
if err := sub.Listen(listenEndpoint); err != nil {
t.Fatalf("Sub Dial failed: %v", err)
}
if err := pub.Dial(dialEndpoint); err != nil {
t.Fatalf("Pub Dial failed: %v", err)
}
sub.SetOption(zmq4.OptionSubscribe, message)

wg := new(sync.WaitGroup)
defer wg.Wait()
defer cancel()
wg.Add(1)
go func(t *testing.T) {
defer wg.Done()
for {
pub.Send(zmq4.NewMsgFromString([]string{message}))
if ctx.Err() != nil {
return
}
time.Sleep(1 * time.Millisecond)
}
}(t)

checkConnectionWorking := func(socket zmq4.Socket) {
msg, err := socket.Recv()
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
if string(msg.Frames[0]) != message {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
}

checkConnectionWorking(sub)
sub.Close()

sub2 := zmq4.NewSub(context.Background())
defer sub2.Close()
if err := sub2.Listen(listenEndpoint); err != nil {
t.Fatalf("Sub Listen failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
sub2.SetOption(zmq4.OptionSubscribe, message)
checkConnectionWorking(sub2)
}

0 comments on commit 6d219d4

Please sign in to comment.