From ea7baf8eb510edcdcfc5d1a76d62acd46226e6f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerg=C5=91=20Jedlicska?= Date: Wed, 15 Feb 2023 19:20:45 +0100 Subject: [PATCH 1/2] fix(type_checking): make sure forwardrefs blank pass type checking --- src/specklepy/objects/base.py | 4 ++++ tests/unit/test_type_validation.py | 1 + 2 files changed, 5 insertions(+) diff --git a/src/specklepy/objects/base.py b/src/specklepy/objects/base.py index b3189fa6..ce5eab00 100644 --- a/src/specklepy/objects/base.py +++ b/src/specklepy/objects/base.py @@ -5,6 +5,7 @@ Any, ClassVar, Dict, + ForwardRef, List, Optional, Set, @@ -217,6 +218,9 @@ def _validate_type(t: Optional[type], value: Any) -> Tuple[bool, Any]: return True, t(value) if getattr(t, "__module__", None) == "typing": + if isinstance(t, ForwardRef): + return True, value + origin = getattr(t, "__origin__") # below is what in nicer for >= py38 # origin = get_origin(t) diff --git a/tests/unit/test_type_validation.py b/tests/unit/test_type_validation.py index 6411234f..55e19092 100644 --- a/tests/unit/test_type_validation.py +++ b/tests/unit/test_type_validation.py @@ -106,6 +106,7 @@ def __init__(self, foo: str) -> None: True, fake_bases, ), + (List["int"], [2, 3, 4], True, [2, 3, 4]) ], ) def test_validate_type( From ae42bec1c3bc9a8fedd21d9eb9bc77983decb74e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerg=C5=91=20Jedlicska?= Date: Wed, 15 Feb 2023 19:21:48 +0100 Subject: [PATCH 2/2] style(formatting): rerun formatting --- src/specklepy/objects/base.py | 2 +- tests/unit/test_type_validation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/specklepy/objects/base.py b/src/specklepy/objects/base.py index ce5eab00..a4b09888 100644 --- a/src/specklepy/objects/base.py +++ b/src/specklepy/objects/base.py @@ -220,7 +220,7 @@ def _validate_type(t: Optional[type], value: Any) -> Tuple[bool, Any]: if getattr(t, "__module__", None) == "typing": if isinstance(t, ForwardRef): return True, value - + origin = getattr(t, "__origin__") # below is what in nicer for >= py38 # origin = get_origin(t) diff --git a/tests/unit/test_type_validation.py b/tests/unit/test_type_validation.py index 55e19092..1980ca21 100644 --- a/tests/unit/test_type_validation.py +++ b/tests/unit/test_type_validation.py @@ -106,7 +106,7 @@ def __init__(self, foo: str) -> None: True, fake_bases, ), - (List["int"], [2, 3, 4], True, [2, 3, 4]) + (List["int"], [2, 3, 4], True, [2, 3, 4]), ], ) def test_validate_type(