From 999966053339de9ed0244ff8f760edfc30c8cf9c Mon Sep 17 00:00:00 2001 From: Daniel <57721552+dluo96@users.noreply.github.com> Date: Tue, 15 Nov 2022 17:15:19 +0000 Subject: [PATCH] feat: use protocol to force all environment states to have a key (#45) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clément Bonnet <56230714+clement-bonnet@users.noreply.github.com> --- jumanji/env.py | 15 +++++++++++---- jumanji/environments/combinatorial/cvrp/types.py | 4 ++++ .../environments/combinatorial/knapsack/types.py | 4 ++++ jumanji/environments/combinatorial/tsp/types.py | 4 ++++ jumanji/environments/games/connect4/types.py | 3 +++ jumanji/wrappers.py | 3 +-- 6 files changed, 27 insertions(+), 6 deletions(-) diff --git a/jumanji/env.py b/jumanji/env.py index 8199fabb5..7949a8031 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -15,14 +15,21 @@ """Abstract environment class""" import abc -from typing import Any, Generic, Tuple, TypeVar +from typing import Any, Generic, Protocol, Tuple, TypeVar -from chex import PRNGKey +import chex from jumanji import specs from jumanji.types import Action, TimeStep -State = TypeVar("State") + +class StateProtocol(Protocol): + """Enforce that the State for every Environment must implement a key.""" + + key: chex.PRNGKey + + +State = TypeVar("State", bound="StateProtocol") class Environment(abc.ABC, Generic[State]): @@ -36,7 +43,7 @@ def __repr__(self) -> str: return "Environment." @abc.abstractmethod - def reset(self, key: PRNGKey) -> Tuple[State, TimeStep]: + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """Resets the environment to an initial state. Args: diff --git a/jumanji/environments/combinatorial/cvrp/types.py b/jumanji/environments/combinatorial/cvrp/types.py index 1c976f427..b156912a0 100644 --- a/jumanji/environments/combinatorial/cvrp/types.py +++ b/jumanji/environments/combinatorial/cvrp/types.py @@ -14,11 +14,14 @@ from typing import TYPE_CHECKING, NamedTuple +import jax.random + if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 from dataclasses import dataclass else: from chex import dataclass +import chex import jax.numpy as jnp from chex import Array @@ -42,6 +45,7 @@ class State: visited_mask: Array # (problem_size + 1,) order: Array # (2 * problem_size,) - this size is worst-case (visit depot after each node) num_total_visits: jnp.int32 + key: chex.PRNGKey = jax.random.PRNGKey(0) class Observation(NamedTuple): diff --git a/jumanji/environments/combinatorial/knapsack/types.py b/jumanji/environments/combinatorial/knapsack/types.py index d659924a3..37911af15 100644 --- a/jumanji/environments/combinatorial/knapsack/types.py +++ b/jumanji/environments/combinatorial/knapsack/types.py @@ -14,6 +14,9 @@ from typing import TYPE_CHECKING, NamedTuple +import chex +import jax.random + if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 from dataclasses import dataclass else: @@ -40,6 +43,7 @@ class State: used_mask: Array # (problem_size,) num_steps: jnp.int32 remaining_budget: jnp.float32 + key: chex.PRNGKey = jax.random.PRNGKey(0) class Observation(NamedTuple): diff --git a/jumanji/environments/combinatorial/tsp/types.py b/jumanji/environments/combinatorial/tsp/types.py index 3b3e540b1..ed9bed3fd 100644 --- a/jumanji/environments/combinatorial/tsp/types.py +++ b/jumanji/environments/combinatorial/tsp/types.py @@ -14,6 +14,9 @@ from typing import TYPE_CHECKING, NamedTuple +import chex +import jax.random + if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 from dataclasses import dataclass else: @@ -38,6 +41,7 @@ class State: visited_mask: Array # (problem_size,) order: Array # (problem_size,) num_visited: jnp.int32 + key: chex.PRNGKey = jax.random.PRNGKey(0) class Observation(NamedTuple): diff --git a/jumanji/environments/games/connect4/types.py b/jumanji/environments/games/connect4/types.py index 4fc92a62e..52c8d6552 100644 --- a/jumanji/environments/games/connect4/types.py +++ b/jumanji/environments/games/connect4/types.py @@ -14,7 +14,9 @@ from typing import TYPE_CHECKING +import chex import jax.numpy as jnp +import jax.random from chex import Array from typing_extensions import TypeAlias @@ -30,6 +32,7 @@ class State: current_player: jnp.int8 board: Board + key: chex.PRNGKey = jax.random.PRNGKey(0) @dataclass diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index c9074852d..b941ada0c 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -35,10 +35,9 @@ from jax import jit, random from jumanji import specs, tree_utils -from jumanji.env import Environment +from jumanji.env import Environment, State from jumanji.types import Action, TimeStep, restart, termination, transition -State = TypeVar("State") Observation = TypeVar("Observation") # Type alias that corresponds to ObsType in the Gym API