Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the definition of None as default in process functions #2582

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 56 additions & 32 deletions aiida/backends/tests/engine/test_process_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from __future__ import print_function
from __future__ import absolute_import

from aiida import orm
from aiida.backends.testbase import AiidaTestCase
from aiida.engine import run, run_get_node, submit, calcfunction, workfunction, Process, ExitCode
from aiida.orm import Int, Str, WorkFunctionNode, CalcFunctionNode
from aiida.orm.nodes.data.bool import get_true_node

DEFAULT_INT = 256
Expand All @@ -36,6 +36,8 @@ class TestProcessFunction(AiidaTestCase):
function would complain as the dummy node class is not recognized as a valid process node.
"""

# pylint: disable=too-many-public-methods

def setUp(self):
super(TestProcessFunction, self).setUp()
self.assertIsNone(Process.current())
Expand All @@ -53,9 +55,15 @@ def function_args(data_a):
return data_a

@workfunction
def function_args_with_default(data_a=Int(DEFAULT_INT)):
def function_args_with_default(data_a=orm.Int(DEFAULT_INT)):
return data_a

@calcfunction
def function_with_none_default(int_a, int_b, int_c=None):
if int_c is not None:
return orm.Int(int_a + int_b + int_c)
return orm.Int(int_a + int_b)

@workfunction
def function_kwargs(**kwargs):
return kwargs
Expand All @@ -67,12 +75,12 @@ def function_args_and_kwargs(data_a, **kwargs):
return result

@workfunction
def function_args_and_default(data_a, data_b=Int(DEFAULT_INT)):
def function_args_and_default(data_a, data_b=orm.Int(DEFAULT_INT)):
return {'data_a': data_a, 'data_b': data_b}

@workfunction
def function_defaults(
data_a=Int(DEFAULT_INT), metadata={
data_a=orm.Int(DEFAULT_INT), metadata={
'label': DEFAULT_LABEL,
'description': DEFAULT_DESCRIPTION
}): # pylint: disable=unused-argument,dangerous-default-value,missing-docstring
Expand All @@ -90,6 +98,7 @@ def function_excepts(exception):
self.function_return_true = function_return_true
self.function_args = function_args
self.function_args_with_default = function_args_with_default
self.function_with_none_default = function_with_none_default
self.function_kwargs = function_kwargs
self.function_args_and_kwargs = function_args_and_kwargs
self.function_args_and_default = function_args_and_default
Expand Down Expand Up @@ -125,9 +134,9 @@ def test_source_code_attributes(self):

@calcfunction
def test_process_function(data):
return {'result': Int(data.value + 1)}
return {'result': orm.Int(data.value + 1)}

_, node = test_process_function.run_get_node(data=Int(5))
_, node = test_process_function.run_get_node(data=orm.Int(5))

# Read the source file of the calculation function that should be stored in the repository
function_source_code = node.get_function_source_code().split('\n')
Expand Down Expand Up @@ -157,25 +166,39 @@ def test_function_args(self):
with self.assertRaises(ValueError):
result = self.function_args() # pylint: disable=no-value-for-parameter

result = self.function_args(data_a=Int(arg))
self.assertTrue(isinstance(result, Int))
result = self.function_args(data_a=orm.Int(arg))
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, arg)

def test_function_args_with_default(self):
"""Simple process function that defines a single argument with a default."""
arg = 1

result = self.function_args_with_default()
self.assertTrue(isinstance(result, Int))
self.assertEqual(result, Int(DEFAULT_INT))
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, orm.Int(DEFAULT_INT))

result = self.function_args_with_default(data_a=Int(arg))
self.assertTrue(isinstance(result, Int))
result = self.function_args_with_default(data_a=orm.Int(arg))
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, arg)

def test_function_with_none_default(self):
"""Simple process function that defines a keyword with `None` as default value."""
int_a = orm.Int(1)
int_b = orm.Int(2)
int_c = orm.Int(3)

result = self.function_with_none_default(int_a, int_b)
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, orm.Int(3))

result = self.function_with_none_default(int_a, int_b, int_c)
self.assertTrue(isinstance(result, orm.Int))
self.assertEqual(result, orm.Int(6))

def test_function_kwargs(self):
"""Simple process function that defines keyword arguments."""
kwargs = {'data_a': Int(DEFAULT_INT)}
kwargs = {'data_a': orm.Int(DEFAULT_INT)}

result = self.function_kwargs()
self.assertTrue(isinstance(result, dict))
Expand All @@ -188,8 +211,8 @@ def test_function_kwargs(self):
def test_function_args_and_kwargs(self):
"""Simple process function that defines a positional argument and keyword arguments."""
arg = 1
args = (Int(DEFAULT_INT),)
kwargs = {'data_b': Int(arg)}
args = (orm.Int(DEFAULT_INT),)
kwargs = {'data_b': orm.Int(arg)}

result = self.function_args_and_kwargs(*args)
self.assertTrue(isinstance(result, dict))
Expand All @@ -202,12 +225,12 @@ def test_function_args_and_kwargs(self):
def test_function_args_and_kwargs_default(self):
"""Simple process function that defines a positional argument and an argument with a default."""
arg = 1
args_input_default = (Int(DEFAULT_INT),)
args_input_explicit = (Int(DEFAULT_INT), Int(arg))
args_input_default = (orm.Int(DEFAULT_INT),)
args_input_explicit = (orm.Int(DEFAULT_INT), orm.Int(arg))

result = self.function_args_and_default(*args_input_default)
self.assertTrue(isinstance(result, dict))
self.assertEqual(result, {'data_a': args_input_default[0], 'data_b': Int(DEFAULT_INT)})
self.assertEqual(result, {'data_a': args_input_default[0], 'data_b': orm.Int(DEFAULT_INT)})

result = self.function_args_and_default(*args_input_explicit)
self.assertTrue(isinstance(result, dict))
Expand All @@ -218,13 +241,13 @@ def test_function_args_passing_kwargs(self):
arg = 1

with self.assertRaises(ValueError):
self.function_args(data_a=Int(arg), data_b=Int(arg)) # pylint: disable=unexpected-keyword-arg
self.function_args(data_a=orm.Int(arg), data_b=orm.Int(arg)) # pylint: disable=unexpected-keyword-arg

def test_function_set_label_description(self):
"""Verify that the label and description can be set for all process function variants."""
metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION}

_, node = self.function_args.run_get_node(data_a=Int(DEFAULT_INT), metadata=metadata)
_, node = self.function_args.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata)
self.assertEqual(node.label, CUSTOM_LABEL)
self.assertEqual(node.description, CUSTOM_DESCRIPTION)

Expand All @@ -236,19 +259,19 @@ def test_function_set_label_description(self):
self.assertEqual(node.label, CUSTOM_LABEL)
self.assertEqual(node.description, CUSTOM_DESCRIPTION)

_, node = self.function_args_and_kwargs.run_get_node(data_a=Int(DEFAULT_INT), metadata=metadata)
_, node = self.function_args_and_kwargs.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata)
self.assertEqual(node.label, CUSTOM_LABEL)
self.assertEqual(node.description, CUSTOM_DESCRIPTION)

_, node = self.function_args_and_default.run_get_node(data_a=Int(DEFAULT_INT), metadata=metadata)
_, node = self.function_args_and_default.run_get_node(data_a=orm.Int(DEFAULT_INT), metadata=metadata)
self.assertEqual(node.label, CUSTOM_LABEL)
self.assertEqual(node.description, CUSTOM_DESCRIPTION)

def test_function_defaults(self):
"""Verify that a process function can define a default label and description but can be overriden."""
metadata = {'label': CUSTOM_LABEL, 'description': CUSTOM_DESCRIPTION}

_, node = self.function_defaults.run_get_node(data_a=Int(DEFAULT_INT))
_, node = self.function_defaults.run_get_node(data_a=orm.Int(DEFAULT_INT))
self.assertEqual(node.label, DEFAULT_LABEL)
self.assertEqual(node.description, DEFAULT_DESCRIPTION)

Expand All @@ -264,7 +287,7 @@ def test_launchers(self):
result, node = run_get_node(self.function_return_true)
self.assertTrue(result)
self.assertEqual(result, get_true_node())
self.assertTrue(isinstance(node, CalcFunctionNode))
self.assertTrue(isinstance(node, orm.CalcFunctionNode))

with self.assertRaises(AssertionError):
submit(self.function_return_true)
Expand All @@ -276,7 +299,8 @@ def test_return_exit_code(self):
exit_status = 418
exit_message = 'I am a teapot'

_, node = self.function_exit_code.run_get_node(exit_status=Int(exit_status), exit_message=Str(exit_message))
message = orm.Str(exit_message)
_, node = self.function_exit_code.run_get_node(exit_status=orm.Int(exit_status), exit_message=message)

self.assertTrue(node.is_finished)
self.assertFalse(node.is_finished_ok)
Expand All @@ -288,7 +312,7 @@ def test_normal_exception(self):
exception = 'This process function excepted'

with self.assertRaises(RuntimeError):
_, node = self.function_excepts.run_get_node(exception=Str(exception))
_, node = self.function_excepts.run_get_node(exception=orm.Str(exception))
self.assertTrue(node.is_excepted)
self.assertEqual(node.exception, exception)

Expand All @@ -307,23 +331,23 @@ def mul(data_a, data_b):
def add_mul_wf(data_a, data_b, data_c):
return mul(add(data_a, data_b), data_c)

result, node = add_mul_wf.run_get_node(Int(3), Int(4), Int(5))
result, node = add_mul_wf.run_get_node(orm.Int(3), orm.Int(4), orm.Int(5))

self.assertEqual(result, (3 + 4) * 5)
self.assertIsInstance(node, WorkFunctionNode)
self.assertIsInstance(node, orm.WorkFunctionNode)

def test_hashes(self):
"""Test that the hashes generated for identical process functions with identical inputs are the same."""
_, node1 = self.function_return_input.run_get_node(data=Int(2))
_, node2 = self.function_return_input.run_get_node(data=Int(2))
_, node1 = self.function_return_input.run_get_node(data=orm.Int(2))
_, node2 = self.function_return_input.run_get_node(data=orm.Int(2))
self.assertEqual(node1.get_hash(), node1.get_extra('_aiida_hash'))
self.assertEqual(node2.get_hash(), node2.get_extra('_aiida_hash'))
self.assertEqual(node1.get_hash(), node2.get_hash())

def test_hashes_different(self):
"""Test that the hashes generated for identical process functions with different inputs are the different."""
_, node1 = self.function_return_input.run_get_node(data=Int(2))
_, node2 = self.function_return_input.run_get_node(data=Int(3))
_, node1 = self.function_return_input.run_get_node(data=orm.Int(2))
_, node2 = self.function_return_input.run_get_node(data=orm.Int(3))
self.assertEqual(node1.get_hash(), node1.get_extra('_aiida_hash'))
self.assertEqual(node2.get_hash(), node2.get_extra('_aiida_hash'))
self.assertNotEqual(node1.get_hash(), node2.get_hash())
11 changes: 10 additions & 1 deletion aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,19 @@ def _define(cls, spec):
if i >= first_default_pos:
default = defaults[i - first_default_pos]

# If the keyword was already specified, simply override the default
if spec.has_input(arg):
spec.inputs[arg].default = default
else:
spec.input(arg, valid_type=orm.Data, default=default)
# If the default is `None` make sure that the port also accepts a `NoneType`
# Note that we cannot use `None` because the validation will call `isinstance` which does not work
# when passing `None`, but it does work with `NoneType` which is returned by calling `type(None)`
if default is None:
valid_type = (orm.Data, type(None))
else:
valid_type = (orm.Data,)

spec.input(arg, valid_type=valid_type, default=default)

# If the function support kwargs then allow dynamic inputs, otherwise disallow
spec.inputs.dynamic = keywords is not None
Expand Down
4 changes: 4 additions & 0 deletions aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ def _setup_inputs(self):

for name, node in self._flat_inputs().items():

# Certain processes allow to specify ports with `None` as acceptable values
if node is None:
continue

# Special exception: set computer if node is a remote Code and our node does not yet have a computer set
if isinstance(node, Code) and not node.is_local() and not self.node.computer:
self.node.computer = node.get_remote_computer()
Expand Down