diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 433fdb5e0..7e8a0ed1d 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -19,13 +19,13 @@ import numpy as np from jax import Array, lax -EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # opponent: -1 * piece +EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = [jnp.int8(i) for i in range(7)] # opponent: -1 * piece MAX_TERMINATION_STEPS = 512 # from AlphaZero paper # prepare precomputed values here (e.g., available moves, map to label, etc.) # index: a1: 0, a2: 1, ..., h8: 63 -INIT_BOARD = jnp.int32([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4]) # fmt: skip +INIT_BOARD = jnp.int8([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4]) # fmt: skip # 8 7 15 23 31 39 47 55 63 # 7 6 14 22 30 38 46 54 62 # 6 5 13 21 29 37 45 53 61 @@ -56,7 +56,7 @@ # 39 11 62 # 38 10 64 # 37 9 64 -FROM_PLANE = -np.ones((64, 73), dtype=np.int32) +FROM_PLANE = -np.ones((64, 73), dtype=np.int8) TO_PLANE = -np.ones((64, 64), dtype=np.int32) # ignores underpromotion zeros, seq, rseq = [0] * 7, list(range(1, 8)), list(range(-7, 0)) # down, up, left, right, down-left, down-right, up-right, up-left, knight, and knight @@ -83,8 +83,8 @@ INIT_LEGAL_ACTION_MASK[ixs] = True LEGAL_DEST = -np.ones((7, 64, 27), np.int32) # LEGAL_DEST[0, :, :] == -1 -LEGAL_DEST_NEAR = -np.ones((64, 16), np.int32) # king and knight moves -LEGAL_DEST_FAR = -np.ones((64, 19), np.int32) # queen moves except king moves +LEGAL_DEST_NEAR = -np.ones((64, 16), np.int8) # king and knight moves +LEGAL_DEST_FAR = -np.ones((64, 19), np.int8) # queen moves except king moves CAN_MOVE = np.zeros((7, 64, 64), dtype=np.bool_) for from_ in range(64): legal_dest = {p: [] for p in range(7)} @@ -93,26 +93,26 @@ continue r0, c0, r1, c1 = from_ % 8, from_ // 8, to % 8, to // 8 if (r1 - r0 == 1 and abs(c1 - c0) <= 1) or ((r0, r1) == (1, 3) and abs(c1 - c0) == 0): - legal_dest[PAWN].append(to) + legal_dest[PAWN.item()].append(to) if (abs(r1 - r0) == 1 and abs(c1 - c0) == 2) or (abs(r1 - r0) == 2 and abs(c1 - c0) == 1): - legal_dest[KNIGHT].append(to) + legal_dest[KNIGHT.item()].append(to) if abs(r1 - r0) == abs(c1 - c0): - legal_dest[BISHOP].append(to) + legal_dest[BISHOP.item()].append(to) if abs(r1 - r0) == 0 or abs(c1 - c0) == 0: - legal_dest[ROOK].append(to) + legal_dest[ROOK.item()].append(to) if (abs(r1 - r0) == 0 or abs(c1 - c0) == 0) or (abs(r1 - r0) == abs(c1 - c0)): - legal_dest[QUEEN].append(to) + legal_dest[QUEEN.item()].append(to) if from_ != to and abs(r1 - r0) <= 1 and abs(c1 - c0) <= 1: - legal_dest[KING].append(to) + legal_dest[KING.item()].append(to) for p in range(1, 7): LEGAL_DEST[p, from_, : len(legal_dest[p])] = legal_dest[p] CAN_MOVE[p, from_, legal_dest[p]] = True - dests = list(set(legal_dest[KING]) | set(legal_dest[KNIGHT])) + dests = list(set(legal_dest[KING.item()]) | set(legal_dest[KNIGHT.item()])) LEGAL_DEST_NEAR[from_, : len(dests)] = dests - dests = list(set(legal_dest[QUEEN]).difference(set(legal_dest[KING]))) + dests = list(set(legal_dest[QUEEN.item()]).difference(set(legal_dest[KING.item()]))) LEGAL_DEST_FAR[from_, : len(dests)] = dests -BETWEEN = -np.ones((64, 64, 6), dtype=np.int32) +BETWEEN = -np.ones((64, 64, 6), dtype=np.int8) for from_ in range(64): for to in range(64): r0, c0, r1, c1 = from_ % 8, from_ // 8, to % 8, to // 8 @@ -138,31 +138,31 @@ class GameState(NamedTuple): - color: Array = jnp.int32(0) # w: 0, b: 1 + color: Array = jnp.int8(0) # w: 0, b: 1 board: Array = INIT_BOARD # (64,) castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king - en_passant: Array = jnp.int32(-1) - halfmove_count: Array = jnp.int32(0) # number of moves since the last piece capture or pawn move + en_passant: Array = jnp.int8(-1) + halfmove_count: Array = jnp.int8(0) # number of moves since the last piece capture or pawn move fullmove_count: Array = jnp.int32(1) # increase every black move hash_history: Array = jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH) - board_history: Array = jnp.zeros((8, 64), dtype=jnp.int32).at[0, :].set(INIT_BOARD) + board_history: Array = jnp.zeros((8, 64), dtype=jnp.int8).at[0, :].set(INIT_BOARD) legal_action_mask: Array = INIT_LEGAL_ACTION_MASK step_count: Array = jnp.int32(0) class Action(NamedTuple): - from_: Array = jnp.int32(-1) - to: Array = jnp.int32(-1) - underpromotion: Array = jnp.int32(-1) # 0: rook, 1: bishop, 2: knight + from_: Array = jnp.int8(-1) + to: Array = jnp.int8(-1) + underpromotion: Array = jnp.int8(-1) # 0: rook, 1: bishop, 2: knight @staticmethod def _from_label(label: Array): - from_, plane = label // 73, label % 73 - underpromotion = lax.select(plane >= 9, -1, plane // 3) + from_, plane = jnp.int8(label // 73), jnp.int8(label % 73) + underpromotion = lax.select(plane >= 9, jnp.int8(-1), plane // 3) return Action(from_=from_, to=FROM_PLANE[from_, plane], underpromotion=underpromotion) def _to_label(self): - return self.from_ * 73 + TO_PLANE[self.from_, self.to] + return jnp.int32(self.from_) * 73 + TO_PLANE[self.from_, self.to] class Game: @@ -264,14 +264,14 @@ def _apply_move(state: GameState, a: Action) -> GameState: is_en_passant = (state.en_passant >= 0) & (piece == PAWN) & (state.en_passant == a.to) removed_pawn_pos = a.to - 1 state = state._replace( - board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, state.board[removed_pawn_pos])) + board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, jnp.int8(state.board[removed_pawn_pos]))) ) is_en_passant = (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2) - state = state._replace(en_passant=lax.select(is_en_passant, (a.to + a.from_) // 2, -1)) + state = state._replace(en_passant=lax.select(is_en_passant, (a.to + a.from_) // 2, jnp.int8(-1))) # update counters captured = (state.board[a.to] < 0) | is_en_passant state = state._replace( - halfmove_count=lax.select(captured | (piece == PAWN), 0, state.halfmove_count + 1), + halfmove_count=lax.select(captured | (piece == PAWN), jnp.int8(0), state.halfmove_count + 1), fullmove_count=state.fullmove_count + jnp.int32(state.color == 1), ) # castling @@ -285,9 +285,9 @@ def _apply_move(state: GameState, a: Action) -> GameState: cond = jnp.bool_([[(a.from_ != 32) & (a.from_ != 0), (a.from_ != 32) & (a.from_ != 56)], [a.to != 7, a.to != 63]]) state = state._replace(castling_rights=state.castling_rights & cond) # promotion to queen - piece = lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, piece) + piece = lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, jnp.int8(piece)) # underpromotion - piece = lax.select(a.underpromotion < 0, piece, jnp.int32([ROOK, BISHOP, KNIGHT])[a.underpromotion]) + piece = lax.select(a.underpromotion < 0, piece, jnp.int8([ROOK, BISHOP, KNIGHT])[a.underpromotion]) # actually move state = state._replace(board=state.board.at[a.from_].set(EMPTY).at[a.to].set(piece)) # type: ignore return state @@ -365,7 +365,7 @@ def legal_labels(label): can_castle_queen_side &= (b[0] == ROOK) & (b[8] == EMPTY) & (b[16] == EMPTY) & (b[24] == EMPTY) & (b[32] == KING) can_castle_king_side = state.castling_rights[0, 1] can_castle_king_side &= (b[32] == KING) & (b[40] == EMPTY) & (b[48] == EMPTY) & (b[56] == ROOK) - not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int32([16, 24, 32, 40, 48])) + not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int8([16, 24, 32, 40, 48])) mask = mask.at[2364].set(mask[2364] | (can_castle_queen_side & not_checked[:3].all())) mask = mask.at[2367].set(mask[2367] | (can_castle_king_side & not_checked[2:].all())) diff --git a/pgx/experimental/chess.py b/pgx/experimental/chess.py index 36313d069..1a2d3e1c7 100644 --- a/pgx/experimental/chess.py +++ b/pgx/experimental/chess.py @@ -52,25 +52,25 @@ def from_fen(fen: str): castling_rights = jnp.bool_([["Q" in castling, "K" in castling], ["q" in castling, "k" in castling]]) if color == "b": castling_rights = castling_rights[::-1] - mat = jnp.int32(arr).reshape(8, 8) + mat = jnp.int8(arr).reshape(8, 8) if color == "b": mat = -jnp.flip(mat, axis=0) - ep = jnp.int32(-1) if en_passant == "-" else jnp.int32("abcdefgh".index(en_passant[0]) * 8 + int(en_passant[1]) - 1) + ep = jnp.int8(-1) if en_passant == "-" else jnp.int8("abcdefgh".index(en_passant[0]) * 8 + int(en_passant[1]) - 1) if color == "b" and ep >= 0: ep = _flip_pos(ep) x = GameState( board=jnp.rot90(mat, k=3).flatten(), - color=jnp.int32(0) if color == "w" else jnp.int32(1), + color=jnp.int8(0) if color == "w" else jnp.int8(1), castling_rights=castling_rights, en_passant=ep, - halfmove_count=jnp.int32(halfmove_cnt), + halfmove_count=jnp.int8(halfmove_cnt), fullmove_count=jnp.int32(fullmove_cnt), ) legal_action_mask = jax.jit(_legal_action_mask)(x) x = x._replace(legal_action_mask=legal_action_mask) x = _update_history(x) - player_order = jnp.int32([0, 1]) + player_order = jnp.int8([0, 1]) state = State( _player_order=player_order, _x=x,