diff --git a/crates/beet_ml/src/environments/frozen_lake/frozen_lake_environment.rs b/crates/beet_ml/src/environments/frozen_lake/frozen_lake_environment.rs deleted file mode 100644 index 671e8d03..00000000 --- a/crates/beet_ml/src/environments/frozen_lake/frozen_lake_environment.rs +++ /dev/null @@ -1,74 +0,0 @@ -use crate::prelude::*; -use bevy::prelude::*; -use bevy::utils::HashMap; -use rand::rngs::StdRng; -use rand::SeedableRng; - -/// The outcome of a transition. -#[derive(Debug, Clone)] -pub struct TransitionOutcome { - /// The new position of the agent. - pub pos: UVec2, - /// The reward obtained from the transition. - pub reward: f32, - /// Whether the new state is terminal. - pub is_terminal: bool, -} - -#[derive(Debug, Clone, Component)] -/// An environment for the Frozen Lake game. -pub struct FrozenLakeEnv { - /// A number generator for determining. - rng: StdRng, - /// Whether there is a 2/3 chance the agent moves left or right of the intended direction. - is_slippery: bool, - /// The transition probabilities for each state-action pair. - outcomes: HashMap<(UVec2, GridDirection), TransitionOutcome>, -} - -impl FrozenLakeEnv { - /// Creates a new environment. - /// # Panics - /// If the map has no agent position. - pub fn new(grid: FrozenLakeMap, is_slippery: bool) -> Self { - Self { - is_slippery, - rng: StdRng::from_entropy(), - outcomes: grid.transition_outcomes(), - } - } - pub fn with_slippery_rng(mut self, rng: StdRng) -> Self { - self.rng = rng; - self - } -} - -impl Environment for FrozenLakeEnv { - type State = GridPos; - type Action = GridDirection; - - - fn step( - &mut self, - state: &Self::State, - action: &Self::Action, - ) -> StepOutcome { - let action = if self.is_slippery { - action.as_slippery(&mut self.rng) - } else { - action.clone() - }; - let TransitionOutcome { - pos, - reward, - is_terminal, - } = self.outcomes[&(**state, action)]; - // println!("pos: {:?}, reward: {:?}, is_terminal: {:?}", pos, reward, is_terminal); - - StepOutcome { - state: pos.into(), - reward, - done: is_terminal, - } - } -} diff --git a/crates/beet_ml/src/environments/frozen_lake/frozen_lake_map.rs b/crates/beet_ml/src/environments/frozen_lake/frozen_lake_map.rs index fa3561b5..6a105f0b 100644 --- a/crates/beet_ml/src/environments/frozen_lake/frozen_lake_map.rs +++ b/crates/beet_ml/src/environments/frozen_lake/frozen_lake_map.rs @@ -16,6 +16,7 @@ use strum::IntoEnumIterator; Serialize, Deserialize, Component, + Reflect, )] pub enum FrozenLakeCell { Agent, @@ -43,7 +44,7 @@ impl FrozenLakeCell { /// Define an intial state for a [`FrozenLakeEnv`]. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Component)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Component, Reflect)] pub struct FrozenLakeMap { cells: Vec, size: UVec2, @@ -71,8 +72,8 @@ impl FrozenLakeMap { pub fn cells(&self) -> &Vec { &self.cells } pub fn size(&self) -> UVec2 { self.size } - pub fn width(&self) -> u32 { self.size.x } - pub fn height(&self) -> u32 { self.size.y } + pub fn num_cols(&self) -> u32 { self.size.x } + pub fn num_rows(&self) -> u32 { self.size.y } pub fn cells_with_positions( &self, ) -> impl Iterator { @@ -92,7 +93,7 @@ impl FrozenLakeMap { &self, position: UVec2, direction: GridDirection, - ) -> Option { + ) -> Option> { let direction: IVec2 = direction.into(); let new_pos = IVec2::new( position.x as i32 + direction.x, @@ -104,10 +105,10 @@ impl FrozenLakeMap { let new_pos = new_pos.try_into().expect("already checked in bounds"); let new_cell = self.position_to_cell(new_pos); - Some(TransitionOutcome { + Some(StepOutcome { reward: new_cell.reward(), - pos: new_pos, - is_terminal: new_cell.is_terminal(), + state: GridPos(new_pos), + done: new_cell.is_terminal(), }) } } @@ -128,29 +129,29 @@ impl FrozenLakeMap { pub fn transition_outcomes( &self, - ) -> HashMap<(UVec2, GridDirection), TransitionOutcome> { + ) -> HashMap<(GridPos, GridDirection), StepOutcome> { let mut outcomes = HashMap::new(); for (pos, cell) in self.cells_with_positions() { for action in GridDirection::iter() { let outcome = if cell.is_terminal() { // early exit, cannot move from terminal cell - TransitionOutcome { + StepOutcome { reward: 0.0, - pos, - is_terminal: true, + state: GridPos(pos), + done: true, } } else { // yes you can go here self.try_transition(pos, action).unwrap_or( // stay where you are - TransitionOutcome { + StepOutcome { reward: 0.0, - pos, - is_terminal: false, + state: GridPos(pos), + done: false, }, ) }; - outcomes.insert((pos, action), outcome); + outcomes.insert((GridPos(pos), action), outcome); } } @@ -190,27 +191,28 @@ impl FrozenLakeMap { impl FrozenLakeMap { #[rustfmt::skip] pub fn default_eight_by_eight() -> Self { - Self { - size: UVec2::new(8, 8), - //https://github.com/openai/gym/blob/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/frozen_lake.py#L17 - cells: vec![ - //row 1 - FrozenLakeCell::Agent, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, - //row 2 - FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, - //row 3 - FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, - //row 4 - FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, - //row 5 - FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, - //row 6 - FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, - //row 7 - FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, - //row 8 - FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Goal, - ], - } + todo!(); + // Self { + // size: UVec2::new(8, 8), + // //https://github.com/openai/gym/blob/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/frozen_lake.py#L17 + // cells: vec![ + // //row 1 + // FrozenLakeCell::Agent, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, + // //row 2 + // FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, + // //row 3 + // FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, + // //row 4 + // FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, + // //row 5 + // FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, + // //row 6 + // FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, + // //row 7 + // FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, + // //row 8 + // FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Goal, + // ], + // } } } diff --git a/crates/beet_ml/src/environments/frozen_lake/frozen_lake_plugin.rs b/crates/beet_ml/src/environments/frozen_lake/frozen_lake_plugin.rs index cdc8d5b4..191c7adb 100644 --- a/crates/beet_ml/src/environments/frozen_lake/frozen_lake_plugin.rs +++ b/crates/beet_ml/src/environments/frozen_lake/frozen_lake_plugin.rs @@ -41,17 +41,20 @@ impl Plugin for FrozenLakePlugin { ActionPlugin::<( TranslateGrid, StepEnvironment, + ReadQPolicy, )>::default(), RlSessionPlugin::::default(), )) - .add_systems(Startup, init_frozen_lake_assets) + .add_systems(PreStartup, init_frozen_lake_assets) .add_systems(Update, reward_grid.in_set(PostTickSet)) .add_systems( Update, (spawn_frozen_lake_session, spawn_frozen_lake_episode) .in_set(PostTickSet), ) - .init_resource::(); + .init_resource::() + .init_asset::>() + .init_asset_loader::>(); let world = app.world_mut(); world.init_component::(); @@ -69,14 +72,17 @@ impl Plugin for FrozenLakePlugin { #[derive(Debug, Clone, Reflect)] pub struct FrozenLakeEpParams { pub learn_params: QLearnParams, - pub map_width: f32, + pub map: FrozenLakeMap, + pub grid_to_world: GridToWorld, } impl Default for FrozenLakeEpParams { fn default() -> Self { + let map = FrozenLakeMap::default_four_by_four(); Self { learn_params: QLearnParams::default(), - map_width: 4.0, + grid_to_world: GridToWorld::from_frozen_lake_map(&map, 4.0), + map, } } } @@ -93,7 +99,7 @@ pub struct FrozenLakeQTableSession; impl RlSessionTypes for FrozenLakeQTableSession { type State = GridPos; type Action = GridDirection; - type QSource = FrozenLakeQTable; - type Env = FrozenLakeEnv; + type QLearnPolicy = FrozenLakeQTable; + type Env = QTableEnv; type EpisodeParams = FrozenLakeEpParams; } diff --git a/crates/beet_ml/src/environments/frozen_lake/frozen_lake_scene.rs b/crates/beet_ml/src/environments/frozen_lake/frozen_lake_scene.rs new file mode 100644 index 00000000..23810619 --- /dev/null +++ b/crates/beet_ml/src/environments/frozen_lake/frozen_lake_scene.rs @@ -0,0 +1,62 @@ +use crate::prelude::*; +use bevy::prelude::*; + +pub fn spawn_frozen_lake_scene( + commands: &mut Commands, + map: &FrozenLakeMap, + grid_to_world: &GridToWorld, + assets: &Res, + bundle: impl Bundle + Clone, +) { + let tile_scale = Vec3::splat(grid_to_world.cell_width); + for x in 0..map.num_cols() { + for y in 0..map.num_rows() { + let mut pos = grid_to_world.world_pos(UVec2::new(x, y)); + pos.y -= grid_to_world.cell_width; + commands.spawn(( + SceneBundle { + scene: assets.tile.clone(), + transform: Transform::from_translation(pos) + .with_scale(tile_scale), + ..default() + }, + bundle.clone(), + )); + } + } + + let object_scale = Vec3::splat(grid_to_world.cell_width * 0.5); + + for (index, cell) in map.cells().iter().enumerate() { + let grid_pos = map.index_to_position(index); + let mut pos = grid_to_world.world_pos(grid_pos); + match cell { + FrozenLakeCell::Hole => { + pos.y += grid_to_world.cell_width * 0.25; // this asset is a bit too low + commands.spawn(( + SceneBundle { + scene: assets.hazard.clone(), + transform: Transform::from_translation(pos) + .with_scale(object_scale), + ..default() + }, + bundle.clone(), + )); + } + FrozenLakeCell::Goal => { + commands.spawn(( + SceneBundle { + scene: assets.goal.clone(), + transform: Transform::from_translation(pos) + .with_scale(object_scale), + ..default() + }, + bundle.clone(), + )); + } + FrozenLakeCell::Ice => {} + FrozenLakeCell::Agent => { /*spawns on episode */ } + } + {} + } +} diff --git a/crates/beet_ml/src/environments/frozen_lake/grid.rs b/crates/beet_ml/src/environments/frozen_lake/grid.rs index 191dc3b0..9be4935b 100644 --- a/crates/beet_ml/src/environments/frozen_lake/grid.rs +++ b/crates/beet_ml/src/environments/frozen_lake/grid.rs @@ -3,6 +3,8 @@ use crate::prelude::ActionSpace; use bevy::prelude::*; use rand::rngs::StdRng; use rand::Rng; +use serde::Deserialize; +use serde::Serialize; use strum::EnumCount; use strum::EnumIter; use strum::VariantArray; @@ -19,6 +21,8 @@ use strum::VariantArray; DerefMut, Component, Reflect, + Serialize, + Deserialize, )] pub struct GridPos(pub UVec2); @@ -55,6 +59,8 @@ impl From for GridPos { VariantArray, EnumIter, EnumCount, + Serialize, + Deserialize, )] pub enum GridDirection { #[default] @@ -124,23 +130,35 @@ impl ActionSpace for GridDirection { } -#[derive(Debug, Clone, Component)] +#[derive(Debug, Clone, Component, Reflect)] pub struct GridToWorld { pub map_width: f32, pub cell_width: f32, - pub map_size:UVec2, + pub map_size: UVec2, pub offset: Vec3, } impl GridToWorld { pub fn from_frozen_lake_map(grid: &FrozenLakeMap, map_width: f32) -> Self { - let cell_width = map_width / grid.width() as f32; + let cell_width = map_width / grid.num_cols() as f32; let h_cell_width = cell_width * 0.5; + + let h_map_width = map_width * 0.5; + let offset = Vec3::new( - grid.width() as f32 + h_cell_width, + -h_map_width + h_cell_width, 0., - grid.height() as f32 + h_cell_width, - ) * -0.5; + -h_map_width + h_cell_width, + ); + + + // let mut offset = Vec3::new( + // grid.num_cols() as f32, + // 0., + // grid.num_rows() as f32, + // ) * -0.5; + // offset.x -= h_cell_width; + // offset.z -= h_cell_width; Self { map_size: grid.size(), diff --git a/crates/beet_ml/src/environments/frozen_lake/mod.rs b/crates/beet_ml/src/environments/frozen_lake/mod.rs index b6aeb4d5..b1c2bfb9 100644 --- a/crates/beet_ml/src/environments/frozen_lake/mod.rs +++ b/crates/beet_ml/src/environments/frozen_lake/mod.rs @@ -4,6 +4,9 @@ pub use self::frozen_lake_plugin::*; pub mod frozen_lake_map; #[allow(unused_imports)] pub use self::frozen_lake_map::*; +pub mod frozen_lake_scene; +#[allow(unused_imports)] +pub use self::frozen_lake_scene::*; pub mod translate_grid; #[allow(unused_imports)] pub use self::translate_grid::*; @@ -19,6 +22,3 @@ pub use self::reward_grid::*; pub mod spawn_frozen_lake; #[allow(unused_imports)] pub use self::spawn_frozen_lake::*; -pub mod frozen_lake_environment; -#[allow(unused_imports)] -pub use self::frozen_lake_environment::*; diff --git a/crates/beet_ml/src/environments/frozen_lake/spawn_frozen_lake.rs b/crates/beet_ml/src/environments/frozen_lake/spawn_frozen_lake.rs index b0415a92..5f406c4c 100644 --- a/crates/beet_ml/src/environments/frozen_lake/spawn_frozen_lake.rs +++ b/crates/beet_ml/src/environments/frozen_lake/spawn_frozen_lake.rs @@ -39,65 +39,17 @@ pub fn spawn_frozen_lake_session( assets: Res, ) { for event in events.read() { - let map = FrozenLakeMap::default_four_by_four(); + let FrozenLakeEpParams { + map, grid_to_world, .. + } = &event.params; - let grid_to_world = - GridToWorld::from_frozen_lake_map(&map, event.params.map_width); - - let tile_scale = Vec3::splat(grid_to_world.cell_width); - for x in 0..map.width() { - for y in 0..map.height() { - let mut pos = grid_to_world.world_pos(UVec2::new(x, y)); - pos.y -= grid_to_world.cell_width; - commands.spawn(( - SceneBundle { - scene: assets.tile.clone(), - transform: Transform::from_translation(pos) - .with_scale(tile_scale), - ..default() - }, - SessionEntity(event.session), - DespawnOnSessionEnd, - )); - } - } - - let object_scale = Vec3::splat(grid_to_world.cell_width * 0.5); - - for (index, cell) in map.cells().iter().enumerate() { - let grid_pos = map.index_to_position(index); - let mut pos = grid_to_world.world_pos(grid_pos); - match cell { - FrozenLakeCell::Hole => { - pos.y += grid_to_world.cell_width * 0.25; // this asset is a bit too low - commands.spawn(( - SceneBundle { - scene: assets.hazard.clone(), - transform: Transform::from_translation(pos) - .with_scale(object_scale), - ..default() - }, - SessionEntity(event.session), - DespawnOnSessionEnd, - )); - } - FrozenLakeCell::Goal => { - commands.spawn(( - SceneBundle { - scene: assets.goal.clone(), - transform: Transform::from_translation(pos) - .with_scale(object_scale), - ..default() - }, - SessionEntity(event.session), - DespawnOnSessionEnd, - )); - } - FrozenLakeCell::Ice => {} - FrozenLakeCell::Agent => { /*spawns on episode */ } - } - {} - } + spawn_frozen_lake_scene( + &mut commands, + map, + grid_to_world, + &assets, + (SessionEntity(event.session), DespawnOnSessionEnd), + ) } } @@ -108,64 +60,52 @@ pub fn spawn_frozen_lake_episode( assets: Res, ) { for event in events.read() { - // TODO deduplicate - let map = FrozenLakeMap::default_four_by_four(); - let grid_to_world = - GridToWorld::from_frozen_lake_map(&map, event.params.map_width); + let FrozenLakeEpParams { + map, grid_to_world, .. + } = &event.params; + let object_scale = Vec3::splat(grid_to_world.cell_width * 0.5); - for (index, cell) in map.cells().iter().enumerate() { - let grid_pos = map.index_to_position(index); - let pos = grid_to_world.world_pos(grid_pos); - match cell { - FrozenLakeCell::Agent => { - commands - .spawn(( - SceneBundle { - scene: assets.character.clone(), - transform: Transform::from_translation(pos) - .with_scale(object_scale), - ..default() - }, - grid_to_world.clone(), - RlAgentBundle { - state: map.agent_position(), - action: GridDirection::sample(), - env: FrozenLakeEnv::new(map.clone(), false), - params: event.params.learn_params.clone(), - session: SessionEntity(event.session), - despawn: DespawnOnEpisodeEnd, - }, - )) - .with_children(|parent| { - let agent = parent.parent_entity(); + let agent_pos = map.agent_position(); + let agent_pos = grid_to_world.world_pos(*agent_pos); + + + commands + .spawn(( + SceneBundle { + scene: assets.character.clone(), + transform: Transform::from_translation(agent_pos) + .with_scale(object_scale), + ..default() + }, + grid_to_world.clone(), + RlAgentBundle { + state: map.agent_position(), + action: GridDirection::sample(), + env: QTableEnv::new(map.transition_outcomes()), + params: event.params.learn_params.clone(), + session: SessionEntity(event.session), + despawn: DespawnOnEpisodeEnd, + }, + )) + .with_children(|parent| { + let agent = parent.parent_entity(); - parent - .spawn(( - Running, - SequenceSelector, - Repeat::default(), - )) - .with_children(|parent| { - parent.spawn(( - TranslateGrid::new( - Duration::from_millis(1), - ), - TargetAgent(agent), - RunTimer::default(), - )); - parent.spawn(( - TargetAgent(agent), - StepEnvironment::< - FrozenLakeQTableSession, - >::new(event.episode), - )); - }); - }); - } - _ => {} - } - {} - } + parent + .spawn((Running, SequenceSelector, Repeat::default())) + .with_children(|parent| { + parent.spawn(( + TranslateGrid::new(Duration::from_millis(100)), + TargetAgent(agent), + RunTimer::default(), + )); + parent.spawn(( + TargetAgent(agent), + StepEnvironment::::new( + event.episode, + ), + )); + }); + }); } } diff --git a/crates/beet_ml/src/rl/environment.rs b/crates/beet_ml/src/rl/environment.rs index 63fe5c0a..33ebb44f 100644 --- a/crates/beet_ml/src/rl/environment.rs +++ b/crates/beet_ml/src/rl/environment.rs @@ -23,6 +23,7 @@ pub trait Environment: 'static + Send + Sync + Clone { // fn action_space(&self) -> Action; } +#[derive(Clone)] pub struct StepOutcome { pub state: State, pub reward: f32, @@ -30,7 +31,7 @@ pub struct StepOutcome { } pub trait DiscreteSpace: - 'static + Send + Sync + Debug + Hash + Clone + PartialEq + Eq + Component + 'static + Send + Sync + Debug + Hash + Clone + PartialEq + Eq + Component + TypePath { // type Value; // const LEN: usize; @@ -47,7 +48,8 @@ impl< + Clone + PartialEq + Eq - + Component, + + Component + + TypePath > DiscreteSpace for T { } diff --git a/crates/beet_ml/src/rl/mod.rs b/crates/beet_ml/src/rl/mod.rs index e6237c88..2f382fe3 100644 --- a/crates/beet_ml/src/rl/mod.rs +++ b/crates/beet_ml/src/rl/mod.rs @@ -7,6 +7,9 @@ pub use self::q_table_selector::*; pub mod q_trainer; #[allow(unused_imports)] pub use self::q_trainer::*; +pub mod q_table_loader; +#[allow(unused_imports)] +pub use self::q_table_loader::*; pub mod q_learn_params; #[allow(unused_imports)] pub use self::q_learn_params::*; @@ -19,9 +22,12 @@ pub use self::q_table_trainer::*; pub mod environment; #[allow(unused_imports)] pub use self::environment::*; -pub mod q_source; +pub mod q_table_env; +#[allow(unused_imports)] +pub use self::q_table_env::*; +pub mod q_policy; #[allow(unused_imports)] -pub use self::q_source::*; +pub use self::q_policy::*; pub mod hash_q_table; #[allow(unused_imports)] pub use self::hash_q_table::*; diff --git a/crates/beet_ml/src/rl/q_source.rs b/crates/beet_ml/src/rl/q_policy.rs similarity index 77% rename from crates/beet_ml/src/rl/q_source.rs rename to crates/beet_ml/src/rl/q_policy.rs index fdb59e84..d815b26c 100644 --- a/crates/beet_ml/src/rl/q_source.rs +++ b/crates/beet_ml/src/rl/q_policy.rs @@ -1,7 +1,7 @@ use crate::prelude::*; use rand::Rng; -pub trait QSource: 'static + Send + Sync { +pub trait QPolicy: 'static + Send + Sync { type State: StateSpace; type Action: ActionSpace; @@ -12,15 +12,37 @@ pub trait QSource: 'static + Send + Sync { rng: &mut impl Rng, epsilon: f32, action: &Self::Action, - prev_state: &Self::State, state: &Self::State, + // the anticipated state, it may not occur + next_state: &Self::State, reward: f32, ) -> Self::Action { - self.set_discounted_reward(params, action, reward, prev_state, state); - let (action, _) = self.epsilon_greedy_policy(&state, epsilon, rng); + self.set_discounted_reward(params, action, reward, state, next_state); + let (action, _) = self.epsilon_greedy_policy(&next_state, epsilon, rng); action } + fn set_discounted_reward( + &mut self, + params: &QLearnParams, + action: &Self::Action, + reward: QValue, + state: &Self::State, + // the anticipated state, it may not occur + next_state: &Self::State, + ) { + let prev_q = self.get_q(&state, &action); + let (_, new_max_q) = self.greedy_policy(&next_state); + + // Bellman equation + // Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)] + let discounted_reward = prev_q + + params.learning_rate + * (reward + params.gamma * new_max_q - prev_q); + + self.set_q(&state, &action, discounted_reward); + } + fn greedy_policy(&self, state: &Self::State) -> (Self::Action, QValue); fn epsilon_greedy_policy( &self, @@ -51,25 +73,4 @@ pub trait QSource: 'static + Send + Sync { action: &Self::Action, value: QValue, ); - - - fn set_discounted_reward( - &mut self, - params: &QLearnParams, - action: &Self::Action, - reward: QValue, - prev_state: &Self::State, - next_state: &Self::State, - ) { - let prev_q = self.get_q(&prev_state, &action); - let (_, new_max_q) = self.greedy_policy(&next_state); - - // Bellman equation - // Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)] - let discounted_reward = prev_q - + params.learning_rate - * (reward + params.gamma * new_max_q - prev_q); - - self.set_q(&prev_state, &action, discounted_reward); - } } diff --git a/crates/beet_ml/src/rl/q_table.rs b/crates/beet_ml/src/rl/q_table.rs index fe3cb9e5..effcbd52 100644 --- a/crates/beet_ml/src/rl/q_table.rs +++ b/crates/beet_ml/src/rl/q_table.rs @@ -1,9 +1,22 @@ use crate::prelude::*; use bevy::prelude::*; use bevy::utils::HashMap; +use serde::Deserialize; +use serde::Serialize; pub type QValue = f32; -#[derive(Debug, Clone, PartialEq, Component, Deref, DerefMut, Reflect)] +#[derive( + Debug, + Clone, + PartialEq, + Component, + Deref, + DerefMut, + Reflect, + Serialize, + Deserialize, + Asset, +)] #[reflect(Default)] pub struct QTable( pub HashMap>, @@ -13,7 +26,7 @@ impl Default for QTable { fn default() -> Self { Self(HashMap::default()) } } -impl QSource for QTable { +impl QPolicy for QTable { type Action = Action; type State = State; @@ -80,7 +93,7 @@ mod test { let mut rng = StdRng::seed_from_u64(0); let map = FrozenLakeMap::default_four_by_four(); let initial_state = map.agent_position(); - let env = FrozenLakeEnv::new(map, false); + let env = QTableEnv::new(map.transition_outcomes()); for episode in 0..params.n_training_episodes { let mut state = initial_state.clone(); diff --git a/crates/beet_ml/src/rl/q_table_env.rs b/crates/beet_ml/src/rl/q_table_env.rs new file mode 100644 index 00000000..cf7d817e --- /dev/null +++ b/crates/beet_ml/src/rl/q_table_env.rs @@ -0,0 +1,35 @@ +use crate::prelude::*; +use bevy::prelude::*; +use bevy::utils::HashMap; +// /// A number generator for determining. +// rng: StdRng, +// /// Whether there is a 2/3 chance the agent moves left or right of the intended direction. +// is_slippery: bool, + +#[derive(Clone, Component)] +/// An environment for the Frozen Lake game. +pub struct QTableEnv { + /// The transition probabilities for each state-action pair. + outcomes: HashMap<(S, A), StepOutcome>, +} + +impl QTableEnv { + pub fn new(outcomes: HashMap<(S, A), StepOutcome>) -> Self { + Self { outcomes } + } +} + +impl Environment + for QTableEnv +{ + type State = S; + type Action = A; + + fn step( + &mut self, + state: &Self::State, + action: &Self::Action, + ) -> StepOutcome { + self.outcomes[&(state.clone(), action.clone())].clone() + } +} diff --git a/crates/beet_ml/src/rl/q_table_loader.rs b/crates/beet_ml/src/rl/q_table_loader.rs new file mode 100644 index 00000000..5879421c --- /dev/null +++ b/crates/beet_ml/src/rl/q_table_loader.rs @@ -0,0 +1,43 @@ +use crate::prelude::*; +use bevy::asset::io::Reader; +use bevy::asset::AssetLoader; +use bevy::asset::AsyncReadExt; +use bevy::asset::LoadContext; +use bevy::utils::ConditionalSendFuture; +use serde::de::DeserializeOwned; +use std::marker::PhantomData; + +#[derive(Default)] +pub struct QTableLoader { + phantom: PhantomData<(State, Action)>, +} + +impl< + State: StateSpace + DeserializeOwned, + Action: ActionSpace + DeserializeOwned, + > AssetLoader for QTableLoader +{ + type Asset = QTable; + type Settings = (); + type Error = anyhow::Error; + + fn load<'a>( + &'a self, + reader: &'a mut Reader, + _settings: &'a Self::Settings, + _load_context: &'a mut LoadContext, + ) -> impl ConditionalSendFuture + + futures::Future< + Output = Result< + ::Asset, + ::Error, + >, + > { + Box::pin(async move { + let mut bytes = Vec::new(); + reader.read_to_end(&mut bytes).await?; + let table = bevy::scene::ron::de::from_bytes::>(&bytes)?; + Ok(table) + }) + } +} diff --git a/crates/beet_ml/src/rl/q_table_selector.rs b/crates/beet_ml/src/rl/q_table_selector.rs index 538b638b..d10f8327 100644 --- a/crates/beet_ml/src/rl/q_table_selector.rs +++ b/crates/beet_ml/src/rl/q_table_selector.rs @@ -11,14 +11,14 @@ use bevy::prelude::*; /// - If a child succeeds, evaluate reward and select next action. #[derive(Debug, Clone, PartialEq, Component, Reflect)] #[reflect(Component, ActionMeta)] -pub struct QTableSelector { +pub struct QTableSelector { pub evaluate: bool, pub learner: L, pub current_episode: usize, pub current_step: usize, } -fn q_table_selector( +fn q_table_selector( mut commands: Commands, mut agents: Query<(&L::State, &mut L::Action, &Reward)>, mut query: Query< @@ -70,10 +70,10 @@ fn q_table_selector( } } -impl ActionMeta for QTableSelector { +impl ActionMeta for QTableSelector { fn category(&self) -> ActionCategory { ActionCategory::Agent } } -impl ActionSystems for QTableSelector { +impl ActionSystems for QTableSelector { fn systems() -> SystemConfigs { q_table_selector::.in_set(TickSet) } } diff --git a/crates/beet_ml/src/rl/q_table_trainer.rs b/crates/beet_ml/src/rl/q_table_trainer.rs index 72a1b082..e91e2a8b 100644 --- a/crates/beet_ml/src/rl/q_table_trainer.rs +++ b/crates/beet_ml/src/rl/q_table_trainer.rs @@ -6,7 +6,7 @@ use rand::Rng; /// Used for training a QTable to completion with a provided [`Environment`]. pub struct QTableTrainer { - pub table: S::QSource, + pub table: S::QLearnPolicy, pub env: Readonly, pub params: Readonly, initial_state: S::State, @@ -15,7 +15,7 @@ pub struct QTableTrainer { impl QTableTrainer { pub fn new( env: S::Env, - table: S::QSource, + table: S::QLearnPolicy, params: QLearnParams, initial_state: S::State, ) -> Self { @@ -29,7 +29,7 @@ impl QTableTrainer { } -impl QSource for QTableTrainer { +impl QPolicy for QTableTrainer { type Action = S::Action; type State = S::State; fn greedy_policy(&self, state: &Self::State) -> (Self::Action, QValue) { @@ -62,7 +62,7 @@ impl QTrainer for QTableTrainer where S::State: Clone, { - fn train(&mut self, rng: &mut impl Rng) { + fn train_with_rng(&mut self, rng: &mut impl Rng) { let params = &self.params; for episode in 0..params.n_training_episodes { @@ -138,7 +138,7 @@ mod test { let mut policy_rng = StdRng::seed_from_u64(0); let map = FrozenLakeMap::default_four_by_four(); let initial_state = map.agent_position(); - let env = FrozenLakeEnv::new(map, false); + let env = QTableEnv::new(map.transition_outcomes()); let params = QLearnParams::default(); let mut trainer = QTableTrainer::::new( @@ -148,7 +148,7 @@ mod test { initial_state, ); let now = Instant::now(); - trainer.train(&mut policy_rng); + trainer.train_with_rng(&mut policy_rng); // My PC: 10ms // Github Actions: 50ms let elapsed = now.elapsed(); diff --git a/crates/beet_ml/src/rl/q_trainer.rs b/crates/beet_ml/src/rl/q_trainer.rs index 05c5e366..e60d2cab 100644 --- a/crates/beet_ml/src/rl/q_trainer.rs +++ b/crates/beet_ml/src/rl/q_trainer.rs @@ -2,9 +2,10 @@ use crate::prelude::*; use rand::Rng; -pub trait QTrainer: 'static + Send + Sync + QSource { +pub trait QTrainer: 'static + Send + Sync + QPolicy { + fn train(&mut self) { self.train_with_rng(&mut rand::thread_rng()) } /// Immediately train an entire agent - fn train(&mut self, rng: &mut impl Rng); + fn train_with_rng(&mut self, rng: &mut impl Rng); /// Immediately evaluate an entire agent fn evaluate(&self) -> Evaluation; diff --git a/crates/beet_ml/src/rl_realtime/mod.rs b/crates/beet_ml/src/rl_realtime/mod.rs index 0abf7ee2..e3c2729a 100644 --- a/crates/beet_ml/src/rl_realtime/mod.rs +++ b/crates/beet_ml/src/rl_realtime/mod.rs @@ -1,18 +1,21 @@ pub mod rl_agent; #[allow(unused_imports)] pub use self::rl_agent::*; -pub mod rl_session; -#[allow(unused_imports)] -pub use self::rl_session::*; pub mod rl_components; #[allow(unused_imports)] pub use self::rl_components::*; pub mod step_environment; #[allow(unused_imports)] pub use self::step_environment::*; +pub mod rl_session_types; +#[allow(unused_imports)] +pub use self::rl_session_types::*; +pub mod read_qpolicy; +#[allow(unused_imports)] +pub use self::read_qpolicy::*; pub mod rl_plugin; #[allow(unused_imports)] pub use self::rl_plugin::*; -pub mod rl_session_types; +pub mod rl_session; #[allow(unused_imports)] -pub use self::rl_session_types::*; +pub use self::rl_session::*; diff --git a/crates/beet_ml/src/rl_realtime/read_qpolicy.rs b/crates/beet_ml/src/rl_realtime/read_qpolicy.rs new file mode 100644 index 00000000..f63ae8cf --- /dev/null +++ b/crates/beet_ml/src/rl_realtime/read_qpolicy.rs @@ -0,0 +1,42 @@ +use crate::prelude::*; +use beet_ecs::prelude::*; +use bevy::ecs::schedule::SystemConfigs; +use bevy::prelude::*; + +#[derive(Debug, Clone, PartialEq, Component, Reflect)] +#[reflect(Component, ActionMeta)] +pub struct ReadQPolicy { + pub policy_handle: Handle

, +} + +impl ReadQPolicy

{ + pub fn new(table: Handle

) -> Self { + Self { + policy_handle: table, + } + } +} + +fn read_q_policy( + mut commands: Commands, + assets: Res>, + mut agents: Query<(&P::State, &mut P::Action)>, + query: Query<(Entity, &ReadQPolicy

), With>, +) { + for (entity, read_q_policy) in query.iter() { + if let Some(policy) = assets.get(&read_q_policy.policy_handle) { + for (state, mut action) in agents.iter_mut() { + *action = policy.greedy_policy(state).0; + commands.entity(entity).insert(RunResult::Success); + } + } + } +} + +impl ActionMeta for ReadQPolicy

{ + fn category(&self) -> ActionCategory { ActionCategory::Behavior } +} + +impl ActionSystems for ReadQPolicy

{ + fn systems() -> SystemConfigs { read_q_policy::

.in_set(TickSet) } +} diff --git a/crates/beet_ml/src/rl_realtime/rl_session.rs b/crates/beet_ml/src/rl_realtime/rl_session.rs index afc6a10c..8d15abff 100644 --- a/crates/beet_ml/src/rl_realtime/rl_session.rs +++ b/crates/beet_ml/src/rl_realtime/rl_session.rs @@ -179,7 +179,7 @@ mod test { mut events: EventReader>, ) { for event in events.read() { - commands.spawn(SessionEntity(event.session)); + commands.spawn((SessionEntity(event.session), DespawnOnEpisodeEnd)); } } diff --git a/crates/beet_ml/src/rl_realtime/rl_session_types.rs b/crates/beet_ml/src/rl_realtime/rl_session_types.rs index 596862da..a2a3042a 100644 --- a/crates/beet_ml/src/rl_realtime/rl_session_types.rs +++ b/crates/beet_ml/src/rl_realtime/rl_session_types.rs @@ -5,7 +5,7 @@ use crate::prelude::*; pub trait RlSessionTypes: 'static + Send + Sync { type State: StateSpace; type Action: ActionSpace; - type QSource: QSource; + type QLearnPolicy: QPolicy; type Env: Environment; type EpisodeParams: EpisodeParams; } diff --git a/crates/beet_ml/src/rl_realtime/step_environment.rs b/crates/beet_ml/src/rl_realtime/step_environment.rs index c5c339c5..e4f5401c 100644 --- a/crates/beet_ml/src/rl_realtime/step_environment.rs +++ b/crates/beet_ml/src/rl_realtime/step_environment.rs @@ -27,7 +27,7 @@ fn step_environment( mut rng: ResMut, mut end_episode_events: EventWriter>, mut commands: Commands, - mut sessions: Query<&mut S::QSource>, + mut sessions: Query<&mut S::QLearnPolicy>, mut agents: Query<( &S::State, &mut S::Action, @@ -42,7 +42,7 @@ fn step_environment( ) where S::State: Component, S::Action: Component, - S::QSource: Component, + S::QLearnPolicy: Component, S::Env: Component, { for (action_entity, agent, mut step) in query.iter_mut() { @@ -91,7 +91,7 @@ impl ActionSystems for StepEnvironment where S::State: Component, S::Action: Component, - S::QSource: Component, + S::QLearnPolicy: Component, S::Env: Component, { fn systems() -> SystemConfigs { step_environment::.in_set(TickSet) } @@ -129,7 +129,7 @@ mod test { .spawn(RlAgentBundle { state: map.agent_position(), action: GridDirection::sample_with_rng(&mut *rng), - env: FrozenLakeEnv::new(map, false), + env: QTableEnv::new(map.transition_outcomes()), params: QLearnParams::default(), session: SessionEntity(session), despawn: DespawnOnEpisodeEnd, diff --git a/examples/fetch.rs b/examples/fetch.rs index 697466bb..cf245124 100644 --- a/examples/fetch.rs +++ b/examples/fetch.rs @@ -44,7 +44,7 @@ fn main() { fn setup_camera(mut commands: Commands) { commands.spawn(( CameraDistance { - width: ITEM_OFFSET * 1.6, + width: ITEM_OFFSET * 1.4, offset: Vec3::new(0., 1.6, ITEM_OFFSET), }, Camera3dBundle::default(), diff --git a/examples/frozen_lake_run.rs b/examples/frozen_lake_run.rs new file mode 100644 index 00000000..cea13e42 --- /dev/null +++ b/examples/frozen_lake_run.rs @@ -0,0 +1,75 @@ +use beet::prelude::*; +use beet_examples::*; +use bevy::prelude::*; +use std::time::Duration; + +const SCENE_SCALE: f32 = 1.; + +fn main() { + let mut app = App::new(); + app.add_plugins(( + ExamplePlugin3d { ground: false }, + DefaultBeetPlugins, + FrozenLakePlugin, + )) + .add_systems(Startup, setup); + + app.run(); +} + + +fn setup( + mut commands: Commands, + assets: Res, + asset_server: Res, +) { + // camera + commands + .spawn((CameraDistance::new(SCENE_SCALE), Camera3dBundle::default())); + // scene + let map = FrozenLakeMap::default_four_by_four(); + let grid_to_world = GridToWorld::from_frozen_lake_map(&map, SCENE_SCALE); + + spawn_frozen_lake_scene(&mut commands, &map, &grid_to_world, &assets, ()); + // agent + + + + + let agent_grid_pos = map.agent_position(); + let agent_pos = grid_to_world.world_pos(*agent_grid_pos); + let object_scale = Vec3::splat(grid_to_world.cell_width * 0.5); + + let policy_handle = + asset_server.load::("ml/frozen_lake_qtable.ron"); + + commands + .spawn(( + SceneBundle { + scene: assets.character.clone(), + transform: Transform::from_translation(agent_pos) + .with_scale(object_scale), + ..default() + }, + grid_to_world.clone(), + agent_grid_pos, + GridDirection::sample(), + )) + .with_children(|parent| { + let agent = parent.parent_entity(); + + parent + .spawn((Running, SequenceSelector, Repeat::default())) + .with_children(|parent| { + parent.spawn(( + TargetAgent(agent), + ReadQPolicy::new(policy_handle), + )); + parent.spawn(( + TranslateGrid::new(Duration::from_secs(1)), + TargetAgent(agent), + RunTimer::default(), + )); + }); + }); +} diff --git a/examples/frozen_lake.rs b/examples/frozen_lake_train.rs similarity index 58% rename from examples/frozen_lake.rs rename to examples/frozen_lake_train.rs index 9c277cb2..616d0367 100644 --- a/examples/frozen_lake.rs +++ b/examples/frozen_lake_train.rs @@ -1,8 +1,11 @@ +//! # Frozen Lake Training +//! +//! Rendering a reinforcement learning algorithm can be entertaining and useful for debugging. use beet::prelude::*; use beet_examples::*; use bevy::prelude::*; -const MAP_WIDTH: f32 = 4.; +const SCENE_SCALE: f32 = 1.; fn main() { let mut app = App::new(); @@ -18,20 +21,17 @@ fn main() { fn setup_camera(mut commands: Commands) { - commands.spawn(( - CameraDistance { - width: MAP_WIDTH * 1.1, - offset: Vec3::new(0., 4., 4.), - }, - Camera3dBundle::default(), - )); + commands + .spawn((CameraDistance::new(SCENE_SCALE), Camera3dBundle::default())); } fn setup_runner(mut commands: Commands) { + let map = FrozenLakeMap::default_four_by_four(); let params = FrozenLakeEpParams { learn_params: default(), - map_width: MAP_WIDTH, + grid_to_world: GridToWorld::from_frozen_lake_map(&map, SCENE_SCALE), + map, }; commands.spawn((RlSession::new(params), FrozenLakeQTable::default())); } diff --git a/examples/hello_ml_vanilla.rs b/examples/hello_ml_basic.rs similarity index 100% rename from examples/hello_ml_vanilla.rs rename to examples/hello_ml_basic.rs diff --git a/examples/hello_rl_basic.rs b/examples/hello_rl_basic.rs new file mode 100644 index 00000000..6501186c --- /dev/null +++ b/examples/hello_rl_basic.rs @@ -0,0 +1,35 @@ +use beet::prelude::*; +use bevy::scene::ron; +use std::fs::File; +use std::fs::{ + self, +}; +use std::io::Write; + +fn main() -> anyhow::Result<()> { + let map = FrozenLakeMap::default_four_by_four(); + let initial_state = map.agent_position(); + let env = QTableEnv::new(map.transition_outcomes()); + let params = QLearnParams::default(); + + let mut trainer = QTableTrainer::::new( + env.clone(), + QTable::default(), + params, + initial_state, + ); + trainer.train(); + let eval = trainer.evaluate(); + assert_eq!(eval.mean, 1.); + assert_eq!(eval.std, 0.); + assert_eq!(eval.total_steps, 600); + // println!("Model trained\nMean: {}, Std: {}", eval.mean, eval.std); + + let table = trainer.table; + let text = ron::ser::to_string_pretty(&table, Default::default())?; + fs::create_dir_all("assets/ml")?; + File::create("assets/ml/frozen_lake_qtable.ron") + .and_then(|mut file| file.write(text.as_bytes()))?; + // save table to ron file + Ok(()) +} diff --git a/examples/src/camera_distance.rs b/examples/src/camera_distance.rs index ef3fcf7b..8c40373f 100644 --- a/examples/src/camera_distance.rs +++ b/examples/src/camera_distance.rs @@ -9,25 +9,14 @@ pub struct CameraDistance { } impl Default for CameraDistance { - fn default() -> Self { - Self { - width: 10.0, - offset: Vec3::ZERO, - } - } + fn default() -> Self { Self::new(10.) } } impl CameraDistance { - pub fn new(x: f32) -> Self { - Self { - width: x, - offset: Vec3::ZERO, - } - } - pub fn new_with_origin(width: f32, origin: Vec3) -> Self { + pub fn new(scale: f32) -> Self { Self { - width, - offset: origin, + width: scale * 1.1, + offset: Vec3::new(0., scale, scale), } } }