Skip to content

Commit

Permalink
Making _set_seed abstract (#770)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Dec 29, 2022
1 parent 1ce4fc5 commit 578938a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 18 deletions.
18 changes: 8 additions & 10 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
from tensordict.tensordict import TensorDict, TensorDictBase
from torchrl._utils import seed_generator
from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
Expand Down Expand Up @@ -85,12 +84,9 @@ def __init__(
def maxstep(self):
return 100

def set_seed(self, seed: int, static_seed=False) -> int:
def _set_seed(self, seed: Optional[int]):
self.seed = seed
self.counter = seed % 17 # make counter a small number
if static_seed:
return seed
return seed_generator(seed)

def custom_fun(self):
return 0
Expand Down Expand Up @@ -136,14 +132,11 @@ def __init__(self, device):
super(MockSerialEnv, self).__init__(device=device)
self.is_closed = False

def set_seed(self, seed: int, static_seed: bool = False) -> int:
def _set_seed(self, seed: Optional[int]):
assert seed >= 1
self.seed = seed
self.counter = seed % 17 # make counter a small number
self.max_val = max(self.counter + 100, self.counter * 2)
if static_seed:
return seed
return seed_generator(seed)

def _step(self, tensordict):
self.counter += 1
Expand Down Expand Up @@ -207,9 +200,14 @@ def __init__(self, device, batch_size=None):
super(MockBatchedLockedEnv, self).__init__(device=device, batch_size=batch_size)
self.counter = 0

set_seed = MockSerialEnv.set_seed
rand_step = MockSerialEnv.rand_step

def _set_seed(self, seed: Optional[int]):
assert seed >= 1
self.seed = seed
self.counter = seed % 17 # make counter a small number
self.max_val = max(self.counter + 100, self.counter * 2)

def _step(self, tensordict):
self.counter += 1
# We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv
Expand Down
5 changes: 1 addition & 4 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def set_seed(
seed = new_seed
return seed

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError

Expand Down Expand Up @@ -824,10 +825,6 @@ def close(self) -> None:
except AttributeError:
pass

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError


def make_tensordict(
env: _EnvWrapper,
Expand Down
9 changes: 7 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from torchrl.envs.transforms.utils import check_finite
from torchrl.envs.utils import step_mdp


try:
from torchvision.transforms.functional import center_crop
from torchvision.transforms.functional_tensor import (
Expand Down Expand Up @@ -453,10 +452,16 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

return tensordict_out

def set_seed(self, seed: int, static_seed: bool = False) -> int:
def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
"""Set the seeds of the environment."""
return self.base_env.set_seed(seed, static_seed=static_seed)

def _set_seed(self, seed: Optional[int]):
"""This method is not used in transformed envs."""
pass

def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
if tensordict is not None:
tensordict = tensordict.clone(recurse=False)
Expand Down
12 changes: 10 additions & 2 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ def close(self) -> None:
def _shutdown_workers(self) -> None:
raise NotImplementedError

def _set_seed(self, seed: Optional[int]):
"""This method is not used in batched envs."""
pass

def start(self) -> None:
if not self.is_closed:
raise RuntimeError("trying to start a environment that is not closed.")
Expand Down Expand Up @@ -606,7 +610,9 @@ def _shutdown_workers(self) -> None:
del self._envs

@_check_start
def set_seed(self, seed: int, static_seed: bool = False) -> int:
def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
for env in self._envs:
new_seed = env.set_seed(seed, static_seed=static_seed)
seed = new_seed
Expand Down Expand Up @@ -816,7 +822,9 @@ def _shutdown_workers(self) -> None:
del self.parent_channels

@_check_start
def set_seed(self, seed: int, static_seed: bool = False) -> int:
def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
self._seeds = []
for channel in self.parent_channels:
channel.send(("seed", (seed, static_seed)))
Expand Down

0 comments on commit 578938a

Please sign in to comment.