From fe0049112ebdef3924fa18de8fe9b378d186b6d5 Mon Sep 17 00:00:00 2001 From: Donal Byrne Date: Thu, 18 Aug 2022 14:32:29 +0100 Subject: [PATCH] refactor: moved wrapper base class to wrappers.py (#45) Co-authored-by: Donal Byrne --- jumanji/env.py | 80 -------------------------------------- jumanji/env_test.py | 2 +- jumanji/wrappers.py | 94 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 93 insertions(+), 83 deletions(-) diff --git a/jumanji/env.py b/jumanji/env.py index aab03449e..33df8ad41 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -124,86 +124,6 @@ def __exit__(self, *args: Any) -> None: self.close() -class Wrapper(Environment[State], Generic[State]): - """Wraps the environment to allow modular transformations. - Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72 - """ - - def __init__(self, env: Environment): - super().__init__() - self._env = env - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({repr(self._env)})" - - def __getattr__(self, name: str) -> Any: - if name == "__setstate__": - raise AttributeError(name) - return getattr(self._env, name) - - @property - def unwrapped(self) -> Environment: - """Returns the wrapped env.""" - return self._env.unwrapped - - def reset(self, key: PRNGKey) -> Tuple[State, TimeStep, Extra]: - """Resets the environment to an initial state. - - Args: - key: random key used to reset the environment. - - Returns: - state: State object corresponding to the new state of the environment, - timestep: TimeStep object corresponding the first timestep returned by the environment, - extra: metrics, default to None. - """ - return self._env.reset(key) - - def step(self, state: State, action: Action) -> Tuple[State, TimeStep, Extra]: - """Run one timestep of the environment's dynamics. - - Args: - state: State object containing the dynamics of the environment. - action: Array containing the action to take. - - Returns: - state: State object corresponding to the next state of the environment, - timestep: TimeStep object corresponding the timestep returned by the environment, - extra: metrics, default to None. - """ - return self._env.step(state, action) - - def observation_spec(self) -> specs.Spec: - """Returns the observation spec.""" - return self._env.observation_spec() - - def action_spec(self) -> specs.Spec: - """Returns the action spec.""" - return self._env.action_spec() - - def render(self, state: State) -> Any: - """Compute render frames during initialisation of the environment. - - Args: - state: State object containing the dynamics of the environment. - """ - return self._env.render(state) - - def close(self) -> None: - """Perform any necessary cleanup. - - Environments will automatically :meth:`close()` themselves when - garbage collected or when the program exits. - """ - return self._env.close() - - def __enter__(self) -> "Wrapper": - return self - - def __exit__(self, *args: Any) -> None: - self.close() - - def make_environment_spec(environment: Environment) -> specs.EnvironmentSpec: """Returns an `EnvironmentSpec` describing values used by an environment.""" return specs.EnvironmentSpec( diff --git a/jumanji/env_test.py b/jumanji/env_test.py index 295197c59..7c170799d 100644 --- a/jumanji/env_test.py +++ b/jumanji/env_test.py @@ -17,8 +17,8 @@ import pytest import pytest_mock -from jumanji.env import Wrapper from jumanji.testing.fakes import FakeEnvironment, FakeState +from jumanji.wrappers import Wrapper @pytest.fixture diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index 8c6934ef3..6f181c5c4 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -12,7 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, ClassVar, Dict, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Generic, + Optional, + Tuple, + TypeVar, + Union, +) import dm_env.specs import gym @@ -25,13 +35,93 @@ from jax import jit, random from jumanji import specs -from jumanji.env import Environment, Wrapper +from jumanji.env import Environment from jumanji.types import Action, Extra, TimeStep, restart, termination, transition State = TypeVar("State") Observation = TypeVar("Observation") +class Wrapper(Environment[State], Generic[State]): + """Wraps the environment to allow modular transformations. + Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72 + """ + + def __init__(self, env: Environment): + super().__init__() + self._env = env + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({repr(self._env)})" + + def __getattr__(self, name: str) -> Any: + if name == "__setstate__": + raise AttributeError(name) + return getattr(self._env, name) + + @property + def unwrapped(self) -> Environment: + """Returns the wrapped env.""" + return self._env.unwrapped + + def reset(self, key: PRNGKey) -> Tuple[State, TimeStep, Extra]: + """Resets the environment to an initial state. + + Args: + key: random key used to reset the environment. + + Returns: + state: State object corresponding to the new state of the environment, + timestep: TimeStep object corresponding the first timestep returned by the environment, + extra: metrics, default to None. + """ + return self._env.reset(key) + + def step(self, state: State, action: Action) -> Tuple[State, TimeStep, Extra]: + """Run one timestep of the environment's dynamics. + + Args: + state: State object containing the dynamics of the environment. + action: Array containing the action to take. + + Returns: + state: State object corresponding to the next state of the environment, + timestep: TimeStep object corresponding the timestep returned by the environment, + extra: metrics, default to None. + """ + return self._env.step(state, action) + + def observation_spec(self) -> specs.Spec: + """Returns the observation spec.""" + return self._env.observation_spec() + + def action_spec(self) -> specs.Spec: + """Returns the action spec.""" + return self._env.action_spec() + + def render(self, state: State) -> Any: + """Compute render frames during initialisation of the environment. + + Args: + state: State object containing the dynamics of the environment. + """ + return self._env.render(state) + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + return self._env.close() + + def __enter__(self) -> "Wrapper": + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + class JumanjiToDMEnvWrapper(dm_env.Environment): """A wrapper that converts Environment to dm_env.Environment."""