Skip to content

Commit

Permalink
docs: update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Sep 11, 2024
1 parent d1f281f commit 71f1ba0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/source/dataclasses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ Integration with :mod:`dataclasses`

.. currentmodule:: optree.dataclasses

.. automodule:: optree.dataclasses
:no-members:

.. autosummary::

field
Expand Down
3 changes: 3 additions & 0 deletions docs/source/functools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ Integration with :mod:`functools`

.. currentmodule:: optree.functools

.. automodule:: optree.functools
:no-members:

.. autosummary::

partial
Expand Down
35 changes: 25 additions & 10 deletions optree/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,27 @@
... z: float = 0.0
... norm: float = optree.dataclasses.field(init=False, pytree_node=False)
...
... def __post_init__(self):
... def __post_init__(self) -> None:
... self.norm = math.hypot(self.x, self.y, self.z)
...
>>> point = Point(1.0, 2.0, 3.0)
>>> leaves, treespec = optree.tree_flatten(point, namespace='my_module')
>>> leaves
[1.0, 2.0, 3.0]
>>> point = Point(2.0, 6.0, 3.0)
>>> point
Point(x=2.0, y=6.0, z=3.0, norm=7.0)
>>> # Flatten without specifying the namespace
>>> optree.tree_flatten(point) # `Point`s are leaf nodes
([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*))
>>> # Flatten with the namespace
>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module')
>>> accessors, leaves, treespec # doctest: +IGNORE_WHITESPACE,ELLIPSIS
(
[
PyTreeAccessor(*.x, (DataclassEntry(field='x', type=<class '...Point'>),)),
PyTreeAccessor(*.y, (DataclassEntry(field='y', type=<class '...Point'>),)),
PyTreeAccessor(*.z, (DataclassEntry(field='z', type=<class '...Point'>),))
],
[2.0, 6.0, 3.0],
PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module')
)
>>> point == optree.tree_unflatten(treespec, leaves)
True
"""
Expand Down Expand Up @@ -86,21 +100,22 @@ def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-ma
PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will
be considered as PyTree metadata.
Setting `pytree_node` in the field factory is equivalent to setting the `pytree_node` metadata
in the original field factory. The `pytree_node` metadata can be accessed using the `metadata`.
If ``pytree_node`` is :data:`None`, the ``metadata.get('pytree_node', True)`` will be used.
Setting ``pytree_node`` in the field factory is equivalent to setting a key ``'pytree_node'`` in
``metadata`` in the original field factory. The ``pytree_node`` value can be accessed using
``field.metadata['pytree_node']``. If ``pytree_node`` is :data:`None`, the value
``metadata.get('pytree_node', True)`` will be used.
.. note::
If a field is considered a child node, it must be included in the argument list of the
:meth:`__init__` method.
:meth:`__init__` method, i.e., passes ``init=True`` in the field factory.
Args:
pytree_node (bool or None, optional): Whether the field is a PyTree node.
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.field`.
Returns:
dataclasses.Field: The field defined using the provided arguments with
``field.metadata['pytree_node']`` set.
``field.metadata['pytree_node']`` set.
"""
metadata = (metadata or {}).copy()
if pytree_node is None:
Expand Down

0 comments on commit 71f1ba0

Please sign in to comment.