diff --git a/docs/source/dataclasses.rst b/docs/source/dataclasses.rst index 15642b16..95a4e0fd 100644 --- a/docs/source/dataclasses.rst +++ b/docs/source/dataclasses.rst @@ -3,6 +3,9 @@ Integration with :mod:`dataclasses` .. currentmodule:: optree.dataclasses +.. automodule:: optree.dataclasses + :no-members: + .. autosummary:: field diff --git a/docs/source/functools.rst b/docs/source/functools.rst index f8dc1f54..fbcbb54c 100644 --- a/docs/source/functools.rst +++ b/docs/source/functools.rst @@ -3,6 +3,9 @@ Integration with :mod:`functools` .. currentmodule:: optree.functools +.. automodule:: optree.functools + :no-members: + .. autosummary:: partial diff --git a/optree/dataclasses.py b/optree/dataclasses.py index 473d37af..b4fd0914 100644 --- a/optree/dataclasses.py +++ b/optree/dataclasses.py @@ -30,13 +30,27 @@ ... z: float = 0.0 ... norm: float = optree.dataclasses.field(init=False, pytree_node=False) ... -... def __post_init__(self): +... def __post_init__(self) -> None: ... self.norm = math.hypot(self.x, self.y, self.z) ... ->>> point = Point(1.0, 2.0, 3.0) ->>> leaves, treespec = optree.tree_flatten(point, namespace='my_module') ->>> leaves -[1.0, 2.0, 3.0] +>>> point = Point(2.0, 6.0, 3.0) +>>> point +Point(x=2.0, y=6.0, z=3.0, norm=7.0) +>>> # Flatten without specifying the namespace +>>> optree.tree_flatten(point) # `Point`s are leaf nodes +([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*)) +>>> # Flatten with the namespace +>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module') +>>> accessors, leaves, treespec # doctest: +IGNORE_WHITESPACE,ELLIPSIS +( + [ + PyTreeAccessor(*.x, (DataclassEntry(field='x', type=),)), + PyTreeAccessor(*.y, (DataclassEntry(field='y', type=),)), + PyTreeAccessor(*.z, (DataclassEntry(field='z', type=),)) + ], + [2.0, 6.0, 3.0], + PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module') +) >>> point == optree.tree_unflatten(treespec, leaves) True """ @@ -86,13 +100,14 @@ def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-ma PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will be considered as PyTree metadata. - Setting `pytree_node` in the field factory is equivalent to setting the `pytree_node` metadata - in the original field factory. The `pytree_node` metadata can be accessed using the `metadata`. - If ``pytree_node`` is :data:`None`, the ``metadata.get('pytree_node', True)`` will be used. + Setting ``pytree_node`` in the field factory is equivalent to setting a key ``'pytree_node'`` in + ``metadata`` in the original field factory. The ``pytree_node`` value can be accessed using + ``field.metadata['pytree_node']``. If ``pytree_node`` is :data:`None`, the value + ``metadata.get('pytree_node', True)`` will be used. .. note:: If a field is considered a child node, it must be included in the argument list of the - :meth:`__init__` method. + :meth:`__init__` method, i.e., passes ``init=True`` in the field factory. Args: pytree_node (bool or None, optional): Whether the field is a PyTree node. @@ -100,7 +115,7 @@ def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-ma Returns: dataclasses.Field: The field defined using the provided arguments with - ``field.metadata['pytree_node']`` set. + ``field.metadata['pytree_node']`` set. """ metadata = (metadata or {}).copy() if pytree_node is None: