Skip to content

Commit

Permalink
feat: frozen lake runner
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchantey committed Jun 11, 2024
1 parent 8eba3f6 commit 463b553
Show file tree
Hide file tree
Showing 27 changed files with 519 additions and 320 deletions.

This file was deleted.

76 changes: 39 additions & 37 deletions crates/beet_ml/src/environments/frozen_lake/frozen_lake_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use strum::IntoEnumIterator;
Serialize,
Deserialize,
Component,
Reflect,
)]
pub enum FrozenLakeCell {
Agent,
Expand Down Expand Up @@ -43,7 +44,7 @@ impl FrozenLakeCell {


/// Define an intial state for a [`FrozenLakeEnv`].
#[derive(Debug, Clone, PartialEq, Eq, Hash, Component)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Component, Reflect)]
pub struct FrozenLakeMap {
cells: Vec<FrozenLakeCell>,
size: UVec2,
Expand Down Expand Up @@ -71,8 +72,8 @@ impl FrozenLakeMap {

pub fn cells(&self) -> &Vec<FrozenLakeCell> { &self.cells }
pub fn size(&self) -> UVec2 { self.size }
pub fn width(&self) -> u32 { self.size.x }
pub fn height(&self) -> u32 { self.size.y }
pub fn num_cols(&self) -> u32 { self.size.x }
pub fn num_rows(&self) -> u32 { self.size.y }
pub fn cells_with_positions(
&self,
) -> impl Iterator<Item = (UVec2, &FrozenLakeCell)> {
Expand All @@ -92,7 +93,7 @@ impl FrozenLakeMap {
&self,
position: UVec2,
direction: GridDirection,
) -> Option<TransitionOutcome> {
) -> Option<StepOutcome<GridPos>> {
let direction: IVec2 = direction.into();
let new_pos = IVec2::new(
position.x as i32 + direction.x,
Expand All @@ -104,10 +105,10 @@ impl FrozenLakeMap {
let new_pos =
new_pos.try_into().expect("already checked in bounds");
let new_cell = self.position_to_cell(new_pos);
Some(TransitionOutcome {
Some(StepOutcome {
reward: new_cell.reward(),
pos: new_pos,
is_terminal: new_cell.is_terminal(),
state: GridPos(new_pos),
done: new_cell.is_terminal(),
})
}
}
Expand All @@ -128,29 +129,29 @@ impl FrozenLakeMap {

pub fn transition_outcomes(
&self,
) -> HashMap<(UVec2, GridDirection), TransitionOutcome> {
) -> HashMap<(GridPos, GridDirection), StepOutcome<GridPos>> {
let mut outcomes = HashMap::new();
for (pos, cell) in self.cells_with_positions() {
for action in GridDirection::iter() {
let outcome = if cell.is_terminal() {
// early exit, cannot move from terminal cell
TransitionOutcome {
StepOutcome {
reward: 0.0,
pos,
is_terminal: true,
state: GridPos(pos),
done: true,
}
} else {
// yes you can go here
self.try_transition(pos, action).unwrap_or(
// stay where you are
TransitionOutcome {
StepOutcome {
reward: 0.0,
pos,
is_terminal: false,
state: GridPos(pos),
done: false,
},
)
};
outcomes.insert((pos, action), outcome);
outcomes.insert((GridPos(pos), action), outcome);
}
}

Expand Down Expand Up @@ -190,27 +191,28 @@ impl FrozenLakeMap {
impl FrozenLakeMap {
#[rustfmt::skip]
pub fn default_eight_by_eight() -> Self {
Self {
size: UVec2::new(8, 8),
//https://github.com/openai/gym/blob/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/frozen_lake.py#L17
cells: vec![
//row 1
FrozenLakeCell::Agent, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
//row 2
FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
//row 3
FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
//row 4
FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
//row 5
FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
//row 6
FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice,
//row 7
FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice,
//row 8
FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Goal,
],
}
todo!();
// Self {
// size: UVec2::new(8, 8),
// //https://github.com/openai/gym/blob/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/frozen_lake.py#L17
// cells: vec![
// //row 1
// FrozenLakeCell::Agent, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
// //row 2
// FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
// //row 3
// FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
// //row 4
// FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
// //row 5
// FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice,
// //row 6
// FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice,
// //row 7
// FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice,
// //row 8
// FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Hole, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Ice, FrozenLakeCell::Goal,
// ],
// }
}
}
18 changes: 12 additions & 6 deletions crates/beet_ml/src/environments/frozen_lake/frozen_lake_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,20 @@ impl Plugin for FrozenLakePlugin {
ActionPlugin::<(
TranslateGrid,
StepEnvironment<FrozenLakeQTableSession>,
ReadQPolicy<FrozenLakeQTable>,
)>::default(),
RlSessionPlugin::<FrozenLakeEpParams>::default(),
))
.add_systems(Startup, init_frozen_lake_assets)
.add_systems(PreStartup, init_frozen_lake_assets)
.add_systems(Update, reward_grid.in_set(PostTickSet))
.add_systems(
Update,
(spawn_frozen_lake_session, spawn_frozen_lake_episode)
.in_set(PostTickSet),
)
.init_resource::<RlRng>();
.init_resource::<RlRng>()
.init_asset::<QTable<GridPos, GridDirection>>()
.init_asset_loader::<QTableLoader<GridPos, GridDirection>>();

let world = app.world_mut();
world.init_component::<GridPos>();
Expand All @@ -69,14 +72,17 @@ impl Plugin for FrozenLakePlugin {
#[derive(Debug, Clone, Reflect)]
pub struct FrozenLakeEpParams {
pub learn_params: QLearnParams,
pub map_width: f32,
pub map: FrozenLakeMap,
pub grid_to_world: GridToWorld,
}

impl Default for FrozenLakeEpParams {
fn default() -> Self {
let map = FrozenLakeMap::default_four_by_four();
Self {
learn_params: QLearnParams::default(),
map_width: 4.0,
grid_to_world: GridToWorld::from_frozen_lake_map(&map, 4.0),
map,
}
}
}
Expand All @@ -93,7 +99,7 @@ pub struct FrozenLakeQTableSession;
impl RlSessionTypes for FrozenLakeQTableSession {
type State = GridPos;
type Action = GridDirection;
type QSource = FrozenLakeQTable;
type Env = FrozenLakeEnv;
type QLearnPolicy = FrozenLakeQTable;
type Env = QTableEnv<Self::State, Self::Action>;
type EpisodeParams = FrozenLakeEpParams;
}
62 changes: 62 additions & 0 deletions crates/beet_ml/src/environments/frozen_lake/frozen_lake_scene.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use crate::prelude::*;
use bevy::prelude::*;

pub fn spawn_frozen_lake_scene(
commands: &mut Commands,
map: &FrozenLakeMap,
grid_to_world: &GridToWorld,
assets: &Res<FrozenLakeAssets>,
bundle: impl Bundle + Clone,
) {
let tile_scale = Vec3::splat(grid_to_world.cell_width);
for x in 0..map.num_cols() {
for y in 0..map.num_rows() {
let mut pos = grid_to_world.world_pos(UVec2::new(x, y));
pos.y -= grid_to_world.cell_width;
commands.spawn((
SceneBundle {
scene: assets.tile.clone(),
transform: Transform::from_translation(pos)
.with_scale(tile_scale),
..default()
},
bundle.clone(),
));
}
}

let object_scale = Vec3::splat(grid_to_world.cell_width * 0.5);

for (index, cell) in map.cells().iter().enumerate() {
let grid_pos = map.index_to_position(index);
let mut pos = grid_to_world.world_pos(grid_pos);
match cell {
FrozenLakeCell::Hole => {
pos.y += grid_to_world.cell_width * 0.25; // this asset is a bit too low
commands.spawn((
SceneBundle {
scene: assets.hazard.clone(),
transform: Transform::from_translation(pos)
.with_scale(object_scale),
..default()
},
bundle.clone(),
));
}
FrozenLakeCell::Goal => {
commands.spawn((
SceneBundle {
scene: assets.goal.clone(),
transform: Transform::from_translation(pos)
.with_scale(object_scale),
..default()
},
bundle.clone(),
));
}
FrozenLakeCell::Ice => {}
FrozenLakeCell::Agent => { /*spawns on episode */ }
}
{}
}
}
30 changes: 24 additions & 6 deletions crates/beet_ml/src/environments/frozen_lake/grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::prelude::ActionSpace;
use bevy::prelude::*;
use rand::rngs::StdRng;
use rand::Rng;
use serde::Deserialize;
use serde::Serialize;
use strum::EnumCount;
use strum::EnumIter;
use strum::VariantArray;
Expand All @@ -19,6 +21,8 @@ use strum::VariantArray;
DerefMut,
Component,
Reflect,
Serialize,
Deserialize,
)]
pub struct GridPos(pub UVec2);

Expand Down Expand Up @@ -55,6 +59,8 @@ impl From<UVec2> for GridPos {
VariantArray,
EnumIter,
EnumCount,
Serialize,
Deserialize,
)]
pub enum GridDirection {
#[default]
Expand Down Expand Up @@ -124,23 +130,35 @@ impl ActionSpace for GridDirection {
}


#[derive(Debug, Clone, Component)]
#[derive(Debug, Clone, Component, Reflect)]
pub struct GridToWorld {
pub map_width: f32,
pub cell_width: f32,
pub map_size:UVec2,
pub map_size: UVec2,
pub offset: Vec3,
}

impl GridToWorld {
pub fn from_frozen_lake_map(grid: &FrozenLakeMap, map_width: f32) -> Self {
let cell_width = map_width / grid.width() as f32;
let cell_width = map_width / grid.num_cols() as f32;
let h_cell_width = cell_width * 0.5;

let h_map_width = map_width * 0.5;

let offset = Vec3::new(
grid.width() as f32 + h_cell_width,
-h_map_width + h_cell_width,
0.,
grid.height() as f32 + h_cell_width,
) * -0.5;
-h_map_width + h_cell_width,
);


// let mut offset = Vec3::new(
// grid.num_cols() as f32,
// 0.,
// grid.num_rows() as f32,
// ) * -0.5;
// offset.x -= h_cell_width;
// offset.z -= h_cell_width;

Self {
map_size: grid.size(),
Expand Down
Loading

0 comments on commit 463b553

Please sign in to comment.