From 879f46d2909ba749af877a1377c45d4cc2731904 Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Fri, 11 Nov 2022 06:23:40 +0000 Subject: [PATCH 01/10] docs: updated docstrings on extras --- jumanji/types.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/jumanji/types.py b/jumanji/types.py index 8e6da53e3..07809bb78 100644 --- a/jumanji/types.py +++ b/jumanji/types.py @@ -72,8 +72,9 @@ class TimeStep(Generic[Observation]): also valid in place of a scalar array. extras: environment metrics or things to be seen by the agent but not directly observed (hence not in the observation) e.g. whether an invalid action was - taken, some environment metric, or the player whose turn it is. In most - environments, extras is None. + taken or some environment metric(s). In most environments, extras is None. + In particular, the extras should not contain any quantity that is meant to + be observed by the agent - such quantities should be in the observation. """ step_type: StepType @@ -103,8 +104,7 @@ def restart( observation: array or tree of arrays. extras: environment metrics or things to be seen by the agent but not directly observed (hence not in the observation) e.g. whether an invalid action was - taken, some environment metric, or the player whose turn it is. In most - environments, extras is None. + taken or some environment metric(s). In most environments, extras is None. shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. @@ -136,8 +136,7 @@ def transition( discount: array. extras: environment metrics or things to be seen by the agent but not directly observed (hence not in the observation) e.g. whether an invalid action was - taken, some environment metric, or the player whose turn it is. In most - environments, extras is None. + taken or some environment metric(s). In most environments, extras is None. shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. @@ -168,8 +167,7 @@ def termination( observation: array or tree of arrays. extras: environment metrics or things to be seen by the agent but not directly observed (hence not in the observation) e.g. whether an invalid action was - taken, some environment metric, or the player whose turn it is. In most - environments, extras is None. + taken or some environment metric(s). In most environments, extras is None. shape : optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. @@ -201,8 +199,7 @@ def truncation( discount: array. extras: environment metrics or things to be seen by the agent but not directly observed (hence not in the observation) e.g. whether an invalid action was - taken, some environment metric, or the player whose turn it is. In most - environments, extras is None. + taken or some environment metric(s). In most environments, extras is None. shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. From db4dc776a08a3ed3d1d58e2d2328c5a6174f5ad6 Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Fri, 11 Nov 2022 06:55:55 +0000 Subject: [PATCH 02/10] feat: added current player field to observation --- jumanji/environments/games/connect4/env.py | 20 ++++++++++++-------- jumanji/environments/games/connect4/specs.py | 6 ++++++ jumanji/environments/games/connect4/types.py | 1 + 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/jumanji/environments/games/connect4/env.py b/jumanji/environments/games/connect4/env.py index 2f69d9f02..bf9fcef6a 100644 --- a/jumanji/environments/games/connect4/env.py +++ b/jumanji/environments/games/connect4/env.py @@ -83,10 +83,11 @@ def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]: board = jnp.zeros((BOARD_HEIGHT, BOARD_WIDTH), dtype=jnp.int8) action_mask = jnp.ones((BOARD_WIDTH,), dtype=jnp.int8) - obs = Observation(board=board, action_mask=action_mask) + obs = Observation( + board=board, action_mask=action_mask, current_player=jnp.int8(0) + ) - extras = {"current_player": jnp.array(0, dtype=jnp.int8)} - timestep = restart(observation=obs, shape=(self.n_players,), extras=extras) + timestep = restart(observation=obs, shape=(self.n_players,)) state = State(current_player=jnp.int8(0), board=board) @@ -101,8 +102,8 @@ def step(self, state: State, action: Action) -> Tuple[State, TimeStep[Observatio Returns: state: State object corresponding to the next state of the environment, - timestep: TimeStep object corresponding to the timestep returned by the environment, - its `extras` field contains the current player id. + timestep: TimeStep object corresponding to the timestep returned by the environment. + Its `observation` attribute contains a field for the current player id. """ board = state.board @@ -141,7 +142,11 @@ def step(self, state: State, action: Action) -> Tuple[State, TimeStep[Observatio # creating next state next_state = State(current_player=jnp.int8(next_player), board=new_board) - obs = Observation(board=new_board, action_mask=action_mask) + obs = Observation( + board=new_board, + action_mask=action_mask, + current_player=jnp.int8(next_player), + ) timestep = lax.cond( done, @@ -149,13 +154,11 @@ def step(self, state: State, action: Action) -> Tuple[State, TimeStep[Observatio reward=reward, observation=obs, shape=(self.n_players,), - extras={"current_player": next_player}, ), lambda _: transition( reward=reward, observation=obs, shape=(self.n_players,), - extras={"current_player": next_player}, ), operand=None, ) @@ -177,6 +180,7 @@ def observation_spec(self) -> ObservationSpec: maximum=1, name="invalid_mask", ), + current_player=specs.Array(shape=(), dtype=jnp.int8, name="current_player"), ) def action_spec(self) -> specs.DiscreteArray: diff --git a/jumanji/environments/games/connect4/specs.py b/jumanji/environments/games/connect4/specs.py index 4c5f9a98f..3ca2bc252 100644 --- a/jumanji/environments/games/connect4/specs.py +++ b/jumanji/environments/games/connect4/specs.py @@ -23,22 +23,26 @@ def __init__( self, board_obs: specs.Array, action_mask: specs.Array, + current_player: specs.Array, ): name = ( "Observation(\n" f"\tboard: {board_obs.name},\n" f"\taction_mask: {action_mask.name},\n" + f"\tcurrent_player: {current_player.name},\n" ")" ) super().__init__(name=name) self.board_obs = board_obs self.action_mask = action_mask + self.current_player = current_player def __repr__(self) -> str: return ( "ObservationSpec(\n" f"\tboard_obs={repr(self.board_obs)},\n" f"\taction_mask={repr(self.action_mask)},\n" + f"\tcurrent_player={repr(self.current_player)},\n" ")" ) @@ -47,6 +51,7 @@ def generate_value(self) -> Observation: return Observation( board=self.board_obs.generate_value(), action_mask=self.action_mask.generate_value(), + current_player=self.current_player.generate_value(), ) def validate(self, value: Observation) -> Observation: @@ -64,6 +69,7 @@ def validate(self, value: Observation) -> Observation: observation = Observation( board=self.board_obs.validate(value.board), action_mask=self.action_mask.validate(value.action_mask), + current_player=self.current_player.validate(value.current_player), ) return observation diff --git a/jumanji/environments/games/connect4/types.py b/jumanji/environments/games/connect4/types.py index ea73b2fb9..957d8f9dd 100644 --- a/jumanji/environments/games/connect4/types.py +++ b/jumanji/environments/games/connect4/types.py @@ -35,3 +35,4 @@ class State: class Observation: board: Board action_mask: Array + current_player: Array From 034eb125ee10049aca1a823f431d3a81c58dc2fa Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Fri, 11 Nov 2022 10:32:43 +0000 Subject: [PATCH 03/10] feat: made current player discrete array in specs --- jumanji/environments/games/connect4/env.py | 5 ++++- jumanji/environments/games/connect4/specs.py | 11 +++++++---- jumanji/environments/games/connect4/types.py | 3 ++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/jumanji/environments/games/connect4/env.py b/jumanji/environments/games/connect4/env.py index bf9fcef6a..1c0d82d61 100644 --- a/jumanji/environments/games/connect4/env.py +++ b/jumanji/environments/games/connect4/env.py @@ -44,6 +44,7 @@ class Connect4(Environment[State]): - (-1) if it contains a token by the other player. - action_mask: jax array (bool) valid columns (actions) are identified with `True`, invalid ones with `False`. + - current_player: int, id of the current player {0, 1}. - action: Array containing the column to insert the token into {0, 1, 2, 3, 4, 5, 6} @@ -180,7 +181,9 @@ def observation_spec(self) -> ObservationSpec: maximum=1, name="invalid_mask", ), - current_player=specs.Array(shape=(), dtype=jnp.int8, name="current_player"), + current_player=specs.DiscreteArray( + num_values=self.n_players, dtype=jnp.int8, name="current_player" + ), ) def action_spec(self) -> specs.DiscreteArray: diff --git a/jumanji/environments/games/connect4/specs.py b/jumanji/environments/games/connect4/specs.py index 3ca2bc252..497f19c0b 100644 --- a/jumanji/environments/games/connect4/specs.py +++ b/jumanji/environments/games/connect4/specs.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any from jumanji import specs @@ -23,7 +22,7 @@ def __init__( self, board_obs: specs.Array, action_mask: specs.Array, - current_player: specs.Array, + current_player: specs.DiscreteArray, ): name = ( "Observation(\n" @@ -82,6 +81,10 @@ def replace(self, **kwargs: Any) -> "ObservationSpec": Returns: A new copy of `ObservationSpec`. """ - all_kwargs = {"board_obs": self.board_obs, "action_mask": self.action_mask} + all_kwargs = { + "board_obs": self.board_obs, + "action_mask": self.action_mask, + "current_player": self.current_player, + } all_kwargs.update(kwargs) - return ObservationSpec(**all_kwargs) + return ObservationSpec(**all_kwargs) # type: ignore diff --git a/jumanji/environments/games/connect4/types.py b/jumanji/environments/games/connect4/types.py index 957d8f9dd..bbbdccc46 100644 --- a/jumanji/environments/games/connect4/types.py +++ b/jumanji/environments/games/connect4/types.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING +import jax.numpy as jnp from chex import Array from typing_extensions import TypeAlias @@ -35,4 +36,4 @@ class State: class Observation: board: Board action_mask: Array - current_player: Array + current_player: jnp.int8 From 5c62e55a279b4f852f2156ace8c2385ce298a2ac Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Mon, 14 Nov 2022 10:31:02 +0000 Subject: [PATCH 04/10] refactor: cast to jnp int8 when next player is created --- jumanji/environments/games/connect4/env.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jumanji/environments/games/connect4/env.py b/jumanji/environments/games/connect4/env.py index 1c0d82d61..e391dd98e 100644 --- a/jumanji/environments/games/connect4/env.py +++ b/jumanji/environments/games/connect4/env.py @@ -129,7 +129,7 @@ def step(self, state: State, action: Action) -> Tuple[State, TimeStep[Observatio action_mask = get_action_mask(new_board) # switching player - next_player = (state.current_player + 1) % self.n_players + next_player = jnp.int8((state.current_player + 1) % self.n_players) # computing reward reward_value = compute_reward(invalid, winning) @@ -141,12 +141,12 @@ def step(self, state: State, action: Action) -> Tuple[State, TimeStep[Observatio reward = reward.at[next_player].set(-reward_value) # creating next state - next_state = State(current_player=jnp.int8(next_player), board=new_board) + next_state = State(current_player=next_player, board=new_board) obs = Observation( board=new_board, action_mask=action_mask, - current_player=jnp.int8(next_player), + current_player=next_player, ) timestep = lax.cond( From 4ffb7354c32a33c849b1f57d3d62a8b0200a2362 Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Mon, 14 Nov 2022 10:41:52 +0000 Subject: [PATCH 05/10] docs: updated description of extras --- jumanji/types.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/jumanji/types.py b/jumanji/types.py index 07809bb78..b5ad35ef4 100644 --- a/jumanji/types.py +++ b/jumanji/types.py @@ -70,11 +70,10 @@ class TimeStep(Generic[Observation]): observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. - extras: environment metrics or things to be seen by the agent but not directly - observed (hence not in the observation) e.g. whether an invalid action was - taken or some environment metric(s). In most environments, extras is None. - In particular, the extras should not contain any quantity that is meant to - be observed by the agent - such quantities should be in the observation. + extras: environment metric(s) or information returned by the environment but + not observed by the agent (hence not in the observation). For example, it + could be whether an invalid action was taken. In most environments, extras + is None. """ step_type: StepType @@ -102,9 +101,10 @@ def restart( Args: observation: array or tree of arrays. - extras: environment metrics or things to be seen by the agent but not directly - observed (hence not in the observation) e.g. whether an invalid action was - taken or some environment metric(s). In most environments, extras is None. + extras: environment metric(s) or information returned by the environment but + not observed by the agent (hence not in the observation). For example, it + could be whether an invalid action was taken. In most environments, extras + is None. shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. @@ -134,9 +134,10 @@ def transition( reward: array. observation: array or tree of arrays. discount: array. - extras: environment metrics or things to be seen by the agent but not directly - observed (hence not in the observation) e.g. whether an invalid action was - taken or some environment metric(s). In most environments, extras is None. + extras: environment metric(s) or information returned by the environment but + not observed by the agent (hence not in the observation). For example, it + could be whether an invalid action was taken. In most environments, extras + is None. shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. @@ -165,9 +166,10 @@ def termination( Args: reward: array. observation: array or tree of arrays. - extras: environment metrics or things to be seen by the agent but not directly - observed (hence not in the observation) e.g. whether an invalid action was - taken or some environment metric(s). In most environments, extras is None. + extras: environment metric(s) or information returned by the environment but + not observed by the agent (hence not in the observation). For example, it + could be whether an invalid action was taken. In most environments, extras + is None. shape : optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. @@ -197,9 +199,10 @@ def truncation( reward: array. observation: array or tree of arrays. discount: array. - extras: environment metrics or things to be seen by the agent but not directly - observed (hence not in the observation) e.g. whether an invalid action was - taken or some environment metric(s). In most environments, extras is None. + extras: environment metric(s) or information returned by the environment but + not observed by the agent (hence not in the observation). For example, it + could be whether an invalid action was taken. In most environments, extras + is None. shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. From 6cf626e05c4f86ac1fdbdb38ead86ad37865807b Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Mon, 14 Nov 2022 10:49:49 +0000 Subject: [PATCH 06/10] docs: updated docstring description of extras --- jumanji/environments/games/connect4/env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jumanji/environments/games/connect4/env.py b/jumanji/environments/games/connect4/env.py index e391dd98e..af0774910 100644 --- a/jumanji/environments/games/connect4/env.py +++ b/jumanji/environments/games/connect4/env.py @@ -77,8 +77,8 @@ def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]: Returns: state: State object corresponding to the new state of the environment, - timestep: TimeStep object corresponding to the first timestep returned by - the environment, its `extras` field contains the current player id. + timestep: TimeStep object corresponding to the first timestep returned by the + environment. Its `observation` attribute contains a field for the current player id. """ del key board = jnp.zeros((BOARD_HEIGHT, BOARD_WIDTH), dtype=jnp.int8) From b6e8b6edfb65697e6874bb3cfb5b927099113eef Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Mon, 14 Nov 2022 11:00:30 +0000 Subject: [PATCH 07/10] feat: removed extra from docstring --- jumanji/environments/combinatorial/tsp/env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jumanji/environments/combinatorial/tsp/env.py b/jumanji/environments/combinatorial/tsp/env.py index 44827f92f..bc7804349 100644 --- a/jumanji/environments/combinatorial/tsp/env.py +++ b/jumanji/environments/combinatorial/tsp/env.py @@ -104,7 +104,6 @@ def reset_from_state( state: State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. - extra: Not used. """ state = State( problem=problem, From 122b8e9084c5e5bcea6558333f1f8b84b3b2d340 Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Mon, 14 Nov 2022 15:32:02 +0000 Subject: [PATCH 08/10] feat: changed current player in connect4 state to be jnp int8 --- jumanji/environments/games/connect4/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/games/connect4/types.py b/jumanji/environments/games/connect4/types.py index bbbdccc46..4fc92a62e 100644 --- a/jumanji/environments/games/connect4/types.py +++ b/jumanji/environments/games/connect4/types.py @@ -28,7 +28,7 @@ @dataclass class State: - current_player: int + current_player: jnp.int8 board: Board From c7892f3e755376a61f3344b7f639b3f028d646a5 Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Mon, 14 Nov 2022 15:48:15 +0000 Subject: [PATCH 09/10] test: fixed failing connect4 test --- jumanji/environments/games/connect4/env_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/games/connect4/env_test.py b/jumanji/environments/games/connect4/env_test.py index 7c33e3e4f..67b6da22e 100644 --- a/jumanji/environments/games/connect4/env_test.py +++ b/jumanji/environments/games/connect4/env_test.py @@ -41,7 +41,7 @@ def test_connect4__reset(connect4_env: Connect4, empty_board: Array) -> None: assert isinstance(timestep, TimeStep) assert isinstance(state, State) assert state.current_player == 0 - assert timestep.extras["current_player"] == 0 # type: ignore + assert timestep.observation.current_player == 0 assert jnp.array_equal(state.board, empty_board) assert jnp.array_equal( timestep.observation.action_mask, jnp.ones((BOARD_WIDTH,), dtype=jnp.int8) From 541ef5a250e1f6be7b9947f6e9e84438ebbb9721 Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Mon, 14 Nov 2022 16:03:55 +0000 Subject: [PATCH 10/10] feat: updated type of current player in docstring --- jumanji/environments/games/connect4/env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jumanji/environments/games/connect4/env.py b/jumanji/environments/games/connect4/env.py index af0774910..28c4c94e9 100644 --- a/jumanji/environments/games/connect4/env.py +++ b/jumanji/environments/games/connect4/env.py @@ -44,7 +44,7 @@ class Connect4(Environment[State]): - (-1) if it contains a token by the other player. - action_mask: jax array (bool) valid columns (actions) are identified with `True`, invalid ones with `False`. - - current_player: int, id of the current player {0, 1}. + - current_player: jnp.int8, id of the current player {0, 1}. - action: Array containing the column to insert the token into {0, 1, 2, 3, 4, 5, 6} @@ -58,7 +58,7 @@ class Connect4(Environment[State]): - if a player plays an invalid move, this player loses and the game ends. - state: State - - current_player: int, id of the current player {0, 1}. + - current_player: jnp.int8, id of the current player {0, 1}. - board: jax array (int8) of shape (6, 7): each cell contains either: - 1 if it contains a token placed by the current player,