Skip to content
This repository was archived by the owner on Apr 13, 2024. It is now read-only.

Commit f6a8e39

Browse files
committed
add test with multiple subscribers
1 parent 84b3af5 commit f6a8e39

File tree

4 files changed

+53
-21
lines changed

4 files changed

+53
-21
lines changed

messaging/impl_zmq.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ Message * ZMQSubSocket::receive(bool non_blocking){
8989
// Make a copy to ensure the data is aligned
9090
r = new ZMQMessage;
9191
r->init((char*)zmq_msg_data(&msg), zmq_msg_size(&msg));
92-
}
92+
}
9393

9494
zmq_msg_close(&msg);
9595
return r;

messaging/messaging.pxd

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ cdef extern from "messaging.hpp":
3939
@staticmethod
4040
Poller * create()
4141
void registerSocket(SubSocket *)
42-
vector[SubSocket*] poll(int)
42+
vector[SubSocket*] poll(int) nogil

messaging/messaging_pyx.pyx

+8-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ cdef class Context:
1919
def __cinit__(self):
2020
self.context = cppContext.create()
2121

22+
def term(self):
23+
del self.context
24+
self.context = NULL
25+
2226
def __dealloc__(self):
2327
pass
2428
# Deleting the context will hang if sockets are still active
@@ -43,8 +47,11 @@ cdef class Poller:
4347

4448
def poll(self, timeout):
4549
sockets = []
50+
cdef int t = timeout
51+
52+
with nogil:
53+
result = self.poller.poll(t)
4654

47-
result = self.poller.poll(timeout)
4855
for s in result:
4956
socket = SubSocket()
5057
socket.setPtr(s)

messaging/tests/test_poller.py

+43-18
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
import unittest
2+
import os
23
import time
34
import cereal.messaging as messaging
45

5-
from multiprocessing import Process, Pipe
6+
import concurrent.futures
67

78

8-
def poller(pipe):
9+
def poller():
910
context = messaging.Context()
1011

12+
p = messaging.Poller()
13+
1114
sub = messaging.SubSocket()
1215
sub.connect(context, 'controlsState')
13-
14-
p = messaging.Poller()
1516
p.registerSocket(sub)
1617

17-
while True:
18-
pipe.recv()
18+
socks = p.poll(1000)
19+
r = [s.receive(non_blocking=True) for s in socks]
1920

20-
socks = p.poll(1000)
21-
pipe.send([s.receive(non_blocking=True) for s in socks])
21+
return r
2222

2323

2424
class TestPoller(unittest.TestCase):
@@ -28,19 +28,44 @@ def test_poll_once(self):
2828
pub = messaging.PubSocket()
2929
pub.connect(context, 'controlsState')
3030

31-
pipe, pipe_child = Pipe()
32-
proc = Process(target=poller, args=(pipe_child,))
33-
proc.start()
31+
with concurrent.futures.ThreadPoolExecutor() as e:
32+
poll = e.submit(poller)
33+
34+
time.sleep(0.1) # Slow joiner syndrome
35+
36+
# Send message
37+
pub.send("a")
38+
39+
# Wait for poll result
40+
result = poll.result()
41+
42+
del pub
43+
context.term()
44+
45+
self.assertEqual(result, [b"a"])
46+
47+
@unittest.skipIf(os.environ.get('MSGQ'), "fails under msgq")
48+
def test_poll_and_create_many_subscribers(self):
49+
context = messaging.Context()
50+
51+
pub = messaging.PubSocket()
52+
pub.connect(context, 'controlsState')
53+
54+
with concurrent.futures.ThreadPoolExecutor() as e:
55+
poll = e.submit(poller)
3456

35-
time.sleep(.1)
57+
time.sleep(0.1) # Slow joiner syndrome
58+
c = messaging.Context()
59+
for _ in range(10):
60+
messaging.SubSocket().connect(c, 'controlsState')
3661

37-
# Start poll
38-
pipe.send("go")
62+
# Send message
63+
pub.send("a")
3964

40-
# Send message
41-
pub.send("a")
65+
# Wait for poll result
66+
result = poll.result()
4267

43-
result = pipe.recv()
44-
proc.kill()
68+
del pub
69+
context.term()
4570

4671
self.assertEqual(result, [b"a"])

0 commit comments

Comments
 (0)