From af8ae34018114477a9c55b5a6ce4dd850300d122 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 5 Mar 2019 21:39:56 +0100 Subject: [PATCH] Allow the definition of `None` as default in process functions It was impossible to define a process function with a keyword argument that had `None` as default, because the dynamically created ports always only specified `orm.Data` as valid types. So the validation of the default value would fail during the spec definition. Here we detect if `None` is passed as the default in the function signature, in which case we define the tuple of `orm.Data` and `type(None)` to be a valid type. Note that we need to use `type(None)` because the port validation later on will call `isinstance(input, valid_types)` which will not work if one of the values in `valid_types` is simply `None`. The `Process._setup_inputs` had to be adjusted to skip values of `None` because they cannot be linked to obviously but they can now potentially be passed as inputs to a `Process`. --- .../tests/engine/test_process_function.py | 88 ++++++++++++------- aiida/engine/processes/functions.py | 11 ++- aiida/engine/processes/process.py | 4 + 3 files changed, 70 insertions(+), 33 deletions(-) diff --git a/aiida/backends/tests/engine/test_process_function.py b/aiida/backends/tests/engine/test_process_function.py index 4ef24e6e27..2b514c3692 100644 --- a/aiida/backends/tests/engine/test_process_function.py +++ b/aiida/backends/tests/engine/test_process_function.py @@ -12,9 +12,9 @@ from __future__ import print_function from __future__ import absolute_import +from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.engine import run, run_get_node, submit, calcfunction, workfunction, Process, ExitCode -from aiida.orm import Int, Str, WorkFunctionNode, CalcFunctionNode from aiida.orm.nodes.data.bool import get_true_node DEFAULT_INT = 256 @@ -36,6 +36,8 @@ class TestProcessFunction(AiidaTestCase): function would complain as the dummy node class is not recognized as a valid process node. """ + # pylint: disable=too-many-public-methods + def setUp(self): super(TestProcessFunction, self).setUp() self.assertIsNone(Process.current()) @@ -53,9 +55,15 @@ def function_args(data_a): return data_a @workfunction - def function_args_with_default(data_a=Int(DEFAULT_INT)): + def function_args_with_default(data_a=orm.Int(DEFAULT_INT)): return data_a + @calcfunction + def function_with_none_default(int_a, int_b, int_c=None): + if int_c is not None: + return orm.Int(int_a + int_b + int_c) + return orm.Int(int_a + int_b) + @workfunction def function_kwargs(**kwargs): return kwargs @@ -67,12 +75,12 @@ def function_args_and_kwargs(data_a, **kwargs): return result @workfunction - def function_args_and_default(data_a, data_b=Int(DEFAULT_INT)): + def function_args_and_default(data_a, data_b=orm.Int(DEFAULT_INT)): return {'data_a': data_a, 'data_b': data_b} @workfunction def function_defaults( - data_a=Int(DEFAULT_INT), metadata={ + data_a=orm.Int(DEFAULT_INT), metadata={ 'label': DEFAULT_LABEL, 'description': DEFAULT_DESCRIPTION }): # pylint: disable=unused-argument,dangerous-default-value,missing-docstring @@ -90,6 +98,7 @@ def function_excepts(exception): self.function_return_true = function_return_true self.function_args = function_args self.function_args_with_default = function_args_with_default + self.function_with_none_default = function_with_none_default self.function_kwargs = function_kwargs self.function_args_and_kwargs = function_args_and_kwargs self.function_args_and_default = function_args_and_default @@ -125,9 +134,9 @@ def test_source_code_attributes(self): @calcfunction def test_process_function(data): - return {'result': Int(data.value + 1)} + return {'result': orm.Int(data.value + 1)} - _, node = test_process_function.run_get_node(data=Int(5)) + _, node = test_process_function.run_get_node(data=orm.Int(5)) # Read the source file of the calculation function that should be stored in the repository function_source_code = node.get_function_source_code().split('\n') @@ -157,8 +166,8 @@ def test_function_args(self): with self.assertRaises(ValueError): result = self.function_args() # pylint: disable=no-value-for-parameter - result = self.function_args(data_a=Int(arg)) - self.assertTrue(isinstance(result, Int)) + result = self.function_args(data_a=orm.Int(arg)) + self.assertTrue(isinstance(result, orm.Int)) self.assertEqual(result, arg) def test_function_args_with_default(self): @@ -166,16 +175,30 @@ def test_function_args_with_default(self): arg = 1 result = self.function_args_with_default() - self.assertTrue(isinstance(result, Int)) - self.assertEqual(result, Int(DEFAULT_INT)) + self.assertTrue(isinstance(result, orm.Int)) + self.assertEqual(result, orm.Int(DEFAULT_INT)) - result = self.function_args_with_default(data_a=Int(arg)) - self.assertTrue(isinstance(result, Int)) + result = self.function_args_with_default(data_a=orm.Int(arg)) + self.assertTrue(isinstance(result, orm.Int)) self.assertEqual(result, arg) + def test_function_with_none_default(self): + """Simple process function that defines a keyword with `None` as default value.""" + int_a = orm.Int(1) + int_b = orm.Int(2) + int_c = orm.Int(3) + + result = self.function_with_none_default(int_a, int_b) + self.assertTrue(isinstance(result, orm.Int)) + self.assertEqual(result, orm.Int(3)) + + result = self.function_with_none_default(int_a, int_b, int_c) + self.assertTrue(isinstance(result, orm.Int)) + self.assertEqual(result, orm.Int(6)) + def test_function_kwargs(self): """Simple process function that defines keyword arguments.""" - kwargs = {'data_a': Int(DEFAULT_INT)} + kwargs = {'data_a': orm.Int(DEFAULT_INT)} result = self.function_kwargs() self.assertTrue(isinstance(result, dict)) @@ -188,8 +211,8 @@ def test_function_kwargs(self): def test_function_args_and_kwargs(self): """Simple process function that defines a positional argument and keyword arguments.""" arg = 1 - args = (Int(DEFAULT_INT),) - kwargs = {'data_b': Int(arg)} + args = (orm.Int(DEFAULT_INT),) + kwargs = {'data_b': orm.Int(arg)} result = self.function_args_and_kwargs(*args) self.assertTrue(isinstance(result, dict)) @@ -202,12 +225,12 @@ def test_function_args_and_kwargs(self): def test_function_args_and_kwargs_default(self): """Simple process function that defines a positional argument and an argument with a default.""" arg = 1 - args_input_default = (Int(DEFAULT_INT),) - args_input_explicit = (Int(DEFAULT_INT), Int(arg)) + args_input_default = (orm.Int(DEFAULT_INT),) + args_input_explicit = (orm.Int(DEFAULT_INT), orm.Int(arg)) result = self.function_args_and_default(*args_input_default) self.assertTrue(isinstance(result, dict)) - self.assertEqual(result, {'data_a': args_input_default[0], 'data_b': Int(DEFAULT_INT)}) + self.assertEqual(result, {'data_a': args_input_default[0], 'data_b': orm.Int(DEFAULT_INT)}) result = self.function_args_and_default(*args_input_explicit) self.assertTrue(isinstance(result, dict)) @@ -218,13 +241,13 @@ def test_function_args_passing_kwargs(self): arg = 1 with self.assertRaises(ValueError): - self.function_args(data_a=Int(arg), data_b=Int(arg)) # pylint: disable=unexpected-keyword-arg + self.function_args(data_a=orm.Int(arg), data_b=orm.Int(arg)) # pylint: disable=unexpected-keyword-arg def test_function_set_label_description(self): """Verify that the label and description can be set for all process function variants.""" metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} - _, node = self.function_args.run_get_node(data_a=Int(DEFAULT_INT), metadata=metadata) + _, node = self.function_args.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) self.assertEqual(node.label, CUSTOM_LABEL) self.assertEqual(node.description, CUSTOM_DESCRIPTION) @@ -236,11 +259,11 @@ def test_function_set_label_description(self): self.assertEqual(node.label, CUSTOM_LABEL) self.assertEqual(node.description, CUSTOM_DESCRIPTION) - _, node = self.function_args_and_kwargs.run_get_node(data_a=Int(DEFAULT_INT), metadata=metadata) + _, node = self.function_args_and_kwargs.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) self.assertEqual(node.label, CUSTOM_LABEL) self.assertEqual(node.description, CUSTOM_DESCRIPTION) - _, node = self.function_args_and_default.run_get_node(data_a=Int(DEFAULT_INT), metadata=metadata) + _, node = self.function_args_and_default.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata) self.assertEqual(node.label, CUSTOM_LABEL) self.assertEqual(node.description, CUSTOM_DESCRIPTION) @@ -248,7 +271,7 @@ def test_function_defaults(self): """Verify that a process function can define a default label and description but can be overriden.""" metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION} - _, node = self.function_defaults.run_get_node(data_a=Int(DEFAULT_INT)) + _, node = self.function_defaults.run_get_node(data_a=orm.Int(DEFAULT_INT)) self.assertEqual(node.label, DEFAULT_LABEL) self.assertEqual(node.description, DEFAULT_DESCRIPTION) @@ -264,7 +287,7 @@ def test_launchers(self): result, node = run_get_node(self.function_return_true) self.assertTrue(result) self.assertEqual(result, get_true_node()) - self.assertTrue(isinstance(node, CalcFunctionNode)) + self.assertTrue(isinstance(node, orm.CalcFunctionNode)) with self.assertRaises(AssertionError): submit(self.function_return_true) @@ -276,7 +299,8 @@ def test_return_exit_code(self): exit_status = 418 exit_message = 'I am a teapot' - _, node = self.function_exit_code.run_get_node(exit_status=Int(exit_status), exit_message=Str(exit_message)) + message = orm.Str(exit_message) + _, node = self.function_exit_code.run_get_node(exit_status=orm.Int(exit_status), exit_message=message) self.assertTrue(node.is_finished) self.assertFalse(node.is_finished_ok) @@ -288,7 +312,7 @@ def test_normal_exception(self): exception = 'This process function excepted' with self.assertRaises(RuntimeError): - _, node = self.function_excepts.run_get_node(exception=Str(exception)) + _, node = self.function_excepts.run_get_node(exception=orm.Str(exception)) self.assertTrue(node.is_excepted) self.assertEqual(node.exception, exception) @@ -307,23 +331,23 @@ def mul(data_a, data_b): def add_mul_wf(data_a, data_b, data_c): return mul(add(data_a, data_b), data_c) - result, node = add_mul_wf.run_get_node(Int(3), Int(4), Int(5)) + result, node = add_mul_wf.run_get_node(orm.Int(3), orm.Int(4), orm.Int(5)) self.assertEqual(result, (3 + 4) * 5) - self.assertIsInstance(node, WorkFunctionNode) + self.assertIsInstance(node, orm.WorkFunctionNode) def test_hashes(self): """Test that the hashes generated for identical process functions with identical inputs are the same.""" - _, node1 = self.function_return_input.run_get_node(data=Int(2)) - _, node2 = self.function_return_input.run_get_node(data=Int(2)) + _, node1 = self.function_return_input.run_get_node(data=orm.Int(2)) + _, node2 = self.function_return_input.run_get_node(data=orm.Int(2)) self.assertEqual(node1.get_hash(), node1.get_extra('_aiida_hash')) self.assertEqual(node2.get_hash(), node2.get_extra('_aiida_hash')) self.assertEqual(node1.get_hash(), node2.get_hash()) def test_hashes_different(self): """Test that the hashes generated for identical process functions with different inputs are the different.""" - _, node1 = self.function_return_input.run_get_node(data=Int(2)) - _, node2 = self.function_return_input.run_get_node(data=Int(3)) + _, node1 = self.function_return_input.run_get_node(data=orm.Int(2)) + _, node2 = self.function_return_input.run_get_node(data=orm.Int(3)) self.assertEqual(node1.get_hash(), node1.get_extra('_aiida_hash')) self.assertEqual(node2.get_hash(), node2.get_extra('_aiida_hash')) self.assertNotEqual(node1.get_hash(), node2.get_hash()) diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 74a634f041..9b474c6241 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -192,10 +192,19 @@ def _define(cls, spec): if i >= first_default_pos: default = defaults[i - first_default_pos] + # If the keyword was already specified, simply override the default if spec.has_input(arg): spec.inputs[arg].default = default else: - spec.input(arg, valid_type=orm.Data, default=default) + # If the default is `None` make sure that the port also accepts a `NoneType` + # Note that we cannot use `None` because the validation will call `isinstance` which does not work + # when passing `None`, but it does work with `NoneType` which is returned by calling `type(None)` + if default is None: + valid_type = (orm.Data, type(None)) + else: + valid_type = (orm.Data,) + + spec.input(arg, valid_type=valid_type, default=default) # If the function support kwargs then allow dynamic inputs, otherwise disallow spec.inputs.dynamic = keywords is not None diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index c9d7dda56a..2dc496c6e9 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -588,6 +588,10 @@ def _setup_inputs(self): for name, node in self._flat_inputs().items(): + # Certain processes allow to specify ports with `None` as acceptable values + if node is None: + continue + # Special exception: set computer if node is a remote Code and our node does not yet have a computer set if isinstance(node, Code) and not node.is_local() and not self.node.computer: self.node.computer = node.get_remote_computer()