From 7a394083eb5663a5b167982078681b6b0d8ab5a6 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 11 Sep 2024 19:39:47 +0800 Subject: [PATCH] docs: update docstrings --- docs/source/dataclasses.rst | 3 +++ docs/source/functools.rst | 3 +++ optree/dataclasses.py | 28 +++++++++++++++++++++------- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/docs/source/dataclasses.rst b/docs/source/dataclasses.rst index 15642b16..95a4e0fd 100644 --- a/docs/source/dataclasses.rst +++ b/docs/source/dataclasses.rst @@ -3,6 +3,9 @@ Integration with :mod:`dataclasses` .. currentmodule:: optree.dataclasses +.. automodule:: optree.dataclasses + :no-members: + .. autosummary:: field diff --git a/docs/source/functools.rst b/docs/source/functools.rst index f8dc1f54..fbcbb54c 100644 --- a/docs/source/functools.rst +++ b/docs/source/functools.rst @@ -3,6 +3,9 @@ Integration with :mod:`functools` .. currentmodule:: optree.functools +.. automodule:: optree.functools + :no-members: + .. autosummary:: partial diff --git a/optree/dataclasses.py b/optree/dataclasses.py index 473d37af..c0158620 100644 --- a/optree/dataclasses.py +++ b/optree/dataclasses.py @@ -30,13 +30,27 @@ ... z: float = 0.0 ... norm: float = optree.dataclasses.field(init=False, pytree_node=False) ... -... def __post_init__(self): +... def __post_init__(self) -> None: ... self.norm = math.hypot(self.x, self.y, self.z) ... ->>> point = Point(1.0, 2.0, 3.0) ->>> leaves, treespec = optree.tree_flatten(point, namespace='my_module') ->>> leaves -[1.0, 2.0, 3.0] +>>> point = Point(2.0, 6.0, 3.0) +>>> point +Point(x=2.0, y=6.0, z=3.0, norm=7.0) +>>> # Flatten without specifying the namespace +>>> optree.tree_flatten(point) # `Point`s are leaf nodes +([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*)) +>>> # Flatten with the namespace +>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module') +>>> accessors, leaves, treespec # doctest: +IGNORE_WHITESPACE,ELLIPSIS +( + [ + PyTreeAccessor(*.x, (DataclassEntry(field='x', type=),)), + PyTreeAccessor(*.y, (DataclassEntry(field='y', type=),)), + PyTreeAccessor(*.z, (DataclassEntry(field='z', type=),)) + ], + [2.0, 6.0, 3.0], + PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module') +) >>> point == optree.tree_unflatten(treespec, leaves) True """ @@ -88,7 +102,7 @@ def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-ma Setting `pytree_node` in the field factory is equivalent to setting the `pytree_node` metadata in the original field factory. The `pytree_node` metadata can be accessed using the `metadata`. - If ``pytree_node`` is :data:`None`, the ``metadata.get('pytree_node', True)`` will be used. + If ``pytree_node`` is :data:`None`, the value ``metadata.get('pytree_node', True)`` will be used. .. note:: If a field is considered a child node, it must be included in the argument list of the @@ -100,7 +114,7 @@ def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-ma Returns: dataclasses.Field: The field defined using the provided arguments with - ``field.metadata['pytree_node']`` set. + ``field.metadata['pytree_node']`` set. """ metadata = (metadata or {}).copy() if pytree_node is None: