Skip to content

Commit

Permalink
Handle ForwardRef, expand TypeVar and link Ellipsis (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaborbernat authored Jan 25, 2022
1 parent f75d19b commit 8b0599d
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 46 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ repos:
hooks:
- id: flake8
additional_dependencies:
- flake8-bugbear==21.11.29
- flake8-comprehensions==3.7
- flake8-bugbear==22.1.11
- flake8-comprehensions==3.8
- flake8-pytest-style==1.6
- flake8-spellcheck==0.24
- flake8-unused-arguments==0.0.9
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
- Add support for type subscriptions with multiple elements, where one or more elements
are tuples; e.g., `nptyping.NDArray[(Any, ...), nptyping.Float]`
- Fix bug for arbitrary types accepting singleton subscriptions; e.g., `nptyping.Float[64]`
- Resolve forward references
- Expand and better handle `TypeVar`
- Add intershpinx reference link for `...` to `Ellipsis` (as is just an alias)

## 1.15.3

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ other =
*\sphinx-autodoc-typehints

[coverage:report]
fail_under = 78
fail_under = 82

[coverage:html]
show_contexts = true
Expand Down
85 changes: 61 additions & 24 deletions src/sphinx_autodoc_typehints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import sys
import textwrap
from ast import FunctionDef, Module, stmt
from typing import Any, AnyStr, Callable, NewType, TypeVar, get_type_hints
from typing import _eval_type # type: ignore # no import defined in stubs
from typing import Any, AnyStr, Callable, ForwardRef, NewType, TypeVar, get_type_hints

from sphinx.application import Sphinx
from sphinx.config import Config
Expand All @@ -24,7 +25,8 @@
def get_annotation_module(annotation: Any) -> str:
if annotation is None:
return "builtins"
if sys.version_info >= (3, 10) and isinstance(annotation, NewType): # type: ignore # isinstance NewType is Callable
is_new_type = sys.version_info >= (3, 10) and isinstance(annotation, NewType) # type: ignore
if is_new_type or isinstance(annotation, TypeVar):
return "typing"
if hasattr(annotation, "__module__"):
return annotation.__module__ # type: ignore # deduced Any
Expand Down Expand Up @@ -79,13 +81,14 @@ def get_annotation_args(annotation: Any, module: str, class_name: str) -> tuple[
return (annotation.type_var,)
elif class_name == "ClassVar" and hasattr(annotation, "__type__"): # ClassVar on Python < 3.7
return (annotation.__type__,)
elif class_name == "TypeVar" and hasattr(annotation, "__constraints__"):
return annotation.__constraints__ # type: ignore # no stubs defined
elif class_name == "NewType" and hasattr(annotation, "__supertype__"):
return (annotation.__supertype__,)
elif class_name == "Literal" and hasattr(annotation, "__values__"):
return annotation.__values__ # type: ignore # deduced Any
elif class_name == "Generic":
return annotation.__parameters__ # type: ignore # deduced Any

return getattr(annotation, "__args__", ())


Expand All @@ -104,29 +107,25 @@ def format_internal_tuple(t: tuple[Any, ...], config: Config) -> str:
return f"({', '.join(fmt)})"


def format_annotation(annotation: Any, config: Config) -> str:
def format_annotation(annotation: Any, config: Config) -> str: # noqa: C901 # too complex
typehints_formatter: Callable[..., str] | None = getattr(config, "typehints_formatter", None)
if typehints_formatter is not None:
formatted = typehints_formatter(annotation, config)
if formatted is not None:
return formatted

# Special cases
if isinstance(annotation, ForwardRef):
value = _resolve_forward_ref(annotation, config)
return format_annotation(value, config)
if annotation is None or annotation is type(None): # noqa: E721
return ":py:obj:`None`"
elif annotation is Ellipsis:
return "..."
if annotation is Ellipsis:
return ":py:data:`...<Ellipsis>`"

if isinstance(annotation, tuple):
return format_internal_tuple(annotation, config)

# Type variables are also handled specially
try:
if isinstance(annotation, TypeVar) and annotation is not AnyStr:
return "\\" + repr(annotation)
except TypeError:
pass

try:
module = get_annotation_module(annotation)
class_name = get_annotation_class_name(annotation, module)
Expand All @@ -143,12 +142,22 @@ def format_annotation(annotation: Any, config: Config) -> str:
prefix = "" if fully_qualified or full_name == class_name else "~"
role = "data" if class_name in _PYDATA_ANNOTATIONS else "class"
args_format = "\\[{}]"
formatted_args = ""
formatted_args: str | None = ""

# Some types require special handling
if full_name == "typing.NewType":
args_format = f"\\(``{annotation.__name__}``, {{}})"
role = "class" if sys.version_info >= (3, 10) else "func"
elif full_name == "typing.TypeVar":
params = {k: getattr(annotation, f"__{k}__") for k in ("bound", "covariant", "contravariant")}
params = {k: v for k, v in params.items() if v}
if "bound" in params:
params["bound"] = f" {format_annotation(params['bound'], config)}"
args_format = f"\\(``{annotation.__name__}``{', {}' if args else ''}"
if params:
args_format += "".join(f", {k}={v}" for k, v in params.items())
args_format += ")"
formatted_args = None if args else args_format
elif full_name == "typing.Optional":
args = tuple(x for x in args if x is not type(None)) # noqa: E721
elif full_name == "typing.Union" and type(None) in args:
Expand Down Expand Up @@ -176,7 +185,19 @@ def format_annotation(annotation: Any, config: Config) -> str:
fmt = [format_annotation(arg, config) for arg in args]
formatted_args = args_format.format(", ".join(fmt))

return f":py:{role}:`{prefix}{full_name}`{formatted_args}"
result = f":py:{role}:`{prefix}{full_name}`{formatted_args}"
return result


def _resolve_forward_ref(annotation: ForwardRef, config: Config) -> Any:
raw, base_globals = annotation.__forward_arg__, config._annotation_globals
params = {"is_class": True} if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1) else {}
value = ForwardRef(raw, is_argument=False, **params)
try:
result = _eval_type(value, base_globals, None)
except NameError:
result = raw # fallback to the value itself as string
return result


# reference: https://github.com/pytorch/pytorch/pull/46548/files
Expand Down Expand Up @@ -284,14 +305,15 @@ def _future_annotations_imported(obj: Any) -> bool:

def get_all_type_hints(obj: Any, name: str) -> dict[str, Any]:
result = _get_type_hint(name, obj)
if result:
return result
result = backfill_type_hints(obj, name)
try:
obj.__annotations__ = result
except (AttributeError, TypeError):
return result
return _get_type_hint(name, obj)
if not result:
result = backfill_type_hints(obj, name)
try:
obj.__annotations__ = result
except (AttributeError, TypeError):
pass
else:
result = _get_type_hint(name, obj)
return result


_TYPE_GUARD_IMPORT_RE = re.compile(r"\nif (typing.)?TYPE_CHECKING:[^\n]*([\s\S]*?)(?=\n\S)")
Expand Down Expand Up @@ -474,7 +496,22 @@ def process_docstring(
except (ValueError, TypeError):
signature = None
type_hints = get_all_type_hints(obj, name)

app.config._annotation_globals = getattr(obj, "__globals__", {}) # type: ignore # config has no such attribute
try:
_inject_types_to_docstring(type_hints, signature, original_obj, app, what, name, lines)
finally:
delattr(app.config, "_annotation_globals")


def _inject_types_to_docstring(
type_hints: dict[str, Any],
signature: inspect.Signature | None,
original_obj: Any,
app: Sphinx,
what: str,
name: str,
lines: list[str],
) -> None:
for arg_name, annotation in type_hints.items():
if arg_name == "return":
continue # this is handled separately later
Expand Down
71 changes: 52 additions & 19 deletions tests/test_sphinx_autodoc_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
T = TypeVar("T")
U = TypeVar("U", covariant=True)
V = TypeVar("V", contravariant=True)
X = TypeVar("X", str, int)
Y = TypeVar("Y", bound=str)
Z = TypeVar("Z", bound="A")
S = TypeVar("S", bound="miss") # type: ignore # miss not defined on purpose # noqa: F821
W = NewType("W", str)


Expand All @@ -61,8 +65,7 @@ class Inner:


class B(Generic[T]):
# This is set to make sure the correct class name ("B") is picked up
name = "Foo"
name = "Foo" # This is set to make sure the correct class name ("B") is picked up


class C(B[str]):
Expand Down Expand Up @@ -147,21 +150,35 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
(Type[A], ":py:class:`~typing.Type`\\[:py:class:`~%s.A`]" % __name__),
(Any, ":py:data:`~typing.Any`"),
(AnyStr, ":py:data:`~typing.AnyStr`"),
(Generic[T], ":py:class:`~typing.Generic`\\[\\~T]"),
(Generic[T], ":py:class:`~typing.Generic`\\[:py:class:`~typing.TypeVar`\\(``T``)]"),
(Mapping, ":py:class:`~typing.Mapping`"),
(Mapping[T, int], ":py:class:`~typing.Mapping`\\[\\~T, :py:class:`int`]"),
(Mapping[str, V], ":py:class:`~typing.Mapping`\\[:py:class:`str`, \\-V]"),
(Mapping[T, U], ":py:class:`~typing.Mapping`\\[\\~T, \\+U]"),
(Mapping[T, int], ":py:class:`~typing.Mapping`\\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]"),
(
Mapping[str, V],
":py:class:`~typing.Mapping`\\[:py:class:`str`, :py:class:`~typing.TypeVar`\\(``V``, contravariant=True)]",
),
(
Mapping[T, U],
":py:class:`~typing.Mapping`\\[:py:class:`~typing.TypeVar`\\(``T``), "
":py:class:`~typing.TypeVar`\\(``U``, covariant=True)]",
),
(Mapping[str, bool], ":py:class:`~typing.Mapping`\\[:py:class:`str`, " ":py:class:`bool`]"),
(Dict, ":py:class:`~typing.Dict`"),
(Dict[T, int], ":py:class:`~typing.Dict`\\[\\~T, :py:class:`int`]"),
(Dict[str, V], ":py:class:`~typing.Dict`\\[:py:class:`str`, \\-V]"),
(Dict[T, U], ":py:class:`~typing.Dict`\\[\\~T, \\+U]"),
(Dict[T, int], ":py:class:`~typing.Dict`\\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]"),
(
Dict[str, V],
":py:class:`~typing.Dict`\\[:py:class:`str`, :py:class:`~typing.TypeVar`\\(``V``, contravariant=True)]",
),
(
Dict[T, U],
":py:class:`~typing.Dict`\\[:py:class:`~typing.TypeVar`\\(``T``),"
" :py:class:`~typing.TypeVar`\\(``U``, covariant=True)]",
),
(Dict[str, bool], ":py:class:`~typing.Dict`\\[:py:class:`str`, " ":py:class:`bool`]"),
(Tuple, ":py:data:`~typing.Tuple`"),
(Tuple[str, bool], ":py:data:`~typing.Tuple`\\[:py:class:`str`, " ":py:class:`bool`]"),
(Tuple[int, int, int], ":py:data:`~typing.Tuple`\\[:py:class:`int`, " ":py:class:`int`, :py:class:`int`]"),
(Tuple[str, ...], ":py:data:`~typing.Tuple`\\[:py:class:`str`, ...]"),
(Tuple[str, ...], ":py:data:`~typing.Tuple`\\[:py:class:`str`, :py:data:`...<Ellipsis>`]"),
(Union, ":py:data:`~typing.Union`"),
(Union[str, bool], ":py:data:`~typing.Union`\\[:py:class:`str`, " ":py:class:`bool`]"),
(Union[str, bool, None], ":py:data:`~typing.Union`\\[:py:class:`str`, " ":py:class:`bool`, :py:obj:`None`]"),
Expand All @@ -178,7 +195,7 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
":py:data:`~typing.Union`\\[:py:class:`str`, " ":py:class:`bool`, :py:obj:`None`]",
),
(Callable, ":py:data:`~typing.Callable`"),
(Callable[..., int], ":py:data:`~typing.Callable`\\[..., :py:class:`int`]"),
(Callable[..., int], ":py:data:`~typing.Callable`\\[:py:data:`...<Ellipsis>`, :py:class:`int`]"),
(Callable[[int], int], ":py:data:`~typing.Callable`\\[\\[:py:class:`int`], " ":py:class:`int`]"),
(
Callable[[int, str], bool],
Expand All @@ -188,7 +205,11 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
Callable[[int, str], None],
":py:data:`~typing.Callable`\\[\\[:py:class:`int`, " ":py:class:`str`], :py:obj:`None`]",
),
(Callable[[T], T], ":py:data:`~typing.Callable`\\[\\[\\~T], \\~T]"),
(
Callable[[T], T],
":py:data:`~typing.Callable`\\[\\[:py:class:`~typing.TypeVar`\\(``T``)],"
" :py:class:`~typing.TypeVar`\\(``T``)]",
),
(Pattern, ":py:class:`~typing.Pattern`"),
(Pattern[str], ":py:class:`~typing.Pattern`\\[:py:class:`str`]"),
(IO, ":py:class:`~typing.IO`"),
Expand All @@ -202,14 +223,21 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
(E, ":py:class:`~%s.E`" % __name__),
(E[int], ":py:class:`~%s.E`\\[:py:class:`int`]" % __name__),
(W, f':py:{"class" if PY310_PLUS else "func"}:' f"`~typing.NewType`\\(``W``, :py:class:`str`)"),
(T, ":py:class:`~typing.TypeVar`\\(``T``)"),
(U, ":py:class:`~typing.TypeVar`\\(``U``, covariant=True)"),
(V, ":py:class:`~typing.TypeVar`\\(``V``, contravariant=True)"),
(X, ":py:class:`~typing.TypeVar`\\(``X``, :py:class:`str`, :py:class:`int`)"),
(Y, ":py:class:`~typing.TypeVar`\\(``Y``, bound= :py:class:`str`)"),
(Z, ":py:class:`~typing.TypeVar`\\(``Z``, bound= :py:class:`~test_sphinx_autodoc_typehints.A`)"),
(S, ":py:class:`~typing.TypeVar`\\(``S``, bound= miss)"),
# ## These test for correct internal tuple rendering, even if not all are valid Tuple types
# Zero-length tuple remains
(Tuple[()], ":py:data:`~typing.Tuple`\\[()]"),
# Internal single tuple with simple types is flattened in the output
(Tuple[(int,)], ":py:data:`~typing.Tuple`\\[:py:class:`int`]"),
(Tuple[(int, int)], ":py:data:`~typing.Tuple`\\[:py:class:`int`, :py:class:`int`]"),
# Ellipsis in single tuple also gets flattened
(Tuple[(int, ...)], ":py:data:`~typing.Tuple`\\[:py:class:`int`, ...]"),
(Tuple[(int, ...)], ":py:data:`~typing.Tuple`\\[:py:class:`int`, :py:data:`...<Ellipsis>`]"),
# Internal tuple with following additional type cannot be flattened (specific to nptyping?)
# These cases will fail if nptyping restructures its internal module hierarchy
(
Expand All @@ -236,7 +264,7 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
(
nptyping.NDArray[(Any, ...), nptyping.Float],
(
":py:class:`~nptyping.types._ndarray.NDArray`\\[(:py:data:`~typing.Any`, ...), "
":py:class:`~nptyping.types._ndarray.NDArray`\\[(:py:data:`~typing.Any`, :py:data:`...<Ellipsis>`), "
":py:class:`~nptyping.types._number.Float`]"
),
),
Expand All @@ -249,12 +277,15 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
),
(
nptyping.NDArray[(3, ...), nptyping.Float],
(":py:class:`~nptyping.types._ndarray.NDArray`\\[(3, ...), :py:class:`~nptyping.types._number.Float`]"),
(
":py:class:`~nptyping.types._ndarray.NDArray`\\[(3, :py:data:`...<Ellipsis>`),"
" :py:class:`~nptyping.types._number.Float`]"
),
),
],
)
def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str) -> None:
conf = create_autospec(Config)
conf = create_autospec(Config, _annotation_globals=globals())
result = format_annotation(annotation, conf)
assert result == expected_result

Expand All @@ -266,21 +297,23 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
# encapsulate Union in typing.Optional
expected_result_not_simplified = ":py:data:`~typing.Optional`\\[" + expected_result_not_simplified
expected_result_not_simplified += "]"
conf = create_autospec(Config, simplify_optional_unions=False)
conf = create_autospec(Config, simplify_optional_unions=False, _annotation_globals=globals())
assert format_annotation(annotation, conf) == expected_result_not_simplified

# Test with the "fully_qualified" flag turned on
if "typing" in expected_result_not_simplified:
expected_result_not_simplified = expected_result_not_simplified.replace("~typing", "typing")
conf = create_autospec(Config, typehints_fully_qualified=True, simplify_optional_unions=False)
conf = create_autospec(
Config, typehints_fully_qualified=True, simplify_optional_unions=False, _annotation_globals=globals()
)
assert format_annotation(annotation, conf) == expected_result_not_simplified

# Test with the "fully_qualified" flag turned on
if "typing" in expected_result or "nptyping" in expected_result or __name__ in expected_result:
expected_result = expected_result.replace("~typing", "typing")
expected_result = expected_result.replace("~nptyping", "nptyping")
expected_result = expected_result.replace("~" + __name__, __name__)
conf = create_autospec(Config, typehints_fully_qualified=True)
conf = create_autospec(Config, typehints_fully_qualified=True, _annotation_globals=globals())
assert format_annotation(annotation, conf) == expected_result

# Test for the correct role (class vs data) using the official Sphinx inventory
Expand Down
2 changes: 2 additions & 0 deletions whitelist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ contravariant
cpython
csv
dedent
delattr
dirname
docnames
dunder
eval
exc
fget
fmt
Expand Down

0 comments on commit 8b0599d

Please sign in to comment.