Skip to content

Commit

Permalink
feat: rl session
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchantey committed Jun 10, 2024
1 parent ea7a31e commit 8eba3f6
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 242 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ impl Plugin for FrozenLakePlugin {
TranslateGrid,
StepEnvironment<FrozenLakeQTableSession>,
)>::default(),
EpisodeRunnerPlugin::<FrozenLakeEpParams>::default(),
RlSessionPlugin::<FrozenLakeEpParams>::default(),
))
.add_systems(Startup, init_frozen_lake_assets)
.add_systems(Update, reward_grid.in_set(PostTickSet))
.add_systems(
Update,
(spawn_frozen_lake_static, spawn_frozen_lake).in_set(PostTickSet),
(spawn_frozen_lake_session, spawn_frozen_lake_episode)
.in_set(PostTickSet),
)
.init_resource::<RlRng>();

Expand Down
19 changes: 11 additions & 8 deletions crates/beet_ml/src/environments/frozen_lake/spawn_frozen_lake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub fn init_frozen_lake_assets(
});
}

pub fn spawn_frozen_lake_static(
pub fn spawn_frozen_lake_session(
mut events: EventReader<StartSession<FrozenLakeEpParams>>,
mut commands: Commands,
assets: Res<FrozenLakeAssets>,
Expand All @@ -56,7 +56,8 @@ pub fn spawn_frozen_lake_static(
.with_scale(tile_scale),
..default()
},
EpisodeOwner(event.trainer),
SessionEntity(event.session),
DespawnOnSessionEnd,
));
}
}
Expand All @@ -76,7 +77,8 @@ pub fn spawn_frozen_lake_static(
.with_scale(object_scale),
..default()
},
EpisodeOwner(event.trainer),
SessionEntity(event.session),
DespawnOnSessionEnd,
));
}
FrozenLakeCell::Goal => {
Expand All @@ -87,7 +89,8 @@ pub fn spawn_frozen_lake_static(
.with_scale(object_scale),
..default()
},
EpisodeOwner(event.trainer),
SessionEntity(event.session),
DespawnOnSessionEnd,
));
}
FrozenLakeCell::Ice => {}
Expand All @@ -99,7 +102,7 @@ pub fn spawn_frozen_lake_static(
}


pub fn spawn_frozen_lake(
pub fn spawn_frozen_lake_episode(
mut events: EventReader<StartEpisode<FrozenLakeEpParams>>,
mut commands: Commands,
assets: Res<FrozenLakeAssets>,
Expand Down Expand Up @@ -128,10 +131,10 @@ pub fn spawn_frozen_lake(
RlAgentBundle {
state: map.agent_position(),
action: GridDirection::sample(),
table: QTable::default(),
env: FrozenLakeEnv::new(map.clone(), false),
params: event.params.learn_params.clone(),
trainer: EpisodeOwner(event.trainer),
session: SessionEntity(event.session),
despawn: DespawnOnEpisodeEnd,
},
))
.with_children(|parent| {
Expand All @@ -146,7 +149,7 @@ pub fn spawn_frozen_lake(
.with_children(|parent| {
parent.spawn((
TranslateGrid::new(
Duration::from_secs(1),
Duration::from_millis(1),
),
TargetAgent(agent),
RunTimer::default(),
Expand Down
194 changes: 0 additions & 194 deletions crates/beet_ml/src/rl_realtime/episode_runner.rs

This file was deleted.

8 changes: 4 additions & 4 deletions crates/beet_ml/src/rl_realtime/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub mod rl_agent;
#[allow(unused_imports)]
pub use self::rl_agent::*;
pub mod episode_runner;
pub mod rl_session;
#[allow(unused_imports)]
pub use self::episode_runner::*;
pub use self::rl_session::*;
pub mod rl_components;
#[allow(unused_imports)]
pub use self::rl_components::*;
Expand All @@ -13,6 +13,6 @@ pub use self::step_environment::*;
pub mod rl_plugin;
#[allow(unused_imports)]
pub use self::rl_plugin::*;
pub mod rl_session;
pub mod rl_session_types;
#[allow(unused_imports)]
pub use self::rl_session::*;
pub use self::rl_session_types::*;
27 changes: 12 additions & 15 deletions crates/beet_ml/src/rl_realtime/rl_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,19 @@ use bevy::prelude::*;


#[derive(Bundle)]
pub struct RlAgentBundle<
Env: Component + Environment<State = Table::State, Action = Table::Action>,
Table: Component + QSource,
> {
pub state: Table::State,
pub action: Table::Action,
pub table: Table,
pub struct RlAgentBundle<Env: Component + Environment> {
pub state: Env::State,
pub action: Env::Action,
pub env: Env,
pub params: QLearnParams,
pub trainer: EpisodeOwner,
pub session: SessionEntity,
pub despawn:DespawnOnEpisodeEnd
}


impl<
Env: Component + Environment<State = Table::State, Action = Table::Action>,
Table: Component + QSource,
> RlAgentBundle<Env, Table>
{
}
// #[derive(Bundle)]
// pub struct RlSessionBundle<S: RlSessionTypes>
// where
// S::QSource: Component,
// {
// pub source: S::QSource,
// }
4 changes: 2 additions & 2 deletions crates/beet_ml/src/rl_realtime/rl_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ pub struct RlPlugin;
impl Plugin for RlPlugin {
fn build(&self, app: &mut App) {
let world = app.world_mut();
world.init_component::<EpisodeOwner>();
world.init_component::<SessionEntity>();

let mut registry =
world.get_resource::<AppTypeRegistry>().unwrap().write();

registry.register::<EpisodeOwner>();
registry.register::<SessionEntity>();
}
}
Loading

0 comments on commit 8eba3f6

Please sign in to comment.