Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pytree-Dataclass utilities #7

Merged
merged 10 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
command: ruff .
- run:
name: Typecheck (mypy)
command: mypy .
command: mypy stable_baselines3/common/pytree_dataclass.py tests/test_pytree_dataclass.py # TODO: remove, in PR#4.
pytype:
docker:
- image: ghcr.io/alignmentresearch/learned-planners:<< pipeline.parameters.docker_img_version >>
Expand All @@ -62,7 +62,7 @@ jobs:
working_directory: /workspace/third_party/stable-baselines3
steps:
- checkout
- run: pytype -j 4
- run: pytype -j 4 stable_baselines3/common/pytree_dataclass.py tests/test_pytree_dataclass.py # TODO: remove, in PR#4.
py-tests:
docker:
- image: ghcr.io/alignmentresearch/learned-planners:<< pipeline.parameters.docker_img_version >>
Expand Down
273 changes: 267 additions & 6 deletions stable_baselines3/common/pytree_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,282 @@
from typing import Callable, TypeVar
import dataclasses
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generic,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
overload,
)

import optree as ot
from optree import PyTree as PyTree
import torch as th
from optree import CustomTreeNode, PyTree
from typing_extensions import dataclass_transform

__all__ = ["tree_flatten", "PyTree"]
from stable_baselines3.common.type_aliases import TensorIndex
from stable_baselines3.common.utils import zip_strict

__all__ = [
"FrozenPyTreeDataclass",
"MutablePyTreeDataclass",
"TensorTree",
"tree_empty",
"tree_flatten",
"tree_index",
"tree_map",
]

T = TypeVar("T")
U = TypeVar("U")

SB3_NAMESPACE = "stable-baselines3"

_RESERVED_NAMES = ["_PyTreeDataclassBase", "FrozenPyTreeDataclass", "MutablePyTreeDataclass"]


# We need to inherit from `type(CustomTreeNode)` to prevent conflicts due to different-inheritance in metaclasses.
# - For some reason just inheriting from `typing._ProtocolMeta` does not get rid of that error.
# - Inheriting from `typing._GenericAlias` is impossible, as it's a `typing._Final` class.
#
# But in mypy, inheriting from a dynamic base class from `type` is not supported, so we disable type checking for this
# line.
class _PyTreeDataclassMeta(type(CustomTreeNode)): # type: ignore[misc]
"""Metaclass to register dataclasses as PyTrees.

Usage:
class MyDataclass(metaclass=_DataclassPyTreeMeta):
...
"""

# We need to have this `currently_registering` variable because, in the course of making a DataClass with __slots__,
# another class is created. So this will be called *twice* for every dataclass we annotate with this metaclass.
currently_registering: ClassVar[Optional[type]] = None

def __new__(mcs, name, bases, namespace, slots=True, **kwargs):
# First: create the class in the normal way.
cls = super().__new__(mcs, name, bases, namespace)

if dataclasses.is_dataclass(cls):
# If the class we're registering is already a Dataclass, it means it is a descendant of FrozenPyTreeDataclass or
# MutablePyTreeDataclass.
# This includes the children which are created when we create a dataclass with __slots__.

if mcs.currently_registering is not None:
# We've already created and annotated a class without __slots__, now we create the one with __slots__
# that will actually get returned after from the __new__ method.
assert mcs.currently_registering.__module__ == cls.__module__
assert mcs.currently_registering.__name__ == cls.__name__
mcs.currently_registering = None
return cls

else:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'else' is unnecessary as the if always kicks out of the function (either through an assertion failure, or a direct return). Removing the 'else' makes the nesting a little less severe, which is IMO slightly nice for reading (and makes it so I don't have to think about which if it corresponds to).

assert name not in _RESERVED_NAMES, (
f"Class with name {name}: classes {_RESERVED_NAMES} don't inherit from a dataclass, so they should "
"not be in this branch."
)

# Otherwise we just mark the current class as what we're registering.
if not issubclass(cls, (FrozenPyTreeDataclass, MutablePyTreeDataclass)):
raise TypeError(f"Dataclass {cls} should inherit from FrozenPyTreeDataclass or MutablePyTreeDataclass")
mcs.currently_registering = cls
rhaps0dy marked this conversation as resolved.
Show resolved Hide resolved
else:
mcs.currently_registering = cls

if name in _RESERVED_NAMES:
if not (
namespace["__module__"] == "stable_baselines3.common.pytree_dataclass" and namespace["__qualname__"] == name
):
raise TypeError(f"You cannot have another class named {name} with metaclass=_PyTreeDataclassMeta")

if name == "_PyTreeDataclassBase":
return cls
frozen = kwargs.pop("frozen")
else:
if "frozen" in kwargs:
raise TypeError(
"You should not specify frozen= for descendants of FrozenPyTreeDataclass or MutablePyTreeDataclass"
)

frozen = issubclass(cls, FrozenPyTreeDataclass)
if frozen:
if not (not issubclass(cls, MutablePyTreeDataclass) and issubclass(cls, FrozenPyTreeDataclass)):
rhaps0dy marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(f"Frozen dataclass {cls} should inherit from FrozenPyTreeDataclass")
else:
if not (issubclass(cls, MutablePyTreeDataclass) and not issubclass(cls, FrozenPyTreeDataclass)):
rhaps0dy marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(f"Mutable dataclass {cls} should inherit from MutablePyTreeDataclass")

# Calling `dataclasses.dataclass` here, with slots, is what triggers the EARLY RETURN path above.
cls = dataclasses.dataclass(frozen=frozen, slots=slots, **kwargs)(cls)

assert issubclass(cls, CustomTreeNode)
ot.register_pytree_node_class(cls, namespace=SB3_NAMESPACE)
return cls


class _PyTreeDataclassBase(CustomTreeNode[T], metaclass=_PyTreeDataclassMeta):
"""
Provides utility methods common to both MutablePyTreeDataclass and FrozenPyTreeDataclass.

However _PyTreeDataclassBase is *not* a dataclass. as it hasn't been passed through the `dataclasses.dataclass(...)`
creation function.
"""

_names_cache: ClassVar[Optional[Tuple[str, ...]]] = None

# Mark this class as a dataclass, for type checking purposes.
# Instead, it provides utility methods used by both Frozen and Mutable dataclasses.
__dataclass_fields__: ClassVar[Dict[str, dataclasses.Field[Any]]]

@classmethod
def _names(cls) -> Tuple[str, ...]:
rhaps0dy marked this conversation as resolved.
Show resolved Hide resolved
if cls._names_cache is None:
cls._names_cache = tuple(f.name for f in dataclasses.fields(cls))
return cls._names_cache

def __iter__(self):
seq, _, _ = self.tree_flatten()
return iter(seq)

# The annotations here are invalid for Pytype because T does not appear in the rest of the function. But it does
# appear as a parameter of the containing class, so it's actually not an error.
def tree_flatten(self) -> tuple[Sequence[T], None, tuple[str, ...]]: # pytype: disable=invalid-annotation
names = self._names()
return tuple(getattr(self, n) for n in names), None, names

@classmethod
def tree_unflatten(cls, metadata: None, children: Sequence[T]) -> CustomTreeNode[T]: # pytype: disable=invalid-annotation
return cls(**dict(zip_strict(cls._names(), children)))


@dataclass_transform(frozen_default=True) # pytype: disable=not-supported-yet
class FrozenPyTreeDataclass(_PyTreeDataclassBase[T], Generic[T], frozen=True):
"Abstract class for immutable dataclass PyTrees"
...


@dataclass_transform(frozen_default=False) # pytype: disable=not-supported-yet
class MutablePyTreeDataclass(_PyTreeDataclassBase[T], Generic[T], frozen=False):
"Abstract class for mutable dataclass PyTrees"
...


# Manually expand the concrete type PyTree[th.Tensor] to make mypy happy.
# See links in https://github.com/metaopt/optree/issues/6, generic recursive types are not currently supported in mypy
TensorTree = Union[
th.Tensor,
Tuple["TensorTree", ...],
Tuple[th.Tensor, ...],
List["TensorTree"],
List[th.Tensor],
Dict[Any, "TensorTree"],
Dict[Any, th.Tensor],
CustomTreeNode[th.Tensor],
PyTree[th.Tensor],
FrozenPyTreeDataclass[th.Tensor],
MutablePyTreeDataclass[th.Tensor],
]

ConcreteTensorTree = TypeVar("ConcreteTensorTree", bound=TensorTree)


@overload
def tree_flatten(
tree: TensorTree,
is_leaf: Callable[[TensorTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = SB3_NAMESPACE,
) -> tuple[list[th.Tensor], ot.PyTreeSpec]:
...


@overload
def tree_flatten(
tree: ot.PyTree[T],
tree: PyTree[T],
is_leaf: Callable[[T], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = SB3_NAMESPACE
namespace: str = SB3_NAMESPACE,
) -> tuple[list[T], ot.PyTreeSpec]:
"""optree.tree_flatten(...) but the default namespace is SB3_NAMESPACE"""
...


def tree_flatten(tree, is_leaf=None, *, none_is_leaf=False, namespace=SB3_NAMESPACE):
"""
Flattens the PyTree (see `optree.tree_flatten`), expanding nodes using the SB3_NAMESPACE by default.
"""
return ot.tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)


@overload
def tree_map(
func: Callable[..., th.Tensor],
tree: ConcreteTensorTree,
*rests: TensorTree,
is_leaf: Callable[[TensorTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = SB3_NAMESPACE,
) -> ConcreteTensorTree:
...


@overload
def tree_map(
# This annotation is supposedly invalid for Pytype because U only appears once.
func: Callable[..., U], # pytype: disable=invalid-annotation
tree: PyTree[T],
*rests: Any,
is_leaf: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree[U]:
...


def tree_map(func, tree, *rests, is_leaf=None, none_is_leaf=False, namespace=SB3_NAMESPACE): # type: ignore
"""
Maps a function over a PyTree (see `optree.tree_map`), over the trees in `tree` and `*rests`, expanding nodes using
the SB3_NAMESPACE by default.
"""
return ot.tree_map(func, tree, *rests, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)


def tree_empty(
tree: ot.PyTree, *, is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, namespace: str = SB3_NAMESPACE
) -> bool:
"""Is the tree `tree` empty, i.e. without leaves?

:param tree: the tree to check
:param namespace: when expanding nodes, use this namespace
:return: True iff the tree is empty
"""
flattened_state, _ = ot.tree_flatten(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
return not bool(flattened_state)


def tree_index(
tree: ConcreteTensorTree,
idx: TensorIndex,
*,
is_leaf: None | Callable[[TensorTree], bool] = None,
none_is_leaf: bool = False,
namespace: str = SB3_NAMESPACE,
) -> ConcreteTensorTree:
"""
Index each leaf of a PyTree of Tensors using the index `idx`.

:param tree: the tree of tensors to index
:param idx: the index to use
:param is_leaf: whether to stop tree traversal at any particular node. `is_leaf(x: PyTree[Tensor])` should return
True if the traversal should stop at `x`.
:param none_is_leaf: Whether to consider `None` as a leaf that should be indexed.
:param namespace:
:returns: tree of indexed Tensors
"""
return tree_map(lambda x: x[idx], tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
16 changes: 8 additions & 8 deletions stable_baselines3/common/recurrent/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ def _get_samples(
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (
self.to_device((lstm_states_pi[0])).contiguous(),
self.to_device((lstm_states_pi[1])).contiguous(),
self.to_device(lstm_states_pi[0]).contiguous(),
self.to_device(lstm_states_pi[1]).contiguous(),
)
lstm_states_vf = (
self.to_device((lstm_states_vf[0])).contiguous(),
self.to_device((lstm_states_vf[1])).contiguous(),
self.to_device(lstm_states_vf[0]).contiguous(),
self.to_device(lstm_states_vf[1]).contiguous(),
)
return RecurrentRolloutBufferSamples(
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
Expand Down Expand Up @@ -371,12 +371,12 @@ def _get_samples(
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (
self.to_device((lstm_states_pi[0])).contiguous(),
self.to_device((lstm_states_pi[1])).contiguous(),
self.to_device(lstm_states_pi[0]).contiguous(),
self.to_device(lstm_states_pi[1]).contiguous(),
)
lstm_states_vf = (
self.to_device((lstm_states_vf[0])).contiguous(),
self.to_device((lstm_states_vf[1])).contiguous(),
self.to_device(lstm_states_vf[0]).contiguous(),
self.to_device(lstm_states_vf[1]).contiguous(),
)

observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/recurrent/policies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import numpy as np
import torch as th
from gymnasium import spaces
from torch import nn
Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/recurrent/type_aliases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import NamedTuple, Tuple

import torch as th

from stable_baselines3.common.type_aliases import TensorDict


Expand Down
1 change: 0 additions & 1 deletion tests/test_buffers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import gymnasium as gym
import numpy as np
import optree as ot
import pytest
import torch as th
from gymnasium import spaces
Expand Down
Loading