Skip to content

Commit

Permalink
Merge pull request #306 from burnpanck/add-Dict-key-validation
Browse files Browse the repository at this point in the history
added a key_trait argument to Dict
  • Loading branch information
minrk authored Nov 28, 2016
2 parents c3b785c + 3dca190 commit b2266cc
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 46 deletions.
30 changes: 23 additions & 7 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,39 +1603,55 @@ def test_dict_assignment():
assert c.value is d


class UniformlyValidatedDictTrait(HasTraits):
class UniformlyValueValidatedDictTrait(HasTraits):

value = Dict(trait=Unicode(),
default_value={'foo': '1'})


class TestInstanceUniformlyValidatedDict(TraitTestBase):
class TestInstanceUniformlyValueValidatedDict(TraitTestBase):

obj = UniformlyValidatedDictTrait()
obj = UniformlyValueValidatedDictTrait()

_default_value = {'foo': '1'}
_good_values = [{'foo': '0', 'bar': '1'}]
_bad_values = [{'foo': 0, 'bar': '1'}]


class KeyValidatedDictTrait(HasTraits):
class NonuniformlyValueValidatedDictTrait(HasTraits):

value = Dict(traits={'foo': Int()},
default_value={'foo': 1})


class TestInstanceKeyValidatedDict(TraitTestBase):
class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase):

obj = KeyValidatedDictTrait()
obj = NonuniformlyValueValidatedDictTrait()

_default_value = {'foo': 1}
_good_values = [{'foo': 0, 'bar': '1'}, {'foo': 0, 'bar': 1}]
_bad_values = [{'foo': '0', 'bar': '1'}]


class KeyValidatedDictTrait(HasTraits):

value = Dict(key_trait=Unicode(),
default_value={'foo': '1'})


class TestInstanceKeyValidatedDict(TraitTestBase):

obj = KeyValidatedDictTrait()

_default_value = {'foo': '1'}
_good_values = [{'foo': '0', 'bar': '1'}]
_bad_values = [{'foo': '0', 0: '1'}]


class FullyValidatedDictTrait(HasTraits):

value = Dict(trait=Unicode(),
key_trait=Unicode(),
traits={'foo': Int()},
default_value={'foo': 1})

Expand All @@ -1646,7 +1662,7 @@ class TestInstanceFullyValidatedDict(TraitTestBase):

_default_value = {'foo': 1}
_good_values = [{'foo': 0, 'bar': '1'}, {'foo': 1, 'bar': '2'}]
_bad_values = [{'foo': 0, 'bar': 1}, {'foo': '0', 'bar': '1'}]
_bad_values = [{'foo': 0, 'bar': 1}, {'foo': '0', 'bar': '1'}, {'foo': 0, 0: '1'}]


def test_dict_default_value():
Expand Down
120 changes: 81 additions & 39 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,9 +2482,10 @@ def instance_init(self, obj):

class Dict(Instance):
"""An instance of a Python dict."""
_trait = None
_value_trait = None
_key_trait = None

def __init__(self, trait=None, traits=None, default_value=Undefined,
def __init__(self, value_trait=None, per_key_traits=None, key_trait=None, default_value=Undefined,
**kwargs):
"""Create a dict trait type from a Python dict.
Expand All @@ -2494,24 +2495,48 @@ def __init__(self, trait=None, traits=None, default_value=Undefined,
Parameters
----------
trait : TraitType [ optional ]
value_trait : TraitType [ optional ]
The specified trait type to check and use to restrict contents of
the Container. If unspecified, trait types are not checked.
traits : Dictionary of trait types [ optional ]
per_key_traits : Dictionary of trait types [ optional ]
A Python dictionary containing the types that are valid for
restricting the content of the Dict Container for certain keys.
key_trait : TraitType [ optional ]
The type for restricting the keys of the container. If
unspecified, the types of the keys are not checked.
default_value : SequenceType [ optional ]
The default value for the Dict. Must be dict, tuple, or None, and
will be cast to a dict if not None. If `trait` is specified, the
`default_value` must conform to the constraints it specifies.
"""

# handle deprecated keywords
trait = kwargs.pop('trait', None)
if trait is not None:
if value_trait is not None:
raise TypeError("Found a value for both `value_trait` and it's deprecated alias `trait`.")
value_trait = trait
warn("Keyword `trait` is deprecated, use `value_trait` instead", DeprecationWarning)
traits = kwargs.pop('traits', None)
if traits is not None:
if per_key_traits is not None:
raise TypeError("Found a value for both `per_key_traits` and it's deprecated alias `traits`.")
per_key_traits = traits
warn("Keyword `traits` is deprecated, use `per_key_traits` instead", DeprecationWarning)

# Handling positional arguments
if default_value is Undefined and trait is not None:
if not is_trait(trait):
default_value = trait
trait = None
if default_value is Undefined and value_trait is not None:
if not is_trait(value_trait):
default_value = value_trait
value_trait = None

if key_trait is None and per_key_traits is not None:
if is_trait(per_key_traits):
key_trait = per_key_traits
per_key_traits = None

# Handling default value
if default_value is Undefined:
Expand All @@ -2526,21 +2551,32 @@ def __init__(self, trait=None, traits=None, default_value=Undefined,
raise TypeError('default value of Dict was %s' % default_value)

# Case where a type of TraitType is provided rather than an instance
if is_trait(trait):
if isinstance(trait, type):
if is_trait(value_trait):
if isinstance(value_trait, type):
warn("Traits should be given as instances, not types (for example, `Int()`, not `Int`)"
" Passing types is deprecated in traitlets 4.1.",
DeprecationWarning, stacklevel=2)
self._trait = trait() if isinstance(trait, type) else trait
elif trait is not None:
raise TypeError("`trait` must be a Trait or None, got %s" % repr_type(trait))
value_trait = value_trait()
self._value_trait = value_trait
elif value_trait is not None:
raise TypeError("`value_trait` must be a Trait or None, got %s" % repr_type(value_trait))

if is_trait(key_trait):
if isinstance(key_trait, type):
warn("Traits should be given as instances, not types (for example, `Int()`, not `Int`)"
" Passing types is deprecated in traitlets 4.1.",
DeprecationWarning, stacklevel=2)
key_trait = key_trait()
self._key_trait = key_trait
elif key_trait is not None:
raise TypeError("`key_trait` must be a Trait or None, got %s" % repr_type(key_trait))

self._traits = traits
self._per_key_traits = per_key_traits

super(Dict, self).__init__(klass=dict, args=args, **kwargs)

def element_error(self, obj, element, validator):
e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
def element_error(self, obj, element, validator, side='Values'):
e = side + " of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
% (self.name, class_of(obj), validator.info(), repr_type(element))
raise TraitError(e)

Expand All @@ -2552,41 +2588,47 @@ def validate(self, obj, value):
return value

def validate_elements(self, obj, value):
use_dict = bool(self._traits)
default_to = (self._trait or Any())
if not use_dict and isinstance(default_to, Any):
per_key_override = self._per_key_traits or {}
key_trait = self._key_trait
value_trait = self._value_trait
if not (key_trait or value_trait or per_key_override):
return value

validated = {}
for key in value:
if use_dict and key in self._traits:
validate_with = self._traits[key]
else:
validate_with = default_to
try:
v = value[key]
if not isinstance(validate_with, Any):
v = validate_with._validate(obj, v)
except TraitError:
self.element_error(obj, v, validate_with)
else:
validated[key] = v
v = value[key]
if key_trait:
try:
key = key_trait._validate(obj, key)
except TraitError:
self.element_error(obj, key, key_trait, 'Keys')
active_value_trait = per_key_override.get(key, value_trait)
if active_value_trait:
try:
v = active_value_trait._validate(obj, v)
except TraitError:
self.element_error(obj, v, active_value_trait, 'Values')
validated[key] = v

return self.klass(validated)

def class_init(self, cls, name):
if isinstance(self._trait, TraitType):
self._trait.class_init(cls, None)
if self._traits is not None:
for trait in self._traits.values():
if isinstance(self._value_trait, TraitType):
self._value_trait.class_init(cls, None)
if isinstance(self._key_trait, TraitType):
self._key_trait.class_init(cls, None)
if self._per_key_traits is not None:
for trait in self._per_key_traits.values():
trait.class_init(cls, None)
super(Dict, self).class_init(cls, name)

def instance_init(self, obj):
if isinstance(self._trait, TraitType):
self._trait.instance_init(obj)
if self._traits is not None:
for trait in self._traits.values():
if isinstance(self._value_trait, TraitType):
self._value_trait.instance_init(obj)
if isinstance(self._key_trait, TraitType):
self._key_trait.instance_init(obj)
if self._per_key_traits is not None:
for trait in self._per_key_traits.values():
trait.instance_init(obj)
super(Dict, self).instance_init(obj)

Expand Down

0 comments on commit b2266cc

Please sign in to comment.