Skip to content

Commit

Permalink
Merge pull request #61 from NNPDF/scarlehoff-patch-1
Browse files Browse the repository at this point in the history
Update reportengine dependencies
  • Loading branch information
scarlehoff authored Mar 7, 2024
2 parents 79eec2e + 56a3242 commit f393ee9
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
1 change: 0 additions & 1 deletion conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ requirements:
- python
- flit


run:
- python
- jinja2
Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@ classifiers = [
description-file="README.md"
requires = [
"jinja2",
"ruamel_yaml",
"ruamel_yaml<0.18", # the code is not compatible with ruamel 0.18
"matplotlib",
"pandas",
"pygments",
"blessings",
"dask",
"dask[distributed]",
]

[tool.flit.metadata.requires-extra]
test = [
"pytest",
"hypothesis",
]

dashboard = [
"bokeh!=3.0.*,>=2.4.2"
]
15 changes: 9 additions & 6 deletions src/reportengine/resourcebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class DefaultStylePlugin(WorkerPlugin):
Class used to set style for each dask worker
"""

def __init__(self, style, default_style):
self.style = style
self.default_style = default_style
def __init__(self, style = None, default_style = None):
self.style = style if style is not None else "default"
self.default_style = default_style if default_style is not None else "default"

def setup(self, worker):
from matplotlib import style
Expand Down Expand Up @@ -238,9 +238,12 @@ def execute_parallel(self, scheduler=None):
"""
log.info("Initializing dask.distributed Client")

plugin = DefaultStylePlugin(
style=self.environment.style, default_style=self.environment.default_style
)
if self.environment is not None:
plugin = DefaultStylePlugin(
style=self.environment.style, default_style=self.environment.default_style
)
else:
plugin = DefaultStylePlugin()

if not scheduler:
# the deefault distributed logger is too noisy. Limit it here since
Expand Down
3 changes: 2 additions & 1 deletion src/reportengine/tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def test_collect(self):
d = namespaces.resolve(builder.rootns, [('lists',1)])
assert d['restaurant_collect'] == list("123")
builder.execute_parallel()
assert namespaces.resolve(builder.rootns, ('UK',))['score'] == -1
# since it is using dask it returns a future
assert namespaces.resolve(builder.rootns, ('UK',))['score'].result() == -1

def test_collect_raises(self):
with self.assertRaises(TypeError):
Expand Down
11 changes: 7 additions & 4 deletions src/reportengine/tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ def nsspec(x, beginning=()):
self.graph.add_node(mcall, inputs={gcall, hcall})


def _test_ns(self):
def _test_ns(self, promise=False):
mresult = 'fresult: 4'*10
namespace = self.rootns
self.assertEqual(namespace['mresult'], mresult)
if promise:
self.assertEqual(namespace['mresult'].result(), mresult)
else:
self.assertEqual(namespace['mresult'], mresult)


def test_seq_execute(self):
Expand All @@ -88,7 +91,7 @@ def test_seq_execute(self):

def test_parallel_execute(self):
self.execute_parallel()
self._test_ns()
self._test_ns(promise=True)

if __name__ =='__main__':
unittest.main()
unittest.main()

0 comments on commit f393ee9

Please sign in to comment.