Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chess] Use int8 #1259

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import numpy as np
from jax import Array, lax

EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # opponent: -1 * piece
EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = [jnp.int8(i) for i in range(7)] # opponent: -1 * piece
MAX_TERMINATION_STEPS = 512 # from AlphaZero paper

# prepare precomputed values here (e.g., available moves, map to label, etc.)

# index: a1: 0, a2: 1, ..., h8: 63
INIT_BOARD = jnp.int32([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4]) # fmt: skip
INIT_BOARD = jnp.int8([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4]) # fmt: skip
# 8 7 15 23 31 39 47 55 63
# 7 6 14 22 30 38 46 54 62
# 6 5 13 21 29 37 45 53 61
Expand Down Expand Up @@ -56,7 +56,7 @@
# 39 11 62
# 38 10 64
# 37 9 64
FROM_PLANE = -np.ones((64, 73), dtype=np.int32)
FROM_PLANE = -np.ones((64, 73), dtype=np.int8)
TO_PLANE = -np.ones((64, 64), dtype=np.int32) # ignores underpromotion
zeros, seq, rseq = [0] * 7, list(range(1, 8)), list(range(-7, 0))
# down, up, left, right, down-left, down-right, up-right, up-left, knight, and knight
Expand All @@ -83,8 +83,8 @@
INIT_LEGAL_ACTION_MASK[ixs] = True

LEGAL_DEST = -np.ones((7, 64, 27), np.int32) # LEGAL_DEST[0, :, :] == -1
LEGAL_DEST_NEAR = -np.ones((64, 16), np.int32) # king and knight moves
LEGAL_DEST_FAR = -np.ones((64, 19), np.int32) # queen moves except king moves
LEGAL_DEST_NEAR = -np.ones((64, 16), np.int8) # king and knight moves
LEGAL_DEST_FAR = -np.ones((64, 19), np.int8) # queen moves except king moves
CAN_MOVE = np.zeros((7, 64, 64), dtype=np.bool_)
for from_ in range(64):
legal_dest = {p: [] for p in range(7)}
Expand All @@ -93,26 +93,26 @@
continue
r0, c0, r1, c1 = from_ % 8, from_ // 8, to % 8, to // 8
if (r1 - r0 == 1 and abs(c1 - c0) <= 1) or ((r0, r1) == (1, 3) and abs(c1 - c0) == 0):
legal_dest[PAWN].append(to)
legal_dest[PAWN.item()].append(to)
if (abs(r1 - r0) == 1 and abs(c1 - c0) == 2) or (abs(r1 - r0) == 2 and abs(c1 - c0) == 1):
legal_dest[KNIGHT].append(to)
legal_dest[KNIGHT.item()].append(to)
if abs(r1 - r0) == abs(c1 - c0):
legal_dest[BISHOP].append(to)
legal_dest[BISHOP.item()].append(to)
if abs(r1 - r0) == 0 or abs(c1 - c0) == 0:
legal_dest[ROOK].append(to)
legal_dest[ROOK.item()].append(to)
if (abs(r1 - r0) == 0 or abs(c1 - c0) == 0) or (abs(r1 - r0) == abs(c1 - c0)):
legal_dest[QUEEN].append(to)
legal_dest[QUEEN.item()].append(to)
if from_ != to and abs(r1 - r0) <= 1 and abs(c1 - c0) <= 1:
legal_dest[KING].append(to)
legal_dest[KING.item()].append(to)
for p in range(1, 7):
LEGAL_DEST[p, from_, : len(legal_dest[p])] = legal_dest[p]
CAN_MOVE[p, from_, legal_dest[p]] = True
dests = list(set(legal_dest[KING]) | set(legal_dest[KNIGHT]))
dests = list(set(legal_dest[KING.item()]) | set(legal_dest[KNIGHT.item()]))
LEGAL_DEST_NEAR[from_, : len(dests)] = dests
dests = list(set(legal_dest[QUEEN]).difference(set(legal_dest[KING])))
dests = list(set(legal_dest[QUEEN.item()]).difference(set(legal_dest[KING.item()])))
LEGAL_DEST_FAR[from_, : len(dests)] = dests

BETWEEN = -np.ones((64, 64, 6), dtype=np.int32)
BETWEEN = -np.ones((64, 64, 6), dtype=np.int8)
for from_ in range(64):
for to in range(64):
r0, c0, r1, c1 = from_ % 8, from_ // 8, to % 8, to // 8
Expand All @@ -138,31 +138,31 @@


class GameState(NamedTuple):
color: Array = jnp.int32(0) # w: 0, b: 1
color: Array = jnp.int8(0) # w: 0, b: 1
board: Array = INIT_BOARD # (64,)
castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king
en_passant: Array = jnp.int32(-1)
halfmove_count: Array = jnp.int32(0) # number of moves since the last piece capture or pawn move
en_passant: Array = jnp.int8(-1)
halfmove_count: Array = jnp.int8(0) # number of moves since the last piece capture or pawn move
fullmove_count: Array = jnp.int32(1) # increase every black move
hash_history: Array = jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH)
board_history: Array = jnp.zeros((8, 64), dtype=jnp.int32).at[0, :].set(INIT_BOARD)
board_history: Array = jnp.zeros((8, 64), dtype=jnp.int8).at[0, :].set(INIT_BOARD)
legal_action_mask: Array = INIT_LEGAL_ACTION_MASK
step_count: Array = jnp.int32(0)


class Action(NamedTuple):
from_: Array = jnp.int32(-1)
to: Array = jnp.int32(-1)
underpromotion: Array = jnp.int32(-1) # 0: rook, 1: bishop, 2: knight
from_: Array = jnp.int8(-1)
to: Array = jnp.int8(-1)
underpromotion: Array = jnp.int8(-1) # 0: rook, 1: bishop, 2: knight

@staticmethod
def _from_label(label: Array):
from_, plane = label // 73, label % 73
underpromotion = lax.select(plane >= 9, -1, plane // 3)
from_, plane = jnp.int8(label // 73), jnp.int8(label % 73)
underpromotion = lax.select(plane >= 9, jnp.int8(-1), plane // 3)
return Action(from_=from_, to=FROM_PLANE[from_, plane], underpromotion=underpromotion)

def _to_label(self):
return self.from_ * 73 + TO_PLANE[self.from_, self.to]
return jnp.int32(self.from_) * 73 + TO_PLANE[self.from_, self.to]


class Game:
Expand Down Expand Up @@ -264,14 +264,14 @@ def _apply_move(state: GameState, a: Action) -> GameState:
is_en_passant = (state.en_passant >= 0) & (piece == PAWN) & (state.en_passant == a.to)
removed_pawn_pos = a.to - 1
state = state._replace(
board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, state.board[removed_pawn_pos]))
board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, jnp.int8(state.board[removed_pawn_pos])))
)
is_en_passant = (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2)
state = state._replace(en_passant=lax.select(is_en_passant, (a.to + a.from_) // 2, -1))
state = state._replace(en_passant=lax.select(is_en_passant, (a.to + a.from_) // 2, jnp.int8(-1)))
# update counters
captured = (state.board[a.to] < 0) | is_en_passant
state = state._replace(
halfmove_count=lax.select(captured | (piece == PAWN), 0, state.halfmove_count + 1),
halfmove_count=lax.select(captured | (piece == PAWN), jnp.int8(0), state.halfmove_count + 1),
fullmove_count=state.fullmove_count + jnp.int32(state.color == 1),
)
# castling
Expand All @@ -285,9 +285,9 @@ def _apply_move(state: GameState, a: Action) -> GameState:
cond = jnp.bool_([[(a.from_ != 32) & (a.from_ != 0), (a.from_ != 32) & (a.from_ != 56)], [a.to != 7, a.to != 63]])
state = state._replace(castling_rights=state.castling_rights & cond)
# promotion to queen
piece = lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, piece)
piece = lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, jnp.int8(piece))
# underpromotion
piece = lax.select(a.underpromotion < 0, piece, jnp.int32([ROOK, BISHOP, KNIGHT])[a.underpromotion])
piece = lax.select(a.underpromotion < 0, piece, jnp.int8([ROOK, BISHOP, KNIGHT])[a.underpromotion])
# actually move
state = state._replace(board=state.board.at[a.from_].set(EMPTY).at[a.to].set(piece)) # type: ignore
return state
Expand Down Expand Up @@ -365,7 +365,7 @@ def legal_labels(label):
can_castle_queen_side &= (b[0] == ROOK) & (b[8] == EMPTY) & (b[16] == EMPTY) & (b[24] == EMPTY) & (b[32] == KING)
can_castle_king_side = state.castling_rights[0, 1]
can_castle_king_side &= (b[32] == KING) & (b[40] == EMPTY) & (b[48] == EMPTY) & (b[56] == ROOK)
not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int32([16, 24, 32, 40, 48]))
not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int8([16, 24, 32, 40, 48]))
mask = mask.at[2364].set(mask[2364] | (can_castle_queen_side & not_checked[:3].all()))
mask = mask.at[2367].set(mask[2367] | (can_castle_king_side & not_checked[2:].all()))

Expand Down
10 changes: 5 additions & 5 deletions pgx/experimental/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,25 @@ def from_fen(fen: str):
castling_rights = jnp.bool_([["Q" in castling, "K" in castling], ["q" in castling, "k" in castling]])
if color == "b":
castling_rights = castling_rights[::-1]
mat = jnp.int32(arr).reshape(8, 8)
mat = jnp.int8(arr).reshape(8, 8)
if color == "b":
mat = -jnp.flip(mat, axis=0)
ep = jnp.int32(-1) if en_passant == "-" else jnp.int32("abcdefgh".index(en_passant[0]) * 8 + int(en_passant[1]) - 1)
ep = jnp.int8(-1) if en_passant == "-" else jnp.int8("abcdefgh".index(en_passant[0]) * 8 + int(en_passant[1]) - 1)
if color == "b" and ep >= 0:
ep = _flip_pos(ep)
x = GameState(
board=jnp.rot90(mat, k=3).flatten(),
color=jnp.int32(0) if color == "w" else jnp.int32(1),
color=jnp.int8(0) if color == "w" else jnp.int8(1),
castling_rights=castling_rights,
en_passant=ep,
halfmove_count=jnp.int32(halfmove_cnt),
halfmove_count=jnp.int8(halfmove_cnt),
fullmove_count=jnp.int32(fullmove_cnt),
)
legal_action_mask = jax.jit(_legal_action_mask)(x)
x = x._replace(legal_action_mask=legal_action_mask)
x = _update_history(x)

player_order = jnp.int32([0, 1])
player_order = jnp.int8([0, 1])
state = State(
_player_order=player_order,
_x=x,
Expand Down
Loading