Skip to content

Commit

Permalink
fix: support dynamic classes in multiprocessing (#16)
Browse files Browse the repository at this point in the history
Need to set __module__ to __main__.

uqfoundation/dill#56
  • Loading branch information
william-silversmith authored Feb 24, 2019
1 parent 0dc5bdb commit 1ef5dd4
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
21 changes: 19 additions & 2 deletions taskqueue/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,25 @@ def capturing_soloprocess_upload(*args, **kwargs):
)
tasks = _scatter(tasks, parallel)

pool = pathos.pools.ProcessPool(parallel)
pool.map(uploadfn, tasks)
# This is a hack to get dill to pickle dynamically
# generated classes. This is an important use case
# for when we create iterators with generator __iter__
# functions on demand.

# https://github.com/uqfoundation/dill/issues/56

try:
task = next(item for item in tasks if item is not None)
except StopIteration:
return

cls_module = task.__class__.__module__
task.__class__.__module__ = '__main__'

with pathos.pools.ProcessPool(parallel) as pool:
pool.map(uploadfn, tasks)

task.__class__.__module__ = cls_module

if not error_queue.empty():
errors = []
Expand Down
21 changes: 21 additions & 0 deletions test/pathos_issue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from taskqueue import PrintTask

import copy

def crt_tasks(a,b):
bounds = 5

class TaskIterator():
def __init__(self, x):
self.x = x
def __len__(self):
return b-a
def __getitem__(self, slc):
itr = copy.deepcopy(self)
itr.x = 666
return itr
def __iter__(self):
for i in range(a,b):
yield PrintTask(str(i) + str(self.x))

return TaskIterator(bounds)
13 changes: 12 additions & 1 deletion test/test_taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,15 @@ def test_local_taskqueue():

with MockTaskQueue(parallel=True, progress=False) as tq:
for i in range(200):
tq.insert(ExecutePrintTask(), [i], { 'wow2': 4 })
tq.insert(ExecutePrintTask(), [i], { 'wow2': 4 })

def test_parallel_insert_all():
import pathos_issue

global QURL
tq = GreenTaskQueue(QURL)

tasks = pathos_issue.crt_tasks(5, 20)
tq.insert_all(tasks, parallel=2)

tq.purge()

0 comments on commit 1ef5dd4

Please sign in to comment.