Skip to content

Commit

Permalink
feat: use protocol to force all environment states to have a key (#45)
Browse files Browse the repository at this point in the history
Co-authored-by: Clément Bonnet <56230714+clement-bonnet@users.noreply.github.com>
  • Loading branch information
dluo96 and clement-bonnet authored Nov 15, 2022
1 parent c90e1d9 commit 9999660
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 6 deletions.
15 changes: 11 additions & 4 deletions jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@
"""Abstract environment class"""

import abc
from typing import Any, Generic, Tuple, TypeVar
from typing import Any, Generic, Protocol, Tuple, TypeVar

from chex import PRNGKey
import chex

from jumanji import specs
from jumanji.types import Action, TimeStep

State = TypeVar("State")

class StateProtocol(Protocol):
"""Enforce that the State for every Environment must implement a key."""

key: chex.PRNGKey


State = TypeVar("State", bound="StateProtocol")


class Environment(abc.ABC, Generic[State]):
Expand All @@ -36,7 +43,7 @@ def __repr__(self) -> str:
return "Environment."

@abc.abstractmethod
def reset(self, key: PRNGKey) -> Tuple[State, TimeStep]:
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
"""Resets the environment to an initial state.
Args:
Expand Down
4 changes: 4 additions & 0 deletions jumanji/environments/combinatorial/cvrp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

from typing import TYPE_CHECKING, NamedTuple

import jax.random

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
from chex import dataclass

import chex
import jax.numpy as jnp
from chex import Array

Expand All @@ -42,6 +45,7 @@ class State:
visited_mask: Array # (problem_size + 1,)
order: Array # (2 * problem_size,) - this size is worst-case (visit depot after each node)
num_total_visits: jnp.int32
key: chex.PRNGKey = jax.random.PRNGKey(0)


class Observation(NamedTuple):
Expand Down
4 changes: 4 additions & 0 deletions jumanji/environments/combinatorial/knapsack/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from typing import TYPE_CHECKING, NamedTuple

import chex
import jax.random

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
Expand All @@ -40,6 +43,7 @@ class State:
used_mask: Array # (problem_size,)
num_steps: jnp.int32
remaining_budget: jnp.float32
key: chex.PRNGKey = jax.random.PRNGKey(0)


class Observation(NamedTuple):
Expand Down
4 changes: 4 additions & 0 deletions jumanji/environments/combinatorial/tsp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from typing import TYPE_CHECKING, NamedTuple

import chex
import jax.random

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
Expand All @@ -38,6 +41,7 @@ class State:
visited_mask: Array # (problem_size,)
order: Array # (problem_size,)
num_visited: jnp.int32
key: chex.PRNGKey = jax.random.PRNGKey(0)


class Observation(NamedTuple):
Expand Down
3 changes: 3 additions & 0 deletions jumanji/environments/games/connect4/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from typing import TYPE_CHECKING

import chex
import jax.numpy as jnp
import jax.random
from chex import Array
from typing_extensions import TypeAlias

Expand All @@ -30,6 +32,7 @@
class State:
current_player: jnp.int8
board: Board
key: chex.PRNGKey = jax.random.PRNGKey(0)


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
from jax import jit, random

from jumanji import specs, tree_utils
from jumanji.env import Environment
from jumanji.env import Environment, State
from jumanji.types import Action, TimeStep, restart, termination, transition

State = TypeVar("State")
Observation = TypeVar("Observation")

# Type alias that corresponds to ObsType in the Gym API
Expand Down

0 comments on commit 9999660

Please sign in to comment.