Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Nov 24, 2023
1 parent 4d3f997 commit 17ea7cc
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
2 changes: 0 additions & 2 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ requirements:
- python
- flit


run:
- python
- jinja2
Expand All @@ -25,7 +24,6 @@ requirements:
- blessings
- pandoc >=2
- dask
- distributed

test:
requires:
Expand Down
15 changes: 9 additions & 6 deletions src/reportengine/resourcebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class DefaultStylePlugin(WorkerPlugin):
Class used to set style for each dask worker
"""

def __init__(self, style, default_style):
self.style = style
self.default_style = default_style
def __init__(self, style = None, default_style = None):
self.style = style if style is not None else "default"
self.default_style = default_style if default_style is not None else "default"

def setup(self, worker):
from matplotlib import style
Expand Down Expand Up @@ -238,9 +238,12 @@ def execute_parallel(self, scheduler=None):
"""
log.info("Initializing dask.distributed Client")

plugin = DefaultStylePlugin(
style=self.environment.style, default_style=self.environment.default_style
)
if self.environment is not None:
plugin = DefaultStylePlugin(
style=self.environment.style, default_style=self.environment.default_style
)
else:
plugin = DefaultStylePlugin()

if not scheduler:
# the deefault distributed logger is too noisy. Limit it here since
Expand Down
3 changes: 2 additions & 1 deletion src/reportengine/tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def test_collect(self):
d = namespaces.resolve(builder.rootns, [('lists',1)])
assert d['restaurant_collect'] == list("123")
builder.execute_parallel()
assert namespaces.resolve(builder.rootns, ('UK',))['score'] == -1
# since it is using dask it returns a future
assert namespaces.resolve(builder.rootns, ('UK',))['score'].result() == -1

def test_collect_raises(self):
with self.assertRaises(TypeError):
Expand Down
11 changes: 7 additions & 4 deletions src/reportengine/tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ def nsspec(x, beginning=()):
self.graph.add_node(mcall, inputs={gcall, hcall})


def _test_ns(self):
def _test_ns(self, promise=False):
mresult = 'fresult: 4'*10
namespace = self.rootns
self.assertEqual(namespace['mresult'], mresult)
if promise:
self.assertEqual(namespace['mresult'].result(), mresult)
else:
self.assertEqual(namespace['mresult'], mresult)


def test_seq_execute(self):
Expand All @@ -88,7 +91,7 @@ def test_seq_execute(self):

def test_parallel_execute(self):
self.execute_parallel()
self._test_ns()
self._test_ns(promise=True)

if __name__ =='__main__':
unittest.main()
unittest.main()

0 comments on commit 17ea7cc

Please sign in to comment.