Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use protocol to force all environment states to have a key #45

Merged
merged 13 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@
"""Abstract environment class"""

import abc
from typing import Any, Generic, Tuple, TypeVar
from typing import Any, Generic, Protocol, Tuple, TypeVar

from chex import PRNGKey
import chex

from jumanji import specs
from jumanji.types import Action, TimeStep

State = TypeVar("State")

class StateProtocol(Protocol):
"""Enforce that the State for every Environment must implement a key."""

key: chex.PRNGKey


State = TypeVar("State", bound="StateProtocol")


class Environment(abc.ABC, Generic[State]):
Expand All @@ -36,7 +43,7 @@ def __repr__(self) -> str:
return "Environment."

@abc.abstractmethod
def reset(self, key: PRNGKey) -> Tuple[State, TimeStep]:
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
"""Resets the environment to an initial state.

Args:
Expand Down
4 changes: 4 additions & 0 deletions jumanji/environments/combinatorial/cvrp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

from typing import TYPE_CHECKING, NamedTuple

import jax.random

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
from chex import dataclass

import chex
import jax.numpy as jnp
from chex import Array

Expand All @@ -42,6 +45,7 @@ class State:
visited_mask: Array # (problem_size + 1,)
order: Array # (2 * problem_size,) - this size is worst-case (visit depot after each node)
num_total_visits: jnp.int32
key: chex.PRNGKey = jax.random.PRNGKey(0)


class Observation(NamedTuple):
Expand Down
4 changes: 4 additions & 0 deletions jumanji/environments/combinatorial/knapsack/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from typing import TYPE_CHECKING, NamedTuple

import chex
import jax.random

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
Expand All @@ -40,6 +43,7 @@ class State:
used_mask: Array # (problem_size,)
num_steps: jnp.int32
remaining_budget: jnp.float32
key: chex.PRNGKey = jax.random.PRNGKey(0)


class Observation(NamedTuple):
Expand Down
4 changes: 4 additions & 0 deletions jumanji/environments/combinatorial/tsp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from typing import TYPE_CHECKING, NamedTuple

import chex
import jax.random

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
Expand All @@ -38,6 +41,7 @@ class State:
visited_mask: Array # (problem_size,)
order: Array # (problem_size,)
num_visited: jnp.int32
key: chex.PRNGKey = jax.random.PRNGKey(0)


class Observation(NamedTuple):
Expand Down
3 changes: 3 additions & 0 deletions jumanji/environments/games/connect4/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from typing import TYPE_CHECKING

import chex
import jax.numpy as jnp
import jax.random
from chex import Array
from typing_extensions import TypeAlias

Expand All @@ -30,6 +32,7 @@
class State:
current_player: jnp.int8
board: Board
key: chex.PRNGKey = jax.random.PRNGKey(0)


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
from jax import jit, random

from jumanji import specs, tree_utils
from jumanji.env import Environment
from jumanji.env import Environment, State
from jumanji.types import Action, TimeStep, restart, termination, transition

State = TypeVar("State")
Observation = TypeVar("Observation")

# Type alias that corresponds to ObsType in the Gym API
Expand Down