Skip to content

Commit

Permalink
Workchains: Raise if if_/while_ predicate does not return boolean (#…
Browse files Browse the repository at this point in the history
…259)

The `if_` and `while_` conditionals are constructed with a predicate.
The interface expects the predicate to be a callable that returns a
boolean, which if true, the body of the conditional is entered.

The problem is that the type of the value returned by the predicate was
not explicitly checked, and any value that would evaluate as truthy
would be accepted. This could potentially lead to unexpected behavior,
such as an infinite loop for the `while_` construct.

Here the `_Conditional.is_true` method is updated to explicitly check
the type of the value returned by the predicate. If anything but a
boolean is returned, a `TypeError` is raised.

Cherry-pick: 800bcf1
  • Loading branch information
sphuber authored and Sebastiaan Huber committed Mar 9, 2023
1 parent 514041b commit f8521f6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,12 @@ def predicate(self) -> PREDICATE_TYPE:
return self._predicate

def is_true(self, workflow: 'WorkChain') -> bool:
return self._predicate(workflow)
result = self._predicate(workflow)

if not isinstance(result, bool):
raise TypeError(f'The conditional predicate `{self._predicate.__name__}` did not return a boolean')

return result

def __call__(self, *instructions: Union[_Instruction, WC_COMMAND_TYPE]) -> _Instruction:
assert self._body is None, 'Instructions have already been set'
Expand Down
11 changes: 11 additions & 0 deletions test/test_workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,14 @@ def step_two(self):

workchain = Wf(inputs=dict(subspace={'one': 1, 'two': 2}))
workchain.execute()


@pytest.mark.parametrize('construct', (if_, while_))
def test_conditional_return_type(construct):
"""Test that a conditional passed to the ``if_`` and ``while_`` functions that does not return a ``bool`` raises."""

def invalid_conditional(self):
return 'true'

with pytest.raises(TypeError, match='The conditional predicate `invalid_conditional` did not return a boolean'):
construct(invalid_conditional)[0].is_true(None)

0 comments on commit f8521f6

Please sign in to comment.