Skip to content

Commit

Permalink
feat: hello_ml observer
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchantey committed Jul 9, 2024
1 parent fd639b4 commit c2978f5
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 157 deletions.
3 changes: 2 additions & 1 deletion crates/beet_ecs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ categories.workspace = true
default = []
# default = ["reflect"]
reflect = []
test = ["dep:sweet"]

[dependencies]
beet_ecs_macros.workspace = true
Expand All @@ -30,9 +31,9 @@ extend.workspace = true
num-traits.workspace = true

bevy.workspace = true
sweet = { workspace = true, optional = true }

[dev-dependencies]
pretty_env_logger.workspace = true
sweet.workspace = true
bincode.workspace = true
ron.workspace = true
4 changes: 2 additions & 2 deletions crates/beet_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub mod graph;
pub mod lifecycle;
pub mod observers;
pub mod reflect;
#[cfg(test)]
#[cfg(any(test, feature = "test"))]
pub mod test;
pub mod tree;

Expand All @@ -34,7 +34,7 @@ pub mod prelude {
pub use crate::observers::*;
// pub use crate::lifecycle::*;
pub use crate::reflect::*;
#[cfg(test)]
#[cfg(any(test, feature = "test"))]
pub use crate::test::*;
pub use crate::tree::*;
pub use beet_ecs_macros::*;
Expand Down
3 changes: 1 addition & 2 deletions crates/beet_examples/src/scenes/seek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ pub fn seek(mut commands: Commands) {
// behavior
parent.spawn((
Name::new("Seek"),
RunOnSpawn,
RunOnAppReady::default(),
ContinueRun::default(),
InsertOnTrigger::<AppReady, Running>::default(),
TargetAgent(parent.parent_entity()),
Seek,
));
Expand Down
7 changes: 2 additions & 5 deletions crates/beet_examples/src/scenes/sentence_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@ pub fn sentence_selector(mut commands: Commands) {
.spawn((
Name::new("Sentence Selector"),
AssetLoadBlockAppReady,
InsertOnSend::<AppReady, Running>::default(),
RunOnAppReady::default(),
TargetAgent(agent),
bert_handle,
SentenceScorer::default(),
ScoreSelector {
consume_scores: true,
},
SentenceFlow::default(),
))
.with_children(|parent| {
parent.spawn((
Expand Down
6 changes: 4 additions & 2 deletions crates/beet_examples/src/serde_utils/ready_on_asset_load.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::beet::prelude::AppReady;
use beet_net::events::RunOnAppReady;
use bevy::prelude::*;
use std::marker::PhantomData;

Expand All @@ -21,10 +22,10 @@ impl<A: Asset> Plugin for ReadyOnAssetLoadPlugin<A> {

pub fn ready_on_asset_load<A: Asset>(
mut asset_events: EventReader<AssetEvent<A>>,
mut ready_events: EventWriter<AppReady>,
mut commands: Commands,
query: Query<(Entity, &Handle<A>), With<AssetLoadBlockAppReady>>,
all_blocks: Query<Entity, With<AssetLoadBlockAppReady>>,
all_awaiting: Query<Entity, With<RunOnAppReady>>,
) {
let mut total_ready = 0;
for ev in asset_events.read() {
Expand All @@ -44,6 +45,7 @@ pub fn ready_on_asset_load<A: Asset>(
}
let total_blocks = all_blocks.iter().count();
if total_blocks > 0 && total_blocks == total_ready {
ready_events.send(AppReady);
let targets = all_awaiting.iter().collect::<Vec<_>>();
commands.trigger_targets(AppReady,targets);
}
}
1 change: 1 addition & 0 deletions crates/beet_ml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ wasm-bindgen-futures.workspace = true
console_error_panic_hook.workspace = true

[dev-dependencies]
beet_ecs = { workspace = true, features = ["test"] }
pretty_env_logger.workspace = true
sweet.workspace = true

Expand Down
2 changes: 1 addition & 1 deletion crates/beet_ml/src/language/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl Bert {
}


/// Score a list of entities with a [`Sentence`] against a root entity with a [`Sentence`]. This returns a list of entities with their sentence and raw cosine similarity scores.
/// Score and **sort** a list of entities with a [`Sentence`] against a root entity with a [`Sentence`]. This returns a list of entities with their sentence and raw cosine similarity scores.
/// Scores are in a range of `0..1`, higher means more similar, the list is sorted in descending order.
/// This calls [`Bert::get_embeddings`] and has the associated performance implications.
/// If the root is missing a [`Sentence`] an empty vec will be returned.
Expand Down
2 changes: 1 addition & 1 deletion crates/beet_ml/src/language/bert_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub struct BertPlugin;

impl Plugin for BertPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(ActionPlugin::<SentenceScorer>::default())
app.add_plugins(ActionPlugin::<SentenceFlow>::default())
.init_asset::<Bert>()
.init_asset_loader::<BertLoader>()
.register_type::<Sentence>()
Expand Down
4 changes: 2 additions & 2 deletions crates/beet_ml/src/language/selectors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub mod find_sentence_steer_target;
#[allow(unused_imports)]
pub use self::find_sentence_steer_target::*;
pub mod sentence_scorer;
pub mod sentence_flow;
#[allow(unused_imports)]
pub use self::sentence_scorer::*;
pub use self::sentence_flow::*;
111 changes: 111 additions & 0 deletions crates/beet_ml/src/language/selectors/sentence_flow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use crate::prelude::*;
use beet_ecs::prelude::*;
use bevy::prelude::*;
use forky_core::ResultTEExt;
use std::borrow::Cow;

/// This component is for use with [`SentenceFlow`]. Add to either the agent or a child behavior.
#[derive(Debug, Clone, Component, PartialEq, Reflect)]
#[reflect(Component)]
pub struct Sentence(pub Cow<'static, str>);
impl Sentence {
pub fn new(s: impl Into<Cow<'static, str>>) -> Self { Self(s.into()) }
}

/// Runs the child with the [`Sentence`] that is most similar to that of the agent.
/// for use with [`ScoreSelector`]
#[derive(Debug, Default, Clone, PartialEq, Action, Reflect)]
#[reflect(Component, ActionMeta)]
#[category(ActionCategory::ChildBehaviors)]
#[observers(sentence_flow)]
pub struct SentenceFlow;

impl SentenceFlow {
pub fn new() -> Self { Self {} }
}

fn sentence_flow(
trigger: Trigger<OnRun>,
mut commands: Commands,
mut berts: ResMut<Assets<Bert>>,
sentences: Query<&Sentence>,
// TODO double query, ie added running and added asset
query: Query<(&SentenceFlow, &Handle<Bert>, &TargetAgent, &Children)>,
) {
let (_scorer, handle, agent, children) = query
.get(trigger.entity())
.expect(expect_action::NO_ACTION_COMP);
let Some(bert) = berts.get_mut(handle) else {
// not ready yet
log::warn!("SentenceFlow: Bert asset was not ready, will not run");
return;
};

let children = children.into_iter().cloned().collect::<Vec<_>>();
//todo: async
bert.score_sentences(agent.0, children, &sentences)
.ok_or(|e| log::error!("{e}"))
.map(|scores| {
if let Some((entity, ..)) = scores.first() {
commands.entity(*entity).trigger(OnRun);
} else {
log::warn!("SentenceFlow: No scores returned");
}
});
}

#[cfg(test)]
mod test {
use crate::prelude::*;
use anyhow::Result;
use beet_ecs::prelude::*;
use bevy::prelude::*;
use sweet::*;



#[test]
fn works() -> Result<()> {
pretty_env_logger::try_init().ok();

let mut app = App::new();
app.add_plugins((
MinimalPlugins,
AssetPlugin::default(),
BertPlugin::default(),
LifecyclePlugin,
))
.finish();
let on_run = observe_trigger_names::<OnRun>(app.world_mut());

block_on_asset_load::<Bert>(&mut app, "default-bert.ron");

let handle = app
.world_mut()
.resource_mut::<AssetServer>()
.load::<Bert>("default-bert.ron");

let agent = app.world_mut().spawn(Sentence::new("destroy")).id();


app.world_mut()
.spawn((
Name::new("root"),
TargetAgent(agent),
handle,
SentenceFlow::default(),
))
.with_children(|parent| {
parent.spawn((Name::new("heal"), Sentence::new("heal")));
parent.spawn((Name::new("kill"), Sentence::new("kill")));
})
.flush_trigger(OnRun);


expect(&on_run).to_have_been_called_times(2)?;
expect(&on_run).to_have_returned_nth_with(0, &"root".to_string())?;
expect(&on_run).to_have_returned_nth_with(1, &"kill".to_string())?;

Ok(())
}
}
141 changes: 0 additions & 141 deletions crates/beet_ml/src/language/selectors/sentence_scorer.rs

This file was deleted.

0 comments on commit c2978f5

Please sign in to comment.