Skip to content

Commit

Permalink
fix unbounded memory growth with threadmap - closes #97
Browse files Browse the repository at this point in the history
  • Loading branch information
iiSeymour committed Jan 22, 2021
1 parent fe72fbb commit 7796927
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 40 deletions.
1 change: 1 addition & 0 deletions bonito/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def run(self):
while True:
item = self.input_queue.get()
if item is StopIteration:
self.output_queue.put(item)
break
k, v = item
mapping = next(self.aligner.map(v['sequence'], buf=thrbuf, MD=True), None)
Expand Down
60 changes: 20 additions & 40 deletions bonito/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import queue
from itertools import count
from threading import Thread
from functools import partial
from collections import deque
Expand Down Expand Up @@ -31,19 +32,12 @@ def process_map(func, iterator, n_proc=4, maxsize=0):
return iter(ProcessMap(func, iterator, n_proc, output_queue=Queue(maxsize)))


def thread_map(func, iterator, n_thread=4, maxsize=0, preserve_order=False):
def thread_map(func, iterator, n_thread=4, maxsize=2):
"""
Take an `iterator` of key, value pairs and apply `func` to all values using `n_thread` threads.
"""
if n_thread == 0: return ((k, func(v)) for k, v in iterator)
return iter(
ThreadMap(
partial(MapWorkerThread, func),
iterator, n_thread,
output_queue=queue.Queue(maxsize),
preserve_order=preserve_order
)
)
return iter(ThreadMap(partial(MapWorkerThread, func), iterator, n_thread, maxsize=maxsize))


class BackgroundIterator:
Expand Down Expand Up @@ -154,24 +148,21 @@ def run(self):
while True:
item = self.input_queue.get()
if item is StopIteration:
self.output_queue.put(item)
break
k, v = item
self.output_queue.put((k, self.func(v)))


class ThreadMap(Thread):
def __init__(self, worker_type, iterator, n_thread, output_queue=None, preserve_order=False):

def __init__(self, worker_type, iterator, n_thread, maxsize=2):
super().__init__()
self.iterator = iterator
self.work_queue = queue.Queue(n_thread*2)
self.output_queue = output_queue or queue.Queue()
self.workers = [worker_type(input_queue=self.work_queue, output_queue=self.output_queue) for _ in range(n_thread)]
if preserve_order:
self.keys = deque()
self.results = {}
else:
self.keys = None
self.results = None
self.n_thread = n_thread
self.work_queues = [queue.Queue(maxsize) for _ in range(n_thread)]
self.output_queues = [queue.Queue(maxsize) for _ in range(n_thread)]
self.workers = [worker_type(input_queue=in_q, output_queue=out_q) for (in_q, out_q) in zip(self.work_queues, self.output_queues)]

def start(self):
for worker in self.workers:
Expand All @@ -180,30 +171,19 @@ def start(self):

def __iter__(self):
self.start()
while True:
while self.keys:
key = self.keys.popleft()
if key in self.results:
yield (key, self.results.pop(key))
else:
self.keys.appendleft(key)
break
item = self.output_queue.get()
for i in count():
item = self.output_queues[i % self.n_thread].get()
if item is StopIteration:
#do we need to empty output_queues in order to join worker threads?
for j in range(i + 1, i + self.n_thread):
self.output_queues[j % self.n_thread].get()
break

if self.results is None:
yield item
else:
k, v = item
self.results[k] = v
yield item

def run(self):
for (k, v) in self.iterator:
self.work_queue.put((k, v))
if self.keys is not None: self.keys.append(k)
for _ in self.workers:
self.work_queue.put(StopIteration)
for i, (k, v) in enumerate(self.iterator):
self.work_queues[i % self.n_thread].put((k, v))
for q in self.work_queues:
q.put(StopIteration)
for worker in self.workers:
worker.join()
self.output_queue.put(StopIteration)

0 comments on commit 7796927

Please sign in to comment.