Skip to content

Commit

Permalink
Add context argument to Port validator method
Browse files Browse the repository at this point in the history
For normal `Port` instances, the `validator` signature changes from
`validator(value)` to `validator(value, port)` where `port` is the
instance of the `Port` to which the called validator is assigned. For
`PortNamespace` validators the `port` argument is the instance of the
port namespace itself.

The `port` is especially useful in validators of namespaces because it
can be used to see what ports are actually present. The expose
functionality, which allows to exclude or include only certain ports can
fundamentally change the ports that are present in a namespace, however,
the validator remains untouced. To make the validator robust to these
changes it should check if the port of interest is still present in the
namespace before checking the corresponding value passed in the `values`
argument.
  • Loading branch information
sphuber committed Jan 9, 2020
1 parent 660e8ad commit 2b1a268
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 13 deletions.
22 changes: 20 additions & 2 deletions plumpy/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,28 @@
import json
import logging
import six
import warnings

from plumpy.utils import is_mutable_property, type_check

if six.PY2:
import collections
from inspect import getargspec as get_arg_spec
else:
import collections.abc as collections
from inspect import getfullargspec as get_arg_spec

_LOGGER = logging.getLogger(__name__)
UNSPECIFIED = ()

__all__ = ['UNSPECIFIED', 'PortValidationError', 'Port', 'InputPort', 'OutputPort']


VALIDATOR_SIGNATURE_DEPRECATION_WARNING = """the validator `{}` has a signature that only takes a single argument.
This has been deprecated and the new signature is `validator(value, port)` where the `port` argument will be the
port instance to which the validator has been assigned."""


class PortValidationError(Exception):
"""Error when validation fails on a port"""

Expand Down Expand Up @@ -190,7 +198,12 @@ def validate(self, value, breadcrumbs=()):
self.name, type(value), self._valid_type)

if not validation_error and self._validator is not None:
result = self.validator(value)
spec = get_arg_spec(self.validator)
if len(spec[0]) == 1:
warnings.warn(VALIDATOR_SIGNATURE_DEPRECATION_WARNING.format(self.validator.__name__))
result = self.validator(value)
else:
result = self.validator(value, self)
if result is not None:
assert isinstance(result, str), "Validator returned non string type"
validation_error = result
Expand Down Expand Up @@ -606,7 +619,12 @@ def validate(self, port_values=None, breadcrumbs=()):

# Validate the validator after the ports themselves, as it most likely will rely on the port values
if self.validator is not None:
message = self.validator(port_values_clone)
spec = get_arg_spec(self.validator)
if len(spec[0]) == 1:
warnings.warn(VALIDATOR_SIGNATURE_DEPRECATION_WARNING.format(self.validator.__name__))
message = self.validator(port_values_clone)
else:
message = self.validator(port_values_clone, self)
if message is not None:
assert isinstance(message, str), \
"Validator returned something other than None or str: '{}'".format(type(message))
Expand Down
8 changes: 4 additions & 4 deletions test/test_expose.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestExposeProcess(utils.TestCaseWithLoop):
def setUp(self):
super(TestExposeProcess, self).setUp()

def validator_function(input):
def validator_function(input, port):
pass

class BaseNamespaceProcess(NewLoopProcess):
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_expose_ports_top_level(self):
properties with that of the exposed process
"""

def validator_function(input):
def validator_function(input, port):
pass

# Define child process with all mutable properties of the inputs PortNamespace to a non-default value
Expand Down Expand Up @@ -230,7 +230,7 @@ def test_expose_ports_top_level_override(self):
namespace_options will be the end-all-be-all
"""

def validator_function(input):
def validator_function(input, port):
pass

# Define child process with all mutable properties of the inputs PortNamespace to a non-default value
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_expose_ports_namespace(self):
namespace with the properties of the exposed port namespace
"""

def validator_function(input):
def validator_function(input, port):
pass

# Define child process with all mutable properties of the inputs PortNamespace to a non-default value
Expand Down
13 changes: 8 additions & 5 deletions test/test_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_validate(self):

def test_validator(self):

def validate(value):
def validate(value, port):
assert isinstance(port, Port)
if not isinstance(value, int):
return "Not int"
return None
Expand All @@ -46,7 +47,8 @@ def test_default(self):
def test_validator(self):
"""Test the validator functionality."""

def integer_validator(value):
def integer_validator(value, port):
assert isinstance(port, Port)
if value < 0:
return 'Only positive integers allowed'

Expand Down Expand Up @@ -84,8 +86,8 @@ def test_default(self):
help_string = 'Help string'
required = False

def validator(value):
pass
def validator(value, port):
assert isinstance(port, Port)

port = OutputPort(name, valid_type=valid_type, help=help_string, required=required, validator=validator)
self.assertEqual(port.name, name)
Expand Down Expand Up @@ -120,7 +122,8 @@ def test_port_namespace(self):
def test_port_namespace_validation(self):
"""Test validate method of a `PortNamespace`."""

def validator(port_values):
def validator(port_values, port):
assert isinstance(port, PortNamespace)
if port_values['explicit'] < 0 or port_values['dynamic'] < 0:
return 'Only positive integers allowed'

Expand Down
4 changes: 2 additions & 2 deletions test/test_process_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_input_namespaced(self):
def test_validator(self):
"""Test the port validator with default."""

def dict_validator(dictionary):
def dict_validator(dictionary, port):
if 'key' not in dictionary or dictionary['key'] is not 'value':
return 'Invalid dictionary'

Expand All @@ -94,7 +94,7 @@ def dict_validator(dictionary):
def test_validate(self):
"""Test the global spec validator functionality."""

def is_valid(inputs):
def is_valid(inputs, port):
if not ('a' in inputs) ^ ('b' in inputs):
return 'Must have a OR b in inputs'
return
Expand Down

0 comments on commit 2b1a268

Please sign in to comment.