diff --git a/.idea/basedtyping.iml b/.idea/basedtyping.iml
index 8e64af1..822b5af 100644
--- a/.idea/basedtyping.iml
+++ b/.idea/basedtyping.iml
@@ -8,8 +8,15 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/watcherTasks.xml b/.idea/watcherTasks.xml
index 20d228e..bd80a55 100644
--- a/.idea/watcherTasks.xml
+++ b/.idea/watcherTasks.xml
@@ -15,7 +15,7 @@
-
+
@@ -35,7 +35,7 @@
-
+
diff --git a/basedtyping/runtime_only.py b/basedtyping/runtime_only.py
index 5dda529..83878f0 100644
--- a/basedtyping/runtime_only.py
+++ b/basedtyping/runtime_only.py
@@ -6,7 +6,27 @@
from __future__ import annotations
-from typing import Final, Final as Final_ext, Literal, Union
+import functools
+import operator
+import sys
+import types
+from _ast import AST, Attribute, BinOp, BitAnd, Constant, Load, Name, Subscript, Tuple
+from ast import NodeTransformer, parse
+from types import GenericAlias
+from typing import (
+ Final,
+ Final as Final_ext,
+ ForwardRef,
+ Literal,
+ Union,
+ Unpack,
+ _eval_type,
+ _Final,
+ _GenericAlias,
+ _should_unflatten_callable_args,
+ _strip_annotations,
+ _type_check,
+)
LiteralType: Final = type(Literal[1])
"""A type that can be used to check if type hints are a ``typing.Literal`` instance"""
@@ -14,3 +34,267 @@
# TODO: this is type[object], we need it to be 'SpecialForm[Union]' (or something)
OldUnionType: Final_ext[type[object]] = type(Union[str, int])
"""A type that can be used to check if type hints are a ``typing.Union`` instance."""
+
+
+def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
+ if getattr(obj, "__no_type_check__", None):
+ return {}
+ # Classes require a special treatment.
+ if isinstance(obj, type):
+ hints = {}
+ for base in reversed(obj.__mro__):
+ if globalns is None:
+ base_globals = getattr(
+ sys.modules.get(base.__module__, None), "__dict__", {}
+ )
+ else:
+ base_globals = globalns
+ ann = base.__dict__.get("__annotations__", {})
+ if isinstance(ann, types.GetSetDescriptorType):
+ ann = {}
+ base_locals = dict(vars(base)) if localns is None else localns
+ if localns is None and globalns is None:
+ # This is surprising, but required. Before Python 3.10,
+ # get_type_hints only evaluated the globalns of
+ # a class. To maintain backwards compatibility, we reverse
+ # the globalns and localns order so that eval() looks into
+ # *base_globals* first rather than *base_locals*.
+ # This only affects ForwardRefs.
+ base_globals, base_locals = base_locals, base_globals
+ p = BasedTypeParser()
+ for name, value in ann.items():
+ if value is None:
+ value = type(None)
+ if isinstance(value, str):
+ value = p.visit(parse(value, mode="eval"))
+ # value = unparse(p.visit(parse(value)))
+ value = ForwardRef(value, is_argument=False, is_class=True)
+ value = _eval_type(value, base_globals, base_locals)
+ hints[name] = value
+ return (
+ hints
+ if include_extras
+ else {k: _strip_annotations(t) for k, t in hints.items()}
+ )
+
+ if globalns is None:
+ if isinstance(obj, types.ModuleType):
+ globalns = obj.__dict__
+ else:
+ nsobj = obj
+ # Find globalns for the unwrapped object.
+ while hasattr(nsobj, "__wrapped__"):
+ nsobj = nsobj.__wrapped__
+ globalns = getattr(nsobj, "__globals__", {})
+ if localns is None:
+ localns = globalns
+ elif localns is None:
+ localns = globalns
+ hints = getattr(obj, "__annotations__", None)
+ if hints is None:
+ # Return empty annotations for something that _could_ have them.
+ if isinstance(obj, _allowed_types):
+ return {}
+ else:
+ raise TypeError(f"{obj!r} is not a module, class, method, or function.")
+ hints = dict(hints)
+ for name, value in hints.items():
+ if value is None:
+ value = type(None)
+ if isinstance(value, str):
+ # class-level forward refs were handled above, this must be either
+ # a module-level annotation or a function argument annotation
+ value = ForwardRef(
+ value, is_argument=not isinstance(obj, types.ModuleType), is_class=False
+ )
+ hints[name] = _eval_type(value, globalns, localns)
+ return (
+ hints
+ if include_extras
+ else {k: _strip_annotations(t) for k, t in hints.items()}
+ )
+
+
+def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
+ """Evaluate all forward references in the given type t.
+ For use of globalns and localns see the docstring for get_type_hints().
+ recursive_guard is used to prevent infinite recursion with a recursive
+ ForwardRef.
+ """
+ if isinstance(t, ForwardRef):
+ return t._evaluate(globalns, localns, recursive_guard)
+ if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
+ if isinstance(t, GenericAlias):
+ args = tuple(
+ ForwardRef(arg) if isinstance(arg, str) else arg for arg in t.__args__
+ )
+ is_unpacked = t.__unpacked__
+ if _should_unflatten_callable_args(t, args):
+ t = t.__origin__[(args[:-1], args[-1])]
+ else:
+ t = t.__origin__[args]
+ if is_unpacked:
+ t = Unpack[t]
+ ev_args = tuple(
+ _eval_type(a, globalns, localns, recursive_guard) for a in t.__args__
+ )
+ if ev_args == t.__args__:
+ return t
+ if isinstance(t, GenericAlias):
+ return GenericAlias(t.__origin__, ev_args)
+ if isinstance(t, types.UnionType):
+ return functools.reduce(operator.or_, ev_args)
+ else:
+ return t.copy_with(ev_args)
+ return t
+
+
+class ForwardRef(_Final, _root=True):
+ """Internal wrapper to hold a forward reference."""
+
+ __slots__ = (
+ "__forward_arg__",
+ "__forward_code__",
+ "__forward_evaluated__",
+ "__forward_value__",
+ "__forward_is_argument__",
+ "__forward_is_class__",
+ "__forward_module__",
+ )
+
+ def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
+ if isinstance(arg, str):
+ # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
+ # Unfortunately, this isn't a valid expression on its own, so we
+ # do the unpacking manually.
+ if arg[0] == "*":
+ arg_to_compile = ( # E.g. (*Ts,)[0] or (*tuple[int, int],)[0]
+ f"({arg},)[0]"
+ )
+ else:
+ arg_to_compile = arg
+ elif isinstance(arg, AST):
+ arg_to_compile = arg
+ else:
+ raise TypeError(f"Forward reference must be a string or AST -- got {arg!r}")
+ try:
+ code = compile(arg_to_compile, "", "eval")
+ except SyntaxError:
+ raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
+ except TypeError as t:
+ print(arg_to_compile.body, t)
+ ...
+ self.__forward_arg__ = arg
+ self.__forward_code__ = code
+ self.__forward_evaluated__ = False
+ self.__forward_value__ = None
+ self.__forward_is_argument__ = is_argument
+ self.__forward_is_class__ = is_class
+ self.__forward_module__ = module
+
+ def _evaluate(self, globalns, localns, recursive_guard):
+ if self.__forward_arg__ in recursive_guard:
+ return self
+ if not self.__forward_evaluated__ or localns is not globalns:
+ if globalns is None and localns is None:
+ globalns = localns = {}
+ elif globalns is None:
+ globalns = localns
+ elif localns is None:
+ localns = globalns
+ if self.__forward_module__ is not None:
+ globalns = getattr(
+ sys.modules.get(self.__forward_module__, None), "__dict__", globalns
+ )
+ import typing
+
+ import basedtyping
+
+ type_ = _type_check(
+ eval(
+ self.__forward_code__,
+ globalns | {"__secret__": typing, "__basedsecret__": basedtyping},
+ localns,
+ ),
+ "Forward references must evaluate to types.",
+ is_argument=self.__forward_is_argument__,
+ allow_special_forms=self.__forward_is_class__,
+ )
+ self.__forward_value__ = _eval_type(
+ type_, globalns, localns, recursive_guard | {self.__forward_arg__}
+ )
+ self.__forward_evaluated__ = True
+ return self.__forward_value__
+
+ def __eq__(self, other):
+ if not isinstance(other, ForwardRef):
+ return NotImplemented
+ if self.__forward_evaluated__ and other.__forward_evaluated__:
+ return (
+ self.__forward_arg__ == other.__forward_arg__
+ and self.__forward_value__ == other.__forward_value__
+ )
+ return (
+ self.__forward_arg__ == other.__forward_arg__
+ and self.__forward_module__ == other.__forward_module__
+ )
+
+ def __hash__(self):
+ return hash((self.__forward_arg__, self.__forward_module__))
+
+ def __or__(self, other):
+ return Union[self, other]
+
+ def __ror__(self, other):
+ return Union[other, self]
+
+ def __repr__(self):
+ if self.__forward_module__ is None:
+ module_repr = ""
+ else:
+ module_repr = f", module={self.__forward_module__!r}"
+ return f"ForwardRef({self.__forward_arg__!r}{module_repr})"
+
+
+class BasedTypeParser(NodeTransformer):
+ in_subscript = 0
+
+ def __init__(self):
+ self.load = Load()
+
+ def visit_BinOp(self, node: BinOp) -> AST:
+ if isinstance(node.op, BitAnd):
+ extra = dict(lineno=node.lineno, col_offset=node.col_offset, ctx=self.load)
+ return Subscript(
+ Attribute(Name("__basedsecret__", **extra), "Intersection", **extra),
+ Tuple([self.visit(node.left), self.visit(node.right)], **extra),
+ **extra,
+ )
+ return self.generic_visit(node)
+
+ def visit_Constant(self, node: Constant) -> AST:
+ if isinstance(node.value, int):
+ # todo enum
+
+ extra = dict(lineno=node.lineno, col_offset=node.col_offset, ctx=self.load)
+ return Subscript(
+ Attribute(Name("__secret__", **extra), "Literal", **extra),
+ node,
+ **extra,
+ )
+ return self.generic_visit(node)
+
+ def visit_Tuple(self, node: Tuple) -> AST:
+ if self.in_subscript:
+ self.in_subscript = False
+ return self.generic_visit(node)
+ extra = dict(lineno=node.lineno, col_offset=node.col_offset, ctx=self.load)
+ return Subscript(Name("__secret__.Tuple"), self.generic_visit(node), **extra)
+
+ def visit_Subscript(self, node: Subscript) -> AST:
+ if isinstance(node.slice, Tuple):
+ self.in_subscript = True
+ try:
+ return self.generic_visit(node)
+ finally:
+ self.in_subscript = False
diff --git a/tests/test_runtime_only/test_parser.py b/tests/test_runtime_only/test_parser.py
new file mode 100644
index 0000000..c01fd64
--- /dev/null
+++ b/tests/test_runtime_only/test_parser.py
@@ -0,0 +1,82 @@
+from ast import parse, unparse
+from enum import Enum, auto
+
+from pytest import fixture, mark
+
+from basedtyping.runtime_only import get_type_hints
+import basedtyping
+from runtime_only import BasedTypeParser
+import typing
+
+
+class E(Enum):
+ a = auto()
+
+
+@fixture(scope="session")
+def hints():
+ class A:
+ a: "int & str"
+ b: "1"
+ e: "E.a"
+
+ yield get_type_hints(A)
+
+
+def test_get_intersection(hints):
+ a = hints["a"]
+ assert isinstance(a, basedtyping._IntersectionGenericAlias)
+ assert a.__args__ == (int, str)
+
+
+def test_get_literal(hints):
+ b = hints["b"]
+ assert isinstance(b, typing._LiteralGenericAlias)
+ assert b.__args__ == (1,)
+
+
+@mark.xfail(condition=True, reason="this isn't implemented")
+def test_get_literal_enum(hints):
+ e = hints["e"]
+ assert isinstance(e, typing._LiteralGenericAlias)
+ assert e.__args__ == (E.a,)
+
+
+def process(value):
+ return unparse(parser.visit(parse(value)))
+
+
+parser = BasedTypeParser()
+bprefix = "__basedsecret__"
+prefix = "__secret__"
+
+
+def test_intersection():
+ assert process("int & str") == f"{bprefix}.Intersection[int, str]"
+
+
+def test_literal():
+ assert process("1 | 2") == f"{prefix}.Literal[1] | {prefix}.Literal[2]"
+
+
+def test_tuple():
+ assert process("(int, str)") == f"{prefix}.Tuple[int, str]"
+
+
+def test_tuple():
+ assert process("(int, str)") == f"{prefix}.Tuple[int, str]"
+
+
+def test_subscript():
+ """To ensure tuple expressions within a subscript don't get interpreted as tuple literals"""
+ assert process("a[b, c]") == "a[b, c]"
+
+
+@mark.xfail(condition=True, reason="Not implemented")
+def test_implicit():
+ class A:
+ a: "Literal[1]"
+
+ result = get_type_hints(A, implicit=True)["a"]
+ assert isinstance(result, typing._LiteralGenericAlias)
+ assert result.__args__ == (1,)