Skip to content

Commit

Permalink
docs: update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Sep 11, 2024
1 parent d1f281f commit 7a39408
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
3 changes: 3 additions & 0 deletions docs/source/dataclasses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ Integration with :mod:`dataclasses`

.. currentmodule:: optree.dataclasses

.. automodule:: optree.dataclasses
:no-members:

.. autosummary::

field
Expand Down
3 changes: 3 additions & 0 deletions docs/source/functools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ Integration with :mod:`functools`

.. currentmodule:: optree.functools

.. automodule:: optree.functools
:no-members:

.. autosummary::

partial
Expand Down
28 changes: 21 additions & 7 deletions optree/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,27 @@
... z: float = 0.0
... norm: float = optree.dataclasses.field(init=False, pytree_node=False)
...
... def __post_init__(self):
... def __post_init__(self) -> None:
... self.norm = math.hypot(self.x, self.y, self.z)
...
>>> point = Point(1.0, 2.0, 3.0)
>>> leaves, treespec = optree.tree_flatten(point, namespace='my_module')
>>> leaves
[1.0, 2.0, 3.0]
>>> point = Point(2.0, 6.0, 3.0)
>>> point
Point(x=2.0, y=6.0, z=3.0, norm=7.0)
>>> # Flatten without specifying the namespace
>>> optree.tree_flatten(point) # `Point`s are leaf nodes
([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*))
>>> # Flatten with the namespace
>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module')
>>> accessors, leaves, treespec # doctest: +IGNORE_WHITESPACE,ELLIPSIS
(
[
PyTreeAccessor(*.x, (DataclassEntry(field='x', type=<class '...Point'>),)),
PyTreeAccessor(*.y, (DataclassEntry(field='y', type=<class '...Point'>),)),
PyTreeAccessor(*.z, (DataclassEntry(field='z', type=<class '...Point'>),))
],
[2.0, 6.0, 3.0],
PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module')
)
>>> point == optree.tree_unflatten(treespec, leaves)
True
"""
Expand Down Expand Up @@ -88,7 +102,7 @@ def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-ma
Setting `pytree_node` in the field factory is equivalent to setting the `pytree_node` metadata
in the original field factory. The `pytree_node` metadata can be accessed using the `metadata`.
If ``pytree_node`` is :data:`None`, the ``metadata.get('pytree_node', True)`` will be used.
If ``pytree_node`` is :data:`None`, the value ``metadata.get('pytree_node', True)`` will be used.
.. note::
If a field is considered a child node, it must be included in the argument list of the
Expand All @@ -100,7 +114,7 @@ def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-ma
Returns:
dataclasses.Field: The field defined using the provided arguments with
``field.metadata['pytree_node']`` set.
``field.metadata['pytree_node']`` set.
"""
metadata = (metadata or {}).copy()
if pytree_node is None:
Expand Down

0 comments on commit 7a39408

Please sign in to comment.