Skip to content

Commit

Permalink
wip: episode runner
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchantey committed Jun 10, 2024
1 parent 36ef87d commit 87da928
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 107 deletions.
58 changes: 34 additions & 24 deletions crates/beet_ml/src/environments/frozen_lake/frozen_lake_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ use crate::prelude::*;
use beet_ecs::prelude::*;
use bevy::prelude::*;

pub type FrozenLakeQTable = QTable<GridPos, GridDirection>;


/**
Implementation of the OpenAI Gym Frozen Lake environment.
https://github.com/openai/gym/blob/master/gym/envs/toy_text/frozen_lake.py
Expand Down Expand Up @@ -38,32 +35,12 @@ Reward schedule:
**/
pub struct FrozenLakePlugin;

#[derive(Debug, Clone, Reflect)]
pub struct FrozenLakeEpParams {
pub learn_params: QLearnParams,
pub map_width: f32,
}

impl Default for FrozenLakeEpParams {
fn default() -> Self {
Self {
learn_params: QLearnParams::default(),
map_width: 4.,
}
}
}

impl EpisodeParams for FrozenLakeEpParams {
fn num_episodes(&self) -> u32 { self.learn_params.n_training_episodes }
}


impl Plugin for FrozenLakePlugin {
fn build(&self, app: &mut App) {
app.add_plugins((
ActionPlugin::<(
TranslateGrid,
StepEnvironment<FrozenLakeEnv, FrozenLakeQTable>,
StepEnvironment<FrozenLakeQTableSession>,
)>::default(),
EpisodeRunnerPlugin::<FrozenLakeEpParams>::default(),
))
Expand All @@ -84,3 +61,36 @@ impl Plugin for FrozenLakePlugin {
registry.register::<GridDirection>();
}
}


#[derive(Debug, Clone, Reflect)]
pub struct FrozenLakeEpParams {
pub learn_params: QLearnParams,
pub map_width: f32,
}

impl Default for FrozenLakeEpParams {
fn default() -> Self {
Self {
learn_params: QLearnParams::default(),
map_width: 4.0,
}
}
}

impl EpisodeParams for FrozenLakeEpParams {
fn num_episodes(&self) -> u32 { self.learn_params.n_training_episodes }
}

pub type FrozenLakeQTable = QTable<GridPos, GridDirection>;


pub struct FrozenLakeQTableSession;

impl RlSessionTypes for FrozenLakeQTableSession {
type State = GridPos;
type Action = GridDirection;
type QSource = FrozenLakeQTable;
type Env = FrozenLakeEnv;
type EpisodeParams = FrozenLakeEpParams;
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ pub fn spawn_frozen_lake(
parent.spawn((
TargetAgent(agent),
StepEnvironment::<
FrozenLakeEnv,
FrozenLakeQTable,
FrozenLakeQTableSession,
>::new(event.episode),
));
});
Expand Down
2 changes: 1 addition & 1 deletion crates/beet_ml/src/rl/q_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ mod test {
}
}

let eval = QTableTrainer::new(
let eval = QTableTrainer::<FrozenLakeQTableSession>::new(
env,
source,
params,
Expand Down
51 changes: 16 additions & 35 deletions crates/beet_ml/src/rl/q_table_trainer.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,23 @@
use super::*;
use crate::prelude::RlSessionTypes;
use rand::Rng;



/// Used for training a QTable to completion with a provided [`Environment`].
pub struct QTableTrainer<
S: StateSpace + Clone,
A: ActionSpace,
Env: Environment<State = S, Action = A>,
Table: QSource<State = S, Action = A>,
> {
pub table: Table,
pub env: Readonly<Env>,
pub struct QTableTrainer<S: RlSessionTypes> {
pub table: S::QSource,
pub env: Readonly<S::Env>,
pub params: Readonly<QLearnParams>,
initial_state: S,
initial_state: S::State,
}

impl<
S: StateSpace + Clone,
A: ActionSpace,
Env: Environment<State = S, Action = A>,
Table: QSource<State = S, Action = A>,
> QTableTrainer<S, A, Env, Table>
{
impl<S: RlSessionTypes> QTableTrainer<S> {
pub fn new(
env: Env,
table: Table,
env: S::Env,
table: S::QSource,
params: QLearnParams,
initial_state: S,
initial_state: S::State,
) -> Self {
Self {
table,
Expand All @@ -39,15 +29,9 @@ impl<
}


impl<
S: StateSpace + Clone,
A: ActionSpace,
Env: Environment<State = S, Action = A>,
Table: QSource<State = S, Action = A>,
> QSource for QTableTrainer<S, A, Env, Table>
{
type Action = A;
type State = S;
impl<S: RlSessionTypes> QSource for QTableTrainer<S> {
type Action = S::Action;
type State = S::State;
fn greedy_policy(&self, state: &Self::State) -> (Self::Action, QValue) {
self.table.greedy_policy(state)
}
Expand All @@ -74,12 +58,9 @@ impl<
}


impl<
S: StateSpace + Clone,
A: ActionSpace,
Env: Environment<State = S, Action = A>,
Table: QSource<State = S, Action = A>,
> QTrainer for QTableTrainer<S, A, Env, Table>
impl<S: RlSessionTypes> QTrainer for QTableTrainer<S>
where
S::State: Clone,
{
fn train(&mut self, rng: &mut impl Rng) {
let params = &self.params;
Expand Down Expand Up @@ -160,7 +141,7 @@ mod test {
let env = FrozenLakeEnv::new(map, false);
let params = QLearnParams::default();

let mut trainer = QTableTrainer::new(
let mut trainer = QTableTrainer::<FrozenLakeQTableSession>::new(
env.clone(),
QTable::default(),
params,
Expand Down
11 changes: 6 additions & 5 deletions crates/beet_ml/src/rl_realtime/episode_runner.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::prelude::*;
use beet_ecs::prelude::*;
use bevy::prelude::*;
use std::borrow::Cow;
Expand All @@ -21,7 +20,7 @@ pub struct EndEpisode<T: EpisodeParams> {
phantom: PhantomData<T>,
}

impl EndEpisode<FrozenLakeEpParams> {
impl<T: EpisodeParams> EndEpisode<T> {
pub fn new(trainer: Entity) -> Self {
Self {
trainer,
Expand All @@ -39,7 +38,10 @@ impl<T: EpisodeParams> Plugin for EpisodeRunnerPlugin<T> {
fn build(&self, app: &mut App) {
app.add_systems(
Update,
(init_episode_runner::<T>, handle_episode_end::<T>).in_set(TickSet),
(
init_episode_runner::<T>.in_set(PreTickSet),
handle_episode_end::<T>.in_set(PostTickSet),
),
)
.add_event::<StartEpisode<T>>()
.add_event::<EndEpisode<T>>();
Expand Down Expand Up @@ -153,8 +155,7 @@ mod test {
LifecyclePlugin::default(),
EpisodeRunnerPlugin::<FrozenLakeEpParams>::default(),
))
.add_systems(Update, start_ep.in_set(PostTickSet))
.add_systems(Update, end_ep.in_set(PreTickSet));
.add_systems(Update, (start_ep, end_ep).in_set(TickSet));
let mut params = FrozenLakeEpParams::default();
params.learn_params.n_training_episodes = 1;
app.world_mut().spawn(EpisodeRunner::new(params));
Expand Down
3 changes: 3 additions & 0 deletions crates/beet_ml/src/rl_realtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ pub use self::step_environment::*;
pub mod rl_plugin;
#[allow(unused_imports)]
pub use self::rl_plugin::*;
pub mod rl_session;
#[allow(unused_imports)]
pub use self::rl_session::*;
11 changes: 11 additions & 0 deletions crates/beet_ml/src/rl_realtime/rl_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use crate::prelude::*;



pub trait RlSessionTypes: 'static + Send + Sync {
type State: StateSpace;
type Action: ActionSpace;
type QSource: QSource<State = Self::State, Action = Self::Action>;
type Env: Environment<State = Self::State, Action = Self::Action>;
type EpisodeParams: EpisodeParams;
}
73 changes: 33 additions & 40 deletions crates/beet_ml/src/rl_realtime/step_environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,13 @@ use std::marker::PhantomData;

#[derive(Debug, Clone, PartialEq, Component, Reflect)]
#[reflect(Component, ActionMeta)]
pub struct StepEnvironment<
Env: Component + Environment<State = Table::State, Action = Table::Action>,
Table: Component + QSource,
> {
pub struct StepEnvironment<S: RlSessionTypes> {
episode: u32,
step: u32,
phantom: PhantomData<(Env, Table)>,
phantom: PhantomData<S>,
}

impl<
Env: Component + Environment<State = Table::State, Action = Table::Action>,
Table: Component + QSource,
> StepEnvironment<Env, Table>
{
impl<S: RlSessionTypes> StepEnvironment<S> {
pub fn new(episode: u32) -> Self {
Self {
episode,
Expand All @@ -30,27 +23,29 @@ impl<
}


fn step_environment<
Env: Component + Environment<State = Table::State, Action = Table::Action>,
Table: Component + QSource,
>(
fn step_environment<S: RlSessionTypes>(
mut rng: ResMut<RlRng>,
mut end_episode_events: EventWriter<EndEpisode<S::EpisodeParams>>,
mut commands: Commands,
mut agents: Query<(
&Table::State,
&mut Table::Action,
&mut Table,
&mut Env,
&S::State,
&mut S::Action,
&mut S::QSource,
&mut S::Env,
&QLearnParams,
&EpisodeOwner,
)>,
mut query: Query<
(Entity, &TargetAgent, &mut StepEnvironment<Env, Table>),
(Entity, &TargetAgent, &mut StepEnvironment<S>),
Added<Running>,
>,
) {
) where
S::State: Component,
S::Action: Component,
S::QSource: Component,
S::Env: Component,
{
for (action_entity, agent, mut step) in query.iter_mut() {
log::info!("step start");
let Ok((state, mut action, mut table, mut env, params, trainer)) =
agents.get_mut(**agent)
else {
Expand Down Expand Up @@ -80,28 +75,23 @@ fn step_environment<

step.step += 1;
if outcome.done || step.step >= params.max_steps {
commands.entity(**trainer).insert(RunResult::Success);
log::info!("episode complete");
end_episode_events.send(EndEpisode::new(**trainer));
}
}
}

impl<
Env: Component + Environment<State = Table::State, Action = Table::Action>,
Table: Component + QSource,
> ActionMeta for StepEnvironment<Env, Table>
{
impl<S: RlSessionTypes> ActionMeta for StepEnvironment<S> {
fn category(&self) -> ActionCategory { ActionCategory::Behavior }
}

impl<
Env: Component + Environment<State = Table::State, Action = Table::Action>,
Table: Component + QSource,
> ActionSystems for StepEnvironment<Env, Table>
impl<S: RlSessionTypes> ActionSystems for StepEnvironment<S>
where
S::State: Component,
S::Action: Component,
S::QSource: Component,
S::Env: Component,
{
fn systems() -> SystemConfigs {
step_environment::<Env, Table>.in_set(TickSet)
}
fn systems() -> SystemConfigs { step_environment::<S>.in_set(TickSet) }
}


Expand All @@ -118,8 +108,12 @@ mod test {
fn works() -> Result<()> {
let mut app = App::new();

app.add_plugins((LifecyclePlugin, FrozenLakePlugin))
.insert_time();
app.add_plugins((
AssetPlugin::default(),
LifecyclePlugin,
FrozenLakePlugin,
))
.insert_time();

let map = FrozenLakeMap::default_four_by_four();

Expand All @@ -140,7 +134,7 @@ mod test {
.with_children(|parent| {
parent.spawn((
TargetAgent(parent.parent_entity()),
StepEnvironment::<FrozenLakeEnv, FrozenLakeQTable>::new(0),
StepEnvironment::<FrozenLakeQTableSession>::new(0),
Running,
));
})
Expand All @@ -162,8 +156,7 @@ mod test {
let table = app.world().get::<FrozenLakeQTable>(agent).unwrap();
expect(table.keys().next()).to_be(Some(&GridPos(UVec2::new(0, 0))))?;
let inner = table.values().next().unwrap();
expect(inner.iter().next().unwrap())
.to_be((&GridDirection::Left, &0.))?;
expect(inner.iter().next().unwrap().1).to_be(&0.)?;

Ok(())
}
Expand Down

0 comments on commit 87da928

Please sign in to comment.