diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index 323995bdd..ba8115c16 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -23,12 +23,7 @@ from jumanji import specs from jumanji.env import Environment from jumanji.environments.logic.game_2048.types import Board, Observation, State -from jumanji.environments.logic.game_2048.utils import ( - move_down, - move_left, - move_right, - move_up, -) +from jumanji.environments.logic.game_2048.utils import can_move, move from jumanji.environments.logic.game_2048.viewer import Game2048Viewer from jumanji.types import TimeStep, restart, termination, transition from jumanji.viewer import Viewer @@ -181,11 +176,7 @@ def step( timestep: the next timestep. """ # Take the action in the environment: Up, Right, Down, Left. - updated_board, additional_reward = jax.lax.switch( - action, - [move_up, move_right, move_down, move_left], - state.board, - ) + updated_board, reward = move(state.board, action) # Generate new key. random_cell_key, new_state_key = jax.random.split(state.key) @@ -209,7 +200,7 @@ def step( action_mask=action_mask, step_count=state.step_count + 1, key=new_state_key, - score=state.score + additional_reward.astype(float), + score=state.score + reward, ) # Generate the observation from the environment state. @@ -227,12 +218,12 @@ def step( timestep = jax.lax.cond( done, lambda: termination( - reward=additional_reward, + reward=reward, observation=observation, extras=extras, ), lambda: transition( - reward=additional_reward, + reward=reward, observation=observation, extras=extras, ), @@ -303,15 +294,7 @@ def _get_action_mask(self, board: Board) -> chex.Array: Returns: action_mask: action mask for the current state of the environment. """ - action_mask = jnp.array( - [ - jnp.any(move_up(board, final_shift=False)[0] != board), - jnp.any(move_right(board, final_shift=False)[0] != board), - jnp.any(move_down(board, final_shift=False)[0] != board), - jnp.any(move_left(board, final_shift=False)[0] != board), - ], - ) - return action_mask + return jax.vmap(can_move, (None, 0))(board, jnp.arange(4)) def render(self, state: State) -> Optional[NDArray]: """Renders the current state of the game board. diff --git a/jumanji/environments/logic/game_2048/utils.py b/jumanji/environments/logic/game_2048/utils.py index ae17c3e66..0cea3389a 100644 --- a/jumanji/environments/logic/game_2048/utils.py +++ b/jumanji/environments/logic/game_2048/utils.py @@ -12,229 +12,244 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Tuple +from typing import NamedTuple, Tuple +import chex import jax import jax.numpy as jnp -from jax.numpy import DeviceArray from jumanji.environments.logic.game_2048.types import Board -def shift_nonzero_element(carry: Tuple) -> Tuple[DeviceArray, int]: - """Shift nonzero element from index i to index j and increment j. - For example, in the case of this column [2, 0, 2, 0], this method will be invoked - when `i` equals 0 and 2, and it will return successively ([2, 0, 2, 0], `j` = 1) - and ([2, 2, 2, 0], `j` = 2). - - Args: - carry: - col: a column of the board. - i: the current index. - j: the index of the nonzero element. It also represents the number of nonzero - elements that have been shifted so far. - - Returns: - A tuple containing the updated array (col) and the incremented target index (j). - """ - col, j, i = carry - col = col.at[j].set(col[i]) - j += 1 - return col, j - - -def shift_column_elements_up(carry: Tuple, i: int) -> Tuple[DeviceArray, None]: - """This method calls `shift_nonzero_element` to shift non-zero elements in the column, - and conducts the identity operation if the element is zero. - - Agrs: - carry: - col: a one-dimensional array representing a column of the board. - j: the index of the non zero element. It also represents the number of non-zero - elements that have been shifted so far. - i: the current index. - - Returns: - A tuple containing the updated column and None. - """ - col, j = carry - col, j = jax.lax.cond( - col[i] != 0, - shift_nonzero_element, - lambda col_j_i: col_j_i[:2], - (col, j, i), +def transform_board(board: Board, action: int) -> Board: + """Transform board so that move_left is analagous to move_action. Also, transform back.""" + return jax.lax.switch( + action, + [ + lambda: jnp.transpose(board), + lambda: jnp.flip(board, 1), + lambda: jnp.flip(jnp.transpose(board)), + lambda: board, + ], ) - return (col, j), None - - -def fill_with_zero(carry: Tuple[DeviceArray, int]) -> Tuple[DeviceArray, int]: - """Fill the remaining elements of the column with zeros after shifting non-zero elements to the up. - For example: if the initial column is [2, 0, 2, 0] then this method will be invoked when `j` - equals to 2 and 3. - - Args: - carry: - col: a column of the board. - j: the index of the nonzero element. It also represents the number of nonzero - elements that have been shifted so far. - - Returns: - A tuple containing the updated column and incremented index. - """ - col, j = carry - col = col.at[j].set(0) - j += 1 - return col, j - - -def shift_up(col: DeviceArray) -> DeviceArray: - """Shift all the elements in a column up. - For example: [2, 0, 2, 0] -> [2, 2, 0, 0] - - Args: - col: a column of the board. - - Returns: - The modified column with all the elements shifted up. - """ - j = 0 - (col, j), _ = jax.lax.scan( # In example: [2, 0, 2, 0] -> [2, 2, 2, 0] - f=shift_column_elements_up, init=(col, j), xs=jnp.arange(len(col)) + + +class CanMoveCarry(NamedTuple): + """Carry value for while loop in can_move_left_row.""" + + can_move: bool + row: chex.Array + target_idx: int + origin_idx: int + + @property + def target(self) -> chex.Numeric: + """Tile at target index of row.""" + return self.row[self.target_idx] + + @property + def origin(self) -> chex.Numeric: + """Tile at origin index of row.""" + return self.row[self.origin_idx] + + +def can_move_left_row_cond(carry: CanMoveCarry) -> chex.Numeric: + """Terminate loop when valid move is found or origin reaches end of row.""" + return ~carry.can_move & (carry.origin_idx < carry.row.shape[0]) + + +def can_move_left_row_body(carry: CanMoveCarry) -> CanMoveCarry: + """Check if the current tiles can move and increment the indices.""" + # Check if tiles can move + can_move = (carry.origin != 0) & ( + (carry.target == 0) | (carry.target == carry.origin) ) - col, j = jax.lax.while_loop( # In example: [2, 2, 2, 0] -> [2, 2, 0, 0] - lambda col_j: col_j[1] < len(col_j[0]), - fill_with_zero, - (col, j), + + # Increment indices as if performed a no op + # If not performing no op, loop will be terminated anyways + target_idx = carry.target_idx + (carry.origin != 0) + origin_idx = jax.lax.select( + (carry.origin == 0) | (target_idx == carry.origin_idx), + carry.origin_idx + 1, + carry.origin_idx, ) - return col - - -def merge_elements(carry: Tuple) -> Tuple[DeviceArray, float]: - """Merge two adjacent elements in a column. - For example: col = [1, 1, 2, 2] and i = 2 -> [1, 1, 3, 0], with a reward equal to 2³. - - Args: - carry: a tuple containing the current state of the column, the current index, - and the current reward. - - Returns: - A tuple containing the modified column, and the updated reward. - """ - col, reward, i = carry - new_col_i = col[i] + 1 - col = col.at[i].set(new_col_i) - col = col.at[i + 1].set(0) - reward += 2**new_col_i - return col, reward - - -def merge_equal_elements( - carry: Tuple[DeviceArray, float], i: int -) -> Tuple[Tuple[DeviceArray, float], None]: - """This function merges adjacent non-zero elements in the column of the board, if the - two adjacent elements are equal. - This function will examine each element individually to locate two adjacent equal elements. - For example in the case of [1, 1, 2, 2], this method will call `merge_elements` for `i` equals - to 0 and 2. - - Args: - carry: a tuple containing the current state of the column, and the current reward. - i: the current index. - - Returns: - Tuple containing the updated column and the reward. - """ - col, reward = carry - col, reward = jax.lax.cond( - ((col[i] != 0) & (col[i] == col[i + 1])), - merge_elements, - lambda col_reward_i: col_reward_i[:2], - (col, reward, i), + + # Return updated carry + return carry._replace( + can_move=can_move, target_idx=target_idx, origin_idx=origin_idx ) - return (col, reward), None -def merge_col(col: DeviceArray) -> Tuple[DeviceArray, float]: - """Merge the elements of a column according to the rules of the 2048 game. - For example: [0, 0, 2, 2] -> [0, 0, 3, 0] with a reward equal to 2³. +def can_move_left_row(row: chex.Array) -> bool: + """Check if row can move left.""" + carry = CanMoveCarry(can_move=False, row=row, target_idx=0, origin_idx=1) + can_move: bool = jax.lax.while_loop( + can_move_left_row_cond, can_move_left_row_body, carry + )[0] + return can_move + + +def can_move_left(board: Board) -> bool: + """Check if board can move left.""" + can_move: bool = jax.vmap(can_move_left_row)(board).any() + return can_move + + +def can_move(board: Board, action: int) -> bool: + """Check if board can move with action.""" + return can_move_left(transform_board(board, action)) + + +def can_move_up(board: Board) -> bool: + """Check if board can move up.""" + return can_move(board, 0) + + +def can_move_right(board: Board) -> bool: + """Check if board can move right.""" + return can_move(board, 1) + + +def can_move_down(board: Board) -> bool: + """Check if board can move down.""" + return can_move(board, 2) + + +class MoveUpdate(NamedTuple): + """Update to move carry.""" + + target: chex.Numeric + origin: chex.Numeric + additional_reward: float + target_idx: int + origin_idx: int - Args: - col: a column of the board. - Returns: - A tuple containing the modified column and the total reward obtained by - merging the elements. - """ - reward = 0.0 - elements_indices = jnp.arange(len(col) - 1) - (col, reward), _ = jax.lax.scan( - f=merge_equal_elements, init=(col, reward), xs=elements_indices +class MoveCarry(NamedTuple): + """Carry value for while loop in move_left_row.""" + + row: chex.Array + reward: float + target_idx: int + origin_idx: int + + @property + def target(self) -> chex.Numeric: + """Tile at target index of row.""" + return self.row[self.target_idx] + + @property + def origin(self) -> chex.Numeric: + """Tile at origin index of row.""" + return self.row[self.origin_idx] + + def update(self, update: MoveUpdate) -> "MoveCarry": + """Return new updated carry. This method will cause row to be copied when called within a + jax conditional primative such as `jax.lax.cond` or `jax.lax.switch`. + """ + # Update row + row = self.row + row = row.at[self.target_idx].set(update.target) + row = row.at[self.origin_idx].set(update.origin) + + # Return updated carry + return self._replace( + row=row, + reward=self.reward + update.additional_reward, + target_idx=update.target_idx, + origin_idx=update.origin_idx, + ) + + +def no_op(carry: MoveCarry) -> MoveUpdate: + """Return a move update equivalent to performing a no op.""" + target_idx = carry.target_idx + (carry.origin != 0) + origin_idx = jax.lax.select( + (carry.origin == 0) | (target_idx == carry.origin_idx), + carry.origin_idx + 1, + carry.origin_idx, ) - return col, reward - - -def move_up_col( - carry: Tuple[Board, float], c: int, final_shift: bool = True -) -> Tuple[Tuple[Board, float], None]: - """Move the elements in the specified column up and merge those that are equal in - a single pass. `final_shift` is not needed when computing the action mask - this is - because creating the action mask only requires knowledge of whether the board will - have changed as a result of the action. - - For example: [2, 2, 1, 1] -> [3, 2, 0, 0]. - - Args: - carry: tuple containing the board and the additional reward. - c: column index to perform the move and merge on. - final_shift: is a flag to determine if the column should be shifted up once or - twice. In the "get_action_mask" method, it is set to False, as the purpose is - to check if the action is allowed and one shift is enough for this determination. - - Returns: - Tuple containing the updated board and the additional reward. - """ - board, additional_reward = carry - col = board[:, c] - col = shift_up(col) # In example: [2, 2, 1, 1] -> [2, 2, 1, 1] - col, reward = merge_col(col) # In example: [2, 2, 1, 1] -> [3, 0, 2, 0] - if final_shift: - col = shift_up(col) # In example: [3, 0, 2, 0] -> [3, 2, 0, 0] - additional_reward += reward - return (board.at[:, c].set(col), additional_reward), None - - -def move_up(board: Board, final_shift: bool = True) -> Tuple[Board, float]: - """Move up.""" - additional_reward = 0.0 - col_indices = jnp.arange(board.shape[0]) # Board of size 4 -> [0, 1, 2, 3] - (board, additional_reward), _ = jax.lax.scan( - f=functools.partial(move_up_col, final_shift=final_shift), - init=(board, additional_reward), - xs=col_indices, + return MoveUpdate( + target=carry.target, + origin=carry.origin, + additional_reward=0.0, + target_idx=target_idx, + origin_idx=origin_idx, ) - return board, additional_reward -def move_down(board: Board, final_shift: bool = True) -> Tuple[Board, float]: - """Move down.""" - board, additional_reward = move_up( - board=jnp.flip(board, 0), final_shift=final_shift +def shift(carry: MoveCarry) -> MoveUpdate: + """Return a move update equivalent to shifting origin to target.""" + return MoveUpdate( + target=carry.origin, + origin=0, + additional_reward=0.0, + target_idx=carry.target_idx, + origin_idx=carry.origin_idx + 1, ) - return jnp.flip(board, 0), additional_reward -def move_left(board: Board, final_shift: bool = True) -> Tuple[Board, float]: - """Move left.""" - board, additional_reward = move_up( - board=jnp.rot90(board, k=-1), final_shift=final_shift +def merge(carry: MoveCarry) -> MoveUpdate: + """Return a move update equivalent to merging origin with target.""" + return MoveUpdate( + target=carry.target + 1, + origin=0, + additional_reward=2.0 ** (carry.target + 1), + target_idx=carry.target_idx + 1, + origin_idx=carry.origin_idx + 1, ) - return jnp.rot90(board, k=1), additional_reward -def move_right(board: Board, final_shift: bool = True) -> Tuple[Board, float]: - """Move right.""" - board, additional_reward = move_up( - board=jnp.rot90(board, k=1), final_shift=final_shift - ) - return jnp.rot90(board, k=-1), additional_reward +def move_left_row_cond(carry: MoveCarry) -> chex.Numeric: + """Terminate loop when origin reaches end of row.""" + return carry.origin_idx < carry.row.shape[0] + + +def move_left_row_body(carry: MoveCarry) -> MoveCarry: + """Move the current tiles and increment the indices.""" + # Determine move type + can_shift = (carry.origin != 0) & (carry.target == 0) + can_merge = (carry.origin != 0) & (carry.target == carry.origin) + move_type = can_shift.astype(int) + 2 * can_merge.astype(int) + + # Get update based on move type + update = jax.lax.switch(move_type, [no_op, shift, merge], carry) + + # Return updated carry + return carry.update(update) + + +def move_left_row(row: chex.Array) -> Tuple[chex.Array, float]: + """Move the row left.""" + carry = MoveCarry(row=row, reward=0.0, target_idx=0, origin_idx=1) + row, reward, *_ = jax.lax.while_loop(move_left_row_cond, move_left_row_body, carry) + return row, reward + + +def move_left(board: Board) -> Tuple[Board, float]: + """Move the board left.""" + board, reward = jax.vmap(move_left_row)(board) + return board, reward.sum() + + +def move(board: Board, action: int) -> Tuple[Board, float]: + """Move the board with action.""" + board = transform_board(board, action) + board, reward = move_left(board) + board = transform_board(board, action) + return board, reward + + +def move_up(board: Board) -> Tuple[Board, float]: + """Move the board up.""" + return move(board, 0) + + +def move_right(board: Board) -> Tuple[Board, float]: + """Move the board right.""" + return move(board, 1) + + +def move_down(board: Board) -> Tuple[Board, float]: + """Move the board down.""" + return move(board, 2) diff --git a/jumanji/environments/logic/game_2048/utils_test.py b/jumanji/environments/logic/game_2048/utils_test.py index 0e27c8dcc..5d38983b9 100644 --- a/jumanji/environments/logic/game_2048/utils_test.py +++ b/jumanji/environments/logic/game_2048/utils_test.py @@ -17,6 +17,10 @@ from jumanji.environments.logic.game_2048.types import Board from jumanji.environments.logic.game_2048.utils import ( + can_move_down, + can_move_left, + can_move_right, + can_move_up, move_down, move_left, move_right, @@ -72,6 +76,38 @@ def board8x8() -> Board: return board +def test_can_move_down(board: Board, another_board: Board) -> None: + """Test checking if the board can move down.""" + assert can_move_down(board) + assert can_move_down(another_board) + board = jnp.array([[0, 0, 0, 0], [1, 0, 0, 0], [2, 1, 0, 0], [3, 2, 1, 0]]) + assert ~can_move_down(board) + + +def test_can_move_up(board: Board, another_board: Board) -> None: + """Test checking if the board can move up.""" + assert can_move_up(board) + assert can_move_up(another_board) + board = jnp.array([[4, 2, 1, 0], [3, 1, 0, 0], [2, 0, 0, 0], [1, 0, 0, 0]]) + assert ~can_move_up(board) + + +def test_can_move_right(board: Board, another_board: Board) -> None: + """Test checking if the board can move right.""" + assert can_move_right(board) + assert can_move_right(another_board) + board = jnp.array([[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 2], [0, 1, 2, 3]]) + assert ~can_move_right(board) + + +def test_can_move_left(board: Board, another_board: Board) -> None: + """Test checking if the board can move left.""" + assert can_move_left(board) + assert can_move_left(another_board) + board = jnp.array([[1, 2, 3, 4], [1, 2, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0]]) + assert ~can_move_left(board) + + def test_move_down(board: Board, another_board: Board) -> None: """Test shifting the board cells down.""" # First example.