diff --git a/.circleci/config.yml b/.circleci/config.yml index 70f122604..928662eda 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -51,7 +51,7 @@ jobs: command: ruff . - run: name: Typecheck (mypy) - command: mypy . + command: mypy stable_baselines3/common/pytree_dataclass.py tests/test_pytree_dataclass.py # TODO: remove, in PR#4. pytype: docker: - image: ghcr.io/alignmentresearch/learned-planners:<< pipeline.parameters.docker_img_version >> @@ -62,7 +62,7 @@ jobs: working_directory: /workspace/third_party/stable-baselines3 steps: - checkout - - run: pytype -j 4 + - run: pytype -j 4 stable_baselines3/common/pytree_dataclass.py tests/test_pytree_dataclass.py # TODO: remove, in PR#4. py-tests: docker: - image: ghcr.io/alignmentresearch/learned-planners:<< pipeline.parameters.docker_img_version >> diff --git a/stable_baselines3/common/pytree_dataclass.py b/stable_baselines3/common/pytree_dataclass.py index a9b75b87d..a8a0f4807 100644 --- a/stable_baselines3/common/pytree_dataclass.py +++ b/stable_baselines3/common/pytree_dataclass.py @@ -1,21 +1,282 @@ -from typing import Callable, TypeVar +import dataclasses +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Generic, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + overload, +) import optree as ot -from optree import PyTree as PyTree +import torch as th +from optree import CustomTreeNode, PyTree +from typing_extensions import dataclass_transform -__all__ = ["tree_flatten", "PyTree"] +from stable_baselines3.common.type_aliases import TensorIndex +from stable_baselines3.common.utils import zip_strict + +__all__ = [ + "FrozenPyTreeDataclass", + "MutablePyTreeDataclass", + "TensorTree", + "tree_empty", + "tree_flatten", + "tree_index", + "tree_map", +] T = TypeVar("T") +U = TypeVar("U") SB3_NAMESPACE = "stable-baselines3" +_RESERVED_NAMES = ["_PyTreeDataclassBase", "FrozenPyTreeDataclass", "MutablePyTreeDataclass"] + + +# We need to inherit from `type(CustomTreeNode)` to prevent conflicts due to different-inheritance in metaclasses. +# - For some reason just inheriting from `typing._ProtocolMeta` does not get rid of that error. +# - Inheriting from `typing._GenericAlias` is impossible, as it's a `typing._Final` class. +# +# But in mypy, inheriting from a dynamic base class from `type` is not supported, so we disable type checking for this +# line. +class _PyTreeDataclassMeta(type(CustomTreeNode)): # type: ignore[misc] + """Metaclass to register dataclasses as PyTrees. + + Usage: + class MyDataclass(metaclass=_DataclassPyTreeMeta): + ... + """ + + # We need to have this `currently_registering` variable because, in the course of making a DataClass with __slots__, + # another class is created. So this will be called *twice* for every dataclass we annotate with this metaclass. + currently_registering: ClassVar[Optional[type]] = None + + def __new__(mcs, name, bases, namespace, slots=True, **kwargs): + # First: create the class in the normal way. + cls = super().__new__(mcs, name, bases, namespace) + + if dataclasses.is_dataclass(cls): + # If the class we're registering is already a Dataclass, it means it is a descendant of FrozenPyTreeDataclass or + # MutablePyTreeDataclass. + # This includes the children which are created when we create a dataclass with __slots__. + + if mcs.currently_registering is not None: + # We've already created and annotated a class without __slots__, now we create the one with __slots__ + # that will actually get returned after from the __new__ method. + assert mcs.currently_registering.__module__ == cls.__module__ + assert mcs.currently_registering.__name__ == cls.__name__ + mcs.currently_registering = None + return cls + + else: + assert name not in _RESERVED_NAMES, ( + f"Class with name {name}: classes {_RESERVED_NAMES} don't inherit from a dataclass, so they should " + "not be in this branch." + ) + + # Otherwise we just mark the current class as what we're registering. + if not issubclass(cls, (FrozenPyTreeDataclass, MutablePyTreeDataclass)): + raise TypeError(f"Dataclass {cls} should inherit from FrozenPyTreeDataclass or MutablePyTreeDataclass") + mcs.currently_registering = cls + else: + mcs.currently_registering = cls + + if name in _RESERVED_NAMES: + if not ( + namespace["__module__"] == "stable_baselines3.common.pytree_dataclass" and namespace["__qualname__"] == name + ): + raise TypeError(f"You cannot have another class named {name} with metaclass=_PyTreeDataclassMeta") + + if name == "_PyTreeDataclassBase": + return cls + frozen = kwargs.pop("frozen") + else: + if "frozen" in kwargs: + raise TypeError( + "You should not specify frozen= for descendants of FrozenPyTreeDataclass or MutablePyTreeDataclass" + ) + + frozen = issubclass(cls, FrozenPyTreeDataclass) + if frozen: + if not (not issubclass(cls, MutablePyTreeDataclass) and issubclass(cls, FrozenPyTreeDataclass)): + raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass") + else: + if not (issubclass(cls, MutablePyTreeDataclass) and not issubclass(cls, FrozenPyTreeDataclass)): + raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass") + + # Calling `dataclasses.dataclass` here, with slots, is what triggers the EARLY RETURN path above. + cls = dataclasses.dataclass(frozen=frozen, slots=slots, **kwargs)(cls) + + assert issubclass(cls, CustomTreeNode) + ot.register_pytree_node_class(cls, namespace=SB3_NAMESPACE) + return cls + + +class _PyTreeDataclassBase(CustomTreeNode[T], metaclass=_PyTreeDataclassMeta): + """ + Provides utility methods common to both MutablePyTreeDataclass and FrozenPyTreeDataclass. + + However _PyTreeDataclassBase is *not* a dataclass. as it hasn't been passed through the `dataclasses.dataclass(...)` + creation function. + """ + + _names_cache: ClassVar[Optional[Tuple[str, ...]]] = None + + # Mark this class as a dataclass, for type checking purposes. + # Instead, it provides utility methods used by both Frozen and Mutable dataclasses. + __dataclass_fields__: ClassVar[Dict[str, dataclasses.Field[Any]]] + + @classmethod + def _names(cls) -> Tuple[str, ...]: + if cls._names_cache is None: + cls._names_cache = tuple(f.name for f in dataclasses.fields(cls)) + return cls._names_cache + + def __iter__(self): + seq, _, _ = self.tree_flatten() + return iter(seq) + + # The annotations here are invalid for Pytype because T does not appear in the rest of the function. But it does + # appear as a parameter of the containing class, so it's actually not an error. + def tree_flatten(self) -> tuple[Sequence[T], None, tuple[str, ...]]: # pytype: disable=invalid-annotation + names = self._names() + return tuple(getattr(self, n) for n in names), None, names + + @classmethod + def tree_unflatten(cls, metadata: None, children: Sequence[T]) -> CustomTreeNode[T]: # pytype: disable=invalid-annotation + return cls(**dict(zip_strict(cls._names(), children))) + + +@dataclass_transform(frozen_default=True) # pytype: disable=not-supported-yet +class FrozenPyTreeDataclass(_PyTreeDataclassBase[T], Generic[T], frozen=True): + "Abstract class for immutable dataclass PyTrees" + ... + + +@dataclass_transform(frozen_default=False) # pytype: disable=not-supported-yet +class MutablePyTreeDataclass(_PyTreeDataclassBase[T], Generic[T], frozen=False): + "Abstract class for mutable dataclass PyTrees" + ... + +# Manually expand the concrete type PyTree[th.Tensor] to make mypy happy. +# See links in https://github.com/metaopt/optree/issues/6, generic recursive types are not currently supported in mypy +TensorTree = Union[ + th.Tensor, + Tuple["TensorTree", ...], + Tuple[th.Tensor, ...], + List["TensorTree"], + List[th.Tensor], + Dict[Any, "TensorTree"], + Dict[Any, th.Tensor], + CustomTreeNode[th.Tensor], + PyTree[th.Tensor], + FrozenPyTreeDataclass[th.Tensor], + MutablePyTreeDataclass[th.Tensor], +] + +ConcreteTensorTree = TypeVar("ConcreteTensorTree", bound=TensorTree) + + +@overload +def tree_flatten( + tree: TensorTree, + is_leaf: Callable[[TensorTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = SB3_NAMESPACE, +) -> tuple[list[th.Tensor], ot.PyTreeSpec]: + ... + + +@overload def tree_flatten( - tree: ot.PyTree[T], + tree: PyTree[T], is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, - namespace: str = SB3_NAMESPACE + namespace: str = SB3_NAMESPACE, ) -> tuple[list[T], ot.PyTreeSpec]: - """optree.tree_flatten(...) but the default namespace is SB3_NAMESPACE""" + ... + + +def tree_flatten(tree, is_leaf=None, *, none_is_leaf=False, namespace=SB3_NAMESPACE): + """ + Flattens the PyTree (see `optree.tree_flatten`), expanding nodes using the SB3_NAMESPACE by default. + """ return ot.tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + + +@overload +def tree_map( + func: Callable[..., th.Tensor], + tree: ConcreteTensorTree, + *rests: TensorTree, + is_leaf: Callable[[TensorTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = SB3_NAMESPACE, +) -> ConcreteTensorTree: + ... + + +@overload +def tree_map( + # This annotation is supposedly invalid for Pytype because U only appears once. + func: Callable[..., U], # pytype: disable=invalid-annotation + tree: PyTree[T], + *rests: Any, + is_leaf: Callable[[T], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTree[U]: + ... + + +def tree_map(func, tree, *rests, is_leaf=None, none_is_leaf=False, namespace=SB3_NAMESPACE): # type: ignore + """ + Maps a function over a PyTree (see `optree.tree_map`), over the trees in `tree` and `*rests`, expanding nodes using + the SB3_NAMESPACE by default. + """ + return ot.tree_map(func, tree, *rests, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + + +def tree_empty( + tree: ot.PyTree, *, is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, namespace: str = SB3_NAMESPACE +) -> bool: + """Is the tree `tree` empty, i.e. without leaves? + + :param tree: the tree to check + :param namespace: when expanding nodes, use this namespace + :return: True iff the tree is empty + """ + flattened_state, _ = ot.tree_flatten(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + return not bool(flattened_state) + + +def tree_index( + tree: ConcreteTensorTree, + idx: TensorIndex, + *, + is_leaf: None | Callable[[TensorTree], bool] = None, + none_is_leaf: bool = False, + namespace: str = SB3_NAMESPACE, +) -> ConcreteTensorTree: + """ + Index each leaf of a PyTree of Tensors using the index `idx`. + + :param tree: the tree of tensors to index + :param idx: the index to use + :param is_leaf: whether to stop tree traversal at any particular node. `is_leaf(x: PyTree[Tensor])` should return + True if the traversal should stop at `x`. + :param none_is_leaf: Whether to consider `None` as a leaf that should be indexed. + :param namespace: + :returns: tree of indexed Tensors + """ + return tree_map(lambda x: x[idx], tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 68dbf1a79..9d1e3edcb 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -225,12 +225,12 @@ def _get_samples( self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), ) lstm_states_pi = ( - self.to_device((lstm_states_pi[0])).contiguous(), - self.to_device((lstm_states_pi[1])).contiguous(), + self.to_device(lstm_states_pi[0]).contiguous(), + self.to_device(lstm_states_pi[1]).contiguous(), ) lstm_states_vf = ( - self.to_device((lstm_states_vf[0])).contiguous(), - self.to_device((lstm_states_vf[1])).contiguous(), + self.to_device(lstm_states_vf[0]).contiguous(), + self.to_device(lstm_states_vf[1]).contiguous(), ) return RecurrentRolloutBufferSamples( # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) @@ -371,12 +371,12 @@ def _get_samples( self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), ) lstm_states_pi = ( - self.to_device((lstm_states_pi[0])).contiguous(), - self.to_device((lstm_states_pi[1])).contiguous(), + self.to_device(lstm_states_pi[0]).contiguous(), + self.to_device(lstm_states_pi[1]).contiguous(), ) lstm_states_vf = ( - self.to_device((lstm_states_vf[0])).contiguous(), - self.to_device((lstm_states_vf[1])).contiguous(), + self.to_device(lstm_states_vf[0]).contiguous(), + self.to_device(lstm_states_vf[1]).contiguous(), ) observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 87372a1d9..be08a7c21 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union -import numpy as np import torch as th from gymnasium import spaces from torch import nn diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py index 21ac0e0d9..b1f04446b 100644 --- a/stable_baselines3/common/recurrent/type_aliases.py +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -1,6 +1,7 @@ from typing import NamedTuple, Tuple import torch as th + from stable_baselines3.common.type_aliases import TensorDict diff --git a/tests/test_buffers.py b/tests/test_buffers.py index e69328bf1..b4d7ac8e0 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -1,6 +1,5 @@ import gymnasium as gym import numpy as np -import optree as ot import pytest import torch as th from gymnasium import spaces diff --git a/tests/test_pytree_dataclass.py b/tests/test_pytree_dataclass.py new file mode 100644 index 000000000..ad4c797cc --- /dev/null +++ b/tests/test_pytree_dataclass.py @@ -0,0 +1,208 @@ +from dataclasses import FrozenInstanceError +from typing import Optional + +import pytest + +import stable_baselines3.common.pytree_dataclass as ptd + + +@pytest.mark.parametrize("ParentPyTreeClass", (ptd.FrozenPyTreeDataclass, ptd.MutablePyTreeDataclass)) +def test_dataclass_mapped_have_slots(ParentPyTreeClass: type) -> None: + """ + If after running `tree_map` the class still has __slots__ and they're the same, then the correct class (the one with + __slots__) is what has been registered as a Pytree custom node. + """ + + class D(ParentPyTreeClass): + a: int + b: str + + d = D(4, "b") + + assert D.__slots__ == ("a", "b") + assert d.__slots__ == ("a", "b") + + d2 = ptd.tree_map(lambda x: x * 2, d) + + assert d2.a == 8 and d2.b == "bb" + + assert isinstance(d2, D) + assert d2.__slots__ == d.__slots__ + + +@pytest.mark.parametrize("ParentPyTreeClass", (ptd.FrozenPyTreeDataclass, ptd.MutablePyTreeDataclass)) +def test_dataclass_frozen_explicit(ParentPyTreeClass: type) -> None: + class D(ParentPyTreeClass): + a: int + + with pytest.raises(TypeError, match="You should not specify frozen= for descendants"): + + class D(ParentPyTreeClass, frozen=True): # type: ignore # noqa:F811 + a: int + + +@pytest.mark.parametrize("frozen", (True, False)) +def test_dataclass_must_be_descendant(frozen: bool) -> None: + """classes with metaclass _PyTreeDataclassMeta must be descendants of FrozenPyTreeDataclass or MutablePyTreeDataclass""" + + # First with arbitrary name + with pytest.raises(TypeError): + + class D(ptd._PyTreeDataclassBase, frozen=frozen): # type: ignore + pass + + with pytest.raises(TypeError): + + class D(metaclass=ptd._PyTreeDataclassMeta, frozen=frozen): # type: ignore # noqa: F811 + pass + + with pytest.raises(TypeError, match="[^ ]* dataclass .* should inherit"): + + class D(ptd._PyTreeDataclassBase): # type: ignore # noqa: F811 + pass + + with pytest.raises(TypeError, match="[^ ]* dataclass .* should inherit"): + + class D(metaclass=ptd._PyTreeDataclassMeta): # type: ignore # noqa: F811 + pass + + # Then try to copy each of the reserved names: + ## _PyTreeDataclassBase + with pytest.raises(TypeError): + + class _PyTreeDataclassBase(ptd._PyTreeDataclassBase, frozen=frozen): # type: ignore + pass + + with pytest.raises(TypeError): + + class _PyTreeDataclassBase(metaclass=ptd._PyTreeDataclassMeta, frozen=frozen): # type: ignore + pass + + with pytest.raises(TypeError, match="You cannot have another class named"): + + class _PyTreeDataclassBase(ptd._PyTreeDataclassBase): # type: ignore + pass + + with pytest.raises(TypeError, match="You cannot have another class named"): + + class _PyTreeDataclassBase(metaclass=ptd._PyTreeDataclassMeta): # type: ignore + pass + + ## FrozenPyTreeDataclass + with pytest.raises(TypeError): + + class FrozenPyTreeDataclass(ptd._PyTreeDataclassBase, frozen=frozen): # type: ignore + pass + + with pytest.raises(TypeError): + + class FrozenPyTreeDataclass(metaclass=ptd._PyTreeDataclassMeta, frozen=frozen): # type: ignore # noqa: F811 + pass + + with pytest.raises(TypeError, match="You cannot have another class named"): + + class FrozenPyTreeDataclass(ptd._PyTreeDataclassBase): # type: ignore # noqa: F811 + pass + + with pytest.raises(TypeError, match="You cannot have another class named"): + + class FrozenPyTreeDataclass(metaclass=ptd._PyTreeDataclassMeta): # type: ignore # noqa: F811 + pass + + ## MutablePyTreeDataclass + with pytest.raises(TypeError): + + class MutablePyTreeDataclass(ptd._PyTreeDataclassBase, frozen=frozen): # type: ignore + pass + + with pytest.raises(TypeError): + + class MutablePyTreeDataclass(metaclass=ptd._PyTreeDataclassMeta, frozen=frozen): # type: ignore # noqa:F811 + pass + + with pytest.raises(TypeError, match="You cannot have another class named"): + + class MutablePyTreeDataclass(ptd._PyTreeDataclassBase): # type: ignore # noqa:F811 + pass + + with pytest.raises(TypeError, match="You cannot have another class named"): + + class MutablePyTreeDataclass(metaclass=ptd._PyTreeDataclassMeta): # type: ignore # noqa:F811 + pass + + +def test_dataclass_frozen_or_not() -> None: + class MutA(ptd.MutablePyTreeDataclass): + a: int + + class FrozenA(ptd.FrozenPyTreeDataclass): + a: int + + inst1 = MutA(2) + inst2 = FrozenA(2) + + inst1.a = 2 + with pytest.raises(FrozenInstanceError): + inst2.a = 3 # type: ignore[misc] + + +@pytest.mark.parametrize("ParentPyTreeClass", (ptd.FrozenPyTreeDataclass, ptd.MutablePyTreeDataclass)) +def test_dataclass_inheriting_dataclass(ParentPyTreeClass: type) -> None: + class A(ParentPyTreeClass): + a: int + + inst = A(3) + assert inst.a == 3 + + class B(A): + b: int + + inst = B(2, 4) + assert inst.a == 2 + assert inst.b == 4 + + +def test_tree_flatten() -> None: + class A(ptd.FrozenPyTreeDataclass): + a: Optional[int] + + flat, _ = ptd.tree_flatten((A(3), A(None), {"a": A(4)})) # type: ignore + assert flat == [3, 4] + + +def test_tree_map() -> None: + class A(ptd.FrozenPyTreeDataclass): + a: Optional[int] + + out = ptd.tree_map(lambda x: x * 2, ([2, 3], 4, A(5), None, {"a": 6})) # type: ignore + assert out == ([4, 6], 8, A(10), None, {"a": 12}) + + +def test_tree_empty() -> None: + assert ptd.tree_empty(()) # type: ignore + assert ptd.tree_empty([]) # type: ignore + assert ptd.tree_empty({}) # type: ignore + assert not ptd.tree_empty({"a": 2}) # type: ignore + assert not ptd.tree_empty([2]) # type: ignore + + class A(ptd.FrozenPyTreeDataclass): + a: Optional[int] + + assert ptd.tree_empty([A(None)]) # type: ignore + assert not ptd.tree_empty([A(None)], none_is_leaf=True) # type: ignore + assert not ptd.tree_empty([A(2)]) # type: ignore + + +def test_tree_index() -> None: + l1 = ["a", "b", "c"] + l2 = ["hi", "bye"] + idx = 1 + + e1 = l1[idx] + e2 = l2[idx] + + class A(ptd.FrozenPyTreeDataclass): + a: str + + out_tree = ptd.tree_index([A(l1), A(l2), l1, (l2, {"a": l1})], idx, is_leaf=lambda x: x is l1 or x is l2) # type: ignore + assert out_tree == [A(e1), A(e2), e1, (e2, {"a": e1})]