Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use protocol to force all environment states to have a key #45

Merged
merged 13 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@
"""Abstract environment class"""

import abc
from typing import Any, Generic, Tuple, TypeVar
from typing import Any, Generic, Protocol, Tuple, TypeVar

import jax.random
from chex import PRNGKey
dluo96 marked this conversation as resolved.
Show resolved Hide resolved

from jumanji import specs
from jumanji.types import Action, TimeStep

State = TypeVar("State")

class StateProtocol(Protocol):
"""Enforce that the State for every Environment must implement a key."""

key: jax.random.PRNGKey
dluo96 marked this conversation as resolved.
Show resolved Hide resolved


State = TypeVar("State", bound="StateProtocol")


class Environment(abc.ABC, Generic[State]):
Expand Down
3 changes: 3 additions & 0 deletions jumanji/environments/combinatorial/cvrp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from typing import TYPE_CHECKING, NamedTuple

import jax.random

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
Expand Down Expand Up @@ -42,6 +44,7 @@ class State:
visited_mask: Array # (problem_size + 1,)
order: Array # (2 * problem_size,) - this size is worst-case (visit depot after each node)
num_total_visits: jnp.int32
key: jax.random.PRNGKey = jax.random.PRNGKey(0)
dluo96 marked this conversation as resolved.
Show resolved Hide resolved


class Observation(NamedTuple):
Expand Down
3 changes: 3 additions & 0 deletions jumanji/environments/combinatorial/knapsack/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from typing import TYPE_CHECKING, NamedTuple

import jax.random

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
Expand All @@ -40,6 +42,7 @@ class State:
used_mask: Array # (problem_size,)
num_steps: jnp.int32
remaining_budget: jnp.float32
key: jax.random.PRNGKey = jax.random.PRNGKey(0)
dluo96 marked this conversation as resolved.
Show resolved Hide resolved


class Observation(NamedTuple):
Expand Down
3 changes: 3 additions & 0 deletions jumanji/environments/combinatorial/tsp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from typing import TYPE_CHECKING, NamedTuple

import jax.random

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
Expand All @@ -38,6 +40,7 @@ class State:
visited_mask: Array # (problem_size,)
order: Array # (problem_size,)
num_visited: jnp.int32
key: jax.random.PRNGKey = jax.random.PRNGKey(0)
dluo96 marked this conversation as resolved.
Show resolved Hide resolved


class Observation(NamedTuple):
Expand Down
2 changes: 2 additions & 0 deletions jumanji/environments/games/connect4/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING

import jax.numpy as jnp
import jax.random
from chex import Array
from typing_extensions import TypeAlias

Expand All @@ -30,6 +31,7 @@
class State:
current_player: jnp.int8
board: Board
key: jax.random.PRNGKey = jax.random.PRNGKey(0)
dluo96 marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
from jax import jit, random

from jumanji import specs, tree_utils
from jumanji.env import Environment
from jumanji.env import Environment, State
from jumanji.types import Action, TimeStep, restart, termination, transition

State = TypeVar("State")
Observation = TypeVar("Observation")

# Type alias that corresponds to ObsType in the Gym API
Expand Down