diff --git a/test/test_01_can_validate.py b/test/test_01_can_validate.py index 09cc7a0..8ea8a6e 100644 --- a/test/test_01_can_validate.py +++ b/test/test_01_can_validate.py @@ -1,53 +1,74 @@ # pylint: disable = missing-docstring, expression-not-assigned -from types import NoneType +import sys import typing import pytest from typing_validation import can_validate, validation_aliases +from typing_validation.inspector import _typing_equiv from typing_validation.validation import _pseudotypes_dict from .test_00_validate import _test_cases, _union_cases, _literal_cases, _alias_cases,_typed_dict_cases, _validation_aliases def assert_recorded_type(t: typing.Any) -> None: _t = can_validate(t).recorded_type - if t is NoneType: + if t is type(None): assert None is _t elif hasattr(t, "__origin__") and hasattr(t, "__args__") and t.__module__ == "typing": - assert t.__origin__[*t.__args__] == _t + if sys.version_info[1] <= 8 and t.__origin__ in _typing_equiv: + t_origin = _typing_equiv[t.__origin__] + else: + t_origin = t.__origin__ + if t.__args__: + assert t_origin[t.__args__] == _t + else: + assert t_origin == _t else: + if sys.version_info[1] <= 8 and t in _typing_equiv: + t = _typing_equiv[t] assert t == _t +_valid_cases_ts = sorted({ + t for _, ts in _test_cases for t in ts +}|{ + _pseudotypes_dict[t] for _, ts in _test_cases for t in ts if t in _pseudotypes_dict +}|{typing.Any}, key=repr) -@pytest.mark.parametrize("val, ts", _test_cases) -def test_valid_cases(val: typing.Any, ts: typing.List[typing.Any]) -> None: - ts = ts+[_pseudotypes_dict[t] for t in ts if t in _pseudotypes_dict] - for t in ts: - assert can_validate(t), f"Should be able to validate {t}" - assert can_validate(typing.Optional[t]), f"Should be able to validate {typing.Optional[t]}" - str(can_validate(t)) - assert_recorded_type(t) - assert can_validate(typing.Any), f"Should be able to validate {typing.Any}" +@pytest.mark.parametrize("t", _valid_cases_ts) +def test_valid_cases(t: typing.Any) -> None: + # ts = ts+[_pseudotypes_dict[t] for t in ts if t in _pseudotypes_dict] + assert can_validate(t), f"Should be able to validate {t}" + assert can_validate(typing.Optional[t]), f"Should be able to validate {typing.Optional[t]}" + str(can_validate(t)) + assert_recorded_type(t) + +_other_cases_ts = sorted({ + t for _, ts in _union_cases+_literal_cases for t in ts +}, key=repr) -@pytest.mark.parametrize("val, ts", _union_cases+_literal_cases) -def test_other_cases(val: typing.Any, ts: typing.List[typing.Any]) -> None: - for t in ts: +@pytest.mark.parametrize("t", _other_cases_ts) +def test_other_cases(t: typing.Any) -> None: + assert can_validate(t), f"Should be able to validate {t}" + assert can_validate(typing.Optional[t]), f"Should be able to validate {typing.Optional[t]}" + str(can_validate(t)) + assert_recorded_type(t) + +_alias_cases_ts = sorted({ + t for _, ts in _alias_cases for t in ts +}, key=repr) + +@pytest.mark.parametrize("t", _alias_cases_ts) +def test_alias_cases(t: typing.Any) -> None: + with validation_aliases(**_validation_aliases): assert can_validate(t), f"Should be able to validate {t}" assert can_validate(typing.Optional[t]), f"Should be able to validate {typing.Optional[t]}" str(can_validate(t)) assert_recorded_type(t) -@pytest.mark.parametrize("val, ts", _alias_cases) -def test_alias_cases(val: typing.Any, ts: typing.List[typing.Any]) -> None: - for t in ts: - with validation_aliases(**_validation_aliases): - assert can_validate(t), f"Should be able to validate {t}" - assert can_validate(typing.Optional[t]), f"Should be able to validate {typing.Optional[t]}" - str(can_validate(t)) - assert_recorded_type(t) - -@pytest.mark.parametrize("t, vals", _typed_dict_cases.items()) -def test_typed_dict_cases(t: typing.Any, vals: typing.List[typing.Any]) -> None: +_typed_dict_cases_ts = sorted(_typed_dict_cases.keys(), key=repr) + +@pytest.mark.parametrize("t", _typed_dict_cases_ts) +def test_typed_dict_cases(t: typing.Any) -> None: with validation_aliases(**_validation_aliases): assert can_validate(t), f"Should be able to validate {t}" assert can_validate(typing.Optional[t]), f"Should be able to validate {typing.Optional[t]}" diff --git a/typing_validation/inspector.py b/typing_validation/inspector.py index f5dc0ef..1da8875 100644 --- a/typing_validation/inspector.py +++ b/typing_validation/inspector.py @@ -5,7 +5,8 @@ """ from __future__ import annotations - +import collections +import collections.abc as collections_abc import sys import typing from typing import Any, Optional, Union, get_type_hints @@ -44,6 +45,36 @@ else: TypeConstructorArgs = typing.Tuple[str, Any] +_typing_equiv = { + list: typing.List, + tuple: typing.Tuple, + set: typing.Set, + frozenset: typing.FrozenSet, + dict: typing.Dict, + collections.deque: typing.Deque, + collections.defaultdict: typing.DefaultDict, + collections_abc.Collection: typing.Collection, + collections_abc.Set: typing.AbstractSet, + collections_abc.MutableSet: typing.MutableSet, + collections_abc.Sequence: typing.Sequence, + collections_abc.MutableSequence: typing.MutableSequence, + collections_abc.Iterable: typing.Iterable, + collections_abc.Iterator: typing.Iterator, + collections_abc.Container: typing.Container, + collections_abc.Mapping: typing.Mapping, + collections_abc.MutableMapping: typing.MutableMapping, + collections_abc.Hashable: typing.Hashable, + collections_abc.Sized: typing.Sized, +} + +if sys.version_info[1] <= 11: + _typing_equiv[collections_abc.ByteString] = typing.ByteString # type: ignore + +def _to_typing_equiv(t: Any) -> Any: + if sys.version_info[1] <= 8 and t in _typing_equiv: + return _typing_equiv[t] + return t + class UnsupportedType(type): r""" Wrapper for an unsupported type encountered by a :class:`TypeInspector` instance during validation. @@ -116,13 +147,16 @@ def _recorded_type(self, idx: int) -> typing.Tuple[Any, int]: member_ts.append(member_t) return typing.Union.__getitem__(tuple(member_ts)), idx if tag == "typed-dict": + for _ in get_type_hints(param): + _, idx = self._recorded_type(idx+1) return param, idx pending_type = None if tag == "type": # if isinstance(param, type): if not isinstance(param, tuple): - return param, idx + return _to_typing_equiv(param), idx pending_type, tag, param = param + pending_type = _to_typing_equiv(pending_type) if tag == "collection": item_t, idx = self._recorded_type(idx+1) t = pending_type[item_t] if pending_type is not None else typing.Collection[item_t] # type: ignore[valid-type] @@ -144,7 +178,7 @@ def _recorded_type(self, idx: int) -> typing.Tuple[Any, int]: item_ts.append(item_t) if not item_ts: item_ts = [tuple()] - t = pending_type.__class_getitem__(tuple(item_ts)) if pending_type is not None else typing.Tuple.__getitem__(tuple(item_ts)) + t = pending_type[tuple(item_ts)] if pending_type is not None else typing.Tuple[tuple(item_ts)] return t, idx assert False, f"Invalid type constructor tag: {repr(tag)}" @@ -214,7 +248,7 @@ def __repr__(self) -> str: return header+"\n"+"\n".join(self._repr()[0]) def _repr(self, idx: int = 0, level: int = 0) -> typing.Tuple[typing.List[str], int]: - # pylint: disable = too-many-return-statements, too-many-branches, too-many-statements + # pylint: disable = too-many-return-statements, too-many-branches, too-many-statements, too-many-locals basic_indent = " " assert len(basic_indent) >= 2 indent = basic_indent*level @@ -267,7 +301,8 @@ def _repr(self, idx: int = 0, level: int = 0) -> typing.Tuple[typing.List[str], if tag == "type": # if isinstance(param, type): if not isinstance(param, tuple): - return [indent+param.__name__], idx + param_name = param.__name__ if isinstance(param, type) else str(param) + return [indent+param_name], idx pending_type, tag, param = param if tag == "collection": item_lines, idx = self._repr(idx+1, level+1) diff --git a/typing_validation/validation.py b/typing_validation/validation.py index cedff7c..7277956 100644 --- a/typing_validation/validation.py +++ b/typing_validation/validation.py @@ -9,7 +9,6 @@ import collections.abc as collections_abc from keyword import iskeyword import sys -from types import NoneType import typing from typing import Any, ForwardRef, Optional, Union, get_type_hints import typing_extensions @@ -29,6 +28,11 @@ def issoftkeyword(s: str) -> bool: r""" Dummy implementation for issoftkeyword in Python 3.7 and 3.8. """ return s == "_" +if sys.version_info[1] >= 10: + from types import NoneType +else: + NoneType = type(None) + _validation_aliases: typing.Dict[str, Any] = {} r""" Current context of type aliases, used to resolve forward references to type aliases in :func:`validate`.