Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zmq4: add option for automatic reconnect #127

Merged
merged 1 commit into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ func WithDialerMaxRetries(maxRetries int) Option {
}
}

// WithAutomaticReconnect allows to configure a socket to automatically
// reconnect on connection loss.
func WithAutomaticReconnect(automaticReconnect bool) Option {
sbinet marked this conversation as resolved.
Show resolved Hide resolved
return func(s *socket) {
s.autoReconnect = automaticReconnect
}
}

/*
// TODO(sbinet)

Expand Down
31 changes: 20 additions & 11 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ var (

// socket implements the ZeroMQ socket interface
type socket struct {
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
maxRetries int
sec Security
log *log.Logger
subTopics func() []string
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
maxRetries int
sec Security
log *log.Logger
subTopics func() []string
autoReconnect bool

mu sync.RWMutex
ids map[string]*Conn // ZMTP connection IDs
Expand All @@ -53,8 +54,9 @@ type socket struct {
listener net.Listener
dialer net.Dialer

closedConns []*Conn
reaperCond *sync.Cond
closedConns []*Conn
reaperCond *sync.Cond
reaperStarted bool
}

func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
Expand Down Expand Up @@ -271,7 +273,10 @@ connect:
return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint)
}

go sck.connReaper()
if !sck.reaperStarted {
go sck.connReaper()
sck.reaperStarted = true
}
sck.addConn(zconn)
return nil
}
Expand Down Expand Up @@ -330,6 +335,10 @@ func (sck *socket) scheduleRmConn(c *Conn) {
sck.closedConns = append(sck.closedConns, c)
sck.reaperCond.Signal()
sck.reaperCond.L.Unlock()

if sck.autoReconnect {
sck.Dial(sck.ep)
}
}

// Type returns the type of this Socket (PUB, SUB, ...)
Expand Down
69 changes: 68 additions & 1 deletion socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) {
if err := pub.Dial(endpoint); err != nil {
t.Fatalf("Pub Dial failed: %v", err)
}
wg := new(sync.WaitGroup)
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -322,3 +322,70 @@ func TestConnMaxRetriesInfinite(t *testing.T) {
t.Fatalf("Dial called %d times, expected at least %d", transport.dialCalledCount, atLeastExpectedRetries)
}
}

func TestSocketAutomaticReconnect(t *testing.T) {
ep, err := EndPoint("tcp")
if err != nil {
t.Fatalf("could not find endpoint: %+v", err)
}
message := "test"

var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sbinet marked this conversation as resolved.
Show resolved Hide resolved

sendMessages := func(socket zmq4.Socket) {
wg.Add(1)
go func(t *testing.T) {
defer wg.Done()
for {
socket.Send(zmq4.NewMsgFromString([]string{message}))
if ctx.Err() != nil {
return
}
time.Sleep(1 * time.Millisecond)
}
}(t)
}

sub := zmq4.NewSub(context.Background(), zmq4.WithAutomaticReconnect(true))
defer sub.Close()
sub.SetOption(zmq4.OptionSubscribe, message)
pub := zmq4.NewPub(context.Background())
if err := pub.Listen(ep); err != nil {
t.Fatalf("Pub Dial failed: %v", err)
}
if err := sub.Dial(ep); err != nil {
t.Fatalf("Sub Dial failed: %v", err)
}

sendMessages(pub)

checkConnectionWorking := func(socket zmq4.Socket) {
for {
msg, err := socket.Recv()
if errors.Is(err, io.EOF) {
continue
}
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
if string(msg.Frames[0]) != message {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
return
}
}

checkConnectionWorking(sub)
pub.Close()

pub2 := zmq4.NewPub(context.Background())
defer pub2.Close()
if err := pub2.Listen(ep); err != nil {
t.Fatalf("Sub Listen failed: %v", err)
}
sendMessages(pub2)
checkConnectionWorking(sub)
}