-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpull_worker.py
92 lines (73 loc) · 2.79 KB
/
pull_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import itertools
import logging
import zmq
import config
import argparse
from util import new_task_handler
from local_worker import local_worker
from model import Task
import json
import util
import time
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
REQUEST_TIMEOUT = 2500 # milliseconds
RETRY_INTERVAL = 1 # seconds
def main(num_processes: str, dispatcher_url: int):
_, task_queue, result_queue = new_task_handler(local_worker, num_processes=num_processes)
logging.info("Connecting to task dispatcher...")
context = zmq.Context()
client = context.socket(zmq.REQ)
client.connect(dispatcher_url)
task_count = 0
last_none_reply_time = 0
for sequence in itertools.count():
if not result_queue.empty():
result = result_queue.get()
task_count -= 1
request_type = b"RESULT"
request_payload = json.dumps(result.dict()).encode()
elif last_none_reply_time + RETRY_INTERVAL > time.time():
continue
elif task_count < num_processes:
request_type = b"READY"
request_payload = b"READY"
request = [str(sequence).encode(), request_type, request_payload]
client.send_multipart(request)
# Retry loop
while True:
if (client.poll(REQUEST_TIMEOUT) & zmq.POLLIN) == 0:
logging.warning("No reply from server")
# Socket is confused. Close and remove it.
client.setsockopt(zmq.LINGER, 0)
client.close()
logging.info("Reconnecting to server…")
# Create new connection
client = context.socket(zmq.REQ)
client.connect(dispatcher_url)
logging.info(f"Resending ({request})")
client.send_multipart(request)
continue
reply = client.recv_multipart()
reply_sequence, reply_type, reply_payload = reply
if int(reply_sequence) != sequence:
logging.error(f"Mismatched sequence in request vs reply. req_seq: {reply_sequence}, rep_seq: {reply_sequence}")
continue
logging.info(f"Server replied OK ({reply})")
if reply_type == util.REPLY_TYPE_TASK:
task = Task(**json.loads(reply_payload))
task_count += 1
task_queue.put(task)
elif reply_type == util.REPLY_TYPE_NONE:
last_none_reply_time = time.time()
logging.info(f"No tasks available, retrying in ({RETRY_INTERVAL})")
elif reply_type == util.REPLY_TYPE_ACK:
logging.info("Server received results")
else:
logging.error("Server did not recognize request type")
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--num_processes', type=int, default=4)
parser.add_argument('-u', '--dispatcher_url', type=str, default=config.task_dispatcher_url)
args = parser.parse_args()
main(args.num_processes, args.dispatcher_url)