Skip to content

Commit

Permalink
Merge pull request #58 from pkuyym/fix-56
Browse files Browse the repository at this point in the history
Simplify parallel part for data processing and fix abnormal exit.
  • Loading branch information
pkuyym authored Dec 6, 2017
2 parents 907898a + 20e2258 commit f9ebff7
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 31 deletions.
5 changes: 1 addition & 4 deletions data_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,7 @@ def reader():

reader, cleanup_callback = xmap_readers_mp(
lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
reader,
self._num_threads,
4096,
order=True)
reader, self._num_threads, 4096)

# register callback to main process
atexit.register(cleanup_callback)
Expand Down
68 changes: 41 additions & 27 deletions data_utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from Queue import Queue
from threading import Thread
from multiprocessing import Process, Manager
from multiprocessing import Process, Manager, Value
from paddle.v2.dataset.common import md5file


Expand Down Expand Up @@ -101,40 +101,35 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
:type process_num: int
:param buffer_size: Maximal buffer size.
:type buffer_size: int
:param order: Reserve the order of samples from the given reader.
:type order: bool
:return: The wrappered reader
:rtype: callable
:return: The wrappered reader and cleanup callback
:rtype: tuple
"""
end_flag = XmapEndSignal()

# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for sample in reader():
in_queue.put(sample)
in_queue.put(end_flag)
read_workers = []
handle_workers = []
flush_workers = []

read_exit_flag = Value('i', 0)
handle_exit_flag = Value('i', 0)
flush_exit_flag = Value('i', 0)

# define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue):
for order_id, sample in enumerate(reader()):
if read_exit_flag.value == 1: break
in_queue.put((order_id, sample))
in_queue.put(end_flag)

# define a worker to handle samples from in_queue by mapper and put results
# to out_queue
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
out_queue.put(mapper(sample))
sample = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# the reading worker should not exit until all handling work exited
while handle_exit_flag.value == 0 or read_exit_flag.value == 0:
time.sleep(0.001)

# define a worker to handle samples from in_queue by mapper and put results
# to out_queue with order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
while not isinstance(ins, XmapEndSignal):
if handle_exit_flag.value == 1: break
order_id, sample = ins
result = mapper(sample)
while order_id != out_order[0]:
Expand All @@ -144,22 +139,39 @@ def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# wait for exit of flushing worker
while flush_exit_flag.value == 0 or handle_exit_flag.value == 0:
time.sleep(0.001)
read_exit_flag.value = 1
handle_exit_flag.value = 1

# define a thread worker to flush samples from Manager.Queue to Queue
# for acceleration
def flush_worker(in_queue, out_queue):
finish = 0
while finish < process_num:
while finish < process_num and flush_exit_flag.value == 0:
sample = in_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
out_queue.put(sample)
out_queue.put(end_flag)
handle_exit_flag.value = 1
flush_exit_flag.value = 1

def cleanup():
# kill all sub process and threads
os._exit(0)
# first exit flushing workers
flush_exit_flag.value = 1
for w in flush_workers:
w.join()
# next exit handling workers
handle_exit_flag.value = 1
for w in handle_workers:
w.join()
# last exit reading workers
read_exit_flag.value = 1
for w in read_workers:
w.join()

def xreader():
# prepare shared memory
Expand All @@ -169,27 +181,29 @@ def xreader():
out_order = manager.list([0])

# start a read worker in a process
target = order_read_worker if order else read_worker
target = order_read_worker
p = Process(target=target, args=(reader, in_queue))
p.daemon = True
p.start()
read_workers.append(p)

# start handle_workers with multiple processes
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
target = order_handle_worker
args = (in_queue, out_queue, mapper, out_order)
workers = [
Process(target=target, args=args) for _ in xrange(process_num)
]
for w in workers:
w.daemon = True
w.start()
handle_workers.append(w)

# start a thread to read data from slow Manager.Queue
flush_queue = Queue(buffer_size)
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
t.daemon = True
t.start()
flush_workers.append(t)

# get results
sample = flush_queue.get()
Expand Down

0 comments on commit f9ebff7

Please sign in to comment.