forked from AlignmentResearch/stable-baselines3
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Pytree-Dataclass utilities #7
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
559f259
Checkout updated file from 'contrib-recurrent'
rhaps0dy 7465369
Fix type errors
rhaps0dy 5bf253c
Add docstrings
rhaps0dy 1f5fcd3
Typecheck only the files that should typecheck for this PR
rhaps0dy 71c5e97
Fix ruff errors
rhaps0dy 18a2945
TypeVar S is unused
rhaps0dy a9370d8
Some improvements along with a test
rhaps0dy 57dfc30
Test all the pytree and dataclass functions/metaclasses
rhaps0dy 1e7cf77
Fix ruff and types
rhaps0dy b9ed049
Simplify
rhaps0dy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,282 @@ | ||
from typing import Callable, TypeVar | ||
import dataclasses | ||
from typing import ( | ||
Any, | ||
Callable, | ||
ClassVar, | ||
Dict, | ||
Generic, | ||
List, | ||
Optional, | ||
Sequence, | ||
Tuple, | ||
TypeVar, | ||
Union, | ||
overload, | ||
) | ||
|
||
import optree as ot | ||
from optree import PyTree as PyTree | ||
import torch as th | ||
from optree import CustomTreeNode, PyTree | ||
from typing_extensions import dataclass_transform | ||
|
||
__all__ = ["tree_flatten", "PyTree"] | ||
from stable_baselines3.common.type_aliases import TensorIndex | ||
from stable_baselines3.common.utils import zip_strict | ||
|
||
__all__ = [ | ||
"FrozenPyTreeDataclass", | ||
"MutablePyTreeDataclass", | ||
"TensorTree", | ||
"tree_empty", | ||
"tree_flatten", | ||
"tree_index", | ||
"tree_map", | ||
] | ||
|
||
T = TypeVar("T") | ||
U = TypeVar("U") | ||
|
||
SB3_NAMESPACE = "stable-baselines3" | ||
|
||
_RESERVED_NAMES = ["_PyTreeDataclassBase", "FrozenPyTreeDataclass", "MutablePyTreeDataclass"] | ||
|
||
|
||
# We need to inherit from `type(CustomTreeNode)` to prevent conflicts due to different-inheritance in metaclasses. | ||
# - For some reason just inheriting from `typing._ProtocolMeta` does not get rid of that error. | ||
# - Inheriting from `typing._GenericAlias` is impossible, as it's a `typing._Final` class. | ||
# | ||
# But in mypy, inheriting from a dynamic base class from `type` is not supported, so we disable type checking for this | ||
# line. | ||
class _PyTreeDataclassMeta(type(CustomTreeNode)): # type: ignore[misc] | ||
"""Metaclass to register dataclasses as PyTrees. | ||
|
||
Usage: | ||
class MyDataclass(metaclass=_DataclassPyTreeMeta): | ||
... | ||
""" | ||
|
||
# We need to have this `currently_registering` variable because, in the course of making a DataClass with __slots__, | ||
# another class is created. So this will be called *twice* for every dataclass we annotate with this metaclass. | ||
currently_registering: ClassVar[Optional[type]] = None | ||
|
||
def __new__(mcs, name, bases, namespace, slots=True, **kwargs): | ||
# First: create the class in the normal way. | ||
cls = super().__new__(mcs, name, bases, namespace) | ||
|
||
if dataclasses.is_dataclass(cls): | ||
# If the class we're registering is already a Dataclass, it means it is a descendant of FrozenPyTreeDataclass or | ||
# MutablePyTreeDataclass. | ||
# This includes the children which are created when we create a dataclass with __slots__. | ||
|
||
if mcs.currently_registering is not None: | ||
# We've already created and annotated a class without __slots__, now we create the one with __slots__ | ||
# that will actually get returned after from the __new__ method. | ||
assert mcs.currently_registering.__module__ == cls.__module__ | ||
assert mcs.currently_registering.__name__ == cls.__name__ | ||
mcs.currently_registering = None | ||
return cls | ||
|
||
else: | ||
assert name not in _RESERVED_NAMES, ( | ||
f"Class with name {name}: classes {_RESERVED_NAMES} don't inherit from a dataclass, so they should " | ||
"not be in this branch." | ||
) | ||
|
||
# Otherwise we just mark the current class as what we're registering. | ||
if not issubclass(cls, (FrozenPyTreeDataclass, MutablePyTreeDataclass)): | ||
raise TypeError(f"Dataclass {cls} should inherit from FrozenPyTreeDataclass or MutablePyTreeDataclass") | ||
mcs.currently_registering = cls | ||
rhaps0dy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
mcs.currently_registering = cls | ||
|
||
if name in _RESERVED_NAMES: | ||
if not ( | ||
namespace["__module__"] == "stable_baselines3.common.pytree_dataclass" and namespace["__qualname__"] == name | ||
): | ||
raise TypeError(f"You cannot have another class named {name} with metaclass=_PyTreeDataclassMeta") | ||
|
||
if name == "_PyTreeDataclassBase": | ||
return cls | ||
frozen = kwargs.pop("frozen") | ||
else: | ||
if "frozen" in kwargs: | ||
raise TypeError( | ||
"You should not specify frozen= for descendants of FrozenPyTreeDataclass or MutablePyTreeDataclass" | ||
) | ||
|
||
frozen = issubclass(cls, FrozenPyTreeDataclass) | ||
if frozen: | ||
if not (not issubclass(cls, MutablePyTreeDataclass) and issubclass(cls, FrozenPyTreeDataclass)): | ||
rhaps0dy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass") | ||
else: | ||
if not (issubclass(cls, MutablePyTreeDataclass) and not issubclass(cls, FrozenPyTreeDataclass)): | ||
rhaps0dy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass") | ||
|
||
# Calling `dataclasses.dataclass` here, with slots, is what triggers the EARLY RETURN path above. | ||
cls = dataclasses.dataclass(frozen=frozen, slots=slots, **kwargs)(cls) | ||
|
||
assert issubclass(cls, CustomTreeNode) | ||
ot.register_pytree_node_class(cls, namespace=SB3_NAMESPACE) | ||
return cls | ||
|
||
|
||
class _PyTreeDataclassBase(CustomTreeNode[T], metaclass=_PyTreeDataclassMeta): | ||
""" | ||
Provides utility methods common to both MutablePyTreeDataclass and FrozenPyTreeDataclass. | ||
|
||
However _PyTreeDataclassBase is *not* a dataclass. as it hasn't been passed through the `dataclasses.dataclass(...)` | ||
creation function. | ||
""" | ||
|
||
_names_cache: ClassVar[Optional[Tuple[str, ...]]] = None | ||
|
||
# Mark this class as a dataclass, for type checking purposes. | ||
# Instead, it provides utility methods used by both Frozen and Mutable dataclasses. | ||
__dataclass_fields__: ClassVar[Dict[str, dataclasses.Field[Any]]] | ||
|
||
@classmethod | ||
def _names(cls) -> Tuple[str, ...]: | ||
rhaps0dy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if cls._names_cache is None: | ||
cls._names_cache = tuple(f.name for f in dataclasses.fields(cls)) | ||
return cls._names_cache | ||
|
||
def __iter__(self): | ||
seq, _, _ = self.tree_flatten() | ||
return iter(seq) | ||
|
||
# The annotations here are invalid for Pytype because T does not appear in the rest of the function. But it does | ||
# appear as a parameter of the containing class, so it's actually not an error. | ||
def tree_flatten(self) -> tuple[Sequence[T], None, tuple[str, ...]]: # pytype: disable=invalid-annotation | ||
names = self._names() | ||
return tuple(getattr(self, n) for n in names), None, names | ||
|
||
@classmethod | ||
def tree_unflatten(cls, metadata: None, children: Sequence[T]) -> CustomTreeNode[T]: # pytype: disable=invalid-annotation | ||
return cls(**dict(zip_strict(cls._names(), children))) | ||
|
||
|
||
@dataclass_transform(frozen_default=True) # pytype: disable=not-supported-yet | ||
class FrozenPyTreeDataclass(_PyTreeDataclassBase[T], Generic[T], frozen=True): | ||
"Abstract class for immutable dataclass PyTrees" | ||
... | ||
|
||
|
||
@dataclass_transform(frozen_default=False) # pytype: disable=not-supported-yet | ||
class MutablePyTreeDataclass(_PyTreeDataclassBase[T], Generic[T], frozen=False): | ||
"Abstract class for mutable dataclass PyTrees" | ||
... | ||
|
||
|
||
# Manually expand the concrete type PyTree[th.Tensor] to make mypy happy. | ||
# See links in https://github.com/metaopt/optree/issues/6, generic recursive types are not currently supported in mypy | ||
TensorTree = Union[ | ||
th.Tensor, | ||
Tuple["TensorTree", ...], | ||
Tuple[th.Tensor, ...], | ||
List["TensorTree"], | ||
List[th.Tensor], | ||
Dict[Any, "TensorTree"], | ||
Dict[Any, th.Tensor], | ||
CustomTreeNode[th.Tensor], | ||
PyTree[th.Tensor], | ||
FrozenPyTreeDataclass[th.Tensor], | ||
MutablePyTreeDataclass[th.Tensor], | ||
] | ||
|
||
ConcreteTensorTree = TypeVar("ConcreteTensorTree", bound=TensorTree) | ||
|
||
|
||
@overload | ||
def tree_flatten( | ||
tree: TensorTree, | ||
is_leaf: Callable[[TensorTree], bool] | None = None, | ||
*, | ||
none_is_leaf: bool = False, | ||
namespace: str = SB3_NAMESPACE, | ||
) -> tuple[list[th.Tensor], ot.PyTreeSpec]: | ||
... | ||
|
||
|
||
@overload | ||
def tree_flatten( | ||
tree: ot.PyTree[T], | ||
tree: PyTree[T], | ||
is_leaf: Callable[[T], bool] | None = None, | ||
*, | ||
none_is_leaf: bool = False, | ||
namespace: str = SB3_NAMESPACE | ||
namespace: str = SB3_NAMESPACE, | ||
) -> tuple[list[T], ot.PyTreeSpec]: | ||
"""optree.tree_flatten(...) but the default namespace is SB3_NAMESPACE""" | ||
... | ||
|
||
|
||
def tree_flatten(tree, is_leaf=None, *, none_is_leaf=False, namespace=SB3_NAMESPACE): | ||
""" | ||
Flattens the PyTree (see `optree.tree_flatten`), expanding nodes using the SB3_NAMESPACE by default. | ||
""" | ||
return ot.tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) | ||
|
||
|
||
@overload | ||
def tree_map( | ||
func: Callable[..., th.Tensor], | ||
tree: ConcreteTensorTree, | ||
*rests: TensorTree, | ||
is_leaf: Callable[[TensorTree], bool] | None = None, | ||
none_is_leaf: bool = False, | ||
namespace: str = SB3_NAMESPACE, | ||
) -> ConcreteTensorTree: | ||
... | ||
|
||
|
||
@overload | ||
def tree_map( | ||
# This annotation is supposedly invalid for Pytype because U only appears once. | ||
func: Callable[..., U], # pytype: disable=invalid-annotation | ||
tree: PyTree[T], | ||
*rests: Any, | ||
is_leaf: Callable[[T], bool] | None = None, | ||
none_is_leaf: bool = False, | ||
namespace: str = "", | ||
) -> PyTree[U]: | ||
... | ||
|
||
|
||
def tree_map(func, tree, *rests, is_leaf=None, none_is_leaf=False, namespace=SB3_NAMESPACE): # type: ignore | ||
""" | ||
Maps a function over a PyTree (see `optree.tree_map`), over the trees in `tree` and `*rests`, expanding nodes using | ||
the SB3_NAMESPACE by default. | ||
""" | ||
return ot.tree_map(func, tree, *rests, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) | ||
|
||
|
||
def tree_empty( | ||
tree: ot.PyTree, *, is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, namespace: str = SB3_NAMESPACE | ||
) -> bool: | ||
"""Is the tree `tree` empty, i.e. without leaves? | ||
|
||
:param tree: the tree to check | ||
:param namespace: when expanding nodes, use this namespace | ||
:return: True iff the tree is empty | ||
""" | ||
flattened_state, _ = ot.tree_flatten(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) | ||
return not bool(flattened_state) | ||
|
||
|
||
def tree_index( | ||
tree: ConcreteTensorTree, | ||
idx: TensorIndex, | ||
*, | ||
is_leaf: None | Callable[[TensorTree], bool] = None, | ||
none_is_leaf: bool = False, | ||
namespace: str = SB3_NAMESPACE, | ||
) -> ConcreteTensorTree: | ||
""" | ||
Index each leaf of a PyTree of Tensors using the index `idx`. | ||
|
||
:param tree: the tree of tensors to index | ||
:param idx: the index to use | ||
:param is_leaf: whether to stop tree traversal at any particular node. `is_leaf(x: PyTree[Tensor])` should return | ||
True if the traversal should stop at `x`. | ||
:param none_is_leaf: Whether to consider `None` as a leaf that should be indexed. | ||
:param namespace: | ||
:returns: tree of indexed Tensors | ||
""" | ||
return tree_map(lambda x: x[idx], tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'else' is unnecessary as the
if
always kicks out of the function (either through an assertion failure, or a direct return). Removing the 'else' makes the nesting a little less severe, which is IMO slightly nice for reading (and makes it so I don't have to think about whichif
it corresponds to).