Skip to content

Commit

Permalink
feat: frozen lake run observers
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchantey committed Jul 12, 2024
1 parent 3adc56c commit c55e905
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 39 deletions.
46 changes: 25 additions & 21 deletions crates/beet_ecs/src/lifecycle/components/run_timer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@ use bevy::time::Stopwatch;
use std::fmt::Debug;

/// Tracks the last time a node was run.
#[derive(Default, Debug, Component, Reflect)]
#[derive(Default, Debug, Component, Action, Reflect)]
#[reflect(Component, Default)]
#[observers(on_start, on_stop)]
#[systems(
update_run_timers
.run_if(|time: Option<Res<Time>>| time.is_some())
.in_set(PreTickSet)
)]
pub struct RunTimer {
/// Last time the node was last started, or time since level load if never started.
pub last_started: Stopwatch,
Expand All @@ -14,41 +20,39 @@ pub struct RunTimer {
}



fn on_start(trigger: Trigger<OnAdd, Running>, mut query: Query<&mut RunTimer>) {
query
.get_mut(trigger.entity())
.expect(expect_action::ACTION_QUERY_MISSING)
.last_started
.reset();
}
fn on_stop(
trigger: Trigger<OnRemove, Running>,
mut query: Query<&mut RunTimer>,
) {
query
.get_mut(trigger.entity())
.expect(expect_action::ACTION_QUERY_MISSING)
.last_stopped
.reset();
}

/// Syncs [`RunTimer`] components, by default added to [`PreTickSet`].
/// This is added to the [`PreTickSet`], any changes detected were from the previous frame.
/// For this reason timers are reset before they tick to accuratly indicate when the [`Running`]
/// component was *actually* added or removed.
pub fn update_run_timers(
// TODO run_if
time: Res<Time>,
mut timers: Query<&mut RunTimer>,
added: Query<Entity, Added<Running>>,
mut removed: RemovedComponents<Running>,
) {
// 1. reset timers

for added in added.iter() {
if let Ok(mut timer) = timers.get_mut(added) {
timer.last_started.reset();
}
}

for removed in removed.read() {
if let Ok(mut timer) = timers.get_mut(removed) {
timer.last_stopped.reset();
}
}

// 2. tick timers

for mut timer in timers.iter_mut() {
timer.last_started.tick(time.delta());
timer.last_stopped.tick(time.delta());
}
}


#[cfg(test)]
mod test {
use crate::prelude::*;
Expand Down
1 change: 1 addition & 0 deletions crates/beet_ecs/src/lifecycle/lifecycle_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ impl Plugin for LifecyclePlugin {
app.add_plugins(ActionPlugin::<(
InsertInDuration<RunResult>,
InsertOnRun<RunResult>,
RunTimer,
LogOnRun,
// CallOnRun,
SetOnSpawn<Score>,
Expand Down
6 changes: 0 additions & 6 deletions crates/beet_ecs/src/lifecycle/lifecycle_systems_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ impl Plugin for LifecycleSystemsPlugin {
.add_systems(schedule, apply_deferred.after(PreTickSet).before(TickSet))
.add_systems(schedule, apply_deferred.after(TickSet).before(TickSyncSet))
.add_systems(schedule, apply_deferred.after(TickSyncSet).before(PostTickSet))
.add_systems(
schedule,
update_run_timers
.run_if(|time: Option<Res<Time>>| time.is_some())
.in_set(PreTickSet),
)
.add_systems(
schedule,
(sync_interrupts, sync_running).chain().in_set(TickSyncSet),
Expand Down
3 changes: 2 additions & 1 deletion crates/beet_ecs/src/observers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ pub mod expect_action{
/// we always expect the component to exist.
pub const ACTION_QUERY_MISSING: &str =
"Action entity missing from observer query";

pub const TARGET_MISSING: &str =
"Target entity missing in action";
}

pub mod expect_asset {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub fn frozen_lake_run(mut commands: Commands) {

parent
.spawn((
Name::new("Inference Behavior"),
Name::new("Run Frozen Lake Agent"),
RunOnAppReady::default(),
SequenceFlow,
Repeat::default(),
Expand All @@ -45,6 +45,7 @@ pub fn frozen_lake_run(mut commands: Commands) {
));
parent.spawn((
Name::new("Perform action"),
ContinueRun::default(),
TranslateGrid::new(Duration::from_secs(1)),
TargetAgent(agent),
RunTimer::default(),
Expand Down
29 changes: 19 additions & 10 deletions crates/beet_ml/src/rl_realtime/read_qpolicy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::marker::PhantomData;
#[derive(Debug, Clone, PartialEq, Component, Action, Reflect)]
#[reflect(Component, ActionMeta)]
#[category(ActionCategory::Behavior)]
#[systems(read_q_policy::<P>.in_set(TickSet))]
#[observers(read_q_policy::<P>)]
pub struct ReadQPolicy<P: QPolicy + Asset> {
#[reflect(ignore)]
phantom: PhantomData<P>,
Expand All @@ -21,17 +21,26 @@ impl<P: QPolicy + Asset> Default for ReadQPolicy<P> {
}

fn read_q_policy<P: QPolicy + Asset>(
trigger: Trigger<OnRun>,
mut commands: Commands,
assets: Res<Assets<P>>,
mut agents: Query<(&P::State, &mut P::Action)>,
query: Query<(Entity, &Handle<P>, &ReadQPolicy<P>), With<Running>>,
query: Query<(&ReadQPolicy<P>, &Handle<P>, &TargetAgent)>,
) {
for (entity, handle, _read_q_policy) in query.iter() {
if let Some(policy) = assets.get(handle) {
for (state, mut action) in agents.iter_mut() {
*action = policy.greedy_policy(state).0;
commands.entity(entity).insert(RunResult::Success);
}
}
}
let (_, handle, agent) = query
.get(trigger.entity())
.expect(expect_action::ACTION_QUERY_MISSING);

let policy = assets.get(handle).expect(expect_asset::NOT_READY);

let (state, mut action) = agents
.get_mut(agent.0)
.expect(expect_action::TARGET_MISSING);


*action = policy.greedy_policy(state).0;
log::info!("ReadQPolicy: \n{:?}\n{:?}", state, action);
commands
.entity(trigger.entity())
.trigger(OnRunResult::success());
}

0 comments on commit c55e905

Please sign in to comment.