Skip to content

Commit

Permalink
refactor: moved wrapper base class to wrappers.py (#45)
Browse files Browse the repository at this point in the history
Co-authored-by: Donal Byrne <d.byrne@instadeep.com>
  • Loading branch information
djbyrne and djbyrne authored Aug 18, 2022
1 parent 3c8d1e8 commit fe00491
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 83 deletions.
80 changes: 0 additions & 80 deletions jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,86 +124,6 @@ def __exit__(self, *args: Any) -> None:
self.close()


class Wrapper(Environment[State], Generic[State]):
"""Wraps the environment to allow modular transformations.
Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72
"""

def __init__(self, env: Environment):
super().__init__()
self._env = env

def __repr__(self) -> str:
return f"{self.__class__.__name__}({repr(self._env)})"

def __getattr__(self, name: str) -> Any:
if name == "__setstate__":
raise AttributeError(name)
return getattr(self._env, name)

@property
def unwrapped(self) -> Environment:
"""Returns the wrapped env."""
return self._env.unwrapped

def reset(self, key: PRNGKey) -> Tuple[State, TimeStep, Extra]:
"""Resets the environment to an initial state.
Args:
key: random key used to reset the environment.
Returns:
state: State object corresponding to the new state of the environment,
timestep: TimeStep object corresponding the first timestep returned by the environment,
extra: metrics, default to None.
"""
return self._env.reset(key)

def step(self, state: State, action: Action) -> Tuple[State, TimeStep, Extra]:
"""Run one timestep of the environment's dynamics.
Args:
state: State object containing the dynamics of the environment.
action: Array containing the action to take.
Returns:
state: State object corresponding to the next state of the environment,
timestep: TimeStep object corresponding the timestep returned by the environment,
extra: metrics, default to None.
"""
return self._env.step(state, action)

def observation_spec(self) -> specs.Spec:
"""Returns the observation spec."""
return self._env.observation_spec()

def action_spec(self) -> specs.Spec:
"""Returns the action spec."""
return self._env.action_spec()

def render(self, state: State) -> Any:
"""Compute render frames during initialisation of the environment.
Args:
state: State object containing the dynamics of the environment.
"""
return self._env.render(state)

def close(self) -> None:
"""Perform any necessary cleanup.
Environments will automatically :meth:`close()` themselves when
garbage collected or when the program exits.
"""
return self._env.close()

def __enter__(self) -> "Wrapper":
return self

def __exit__(self, *args: Any) -> None:
self.close()


def make_environment_spec(environment: Environment) -> specs.EnvironmentSpec:
"""Returns an `EnvironmentSpec` describing values used by an environment."""
return specs.EnvironmentSpec(
Expand Down
2 changes: 1 addition & 1 deletion jumanji/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import pytest
import pytest_mock

from jumanji.env import Wrapper
from jumanji.testing.fakes import FakeEnvironment, FakeState
from jumanji.wrappers import Wrapper


@pytest.fixture
Expand Down
94 changes: 92 additions & 2 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, ClassVar, Dict, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generic,
Optional,
Tuple,
TypeVar,
Union,
)

import dm_env.specs
import gym
Expand All @@ -25,13 +35,93 @@
from jax import jit, random

from jumanji import specs
from jumanji.env import Environment, Wrapper
from jumanji.env import Environment
from jumanji.types import Action, Extra, TimeStep, restart, termination, transition

State = TypeVar("State")
Observation = TypeVar("Observation")


class Wrapper(Environment[State], Generic[State]):
"""Wraps the environment to allow modular transformations.
Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72
"""

def __init__(self, env: Environment):
super().__init__()
self._env = env

def __repr__(self) -> str:
return f"{self.__class__.__name__}({repr(self._env)})"

def __getattr__(self, name: str) -> Any:
if name == "__setstate__":
raise AttributeError(name)
return getattr(self._env, name)

@property
def unwrapped(self) -> Environment:
"""Returns the wrapped env."""
return self._env.unwrapped

def reset(self, key: PRNGKey) -> Tuple[State, TimeStep, Extra]:
"""Resets the environment to an initial state.
Args:
key: random key used to reset the environment.
Returns:
state: State object corresponding to the new state of the environment,
timestep: TimeStep object corresponding the first timestep returned by the environment,
extra: metrics, default to None.
"""
return self._env.reset(key)

def step(self, state: State, action: Action) -> Tuple[State, TimeStep, Extra]:
"""Run one timestep of the environment's dynamics.
Args:
state: State object containing the dynamics of the environment.
action: Array containing the action to take.
Returns:
state: State object corresponding to the next state of the environment,
timestep: TimeStep object corresponding the timestep returned by the environment,
extra: metrics, default to None.
"""
return self._env.step(state, action)

def observation_spec(self) -> specs.Spec:
"""Returns the observation spec."""
return self._env.observation_spec()

def action_spec(self) -> specs.Spec:
"""Returns the action spec."""
return self._env.action_spec()

def render(self, state: State) -> Any:
"""Compute render frames during initialisation of the environment.
Args:
state: State object containing the dynamics of the environment.
"""
return self._env.render(state)

def close(self) -> None:
"""Perform any necessary cleanup.
Environments will automatically :meth:`close()` themselves when
garbage collected or when the program exits.
"""
return self._env.close()

def __enter__(self) -> "Wrapper":
return self

def __exit__(self, *args: Any) -> None:
self.close()


class JumanjiToDMEnvWrapper(dm_env.Environment):
"""A wrapper that converts Environment to dm_env.Environment."""

Expand Down

0 comments on commit fe00491

Please sign in to comment.