Skip to content

Commit

Permalink
wip: episode runner
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchantey committed Jun 10, 2024
1 parent 87da928 commit ea7a31e
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ impl Plugin for FrozenLakePlugin {
)>::default(),
EpisodeRunnerPlugin::<FrozenLakeEpParams>::default(),
))
.add_systems(Startup, init_frozen_lake_assets)
.add_systems(Update, reward_grid.in_set(PostTickSet))
.add_systems(Update, spawn_frozen_lake.in_set(PostTickSet));

app.init_resource::<RlRng>();

.add_systems(
Update,
(spawn_frozen_lake_static, spawn_frozen_lake).in_set(PostTickSet),
)
.init_resource::<RlRng>();

let world = app.world_mut();
world.init_component::<GridPos>();
Expand Down
134 changes: 89 additions & 45 deletions crates/beet_ml/src/environments/frozen_lake/spawn_frozen_lake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,39 @@ use bevy::prelude::*;
use std::time::Duration;


pub fn spawn_frozen_lake(
mut events: EventReader<StartEpisode<FrozenLakeEpParams>>,
#[derive(Resource)]
pub struct FrozenLakeAssets {
pub tile: Handle<Scene>,
pub character: Handle<Scene>,
pub goal: Handle<Scene>,
pub hazard: Handle<Scene>,
}

pub fn init_frozen_lake_assets(
mut commands: Commands,
asset_server: Res<AssetServer>,
) {
let tile =
asset_server.load("kaykit-minigame/tileSmall_teamBlue.gltf.glb#Scene0");
let character =
asset_server.load("kaykit-minigame/character_dog.gltf.glb#Scene0");
let goal =
asset_server.load("kaykit-minigame/flag_teamYellow.gltf.glb#Scene0");
let hazard =
asset_server.load("kaykit-minigame/bomb_teamRed.gltf.glb#Scene0");

commands.insert_resource(FrozenLakeAssets {
tile,
character,
goal,
hazard,
});
}

pub fn spawn_frozen_lake_static(
mut events: EventReader<StartSession<FrozenLakeEpParams>>,
mut commands: Commands,
assets: Res<FrozenLakeAssets>,
) {
for event in events.read() {
let map = FrozenLakeMap::default_four_by_four();
Expand All @@ -16,49 +45,81 @@ pub fn spawn_frozen_lake(
GridToWorld::from_frozen_lake_map(&map, event.params.map_width);

let tile_scale = Vec3::splat(grid_to_world.cell_width);
let tile_handle = asset_server
.load("kaykit-minigame/tileSmall_teamBlue.gltf.glb#Scene0");
for x in 0..map.width() {
for y in 0..map.height() {
let mut pos = grid_to_world.world_pos(UVec2::new(x, y));
pos.y -= grid_to_world.cell_width;
commands.spawn((SceneBundle {
scene: tile_handle.clone(),
transform: Transform::from_translation(pos)
.with_scale(tile_scale),
..default()
},));
commands.spawn((
SceneBundle {
scene: assets.tile.clone(),
transform: Transform::from_translation(pos)
.with_scale(tile_scale),
..default()
},
EpisodeOwner(event.trainer),
));
}
}
// if let Some(agent_pos) = map.agent_position() {
// let pos =
// offset + Vec3::new(agent_pos.x as f32, 0.1, agent_pos.y as f32);
// }

let character_handle =
asset_server.load("kaykit-minigame/character_dog.gltf.glb#Scene0");

let goal_handle = asset_server
.load("kaykit-minigame/flag_teamYellow.gltf.glb#Scene0");
let object_scale = Vec3::splat(grid_to_world.cell_width * 0.5);

let hazard_handle =
asset_server.load("kaykit-minigame/bomb_teamRed.gltf.glb#Scene0");
for (index, cell) in map.cells().iter().enumerate() {
let grid_pos = map.index_to_position(index);
let mut pos = grid_to_world.world_pos(grid_pos);
match cell {
FrozenLakeCell::Hole => {
pos.y += grid_to_world.cell_width * 0.25; // this asset is a bit too low
commands.spawn((
SceneBundle {
scene: assets.hazard.clone(),
transform: Transform::from_translation(pos)
.with_scale(object_scale),
..default()
},
EpisodeOwner(event.trainer),
));
}
FrozenLakeCell::Goal => {
commands.spawn((
SceneBundle {
scene: assets.goal.clone(),
transform: Transform::from_translation(pos)
.with_scale(object_scale),
..default()
},
EpisodeOwner(event.trainer),
));
}
FrozenLakeCell::Ice => {}
FrozenLakeCell::Agent => { /*spawns on episode */ }
}
{}
}
}
}


pub fn spawn_frozen_lake(
mut events: EventReader<StartEpisode<FrozenLakeEpParams>>,
mut commands: Commands,
assets: Res<FrozenLakeAssets>,
) {
for event in events.read() {
// TODO deduplicate
let map = FrozenLakeMap::default_four_by_four();
let grid_to_world =
GridToWorld::from_frozen_lake_map(&map, event.params.map_width);
let object_scale = Vec3::splat(grid_to_world.cell_width * 0.5);

for (index, cell) in map.cells().iter().enumerate() {
let grid_pos = map.index_to_position(index);
let mut pos = grid_to_world.world_pos(grid_pos);
let pos = grid_to_world.world_pos(grid_pos);
match cell {
FrozenLakeCell::Agent => {
let trainer = commands.spawn_empty().id();


commands
.spawn((
SceneBundle {
scene: character_handle.clone(),
scene: assets.character.clone(),
transform: Transform::from_translation(pos)
.with_scale(object_scale),
..default()
Expand All @@ -70,7 +131,7 @@ pub fn spawn_frozen_lake(
table: QTable::default(),
env: FrozenLakeEnv::new(map.clone(), false),
params: event.params.learn_params.clone(),
trainer: EpisodeOwner(trainer),
trainer: EpisodeOwner(event.trainer),
},
))
.with_children(|parent| {
Expand Down Expand Up @@ -99,24 +160,7 @@ pub fn spawn_frozen_lake(
});
});
}
FrozenLakeCell::Hole => {
pos.y += grid_to_world.cell_width * 0.25; // this asset is a bit too low
commands.spawn(SceneBundle {
scene: hazard_handle.clone(),
transform: Transform::from_translation(pos)
.with_scale(object_scale),
..default()
});
}
FrozenLakeCell::Goal => {
commands.spawn(SceneBundle {
scene: goal_handle.clone(),
transform: Transform::from_translation(pos)
.with_scale(object_scale),
..default()
});
}
FrozenLakeCell::Ice => {}
_ => {}
}
{}
}
Expand Down
38 changes: 30 additions & 8 deletions crates/beet_ml/src/rl_realtime/episode_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ pub struct StartEpisode<T: EpisodeParams> {
pub episode: u32,
pub params: T,
}
#[derive(Debug, Event)]
pub struct StartSession<T: EpisodeParams> {
pub trainer: Entity,
pub params: T,
}
#[derive(Debug, Event)]
pub struct EndSession<T: EpisodeParams> {
pub trainer: Entity,
pub params: T,
}

#[derive(Debug, Event)]
pub struct EndEpisode<T: EpisodeParams> {
pub trainer: Entity,
Expand Down Expand Up @@ -43,6 +54,8 @@ impl<T: EpisodeParams> Plugin for EpisodeRunnerPlugin<T> {
handle_episode_end::<T>.in_set(PostTickSet),
),
)
.add_event::<StartSession<T>>()
.add_event::<EndSession<T>>()
.add_event::<StartEpisode<T>>()
.add_event::<EndEpisode<T>>();
}
Expand Down Expand Up @@ -79,11 +92,16 @@ impl<T: EpisodeParams> EpisodeRunner<T> {
}

pub fn init_episode_runner<T: EpisodeParams>(
mut events: EventWriter<StartEpisode<T>>,
mut start_session: EventWriter<StartSession<T>>,
mut start_episode: EventWriter<StartEpisode<T>>,
runners: Query<(Entity, &mut EpisodeRunner<T>), Added<EpisodeRunner<T>>>,
) {
for (entity, trainer) in runners.iter() {
events.send(StartEpisode {
start_session.send(StartSession {
trainer: entity,
params: trainer.params.clone(),
});
start_episode.send(StartEpisode {
trainer: entity,
episode: trainer.episode,
params: trainer.params.clone(),
Expand All @@ -93,12 +111,13 @@ pub fn init_episode_runner<T: EpisodeParams>(

pub fn handle_episode_end<T: EpisodeParams>(
mut commands: Commands,
mut start_events: EventWriter<StartEpisode<T>>,
mut end_events: EventReader<EndEpisode<T>>,
mut start_ep: EventWriter<StartEpisode<T>>,
mut end_ep: EventReader<EndEpisode<T>>,
mut end_session: EventWriter<EndSession<T>>,
mut ep_entities: Query<(Entity, &EpisodeOwner)>,
mut trainers: Query<(Entity, &mut EpisodeRunner<T>)>,
) {
for event in end_events.read() {
for event in end_ep.read() {
if let Ok((runner_entity, mut runner)) = trainers.get_mut(event.trainer)
{
for (ep_entity, parent_runner) in ep_entities.iter_mut() {
Expand All @@ -108,14 +127,17 @@ pub fn handle_episode_end<T: EpisodeParams>(
}
runner.episode += 1;
if runner.episode < runner.params.num_episodes() {
start_events.send(StartEpisode {
start_ep.send(StartEpisode {
trainer: runner_entity,
episode: runner.episode,
params: runner.params.clone(),
});
} else {
println!("Training complete");
// todo!("Save model");
// complete!
end_session.send(EndSession {
trainer: runner_entity,
params: runner.params.clone(),
});
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions crates/beet_ml/src/rl_realtime/step_environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ fn step_environment<S: RlSessionTypes>(
&outcome.state,
outcome.reward,
);
log::info!(
"step complete - action: {:?}, reward: {:?}",
action,
outcome.reward
);
// log::info!(
// "step complete - action: {:?}, reward: {:?}",
// action,
// outcome.reward
// );

commands.entity(action_entity).insert(RunResult::Success);

step.step += 1;

if outcome.done || step.step >= params.max_steps {
end_episode_events.send(EndEpisode::new(**trainer));
}
Expand Down Expand Up @@ -109,9 +109,9 @@ mod test {
let mut app = App::new();

app.add_plugins((
AssetPlugin::default(),
LifecyclePlugin,
FrozenLakePlugin,
ActionPlugin::<StepEnvironment<FrozenLakeQTableSession>>::default(),
EpisodeRunnerPlugin::<FrozenLakeEpParams>::default(),
))
.insert_time();

Expand Down

0 comments on commit ea7a31e

Please sign in to comment.