Skip to content

Commit

Permalink
style: miscellaneous style housekeeping (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Oct 5, 2023
1 parent c1da9d0 commit 7c4d71e
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 102 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ optree.register_pytree_node(
torch.Tensor,
# (tensor) -> (children, metadata)
flatten_func=lambda tensor: (
(tensor.cpu().numpy(),),
dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
(tensor.cpu().detach().numpy(),),
{'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
),
# (metadata, children) -> tensor
unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
Expand Down
5 changes: 2 additions & 3 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ Tree Manipulation Functions
tree_map_
tree_map_with_path
tree_map_with_path_
tree_replace_nones
tree_transpose
tree_broadcast_prefix
broadcast_prefix
tree_replace_nones
prefix_errors

.. autofunction:: tree_flatten
Expand All @@ -51,10 +51,10 @@ Tree Manipulation Functions
.. autofunction:: tree_map_
.. autofunction:: tree_map_with_path
.. autofunction:: tree_map_with_path_
.. autofunction:: tree_replace_nones
.. autofunction:: tree_transpose
.. autofunction:: tree_broadcast_prefix
.. autofunction:: broadcast_prefix
.. autofunction:: tree_replace_nones
.. autofunction:: prefix_errors

------
Expand All @@ -64,7 +64,6 @@ Tree Reduce Functions

.. autosummary::

tree_replace_nones
tree_reduce
tree_sum
tree_max
Expand Down
2 changes: 0 additions & 2 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,6 @@ class PyTreeSpec {
const std::optional<py::function> &leaf_predicate,
const std::string &registry_namespace);

[[nodiscard]] py::list FlattenUpToImpl(const py::handle &full_tree) const;

template <bool NoneIsLeaf>
static bool AllLeavesImpl(const py::iterable &iterable, const std::string &registry_namespace);

Expand Down
2 changes: 1 addition & 1 deletion optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@
'tree_map_',
'tree_map_with_path',
'tree_map_with_path_',
'tree_replace_nones',
'tree_transpose',
'tree_broadcast_prefix',
'broadcast_prefix',
'tree_replace_nones',
'tree_reduce',
'tree_sum',
'tree_max',
Expand Down
Loading

0 comments on commit 7c4d71e

Please sign in to comment.