Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve consistency of extras #44

Merged
1 change: 0 additions & 1 deletion jumanji/environments/combinatorial/tsp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def reset_from_state(
state: State object corresponding to the new state of the environment.
timestep: TimeStep object corresponding to the first timestep returned by the
environment.
extra: Not used.
"""
state = State(
problem=problem,
Expand Down
33 changes: 20 additions & 13 deletions jumanji/environments/games/connect4/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Connect4(Environment[State]):
- (-1) if it contains a token by the other player.
- action_mask: jax array (bool)
valid columns (actions) are identified with `True`, invalid ones with `False`.
- current_player: jnp.int8, id of the current player {0, 1}.

- action: Array containing the column to insert the token into {0, 1, 2, 3, 4, 5, 6}

Expand All @@ -57,7 +58,7 @@ class Connect4(Environment[State]):
- if a player plays an invalid move, this player loses and the game ends.

- state: State
- current_player: int, id of the current player {0, 1}.
- current_player: jnp.int8, id of the current player {0, 1}.
- board: jax array (int8) of shape (6, 7):
each cell contains either:
- 1 if it contains a token placed by the current player,
Expand All @@ -76,17 +77,18 @@ def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]:

Returns:
state: State object corresponding to the new state of the environment,
timestep: TimeStep object corresponding to the first timestep returned by
the environment, its `extras` field contains the current player id.
timestep: TimeStep object corresponding to the first timestep returned by the
environment. Its `observation` attribute contains a field for the current player id.
"""
del key
board = jnp.zeros((BOARD_HEIGHT, BOARD_WIDTH), dtype=jnp.int8)
action_mask = jnp.ones((BOARD_WIDTH,), dtype=jnp.int8)

obs = Observation(board=board, action_mask=action_mask)
obs = Observation(
board=board, action_mask=action_mask, current_player=jnp.int8(0)
)

extras = {"current_player": jnp.array(0, dtype=jnp.int8)}
timestep = restart(observation=obs, shape=(self.n_players,), extras=extras)
timestep = restart(observation=obs, shape=(self.n_players,))

state = State(current_player=jnp.int8(0), board=board)

Expand All @@ -101,8 +103,8 @@ def step(self, state: State, action: Action) -> Tuple[State, TimeStep[Observatio

Returns:
state: State object corresponding to the next state of the environment,
timestep: TimeStep object corresponding to the timestep returned by the environment,
its `extras` field contains the current player id.
timestep: TimeStep object corresponding to the timestep returned by the environment.
Its `observation` attribute contains a field for the current player id.
"""
board = state.board

Expand All @@ -127,7 +129,7 @@ def step(self, state: State, action: Action) -> Tuple[State, TimeStep[Observatio
action_mask = get_action_mask(new_board)

# switching player
next_player = (state.current_player + 1) % self.n_players
next_player = jnp.int8((state.current_player + 1) % self.n_players)

# computing reward
reward_value = compute_reward(invalid, winning)
Expand All @@ -139,23 +141,25 @@ def step(self, state: State, action: Action) -> Tuple[State, TimeStep[Observatio
reward = reward.at[next_player].set(-reward_value)

# creating next state
next_state = State(current_player=jnp.int8(next_player), board=new_board)
next_state = State(current_player=next_player, board=new_board)

obs = Observation(board=new_board, action_mask=action_mask)
obs = Observation(
board=new_board,
action_mask=action_mask,
current_player=next_player,
)

timestep = lax.cond(
done,
lambda _: termination(
reward=reward,
observation=obs,
shape=(self.n_players,),
extras={"current_player": next_player},
),
lambda _: transition(
reward=reward,
observation=obs,
shape=(self.n_players,),
extras={"current_player": next_player},
),
operand=None,
)
Expand All @@ -177,6 +181,9 @@ def observation_spec(self) -> ObservationSpec:
maximum=1,
name="invalid_mask",
),
current_player=specs.DiscreteArray(
num_values=self.n_players, dtype=jnp.int8, name="current_player"
),
)

def action_spec(self) -> specs.DiscreteArray:
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/games/connect4/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_connect4__reset(connect4_env: Connect4, empty_board: Array) -> None:
assert isinstance(timestep, TimeStep)
assert isinstance(state, State)
assert state.current_player == 0
assert timestep.extras["current_player"] == 0 # type: ignore
assert timestep.observation.current_player == 0
assert jnp.array_equal(state.board, empty_board)
assert jnp.array_equal(
timestep.observation.action_mask, jnp.ones((BOARD_WIDTH,), dtype=jnp.int8)
Expand Down
15 changes: 12 additions & 3 deletions jumanji/environments/games/connect4/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from jumanji import specs
Expand All @@ -23,22 +22,26 @@ def __init__(
self,
board_obs: specs.Array,
action_mask: specs.Array,
current_player: specs.DiscreteArray,
):
name = (
"Observation(\n"
f"\tboard: {board_obs.name},\n"
f"\taction_mask: {action_mask.name},\n"
f"\tcurrent_player: {current_player.name},\n"
")"
)
super().__init__(name=name)
self.board_obs = board_obs
self.action_mask = action_mask
self.current_player = current_player

def __repr__(self) -> str:
return (
"ObservationSpec(\n"
f"\tboard_obs={repr(self.board_obs)},\n"
f"\taction_mask={repr(self.action_mask)},\n"
f"\tcurrent_player={repr(self.current_player)},\n"
")"
)

Expand All @@ -47,6 +50,7 @@ def generate_value(self) -> Observation:
return Observation(
board=self.board_obs.generate_value(),
action_mask=self.action_mask.generate_value(),
current_player=self.current_player.generate_value(),
)

def validate(self, value: Observation) -> Observation:
Expand All @@ -64,6 +68,7 @@ def validate(self, value: Observation) -> Observation:
observation = Observation(
board=self.board_obs.validate(value.board),
action_mask=self.action_mask.validate(value.action_mask),
current_player=self.current_player.validate(value.current_player),
)
return observation

Expand All @@ -76,6 +81,10 @@ def replace(self, **kwargs: Any) -> "ObservationSpec":
Returns:
A new copy of `ObservationSpec`.
"""
all_kwargs = {"board_obs": self.board_obs, "action_mask": self.action_mask}
all_kwargs = {
"board_obs": self.board_obs,
"action_mask": self.action_mask,
"current_player": self.current_player,
}
all_kwargs.update(kwargs)
return ObservationSpec(**all_kwargs)
return ObservationSpec(**all_kwargs) # type: ignore
dluo96 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion jumanji/environments/games/connect4/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import TYPE_CHECKING

import jax.numpy as jnp
from chex import Array
from typing_extensions import TypeAlias

Expand All @@ -27,11 +28,12 @@

@dataclass
class State:
current_player: int
current_player: jnp.int8
board: Board


@dataclass
class Observation:
board: Board
action_mask: Array
current_player: jnp.int8
dluo96 marked this conversation as resolved.
Show resolved Hide resolved
40 changes: 20 additions & 20 deletions jumanji/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ class TimeStep(Generic[Observation]):
observation: A NumPy array, or a nested dict, list or tuple of arrays.
Scalar values that can be cast to NumPy arrays (e.g. Python floats) are
also valid in place of a scalar array.
extras: environment metrics or things to be seen by the agent but not directly
observed (hence not in the observation) e.g. whether an invalid action was
taken, some environment metric, or the player whose turn it is. In most
environments, extras is None.
extras: environment metric(s) or information returned by the environment but
not observed by the agent (hence not in the observation). For example, it
could be whether an invalid action was taken. In most environments, extras
is None.
"""

step_type: StepType
Expand Down Expand Up @@ -101,10 +101,10 @@ def restart(

Args:
observation: array or tree of arrays.
extras: environment metrics or things to be seen by the agent but not directly
observed (hence not in the observation) e.g. whether an invalid action was
taken, some environment metric, or the player whose turn it is. In most
environments, extras is None.
extras: environment metric(s) or information returned by the environment but
not observed by the agent (hence not in the observation). For example, it
could be whether an invalid action was taken. In most environments, extras
is None.
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
Expand Down Expand Up @@ -134,10 +134,10 @@ def transition(
reward: array.
observation: array or tree of arrays.
discount: array.
extras: environment metrics or things to be seen by the agent but not directly
observed (hence not in the observation) e.g. whether an invalid action was
taken, some environment metric, or the player whose turn it is. In most
environments, extras is None.
extras: environment metric(s) or information returned by the environment but
not observed by the agent (hence not in the observation). For example, it
could be whether an invalid action was taken. In most environments, extras
is None.
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
Expand Down Expand Up @@ -166,10 +166,10 @@ def termination(
Args:
reward: array.
observation: array or tree of arrays.
extras: environment metrics or things to be seen by the agent but not directly
observed (hence not in the observation) e.g. whether an invalid action was
taken, some environment metric, or the player whose turn it is. In most
environments, extras is None.
extras: environment metric(s) or information returned by the environment but
not observed by the agent (hence not in the observation). For example, it
could be whether an invalid action was taken. In most environments, extras
is None.
shape : optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
Expand Down Expand Up @@ -199,10 +199,10 @@ def truncation(
reward: array.
observation: array or tree of arrays.
discount: array.
extras: environment metrics or things to be seen by the agent but not directly
observed (hence not in the observation) e.g. whether an invalid action was
taken, some environment metric, or the player whose turn it is. In most
environments, extras is None.
extras: environment metric(s) or information returned by the environment but
not observed by the agent (hence not in the observation). For example, it
could be whether an invalid action was taken. In most environments, extras
is None.
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
Expand Down