From 06b4b012d7b6a28e3e45892e285d6b0ff998dfc2 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Fri, 12 Jul 2024 23:41:48 +0100 Subject: [PATCH 01/48] Curent Prog Bench: 2151642 --- Cargo.toml | 1 + src/mcts.rs | 4 +- src/tree.rs | 10 ++--- src/tree/edge.rs | 96 ++++++++++++++++++++++++++++++------------------ src/tree/hash.rs | 8 ++-- 5 files changed, 72 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 202ad786..04f80dbd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ panic = 'abort' strip = true lto = true codegen-units = 1 +overflow-checks = true [dependencies] goober = { git = 'https://github.com/jw1912/goober.git' } diff --git a/src/mcts.rs b/src/mcts.rs index 49a0bd71..47176db5 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -188,7 +188,7 @@ impl<'a> Searcher<'a> { // probe hash table to use in place of network if self.tree[ptr].state() == GameState::Ongoing { if let Some(entry) = self.tree.probe_hash(hash) { - 1.0 - entry.wins / entry.visits as f32 + 1.0 - entry.q } else { self.get_utility(ptr, pos) } @@ -227,7 +227,7 @@ impl<'a> Searcher<'a> { self.tree.edge_mut(parent, action).update(u); let edge = self.tree.edge(parent, action); - self.tree.push_hash(hash, edge.visits(), edge.wins()); + self.tree.push_hash(hash, edge.visits(), edge.q()); self.tree.propogate_proven_mates(ptr, child_state); diff --git a/src/tree.rs b/src/tree.rs index d0d8e4a6..cce5f9b5 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -195,7 +195,7 @@ impl Tree { pub fn make_root_node(&mut self, node: i32) { self.root = node; - self.parent_edge = *self.edge(self[node].parent(), self[node].action()); + self.parent_edge = self.edge(self[node].parent(), self[node].action()).clone(); self[node].clear_parent(); self[node].set_state(GameState::Ongoing); } @@ -358,10 +358,10 @@ impl Tree { pub fn display(&self, idx: i32, depth: usize) { let mut bars = vec![true; depth + 1]; - self.display_recurse(Edge::new(idx, 0, 0), depth + 1, 0, &mut bars); + self.display_recurse(&Edge::new(idx, 0, 0), depth + 1, 0, &mut bars); } - fn display_recurse(&self, edge: Edge, depth: usize, ply: usize, bars: &mut [bool]) { + fn display_recurse(&self, edge: &Edge, depth: usize, ply: usize, bars: &mut [bool]) { let node = &self[edge.ptr()]; if depth == 0 { @@ -402,7 +402,7 @@ impl Tree { } let mut active = Vec::new(); - for &action in node.actions() { + for action in node.actions() { if action.ptr() != -1 { active.push(action); } @@ -410,7 +410,7 @@ impl Tree { let end = active.len() - 1; - for (i, &action) in active.iter().enumerate() { + for (i, action) in active.iter().enumerate() { if i == end { bars[ply] = false; } diff --git a/src/tree/edge.rs b/src/tree/edge.rs index 0e4eb7f1..8736f3ca 100644 --- a/src/tree/edge.rs +++ b/src/tree/edge.rs @@ -1,22 +1,37 @@ -#[derive(Clone, Copy, Debug)] +use std::sync::atomic::{AtomicI32, AtomicU16, AtomicU32, AtomicI16, Ordering}; + +#[derive(Debug)] pub struct Edge { - ptr: i32, - mov: u16, - policy: i16, - visits: i32, - wins: f32, - sq_wins: f32, + ptr: AtomicI32, + mov: AtomicU16, + policy: AtomicI16, + visits: AtomicI32, + q: AtomicU32, + sq_q: AtomicU32, +} + +impl Clone for Edge { + fn clone(&self) -> Self { + Self { + ptr: AtomicI32::new(self.ptr()), + mov: AtomicU16::new(self.mov()), + policy: AtomicI16::new(self.policy.load(Ordering::SeqCst)), + visits: AtomicI32::new(self.visits()), + q: AtomicU32::new(self.q.load(Ordering::SeqCst)), + sq_q: AtomicU32::new(self.sq_q.load(Ordering::SeqCst)), + } + } } impl Default for Edge { fn default() -> Self { Self { - ptr: -1, - mov: 0, - policy: 0, - visits: 0, - wins: 0.0, - sq_wins: 0.0, + ptr: AtomicI32::new(-1), + mov: AtomicU16::new(0), + policy: AtomicI16::new(0), + visits: AtomicI32::new(0), + q: AtomicU32::new(0), + sq_q: AtomicU32::new(0), } } } @@ -24,56 +39,65 @@ impl Default for Edge { impl Edge { pub fn new(ptr: i32, mov: u16, policy: i16) -> Self { Self { - ptr, - mov, - policy, - visits: 0, - wins: 0.0, - sq_wins: 0.0, + ptr: AtomicI32::new(ptr), + mov: AtomicU16::new(mov), + policy: AtomicI16::new(policy), + visits: AtomicI32::new(0), + q: AtomicU32::new(0), + sq_q: AtomicU32::new(0), } } pub fn ptr(&self) -> i32 { - self.ptr + self.ptr.load(Ordering::SeqCst) } pub fn mov(&self) -> u16 { - self.mov + self.mov.load(Ordering::SeqCst) } - pub fn policy(&self) -> f32 { - f32::from(self.policy) / f32::from(i16::MAX) + pub fn visits(&self) -> i32 { + self.visits.load(Ordering::SeqCst) } - pub fn visits(&self) -> i32 { - self.visits + pub fn policy(&self) -> f32 { + f32::from(self.policy.load(Ordering::SeqCst)) / f32::from(i16::MAX) } - pub fn wins(&self) -> f32 { - self.wins + fn q64(&self) -> f64 { + f64::from(self.q.load(Ordering::SeqCst)) / f64::from(u32::MAX) } pub fn q(&self) -> f32 { - self.wins / self.visits as f32 + self.q64() as f32 + } + + pub fn sq_q(&self) -> f64 { + f64::from(self.sq_q.load(Ordering::SeqCst)) / f64::from(u32::MAX) } pub fn var(&self) -> f32 { - let v = self.visits as f32; - let var = self.sq_wins / v - (self.wins / v).powi(2); - var.max(0.0) + (self.sq_q() - self.q64().powi(2)).max(0.0) as f32 } pub fn set_ptr(&mut self, ptr: i32) { - self.ptr = ptr; + self.ptr.store(ptr, Ordering::SeqCst); } pub fn set_policy(&mut self, policy: f32) { - self.policy = (policy * f32::from(i16::MAX)) as i16 + self.policy.store((policy * f32::from(i16::MAX)) as i16, Ordering::SeqCst) } pub fn update(&mut self, result: f32) { - self.visits += 1; - self.wins += result; - self.sq_wins += result.powi(2); + let r = f64::from(result); + let v = f64::from(self.visits()); + + let q = (self.q64() * v + r) / (v + 1.0); + let sq_q = (self.sq_q() * v + r.powi(2)) / (v + 1.0); + + self.q.store((q * f64::from(u32::MAX)) as u32, Ordering::SeqCst); + self.sq_q.store((sq_q * f64::from(u32::MAX)) as u32, Ordering::SeqCst); + + self.visits.fetch_add(1, Ordering::SeqCst); } } diff --git a/src/tree/hash.rs b/src/tree/hash.rs index ebb0b358..41c9e0a0 100644 --- a/src/tree/hash.rs +++ b/src/tree/hash.rs @@ -2,7 +2,7 @@ pub struct HashEntry { pub hash: u64, pub visits: i32, - pub wins: f32, + pub q: f32, } impl Default for HashEntry { @@ -10,7 +10,7 @@ impl Default for HashEntry { Self { hash: 0, visits: 0, - wins: 0.0, + q: 0.0, } } } @@ -47,8 +47,8 @@ impl HashTable { } } - pub fn push(&mut self, hash: u64, visits: i32, wins: f32) { + pub fn push(&mut self, hash: u64, visits: i32, q: f32) { let idx = hash % (self.table.len() as u64); - self.table[idx as usize] = HashEntry { hash, visits, wins }; + self.table[idx as usize] = HashEntry { hash, visits, q }; } } From 93ecd8bc9cf16d101e92a5048ff922bd2a029262 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Fri, 12 Jul 2024 23:52:52 +0100 Subject: [PATCH 02/48] Current again Bench: 2130284 --- src/mcts.rs | 2 +- src/tree/hash.rs | 68 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 47176db5..1c9a6381 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -188,7 +188,7 @@ impl<'a> Searcher<'a> { // probe hash table to use in place of network if self.tree[ptr].state() == GameState::Ongoing { if let Some(entry) = self.tree.probe_hash(hash) { - 1.0 - entry.q + 1.0 - entry.q() } else { self.get_utility(ptr, pos) } diff --git a/src/tree/hash.rs b/src/tree/hash.rs index 41c9e0a0..ffa53502 100644 --- a/src/tree/hash.rs +++ b/src/tree/hash.rs @@ -1,47 +1,74 @@ -#[derive(Clone, Copy, Debug)] +use std::sync::atomic::{AtomicU64, Ordering}; + +#[derive(Clone, Copy, Debug, Default)] pub struct HashEntry { - pub hash: u64, + pub hash: u16, pub visits: i32, - pub q: f32, + q: u16, } -impl Default for HashEntry { - fn default() -> Self { - Self { - hash: 0, - visits: 0, - q: 0.0, +impl HashEntry { + pub fn q(&self) -> f32 { + f32::from(self.q) / f32::from(u16::MAX) + } +} + +#[derive(Default)] +struct HashEntryInternal(AtomicU64); + +impl Clone for HashEntryInternal { + fn clone(&self) -> Self { + Self(AtomicU64::new(self.0.load(Ordering::SeqCst))) + } +} + +impl From<&HashEntryInternal> for HashEntry { + fn from(value: &HashEntryInternal) -> Self { + unsafe { + std::mem::transmute(value.0.load(Ordering::SeqCst)) + } + } +} + +impl From for u64 { + fn from(value: HashEntry) -> Self { + unsafe { + std::mem::transmute(value) } } } pub struct HashTable { - table: Vec, + table: Vec, } impl HashTable { pub fn new(size: usize) -> Self { Self { - table: vec![HashEntry::default(); size], + table: vec![HashEntryInternal::default(); size], } } pub fn clear(&mut self) { for entry in &mut self.table { - *entry = HashEntry::default(); + *entry = HashEntryInternal::default(); } } - pub fn fetch(&self, hash: u64) -> &HashEntry { + pub fn fetch(&self, hash: u64) -> HashEntry { let idx = hash % (self.table.len() as u64); - &self.table[idx as usize] + HashEntry::from(&self.table[idx as usize]) + } + + fn key(hash: u64) -> u16 { + (hash >> 48) as u16 } pub fn get(&self, hash: u64) -> Option { let entry = self.fetch(hash); - if entry.hash == hash { - Some(*entry) + if entry.hash == Self::key(hash) { + Some(entry) } else { None } @@ -49,6 +76,13 @@ impl HashTable { pub fn push(&mut self, hash: u64, visits: i32, q: f32) { let idx = hash % (self.table.len() as u64); - self.table[idx as usize] = HashEntry { hash, visits, q }; + + let entry = HashEntry { + hash: Self::key(hash), + visits, + q: (q * f32::from(u16::MAX)) as u16, + }; + + self.table[idx as usize].0.store(u64::from(entry), Ordering::SeqCst) } } From b448d63624a31c5f96109a4a642d42d58490f847 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Fri, 12 Jul 2024 23:57:32 +0100 Subject: [PATCH 03/48] Bench: 2130284 --- src/mcts.rs | 4 ++-- src/tree.rs | 10 +--------- src/tree/edge.rs | 6 +++--- src/tree/hash.rs | 2 +- src/tree/node.rs | 4 ---- 5 files changed, 7 insertions(+), 19 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 1c9a6381..90bcfdd2 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -213,7 +213,7 @@ impl<'a> Searcher<'a> { if child_ptr == -1 { let state = pos.game_state(); child_ptr = self.tree.push(Node::new(state, pos.hash(), ptr, action)); - self.tree.edge_mut(ptr, action).set_ptr(child_ptr); + self.tree.edge(ptr, action).set_ptr(child_ptr); } let u = self.perform_one_iteration(pos, child_ptr, depth); @@ -224,7 +224,7 @@ impl<'a> Searcher<'a> { // flip perspective of score u = 1.0 - u; - self.tree.edge_mut(parent, action).update(u); + self.tree.edge(parent, action).update(u); let edge = self.tree.edge(parent, action); self.tree.push_hash(hash, edge.visits(), edge.q()); diff --git a/src/tree.rs b/src/tree.rs index cce5f9b5..8c227712 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -75,7 +75,7 @@ impl Tree { let parent = self[new].parent(); let action = self[new].action(); - self.edge_mut(parent, action).set_ptr(-1); + self.edge(parent, action).set_ptr(-1); self.delete(new); } @@ -208,14 +208,6 @@ impl Tree { } } - pub fn edge_mut(&mut self, ptr: i32, idx: usize) -> &mut Edge { - if ptr == -1 { - &mut self.parent_edge - } else { - &mut self[ptr].actions_mut()[idx] - } - } - pub fn propogate_proven_mates(&mut self, ptr: i32, child_state: GameState) { match child_state { // if the child node resulted in a loss, then diff --git a/src/tree/edge.rs b/src/tree/edge.rs index 8736f3ca..013b8c5a 100644 --- a/src/tree/edge.rs +++ b/src/tree/edge.rs @@ -80,15 +80,15 @@ impl Edge { (self.sq_q() - self.q64().powi(2)).max(0.0) as f32 } - pub fn set_ptr(&mut self, ptr: i32) { + pub fn set_ptr(&self, ptr: i32) { self.ptr.store(ptr, Ordering::SeqCst); } - pub fn set_policy(&mut self, policy: f32) { + pub fn set_policy(&self, policy: f32) { self.policy.store((policy * f32::from(i16::MAX)) as i16, Ordering::SeqCst) } - pub fn update(&mut self, result: f32) { + pub fn update(&self, result: f32) { let r = f64::from(result); let v = f64::from(self.visits()); diff --git a/src/tree/hash.rs b/src/tree/hash.rs index ffa53502..8da00cee 100644 --- a/src/tree/hash.rs +++ b/src/tree/hash.rs @@ -74,7 +74,7 @@ impl HashTable { } } - pub fn push(&mut self, hash: u64, visits: i32, q: f32) { + pub fn push(&self, hash: u64, visits: i32, q: f32) { let idx = hash % (self.table.len() as u64); let entry = HashEntry { diff --git a/src/tree/node.rs b/src/tree/node.rs index 7b2fd09b..1dff2ed4 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -38,10 +38,6 @@ impl Node { &self.actions } - pub fn actions_mut(&mut self) -> &mut [Edge] { - &mut self.actions - } - pub fn state(&self) -> GameState { self.state } From fa6b62841a830e6ea0c7e7acef5dbb8eb2f0b21a Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sat, 13 Jul 2024 00:14:09 +0100 Subject: [PATCH 04/48] Bench: 2130284 --- src/chess.rs | 26 +++++++++++++++ src/tree.rs | 2 +- src/tree/node.rs | 84 ++++++++++++++++++++++++++++-------------------- 3 files changed, 77 insertions(+), 35 deletions(-) diff --git a/src/chess.rs b/src/chess.rs index 8d6d1dd4..2eaab574 100644 --- a/src/chess.rs +++ b/src/chess.rs @@ -19,6 +19,32 @@ pub enum GameState { Won(u8), } +impl From for u16 { + fn from(value: GameState) -> Self { + match value { + GameState::Ongoing => 0, + GameState::Draw => 1 << 8, + GameState::Lost(x) => (2 << 8) ^ u16::from(x), + GameState::Won(x) => (3 << 8) ^ u16::from(x), + } + } +} + +impl From for GameState { + fn from(value: u16) -> Self { + let discr = value >> 8; + let x = value as u8; + + match discr { + 0 => GameState::Ongoing, + 1 => GameState::Draw, + 2 => GameState::Lost(x), + 3 => GameState::Won(x), + _ => unreachable!(), + } + } +} + impl std::fmt::Display for GameState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/src/tree.rs b/src/tree.rs index 8c227712..5f3a21d2 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -44,7 +44,7 @@ impl Tree { } fn new(cap: usize) -> Self { - let mut tree = Self { + let tree = Self { tree: vec![Node::new(GameState::Ongoing, 0, -1, 0); cap / 8], hash: HashTable::new(cap / 16), root: -1, diff --git a/src/tree/node.rs b/src/tree/node.rs index 1dff2ed4..9381a702 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,37 +1,53 @@ +use std::sync::atomic::{AtomicI32, AtomicU16, AtomicU64, Ordering}; + use crate::{chess::Move, tree::Edge, ChessState, GameState, MctsParams, PolicyNetwork}; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Node { actions: Vec, - state: GameState, - hash: u64, + state: AtomicU16, + hash: AtomicU64, // used for lru - bwd_link: i32, - fwd_link: i32, - parent: i32, - action: u16, + bwd_link: AtomicI32, + fwd_link: AtomicI32, + parent: AtomicI32, + action: AtomicU16, +} + +impl Clone for Node { + fn clone(&self) -> Self { + Self { + actions: self.actions.clone(), + state: AtomicU16::new(self.state.load(Ordering::SeqCst)), + hash: AtomicU64::new(self.hash()), + bwd_link: AtomicI32::new(self.bwd_link()), + fwd_link: AtomicI32::new(self.fwd_link()), + parent: AtomicI32::new(self.parent()), + action: AtomicU16::new(self.action.load(Ordering::SeqCst)), + } + } } impl Node { pub fn new(state: GameState, hash: u64, parent: i32, action: usize) -> Self { Node { actions: Vec::new(), - state, - hash, - parent, - bwd_link: -1, - fwd_link: -1, - action: action as u16, + state: AtomicU16::new(u16::from(state)), + hash: AtomicU64::new(hash), + parent: AtomicI32::new(parent), + bwd_link: AtomicI32::new(-1), + fwd_link: AtomicI32::new(-1), + action: AtomicU16::new(action as u16), } } pub fn parent(&self) -> i32 { - self.parent + self.parent.load(Ordering::SeqCst) } pub fn is_terminal(&self) -> bool { - self.state != GameState::Ongoing + self.state() != GameState::Ongoing } pub fn actions(&self) -> &[Edge] { @@ -39,23 +55,23 @@ impl Node { } pub fn state(&self) -> GameState { - self.state + GameState::from(self.state.load(Ordering::SeqCst)) } pub fn hash(&self) -> u64 { - self.hash + self.hash.load(Ordering::SeqCst) } pub fn bwd_link(&self) -> i32 { - self.bwd_link + self.bwd_link.load(Ordering::SeqCst) } pub fn fwd_link(&self) -> i32 { - self.fwd_link + self.fwd_link.load(Ordering::SeqCst) } - pub fn set_state(&mut self, state: GameState) { - self.state = state; + pub fn set_state(&self, state: GameState) { + self.state.store(u16::from(state), Ordering::SeqCst); } pub fn has_children(&self) -> bool { @@ -63,32 +79,32 @@ impl Node { } pub fn action(&self) -> usize { - usize::from(self.action) + usize::from(self.action.load(Ordering::SeqCst)) } - pub fn clear_parent(&mut self) { - self.parent = -1; - self.action = 0; + pub fn clear_parent(&self) { + self.parent.store(-1, Ordering::SeqCst); + self.action.store(0, Ordering::SeqCst); } pub fn is_not_expanded(&self) -> bool { - self.state == GameState::Ongoing && self.actions.is_empty() + self.state() == GameState::Ongoing && self.actions.is_empty() } pub fn clear(&mut self) { self.actions.clear(); - self.state = GameState::Ongoing; - self.hash = 0; - self.bwd_link = -1; - self.fwd_link = -1; + self.set_state(GameState::Ongoing); + self.hash.store(0, Ordering::SeqCst); + self.set_bwd_link(-1); + self.set_fwd_link(-1); } - pub fn set_fwd_link(&mut self, ptr: i32) { - self.fwd_link = ptr; + pub fn set_fwd_link(&self, ptr: i32) { + self.fwd_link.store(ptr, Ordering::SeqCst); } - pub fn set_bwd_link(&mut self, ptr: i32) { - self.bwd_link = ptr; + pub fn set_bwd_link(&self, ptr: i32) { + self.bwd_link.store(ptr, Ordering::SeqCst); } pub fn expand( From 1bb435353d1a3cef789f1c6686793ce9d8e80f9f Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sat, 13 Jul 2024 01:18:19 +0100 Subject: [PATCH 05/48] Bench: 2130284 --- src/mcts.rs | 4 ++-- src/tree.rs | 18 +++++++++++------- src/tree/edge.rs | 9 +++++++++ src/tree/node.rs | 8 ++++++++ 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 90bcfdd2..0440e7eb 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -6,7 +6,7 @@ pub use params::MctsParams; use crate::{ chess::Move, - tree::{Node, Tree}, + tree::Tree, ChessState, GameState, PolicyNetwork, ValueNetwork, }; @@ -212,7 +212,7 @@ impl<'a> Searcher<'a> { // create and push node if not present if child_ptr == -1 { let state = pos.game_state(); - child_ptr = self.tree.push(Node::new(state, pos.hash(), ptr, action)); + child_ptr = self.tree.push_new(state, pos.hash(), ptr, action); self.tree.edge(ptr, action).set_ptr(child_ptr); } diff --git a/src/tree.rs b/src/tree.rs index 5f3a21d2..e26dff30 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -44,8 +44,8 @@ impl Tree { } fn new(cap: usize) -> Self { - let tree = Self { - tree: vec![Node::new(GameState::Ongoing, 0, -1, 0); cap / 8], + let mut tree = Self { + tree: Vec::new(), hash: HashTable::new(cap / 16), root: -1, empty: 0, @@ -55,6 +55,10 @@ impl Tree { parent_edge: Edge::new(0, 0, 0), }; + for _ in 0..cap / 8 { + tree.tree.push(Node::new(GameState::Ongoing, 0, -1, 0)); + } + let end = tree.cap() as i32 - 1; for i in 0..end { @@ -66,7 +70,7 @@ impl Tree { tree } - pub fn push(&mut self, node: Node) -> i32 { + pub fn push_new(&mut self, state: GameState, hash: u64, parent: i32, action: usize) -> i32 { let mut new = self.empty; // tree is full, do some LRU pruning @@ -84,7 +88,7 @@ impl Tree { self.used += 1; self.empty = self[self.empty].fwd_link(); - self[new] = node; + self[new].set_new(state, hash, parent, action); self.append_to_lru(new); @@ -186,7 +190,7 @@ impl Tree { let end = self.cap() as i32 - 1; for i in 0..end { - self[i] = Node::new(GameState::Ongoing, 0, -1, 0); + self[i].set_new(GameState::Ongoing, 0, -1, 0); self[i].set_fwd_link(i + 1); } @@ -243,7 +247,7 @@ impl Tree { let t = Instant::now(); if self.is_empty() { - let node = self.push(Node::new(GameState::Ongoing, root.hash(), -1, 0)); + let node = self.push_new(GameState::Ongoing, root.hash(), -1, 0); self.make_root_node(node); return; @@ -272,7 +276,7 @@ impl Tree { if !found { println!("info string no subtree found"); - let node = self.push(Node::new(GameState::Ongoing, root.hash(), -1, 0)); + let node = self.push_new(GameState::Ongoing, root.hash(), -1, 0); self.make_root_node(node); } diff --git a/src/tree/edge.rs b/src/tree/edge.rs index 013b8c5a..58b2c432 100644 --- a/src/tree/edge.rs +++ b/src/tree/edge.rs @@ -48,6 +48,15 @@ impl Edge { } } + pub fn set_new(&self, ptr: i32, mov: u16, policy: i16) { + self.ptr.store(ptr, Ordering::SeqCst); + self.mov.store(mov, Ordering::SeqCst); + self.policy.store(policy, Ordering::SeqCst); + self.visits.store(0, Ordering::SeqCst); + self.q.store(0, Ordering::SeqCst); + self.sq_q.store(0, Ordering::SeqCst); + } + pub fn ptr(&self) -> i32 { self.ptr.load(Ordering::SeqCst) } diff --git a/src/tree/node.rs b/src/tree/node.rs index 9381a702..5665f375 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -42,6 +42,14 @@ impl Node { } } + pub fn set_new(&mut self, state: GameState, hash: u64, parent: i32, action: usize) { + self.clear(); + self.state.store(u16::from(state), Ordering::SeqCst); + self.hash.store(hash, Ordering::SeqCst); + self.parent.store(parent, Ordering::SeqCst); + self.action.store(action as u16, Ordering::SeqCst); + } + pub fn parent(&self) -> i32 { self.parent.load(Ordering::SeqCst) } From b7aebaa9e7b3c1e496a7bd32f5c1d1540393427e Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 14 Jul 2024 23:01:24 +0100 Subject: [PATCH 06/48] Bench: 2130284 --- Cargo.toml | 1 - src/mcts.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 04f80dbd..202ad786 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,6 @@ panic = 'abort' strip = true lto = true codegen-units = 1 -overflow-checks = true [dependencies] goober = { git = 'https://github.com/jw1912/goober.git' } diff --git a/src/mcts.rs b/src/mcts.rs index 0440e7eb..10dfed29 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -323,7 +323,7 @@ impl<'a> Searcher<'a> { } action = self.tree.edge(action.ptr(), idx); - depth -= 1; + depth = depth.saturating_sub(1); } (pv, score) From 1179e258527e060100cfe6c52cb2e35d97cb035f Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 14 Jul 2024 23:42:07 +0100 Subject: [PATCH 07/48] Use relaxed ordering Bench: 2089708 --- datagen/src/lib.rs | 4 ++-- src/tree/edge.rs | 40 ++++++++++++++++++++-------------------- src/tree/hash.rs | 6 +++--- src/tree/node.rs | 36 ++++++++++++++++++------------------ 4 files changed, 43 insertions(+), 43 deletions(-) diff --git a/datagen/src/lib.rs b/datagen/src/lib.rs index 42aecf1d..c7bdcb83 100644 --- a/datagen/src/lib.rs +++ b/datagen/src/lib.rs @@ -41,7 +41,7 @@ pub struct Destination { impl Destination { pub fn push(&mut self, game: &Binpack, stop: &AtomicBool) { - if stop.load(Ordering::SeqCst) { + if stop.load(Ordering::Relaxed) { return; } @@ -51,7 +51,7 @@ impl Destination { game.serialise_into(&mut self.writer).unwrap(); if self.games >= self.limit { - stop.store(true, Ordering::SeqCst); + stop.store(true, Ordering::Relaxed); return; } diff --git a/src/tree/edge.rs b/src/tree/edge.rs index 58b2c432..5d286f5b 100644 --- a/src/tree/edge.rs +++ b/src/tree/edge.rs @@ -15,10 +15,10 @@ impl Clone for Edge { Self { ptr: AtomicI32::new(self.ptr()), mov: AtomicU16::new(self.mov()), - policy: AtomicI16::new(self.policy.load(Ordering::SeqCst)), + policy: AtomicI16::new(self.policy.load(Ordering::Relaxed)), visits: AtomicI32::new(self.visits()), - q: AtomicU32::new(self.q.load(Ordering::SeqCst)), - sq_q: AtomicU32::new(self.sq_q.load(Ordering::SeqCst)), + q: AtomicU32::new(self.q.load(Ordering::Relaxed)), + sq_q: AtomicU32::new(self.sq_q.load(Ordering::Relaxed)), } } } @@ -49,32 +49,32 @@ impl Edge { } pub fn set_new(&self, ptr: i32, mov: u16, policy: i16) { - self.ptr.store(ptr, Ordering::SeqCst); - self.mov.store(mov, Ordering::SeqCst); - self.policy.store(policy, Ordering::SeqCst); - self.visits.store(0, Ordering::SeqCst); - self.q.store(0, Ordering::SeqCst); - self.sq_q.store(0, Ordering::SeqCst); + self.ptr.store(ptr, Ordering::Relaxed); + self.mov.store(mov, Ordering::Relaxed); + self.policy.store(policy, Ordering::Relaxed); + self.visits.store(0, Ordering::Relaxed); + self.q.store(0, Ordering::Relaxed); + self.sq_q.store(0, Ordering::Relaxed); } pub fn ptr(&self) -> i32 { - self.ptr.load(Ordering::SeqCst) + self.ptr.load(Ordering::Relaxed) } pub fn mov(&self) -> u16 { - self.mov.load(Ordering::SeqCst) + self.mov.load(Ordering::Relaxed) } pub fn visits(&self) -> i32 { - self.visits.load(Ordering::SeqCst) + self.visits.load(Ordering::Relaxed) } pub fn policy(&self) -> f32 { - f32::from(self.policy.load(Ordering::SeqCst)) / f32::from(i16::MAX) + f32::from(self.policy.load(Ordering::Relaxed)) / f32::from(i16::MAX) } fn q64(&self) -> f64 { - f64::from(self.q.load(Ordering::SeqCst)) / f64::from(u32::MAX) + f64::from(self.q.load(Ordering::Relaxed)) / f64::from(u32::MAX) } pub fn q(&self) -> f32 { @@ -82,7 +82,7 @@ impl Edge { } pub fn sq_q(&self) -> f64 { - f64::from(self.sq_q.load(Ordering::SeqCst)) / f64::from(u32::MAX) + f64::from(self.sq_q.load(Ordering::Relaxed)) / f64::from(u32::MAX) } pub fn var(&self) -> f32 { @@ -90,11 +90,11 @@ impl Edge { } pub fn set_ptr(&self, ptr: i32) { - self.ptr.store(ptr, Ordering::SeqCst); + self.ptr.store(ptr, Ordering::Relaxed); } pub fn set_policy(&self, policy: f32) { - self.policy.store((policy * f32::from(i16::MAX)) as i16, Ordering::SeqCst) + self.policy.store((policy * f32::from(i16::MAX)) as i16, Ordering::Relaxed) } pub fn update(&self, result: f32) { @@ -104,9 +104,9 @@ impl Edge { let q = (self.q64() * v + r) / (v + 1.0); let sq_q = (self.sq_q() * v + r.powi(2)) / (v + 1.0); - self.q.store((q * f64::from(u32::MAX)) as u32, Ordering::SeqCst); - self.sq_q.store((sq_q * f64::from(u32::MAX)) as u32, Ordering::SeqCst); + self.q.store((q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); + self.sq_q.store((sq_q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); - self.visits.fetch_add(1, Ordering::SeqCst); + self.visits.fetch_add(1, Ordering::Relaxed); } } diff --git a/src/tree/hash.rs b/src/tree/hash.rs index 8da00cee..f4bf820e 100644 --- a/src/tree/hash.rs +++ b/src/tree/hash.rs @@ -18,14 +18,14 @@ struct HashEntryInternal(AtomicU64); impl Clone for HashEntryInternal { fn clone(&self) -> Self { - Self(AtomicU64::new(self.0.load(Ordering::SeqCst))) + Self(AtomicU64::new(self.0.load(Ordering::Relaxed))) } } impl From<&HashEntryInternal> for HashEntry { fn from(value: &HashEntryInternal) -> Self { unsafe { - std::mem::transmute(value.0.load(Ordering::SeqCst)) + std::mem::transmute(value.0.load(Ordering::Relaxed)) } } } @@ -83,6 +83,6 @@ impl HashTable { q: (q * f32::from(u16::MAX)) as u16, }; - self.table[idx as usize].0.store(u64::from(entry), Ordering::SeqCst) + self.table[idx as usize].0.store(u64::from(entry), Ordering::Relaxed) } } diff --git a/src/tree/node.rs b/src/tree/node.rs index 5665f375..88cb5f44 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -19,12 +19,12 @@ impl Clone for Node { fn clone(&self) -> Self { Self { actions: self.actions.clone(), - state: AtomicU16::new(self.state.load(Ordering::SeqCst)), + state: AtomicU16::new(self.state.load(Ordering::Relaxed)), hash: AtomicU64::new(self.hash()), bwd_link: AtomicI32::new(self.bwd_link()), fwd_link: AtomicI32::new(self.fwd_link()), parent: AtomicI32::new(self.parent()), - action: AtomicU16::new(self.action.load(Ordering::SeqCst)), + action: AtomicU16::new(self.action.load(Ordering::Relaxed)), } } } @@ -44,14 +44,14 @@ impl Node { pub fn set_new(&mut self, state: GameState, hash: u64, parent: i32, action: usize) { self.clear(); - self.state.store(u16::from(state), Ordering::SeqCst); - self.hash.store(hash, Ordering::SeqCst); - self.parent.store(parent, Ordering::SeqCst); - self.action.store(action as u16, Ordering::SeqCst); + self.state.store(u16::from(state), Ordering::Relaxed); + self.hash.store(hash, Ordering::Relaxed); + self.parent.store(parent, Ordering::Relaxed); + self.action.store(action as u16, Ordering::Relaxed); } pub fn parent(&self) -> i32 { - self.parent.load(Ordering::SeqCst) + self.parent.load(Ordering::Relaxed) } pub fn is_terminal(&self) -> bool { @@ -63,23 +63,23 @@ impl Node { } pub fn state(&self) -> GameState { - GameState::from(self.state.load(Ordering::SeqCst)) + GameState::from(self.state.load(Ordering::Relaxed)) } pub fn hash(&self) -> u64 { - self.hash.load(Ordering::SeqCst) + self.hash.load(Ordering::Relaxed) } pub fn bwd_link(&self) -> i32 { - self.bwd_link.load(Ordering::SeqCst) + self.bwd_link.load(Ordering::Relaxed) } pub fn fwd_link(&self) -> i32 { - self.fwd_link.load(Ordering::SeqCst) + self.fwd_link.load(Ordering::Relaxed) } pub fn set_state(&self, state: GameState) { - self.state.store(u16::from(state), Ordering::SeqCst); + self.state.store(u16::from(state), Ordering::Relaxed); } pub fn has_children(&self) -> bool { @@ -87,12 +87,12 @@ impl Node { } pub fn action(&self) -> usize { - usize::from(self.action.load(Ordering::SeqCst)) + usize::from(self.action.load(Ordering::Relaxed)) } pub fn clear_parent(&self) { - self.parent.store(-1, Ordering::SeqCst); - self.action.store(0, Ordering::SeqCst); + self.parent.store(-1, Ordering::Relaxed); + self.action.store(0, Ordering::Relaxed); } pub fn is_not_expanded(&self) -> bool { @@ -102,17 +102,17 @@ impl Node { pub fn clear(&mut self) { self.actions.clear(); self.set_state(GameState::Ongoing); - self.hash.store(0, Ordering::SeqCst); + self.hash.store(0, Ordering::Relaxed); self.set_bwd_link(-1); self.set_fwd_link(-1); } pub fn set_fwd_link(&self, ptr: i32) { - self.fwd_link.store(ptr, Ordering::SeqCst); + self.fwd_link.store(ptr, Ordering::Relaxed); } pub fn set_bwd_link(&self, ptr: i32) { - self.bwd_link.store(ptr, Ordering::SeqCst); + self.bwd_link.store(ptr, Ordering::Relaxed); } pub fn expand( From 5e034714f976ae6384016529537933ab7dce1ecc Mon Sep 17 00:00:00 2001 From: jw1912 Date: Mon, 15 Jul 2024 22:41:28 +0100 Subject: [PATCH 08/48] Technically ready for SMP Bench: 2096036 --- src/tree.rs | 93 +++++++++++++++++++++---------------------- src/tree/edge.rs | 6 +-- src/tree/node.rs | 101 +++++++++++++++++++++++++++-------------------- 3 files changed, 106 insertions(+), 94 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index e26dff30..2a54591e 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -5,7 +5,10 @@ mod node; pub use edge::Edge; use hash::{HashEntry, HashTable}; pub use node::Node; -use std::time::Instant; +use std::{ + sync::atomic::{AtomicI32, AtomicUsize, Ordering}, + time::Instant, +}; use crate::{ chess::{ChessState, Move}, @@ -15,11 +18,11 @@ use crate::{ pub struct Tree { tree: Vec, hash: HashTable, - root: i32, - empty: i32, - used: usize, - lru_head: i32, - lru_tail: i32, + root: AtomicI32, + empty: AtomicI32, + used: AtomicUsize, + lru_head: AtomicI32, + lru_tail: AtomicI32, parent_edge: Edge, } @@ -31,12 +34,6 @@ impl std::ops::Index for Tree { } } -impl std::ops::IndexMut for Tree { - fn index_mut(&mut self, index: i32) -> &mut Self::Output { - &mut self.tree[index as usize] - } -} - impl Tree { pub fn new_mb(mb: usize) -> Self { let cap = mb * 1024 * 1024 / std::mem::size_of::(); @@ -47,11 +44,11 @@ impl Tree { let mut tree = Self { tree: Vec::new(), hash: HashTable::new(cap / 16), - root: -1, - empty: 0, - used: 0, - lru_head: -1, - lru_tail: -1, + root: AtomicI32::new(-1), + empty: AtomicI32::new(0), + used: AtomicUsize::new(0), + lru_head: AtomicI32::new(-1), + lru_tail: AtomicI32::new(-1), parent_edge: Edge::new(0, 0, 0), }; @@ -70,12 +67,12 @@ impl Tree { tree } - pub fn push_new(&mut self, state: GameState, hash: u64, parent: i32, action: usize) -> i32 { - let mut new = self.empty; + pub fn push_new(&self, state: GameState, hash: u64, parent: i32, action: usize) -> i32 { + let mut new = self.empty.load(Ordering::Relaxed); // tree is full, do some LRU pruning if new == -1 { - new = self.lru_tail; + new = self.lru_tail.load(Ordering::Relaxed); let parent = self[new].parent(); let action = self[new].action(); @@ -86,14 +83,14 @@ impl Tree { assert_ne!(new, -1); - self.used += 1; - self.empty = self[self.empty].fwd_link(); + self.used.fetch_add(1, Ordering::Relaxed); + self.empty.store(self[self.empty.load(Ordering::Relaxed)].fwd_link(), Ordering::Relaxed); self[new].set_new(state, hash, parent, action); self.append_to_lru(new); - if self.used == 1 { - self.lru_tail = new; + if self.used.load(Ordering::Relaxed) == 1 { + self.lru_tail.store(new, Ordering::Relaxed); } new @@ -103,51 +100,51 @@ impl Tree { self.hash.get(hash) } - pub fn push_hash(&mut self, hash: u64, visits: i32, wins: f32) { + pub fn push_hash(&self, hash: u64, visits: i32, wins: f32) { self.hash.push(hash, visits, wins); } - pub fn delete(&mut self, ptr: i32) { + pub fn delete(&self, ptr: i32) { self.remove_from_lru(ptr); self[ptr].clear(); - let empty = self.empty; + let empty = self.empty.load(Ordering::Relaxed); self[ptr].set_fwd_link(empty); - self.empty = ptr; - self.used -= 1; - assert!(self.used < self.cap()); + self.empty.store(ptr, Ordering::Relaxed); + let used = self.used.fetch_sub(1, Ordering::Relaxed); + assert!(used - 1 < self.cap()); } - pub fn make_recently_used(&mut self, ptr: i32) { + pub fn make_recently_used(&self, ptr: i32) { self.remove_from_lru(ptr); self.append_to_lru(ptr); } - fn append_to_lru(&mut self, ptr: i32) { - let old_head = self.lru_head; + fn append_to_lru(&self, ptr: i32) { + let old_head = self.lru_head.load(Ordering::Relaxed); if old_head != -1 { self[old_head].set_bwd_link(ptr); } - self.lru_head = ptr; + self.lru_head.store(ptr, Ordering::Relaxed); self[ptr].set_fwd_link(old_head); self[ptr].set_bwd_link(-1); } - fn remove_from_lru(&mut self, ptr: i32) { + fn remove_from_lru(&self, ptr: i32) { let bwd = self[ptr].bwd_link(); let fwd = self[ptr].fwd_link(); if bwd != -1 { self[bwd].set_fwd_link(fwd); } else { - self.lru_head = fwd; + self.lru_head.store(fwd, Ordering::Relaxed); } if fwd != -1 { self[fwd].set_bwd_link(bwd); } else { - self.lru_tail = bwd; + self.lru_tail.store(bwd, Ordering::Relaxed); } self[ptr].set_bwd_link(-1); @@ -155,7 +152,7 @@ impl Tree { } pub fn root_node(&self) -> i32 { - self.root + self.root.load(Ordering::Relaxed) } pub fn cap(&self) -> usize { @@ -163,7 +160,7 @@ impl Tree { } pub fn len(&self) -> usize { - self.used + self.used.load(Ordering::Relaxed) } pub fn remaining(&self) -> usize { @@ -175,16 +172,16 @@ impl Tree { } pub fn clear(&mut self) { - if self.used == 0 { + if self.is_empty() { return; } self.hash.clear(); - self.root = -1; - self.empty = 0; - self.used = 0; - self.lru_head = -1; - self.lru_tail = -1; + self.root.store(-1, Ordering::Relaxed); + self.empty.store(0, Ordering::Relaxed); + self.used.store(0, Ordering::Relaxed); + self.lru_head.store(-1, Ordering::Relaxed); + self.lru_tail.store(-1, Ordering::Relaxed); self.parent_edge = Edge::new(0, 0, 0); let end = self.cap() as i32 - 1; @@ -198,7 +195,7 @@ impl Tree { } pub fn make_root_node(&mut self, node: i32) { - self.root = node; + self.root.store(node, Ordering::Relaxed); self.parent_edge = self.edge(self[node].parent(), self[node].action()).clone(); self[node].clear_parent(); self[node].set_state(GameState::Ongoing); @@ -212,7 +209,7 @@ impl Tree { } } - pub fn propogate_proven_mates(&mut self, ptr: i32, child_state: GameState) { + pub fn propogate_proven_mates(&self, ptr: i32, child_state: GameState) { match child_state { // if the child node resulted in a loss, then // this node has a guaranteed win @@ -260,7 +257,7 @@ impl Tree { if let Some(board) = prev_board { println!("info string searching for subtree"); - let root = self.recurse_find(self.root, board, root, 2); + let root = self.recurse_find(self.root_node(), board, root, 2); if root != -1 && self[root].has_children() { found = true; diff --git a/src/tree/edge.rs b/src/tree/edge.rs index 5d286f5b..c668e215 100644 --- a/src/tree/edge.rs +++ b/src/tree/edge.rs @@ -48,10 +48,10 @@ impl Edge { } } - pub fn set_new(&self, ptr: i32, mov: u16, policy: i16) { - self.ptr.store(ptr, Ordering::Relaxed); + pub fn set_new(&self, mov: u16, policy: f32) { + self.ptr.store(-1, Ordering::Relaxed); self.mov.store(mov, Ordering::Relaxed); - self.policy.store(policy, Ordering::Relaxed); + self.set_policy(policy); self.visits.store(0, Ordering::Relaxed); self.q.store(0, Ordering::Relaxed); self.sq_q.store(0, Ordering::Relaxed); diff --git a/src/tree/node.rs b/src/tree/node.rs index 88cb5f44..e596d7bf 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,10 +1,17 @@ -use std::sync::atomic::{AtomicI32, AtomicU16, AtomicU64, Ordering}; +use std::{ + alloc::{self, Layout}, + sync::atomic::{AtomicI32, AtomicPtr, AtomicU16, AtomicU64, Ordering} +}; use crate::{chess::Move, tree::Edge, ChessState, GameState, MctsParams, PolicyNetwork}; +const EDGE_SIZE: usize = std::mem::size_of::(); +const EDGE_ALIGN: usize = std::mem::align_of::(); + #[derive(Debug)] pub struct Node { - actions: Vec, + actions: AtomicPtr, + num_actions: AtomicU16, state: AtomicU16, hash: AtomicU64, @@ -15,24 +22,11 @@ pub struct Node { action: AtomicU16, } -impl Clone for Node { - fn clone(&self) -> Self { - Self { - actions: self.actions.clone(), - state: AtomicU16::new(self.state.load(Ordering::Relaxed)), - hash: AtomicU64::new(self.hash()), - bwd_link: AtomicI32::new(self.bwd_link()), - fwd_link: AtomicI32::new(self.fwd_link()), - parent: AtomicI32::new(self.parent()), - action: AtomicU16::new(self.action.load(Ordering::Relaxed)), - } - } -} - impl Node { pub fn new(state: GameState, hash: u64, parent: i32, action: usize) -> Self { Node { - actions: Vec::new(), + actions: AtomicPtr::new(std::ptr::null_mut()), + num_actions: AtomicU16::new(0), state: AtomicU16::new(u16::from(state)), hash: AtomicU64::new(hash), parent: AtomicI32::new(parent), @@ -42,7 +36,7 @@ impl Node { } } - pub fn set_new(&mut self, state: GameState, hash: u64, parent: i32, action: usize) { + pub fn set_new(&self, state: GameState, hash: u64, parent: i32, action: usize) { self.clear(); self.state.store(u16::from(state), Ordering::Relaxed); self.hash.store(hash, Ordering::Relaxed); @@ -58,8 +52,20 @@ impl Node { self.state() != GameState::Ongoing } + pub fn num_actions(&self) -> usize { + usize::from(self.num_actions.load(Ordering::Relaxed)) + } + pub fn actions(&self) -> &[Edge] { - &self.actions + let ptr = self.actions.load(Ordering::Relaxed); + + if ptr.is_null() { + return &[]; + } + + unsafe { + std::slice::from_raw_parts(ptr, self.num_actions()) + } } pub fn state(&self) -> GameState { @@ -83,7 +89,7 @@ impl Node { } pub fn has_children(&self) -> bool { - !self.actions.is_empty() + !self.actions.load(Ordering::Relaxed).is_null() } pub fn action(&self) -> usize { @@ -96,11 +102,18 @@ impl Node { } pub fn is_not_expanded(&self) -> bool { - self.state() == GameState::Ongoing && self.actions.is_empty() + self.state() == GameState::Ongoing && !self.has_children() } - pub fn clear(&mut self) { - self.actions.clear(); + pub fn clear(&self) { + let ptr = self.actions.load(Ordering::Relaxed); + let layout = Layout::from_size_align(EDGE_SIZE * self.num_actions(), EDGE_ALIGN).unwrap(); + unsafe { + alloc::dealloc(ptr.cast(), layout); + } + + self.actions.store(std::ptr::null_mut(), Ordering::Relaxed); + self.num_actions.store(0, Ordering::Relaxed); self.set_state(GameState::Ongoing); self.hash.store(0, Ordering::Relaxed); self.set_bwd_link(-1); @@ -116,7 +129,7 @@ impl Node { } pub fn expand( - &mut self, + &self, pos: &ChessState, params: &MctsParams, policy: &PolicyNetwork, @@ -125,41 +138,43 @@ impl Node { let feats = pos.get_policy_feats(); let mut max = f32::NEG_INFINITY; + let mut moves = [(0, 0.0); 256]; + let mut num = 0; pos.map_legal_moves(|mov| { let policy = pos.get_policy(mov, &feats, policy); - - // trick for calculating policy before quantising - self.actions - .push(Edge::new(f32::to_bits(policy) as i32, mov.into(), 0)); max = max.max(policy); + moves[num] = (mov.into(), policy); + num += 1; }); let mut total = 0.0; - for action in &mut self.actions { - let mut policy = f32::from_bits(action.ptr() as u32); - - policy = if ROOT { - ((policy - max) / params.root_pst()).exp() + for (_, policy) in moves[..num].iter_mut() { + *policy = if ROOT { + ((*policy - max) / params.root_pst()).exp() } else { - (policy - max).exp() + (*policy - max).exp() }; - action.set_ptr(f32::to_bits(policy) as i32); + total += *policy; + } + + if num != 0 { + let layout = Layout::from_size_align(EDGE_SIZE * num, EDGE_ALIGN).unwrap(); + let ptr = unsafe { alloc::alloc_zeroed(layout) }; - total += policy; + self.num_actions.store(num as u16, Ordering::Relaxed); + self.actions.store(ptr.cast(), Ordering::Release); } - for action in &mut self.actions { - let policy = f32::from_bits(action.ptr() as u32) / total; - action.set_ptr(-1); - action.set_policy(policy); + for (action, &(mov, policy)) in self.actions().iter().zip(moves[..num].iter()) { + action.set_new(mov, policy / total); } } pub fn relabel_policy( - &mut self, + &self, pos: &ChessState, params: &MctsParams, policy: &PolicyNetwork, @@ -169,7 +184,7 @@ impl Node { let mut policies = Vec::new(); - for action in &self.actions { + for action in self.actions() { let mov = Move::from(action.mov()); let policy = pos.get_policy(mov, &feats, policy); policies.push(policy); @@ -183,7 +198,7 @@ impl Node { total += *policy; } - for (i, action) in self.actions.iter_mut().enumerate() { + for (i, action) in self.actions().iter().enumerate() { action.set_policy(policies[i] / total); } } From 90716d4e31f997eb0f262ff596ee4ea9012421b7 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Tue, 16 Jul 2024 10:22:38 +0100 Subject: [PATCH 09/48] . --- datagen/src/thread.rs | 9 ++++---- src/mcts.rs | 54 ++++++++++++++++++++++++++++--------------- src/uci.rs | 43 +++++++++++++++++++++------------- 3 files changed, 66 insertions(+), 40 deletions(-) diff --git a/datagen/src/thread.rs b/datagen/src/thread.rs index 38be0041..e658c65f 100644 --- a/datagen/src/thread.rs +++ b/datagen/src/thread.rs @@ -118,21 +118,20 @@ impl<'a> DatagenThread<'a> { } let abort = AtomicBool::new(false); + tree.try_use_subtree(&position, &None); let mut searcher = Searcher::new( position.clone(), - tree, - self.params.clone(), + &tree, + &self.params, policy, value, &abort, ); - let (bm, score) = searcher.search(limits, false, &mut 0, &None); + let (bm, score) = searcher.search(limits, false, &mut 0); game.push(position.stm(), bm, score); - tree = searcher.tree_and_board().0; - let mut root_count = 0; position.map_legal_moves(|_| root_count += 1); diff --git a/src/mcts.rs b/src/mcts.rs index b464f84b..40e9a138 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -25,8 +25,8 @@ pub struct Limits { pub struct Searcher<'a> { root_position: ChessState, - tree: Tree, - params: MctsParams, + tree: &'a Tree, + params: &'a MctsParams, policy: &'a PolicyNetwork, value: &'a ValueNetwork, abort: &'a AtomicBool, @@ -35,8 +35,8 @@ pub struct Searcher<'a> { impl<'a> Searcher<'a> { pub fn new( root_position: ChessState, - tree: Tree, - params: MctsParams, + tree: &'a Tree, + params: &'a MctsParams, policy: &'a PolicyNetwork, value: &'a ValueNetwork, abort: &'a AtomicBool, @@ -56,19 +56,17 @@ impl<'a> Searcher<'a> { limits: Limits, uci_output: bool, total_nodes: &mut usize, - prev_board: &Option, ) -> (Move, f32) { let timer = Instant::now(); // attempt to reuse the current tree stored in memory - self.tree.try_use_subtree(&self.root_position, prev_board); let node = self.tree.root_node(); // relabel root policies with root PST value if self.tree[node].has_children() { - self.tree[node].relabel_policy(&self.root_position, &self.params, self.policy); + self.tree[node].relabel_policy(&self.root_position, self.params, self.policy); } else { - self.tree[node].expand::(&self.root_position, &self.params, self.policy); + self.tree[node].expand::(&self.root_position, self.params, self.policy); } let mut nodes = 0; @@ -96,13 +94,13 @@ impl<'a> Searcher<'a> { break; } + if self.abort.load(Ordering::Relaxed) { + break; + } + nodes += 1; if nodes % 256 == 0 { - if self.abort.load(Ordering::Relaxed) { - break; - } - if let Some(time) = limits.max_time { if timer.elapsed().as_millis() >= time { break; @@ -170,6 +168,8 @@ impl<'a> Searcher<'a> { } } + self.abort.store(true, Ordering::Relaxed); + *total_nodes += nodes; if uci_output { @@ -181,6 +181,22 @@ impl<'a> Searcher<'a> { (Move::from(best_child.mov()), best_child.q()) } + pub fn secondary_search(&mut self) { + loop { + let mut pos = self.root_position.clone(); + self.perform_one_iteration(&mut pos, self.tree.root_node(), &mut 0); + + if self.abort.load(Ordering::Relaxed) { + break; + } + + if self.tree[self.tree.root_node()].is_terminal() { + self.abort.store(true, Ordering::Relaxed); + break; + } + } + } + fn perform_one_iteration(&mut self, pos: &mut ChessState, ptr: i32, depth: &mut usize) -> f32 { *depth += 1; @@ -207,7 +223,7 @@ impl<'a> Searcher<'a> { } else { // expand node on the second visit if self.tree[ptr].is_not_expanded() { - self.tree[ptr].expand::(pos, &self.params, self.policy); + self.tree[ptr].expand::(pos, self.params, self.policy); } // select action to take via PUCT @@ -247,7 +263,7 @@ impl<'a> Searcher<'a> { fn get_utility(&self, ptr: i32, pos: &ChessState) -> f32 { match self.tree[ptr].state() { - GameState::Ongoing => pos.get_value_wdl(self.value, &self.params), + GameState::Ongoing => pos.get_value_wdl(self.value, self.params), GameState::Draw => 0.5, GameState::Lost(_) => 0.0, GameState::Won(_) => 1.0, @@ -263,9 +279,9 @@ impl<'a> Searcher<'a> { let edge = self.tree.edge(node.parent(), node.action()); let is_root = edge.ptr() == self.tree.root_node(); - let cpuct = SearchHelpers::get_cpuct(&self.params, edge, is_root); + let cpuct = SearchHelpers::get_cpuct(self.params, edge, is_root); let fpu = SearchHelpers::get_fpu(edge); - let expl_scale = SearchHelpers::get_explore_scaling(&self.params, edge); + let expl_scale = SearchHelpers::get_explore_scaling(self.params, edge); let expl = cpuct * expl_scale; @@ -353,9 +369,9 @@ impl<'a> Searcher<'a> { -400.0 * (1.0 / score.clamp(0.0, 1.0) - 1.0).ln() } - pub fn tree_and_board(self) -> (Tree, ChessState) { - (self.tree, self.root_position) - } + //pub fn tree_and_board(self) -> (Tree, ChessState) { + // (self.tree, self.root_position) + //} pub fn display_moves(&self) { for action in self.tree[self.tree.root_node()].actions() { diff --git a/src/uci.rs b/src/uci.rs index e68d1fbe..2040ff62 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -26,6 +26,7 @@ impl Uci { let mut params = MctsParams::default(); let mut tree = Tree::new_mb(64); let mut report_moves = false; + let mut threads = 1; let mut stored_message: Option = None; @@ -51,15 +52,15 @@ impl Uci { let cmd = *commands.first().unwrap_or(&"oops"); match cmd { "isready" => println!("readyok"), - "setoption" => setoption(&commands, &mut params, &mut report_moves, &mut tree), + "setoption" => setoption(&commands, &mut params, &mut report_moves, &mut tree, &mut threads), "position" => position(commands, &mut pos, &mut prev, &mut tree), "go" => { // increment game ply every time `go` is called root_game_ply += 2; - let res = go( + go( &commands, - tree, + &mut tree, prev, &pos, root_game_ply, @@ -68,10 +69,10 @@ impl Uci { policy, value, &mut stored_message, + threads, ); - tree = res.0; - prev = Some(res.1); + prev = Some(pos.clone()); } "perft" => run_perft(&commands, &pos), "quit" => std::process::exit(0), @@ -143,15 +144,15 @@ impl Uci { }; let mut tree = Tree::new_mb(32); - let abort = AtomicBool::new(false); for fen in bench_fens { + let abort = AtomicBool::new(false); let pos = ChessState::from_fen(fen); - let mut searcher = Searcher::new(pos, tree, params.clone(), policy, value, &abort); + tree.try_use_subtree(&pos, &None); + let mut searcher = Searcher::new(pos, &tree, params, policy, value, &abort); let timer = Instant::now(); - searcher.search(limits, false, &mut total_nodes, &None); + searcher.search(limits, false, &mut total_nodes); time += timer.elapsed().as_secs_f32(); - tree = searcher.tree_and_board().0; tree.clear(); } @@ -172,7 +173,7 @@ fn preamble() { println!("uciok"); } -fn setoption(commands: &[&str], params: &mut MctsParams, report_moves: &mut bool, tree: &mut Tree) { +fn setoption(commands: &[&str], params: &mut MctsParams, report_moves: &mut bool, tree: &mut Tree, threads: &mut usize) { if let ["setoption", "name", "report_moves"] = commands { *report_moves = !*report_moves; return; @@ -183,6 +184,11 @@ fn setoption(commands: &[&str], params: &mut MctsParams, report_moves: &mut bool return; } + if *x == "Threads" { + *threads = y.parse().unwrap(); + return + } + (*x, y.parse::().unwrap_or(0)) } else { return; @@ -241,7 +247,7 @@ fn position( #[allow(clippy::too_many_arguments)] fn go( commands: &[&str], - tree: Tree, + tree: &mut Tree, prev: Option, pos: &ChessState, root_game_ply: u32, @@ -250,7 +256,8 @@ fn go( policy: &PolicyNetwork, value: &ValueNetwork, stored_message: &mut Option, -) -> (Tree, ChessState) { + threads: usize, +) { let mut max_nodes = i32::MAX as usize; let mut max_time = None; let mut max_depth = 256; @@ -313,7 +320,7 @@ fn go( let abort = AtomicBool::new(false); - let mut searcher = Searcher::new(pos.clone(), tree, params.clone(), policy, value, &abort); + tree.try_use_subtree(pos, &prev); let limits = Limits { max_time, @@ -324,7 +331,8 @@ fn go( std::thread::scope(|s| { s.spawn(|| { - let (mov, _) = searcher.search(limits, true, &mut 0, &prev); + let mut searcher = Searcher::new(pos.clone(), tree, params, policy, value, &abort); + let (mov, _) = searcher.search(limits, true, &mut 0); println!("bestmove {}", pos.conv_mov_to_str(mov)); if report_moves { @@ -332,10 +340,13 @@ fn go( } }); + for _ in 0..threads - 1 { + let mut searcher = Searcher::new(pos.clone(), tree, params, policy, value, &abort); + searcher.secondary_search(); + } + *stored_message = handle_search_input(&abort); }); - - searcher.tree_and_board() } fn run_perft(commands: &[&str], pos: &ChessState) { From 8f6d1f6a035b401351b1c7e67fe9f830a4dca1d9 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Tue, 16 Jul 2024 23:24:03 +0100 Subject: [PATCH 10/48] save bytes in node Bench: 2127854 --- src/mcts.rs | 4 ++-- src/tree.rs | 12 ++++++------ src/tree/node.rs | 16 ++++++---------- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 40e9a138..c41cfde7 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -202,7 +202,7 @@ impl<'a> Searcher<'a> { self.tree.make_recently_used(ptr); - let hash = self.tree[ptr].hash(); + let hash = pos.hash(); let parent = self.tree[ptr].parent(); let action = self.tree[ptr].action(); @@ -237,7 +237,7 @@ impl<'a> Searcher<'a> { // create and push node if not present if child_ptr == -1 { let state = pos.game_state(); - child_ptr = self.tree.push_new(state, pos.hash(), ptr, action); + child_ptr = self.tree.push_new(state, ptr, action); self.tree.edge(ptr, action).set_ptr(child_ptr); } diff --git a/src/tree.rs b/src/tree.rs index 2a54591e..1869cfd7 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -53,7 +53,7 @@ impl Tree { }; for _ in 0..cap / 8 { - tree.tree.push(Node::new(GameState::Ongoing, 0, -1, 0)); + tree.tree.push(Node::new(GameState::Ongoing, -1, 0)); } let end = tree.cap() as i32 - 1; @@ -67,7 +67,7 @@ impl Tree { tree } - pub fn push_new(&self, state: GameState, hash: u64, parent: i32, action: usize) -> i32 { + pub fn push_new(&self, state: GameState, parent: i32, action: usize) -> i32 { let mut new = self.empty.load(Ordering::Relaxed); // tree is full, do some LRU pruning @@ -85,7 +85,7 @@ impl Tree { self.used.fetch_add(1, Ordering::Relaxed); self.empty.store(self[self.empty.load(Ordering::Relaxed)].fwd_link(), Ordering::Relaxed); - self[new].set_new(state, hash, parent, action); + self[new].set_new(state, parent, action); self.append_to_lru(new); @@ -187,7 +187,7 @@ impl Tree { let end = self.cap() as i32 - 1; for i in 0..end { - self[i].set_new(GameState::Ongoing, 0, -1, 0); + self[i].set_new(GameState::Ongoing, -1, 0); self[i].set_fwd_link(i + 1); } @@ -244,7 +244,7 @@ impl Tree { let t = Instant::now(); if self.is_empty() { - let node = self.push_new(GameState::Ongoing, root.hash(), -1, 0); + let node = self.push_new(GameState::Ongoing, -1, 0); self.make_root_node(node); return; @@ -273,7 +273,7 @@ impl Tree { if !found { println!("info string no subtree found"); - let node = self.push_new(GameState::Ongoing, root.hash(), -1, 0); + let node = self.push_new(GameState::Ongoing, -1, 0); self.make_root_node(node); } diff --git a/src/tree/node.rs b/src/tree/node.rs index e596d7bf..390281c8 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,6 +1,6 @@ use std::{ alloc::{self, Layout}, - sync::atomic::{AtomicI32, AtomicPtr, AtomicU16, AtomicU64, Ordering} + sync::atomic::{AtomicI32, AtomicPtr, AtomicU16, Ordering} }; use crate::{chess::Move, tree::Edge, ChessState, GameState, MctsParams, PolicyNetwork}; @@ -13,7 +13,6 @@ pub struct Node { actions: AtomicPtr, num_actions: AtomicU16, state: AtomicU16, - hash: AtomicU64, // used for lru bwd_link: AtomicI32, @@ -23,12 +22,11 @@ pub struct Node { } impl Node { - pub fn new(state: GameState, hash: u64, parent: i32, action: usize) -> Self { + pub fn new(state: GameState, parent: i32, action: usize) -> Self { Node { actions: AtomicPtr::new(std::ptr::null_mut()), num_actions: AtomicU16::new(0), state: AtomicU16::new(u16::from(state)), - hash: AtomicU64::new(hash), parent: AtomicI32::new(parent), bwd_link: AtomicI32::new(-1), fwd_link: AtomicI32::new(-1), @@ -36,10 +34,9 @@ impl Node { } } - pub fn set_new(&self, state: GameState, hash: u64, parent: i32, action: usize) { + pub fn set_new(&self, state: GameState, parent: i32, action: usize) { self.clear(); self.state.store(u16::from(state), Ordering::Relaxed); - self.hash.store(hash, Ordering::Relaxed); self.parent.store(parent, Ordering::Relaxed); self.action.store(action as u16, Ordering::Relaxed); } @@ -72,9 +69,9 @@ impl Node { GameState::from(self.state.load(Ordering::Relaxed)) } - pub fn hash(&self) -> u64 { - self.hash.load(Ordering::Relaxed) - } + //pub fn hash(&self) -> u64 { + // self.hash.load(Ordering::Relaxed) + //} pub fn bwd_link(&self) -> i32 { self.bwd_link.load(Ordering::Relaxed) @@ -115,7 +112,6 @@ impl Node { self.actions.store(std::ptr::null_mut(), Ordering::Relaxed); self.num_actions.store(0, Ordering::Relaxed); self.set_state(GameState::Ongoing); - self.hash.store(0, Ordering::Relaxed); self.set_bwd_link(-1); self.set_fwd_link(-1); } From 5ec47995bda4a4f7b4d2ea9f6a7cd173ba431aa3 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Tue, 16 Jul 2024 23:41:09 +0100 Subject: [PATCH 11/48] Bench: 2261080 --- src/tree.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 1869cfd7..18559928 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -36,13 +36,13 @@ impl std::ops::Index for Tree { impl Tree { pub fn new_mb(mb: usize) -> Self { - let cap = mb * 1024 * 1024 / std::mem::size_of::(); + let cap = mb * 1024 * 1024 / 48; Self::new(cap) } fn new(cap: usize) -> Self { let mut tree = Self { - tree: Vec::new(), + tree: Vec::with_capacity(cap / 8), hash: HashTable::new(cap / 16), root: AtomicI32::new(-1), empty: AtomicI32::new(0), From 1864b6ebe5f17538b3b41d0e4b4a87965c2b59b8 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Tue, 16 Jul 2024 23:49:40 +0100 Subject: [PATCH 12/48] Save space in hash table Bench: 2261080 --- src/mcts.rs | 2 +- src/tree.rs | 4 ++-- src/tree/hash.rs | 16 +++++++--------- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index c41cfde7..fde4061f 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -252,7 +252,7 @@ impl<'a> Searcher<'a> { self.tree.edge(parent, action).update(u); let edge = self.tree.edge(parent, action); - self.tree.push_hash(hash, edge.visits(), edge.q()); + self.tree.push_hash(hash, edge.q()); self.tree.propogate_proven_mates(ptr, child_state); diff --git a/src/tree.rs b/src/tree.rs index 18559928..d789215e 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -100,8 +100,8 @@ impl Tree { self.hash.get(hash) } - pub fn push_hash(&self, hash: u64, visits: i32, wins: f32) { - self.hash.push(hash, visits, wins); + pub fn push_hash(&self, hash: u64, wins: f32) { + self.hash.push(hash, wins); } pub fn delete(&self, ptr: i32) { diff --git a/src/tree/hash.rs b/src/tree/hash.rs index f4bf820e..1ef71f65 100644 --- a/src/tree/hash.rs +++ b/src/tree/hash.rs @@ -1,9 +1,8 @@ -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU32, Ordering}; #[derive(Clone, Copy, Debug, Default)] pub struct HashEntry { - pub hash: u16, - pub visits: i32, + hash: u16, q: u16, } @@ -14,11 +13,11 @@ impl HashEntry { } #[derive(Default)] -struct HashEntryInternal(AtomicU64); +struct HashEntryInternal(AtomicU32); impl Clone for HashEntryInternal { fn clone(&self) -> Self { - Self(AtomicU64::new(self.0.load(Ordering::Relaxed))) + Self(AtomicU32::new(self.0.load(Ordering::Relaxed))) } } @@ -30,7 +29,7 @@ impl From<&HashEntryInternal> for HashEntry { } } -impl From for u64 { +impl From for u32 { fn from(value: HashEntry) -> Self { unsafe { std::mem::transmute(value) @@ -74,15 +73,14 @@ impl HashTable { } } - pub fn push(&self, hash: u64, visits: i32, q: f32) { + pub fn push(&self, hash: u64, q: f32) { let idx = hash % (self.table.len() as u64); let entry = HashEntry { hash: Self::key(hash), - visits, q: (q * f32::from(u16::MAX)) as u16, }; - self.table[idx as usize].0.store(u64::from(entry), Ordering::Relaxed) + self.table[idx as usize].0.store(u32::from(entry), Ordering::Relaxed) } } From 730778a3e7a0c78102804cfccc84ae7f965b7ab8 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Wed, 17 Jul 2024 20:37:28 +0100 Subject: [PATCH 13/48] `AtomicVec` Bench: 2261080 --- src/tree.rs | 1 + src/tree/edge.rs | 4 +- src/tree/node.rs | 49 +++++------------------ src/tree/vec.rs | 100 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 43 deletions(-) create mode 100644 src/tree/vec.rs diff --git a/src/tree.rs b/src/tree.rs index d789215e..fa24a697 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1,6 +1,7 @@ mod edge; mod hash; mod node; +mod vec; pub use edge::Edge; use hash::{HashEntry, HashTable}; diff --git a/src/tree/edge.rs b/src/tree/edge.rs index c668e215..98243116 100644 --- a/src/tree/edge.rs +++ b/src/tree/edge.rs @@ -99,14 +99,12 @@ impl Edge { pub fn update(&self, result: f32) { let r = f64::from(result); - let v = f64::from(self.visits()); + let v = f64::from(self.visits.fetch_add(1, Ordering::Relaxed)); let q = (self.q64() * v + r) / (v + 1.0); let sq_q = (self.sq_q() * v + r.powi(2)) / (v + 1.0); self.q.store((q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); self.sq_q.store((sq_q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); - - self.visits.fetch_add(1, Ordering::Relaxed); } } diff --git a/src/tree/node.rs b/src/tree/node.rs index 390281c8..89465f42 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,17 +1,10 @@ -use std::{ - alloc::{self, Layout}, - sync::atomic::{AtomicI32, AtomicPtr, AtomicU16, Ordering} -}; +use std::sync::atomic::{AtomicI32, AtomicU16, Ordering}; -use crate::{chess::Move, tree::Edge, ChessState, GameState, MctsParams, PolicyNetwork}; - -const EDGE_SIZE: usize = std::mem::size_of::(); -const EDGE_ALIGN: usize = std::mem::align_of::(); +use crate::{chess::Move, tree::{Edge, vec::AtomicVec}, ChessState, GameState, MctsParams, PolicyNetwork}; #[derive(Debug)] pub struct Node { - actions: AtomicPtr, - num_actions: AtomicU16, + actions: AtomicVec, state: AtomicU16, // used for lru @@ -24,8 +17,7 @@ pub struct Node { impl Node { pub fn new(state: GameState, parent: i32, action: usize) -> Self { Node { - actions: AtomicPtr::new(std::ptr::null_mut()), - num_actions: AtomicU16::new(0), + actions: AtomicVec::new(), state: AtomicU16::new(u16::from(state)), parent: AtomicI32::new(parent), bwd_link: AtomicI32::new(-1), @@ -50,29 +42,17 @@ impl Node { } pub fn num_actions(&self) -> usize { - usize::from(self.num_actions.load(Ordering::Relaxed)) + self.actions.len() } pub fn actions(&self) -> &[Edge] { - let ptr = self.actions.load(Ordering::Relaxed); - - if ptr.is_null() { - return &[]; - } - - unsafe { - std::slice::from_raw_parts(ptr, self.num_actions()) - } + self.actions.elements() } pub fn state(&self) -> GameState { GameState::from(self.state.load(Ordering::Relaxed)) } - //pub fn hash(&self) -> u64 { - // self.hash.load(Ordering::Relaxed) - //} - pub fn bwd_link(&self) -> i32 { self.bwd_link.load(Ordering::Relaxed) } @@ -86,7 +66,7 @@ impl Node { } pub fn has_children(&self) -> bool { - !self.actions.load(Ordering::Relaxed).is_null() + self.actions.len() != 0 } pub fn action(&self) -> usize { @@ -103,14 +83,7 @@ impl Node { } pub fn clear(&self) { - let ptr = self.actions.load(Ordering::Relaxed); - let layout = Layout::from_size_align(EDGE_SIZE * self.num_actions(), EDGE_ALIGN).unwrap(); - unsafe { - alloc::dealloc(ptr.cast(), layout); - } - - self.actions.store(std::ptr::null_mut(), Ordering::Relaxed); - self.num_actions.store(0, Ordering::Relaxed); + self.actions.clear(); self.set_state(GameState::Ongoing); self.set_bwd_link(-1); self.set_fwd_link(-1); @@ -157,11 +130,7 @@ impl Node { } if num != 0 { - let layout = Layout::from_size_align(EDGE_SIZE * num, EDGE_ALIGN).unwrap(); - let ptr = unsafe { alloc::alloc_zeroed(layout) }; - - self.num_actions.store(num as u16, Ordering::Relaxed); - self.actions.store(ptr.cast(), Ordering::Release); + self.actions.alloc(num); } for (action, &(mov, policy)) in self.actions().iter().zip(moves[..num].iter()) { diff --git a/src/tree/vec.rs b/src/tree/vec.rs new file mode 100644 index 00000000..cb8a5012 --- /dev/null +++ b/src/tree/vec.rs @@ -0,0 +1,100 @@ +use std::{ + alloc::{self, Layout}, sync::atomic::{AtomicPtr, AtomicU16, Ordering} +}; + +use super::Edge; + +const EDGE_SIZE: usize = std::mem::size_of::(); +const EDGE_ALIGN: usize = std::mem::align_of::(); + +#[derive(Debug)] +pub struct AtomicVec { + ptr: AtomicPtr, + len: AtomicU16, + cap: AtomicU16, +} + +impl Drop for AtomicVec { + fn drop(&mut self) { + self.dealloc(); + } +} + +impl AtomicVec { + pub fn new() -> Self { + Self { + ptr: AtomicPtr::new(std::ptr::null_mut()), + len: AtomicU16::new(0), + cap: AtomicU16::new(0), + } + } + + pub fn alloc(&self, len: usize) { + if self.cap() > len { + self.len.store(len as u16, Ordering::Relaxed); + return; + } + + if len == 0 { + return; + } + + self.dealloc(); + + self.len.store(len as u16, Ordering::Relaxed); + self.cap.store(len as u16, Ordering::Relaxed); + + let layout = Layout::from_size_align(EDGE_SIZE * self.cap(), EDGE_ALIGN).unwrap(); + + unsafe { + let ptr = alloc::alloc(layout); + self.ptr.store(ptr.cast(), Ordering::Relaxed); + } + } + + pub fn dealloc(&self) { + let ptr = self.ptr(); + + if ptr.is_null() { + return; + } + + let layout = Layout::from_size_align(EDGE_SIZE * self.cap(), EDGE_ALIGN).unwrap(); + + self.ptr.store(std::ptr::null_mut(), Ordering::Relaxed); + self.len.store(0, Ordering::Relaxed); + self.cap.store(0, Ordering::Relaxed); + + unsafe { + alloc::dealloc(ptr.cast(), layout); + } + } + + pub fn clear(&self) { + self.len.store(0, Ordering::Relaxed); + } + + fn ptr(&self) -> *mut Edge { + self.ptr.load(Ordering::Relaxed) + } + + fn cap(&self) -> usize { + usize::from(self.cap.load(Ordering::Relaxed)) + } + + pub fn len(&self) -> usize { + usize::from(self.len.load(Ordering::Relaxed)) + } + + pub fn elements(&self) -> &[Edge] { + let ptr = self.ptr.load(Ordering::Relaxed); + + if ptr.is_null() { + return &[]; + } + + unsafe { + std::slice::from_raw_parts(ptr, self.len()) + } + } +} \ No newline at end of file From 5d665d2a723a3b9ce90df52cd353c16b7d2f41e3 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Thu, 18 Jul 2024 00:00:39 +0100 Subject: [PATCH 14/48] Bench: 2261080 --- datagen/src/thread.rs | 2 +- src/mcts.rs | 37 +++++++--------- src/tree.rs | 44 +++++++++++++++---- src/tree/node.rs | 57 +++++++++++++----------- src/tree/vec.rs | 100 ------------------------------------------ 5 files changed, 84 insertions(+), 156 deletions(-) delete mode 100644 src/tree/vec.rs diff --git a/datagen/src/thread.rs b/datagen/src/thread.rs index e658c65f..210e8a54 100644 --- a/datagen/src/thread.rs +++ b/datagen/src/thread.rs @@ -139,7 +139,7 @@ impl<'a> DatagenThread<'a> { if root_count <= 112 { let mut policy_pos = PolicyData::new(position.clone(), bm, score); - for action in tree[tree.root_node()].actions() { + for action in tree[tree.root_node()].actions().iter() { policy_pos.push(action.mov().into(), action.visits()); } diff --git a/src/mcts.rs b/src/mcts.rs index fde4061f..d75008cd 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -177,7 +177,7 @@ impl<'a> Searcher<'a> { } let best_action = self.tree.get_best_child(self.tree.root_node()); - let best_child = &self.tree.edge(self.tree.root_node(), best_action); + let best_child = self.tree.edge_copy(self.tree.root_node(), best_action); (Move::from(best_child.mov()), best_child.q()) } @@ -207,7 +207,7 @@ impl<'a> Searcher<'a> { let action = self.tree[ptr].action(); let mut child_state = GameState::Ongoing; - let pvisits = self.tree.edge(parent, action).visits(); + let pvisits = self.tree.get_edge_visits(parent, action); let mut u = if self.tree[ptr].is_terminal() || pvisits == 0 { // probe hash table to use in place of network @@ -229,7 +229,7 @@ impl<'a> Searcher<'a> { // select action to take via PUCT let action = self.pick_action(ptr); - let edge = self.tree.edge(ptr, action); + let edge = self.tree.edge_copy(ptr, action); pos.make_move(Move::from(edge.mov())); let mut child_ptr = edge.ptr(); @@ -238,7 +238,7 @@ impl<'a> Searcher<'a> { if child_ptr == -1 { let state = pos.game_state(); child_ptr = self.tree.push_new(state, ptr, action); - self.tree.edge(ptr, action).set_ptr(child_ptr); + self.tree.set_edge_ptr(ptr, action, child_ptr); } let u = self.perform_one_iteration(pos, child_ptr, depth); @@ -249,10 +249,9 @@ impl<'a> Searcher<'a> { // flip perspective of score u = 1.0 - u; - self.tree.edge(parent, action).update(u); + let new_q = self.tree.update_edge(parent, action, u); - let edge = self.tree.edge(parent, action); - self.tree.push_hash(hash, edge.q()); + self.tree.push_hash(hash, new_q); self.tree.propogate_proven_mates(ptr, child_state); @@ -276,12 +275,12 @@ impl<'a> Searcher<'a> { } let node = &self.tree[ptr]; - let edge = self.tree.edge(node.parent(), node.action()); + let edge = self.tree.edge_copy(node.parent(), node.action()); let is_root = edge.ptr() == self.tree.root_node(); - let cpuct = SearchHelpers::get_cpuct(self.params, edge, is_root); - let fpu = SearchHelpers::get_fpu(edge); - let expl_scale = SearchHelpers::get_explore_scaling(self.params, edge); + let cpuct = SearchHelpers::get_cpuct(self.params, &edge, is_root); + let fpu = SearchHelpers::get_fpu(&edge); + let expl_scale = SearchHelpers::get_explore_scaling(self.params, &edge); let expl = cpuct * expl_scale; @@ -324,7 +323,7 @@ impl<'a> Searcher<'a> { let mate = self.tree[self.tree.root_node()].is_terminal(); let idx = self.tree.get_best_child(self.tree.root_node()); - let mut action = self.tree.edge(self.tree.root_node(), idx); + let mut action = self.tree.edge_copy(self.tree.root_node(), idx); let score = if action.ptr() != -1 { match self.tree[action.ptr()].state() { @@ -347,21 +346,21 @@ impl<'a> Searcher<'a> { break; } - action = self.tree.edge(action.ptr(), idx); + action = self.tree.edge_copy(action.ptr(), idx); depth = depth.saturating_sub(1); } (pv, score) } - fn get_best_action(&self) -> &Edge { + fn get_best_action(&self) -> Edge { let idx = self.tree.get_best_child(self.tree.root_node()); - self.tree.edge(self.tree.root_node(), idx) + self.tree.edge_copy(self.tree.root_node(), idx) } fn get_best_move(&self) -> Move { let idx = self.tree.get_best_child(self.tree.root_node()); - let action = self.tree.edge(self.tree.root_node(), idx); + let action = self.tree.edge_copy(self.tree.root_node(), idx); Move::from(action.mov()) } @@ -369,12 +368,8 @@ impl<'a> Searcher<'a> { -400.0 * (1.0 / score.clamp(0.0, 1.0) - 1.0).ln() } - //pub fn tree_and_board(self) -> (Tree, ChessState) { - // (self.tree, self.root_position) - //} - pub fn display_moves(&self) { - for action in self.tree[self.tree.root_node()].actions() { + for action in self.tree[self.tree.root_node()].actions().iter() { let mov = self.root_position.conv_mov_to_str(action.mov().into()); let q = action.q() * 100.0; println!("{mov} -> {q:.2}%"); diff --git a/src/tree.rs b/src/tree.rs index fa24a697..45f03c31 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1,7 +1,6 @@ mod edge; mod hash; mod node; -mod vec; pub use edge::Edge; use hash::{HashEntry, HashTable}; @@ -77,7 +76,7 @@ impl Tree { let parent = self[new].parent(); let action = self[new].action(); - self.edge(parent, action).set_ptr(-1); + self.set_edge_ptr(parent, action, -1); self.delete(new); } @@ -197,17 +196,44 @@ impl Tree { pub fn make_root_node(&mut self, node: i32) { self.root.store(node, Ordering::Relaxed); - self.parent_edge = self.edge(self[node].parent(), self[node].action()).clone(); + self.parent_edge = self.edge_copy(self[node].parent(), self[node].action()); self[node].clear_parent(); self[node].set_state(GameState::Ongoing); } - pub fn edge(&self, ptr: i32, idx: usize) -> &Edge { + pub fn edge_copy(&self, ptr: i32, idx: usize) -> Edge { if ptr == -1 { + self.parent_edge.clone() + } else { + self[ptr].actions()[idx].clone() + } + } + + pub fn set_edge_ptr(&self, ptr: i32, idx: usize, set: i32) { + if ptr == -1 { + self.parent_edge.set_ptr(set); + } else { + self[ptr].actions()[idx].set_ptr(set); + } + } + + pub fn get_edge_visits(&self, ptr: i32, idx: usize) -> i32 { + if ptr == -1 { + self.parent_edge.visits() + } else { + self[ptr].actions()[idx].visits() + } + } + + pub fn update_edge(&self, ptr: i32, idx: usize, u: f32) -> f32 { + let edge = if ptr == -1 { &self.parent_edge } else { &self[ptr].actions()[idx] - } + }; + + edge.update(u); + edge.q() } pub fn propogate_proven_mates(&self, ptr: i32, child_state: GameState) { @@ -220,7 +246,7 @@ impl Tree { GameState::Won(n) => { let mut proven_loss = true; let mut max_win_len = n; - for action in self[ptr].actions() { + for action in self[ptr].actions().iter() { if action.ptr() == -1 { proven_loss = false; break; @@ -301,7 +327,7 @@ impl Tree { let node = &self.tree[start as usize]; - for action in node.actions() { + for action in node.actions().iter() { let child_idx = action.ptr(); let mut child_board = this_board.clone(); @@ -396,9 +422,9 @@ impl Tree { } let mut active = Vec::new(); - for action in node.actions() { + for action in node.actions().iter() { if action.ptr() != -1 { - active.push(action); + active.push(action.clone()); } } diff --git a/src/tree/node.rs b/src/tree/node.rs index 89465f42..bd1ccf22 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,10 +1,10 @@ -use std::sync::atomic::{AtomicI32, AtomicU16, Ordering}; +use std::sync::{atomic::{AtomicI32, AtomicU16, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard}; -use crate::{chess::Move, tree::{Edge, vec::AtomicVec}, ChessState, GameState, MctsParams, PolicyNetwork}; +use crate::{chess::Move, tree::Edge, ChessState, GameState, MctsParams, PolicyNetwork}; #[derive(Debug)] pub struct Node { - actions: AtomicVec, + actions: RwLock>, state: AtomicU16, // used for lru @@ -17,7 +17,7 @@ pub struct Node { impl Node { pub fn new(state: GameState, parent: i32, action: usize) -> Self { Node { - actions: AtomicVec::new(), + actions: RwLock::new(Vec::new()), state: AtomicU16::new(u16::from(state)), parent: AtomicI32::new(parent), bwd_link: AtomicI32::new(-1), @@ -42,11 +42,15 @@ impl Node { } pub fn num_actions(&self) -> usize { - self.actions.len() + self.actions.read().unwrap().len() } - pub fn actions(&self) -> &[Edge] { - self.actions.elements() + pub fn actions(&self) -> RwLockReadGuard> { + self.actions.read().unwrap() + } + + fn actions_mut(&self) -> RwLockWriteGuard> { + self.actions.write().unwrap() } pub fn state(&self) -> GameState { @@ -66,7 +70,7 @@ impl Node { } pub fn has_children(&self) -> bool { - self.actions.len() != 0 + self.actions.read().unwrap().len() != 0 } pub fn action(&self) -> usize { @@ -83,7 +87,7 @@ impl Node { } pub fn clear(&self) { - self.actions.clear(); + self.actions.write().unwrap().clear(); self.set_state(GameState::Ongoing); self.set_bwd_link(-1); self.set_fwd_link(-1); @@ -107,34 +111,37 @@ impl Node { let feats = pos.get_policy_feats(); let mut max = f32::NEG_INFINITY; - let mut moves = [(0, 0.0); 256]; - let mut num = 0; + + let mut actions = self.actions_mut(); pos.map_legal_moves(|mov| { let policy = pos.get_policy(mov, &feats, policy); + + // trick for calculating policy before quantising + actions.push(Edge::new(f32::to_bits(policy) as i32, mov.into(), 0)); max = max.max(policy); - moves[num] = (mov.into(), policy); - num += 1; }); let mut total = 0.0; - for (_, policy) in moves[..num].iter_mut() { - *policy = if ROOT { - ((*policy - max) / params.root_pst()).exp() + for action in actions.iter_mut() { + let mut policy = f32::from_bits(action.ptr() as u32); + + policy = if ROOT { + ((policy - max) / params.root_pst()).exp() } else { - (*policy - max).exp() + (policy - max).exp() }; - total += *policy; - } + action.set_ptr(f32::to_bits(policy) as i32); - if num != 0 { - self.actions.alloc(num); + total += policy; } - for (action, &(mov, policy)) in self.actions().iter().zip(moves[..num].iter()) { - action.set_new(mov, policy / total); + for action in actions.iter_mut() { + let policy = f32::from_bits(action.ptr() as u32) / total; + action.set_ptr(-1); + action.set_policy(policy); } } @@ -149,7 +156,7 @@ impl Node { let mut policies = Vec::new(); - for action in self.actions() { + for action in self.actions().iter() { let mov = Move::from(action.mov()); let policy = pos.get_policy(mov, &feats, policy); policies.push(policy); @@ -163,7 +170,7 @@ impl Node { total += *policy; } - for (i, action) in self.actions().iter().enumerate() { + for (i, action) in self.actions_mut().iter_mut().enumerate() { action.set_policy(policies[i] / total); } } diff --git a/src/tree/vec.rs b/src/tree/vec.rs deleted file mode 100644 index cb8a5012..00000000 --- a/src/tree/vec.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::{ - alloc::{self, Layout}, sync::atomic::{AtomicPtr, AtomicU16, Ordering} -}; - -use super::Edge; - -const EDGE_SIZE: usize = std::mem::size_of::(); -const EDGE_ALIGN: usize = std::mem::align_of::(); - -#[derive(Debug)] -pub struct AtomicVec { - ptr: AtomicPtr, - len: AtomicU16, - cap: AtomicU16, -} - -impl Drop for AtomicVec { - fn drop(&mut self) { - self.dealloc(); - } -} - -impl AtomicVec { - pub fn new() -> Self { - Self { - ptr: AtomicPtr::new(std::ptr::null_mut()), - len: AtomicU16::new(0), - cap: AtomicU16::new(0), - } - } - - pub fn alloc(&self, len: usize) { - if self.cap() > len { - self.len.store(len as u16, Ordering::Relaxed); - return; - } - - if len == 0 { - return; - } - - self.dealloc(); - - self.len.store(len as u16, Ordering::Relaxed); - self.cap.store(len as u16, Ordering::Relaxed); - - let layout = Layout::from_size_align(EDGE_SIZE * self.cap(), EDGE_ALIGN).unwrap(); - - unsafe { - let ptr = alloc::alloc(layout); - self.ptr.store(ptr.cast(), Ordering::Relaxed); - } - } - - pub fn dealloc(&self) { - let ptr = self.ptr(); - - if ptr.is_null() { - return; - } - - let layout = Layout::from_size_align(EDGE_SIZE * self.cap(), EDGE_ALIGN).unwrap(); - - self.ptr.store(std::ptr::null_mut(), Ordering::Relaxed); - self.len.store(0, Ordering::Relaxed); - self.cap.store(0, Ordering::Relaxed); - - unsafe { - alloc::dealloc(ptr.cast(), layout); - } - } - - pub fn clear(&self) { - self.len.store(0, Ordering::Relaxed); - } - - fn ptr(&self) -> *mut Edge { - self.ptr.load(Ordering::Relaxed) - } - - fn cap(&self) -> usize { - usize::from(self.cap.load(Ordering::Relaxed)) - } - - pub fn len(&self) -> usize { - usize::from(self.len.load(Ordering::Relaxed)) - } - - pub fn elements(&self) -> &[Edge] { - let ptr = self.ptr.load(Ordering::Relaxed); - - if ptr.is_null() { - return &[]; - } - - unsafe { - std::slice::from_raw_parts(ptr, self.len()) - } - } -} \ No newline at end of file From af94337c86806cc373178f2f52b8b06f1de21388 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Fri, 19 Jul 2024 00:06:20 +0100 Subject: [PATCH 15/48] Start work on lrc --- src/mcts.rs | 80 +++++--------- src/mcts/helpers.rs | 18 +-- src/tree.rs | 260 ++++++++++---------------------------------- src/tree/edge.rs | 73 +++++-------- src/tree/half.rs | 50 +++++++++ src/tree/node.rs | 61 ++--------- src/tree/ptr.rs | 30 +++++ src/tree/stats.rs | 67 ++++++++++++ src/uci.rs | 5 - 9 files changed, 281 insertions(+), 363 deletions(-) create mode 100644 src/tree/half.rs create mode 100644 src/tree/ptr.rs create mode 100644 src/tree/stats.rs diff --git a/src/mcts.rs b/src/mcts.rs index d75008cd..606b54ee 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -6,7 +6,7 @@ pub use params::MctsParams; use crate::{ chess::Move, - tree::{Edge, Tree}, + tree::{ActionStats, Edge, NodePtr, Tree}, ChessState, GameState, PolicyNetwork, ValueNetwork, }; @@ -81,7 +81,15 @@ impl<'a> Searcher<'a> { loop { let mut pos = self.root_position.clone(); let mut this_depth = 0; - self.perform_one_iteration(&mut pos, self.tree.root_node(), &mut this_depth); + + let u = self.perform_one_iteration( + &mut pos, + self.tree.root_node(), + self.tree.root_stats(), + &mut this_depth, + ); + + self.tree.root_stats().update(u); cumulative_depth += this_depth - 1; @@ -181,35 +189,14 @@ impl<'a> Searcher<'a> { (Move::from(best_child.mov()), best_child.q()) } - pub fn secondary_search(&mut self) { - loop { - let mut pos = self.root_position.clone(); - self.perform_one_iteration(&mut pos, self.tree.root_node(), &mut 0); - - if self.abort.load(Ordering::Relaxed) { - break; - } - - if self.tree[self.tree.root_node()].is_terminal() { - self.abort.store(true, Ordering::Relaxed); - break; - } - } - } - - fn perform_one_iteration(&mut self, pos: &mut ChessState, ptr: i32, depth: &mut usize) -> f32 { + fn perform_one_iteration(&mut self, pos: &mut ChessState, ptr: NodePtr, node_stats: &ActionStats, depth: &mut usize) -> f32 { *depth += 1; - self.tree.make_recently_used(ptr); - let hash = pos.hash(); - let parent = self.tree[ptr].parent(); - let action = self.tree[ptr].action(); let mut child_state = GameState::Ongoing; - let pvisits = self.tree.get_edge_visits(parent, action); - let mut u = if self.tree[ptr].is_terminal() || pvisits == 0 { + let u = if self.tree[ptr].is_terminal() || node_stats.visits() == 0 { // probe hash table to use in place of network if self.tree[ptr].state() == GameState::Ongoing { if let Some(entry) = self.tree.probe_hash(hash) { @@ -227,7 +214,7 @@ impl<'a> Searcher<'a> { } // select action to take via PUCT - let action = self.pick_action(ptr); + let action = self.pick_action(ptr, &node_stats); let edge = self.tree.edge_copy(ptr, action); pos.make_move(Move::from(edge.mov())); @@ -235,32 +222,28 @@ impl<'a> Searcher<'a> { let mut child_ptr = edge.ptr(); // create and push node if not present - if child_ptr == -1 { + if child_ptr.is_null() { let state = pos.game_state(); - child_ptr = self.tree.push_new(state, ptr, action); + child_ptr = self.tree.push_new(state); self.tree.set_edge_ptr(ptr, action, child_ptr); } - let u = self.perform_one_iteration(pos, child_ptr, depth); + let u = self.perform_one_iteration(pos, child_ptr, &edge.stats(), depth); + + let new_q = self.tree.update_edge_stats(ptr, action, u); + self.tree.push_hash(hash, new_q); + child_state = self.tree[child_ptr].state(); u }; - // flip perspective of score - u = 1.0 - u; - let new_q = self.tree.update_edge(parent, action, u); - - self.tree.push_hash(hash, new_q); - self.tree.propogate_proven_mates(ptr, child_state); - self.tree.make_recently_used(ptr); - - u + 1.0 - u } - fn get_utility(&self, ptr: i32, pos: &ChessState) -> f32 { + fn get_utility(&self, ptr: NodePtr, pos: &ChessState) -> f32 { match self.tree[ptr].state() { GameState::Ongoing => pos.get_value_wdl(self.value, self.params), GameState::Draw => 0.5, @@ -269,18 +252,16 @@ impl<'a> Searcher<'a> { } } - fn pick_action(&self, ptr: i32) -> usize { + fn pick_action(&self, ptr: NodePtr, node_stats: &ActionStats) -> usize { if !self.tree[ptr].has_children() { panic!("trying to pick from no children!"); } - let node = &self.tree[ptr]; - let edge = self.tree.edge_copy(node.parent(), node.action()); - let is_root = edge.ptr() == self.tree.root_node(); + let is_root = ptr == self.tree.root_node(); - let cpuct = SearchHelpers::get_cpuct(self.params, &edge, is_root); - let fpu = SearchHelpers::get_fpu(&edge); - let expl_scale = SearchHelpers::get_explore_scaling(self.params, &edge); + let cpuct = SearchHelpers::get_cpuct(self.params, &node_stats, is_root); + let fpu = SearchHelpers::get_fpu(&node_stats); + let expl_scale = SearchHelpers::get_explore_scaling(self.params, &node_stats); let expl = cpuct * expl_scale; @@ -308,9 +289,8 @@ impl<'a> Searcher<'a> { let elapsed = timer.elapsed(); let nps = nodes as f32 / elapsed.as_secs_f32(); let ms = elapsed.as_millis(); - let hf = self.tree.len() * 1000 / self.tree.cap(); - print!("time {ms} nodes {nodes} nps {nps:.0} hashfull {hf} pv"); + print!("time {ms} nodes {nodes} nps {nps:.0} pv"); for mov in pv_line { print!(" {}", self.root_position.conv_mov_to_str(mov)); @@ -325,7 +305,7 @@ impl<'a> Searcher<'a> { let idx = self.tree.get_best_child(self.tree.root_node()); let mut action = self.tree.edge_copy(self.tree.root_node(), idx); - let score = if action.ptr() != -1 { + let score = if action.ptr().is_null() { match self.tree[action.ptr()].state() { GameState::Lost(_) => 1.1, GameState::Won(_) => -0.1, @@ -338,7 +318,7 @@ impl<'a> Searcher<'a> { let mut pv = Vec::new(); - while (mate || depth > 0) && action.ptr() != -1 { + while (mate || depth > 0) && action.ptr().is_null() { pv.push(Move::from(action.mov())); let idx = self.tree.get_best_child(action.ptr()); diff --git a/src/mcts/helpers.rs b/src/mcts/helpers.rs index 7c7c4c79..5443fc71 100644 --- a/src/mcts/helpers.rs +++ b/src/mcts/helpers.rs @@ -1,4 +1,4 @@ -use crate::{mcts::MctsParams, tree::Edge}; +use crate::{mcts::MctsParams, tree::{ActionStats, Edge}}; pub struct SearchHelpers; @@ -6,7 +6,7 @@ impl SearchHelpers { /// CPUCT /// /// Larger value implies more exploration. - pub fn get_cpuct(params: &MctsParams, parent: &Edge, is_root: bool) -> f32 { + pub fn get_cpuct(params: &MctsParams, node_stats: &ActionStats, is_root: bool) -> f32 { // baseline CPUCT value let mut cpuct = if is_root { params.root_cpuct() @@ -16,11 +16,11 @@ impl SearchHelpers { // scale CPUCT as visits increase let scale = params.cpuct_visits_scale() * 128.0; - cpuct *= 1.0 + ((parent.visits() as f32 + scale) / scale).ln(); + cpuct *= 1.0 + ((node_stats.visits() as f32 + scale) / scale).ln(); // scale CPUCT with variance of Q - if parent.visits() > 1 { - let frac = parent.var().sqrt() / params.cpuct_var_scale(); + if node_stats.visits() > 1 { + let frac = node_stats.var().sqrt() / params.cpuct_var_scale(); cpuct *= 1.0 + params.cpuct_var_weight() * (frac - 1.0); } @@ -30,16 +30,16 @@ impl SearchHelpers { /// Exploration Scaling /// /// Larger value implies more exploration. - pub fn get_explore_scaling(params: &MctsParams, parent: &Edge) -> f32 { - (params.expl_tau() * (parent.visits().max(1) as f32).ln()).exp() + pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats) -> f32 { + (params.expl_tau() * (node_stats.visits().max(1) as f32).ln()).exp() } /// First Play Urgency /// /// #### Note /// Must return a value in [0, 1]. - pub fn get_fpu(parent: &Edge) -> f32 { - 1.0 - parent.q() + pub fn get_fpu(node_stats: &ActionStats) -> f32 { + 1.0 - node_stats.q() } /// Get a predicted win probability for an action diff --git a/src/tree.rs b/src/tree.rs index 45f03c31..fdd2ac8e 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1,12 +1,19 @@ mod edge; +mod half; mod hash; mod node; +mod ptr; +mod stats; pub use edge::Edge; +use half::TreeHalf; use hash::{HashEntry, HashTable}; pub use node::Node; +pub use ptr::NodePtr; +pub use stats::ActionStats; + use std::{ - sync::atomic::{AtomicI32, AtomicUsize, Ordering}, + sync::atomic::{AtomicBool, Ordering}, time::Instant, }; @@ -16,21 +23,17 @@ use crate::{ }; pub struct Tree { - tree: Vec, + tree: [TreeHalf; 2], + half: AtomicBool, hash: HashTable, - root: AtomicI32, - empty: AtomicI32, - used: AtomicUsize, - lru_head: AtomicI32, - lru_tail: AtomicI32, - parent_edge: Edge, + root_stats: ActionStats, } -impl std::ops::Index for Tree { +impl std::ops::Index for Tree { type Output = Node; - fn index(&self, index: i32) -> &Self::Output { - &self.tree[index as usize] + fn index(&self, index: NodePtr) -> &Self::Output { + &self.tree[usize::from(index.half())][index] } } @@ -41,202 +44,58 @@ impl Tree { } fn new(cap: usize) -> Self { - let mut tree = Self { - tree: Vec::with_capacity(cap / 8), + let tree_size = cap / 8; + + Self { + tree: [ + TreeHalf::new(tree_size, false), + TreeHalf::new(tree_size, true), + ], + half: AtomicBool::new(false), hash: HashTable::new(cap / 16), - root: AtomicI32::new(-1), - empty: AtomicI32::new(0), - used: AtomicUsize::new(0), - lru_head: AtomicI32::new(-1), - lru_tail: AtomicI32::new(-1), - parent_edge: Edge::new(0, 0, 0), - }; - - for _ in 0..cap / 8 { - tree.tree.push(Node::new(GameState::Ongoing, -1, 0)); - } - - let end = tree.cap() as i32 - 1; - - for i in 0..end { - tree[i].set_fwd_link(i + 1); - } - - tree[end].set_fwd_link(-1); - - tree - } - - pub fn push_new(&self, state: GameState, parent: i32, action: usize) -> i32 { - let mut new = self.empty.load(Ordering::Relaxed); - - // tree is full, do some LRU pruning - if new == -1 { - new = self.lru_tail.load(Ordering::Relaxed); - let parent = self[new].parent(); - let action = self[new].action(); - - self.set_edge_ptr(parent, action, -1); - - self.delete(new); - } - - assert_ne!(new, -1); - - self.used.fetch_add(1, Ordering::Relaxed); - self.empty.store(self[self.empty.load(Ordering::Relaxed)].fwd_link(), Ordering::Relaxed); - self[new].set_new(state, parent, action); - - self.append_to_lru(new); - - if self.used.load(Ordering::Relaxed) == 1 { - self.lru_tail.store(new, Ordering::Relaxed); - } - - new - } - - pub fn probe_hash(&self, hash: u64) -> Option { - self.hash.get(hash) - } - - pub fn push_hash(&self, hash: u64, wins: f32) { - self.hash.push(hash, wins); - } - - pub fn delete(&self, ptr: i32) { - self.remove_from_lru(ptr); - self[ptr].clear(); - - let empty = self.empty.load(Ordering::Relaxed); - self[ptr].set_fwd_link(empty); - - self.empty.store(ptr, Ordering::Relaxed); - let used = self.used.fetch_sub(1, Ordering::Relaxed); - assert!(used - 1 < self.cap()); - } - - pub fn make_recently_used(&self, ptr: i32) { - self.remove_from_lru(ptr); - self.append_to_lru(ptr); - } - - fn append_to_lru(&self, ptr: i32) { - let old_head = self.lru_head.load(Ordering::Relaxed); - if old_head != -1 { - self[old_head].set_bwd_link(ptr); - } - self.lru_head.store(ptr, Ordering::Relaxed); - self[ptr].set_fwd_link(old_head); - self[ptr].set_bwd_link(-1); - } - - fn remove_from_lru(&self, ptr: i32) { - let bwd = self[ptr].bwd_link(); - let fwd = self[ptr].fwd_link(); - - if bwd != -1 { - self[bwd].set_fwd_link(fwd); - } else { - self.lru_head.store(fwd, Ordering::Relaxed); + root_stats: ActionStats::default(), } - - if fwd != -1 { - self[fwd].set_bwd_link(bwd); - } else { - self.lru_tail.store(bwd, Ordering::Relaxed); - } - - self[ptr].set_bwd_link(-1); - self[ptr].set_fwd_link(-1); - } - - pub fn root_node(&self) -> i32 { - self.root.load(Ordering::Relaxed) - } - - pub fn cap(&self) -> usize { - self.tree.len() } - pub fn len(&self) -> usize { - self.used.load(Ordering::Relaxed) + fn half(&self) -> usize { + usize::from(self.half.load(Ordering::Relaxed)) } - pub fn remaining(&self) -> usize { - self.cap() - self.len() + pub fn push_new(&self, state: GameState) -> NodePtr { + self.tree[self.half()].push_new(state) } - pub fn is_empty(&self) -> bool { - self.len() == 0 + pub fn root_node(&self) -> NodePtr { + NodePtr::new(self.half.load(Ordering::Relaxed), 0) } - pub fn clear(&mut self) { - if self.is_empty() { - return; - } - - self.hash.clear(); - self.root.store(-1, Ordering::Relaxed); - self.empty.store(0, Ordering::Relaxed); - self.used.store(0, Ordering::Relaxed); - self.lru_head.store(-1, Ordering::Relaxed); - self.lru_tail.store(-1, Ordering::Relaxed); - self.parent_edge = Edge::new(0, 0, 0); - - let end = self.cap() as i32 - 1; - - for i in 0..end { - self[i].set_new(GameState::Ongoing, -1, 0); - self[i].set_fwd_link(i + 1); - } - - self[end].set_fwd_link(-1); + pub fn root_stats(&self) -> &ActionStats { + &self.root_stats } - pub fn make_root_node(&mut self, node: i32) { - self.root.store(node, Ordering::Relaxed); - self.parent_edge = self.edge_copy(self[node].parent(), self[node].action()); - self[node].clear_parent(); - self[node].set_state(GameState::Ongoing); + pub fn edge_copy(&self, ptr: NodePtr, action: usize) -> Edge { + self[ptr].actions()[action].clone() } - pub fn edge_copy(&self, ptr: i32, idx: usize) -> Edge { - if ptr == -1 { - self.parent_edge.clone() - } else { - self[ptr].actions()[idx].clone() - } + pub fn set_edge_ptr(&self, ptr: NodePtr, action: usize, set: NodePtr) { + self[ptr].actions()[action].set_ptr(set); } - pub fn set_edge_ptr(&self, ptr: i32, idx: usize, set: i32) { - if ptr == -1 { - self.parent_edge.set_ptr(set); - } else { - self[ptr].actions()[idx].set_ptr(set); - } + pub fn update_edge_stats(&self, ptr: NodePtr, action: usize, result: f32) -> f32 { + let edge = &self[ptr].actions()[action]; + edge.update(result); + edge.q() } - pub fn get_edge_visits(&self, ptr: i32, idx: usize) -> i32 { - if ptr == -1 { - self.parent_edge.visits() - } else { - self[ptr].actions()[idx].visits() - } + pub fn probe_hash(&self, hash: u64) -> Option { + self.hash.get(hash) } - pub fn update_edge(&self, ptr: i32, idx: usize, u: f32) -> f32 { - let edge = if ptr == -1 { - &self.parent_edge - } else { - &self[ptr].actions()[idx] - }; - - edge.update(u); - edge.q() + pub fn push_hash(&self, hash: u64, wins: f32) { + self.hash.push(hash, wins); } - pub fn propogate_proven_mates(&self, ptr: i32, child_state: GameState) { + pub fn propogate_proven_mates(&self, ptr: NodePtr, child_state: GameState) { match child_state { // if the child node resulted in a loss, then // this node has a guaranteed win @@ -247,7 +106,7 @@ impl Tree { let mut proven_loss = true; let mut max_win_len = n; for action in self[ptr].actions().iter() { - if action.ptr() == -1 { + if action.ptr().is_null() { proven_loss = false; break; } else if let GameState::Won(n) = self[action.ptr()].state() { @@ -271,8 +130,7 @@ impl Tree { let t = Instant::now(); if self.is_empty() { - let node = self.push_new(GameState::Ongoing, -1, 0); - self.make_root_node(node); + let node = self.push_new(GameState::Ongoing); return; } @@ -286,7 +144,7 @@ impl Tree { let root = self.recurse_find(self.root_node(), board, root, 2); - if root != -1 && self[root].has_children() { + if root.is_null() && self[root].has_children() { found = true; if root != self.root_node() { @@ -312,20 +170,20 @@ impl Tree { fn recurse_find( &self, - start: i32, + start: NodePtr, this_board: &ChessState, board: &ChessState, depth: u8, - ) -> i32 { + ) -> NodePtr { if this_board.is_same(board) { return start; } - if start == -1 || depth == 0 { - return -1; + if start.is_null() || depth == 0 { + return NodePtr::NULL; } - let node = &self.tree[start as usize]; + let node = &self[start]; for action in node.actions().iter() { let child_idx = action.ptr(); @@ -335,15 +193,15 @@ impl Tree { let found = self.recurse_find(child_idx, &child_board, board, depth - 1); - if found != -1 { + if found.is_null() { return found; } } - -1 + NodePtr::NULL } - pub fn get_best_child_by_key f32>(&self, ptr: i32, mut key: F) -> usize { + pub fn get_best_child_by_key f32>(&self, ptr: NodePtr, mut key: F) -> usize { let mut best_child = usize::MAX; let mut best_score = f32::NEG_INFINITY; @@ -359,11 +217,11 @@ impl Tree { best_child } - pub fn get_best_child(&self, ptr: i32) -> usize { + pub fn get_best_child(&self, ptr: NodePtr) -> usize { self.get_best_child_by_key(ptr, |child| { if child.visits() == 0 { f32::NEG_INFINITY - } else if child.ptr() != -1 { + } else if child.ptr().is_null() { match self[child.ptr()].state() { GameState::Lost(n) => 1.0 + f32::from(n), GameState::Won(n) => f32::from(n) - 256.0, @@ -376,7 +234,7 @@ impl Tree { }) } - pub fn display(&self, idx: i32, depth: usize) { + pub fn display(&self, idx: NodePtr, depth: usize) { let mut bars = vec![true; depth + 1]; self.display_recurse(&Edge::new(idx, 0, 0), depth + 1, 0, &mut bars); } @@ -423,7 +281,7 @@ impl Tree { let mut active = Vec::new(); for action in node.actions().iter() { - if action.ptr() != -1 { + if action.ptr().is_null() { active.push(action.clone()); } } diff --git a/src/tree/edge.rs b/src/tree/edge.rs index 98243116..3cd8dd18 100644 --- a/src/tree/edge.rs +++ b/src/tree/edge.rs @@ -1,24 +1,22 @@ -use std::sync::atomic::{AtomicI32, AtomicU16, AtomicU32, AtomicI16, Ordering}; +use std::sync::atomic::{AtomicU16, AtomicU32, AtomicI16, Ordering}; + +use super::{ActionStats, NodePtr}; #[derive(Debug)] pub struct Edge { - ptr: AtomicI32, + ptr: AtomicU32, mov: AtomicU16, policy: AtomicI16, - visits: AtomicI32, - q: AtomicU32, - sq_q: AtomicU32, + stats: ActionStats, } impl Clone for Edge { fn clone(&self) -> Self { Self { - ptr: AtomicI32::new(self.ptr()), + ptr: AtomicU32::new(self.ptr().inner()), mov: AtomicU16::new(self.mov()), policy: AtomicI16::new(self.policy.load(Ordering::Relaxed)), - visits: AtomicI32::new(self.visits()), - q: AtomicU32::new(self.q.load(Ordering::Relaxed)), - sq_q: AtomicU32::new(self.sq_q.load(Ordering::Relaxed)), + stats: self.stats.clone(), } } } @@ -26,71 +24,61 @@ impl Clone for Edge { impl Default for Edge { fn default() -> Self { Self { - ptr: AtomicI32::new(-1), + ptr: AtomicU32::new(NodePtr::NULL.inner()), mov: AtomicU16::new(0), policy: AtomicI16::new(0), - visits: AtomicI32::new(0), - q: AtomicU32::new(0), - sq_q: AtomicU32::new(0), + stats: ActionStats::default(), } } } impl Edge { - pub fn new(ptr: i32, mov: u16, policy: i16) -> Self { + pub fn new(ptr: NodePtr, mov: u16, policy: i16) -> Self { Self { - ptr: AtomicI32::new(ptr), + ptr: AtomicU32::new(ptr.inner()), mov: AtomicU16::new(mov), policy: AtomicI16::new(policy), - visits: AtomicI32::new(0), - q: AtomicU32::new(0), - sq_q: AtomicU32::new(0), + stats: ActionStats::default(), } } pub fn set_new(&self, mov: u16, policy: f32) { - self.ptr.store(-1, Ordering::Relaxed); + self.ptr.store(NodePtr::NULL.inner(), Ordering::Relaxed); self.mov.store(mov, Ordering::Relaxed); self.set_policy(policy); - self.visits.store(0, Ordering::Relaxed); - self.q.store(0, Ordering::Relaxed); - self.sq_q.store(0, Ordering::Relaxed); + self.stats.clear(); } - pub fn ptr(&self) -> i32 { - self.ptr.load(Ordering::Relaxed) + pub fn ptr(&self) -> NodePtr { + NodePtr::from_raw(self.ptr.load(Ordering::Relaxed)) } pub fn mov(&self) -> u16 { self.mov.load(Ordering::Relaxed) } - pub fn visits(&self) -> i32 { - self.visits.load(Ordering::Relaxed) - } - pub fn policy(&self) -> f32 { f32::from(self.policy.load(Ordering::Relaxed)) / f32::from(i16::MAX) } - fn q64(&self) -> f64 { - f64::from(self.q.load(Ordering::Relaxed)) / f64::from(u32::MAX) + pub fn stats(&self) -> ActionStats { + self.stats.clone() } - pub fn q(&self) -> f32 { - self.q64() as f32 + pub fn visits(&self) -> i32 { + self.stats.visits() } - pub fn sq_q(&self) -> f64 { - f64::from(self.sq_q.load(Ordering::Relaxed)) / f64::from(u32::MAX) + pub fn q(&self) -> f32 { + self.stats.q() } - pub fn var(&self) -> f32 { - (self.sq_q() - self.q64().powi(2)).max(0.0) as f32 + pub fn sq_q(&self) -> f64 { + self.stats.sq_q() } - pub fn set_ptr(&self, ptr: i32) { - self.ptr.store(ptr, Ordering::Relaxed); + pub fn set_ptr(&self, ptr: NodePtr) { + self.ptr.store(ptr.inner(), Ordering::Relaxed); } pub fn set_policy(&self, policy: f32) { @@ -98,13 +86,6 @@ impl Edge { } pub fn update(&self, result: f32) { - let r = f64::from(result); - let v = f64::from(self.visits.fetch_add(1, Ordering::Relaxed)); - - let q = (self.q64() * v + r) / (v + 1.0); - let sq_q = (self.sq_q() * v + r.powi(2)) / (v + 1.0); - - self.q.store((q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); - self.sq_q.store((sq_q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); + self.stats.update(result); } } diff --git a/src/tree/half.rs b/src/tree/half.rs new file mode 100644 index 00000000..85368173 --- /dev/null +++ b/src/tree/half.rs @@ -0,0 +1,50 @@ +use std::sync::atomic::AtomicUsize; + +use crate::GameState; +use super::{Node, NodePtr}; + + +pub struct TreeHalf { + nodes: Vec, + used: AtomicUsize, + half: bool, +} + +impl std::ops::Index for TreeHalf { + type Output = Node; + + fn index(&self, index: NodePtr) -> &Self::Output { + &self.nodes[index.idx()] + } +} + +impl TreeHalf { + pub fn new(size: usize, half: bool) -> Self { + let mut res = Self { + nodes: Vec::with_capacity(size), + used: AtomicUsize::new(0), + half, + }; + + for _ in 0..size { + res.nodes.push(Node::new(GameState::Ongoing)); + } + + res + } + + pub fn push_new(&self, state: GameState) -> NodePtr { + let idx = self.used.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + if idx == self.nodes.len() { + return NodePtr::NULL; + } + + self.nodes[idx].set_new(state); + + NodePtr::new(self.half, idx as u32) + } +} + + + diff --git a/src/tree/node.rs b/src/tree/node.rs index bd1ccf22..03d8e5e4 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,40 +1,24 @@ -use std::sync::{atomic::{AtomicI32, AtomicU16, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::sync::{atomic::{AtomicU16, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard}; -use crate::{chess::Move, tree::Edge, ChessState, GameState, MctsParams, PolicyNetwork}; +use crate::{chess::Move, tree::{Edge, NodePtr}, ChessState, GameState, MctsParams, PolicyNetwork}; #[derive(Debug)] pub struct Node { actions: RwLock>, state: AtomicU16, - - // used for lru - bwd_link: AtomicI32, - fwd_link: AtomicI32, - parent: AtomicI32, - action: AtomicU16, } impl Node { - pub fn new(state: GameState, parent: i32, action: usize) -> Self { + pub fn new(state: GameState) -> Self { Node { actions: RwLock::new(Vec::new()), state: AtomicU16::new(u16::from(state)), - parent: AtomicI32::new(parent), - bwd_link: AtomicI32::new(-1), - fwd_link: AtomicI32::new(-1), - action: AtomicU16::new(action as u16), } } - pub fn set_new(&self, state: GameState, parent: i32, action: usize) { + pub fn set_new(&self, state: GameState) { self.clear(); self.state.store(u16::from(state), Ordering::Relaxed); - self.parent.store(parent, Ordering::Relaxed); - self.action.store(action as u16, Ordering::Relaxed); - } - - pub fn parent(&self) -> i32 { - self.parent.load(Ordering::Relaxed) } pub fn is_terminal(&self) -> bool { @@ -57,14 +41,6 @@ impl Node { GameState::from(self.state.load(Ordering::Relaxed)) } - pub fn bwd_link(&self) -> i32 { - self.bwd_link.load(Ordering::Relaxed) - } - - pub fn fwd_link(&self) -> i32 { - self.fwd_link.load(Ordering::Relaxed) - } - pub fn set_state(&self, state: GameState) { self.state.store(u16::from(state), Ordering::Relaxed); } @@ -73,15 +49,6 @@ impl Node { self.actions.read().unwrap().len() != 0 } - pub fn action(&self) -> usize { - usize::from(self.action.load(Ordering::Relaxed)) - } - - pub fn clear_parent(&self) { - self.parent.store(-1, Ordering::Relaxed); - self.action.store(0, Ordering::Relaxed); - } - pub fn is_not_expanded(&self) -> bool { self.state() == GameState::Ongoing && !self.has_children() } @@ -89,16 +56,6 @@ impl Node { pub fn clear(&self) { self.actions.write().unwrap().clear(); self.set_state(GameState::Ongoing); - self.set_bwd_link(-1); - self.set_fwd_link(-1); - } - - pub fn set_fwd_link(&self, ptr: i32) { - self.fwd_link.store(ptr, Ordering::Relaxed); - } - - pub fn set_bwd_link(&self, ptr: i32) { - self.bwd_link.store(ptr, Ordering::Relaxed); } pub fn expand( @@ -118,14 +75,14 @@ impl Node { let policy = pos.get_policy(mov, &feats, policy); // trick for calculating policy before quantising - actions.push(Edge::new(f32::to_bits(policy) as i32, mov.into(), 0)); + actions.push(Edge::new(NodePtr::from_raw(f32::to_bits(policy)), mov.into(), 0)); max = max.max(policy); }); let mut total = 0.0; for action in actions.iter_mut() { - let mut policy = f32::from_bits(action.ptr() as u32); + let mut policy = f32::from_bits(action.ptr().inner()); policy = if ROOT { ((policy - max) / params.root_pst()).exp() @@ -133,14 +90,14 @@ impl Node { (policy - max).exp() }; - action.set_ptr(f32::to_bits(policy) as i32); + action.set_ptr(NodePtr::from_raw(f32::to_bits(policy))); total += policy; } for action in actions.iter_mut() { - let policy = f32::from_bits(action.ptr() as u32) / total; - action.set_ptr(-1); + let policy = f32::from_bits(action.ptr().inner()) / total; + action.set_ptr(NodePtr::NULL); action.set_policy(policy); } } diff --git a/src/tree/ptr.rs b/src/tree/ptr.rs new file mode 100644 index 00000000..7af64bbf --- /dev/null +++ b/src/tree/ptr.rs @@ -0,0 +1,30 @@ +#[derive(Clone, Copy, Default, PartialEq, Eq)] +pub struct NodePtr(u32); + +impl NodePtr { + pub const NULL: Self = Self(u32::MAX); + + pub fn is_null(self) -> bool { + self == Self::NULL + } + + pub fn new(half: bool, idx: u32) -> Self { + Self((u32::from(half) << 31) | idx) + } + + pub fn half(self) -> bool { + self.0 & (1 << 31) > 0 + } + + pub fn idx(self) -> usize { + (self.0 & 0x7FFFFFFF) as usize + } + + pub fn inner(self) -> u32 { + self.0 + } + + pub fn from_raw(inner: u32) -> Self { + Self(inner) + } +} \ No newline at end of file diff --git a/src/tree/stats.rs b/src/tree/stats.rs new file mode 100644 index 00000000..99720721 --- /dev/null +++ b/src/tree/stats.rs @@ -0,0 +1,67 @@ +use std::sync::atomic::{AtomicI32, AtomicU32, Ordering}; + +#[derive(Debug)] +pub struct ActionStats { + visits: AtomicI32, + q: AtomicU32, + sq_q: AtomicU32, +} + +impl Clone for ActionStats { + fn clone(&self) -> Self { + Self { + visits: AtomicI32::new(self.visits()), + q: AtomicU32::new(self.q.load(Ordering::Relaxed)), + sq_q: AtomicU32::new(self.sq_q.load(Ordering::Relaxed)), + } + } +} + +impl Default for ActionStats { + fn default() -> Self { + Self { + visits: AtomicI32::new(0), + q: AtomicU32::new(0), + sq_q: AtomicU32::new(0), + } + } +} + +impl ActionStats { + pub fn visits(&self) -> i32 { + self.visits.load(Ordering::Relaxed) + } + + fn q64(&self) -> f64 { + f64::from(self.q.load(Ordering::Relaxed)) / f64::from(u32::MAX) + } + + pub fn q(&self) -> f32 { + self.q64() as f32 + } + + pub fn sq_q(&self) -> f64 { + f64::from(self.sq_q.load(Ordering::Relaxed)) / f64::from(u32::MAX) + } + + pub fn var(&self) -> f32 { + (self.sq_q() - self.q64().powi(2)).max(0.0) as f32 + } + + pub fn update(&self, result: f32) { + let r = f64::from(result); + let v = f64::from(self.visits.fetch_add(1, Ordering::Relaxed)); + + let q = (self.q64() * v + r) / (v + 1.0); + let sq_q = (self.sq_q() * v + r.powi(2)) / (v + 1.0); + + self.q.store((q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); + self.sq_q.store((sq_q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); + } + + pub fn clear(&self) { + self.visits.store(0, Ordering::Relaxed); + self.q.store(0, Ordering::Relaxed); + self.sq_q.store(0, Ordering::Relaxed); + } +} \ No newline at end of file diff --git a/src/uci.rs b/src/uci.rs index 2040ff62..f8b1ed9b 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -110,11 +110,6 @@ impl Uci { } } "tree" => { - let u = tree.len(); - let c = tree.cap(); - let pct = u as f32 * 100.0 / c as f32; - println!("filled {u}/{c} ({pct:.2}%)"); - let depth = commands.get(1).unwrap_or(&"5").parse().unwrap_or(5); tree.display(tree.root_node(), depth); } From 42cdeb7e00fc5f247b4be7822446d18e32ea7e23 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Fri, 19 Jul 2024 11:21:28 +0100 Subject: [PATCH 16/48] . --- src/tree.rs | 39 ++++++++++++++++++++++++++++++++++++--- src/tree/half.rs | 4 ++++ src/tree/node.rs | 2 +- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index fdd2ac8e..61fc74d5 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -13,7 +13,7 @@ pub use ptr::NodePtr; pub use stats::ActionStats; use std::{ - sync::atomic::{AtomicBool, Ordering}, + sync::{atomic::{AtomicBool, Ordering}, Mutex}, time::Instant, }; @@ -25,6 +25,7 @@ use crate::{ pub struct Tree { tree: [TreeHalf; 2], half: AtomicBool, + flip_lock: Mutex<()>, hash: HashTable, root_stats: ActionStats, } @@ -54,6 +55,7 @@ impl Tree { half: AtomicBool::new(false), hash: HashTable::new(cap / 16), root_stats: ActionStats::default(), + flip_lock: Mutex::new(()), } } @@ -61,8 +63,39 @@ impl Tree { usize::from(self.half.load(Ordering::Relaxed)) } + pub fn is_full(&self) -> bool { + self.tree[self.half()].is_full() + } + + pub fn copy_across(&self, from: NodePtr, to: NodePtr) { + if from == to { + return; + } + + self[to].set_state(self[from].state()); + + let f = &mut *self[from].actions_mut(); + let t = &mut *self[to].actions_mut(); + std::mem::swap(f, t); + } + pub fn push_new(&self, state: GameState) -> NodePtr { - self.tree[self.half()].push_new(state) + let mut new_ptr = self.tree[self.half()].push_new(state); + + if new_ptr.is_null() { + let _lock = self.flip_lock.lock(); + + if self.is_full() { + let old_root_ptr = self.root_node(); + self.half.fetch_xor(true, Ordering::Relaxed); + let new_root_ptr = self.root_node(); + self.copy_across(old_root_ptr, new_root_ptr); + } + + new_ptr = self.push_new(state); + } + + new_ptr } pub fn root_node(&self) -> NodePtr { @@ -158,7 +191,7 @@ impl Tree { if !found { println!("info string no subtree found"); - let node = self.push_new(GameState::Ongoing, -1, 0); + let node = self.push_new(GameState::Ongoing); self.make_root_node(node); } diff --git a/src/tree/half.rs b/src/tree/half.rs index 85368173..7a040809 100644 --- a/src/tree/half.rs +++ b/src/tree/half.rs @@ -44,6 +44,10 @@ impl TreeHalf { NodePtr::new(self.half, idx as u32) } + + pub fn is_full(&self) -> bool { + self.used.load(std::sync::atomic::Ordering::Relaxed) >= self.nodes.len() + } } diff --git a/src/tree/node.rs b/src/tree/node.rs index 03d8e5e4..0814928a 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -33,7 +33,7 @@ impl Node { self.actions.read().unwrap() } - fn actions_mut(&self) -> RwLockWriteGuard> { + pub fn actions_mut(&self) -> RwLockWriteGuard> { self.actions.write().unwrap() } From 4fd985bbc0567e58061bc501c45649dc80d4837c Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sat, 20 Jul 2024 11:24:03 +0100 Subject: [PATCH 17/48] . --- src/mcts.rs | 9 +-------- src/tree.rs | 24 ++++++++++++++++++++++-- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 606b54ee..d1e48a5f 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -219,14 +219,7 @@ impl<'a> Searcher<'a> { let edge = self.tree.edge_copy(ptr, action); pos.make_move(Move::from(edge.mov())); - let mut child_ptr = edge.ptr(); - - // create and push node if not present - if child_ptr.is_null() { - let state = pos.game_state(); - child_ptr = self.tree.push_new(state); - self.tree.set_edge_ptr(ptr, action, child_ptr); - } + let child_ptr = self.tree.fetch_node(pos, ptr, edge.ptr(), action); let u = self.perform_one_iteration(pos, child_ptr, &edge.stats(), depth); diff --git a/src/tree.rs b/src/tree.rs index 61fc74d5..be8c3fce 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -18,8 +18,7 @@ use std::{ }; use crate::{ - chess::{ChessState, Move}, - GameState, + chess::{ChessState, Move}, GameState }; pub struct Tree { @@ -98,6 +97,27 @@ impl Tree { new_ptr } + pub fn fetch_node( + &self, + pos: &ChessState, + parent_ptr: NodePtr, + ptr: NodePtr, + action: usize, + ) -> NodePtr { + if ptr.is_null() { + let state = pos.game_state(); + let new_ptr = self.push_new(state); + self.set_edge_ptr(parent_ptr, action, new_ptr); + new_ptr + } else if ptr.half() != self.half.load(Ordering::Relaxed) { + let new_ptr = self.push_new(GameState::Ongoing); + self.copy_across(ptr, new_ptr); + new_ptr + } else { + ptr + } + } + pub fn root_node(&self) -> NodePtr { NodePtr::new(self.half.load(Ordering::Relaxed), 0) } From d95e937df7b7bfb611b03229d98f6c714c23edb0 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sat, 20 Jul 2024 23:11:52 +0100 Subject: [PATCH 18/48] . --- src/mcts.rs | 10 +++++----- src/tree.rs | 26 +++++++++++++++++++++++--- src/tree/half.rs | 15 +++++++++++++-- src/tree/ptr.rs | 2 +- src/uci.rs | 7 ------- 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index d1e48a5f..90a56dfe 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -214,7 +214,7 @@ impl<'a> Searcher<'a> { } // select action to take via PUCT - let action = self.pick_action(ptr, &node_stats); + let action = self.pick_action(ptr, node_stats); let edge = self.tree.edge_copy(ptr, action); pos.make_move(Move::from(edge.mov())); @@ -252,9 +252,9 @@ impl<'a> Searcher<'a> { let is_root = ptr == self.tree.root_node(); - let cpuct = SearchHelpers::get_cpuct(self.params, &node_stats, is_root); - let fpu = SearchHelpers::get_fpu(&node_stats); - let expl_scale = SearchHelpers::get_explore_scaling(self.params, &node_stats); + let cpuct = SearchHelpers::get_cpuct(self.params, node_stats, is_root); + let fpu = SearchHelpers::get_fpu(node_stats); + let expl_scale = SearchHelpers::get_explore_scaling(self.params, node_stats); let expl = cpuct * expl_scale; @@ -311,7 +311,7 @@ impl<'a> Searcher<'a> { let mut pv = Vec::new(); - while (mate || depth > 0) && action.ptr().is_null() { + while (mate || depth > 0) && !action.ptr().is_null() { pv.push(Move::from(action.mov())); let idx = self.tree.get_best_child(action.ptr()); diff --git a/src/tree.rs b/src/tree.rs index be8c3fce..757bdb35 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -112,6 +112,7 @@ impl Tree { } else if ptr.half() != self.half.load(Ordering::Relaxed) { let new_ptr = self.push_new(GameState::Ongoing); self.copy_across(ptr, new_ptr); + self.set_edge_ptr(parent_ptr, action, new_ptr); new_ptr } else { ptr @@ -148,6 +149,21 @@ impl Tree { self.hash.push(hash, wins); } + pub fn clear_halves(&self) { + self.tree[0].clear(); + self.tree[1].clear(); + } + + pub fn clear(&mut self) { + self.tree[0].clear(); + self.tree[1].clear(); + self.hash.clear(); + } + + pub fn is_empty(&self) -> bool { + self.tree[0].is_empty() && self.tree[1].is_empty() + } + pub fn propogate_proven_mates(&self, ptr: NodePtr, child_state: GameState) { match child_state { // if the child node resulted in a loss, then @@ -184,7 +200,7 @@ impl Tree { if self.is_empty() { let node = self.push_new(GameState::Ongoing); - + assert_eq!(node, self.root_node()); return; } @@ -201,7 +217,9 @@ impl Tree { found = true; if root != self.root_node() { - self.make_root_node(root); + self.half.fetch_xor(true, Ordering::Relaxed); + self.tree[self.half()].clear(); + self.copy_across(root, self.root_node()); println!("info string found subtree"); } else { println!("info string using current tree"); @@ -211,8 +229,10 @@ impl Tree { if !found { println!("info string no subtree found"); + self.clear_halves(); + self.half.fetch_xor(false, Ordering::Relaxed); let node = self.push_new(GameState::Ongoing); - self.make_root_node(node); + assert_eq!(node, self.root_node()); } println!( diff --git a/src/tree/half.rs b/src/tree/half.rs index 7a040809..2b787eaa 100644 --- a/src/tree/half.rs +++ b/src/tree/half.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::{AtomicUsize, Ordering}; use crate::GameState; use super::{Node, NodePtr}; @@ -8,6 +8,7 @@ pub struct TreeHalf { nodes: Vec, used: AtomicUsize, half: bool, + age: AtomicUsize, } impl std::ops::Index for TreeHalf { @@ -24,6 +25,7 @@ impl TreeHalf { nodes: Vec::with_capacity(size), used: AtomicUsize::new(0), half, + age: AtomicUsize::new(0), }; for _ in 0..size { @@ -45,8 +47,17 @@ impl TreeHalf { NodePtr::new(self.half, idx as u32) } + pub fn clear(&self) { + self.used.store(0, Ordering::Relaxed); + self.age.fetch_add(1, Ordering::Relaxed); + } + + pub fn is_empty(&self) -> bool { + self.used.load(Ordering::Relaxed) == 0 + } + pub fn is_full(&self) -> bool { - self.used.load(std::sync::atomic::Ordering::Relaxed) >= self.nodes.len() + self.used.load(Ordering::Relaxed) >= self.nodes.len() } } diff --git a/src/tree/ptr.rs b/src/tree/ptr.rs index 7af64bbf..517d3869 100644 --- a/src/tree/ptr.rs +++ b/src/tree/ptr.rs @@ -1,4 +1,4 @@ -#[derive(Clone, Copy, Default, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub struct NodePtr(u32); impl NodePtr { diff --git a/src/uci.rs b/src/uci.rs index f8b1ed9b..c50430d6 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -69,7 +69,6 @@ impl Uci { policy, value, &mut stored_message, - threads, ); prev = Some(pos.clone()); @@ -251,7 +250,6 @@ fn go( policy: &PolicyNetwork, value: &ValueNetwork, stored_message: &mut Option, - threads: usize, ) { let mut max_nodes = i32::MAX as usize; let mut max_time = None; @@ -335,11 +333,6 @@ fn go( } }); - for _ in 0..threads - 1 { - let mut searcher = Searcher::new(pos.clone(), tree, params, policy, value, &abort); - searcher.secondary_search(); - } - *stored_message = handle_search_input(&abort); }); } From 66d1a1c146ba1ec559614c8446dd9b99eadad489 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 00:29:55 +0100 Subject: [PATCH 19/48] progress! --- src/mcts.rs | 21 ++++++++++-------- src/tree.rs | 55 ++++++++++++++++++++++++++++++------------------ src/tree/half.rs | 18 +++++++++++----- src/tree/node.rs | 13 +++++++++--- src/uci.rs | 12 ++--------- 5 files changed, 72 insertions(+), 47 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 90a56dfe..0becd6d7 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -82,14 +82,16 @@ impl<'a> Searcher<'a> { let mut pos = self.root_position.clone(); let mut this_depth = 0; - let u = self.perform_one_iteration( + if let Some(u) = self.perform_one_iteration( &mut pos, self.tree.root_node(), self.tree.root_stats(), &mut this_depth, - ); - - self.tree.root_stats().update(u); + ) { + self.tree.root_stats().update(u); + } else { + println!("flippin' eck"); + } cumulative_depth += this_depth - 1; @@ -189,8 +191,9 @@ impl<'a> Searcher<'a> { (Move::from(best_child.mov()), best_child.q()) } - fn perform_one_iteration(&mut self, pos: &mut ChessState, ptr: NodePtr, node_stats: &ActionStats, depth: &mut usize) -> f32 { + fn perform_one_iteration(&mut self, pos: &mut ChessState, ptr: NodePtr, node_stats: &ActionStats, depth: &mut usize) -> Option { *depth += 1; + assert_eq!(ptr.half(), self.tree.half() > 0); let hash = pos.hash(); @@ -219,9 +222,9 @@ impl<'a> Searcher<'a> { let edge = self.tree.edge_copy(ptr, action); pos.make_move(Move::from(edge.mov())); - let child_ptr = self.tree.fetch_node(pos, ptr, edge.ptr(), action); + let child_ptr = self.tree.fetch_node(pos, ptr, edge.ptr(), action)?; - let u = self.perform_one_iteration(pos, child_ptr, &edge.stats(), depth); + let u = self.perform_one_iteration(pos, child_ptr, &edge.stats(), depth)?; let new_q = self.tree.update_edge_stats(ptr, action, u); self.tree.push_hash(hash, new_q); @@ -233,7 +236,7 @@ impl<'a> Searcher<'a> { self.tree.propogate_proven_mates(ptr, child_state); - 1.0 - u + Some(1.0 - u) } fn get_utility(&self, ptr: NodePtr, pos: &ChessState) -> f32 { @@ -298,7 +301,7 @@ impl<'a> Searcher<'a> { let idx = self.tree.get_best_child(self.tree.root_node()); let mut action = self.tree.edge_copy(self.tree.root_node(), idx); - let score = if action.ptr().is_null() { + let score = if !action.ptr().is_null() { match self.tree[action.ptr()].state() { GameState::Lost(_) => 1.1, GameState::Won(_) => -0.1, diff --git a/src/tree.rs b/src/tree.rs index 757bdb35..2d3d18c2 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -58,7 +58,7 @@ impl Tree { } } - fn half(&self) -> usize { + pub fn half(&self) -> usize { usize::from(self.half.load(Ordering::Relaxed)) } @@ -78,23 +78,31 @@ impl Tree { std::mem::swap(f, t); } - pub fn push_new(&self, state: GameState) -> NodePtr { - let mut new_ptr = self.tree[self.half()].push_new(state); + pub fn push_new(&self, state: GameState) -> Option { + let new_ptr = self.tree[self.half()].push_new(state); if new_ptr.is_null() { let _lock = self.flip_lock.lock(); if self.is_full() { + println!("info string flipping!"); let old_root_ptr = self.root_node(); + self.half.fetch_xor(true, Ordering::Relaxed); + self.tree[self.half()].clear(); + self.tree[self.half()].bump_age(); + let new_root_ptr = self.root_node(); + + self.push_new(GameState::Ongoing); self.copy_across(old_root_ptr, new_root_ptr); + println!("{old_root_ptr:?}, {new_root_ptr:?}"); } - new_ptr = self.push_new(state); + None + } else { + Some(new_ptr) } - - new_ptr } pub fn fetch_node( @@ -103,22 +111,26 @@ impl Tree { parent_ptr: NodePtr, ptr: NodePtr, action: usize, - ) -> NodePtr { - if ptr.is_null() { + ) -> Option { + if ptr.is_null() || self.is_old(ptr) { let state = pos.game_state(); - let new_ptr = self.push_new(state); + let new_ptr = self.push_new(state)?; self.set_edge_ptr(parent_ptr, action, new_ptr); - new_ptr + Some(new_ptr) } else if ptr.half() != self.half.load(Ordering::Relaxed) { - let new_ptr = self.push_new(GameState::Ongoing); + let new_ptr = self.push_new(GameState::Ongoing)?; self.copy_across(ptr, new_ptr); self.set_edge_ptr(parent_ptr, action, new_ptr); - new_ptr + Some(new_ptr) } else { - ptr + Some(ptr) } } + fn is_old(&self, ptr: NodePtr) -> bool { + self.tree[usize::from(ptr.half())].age() != self[ptr].age() + } + pub fn root_node(&self) -> NodePtr { NodePtr::new(self.half.load(Ordering::Relaxed), 0) } @@ -136,7 +148,9 @@ impl Tree { } pub fn update_edge_stats(&self, ptr: NodePtr, action: usize, result: f32) -> f32 { - let edge = &self[ptr].actions()[action]; + let actions = &self[ptr].actions(); + assert!(actions.len() > action, "node: {ptr:?}"); + let edge = &actions[action]; edge.update(result); edge.q() } @@ -199,7 +213,7 @@ impl Tree { let t = Instant::now(); if self.is_empty() { - let node = self.push_new(GameState::Ongoing); + let node = self.push_new(GameState::Ongoing).unwrap(); assert_eq!(node, self.root_node()); return; } @@ -213,12 +227,13 @@ impl Tree { let root = self.recurse_find(self.root_node(), board, root, 2); - if root.is_null() && self[root].has_children() { + if !root.is_null() && self[root].has_children() { found = true; if root != self.root_node() { self.half.fetch_xor(true, Ordering::Relaxed); self.tree[self.half()].clear(); + self.push_new(GameState::Ongoing); self.copy_across(root, self.root_node()); println!("info string found subtree"); } else { @@ -231,7 +246,7 @@ impl Tree { println!("info string no subtree found"); self.clear_halves(); self.half.fetch_xor(false, Ordering::Relaxed); - let node = self.push_new(GameState::Ongoing); + let node = self.push_new(GameState::Ongoing).unwrap(); assert_eq!(node, self.root_node()); } @@ -266,7 +281,7 @@ impl Tree { let found = self.recurse_find(child_idx, &child_board, board, depth - 1); - if found.is_null() { + if !found.is_null() { return found; } } @@ -294,7 +309,7 @@ impl Tree { self.get_best_child_by_key(ptr, |child| { if child.visits() == 0 { f32::NEG_INFINITY - } else if child.ptr().is_null() { + } else if !child.ptr().is_null() { match self[child.ptr()].state() { GameState::Lost(n) => 1.0 + f32::from(n), GameState::Won(n) => f32::from(n) - 256.0, @@ -354,7 +369,7 @@ impl Tree { let mut active = Vec::new(); for action in node.actions().iter() { - if action.ptr().is_null() { + if !action.ptr().is_null() { active.push(action.clone()); } } diff --git a/src/tree/half.rs b/src/tree/half.rs index 2b787eaa..40d79127 100644 --- a/src/tree/half.rs +++ b/src/tree/half.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering}; use crate::GameState; use super::{Node, NodePtr}; @@ -8,7 +8,7 @@ pub struct TreeHalf { nodes: Vec, used: AtomicUsize, half: bool, - age: AtomicUsize, + age: AtomicU32, } impl std::ops::Index for TreeHalf { @@ -25,11 +25,11 @@ impl TreeHalf { nodes: Vec::with_capacity(size), used: AtomicUsize::new(0), half, - age: AtomicUsize::new(0), + age: AtomicU32::new(0), }; for _ in 0..size { - res.nodes.push(Node::new(GameState::Ongoing)); + res.nodes.push(Node::new(GameState::Ongoing, 0)); } res @@ -42,7 +42,7 @@ impl TreeHalf { return NodePtr::NULL; } - self.nodes[idx].set_new(state); + self.nodes[idx].set_new(state, self.age()); NodePtr::new(self.half, idx as u32) } @@ -59,6 +59,14 @@ impl TreeHalf { pub fn is_full(&self) -> bool { self.used.load(Ordering::Relaxed) >= self.nodes.len() } + + pub fn age(&self) -> u32 { + self.age.load(Ordering::Relaxed) + } + + pub fn bump_age(&self) { + self.age.fetch_add(1, Ordering::Relaxed); + } } diff --git a/src/tree/node.rs b/src/tree/node.rs index 0814928a..f309433c 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,4 +1,4 @@ -use std::sync::{atomic::{AtomicU16, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::sync::{atomic::{AtomicU16, AtomicU32, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard}; use crate::{chess::Move, tree::{Edge, NodePtr}, ChessState, GameState, MctsParams, PolicyNetwork}; @@ -6,19 +6,22 @@ use crate::{chess::Move, tree::{Edge, NodePtr}, ChessState, GameState, MctsParam pub struct Node { actions: RwLock>, state: AtomicU16, + age: AtomicU32, } impl Node { - pub fn new(state: GameState) -> Self { + pub fn new(state: GameState, age: u32) -> Self { Node { actions: RwLock::new(Vec::new()), state: AtomicU16::new(u16::from(state)), + age: AtomicU32::new(age), } } - pub fn set_new(&self, state: GameState) { + pub fn set_new(&self, state: GameState, age: u32) { self.clear(); self.state.store(u16::from(state), Ordering::Relaxed); + self.age.store(age, Ordering::Relaxed); } pub fn is_terminal(&self) -> bool { @@ -37,6 +40,10 @@ impl Node { self.actions.write().unwrap() } + pub fn age(&self) -> u32 { + self.age.load(Ordering::Relaxed) + } + pub fn state(&self) -> GameState { GameState::from(self.state.load(Ordering::Relaxed)) } diff --git a/src/uci.rs b/src/uci.rs index c50430d6..f1fb045f 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -53,7 +53,7 @@ impl Uci { match cmd { "isready" => println!("readyok"), "setoption" => setoption(&commands, &mut params, &mut report_moves, &mut tree, &mut threads), - "position" => position(commands, &mut pos, &mut prev, &mut tree), + "position" => position(commands, &mut pos), "go" => { // increment game ply every time `go` is called root_game_ply += 2; @@ -195,12 +195,7 @@ fn setoption(commands: &[&str], params: &mut MctsParams, report_moves: &mut bool } } -fn position( - commands: Vec<&str>, - pos: &mut ChessState, - prev: &mut Option, - tree: &mut Tree, -) { +fn position(commands: Vec<&str>, pos: &mut ChessState) { let mut fen = String::new(); let mut move_list = Vec::new(); let mut moves = false; @@ -233,9 +228,6 @@ fn position( pos.make_move(this_mov); } - - tree.try_use_subtree(pos, prev); - *prev = Some(pos.clone()); } #[allow(clippy::too_many_arguments)] From 509e2bcd5ef6884d909a1c3c6002a8dff159d4de Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 01:29:24 +0100 Subject: [PATCH 20/48] almost there --- src/mcts.rs | 13 ++++++++++++- src/tree.rs | 52 ++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 0becd6d7..6cf8fa56 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -193,7 +193,7 @@ impl<'a> Searcher<'a> { fn perform_one_iteration(&mut self, pos: &mut ChessState, ptr: NodePtr, node_stats: &ActionStats, depth: &mut usize) -> Option { *depth += 1; - assert_eq!(ptr.half(), self.tree.half() > 0); + assert!(self.tree.ptr_is_valid(ptr)); let hash = pos.hash(); @@ -220,6 +220,17 @@ impl<'a> Searcher<'a> { let action = self.pick_action(ptr, node_stats); let edge = self.tree.edge_copy(ptr, action); + + let mut found = false; + let mov = Move::from(edge.mov()); + pos.map_legal_moves(|m| { + if m == mov { + found = true; + } + }); + + assert!(found); + pos.make_move(Move::from(edge.mov())); let child_ptr = self.tree.fetch_node(pos, ptr, edge.ptr(), action)?; diff --git a/src/tree.rs b/src/tree.rs index 2d3d18c2..593f0017 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -46,6 +46,8 @@ impl Tree { fn new(cap: usize) -> Self { let tree_size = cap / 8; + println!("info string tree size {tree_size}"); + Self { tree: [ TreeHalf::new(tree_size, false), @@ -66,6 +68,13 @@ impl Tree { self.tree[self.half()].is_full() } + fn flip(&self) { + let old = self.half.fetch_xor(true, Ordering::Relaxed); + let new = usize::from(!old); + self.tree[new].clear(); + self.tree[new].bump_age(); + } + pub fn copy_across(&self, from: NodePtr, to: NodePtr) { if from == to { return; @@ -78,6 +87,7 @@ impl Tree { std::mem::swap(f, t); } + #[must_use] pub fn push_new(&self, state: GameState) -> Option { let new_ptr = self.tree[self.half()].push_new(state); @@ -85,18 +95,14 @@ impl Tree { let _lock = self.flip_lock.lock(); if self.is_full() { - println!("info string flipping!"); let old_root_ptr = self.root_node(); - self.half.fetch_xor(true, Ordering::Relaxed); - self.tree[self.half()].clear(); - self.tree[self.half()].bump_age(); + self.flip(); - let new_root_ptr = self.root_node(); + let new_root_ptr = self.tree[self.half()].push_new(GameState::Ongoing); + assert_eq!(new_root_ptr, self.root_node()); - self.push_new(GameState::Ongoing); self.copy_across(old_root_ptr, new_root_ptr); - println!("{old_root_ptr:?}, {new_root_ptr:?}"); } None @@ -105,6 +111,7 @@ impl Tree { } } + #[must_use] pub fn fetch_node( &self, pos: &ChessState, @@ -120,6 +127,23 @@ impl Tree { } else if ptr.half() != self.half.load(Ordering::Relaxed) { let new_ptr = self.push_new(GameState::Ongoing)?; self.copy_across(ptr, new_ptr); + + //let mut i = 0; + let half = self.half.load(Ordering::Relaxed); + let actions = self[new_ptr].actions(); + for action in actions.iter() { + if action.ptr().half() == half { + action.set_ptr(NodePtr::NULL); + } + } + + //if actions.len() > 0 { + // pos.map_legal_moves(|mov| { + // assert_eq!(mov, Move::from(actions[i].mov())); + // i += 1; + // }); + //} + self.set_edge_ptr(parent_ptr, action, new_ptr); Some(new_ptr) } else { @@ -163,7 +187,7 @@ impl Tree { self.hash.push(hash, wins); } - pub fn clear_halves(&self) { + fn clear_halves(&self) { self.tree[0].clear(); self.tree[1].clear(); } @@ -231,9 +255,8 @@ impl Tree { found = true; if root != self.root_node() { - self.half.fetch_xor(true, Ordering::Relaxed); - self.tree[self.half()].clear(); - self.push_new(GameState::Ongoing); + self.flip(); + self.push_new(GameState::Ongoing).unwrap(); self.copy_across(root, self.root_node()); println!("info string found subtree"); } else { @@ -245,7 +268,7 @@ impl Tree { if !found { println!("info string no subtree found"); self.clear_halves(); - self.half.fetch_xor(false, Ordering::Relaxed); + self.flip(); let node = self.push_new(GameState::Ongoing).unwrap(); assert_eq!(node, self.root_node()); } @@ -256,6 +279,11 @@ impl Tree { ); } + pub fn ptr_is_valid(&self, ptr: NodePtr) -> bool { + ptr.half() == self.half.load(Ordering::Relaxed) + && self.tree[self.half()].age() == self[ptr].age() + } + fn recurse_find( &self, start: NodePtr, From 4a4574d4d93738de34d769c2e3366f4f2e5dc0ba Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 01:34:59 +0100 Subject: [PATCH 21/48] i done broke normal mcts --- src/uci.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/uci.rs b/src/uci.rs index f1fb045f..20c46342 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -137,9 +137,10 @@ impl Uci { max_nodes: 1_000_000, }; - let mut tree = Tree::new_mb(32); + let mut tree = Tree::new_mb(128); for fen in bench_fens { + println!("{fen}"); let abort = AtomicBool::new(false); let pos = ChessState::from_fen(fen); tree.try_use_subtree(&pos, &None); From e58ba4d4d2aa7e794f4ac11924f9246cf7957b25 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 02:14:16 +0100 Subject: [PATCH 22/48] . --- src/mcts.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 6cf8fa56..d75c0fac 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -221,15 +221,15 @@ impl<'a> Searcher<'a> { let edge = self.tree.edge_copy(ptr, action); - let mut found = false; - let mov = Move::from(edge.mov()); - pos.map_legal_moves(|m| { - if m == mov { - found = true; - } - }); - - assert!(found); + //let mut found = false; + //let mov = Move::from(edge.mov()); + //pos.map_legal_moves(|m| { + // if m == mov { + // found = true; + // } + //}); +// + //assert!(found); pos.make_move(Move::from(edge.mov())); From d2f5806aad45e9fcb04675984fad2aedeb0b8eab Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 13:04:48 +0100 Subject: [PATCH 23/48] . --- src/tree.rs | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 593f0017..af0d4e02 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -82,9 +82,20 @@ impl Tree { self[to].set_state(self[from].state()); - let f = &mut *self[from].actions_mut(); - let t = &mut *self[to].actions_mut(); - std::mem::swap(f, t); + // need these mut refs to be dropped immediately + { + let f = &mut *self[from].actions_mut(); + let t = &mut *self[to].actions_mut(); + std::mem::swap(f, t); + } + + let half = self.half.load(Ordering::Relaxed); + let actions = self[to].actions(); + for action in actions.iter() { + if action.ptr().half() == half { + action.set_ptr(NodePtr::NULL); + } + } } #[must_use] @@ -127,23 +138,6 @@ impl Tree { } else if ptr.half() != self.half.load(Ordering::Relaxed) { let new_ptr = self.push_new(GameState::Ongoing)?; self.copy_across(ptr, new_ptr); - - //let mut i = 0; - let half = self.half.load(Ordering::Relaxed); - let actions = self[new_ptr].actions(); - for action in actions.iter() { - if action.ptr().half() == half { - action.set_ptr(NodePtr::NULL); - } - } - - //if actions.len() > 0 { - // pos.map_legal_moves(|mov| { - // assert_eq!(mov, Move::from(actions[i].mov())); - // i += 1; - // }); - //} - self.set_edge_ptr(parent_ptr, action, new_ptr); Some(new_ptr) } else { From 9117e1381ed2ca2909725ac4d64222d952ce2e09 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 13:18:14 +0100 Subject: [PATCH 24/48] . --- src/mcts.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index d75c0fac..7d12d737 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -221,16 +221,6 @@ impl<'a> Searcher<'a> { let edge = self.tree.edge_copy(ptr, action); - //let mut found = false; - //let mov = Move::from(edge.mov()); - //pos.map_legal_moves(|m| { - // if m == mov { - // found = true; - // } - //}); -// - //assert!(found); - pos.make_move(Move::from(edge.mov())); let child_ptr = self.tree.fetch_node(pos, ptr, edge.ptr(), action)?; From e9672a82a53f36745f0b01b86ab571b15d536cda Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 13:30:56 +0100 Subject: [PATCH 25/48] . --- src/tree.rs | 2 +- src/uci.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tree.rs b/src/tree.rs index af0d4e02..4057c569 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -386,7 +386,7 @@ impl Tree { node.state(), ); } else { - println!("root"); + println!("root Q({:.2}%)", self.root_stats.q() * 100.0); } let mut active = Vec::new(); diff --git a/src/uci.rs b/src/uci.rs index 20c46342..54440fde 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -146,7 +146,9 @@ impl Uci { tree.try_use_subtree(&pos, &None); let mut searcher = Searcher::new(pos, &tree, params, policy, value, &abort); let timer = Instant::now(); + let old = total_nodes; searcher.search(limits, false, &mut total_nodes); + println!(" -> {}", total_nodes - old); time += timer.elapsed().as_secs_f32(); tree.clear(); } From 4c47ddc5f1fd7c6c32e7c7e7fb82744824f4fea9 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 14:15:57 +0100 Subject: [PATCH 26/48] . --- src/mcts/params.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcts/params.rs b/src/mcts/params.rs index d1a13c1d..cf8ca73a 100644 --- a/src/mcts/params.rs +++ b/src/mcts/params.rs @@ -131,7 +131,7 @@ macro_rules! make_mcts_params { make_mcts_params! { root_pst: f32 = 3.64, 1.0, 10.0, 0.4, 0.002; - root_cpuct: f32 = 0.624, 0.1, 5.0, 0.065, 0.002; + root_cpuct: f32 = 0.314, 0.1, 5.0, 0.065, 0.002; cpuct: f32 = 0.314, 0.1, 5.0, 0.065, 0.002; cpuct_var_weight: f32 = 0.851, 0.0, 2.0, 0.085, 0.002; cpuct_var_scale: f32 = 0.257, 0.0, 2.0, 0.02, 0.002; From 88d8250e8516305fae49a16bba54e1feef8a9886 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 14:55:40 +0100 Subject: [PATCH 27/48] . --- src/tree.rs | 4 +--- src/tree/half.rs | 6 +----- src/tree/node.rs | 4 ++-- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 4057c569..9906d9c6 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -72,7 +72,6 @@ impl Tree { let old = self.half.fetch_xor(true, Ordering::Relaxed); let new = usize::from(!old); self.tree[new].clear(); - self.tree[new].bump_age(); } pub fn copy_across(&self, from: NodePtr, to: NodePtr) { @@ -187,8 +186,7 @@ impl Tree { } pub fn clear(&mut self) { - self.tree[0].clear(); - self.tree[1].clear(); + self.clear_halves(); self.hash.clear(); } diff --git a/src/tree/half.rs b/src/tree/half.rs index 40d79127..84f4c5a0 100644 --- a/src/tree/half.rs +++ b/src/tree/half.rs @@ -36,7 +36,7 @@ impl TreeHalf { } pub fn push_new(&self, state: GameState) -> NodePtr { - let idx = self.used.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let idx = self.used.fetch_add(1, Ordering::Relaxed); if idx == self.nodes.len() { return NodePtr::NULL; @@ -63,10 +63,6 @@ impl TreeHalf { pub fn age(&self) -> u32 { self.age.load(Ordering::Relaxed) } - - pub fn bump_age(&self) { - self.age.fetch_add(1, Ordering::Relaxed); - } } diff --git a/src/tree/node.rs b/src/tree/node.rs index f309433c..df37455a 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -19,8 +19,8 @@ impl Node { } pub fn set_new(&self, state: GameState, age: u32) { - self.clear(); - self.state.store(u16::from(state), Ordering::Relaxed); + self.actions_mut().clear(); + self.set_state(state); self.age.store(age, Ordering::Relaxed); } From 63233a6d4f01eb153870022c8df8a8807cb6e147 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 15:04:06 +0100 Subject: [PATCH 28/48] Bench: 1682532 --- src/tree.rs | 17 ++++++++++------- src/uci.rs | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 9906d9c6..06411e6b 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -183,6 +183,7 @@ impl Tree { fn clear_halves(&self) { self.tree[0].clear(); self.tree[1].clear(); + self.root_stats.clear(); } pub fn clear(&mut self) { @@ -241,7 +242,7 @@ impl Tree { if let Some(board) = prev_board { println!("info string searching for subtree"); - let root = self.recurse_find(self.root_node(), board, root, 2); + let (root, stats) = self.recurse_find(self.root_node(), board, root, self.root_stats.clone(), 2); if !root.is_null() && self[root].has_children() { found = true; @@ -249,6 +250,7 @@ impl Tree { if root != self.root_node() { self.flip(); self.push_new(GameState::Ongoing).unwrap(); + self.root_stats = stats; self.copy_across(root, self.root_node()); println!("info string found subtree"); } else { @@ -281,14 +283,15 @@ impl Tree { start: NodePtr, this_board: &ChessState, board: &ChessState, + stats: ActionStats, depth: u8, - ) -> NodePtr { + ) -> (NodePtr, ActionStats) { if this_board.is_same(board) { - return start; + return (start, stats); } if start.is_null() || depth == 0 { - return NodePtr::NULL; + return (NodePtr::NULL, ActionStats::default()); } let node = &self[start]; @@ -299,14 +302,14 @@ impl Tree { child_board.make_move(Move::from(action.mov())); - let found = self.recurse_find(child_idx, &child_board, board, depth - 1); + let found = self.recurse_find(child_idx, &child_board, board, action.stats(), depth - 1); - if !found.is_null() { + if !found.0.is_null() { return found; } } - NodePtr::NULL + (NodePtr::NULL, ActionStats::default()) } pub fn get_best_child_by_key f32>(&self, ptr: NodePtr, mut key: F) -> usize { diff --git a/src/uci.rs b/src/uci.rs index 54440fde..a3da62f7 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -137,7 +137,7 @@ impl Uci { max_nodes: 1_000_000, }; - let mut tree = Tree::new_mb(128); + let mut tree = Tree::new_mb(32); for fen in bench_fens { println!("{fen}"); From 23e30dcf2a251dee171ea47131292fc408804b3d Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 15:24:51 +0100 Subject: [PATCH 29/48] . --- src/tree.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/tree.rs b/src/tree.rs index 06411e6b..f86fb6ed 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -387,7 +387,11 @@ impl Tree { node.state(), ); } else { - println!("root Q({:.2}%)", self.root_stats.q() * 100.0); + println!( + "root Q({:.2}%) N({})", + self.root_stats.q() * 100.0, + self.root_stats.visits(), + ); } let mut active = Vec::new(); From c2c6711972005cb9503b0601acd67ec63f8509e4 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 15:57:40 +0100 Subject: [PATCH 30/48] fix --- src/mcts.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcts.rs b/src/mcts.rs index 7d12d737..9879401e 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -203,7 +203,7 @@ impl<'a> Searcher<'a> { // probe hash table to use in place of network if self.tree[ptr].state() == GameState::Ongoing { if let Some(entry) = self.tree.probe_hash(hash) { - 1.0 - entry.q() + entry.q() } else { self.get_utility(ptr, pos) } From c038e0ec84767598e48af3ed861250620b8aba02 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 18:23:21 +0100 Subject: [PATCH 31/48] Cleanup all the garbage Bench: 831571 --- src/mcts.rs | 3 --- src/tree.rs | 46 ++++++++++++++++------------------------------ src/tree/half.rs | 6 +++++- src/uci.rs | 3 --- 4 files changed, 21 insertions(+), 37 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 9879401e..3816022a 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -89,8 +89,6 @@ impl<'a> Searcher<'a> { &mut this_depth, ) { self.tree.root_stats().update(u); - } else { - println!("flippin' eck"); } cumulative_depth += this_depth - 1; @@ -193,7 +191,6 @@ impl<'a> Searcher<'a> { fn perform_one_iteration(&mut self, pos: &mut ChessState, ptr: NodePtr, node_stats: &ActionStats, depth: &mut usize) -> Option { *depth += 1; - assert!(self.tree.ptr_is_valid(ptr)); let hash = pos.hash(); diff --git a/src/tree.rs b/src/tree.rs index f86fb6ed..da49a60c 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -46,8 +46,6 @@ impl Tree { fn new(cap: usize) -> Self { let tree_size = cap / 8; - println!("info string tree size {tree_size}"); - Self { tree: [ TreeHalf::new(tree_size, false), @@ -74,25 +72,23 @@ impl Tree { self.tree[new].clear(); } - pub fn copy_across(&self, from: NodePtr, to: NodePtr) { + pub fn copy_across(&self, from: NodePtr, to: NodePtr) { if from == to { return; } self[to].set_state(self[from].state()); - // need these mut refs to be dropped immediately - { - let f = &mut *self[from].actions_mut(); - let t = &mut *self[to].actions_mut(); - std::mem::swap(f, t); - } + let f = &mut *self[from].actions_mut(); + let t = &mut *self[to].actions_mut(); + std::mem::swap(f, t); - let half = self.half.load(Ordering::Relaxed); - let actions = self[to].actions(); - for action in actions.iter() { - if action.ptr().half() == half { - action.set_ptr(NodePtr::NULL); + if !SAME_HALF { + let half = self.half.load(Ordering::Relaxed); + for action in t.iter() { + if action.ptr().half() == half { + action.set_ptr(NodePtr::NULL); + } } } } @@ -110,9 +106,8 @@ impl Tree { self.flip(); let new_root_ptr = self.tree[self.half()].push_new(GameState::Ongoing); - assert_eq!(new_root_ptr, self.root_node()); - self.copy_across(old_root_ptr, new_root_ptr); + self.copy_across::(old_root_ptr, new_root_ptr); } None @@ -136,8 +131,9 @@ impl Tree { Some(new_ptr) } else if ptr.half() != self.half.load(Ordering::Relaxed) { let new_ptr = self.push_new(GameState::Ongoing)?; - self.copy_across(ptr, new_ptr); + self.copy_across::(ptr, new_ptr); self.set_edge_ptr(parent_ptr, action, new_ptr); + Some(new_ptr) } else { Some(ptr) @@ -166,7 +162,6 @@ impl Tree { pub fn update_edge_stats(&self, ptr: NodePtr, action: usize, result: f32) -> f32 { let actions = &self[ptr].actions(); - assert!(actions.len() > action, "node: {ptr:?}"); let edge = &actions[action]; edge.update(result); edge.q() @@ -230,8 +225,7 @@ impl Tree { let t = Instant::now(); if self.is_empty() { - let node = self.push_new(GameState::Ongoing).unwrap(); - assert_eq!(node, self.root_node()); + self.push_new(GameState::Ongoing).unwrap(); return; } @@ -248,10 +242,8 @@ impl Tree { found = true; if root != self.root_node() { - self.flip(); - self.push_new(GameState::Ongoing).unwrap(); + self.copy_across::(root, self.root_node()); self.root_stats = stats; - self.copy_across(root, self.root_node()); println!("info string found subtree"); } else { println!("info string using current tree"); @@ -263,8 +255,7 @@ impl Tree { println!("info string no subtree found"); self.clear_halves(); self.flip(); - let node = self.push_new(GameState::Ongoing).unwrap(); - assert_eq!(node, self.root_node()); + self.push_new(GameState::Ongoing).unwrap(); } println!( @@ -273,11 +264,6 @@ impl Tree { ); } - pub fn ptr_is_valid(&self, ptr: NodePtr) -> bool { - ptr.half() == self.half.load(Ordering::Relaxed) - && self.tree[self.half()].age() == self[ptr].age() - } - fn recurse_find( &self, start: NodePtr, diff --git a/src/tree/half.rs b/src/tree/half.rs index 84f4c5a0..2d97436f 100644 --- a/src/tree/half.rs +++ b/src/tree/half.rs @@ -56,8 +56,12 @@ impl TreeHalf { self.used.load(Ordering::Relaxed) == 0 } + pub fn used(&self) -> usize { + self.used.load(Ordering::Relaxed) + } + pub fn is_full(&self) -> bool { - self.used.load(Ordering::Relaxed) >= self.nodes.len() + self.used() >= self.nodes.len() } pub fn age(&self) -> u32 { diff --git a/src/uci.rs b/src/uci.rs index a3da62f7..f1fb045f 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -140,15 +140,12 @@ impl Uci { let mut tree = Tree::new_mb(32); for fen in bench_fens { - println!("{fen}"); let abort = AtomicBool::new(false); let pos = ChessState::from_fen(fen); tree.try_use_subtree(&pos, &None); let mut searcher = Searcher::new(pos, &tree, params, policy, value, &abort); let timer = Instant::now(); - let old = total_nodes; searcher.search(limits, false, &mut total_nodes); - println!(" -> {}", total_nodes - old); time += timer.elapsed().as_secs_f32(); tree.clear(); } From 00ea1816c12597ce44251e113f2c55219b67167a Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 21:48:22 +0100 Subject: [PATCH 32/48] fix one issue, another arises Bench: 831571 --- src/tree.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/tree.rs b/src/tree.rs index da49a60c..8a3de9c6 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -242,7 +242,11 @@ impl Tree { found = true; if root != self.root_node() { - self.copy_across::(root, self.root_node()); + if root.half() == self.root_node().half() { + self.copy_across::(root, self.root_node()); + } else { + self.copy_across::(root, self.root_node()); + } self.root_stats = stats; println!("info string found subtree"); } else { From 0b6c7cfa4054dfc04036dcf0667aa3c61587a8cc Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 21 Jul 2024 22:54:12 +0100 Subject: [PATCH 33/48] . --- src/tree.rs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 8a3de9c6..7ead342e 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -72,7 +72,7 @@ impl Tree { self.tree[new].clear(); } - pub fn copy_across(&self, from: NodePtr, to: NodePtr) { + pub fn copy_across(&self, from: NodePtr, to: NodePtr) { if from == to { return; } @@ -83,7 +83,7 @@ impl Tree { let t = &mut *self[to].actions_mut(); std::mem::swap(f, t); - if !SAME_HALF { + if from.half() != to.half() { let half = self.half.load(Ordering::Relaxed); for action in t.iter() { if action.ptr().half() == half { @@ -107,7 +107,7 @@ impl Tree { let new_root_ptr = self.tree[self.half()].push_new(GameState::Ongoing); - self.copy_across::(old_root_ptr, new_root_ptr); + self.copy_across(old_root_ptr, new_root_ptr); } None @@ -131,7 +131,7 @@ impl Tree { Some(new_ptr) } else if ptr.half() != self.half.load(Ordering::Relaxed) { let new_ptr = self.push_new(GameState::Ongoing)?; - self.copy_across::(ptr, new_ptr); + self.copy_across(ptr, new_ptr); self.set_edge_ptr(parent_ptr, action, new_ptr); Some(new_ptr) @@ -242,11 +242,7 @@ impl Tree { found = true; if root != self.root_node() { - if root.half() == self.root_node().half() { - self.copy_across::(root, self.root_node()); - } else { - self.copy_across::(root, self.root_node()); - } + self.copy_across(root, self.root_node()); self.root_stats = stats; println!("info string found subtree"); } else { From 3df4365c51f84a2c584776e8f78a8344da8f4f20 Mon Sep 17 00:00:00 2001 From: JacquesRW Date: Sun, 28 Jul 2024 23:07:13 +0100 Subject: [PATCH 34/48] no age Bench: 1920840 --- src/tree.rs | 6 +----- src/tree/half.rs | 13 +++---------- src/tree/node.rs | 13 +++---------- 3 files changed, 7 insertions(+), 25 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 7ead342e..b5264414 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -124,7 +124,7 @@ impl Tree { ptr: NodePtr, action: usize, ) -> Option { - if ptr.is_null() || self.is_old(ptr) { + if ptr.is_null() { let state = pos.game_state(); let new_ptr = self.push_new(state)?; self.set_edge_ptr(parent_ptr, action, new_ptr); @@ -140,10 +140,6 @@ impl Tree { } } - fn is_old(&self, ptr: NodePtr) -> bool { - self.tree[usize::from(ptr.half())].age() != self[ptr].age() - } - pub fn root_node(&self) -> NodePtr { NodePtr::new(self.half.load(Ordering::Relaxed), 0) } diff --git a/src/tree/half.rs b/src/tree/half.rs index 2d97436f..a23d520d 100644 --- a/src/tree/half.rs +++ b/src/tree/half.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicUsize, Ordering}; use crate::GameState; use super::{Node, NodePtr}; @@ -8,7 +8,6 @@ pub struct TreeHalf { nodes: Vec, used: AtomicUsize, half: bool, - age: AtomicU32, } impl std::ops::Index for TreeHalf { @@ -25,11 +24,10 @@ impl TreeHalf { nodes: Vec::with_capacity(size), used: AtomicUsize::new(0), half, - age: AtomicU32::new(0), }; for _ in 0..size { - res.nodes.push(Node::new(GameState::Ongoing, 0)); + res.nodes.push(Node::new(GameState::Ongoing)); } res @@ -42,14 +40,13 @@ impl TreeHalf { return NodePtr::NULL; } - self.nodes[idx].set_new(state, self.age()); + self.nodes[idx].set_new(state); NodePtr::new(self.half, idx as u32) } pub fn clear(&self) { self.used.store(0, Ordering::Relaxed); - self.age.fetch_add(1, Ordering::Relaxed); } pub fn is_empty(&self) -> bool { @@ -63,10 +60,6 @@ impl TreeHalf { pub fn is_full(&self) -> bool { self.used() >= self.nodes.len() } - - pub fn age(&self) -> u32 { - self.age.load(Ordering::Relaxed) - } } diff --git a/src/tree/node.rs b/src/tree/node.rs index df37455a..cf7af3cf 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,4 +1,4 @@ -use std::sync::{atomic::{AtomicU16, AtomicU32, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::sync::{atomic::{AtomicU16, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard}; use crate::{chess::Move, tree::{Edge, NodePtr}, ChessState, GameState, MctsParams, PolicyNetwork}; @@ -6,22 +6,19 @@ use crate::{chess::Move, tree::{Edge, NodePtr}, ChessState, GameState, MctsParam pub struct Node { actions: RwLock>, state: AtomicU16, - age: AtomicU32, } impl Node { - pub fn new(state: GameState, age: u32) -> Self { + pub fn new(state: GameState) -> Self { Node { actions: RwLock::new(Vec::new()), state: AtomicU16::new(u16::from(state)), - age: AtomicU32::new(age), } } - pub fn set_new(&self, state: GameState, age: u32) { + pub fn set_new(&self, state: GameState) { self.actions_mut().clear(); self.set_state(state); - self.age.store(age, Ordering::Relaxed); } pub fn is_terminal(&self) -> bool { @@ -40,10 +37,6 @@ impl Node { self.actions.write().unwrap() } - pub fn age(&self) -> u32 { - self.age.load(Ordering::Relaxed) - } - pub fn state(&self) -> GameState { GameState::from(self.state.load(Ordering::Relaxed)) } From 5892700c4eb42afd252f1f21f5941235f55882b7 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Thu, 1 Aug 2024 22:30:59 +0100 Subject: [PATCH 35/48] Fixed? Bench: 1920840 --- src/tree.rs | 15 +++------------ src/tree/half.rs | 10 ++++++++++ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index b5264414..0d58556d 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -67,9 +67,9 @@ impl Tree { } fn flip(&self) { - let old = self.half.fetch_xor(true, Ordering::Relaxed); - let new = usize::from(!old); - self.tree[new].clear(); + let old = usize::from(self.half.fetch_xor(true, Ordering::Relaxed)); + self.tree[old].clear_ptrs(); + self.tree[old ^ 1].clear(); } pub fn copy_across(&self, from: NodePtr, to: NodePtr) { @@ -82,15 +82,6 @@ impl Tree { let f = &mut *self[from].actions_mut(); let t = &mut *self[to].actions_mut(); std::mem::swap(f, t); - - if from.half() != to.half() { - let half = self.half.load(Ordering::Relaxed); - for action in t.iter() { - if action.ptr().half() == half { - action.set_ptr(NodePtr::NULL); - } - } - } } #[must_use] diff --git a/src/tree/half.rs b/src/tree/half.rs index a23d520d..d2858364 100644 --- a/src/tree/half.rs +++ b/src/tree/half.rs @@ -49,6 +49,16 @@ impl TreeHalf { self.used.store(0, Ordering::Relaxed); } + pub fn clear_ptrs(&self) { + for node in &self.nodes { + for action in &mut *node.actions_mut() { + if action.ptr().half() != self.half { + action.set_ptr(NodePtr::NULL); + } + } + } + } + pub fn is_empty(&self) -> bool { self.used.load(Ordering::Relaxed) == 0 } From b6e4177c3766af1ee0c60a6676ee1d6564e7d5cb Mon Sep 17 00:00:00 2001 From: jw1912 Date: Fri, 2 Aug 2024 02:18:18 +0100 Subject: [PATCH 36/48] Fix OOM on Workers Bench: 1904002 --- src/tree.rs | 14 ++++++-------- src/tree/node.rs | 4 ++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 0d58556d..eb361a48 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -39,20 +39,18 @@ impl std::ops::Index for Tree { impl Tree { pub fn new_mb(mb: usize) -> Self { - let cap = mb * 1024 * 1024 / 48; - Self::new(cap) + let bytes = mb * 1024 * 1024; + Self::new(bytes / (48 + 20 * 20), bytes / 48 / 16) } - fn new(cap: usize) -> Self { - let tree_size = cap / 8; - + fn new(tree_cap: usize, hash_cap: usize) -> Self { Self { tree: [ - TreeHalf::new(tree_size, false), - TreeHalf::new(tree_size, true), + TreeHalf::new(tree_cap / 2, false), + TreeHalf::new(tree_cap / 2, true), ], half: AtomicBool::new(false), - hash: HashTable::new(cap / 16), + hash: HashTable::new(hash_cap / 4), root_stats: ActionStats::default(), flip_lock: Mutex::new(()), } diff --git a/src/tree/node.rs b/src/tree/node.rs index cf7af3cf..990213fe 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -17,7 +17,7 @@ impl Node { } pub fn set_new(&self, state: GameState) { - self.actions_mut().clear(); + *self.actions_mut() = Vec::new(); self.set_state(state); } @@ -54,7 +54,7 @@ impl Node { } pub fn clear(&self) { - self.actions.write().unwrap().clear(); + *self.actions.write().unwrap() = Vec::new(); self.set_state(GameState::Ongoing); } From 89213f4700909cfc47b1a0e115c728859250bba3 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 15:15:48 +0100 Subject: [PATCH 37/48] . --- datagen/src/thread.rs | 2 +- src/mcts.rs | 254 ++++++++++++++++++++++++++---------------- src/mcts/helpers.rs | 46 +++++++- src/tree.rs | 33 +++--- src/tree/half.rs | 2 +- src/uci.rs | 4 +- 6 files changed, 219 insertions(+), 122 deletions(-) diff --git a/datagen/src/thread.rs b/datagen/src/thread.rs index 210e8a54..92db0dd8 100644 --- a/datagen/src/thread.rs +++ b/datagen/src/thread.rs @@ -119,7 +119,7 @@ impl<'a> DatagenThread<'a> { let abort = AtomicBool::new(false); tree.try_use_subtree(&position, &None); - let mut searcher = Searcher::new( + let searcher = Searcher::new( position.clone(), &tree, &self.params, diff --git a/src/mcts.rs b/src/mcts.rs index 6add8d17..b26635f1 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -51,33 +51,52 @@ impl<'a> Searcher<'a> { } } - pub fn search( - &mut self, - limits: Limits, + #[allow(clippy::too_many_arguments)] + fn playout_until_full_main( + &self, + limits: &Limits, + timer: &Instant, + nodes: &mut usize, + depth: &mut usize, + cumulative_depth: &mut usize, + best_move: &mut Move, + best_move_changes: &mut i32, + previous_score: &mut f32, uci_output: bool, - total_nodes: &mut usize, - ) -> (Move, f32) { - let timer = Instant::now(); - - // attempt to reuse the current tree stored in memory - let node = self.tree.root_node(); - - // relabel root policies with root PST value - if self.tree[node].has_children() { - self.tree[node].relabel_policy(&self.root_position, self.params, self.policy); - } else { - self.tree[node].expand::(&self.root_position, self.params, self.policy); + ) { + if self.playout_until_full_internal( + nodes, + cumulative_depth, + |n, cd| { + self.check_limits( + limits, + timer, + n, + best_move, + best_move_changes, + previous_score, + depth, + cd, + uci_output, + ) + } + ) { + self.abort.store(true, Ordering::Relaxed); } + } - let mut nodes = 0; - let mut depth = 0; - let mut cumulative_depth = 0; - - let mut best_move = Move::NULL; - let mut best_move_changes = 0; - let mut previous_score = f32::NEG_INFINITY; + fn playout_until_full_worker(&mut self, nodes: &mut usize, cumulative_depth: &mut usize) { + let _ = self.playout_until_full_internal(nodes, cumulative_depth, |_, _| false); + } - // search loop + fn playout_until_full_internal( + &self, + nodes: &mut usize, + cumulative_depth: &mut usize, + mut stop: F, + ) -> bool + where F: FnMut(usize, usize) -> bool + { loop { let mut pos = self.root_position.clone(); let mut this_depth = 0; @@ -89,101 +108,142 @@ impl<'a> Searcher<'a> { &mut this_depth, ) { self.tree.root_stats().update(u); + } else { + return false; } - cumulative_depth += this_depth - 1; + *cumulative_depth += this_depth - 1; + *nodes += 1; // proven checkmate if self.tree[self.tree.root_node()].is_terminal() { - break; - } - - if nodes >= limits.max_nodes { - break; + return true; } + // stop signal sent if self.abort.load(Ordering::Relaxed) { - break; + return true; } - nodes += 1; - - if nodes % 128 == 0 { - if let Some(time) = limits.max_time { - if timer.elapsed().as_millis() >= time { - break; - } - } + if stop(*nodes, *cumulative_depth) { + return true; + } + } + } - let new_best_move = self.get_best_move(); - if new_best_move != best_move { - best_move = new_best_move; - best_move_changes += 1; + #[allow(clippy::too_many_arguments)] + fn check_limits( + &self, + limits: &Limits, + timer: &Instant, + nodes: usize, + best_move: &mut Move, + best_move_changes: &mut i32, + previous_score: &mut f32, + depth: &mut usize, + cumulative_depth: usize, + uci_output: bool, + ) -> bool { + if nodes % 128 == 0 { + if let Some(time) = limits.max_time { + if timer.elapsed().as_millis() >= time { + return true; } } - if nodes % 4096 == 0 { - // Time management - if let Some(time) = limits.opt_time { - let elapsed = timer.elapsed().as_millis(); - - // Use more time if our eval is falling, and vice versa - let (_, mut score) = self.get_pv(0); - score = Searcher::get_cp(score); - let eval_diff = if previous_score == f32::NEG_INFINITY { - 0.0 - } else { - previous_score - score - }; - let falling_eval = (1.0 + eval_diff * self.params.tm_falling_eval1()).clamp( - self.params.tm_falling_eval2(), - self.params.tm_falling_eval3(), - ); - - // Use more time if our best move is changing frequently - let best_move_instability = (1.0 - + (best_move_changes as f32 * self.params.tm_bmi1()).ln_1p()) - .clamp(self.params.tm_bmi2(), self.params.tm_bmi3()); - - // Use less time if our best move has a large percentage of visits, and vice versa - let nodes_effort = self.get_best_action().visits() as f32 / nodes as f32; - let best_move_visits = (self.params.tm_bmv1() - - ((nodes_effort + self.params.tm_bmv2()) * self.params.tm_bmv3()).ln_1p() - * self.params.tm_bmv4()) - .clamp(self.params.tm_bmv5(), self.params.tm_bmv6()); - - let total_time = - (time as f32 * falling_eval * best_move_instability * best_move_visits) - as u128; - if elapsed >= total_time { - break; - } - - if nodes % 16384 == 0 { - best_move_changes = 0; - } - previous_score = if previous_score == f32::NEG_INFINITY { - score - } else { - (score + 2.0 * previous_score) / 3.0 - }; - } + let new_best_move = self.get_best_move(); + if new_best_move != *best_move { + *best_move = new_best_move; + *best_move_changes += 1; } + } - // define "depth" as the average depth of selection - let avg_depth = cumulative_depth / nodes; - if avg_depth > depth { - depth = avg_depth; - if depth >= limits.max_depth { - break; + if nodes % 4096 == 0 { + // Time management + if let Some(time) = limits.opt_time { + let (should_stop, score) = SearchHelpers::soft_time_cutoff( + self, + timer, + *previous_score, + *best_move_changes, + nodes, + time + ); + + if should_stop { + return true; } - if uci_output { - self.search_report(depth, &timer, nodes); + if nodes % 16384 == 0 { + *best_move_changes = 0; } + + *previous_score = if *previous_score == f32::NEG_INFINITY { + score + } else { + (score + 2.0 * *previous_score) / 3.0 + }; + } + } + + // define "depth" as the average depth of selection + let avg_depth = cumulative_depth / nodes; + if avg_depth > *depth { + *depth = avg_depth; + if *depth >= limits.max_depth { + return true; + } + + if uci_output { + self.search_report(*depth, timer, nodes); } } + false + } + + pub fn search( + &self, + limits: Limits, + uci_output: bool, + total_nodes: &mut usize, + ) -> (Move, f32) { + let timer = Instant::now(); + + // attempt to reuse the current tree stored in memory + let node = self.tree.root_node(); + + // relabel root policies with root PST value + if self.tree[node].has_children() { + self.tree[node].relabel_policy(&self.root_position, self.params, self.policy); + } else { + self.tree[node].expand::(&self.root_position, self.params, self.policy); + } + + let mut nodes = 0; + let mut depth = 0; + let mut cumulative_depth = 0; + + let mut best_move = Move::NULL; + let mut best_move_changes = 0; + let mut previous_score = f32::NEG_INFINITY; + + // search loop + while !self.abort.load(Ordering::Relaxed) { + self.playout_until_full_main( + &limits, + &timer, + &mut nodes, + &mut depth, + &mut cumulative_depth, + &mut best_move, + &mut best_move_changes, + &mut previous_score, uci_output, + ); + + self.tree.flip(); + } + self.abort.store(true, Ordering::Relaxed); *total_nodes += nodes; @@ -197,7 +257,7 @@ impl<'a> Searcher<'a> { (Move::from(best_child.mov()), best_child.q()) } - fn perform_one_iteration(&mut self, pos: &mut ChessState, ptr: NodePtr, node_stats: &ActionStats, depth: &mut usize) -> Option { + fn perform_one_iteration(&self, pos: &mut ChessState, ptr: NodePtr, node_stats: &ActionStats, depth: &mut usize) -> Option { *depth += 1; let hash = pos.hash(); diff --git a/src/mcts/helpers.rs b/src/mcts/helpers.rs index 145bc7d4..8c42dd2f 100644 --- a/src/mcts/helpers.rs +++ b/src/mcts/helpers.rs @@ -1,4 +1,6 @@ -use crate::{mcts::MctsParams, tree::{ActionStats, Edge}}; +use std::time::Instant; + +use crate::{mcts::{MctsParams, Searcher}, tree::{ActionStats, Edge}}; pub struct SearchHelpers; @@ -107,4 +109,46 @@ impl SearchHelpers { (opt_time, max_time) } } + + pub fn soft_time_cutoff( + searcher: &Searcher, + timer: &Instant, + previous_score: f32, + best_move_changes: i32, + nodes: usize, + time: u128, + ) -> (bool, f32) { + let elapsed = timer.elapsed().as_millis(); + + // Use more time if our eval is falling, and vice versa + let (_, mut score) = searcher.get_pv(0); + score = Searcher::get_cp(score); + let eval_diff = if previous_score == f32::NEG_INFINITY { + 0.0 + } else { + previous_score - score + }; + let falling_eval = (1.0 + eval_diff * searcher.params.tm_falling_eval1()).clamp( + searcher.params.tm_falling_eval2(), + searcher.params.tm_falling_eval3(), + ); + + // Use more time if our best move is changing frequently + let best_move_instability = (1.0 + + (best_move_changes as f32 * searcher.params.tm_bmi1()).ln_1p()) + .clamp(searcher.params.tm_bmi2(), searcher.params.tm_bmi3()); + + // Use less time if our best move has a large percentage of visits, and vice versa + let nodes_effort = searcher.get_best_action().visits() as f32 / nodes as f32; + let best_move_visits = (searcher.params.tm_bmv1() + - ((nodes_effort + searcher.params.tm_bmv2()) * searcher.params.tm_bmv3()).ln_1p() + * searcher.params.tm_bmv4()) + .clamp(searcher.params.tm_bmv5(), searcher.params.tm_bmv6()); + + let total_time = + (time as f32 * falling_eval * best_move_instability * best_move_visits) + as u128; + + (elapsed >= total_time, score) + } } diff --git a/src/tree.rs b/src/tree.rs index eb361a48..d5731391 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -13,7 +13,7 @@ pub use ptr::NodePtr; pub use stats::ActionStats; use std::{ - sync::{atomic::{AtomicBool, Ordering}, Mutex}, + sync::atomic::{AtomicBool, Ordering}, time::Instant, }; @@ -24,7 +24,6 @@ use crate::{ pub struct Tree { tree: [TreeHalf; 2], half: AtomicBool, - flip_lock: Mutex<()>, hash: HashTable, root_stats: ActionStats, } @@ -52,7 +51,6 @@ impl Tree { half: AtomicBool::new(false), hash: HashTable::new(hash_cap / 4), root_stats: ActionStats::default(), - flip_lock: Mutex::new(()), } } @@ -64,11 +62,6 @@ impl Tree { self.tree[self.half()].is_full() } - fn flip(&self) { - let old = usize::from(self.half.fetch_xor(true, Ordering::Relaxed)); - self.tree[old].clear_ptrs(); - self.tree[old ^ 1].clear(); - } pub fn copy_across(&self, from: NodePtr, to: NodePtr) { if from == to { @@ -82,23 +75,23 @@ impl Tree { std::mem::swap(f, t); } - #[must_use] - pub fn push_new(&self, state: GameState) -> Option { - let new_ptr = self.tree[self.half()].push_new(state); - - if new_ptr.is_null() { - let _lock = self.flip_lock.lock(); + pub fn flip(&self) { + let old_root_ptr = self.root_node(); - if self.is_full() { - let old_root_ptr = self.root_node(); + let old = usize::from(self.half.fetch_xor(true, Ordering::Relaxed)); + self.tree[old].clear_ptrs(); + self.tree[old ^ 1].clear(); - self.flip(); + let new_root_ptr = self.tree[self.half()].push_new(GameState::Ongoing); - let new_root_ptr = self.tree[self.half()].push_new(GameState::Ongoing); + self.copy_across(old_root_ptr, new_root_ptr); + } - self.copy_across(old_root_ptr, new_root_ptr); - } + #[must_use] + pub fn push_new(&self, state: GameState) -> Option { + let new_ptr = self.tree[self.half()].push_new(state); + if new_ptr.is_null() { None } else { Some(new_ptr) diff --git a/src/tree/half.rs b/src/tree/half.rs index d2858364..703c8399 100644 --- a/src/tree/half.rs +++ b/src/tree/half.rs @@ -36,7 +36,7 @@ impl TreeHalf { pub fn push_new(&self, state: GameState) -> NodePtr { let idx = self.used.fetch_add(1, Ordering::Relaxed); - if idx == self.nodes.len() { + if idx >= self.nodes.len() { return NodePtr::NULL; } diff --git a/src/uci.rs b/src/uci.rs index f1fb045f..1821fca6 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -143,7 +143,7 @@ impl Uci { let abort = AtomicBool::new(false); let pos = ChessState::from_fen(fen); tree.try_use_subtree(&pos, &None); - let mut searcher = Searcher::new(pos, &tree, params, policy, value, &abort); + let searcher = Searcher::new(pos, &tree, params, policy, value, &abort); let timer = Instant::now(); searcher.search(limits, false, &mut total_nodes); time += timer.elapsed().as_secs_f32(); @@ -316,7 +316,7 @@ fn go( std::thread::scope(|s| { s.spawn(|| { - let mut searcher = Searcher::new(pos.clone(), tree, params, policy, value, &abort); + let searcher = Searcher::new(pos.clone(), tree, params, policy, value, &abort); let (mov, _) = searcher.search(limits, true, &mut 0); println!("bestmove {}", pos.conv_mov_to_str(mov)); From 3572ef0c3150402329ac4cf4a4e44f9d244bda0c Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 15:23:01 +0100 Subject: [PATCH 38/48] Bench: 1954267 --- src/mcts.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/mcts.rs b/src/mcts.rs index b26635f1..e541abe7 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -144,6 +144,10 @@ impl<'a> Searcher<'a> { cumulative_depth: usize, uci_output: bool, ) -> bool { + if nodes >= limits.max_nodes { + return true; + } + if nodes % 128 == 0 { if let Some(time) = limits.max_time { if timer.elapsed().as_millis() >= time { From 8029e1caa8e9bf52475fa1fc8d1bc30e8ad0419f Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 15:37:27 +0100 Subject: [PATCH 39/48] Bench: 1954267 --- datagen/src/thread.rs | 2 +- src/mcts.rs | 34 +++++++++++++++++++++------------- src/tree/node.rs | 8 +++++--- src/uci.rs | 6 ++++-- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/datagen/src/thread.rs b/datagen/src/thread.rs index 92db0dd8..af3b4d45 100644 --- a/datagen/src/thread.rs +++ b/datagen/src/thread.rs @@ -128,7 +128,7 @@ impl<'a> DatagenThread<'a> { &abort, ); - let (bm, score) = searcher.search(limits, false, &mut 0); + let (bm, score) = searcher.search(1, limits, false, &mut 0); game.push(position.stm(), bm, score); diff --git a/src/mcts.rs b/src/mcts.rs index e541abe7..51f063dd 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -11,8 +11,7 @@ use crate::{ }; use std::{ - sync::atomic::{AtomicBool, Ordering}, - time::Instant, + sync::atomic::{AtomicBool, Ordering}, thread, time::Instant }; #[derive(Clone, Copy)] @@ -85,7 +84,7 @@ impl<'a> Searcher<'a> { } } - fn playout_until_full_worker(&mut self, nodes: &mut usize, cumulative_depth: &mut usize) { + fn playout_until_full_worker(&self, nodes: &mut usize, cumulative_depth: &mut usize) { let _ = self.playout_until_full_internal(nodes, cumulative_depth, |_, _| false); } @@ -208,6 +207,7 @@ impl<'a> Searcher<'a> { pub fn search( &self, + threads: usize, limits: Limits, uci_output: bool, total_nodes: &mut usize, @@ -234,16 +234,24 @@ impl<'a> Searcher<'a> { // search loop while !self.abort.load(Ordering::Relaxed) { - self.playout_until_full_main( - &limits, - &timer, - &mut nodes, - &mut depth, - &mut cumulative_depth, - &mut best_move, - &mut best_move_changes, - &mut previous_score, uci_output, - ); + thread::scope(|s| { + s.spawn(|| { + self.playout_until_full_main( + &limits, + &timer, + &mut nodes, + &mut depth, + &mut cumulative_depth, + &mut best_move, + &mut best_move_changes, + &mut previous_score, uci_output, + ); + }); + + for _ in 0..threads - 1 { + s.spawn(|| self.playout_until_full_worker(&mut 0, &mut 0)); + } + }); self.tree.flip(); } diff --git a/src/tree/node.rs b/src/tree/node.rs index 990213fe..947b0827 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -64,13 +64,15 @@ impl Node { params: &MctsParams, policy: &PolicyNetwork, ) { - assert!(self.is_not_expanded()); + let mut actions = self.actions_mut(); + + if actions.len() != 0 { + return; + } let feats = pos.get_policy_feats(); let mut max = f32::NEG_INFINITY; - let mut actions = self.actions_mut(); - pos.map_legal_moves(|mov| { let policy = pos.get_policy(mov, &feats, policy); diff --git a/src/uci.rs b/src/uci.rs index 1821fca6..ecceda47 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -68,6 +68,7 @@ impl Uci { report_moves, policy, value, + threads, &mut stored_message, ); @@ -145,7 +146,7 @@ impl Uci { tree.try_use_subtree(&pos, &None); let searcher = Searcher::new(pos, &tree, params, policy, value, &abort); let timer = Instant::now(); - searcher.search(limits, false, &mut total_nodes); + searcher.search(1, limits, false, &mut total_nodes); time += timer.elapsed().as_secs_f32(); tree.clear(); } @@ -241,6 +242,7 @@ fn go( report_moves: bool, policy: &PolicyNetwork, value: &ValueNetwork, + threads: usize, stored_message: &mut Option, ) { let mut max_nodes = i32::MAX as usize; @@ -317,7 +319,7 @@ fn go( std::thread::scope(|s| { s.spawn(|| { let searcher = Searcher::new(pos.clone(), tree, params, policy, value, &abort); - let (mov, _) = searcher.search(limits, true, &mut 0); + let (mov, _) = searcher.search(threads, limits, true, &mut 0); println!("bestmove {}", pos.conv_mov_to_str(mov)); if report_moves { From c20fd69e932e4f8e4570b406ac64247d3d32038c Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 15:55:58 +0100 Subject: [PATCH 40/48] Bench: 1954267 --- src/uci.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/uci.rs b/src/uci.rs index ecceda47..0b2e173c 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -162,6 +162,7 @@ fn preamble() { println!("id name monty {}", env!("CARGO_PKG_VERSION")); println!("id author Jamie Whiting"); println!("option name Hash type spin default 64 min 1 max 8192"); + println!("option name Threads type spin default 1 min 1 max 512"); println!("option name report_moves type button"); Uci::options(); MctsParams::info(MctsParams::default()); From c5d0cd4628295632763295ad4a417988dc245d06 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 16:16:36 +0100 Subject: [PATCH 41/48] Bench: 1954267 --- src/mcts.rs | 2 +- src/tree.rs | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 51f063dd..0bfe8505 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -253,7 +253,7 @@ impl<'a> Searcher<'a> { } }); - self.tree.flip(); + self.tree.flip(true); } self.abort.store(true, Ordering::Relaxed); diff --git a/src/tree.rs b/src/tree.rs index d5731391..88d69b67 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -75,16 +75,18 @@ impl Tree { std::mem::swap(f, t); } - pub fn flip(&self) { + pub fn flip(&self, copy_across: bool) { let old_root_ptr = self.root_node(); let old = usize::from(self.half.fetch_xor(true, Ordering::Relaxed)); self.tree[old].clear_ptrs(); self.tree[old ^ 1].clear(); - let new_root_ptr = self.tree[self.half()].push_new(GameState::Ongoing); + if copy_across { + let new_root_ptr = self.tree[self.half()].push_new(GameState::Ongoing); - self.copy_across(old_root_ptr, new_root_ptr); + self.copy_across(old_root_ptr, new_root_ptr); + } } #[must_use] @@ -232,7 +234,7 @@ impl Tree { if !found { println!("info string no subtree found"); self.clear_halves(); - self.flip(); + self.flip(false); self.push_new(GameState::Ongoing).unwrap(); } From 2f07abac35725fb1c99df4a52b1b9730ce15b921 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 16:29:41 +0100 Subject: [PATCH 42/48] Working? Bench: 1954267 --- src/mcts.rs | 15 +++++++++++++-- src/tree/node.rs | 14 ++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 0bfe8505..53be35dc 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -302,7 +302,13 @@ impl<'a> Searcher<'a> { let child_ptr = self.tree.fetch_node(pos, ptr, edge.ptr(), action)?; - let u = self.perform_one_iteration(pos, child_ptr, &edge.stats(), depth)?; + self.tree[child_ptr].inc_threads(); + + let maybe_u = self.perform_one_iteration(pos, child_ptr, &edge.stats(), depth); + + self.tree[child_ptr].dec_threads(); + + let u = maybe_u?; let new_q = self.tree.update_edge_stats(ptr, action, u); self.tree.push_hash(hash, new_q); @@ -340,7 +346,12 @@ impl<'a> Searcher<'a> { let expl = cpuct * expl_scale; self.tree.get_best_child_by_key(ptr, |action| { - let q = SearchHelpers::get_action_value(action, fpu); + let q = if !action.ptr().is_null() && self.tree[action.ptr()].threads() > 0 { + 0.0 + } else { + SearchHelpers::get_action_value(action, fpu) + }; + let u = expl * action.policy() / (1 + action.visits()) as f32; q + u diff --git a/src/tree/node.rs b/src/tree/node.rs index 947b0827..391a98d9 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -6,6 +6,7 @@ use crate::{chess::Move, tree::{Edge, NodePtr}, ChessState, GameState, MctsParam pub struct Node { actions: RwLock>, state: AtomicU16, + threads: AtomicU16, } impl Node { @@ -13,6 +14,7 @@ impl Node { Node { actions: RwLock::new(Vec::new()), state: AtomicU16::new(u16::from(state)), + threads: AtomicU16::new(0), } } @@ -29,6 +31,18 @@ impl Node { self.actions.read().unwrap().len() } + pub fn threads(&self) -> u16 { + self.threads.load(Ordering::Relaxed) + } + + pub fn inc_threads(&self) { + self.threads.fetch_add(1, Ordering::Relaxed); + } + + pub fn dec_threads(&self) { + self.threads.fetch_sub(1, Ordering::Relaxed); + } + pub fn actions(&self) -> RwLockReadGuard> { self.actions.read().unwrap() } From c9ffc9bab9bc95b04117e2ff466be2c3da028c9b Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 16:52:45 +0100 Subject: [PATCH 43/48] fix Bench: 1954267 --- src/tree.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 88d69b67..095d9fe5 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -64,14 +64,22 @@ impl Tree { pub fn copy_across(&self, from: NodePtr, to: NodePtr) { + let f = &mut *self[from].actions_mut(); + let t = &mut *self[to].actions_mut(); + + self[to].set_state(self[from].state()); + + if f.is_empty() { + return; + } + + assert!(t.is_empty()); + if from == to { return; } self[to].set_state(self[from].state()); - - let f = &mut *self[from].actions_mut(); - let t = &mut *self[to].actions_mut(); std::mem::swap(f, t); } From ff15098e6df220c70833f6d31fb0abc7bf0f9fe1 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 16:53:31 +0100 Subject: [PATCH 44/48] cleanup Bench: 1954267 --- src/tree.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 095d9fe5..3f248c20 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -64,6 +64,10 @@ impl Tree { pub fn copy_across(&self, from: NodePtr, to: NodePtr) { + if from == to { + return; + } + let f = &mut *self[from].actions_mut(); let t = &mut *self[to].actions_mut(); @@ -75,11 +79,6 @@ impl Tree { assert!(t.is_empty()); - if from == to { - return; - } - - self[to].set_state(self[from].state()); std::mem::swap(f, t); } From 0a040694258fbc4833795c1700ba6ea23479380b Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 16:58:48 +0100 Subject: [PATCH 45/48] Enforce invariant in `Tree::copy_across` Bench: 1954267 --- src/tree.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tree.rs b/src/tree.rs index 3f248c20..4f75bda3 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -91,6 +91,7 @@ impl Tree { if copy_across { let new_root_ptr = self.tree[self.half()].push_new(GameState::Ongoing); + self[new_root_ptr].clear(); self.copy_across(old_root_ptr, new_root_ptr); } @@ -229,6 +230,7 @@ impl Tree { found = true; if root != self.root_node() { + self[self.root_node()].clear(); self.copy_across(root, self.root_node()); self.root_stats = stats; println!("info string found subtree"); From 0d1c819967e72e239928b1e52472157a36f385be Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 19:12:17 +0100 Subject: [PATCH 46/48] Basic fix Bench: 1954267 --- src/tree.rs | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 4f75bda3..49133059 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -117,14 +117,36 @@ impl Tree { action: usize, ) -> Option { if ptr.is_null() { + let actions = self[parent_ptr].actions_mut(); + + let most_recent_ptr = actions[action].ptr(); + if !most_recent_ptr.is_null() { + return Some(most_recent_ptr); + } + + assert_eq!(ptr, most_recent_ptr); + let state = pos.game_state(); let new_ptr = self.push_new(state)?; - self.set_edge_ptr(parent_ptr, action, new_ptr); + + actions[action].set_ptr(new_ptr); + Some(new_ptr) } else if ptr.half() != self.half.load(Ordering::Relaxed) { + let actions = self[parent_ptr].actions_mut(); + + let most_recent_ptr = actions[action].ptr(); + if most_recent_ptr.half() == self.half.load(Ordering::Relaxed){ + return Some(most_recent_ptr); + } + + assert_eq!(ptr, most_recent_ptr); + let new_ptr = self.push_new(GameState::Ongoing)?; + self.copy_across(ptr, new_ptr); - self.set_edge_ptr(parent_ptr, action, new_ptr); + + actions[action].set_ptr(new_ptr); Some(new_ptr) } else { @@ -144,9 +166,9 @@ impl Tree { self[ptr].actions()[action].clone() } - pub fn set_edge_ptr(&self, ptr: NodePtr, action: usize, set: NodePtr) { - self[ptr].actions()[action].set_ptr(set); - } + //pub fn set_edge_ptr(&self, ptr: NodePtr, action: usize, set: NodePtr) { + // self[ptr].actions()[action].set_ptr(set); + //} pub fn update_edge_stats(&self, ptr: NodePtr, action: usize, result: f32) -> f32 { let actions = &self[ptr].actions(); From 1861d056c7b9c7ba4021806741922b62bfeaa318 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Sun, 4 Aug 2024 21:52:10 +0100 Subject: [PATCH 47/48] Fix + Small Cleanup Bench: 1954267 --- src/mcts.rs | 15 ++++++--------- src/tree.rs | 4 ---- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 53be35dc..a91c5411 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -264,9 +264,8 @@ impl<'a> Searcher<'a> { self.search_report(depth.max(1), &timer, nodes); } - let best_action = self.tree.get_best_child(self.tree.root_node()); - let best_child = self.tree.edge_copy(self.tree.root_node(), best_action); - (Move::from(best_child.mov()), best_child.q()) + let best_action = self.get_best_action(); + (Move::from(best_action.mov()), best_action.q()) } fn perform_one_iteration(&self, pos: &mut ChessState, ptr: NodePtr, node_stats: &ActionStats, depth: &mut usize) -> Option { @@ -387,8 +386,7 @@ impl<'a> Searcher<'a> { fn get_pv(&self, mut depth: usize) -> (Vec, f32) { let mate = self.tree[self.tree.root_node()].is_terminal(); - let idx = self.tree.get_best_child(self.tree.root_node()); - let mut action = self.tree.edge_copy(self.tree.root_node(), idx); + let mut action = self.get_best_action(); let score = if !action.ptr().is_null() { match self.tree[action.ptr()].state() { @@ -402,8 +400,9 @@ impl<'a> Searcher<'a> { }; let mut pv = Vec::new(); + let half = self.tree.half() > 0; - while (mate || depth > 0) && !action.ptr().is_null() { + while (mate || depth > 0) && !action.ptr().is_null() && action.ptr().half() == half { pv.push(Move::from(action.mov())); let idx = self.tree.get_best_child(action.ptr()); @@ -424,9 +423,7 @@ impl<'a> Searcher<'a> { } fn get_best_move(&self) -> Move { - let idx = self.tree.get_best_child(self.tree.root_node()); - let action = self.tree.edge_copy(self.tree.root_node(), idx); - Move::from(action.mov()) + Move::from(self.get_best_action().mov()) } fn get_cp(score: f32) -> f32 { diff --git a/src/tree.rs b/src/tree.rs index 49133059..1b51b177 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -166,10 +166,6 @@ impl Tree { self[ptr].actions()[action].clone() } - //pub fn set_edge_ptr(&self, ptr: NodePtr, action: usize, set: NodePtr) { - // self[ptr].actions()[action].set_ptr(set); - //} - pub fn update_edge_stats(&self, ptr: NodePtr, action: usize, result: f32) -> f32 { let actions = &self[ptr].actions(); let edge = &actions[action]; From e2bfed06d9a587acc924e3a8f63c80703ddf04e6 Mon Sep 17 00:00:00 2001 From: jw1912 Date: Mon, 5 Aug 2024 03:51:16 +0100 Subject: [PATCH 48/48] Format Bench: 1954267 --- datagen/src/thread.rs | 10 ++------- src/mcts.rs | 50 ++++++++++++++++++++++++------------------- src/mcts/helpers.rs | 8 ++++--- src/tree.rs | 12 ++++++----- src/tree/edge.rs | 5 +++-- src/tree/half.rs | 6 +----- src/tree/hash.rs | 12 +++++------ src/tree/node.rs | 26 +++++++++++++--------- src/tree/ptr.rs | 2 +- src/tree/stats.rs | 8 ++++--- src/uci.rs | 18 +++++++++++++--- 11 files changed, 88 insertions(+), 69 deletions(-) diff --git a/datagen/src/thread.rs b/datagen/src/thread.rs index af3b4d45..b0a65bf2 100644 --- a/datagen/src/thread.rs +++ b/datagen/src/thread.rs @@ -119,14 +119,8 @@ impl<'a> DatagenThread<'a> { let abort = AtomicBool::new(false); tree.try_use_subtree(&position, &None); - let searcher = Searcher::new( - position.clone(), - &tree, - &self.params, - policy, - value, - &abort, - ); + let searcher = + Searcher::new(position.clone(), &tree, &self.params, policy, value, &abort); let (bm, score) = searcher.search(1, limits, false, &mut 0); diff --git a/src/mcts.rs b/src/mcts.rs index a91c5411..a0ec79a6 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -11,7 +11,9 @@ use crate::{ }; use std::{ - sync::atomic::{AtomicBool, Ordering}, thread, time::Instant + sync::atomic::{AtomicBool, Ordering}, + thread, + time::Instant, }; #[derive(Clone, Copy)] @@ -63,23 +65,19 @@ impl<'a> Searcher<'a> { previous_score: &mut f32, uci_output: bool, ) { - if self.playout_until_full_internal( - nodes, - cumulative_depth, - |n, cd| { - self.check_limits( - limits, - timer, - n, - best_move, - best_move_changes, - previous_score, - depth, - cd, - uci_output, - ) - } - ) { + if self.playout_until_full_internal(nodes, cumulative_depth, |n, cd| { + self.check_limits( + limits, + timer, + n, + best_move, + best_move_changes, + previous_score, + depth, + cd, + uci_output, + ) + }) { self.abort.store(true, Ordering::Relaxed); } } @@ -94,7 +92,8 @@ impl<'a> Searcher<'a> { cumulative_depth: &mut usize, mut stop: F, ) -> bool - where F: FnMut(usize, usize) -> bool + where + F: FnMut(usize, usize) -> bool, { loop { let mut pos = self.root_position.clone(); @@ -170,7 +169,7 @@ impl<'a> Searcher<'a> { *previous_score, *best_move_changes, nodes, - time + time, ); if should_stop { @@ -244,7 +243,8 @@ impl<'a> Searcher<'a> { &mut cumulative_depth, &mut best_move, &mut best_move_changes, - &mut previous_score, uci_output, + &mut previous_score, + uci_output, ); }); @@ -268,7 +268,13 @@ impl<'a> Searcher<'a> { (Move::from(best_action.mov()), best_action.q()) } - fn perform_one_iteration(&self, pos: &mut ChessState, ptr: NodePtr, node_stats: &ActionStats, depth: &mut usize) -> Option { + fn perform_one_iteration( + &self, + pos: &mut ChessState, + ptr: NodePtr, + node_stats: &ActionStats, + depth: &mut usize, + ) -> Option { *depth += 1; let hash = pos.hash(); diff --git a/src/mcts/helpers.rs b/src/mcts/helpers.rs index 8c42dd2f..a5015214 100644 --- a/src/mcts/helpers.rs +++ b/src/mcts/helpers.rs @@ -1,6 +1,9 @@ use std::time::Instant; -use crate::{mcts::{MctsParams, Searcher}, tree::{ActionStats, Edge}}; +use crate::{ + mcts::{MctsParams, Searcher}, + tree::{ActionStats, Edge}, +}; pub struct SearchHelpers; @@ -146,8 +149,7 @@ impl SearchHelpers { .clamp(searcher.params.tm_bmv5(), searcher.params.tm_bmv6()); let total_time = - (time as f32 * falling_eval * best_move_instability * best_move_visits) - as u128; + (time as f32 * falling_eval * best_move_instability * best_move_visits) as u128; (elapsed >= total_time, score) } diff --git a/src/tree.rs b/src/tree.rs index 1b51b177..55095549 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -18,7 +18,8 @@ use std::{ }; use crate::{ - chess::{ChessState, Move}, GameState + chess::{ChessState, Move}, + GameState, }; pub struct Tree { @@ -62,7 +63,6 @@ impl Tree { self.tree[self.half()].is_full() } - pub fn copy_across(&self, from: NodePtr, to: NodePtr) { if from == to { return; @@ -136,7 +136,7 @@ impl Tree { let actions = self[parent_ptr].actions_mut(); let most_recent_ptr = actions[action].ptr(); - if most_recent_ptr.half() == self.half.load(Ordering::Relaxed){ + if most_recent_ptr.half() == self.half.load(Ordering::Relaxed) { return Some(most_recent_ptr); } @@ -242,7 +242,8 @@ impl Tree { if let Some(board) = prev_board { println!("info string searching for subtree"); - let (root, stats) = self.recurse_find(self.root_node(), board, root, self.root_stats.clone(), 2); + let (root, stats) = + self.recurse_find(self.root_node(), board, root, self.root_stats.clone(), 2); if !root.is_null() && self[root].has_children() { found = true; @@ -295,7 +296,8 @@ impl Tree { child_board.make_move(Move::from(action.mov())); - let found = self.recurse_find(child_idx, &child_board, board, action.stats(), depth - 1); + let found = + self.recurse_find(child_idx, &child_board, board, action.stats(), depth - 1); if !found.0.is_null() { return found; diff --git a/src/tree/edge.rs b/src/tree/edge.rs index 3cd8dd18..ae18b691 100644 --- a/src/tree/edge.rs +++ b/src/tree/edge.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::{AtomicU16, AtomicU32, AtomicI16, Ordering}; +use std::sync::atomic::{AtomicI16, AtomicU16, AtomicU32, Ordering}; use super::{ActionStats, NodePtr}; @@ -82,7 +82,8 @@ impl Edge { } pub fn set_policy(&self, policy: f32) { - self.policy.store((policy * f32::from(i16::MAX)) as i16, Ordering::Relaxed) + self.policy + .store((policy * f32::from(i16::MAX)) as i16, Ordering::Relaxed) } pub fn update(&self, result: f32) { diff --git a/src/tree/half.rs b/src/tree/half.rs index 703c8399..b2afdc87 100644 --- a/src/tree/half.rs +++ b/src/tree/half.rs @@ -1,8 +1,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; -use crate::GameState; use super::{Node, NodePtr}; - +use crate::GameState; pub struct TreeHalf { nodes: Vec, @@ -71,6 +70,3 @@ impl TreeHalf { self.used() >= self.nodes.len() } } - - - diff --git a/src/tree/hash.rs b/src/tree/hash.rs index 1ef71f65..ff71e60a 100644 --- a/src/tree/hash.rs +++ b/src/tree/hash.rs @@ -23,17 +23,13 @@ impl Clone for HashEntryInternal { impl From<&HashEntryInternal> for HashEntry { fn from(value: &HashEntryInternal) -> Self { - unsafe { - std::mem::transmute(value.0.load(Ordering::Relaxed)) - } + unsafe { std::mem::transmute(value.0.load(Ordering::Relaxed)) } } } impl From for u32 { fn from(value: HashEntry) -> Self { - unsafe { - std::mem::transmute(value) - } + unsafe { std::mem::transmute(value) } } } @@ -81,6 +77,8 @@ impl HashTable { q: (q * f32::from(u16::MAX)) as u16, }; - self.table[idx as usize].0.store(u32::from(entry), Ordering::Relaxed) + self.table[idx as usize] + .0 + .store(u32::from(entry), Ordering::Relaxed) } } diff --git a/src/tree/node.rs b/src/tree/node.rs index 391a98d9..a89e96f9 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,6 +1,13 @@ -use std::sync::{atomic::{AtomicU16, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::sync::{ + atomic::{AtomicU16, Ordering}, + RwLock, RwLockReadGuard, RwLockWriteGuard, +}; -use crate::{chess::Move, tree::{Edge, NodePtr}, ChessState, GameState, MctsParams, PolicyNetwork}; +use crate::{ + chess::Move, + tree::{Edge, NodePtr}, + ChessState, GameState, MctsParams, PolicyNetwork, +}; #[derive(Debug)] pub struct Node { @@ -91,13 +98,17 @@ impl Node { let policy = pos.get_policy(mov, &feats, policy); // trick for calculating policy before quantising - actions.push(Edge::new(NodePtr::from_raw(f32::to_bits(policy)), mov.into(), 0)); + actions.push(Edge::new( + NodePtr::from_raw(f32::to_bits(policy)), + mov.into(), + 0, + )); max = max.max(policy); }); let mut total = 0.0; - for action in actions.iter_mut() { + for action in actions.iter_mut() { let mut policy = f32::from_bits(action.ptr().inner()); policy = if ROOT { @@ -118,12 +129,7 @@ impl Node { } } - pub fn relabel_policy( - &self, - pos: &ChessState, - params: &MctsParams, - policy: &PolicyNetwork, - ) { + pub fn relabel_policy(&self, pos: &ChessState, params: &MctsParams, policy: &PolicyNetwork) { let feats = pos.get_policy_feats(); let mut max = f32::NEG_INFINITY; diff --git a/src/tree/ptr.rs b/src/tree/ptr.rs index 517d3869..46523663 100644 --- a/src/tree/ptr.rs +++ b/src/tree/ptr.rs @@ -27,4 +27,4 @@ impl NodePtr { pub fn from_raw(inner: u32) -> Self { Self(inner) } -} \ No newline at end of file +} diff --git a/src/tree/stats.rs b/src/tree/stats.rs index 99720721..11a86b34 100644 --- a/src/tree/stats.rs +++ b/src/tree/stats.rs @@ -55,8 +55,10 @@ impl ActionStats { let q = (self.q64() * v + r) / (v + 1.0); let sq_q = (self.sq_q() * v + r.powi(2)) / (v + 1.0); - self.q.store((q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); - self.sq_q.store((sq_q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); + self.q + .store((q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); + self.sq_q + .store((sq_q * f64::from(u32::MAX)) as u32, Ordering::Relaxed); } pub fn clear(&self) { @@ -64,4 +66,4 @@ impl ActionStats { self.q.store(0, Ordering::Relaxed); self.sq_q.store(0, Ordering::Relaxed); } -} \ No newline at end of file +} diff --git a/src/uci.rs b/src/uci.rs index 0b2e173c..c2063826 100644 --- a/src/uci.rs +++ b/src/uci.rs @@ -52,7 +52,13 @@ impl Uci { let cmd = *commands.first().unwrap_or(&"oops"); match cmd { "isready" => println!("readyok"), - "setoption" => setoption(&commands, &mut params, &mut report_moves, &mut tree, &mut threads), + "setoption" => setoption( + &commands, + &mut params, + &mut report_moves, + &mut tree, + &mut threads, + ), "position" => position(commands, &mut pos), "go" => { // increment game ply every time `go` is called @@ -169,7 +175,13 @@ fn preamble() { println!("uciok"); } -fn setoption(commands: &[&str], params: &mut MctsParams, report_moves: &mut bool, tree: &mut Tree, threads: &mut usize) { +fn setoption( + commands: &[&str], + params: &mut MctsParams, + report_moves: &mut bool, + tree: &mut Tree, + threads: &mut usize, +) { if let ["setoption", "name", "report_moves"] = commands { *report_moves = !*report_moves; return; @@ -182,7 +194,7 @@ fn setoption(commands: &[&str], params: &mut MctsParams, report_moves: &mut bool if *x == "Threads" { *threads = y.parse().unwrap(); - return + return; } (*x, y.parse::().unwrap_or(0))