From f8521f6c7358cf2d890cb99984b8a3a26898bde7 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 9 Mar 2023 14:08:23 +0100 Subject: [PATCH] Workchains: Raise if `if_/while_` predicate does not return boolean (#259) The `if_` and `while_` conditionals are constructed with a predicate. The interface expects the predicate to be a callable that returns a boolean, which if true, the body of the conditional is entered. The problem is that the type of the value returned by the predicate was not explicitly checked, and any value that would evaluate as truthy would be accepted. This could potentially lead to unexpected behavior, such as an infinite loop for the `while_` construct. Here the `_Conditional.is_true` method is updated to explicitly check the type of the value returned by the predicate. If anything but a boolean is returned, a `TypeError` is raised. Cherry-pick: 800bcf154c0ea0d4576636b95d2ad2285adec266 --- src/plumpy/workchains.py | 7 ++++++- test/test_workchains.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 1bf0196b..e4eb6b57 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -388,7 +388,12 @@ def predicate(self) -> PREDICATE_TYPE: return self._predicate def is_true(self, workflow: 'WorkChain') -> bool: - return self._predicate(workflow) + result = self._predicate(workflow) + + if not isinstance(result, bool): + raise TypeError(f'The conditional predicate `{self._predicate.__name__}` did not return a boolean') + + return result def __call__(self, *instructions: Union[_Instruction, WC_COMMAND_TYPE]) -> _Instruction: assert self._body is None, 'Instructions have already been set' diff --git a/test/test_workchains.py b/test/test_workchains.py index 7ac020d6..2748c955 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -618,3 +618,14 @@ def step_two(self): workchain = Wf(inputs=dict(subspace={'one': 1, 'two': 2})) workchain.execute() + + +@pytest.mark.parametrize('construct', (if_, while_)) +def test_conditional_return_type(construct): + """Test that a conditional passed to the ``if_`` and ``while_`` functions that does not return a ``bool`` raises.""" + + def invalid_conditional(self): + return 'true' + + with pytest.raises(TypeError, match='The conditional predicate `invalid_conditional` did not return a boolean'): + construct(invalid_conditional)[0].is_true(None)