From 578938a961fd5ef1b582a2d5da68d75bc641686a Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Thu, 29 Dec 2022 01:15:47 -0500 Subject: [PATCH] Making `_set_seed` abstract (#770) --- test/mocking_classes.py | 18 ++++++++---------- torchrl/envs/common.py | 5 +---- torchrl/envs/transforms/transforms.py | 9 +++++++-- torchrl/envs/vec_env.py | 12 ++++++++++-- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 7629d2874e7..5f92e0da52c 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl._utils import seed_generator from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -85,12 +84,9 @@ def __init__( def maxstep(self): return 100 - def set_seed(self, seed: int, static_seed=False) -> int: + def _set_seed(self, seed: Optional[int]): self.seed = seed self.counter = seed % 17 # make counter a small number - if static_seed: - return seed - return seed_generator(seed) def custom_fun(self): return 0 @@ -136,14 +132,11 @@ def __init__(self, device): super(MockSerialEnv, self).__init__(device=device) self.is_closed = False - def set_seed(self, seed: int, static_seed: bool = False) -> int: + def _set_seed(self, seed: Optional[int]): assert seed >= 1 self.seed = seed self.counter = seed % 17 # make counter a small number self.max_val = max(self.counter + 100, self.counter * 2) - if static_seed: - return seed - return seed_generator(seed) def _step(self, tensordict): self.counter += 1 @@ -207,9 +200,14 @@ def __init__(self, device, batch_size=None): super(MockBatchedLockedEnv, self).__init__(device=device, batch_size=batch_size) self.counter = 0 - set_seed = MockSerialEnv.set_seed rand_step = MockSerialEnv.rand_step + def _set_seed(self, seed: Optional[int]): + assert seed >= 1 + self.seed = seed + self.counter = seed % 17 # make counter a small number + self.max_val = max(self.counter + 100, self.counter * 2) + def _step(self, tensordict): self.counter += 1 # We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 8cb7e52266e..406013d2480 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -493,6 +493,7 @@ def set_seed( seed = new_seed return seed + @abc.abstractmethod def _set_seed(self, seed: Optional[int]): raise NotImplementedError @@ -824,10 +825,6 @@ def close(self) -> None: except AttributeError: pass - @abc.abstractmethod - def _set_seed(self, seed: Optional[int]): - raise NotImplementedError - def make_tensordict( env: _EnvWrapper, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 647760ba4cd..ac7858fe119 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -30,7 +30,6 @@ from torchrl.envs.transforms.utils import check_finite from torchrl.envs.utils import step_mdp - try: from torchvision.transforms.functional import center_crop from torchvision.transforms.functional_tensor import ( @@ -453,10 +452,16 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict_out - def set_seed(self, seed: int, static_seed: bool = False) -> int: + def set_seed( + self, seed: Optional[int] = None, static_seed: bool = False + ) -> Optional[int]: """Set the seeds of the environment.""" return self.base_env.set_seed(seed, static_seed=static_seed) + def _set_seed(self, seed: Optional[int]): + """This method is not used in transformed envs.""" + pass + def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): if tensordict is not None: tensordict = tensordict.clone(recurse=False) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 06f9821eb8a..f119e9e2151 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -514,6 +514,10 @@ def close(self) -> None: def _shutdown_workers(self) -> None: raise NotImplementedError + def _set_seed(self, seed: Optional[int]): + """This method is not used in batched envs.""" + pass + def start(self) -> None: if not self.is_closed: raise RuntimeError("trying to start a environment that is not closed.") @@ -606,7 +610,9 @@ def _shutdown_workers(self) -> None: del self._envs @_check_start - def set_seed(self, seed: int, static_seed: bool = False) -> int: + def set_seed( + self, seed: Optional[int] = None, static_seed: bool = False + ) -> Optional[int]: for env in self._envs: new_seed = env.set_seed(seed, static_seed=static_seed) seed = new_seed @@ -816,7 +822,9 @@ def _shutdown_workers(self) -> None: del self.parent_channels @_check_start - def set_seed(self, seed: int, static_seed: bool = False) -> int: + def set_seed( + self, seed: Optional[int] = None, static_seed: bool = False + ) -> Optional[int]: self._seeds = [] for channel in self.parent_channels: channel.send(("seed", (seed, static_seed)))