Skip to content

Commit

Permalink
feat(typing): better annotation support for PyTree[T]
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Oct 23, 2024
1 parent fcf9c5f commit e899e4d
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 10 deletions.
29 changes: 28 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# pylint: disable=all
# mypy: ignore-errors

# -- Path setup --------------------------------------------------------------

Expand Down Expand Up @@ -83,7 +84,7 @@ def get_version() -> str:
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
source_suffix = {'.rst': 'restructuredtext'}

# The master toctree document.
master_doc = 'index'
Expand All @@ -105,6 +106,9 @@ def get_version() -> str:
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'default'

# A list of warning codes to suppress arbitrary warning messages.
suppress_warnings = ['config.cache']

# -- Options for autodoc -----------------------------------------------------

autodoc_default_options = {
Expand Down Expand Up @@ -175,3 +179,26 @@ def get_version() -> str:

# To make sphinx-copybutton skip all prompt characters generated by pygments
copybutton_exclude = '.linenos, .gp'

# -- Options for autodoc-typehints extension ---------------------------------
always_use_bars_union = True
typehints_use_signature = False
typehints_use_signature_return = False


def typehints_formatter(annotation, config=None):
from typing import Union

if (
isinstance(annotation, type(Union[int, str]))
and annotation.__origin__ is Union
and hasattr(annotation, '__pytree_args__')
):
param, name = annotation.__pytree_args__
if name is not None:
return f':py:class:`{name}`'

from sphinx_autodoc_typehints import format_annotation

return rf':py:class:`PyTree` \[{format_annotation(param,config=config)}]'
return None
12 changes: 6 additions & 6 deletions docs/source/integration.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Integration with Third-Party Libraries
======================================

Integration for JAX
-------------------
Integration for `JAX <https://github.com/jax-ml/jax>`_
------------------------------------------------------

.. currentmodule:: optree.integration.jax

Expand All @@ -14,8 +14,8 @@ Integration for JAX

------

Integration for NumPy
---------------------
Integration for `NumPy <https://github.com/numpy/numpy>`_
---------------------------------------------------------

.. currentmodule:: optree.integration.numpy

Expand All @@ -27,8 +27,8 @@ Integration for NumPy

------

Integration for PyTorch
-----------------------
Integration for `PyTorch <https://github.com/pytorch/pytorch>`_
---------------------------------------------------------------

.. currentmodule:: optree.integration.torch

Expand Down
4 changes: 4 additions & 0 deletions optree/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
SUBMODULES: frozenset[str] = frozenset({'jax', 'numpy', 'torch'})


def __dir__() -> list[str]:
return [*sorted(SUBMODULES), 'SUBMODULES']


def __getattr__(name: str) -> ModuleType:
if name in SUBMODULES:
import importlib # pylint: disable=import-outside-toplevel
Expand Down
19 changes: 16 additions & 3 deletions optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class PyTree(Generic[T]): # pragma: no cover
typing.Union[torch.Tensor,
typing.Tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
typing.List[ForwardRef('PyTree[torch.Tensor]')],
typing.Dict[collections.abc.Hashable, ForwardRef('PyTree[torch.Tensor]')],
typing.Dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
typing.Deque[ForwardRef('PyTree[torch.Tensor]')],
optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
"""
Expand Down Expand Up @@ -232,11 +232,24 @@ def __class_getitem__( # noqa: C901
param, # type: ignore[valid-type]
Tuple[recurse_ref, ...], # type: ignore[valid-type] # Tuple, NamedTuple, PyStructSequence
List[recurse_ref], # type: ignore[valid-type]
Dict[Hashable, recurse_ref], # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict
Dict[Any, recurse_ref], # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict
Deque[recurse_ref], # type: ignore[valid-type]
CustomTreeNode[recurse_ref], # type: ignore[valid-type]
]
pytree_alias.__pytree_args__ = item # type: ignore[attr-defined]

# pylint: disable-next=no-member
original_copy_with = pytree_alias.copy_with # type: ignore[attr-defined]
original_num_params = len(pytree_alias.__args__) # type: ignore[attr-defined]

def copy_with(params: tuple) -> TypeAlias:
if not isinstance(params, tuple) or len(params) != original_num_params:
return original_copy_with(params)
if params[0] is param:
return pytree_alias
return PyTree[params[0]] # type: ignore[misc,valid-type]

object.__setattr__(pytree_alias, 'copy_with', copy_with)
return pytree_alias

def __new__(cls) -> NoReturn: # pylint: disable=arguments-differ
Expand Down Expand Up @@ -302,7 +315,7 @@ class PyTreeTypeVar: # pragma: no cover
typing.Union[torch.Tensor,
typing.Tuple[ForwardRef('TensorTree'), ...],
typing.List[ForwardRef('TensorTree')],
typing.Dict[collections.abc.Hashable, ForwardRef('TensorTree')],
typing.Dict[typing.Any, ForwardRef('TensorTree')],
typing.Deque[ForwardRef('TensorTree')],
optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
"""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ typing-modules = ["optree.typing"]
"E402", # module-import-not-at-top-of-file
]
"docs/source/conf.py" = [
"ANN", # flake8-annotations
"INP001", # flake8-no-pep420
]

Expand Down

0 comments on commit e899e4d

Please sign in to comment.