Skip to content

Commit

Permalink
Rework QuerySet to be safe
Browse files Browse the repository at this point in the history
  • Loading branch information
cart committed Aug 7, 2021
1 parent 42fedfc commit 110b13b
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 118 deletions.
48 changes: 18 additions & 30 deletions crates/bevy_ecs/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use syn::{
parse_macro_input,
punctuated::Punctuated,
token::Comma,
Data, DataStruct, DeriveInput, Field, Fields, GenericParam, Ident, Index, Lifetime, LitInt,
Path, Result, Token,
Data, DataStruct, DeriveInput, Field, Fields, GenericParam, Ident, Index, LitInt, Path, Result,
Token,
};

struct AllTuples {
Expand Down Expand Up @@ -176,51 +176,36 @@ fn get_idents(fmt_string: fn(usize) -> String, count: usize) -> Vec<Ident> {
.collect::<Vec<Ident>>()
}

fn get_lifetimes(fmt_string: fn(usize) -> String, count: usize) -> Vec<Lifetime> {
(0..count)
.map(|i| Lifetime::new(&fmt_string(i), Span::call_site()))
.collect::<Vec<Lifetime>>()
}

#[proc_macro]
pub fn impl_query_set(_input: TokenStream) -> TokenStream {
let mut tokens = TokenStream::new();
let max_queries = 4;
let queries = get_idents(|i| format!("Q{}", i), max_queries);
let filters = get_idents(|i| format!("F{}", i), max_queries);
let lifetimes = get_lifetimes(|i| format!("'q{}", i), max_queries);
let state_lifetimes = get_lifetimes(|i| format!("'qs{}", i), max_queries);
let mut query_fns = Vec::new();
let mut query_fn_muts = Vec::new();
for i in 0..max_queries {
let query = &queries[i];
let filter = &filters[i];
let lifetime = &lifetimes[i];
let state_lifetime = &state_lifetimes[i];
let fn_name = Ident::new(&format!("q{}", i), Span::call_site());
let fn_name_mut = Ident::new(&format!("q{}_mut", i), Span::call_site());
let index = Index::from(i);
query_fns.push(quote! {
pub fn #fn_name(&self) -> &Query<#lifetime, #state_lifetime, #query, #filter> {
&self.0.#index
}
});
query_fn_muts.push(quote! {
pub fn #fn_name_mut(&mut self) -> &mut Query<#lifetime, #state_lifetime, #query, #filter> {
&mut self.0.#index
pub fn #fn_name(&mut self) -> Query<'_, '_, #query, #filter> {
// SAFE: systems run without conflicts with other systems.
// Conflicting queries in QuerySet are not accessible at the same time
// QuerySets are guaranteed to not conflict with other SystemParams
unsafe {
Query::new(self.world, &self.query_states.#index, self.last_change_tick, self.change_tick)
}
}
});
}

for query_count in 1..=max_queries {
let query = &queries[0..query_count];
let filter = &filters[0..query_count];
let lifetime = &lifetimes[0..query_count];
let state_lifetime = &state_lifetimes[0..query_count];
let query_fn = &query_fns[0..query_count];
let query_fn_mut = &query_fn_muts[0..query_count];
tokens.extend(TokenStream::from(quote! {
impl<#(#lifetime,)* #(#state_lifetime,)* #(#query: WorldQuery + 'static,)* #(#filter: WorldQuery + 'static,)*> SystemParam for QuerySet<(#(Query<#lifetime, #state_lifetime, #query, #filter>,)*)>
impl<'w, 's, #(#query: WorldQuery + 'static,)* #(#filter: WorldQuery + 'static,)*> SystemParam for QuerySet<'w, 's, (#(QueryState<#query, #filter>,)*)>
where #(#filter::Fetch: FilterFetch,)*
{
type Fetch = QuerySetState<(#(QueryState<#query, #filter>,)*)>;
Expand Down Expand Up @@ -276,7 +261,7 @@ pub fn impl_query_set(_input: TokenStream) -> TokenStream {
impl<'w, 's, #(#query: WorldQuery + 'static,)* #(#filter: WorldQuery + 'static,)*> SystemParamFetch<'w, 's> for QuerySetState<(#(QueryState<#query, #filter>,)*)>
where #(#filter::Fetch: FilterFetch,)*
{
type Item = QuerySet<(#(Query<'w, 's, #query, #filter>,)*)>;
type Item = QuerySet<'w, 's, (#(QueryState<#query, #filter>,)*)>;

#[inline]
unsafe fn get_param(
Expand All @@ -285,15 +270,18 @@ pub fn impl_query_set(_input: TokenStream) -> TokenStream {
world: &'w World,
change_tick: u32,
) -> Self::Item {
let (#(#query,)*) = &state.0;
QuerySet((#(Query::new(world, #query, system_meta.last_change_tick, change_tick),)*))
QuerySet {
query_states: &state.0,
world,
last_change_tick: system_meta.last_change_tick,
change_tick,
}
}
}

impl<#(#lifetime,)* #(#state_lifetime,)* #(#query: WorldQuery,)* #(#filter: WorldQuery,)*> QuerySet<(#(Query<#lifetime, #state_lifetime, #query, #filter>,)*)>
impl<'w, 's, #(#query: WorldQuery,)* #(#filter: WorldQuery,)*> QuerySet<'w, 's, (#(QueryState<#query, #filter>,)*)>
where #(#filter::Fetch: FilterFetch,)*
{
#(#query_fn)*
#(#query_fn_mut)*
}
}));
Expand Down
42 changes: 23 additions & 19 deletions crates/bevy_ecs/src/system/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mod tests {
bundle::Bundles,
component::Components,
entity::{Entities, Entity},
query::{Added, Changed, Or, With, Without},
query::{Added, Changed, Or, QueryState, With, Without},
schedule::{Schedule, Stage, SystemStage},
system::{
ConfigurableSystem, IntoExclusiveSystem, IntoSystem, Local, Query, QuerySet,
Expand Down Expand Up @@ -131,9 +131,9 @@ mod tests {
// Regression test for issue #762
fn query_system(
mut ran: ResMut<bool>,
set: QuerySet<(
Query<(), Or<(Changed<A>, Changed<B>)>>,
Query<(), Or<(Added<A>, Added<B>)>>,
mut set: QuerySet<(
QueryState<(), Or<(Changed<A>, Changed<B>)>>,
QueryState<(), Or<(Added<A>, Added<B>)>>,
)>,
) {
let changed = set.q0().iter().count();
Expand Down Expand Up @@ -236,15 +236,15 @@ mod tests {

#[test]
fn query_set_system() {
fn sys(mut _set: QuerySet<(Query<&mut A>, Query<&A>)>) {}
fn sys(mut _set: QuerySet<(QueryState<&mut A>, QueryState<&A>)>) {}
let mut world = World::default();
run_system(&mut world, sys);
}

#[test]
#[should_panic]
fn conflicting_query_with_query_set_system() {
fn sys(_query: Query<&mut A>, _set: QuerySet<(Query<&mut A>, Query<&B>)>) {}
fn sys(_query: Query<&mut A>, _set: QuerySet<(QueryState<&mut A>, QueryState<&B>)>) {}

let mut world = World::default();
run_system(&mut world, sys);
Expand All @@ -253,7 +253,11 @@ mod tests {
#[test]
#[should_panic]
fn conflicting_query_sets_system() {
fn sys(_set_1: QuerySet<(Query<&mut A>,)>, _set_2: QuerySet<(Query<&mut A>, Query<&B>)>) {}
fn sys(
_set_1: QuerySet<(QueryState<&mut A>,)>,
_set_2: QuerySet<(QueryState<&mut A>, QueryState<&B>)>,
) {
}

let mut world = World::default();
run_system(&mut world, sys);
Expand Down Expand Up @@ -520,8 +524,11 @@ mod tests {
world.insert_resource(A(42));
world.spawn().insert(B(7));

let mut system_state: SystemState<(Res<A>, Query<&B>, QuerySet<(Query<&C>, Query<&D>)>)> =
SystemState::new(&mut world);
let mut system_state: SystemState<(
Res<A>,
Query<&B>,
QuerySet<(QueryState<&C>, QueryState<&D>)>,
)> = SystemState::new(&mut world);
let (a, query, _) = system_state.get(&world);
assert_eq!(*a, A(42), "returned resource matches initial value");
assert_eq!(
Expand Down Expand Up @@ -651,11 +658,10 @@ mod tests {
println!("{} {}", a1.0, a2.0);
}
}

fn query_set(mut queries: QuerySet<(Query<&mut A>,Query<&A>)>, e: Res<Entity>) {
{

let q2 = queries.q0_mut();
fn query_set(mut queries: QuerySet<(QueryState<&mut A>, QueryState<&A>)>, e: Res<Entity>) {
{
let mut q2 = queries.q0();
let mut iter2 = q2.iter_mut();
let mut b = iter2.next().unwrap();

Expand All @@ -668,21 +674,19 @@ mod tests {
}

{

let q1 = queries.q1();
let mut iter = q1.iter();
let a = &*iter.next().unwrap();

let q2 = queries.q0_mut();
let mut q2 = queries.q0();
let mut iter2 = q2.iter_mut();
let mut b = iter2.next().unwrap();

// this should fail to compile (but currently doesn't)
b.0 = a.0;
}
{

let q2 = queries.q0_mut();
let mut q2 = queries.q0();
let mut b = q2.get_mut(*e).unwrap();

let q1 = queries.q1();
Expand All @@ -696,7 +700,7 @@ mod tests {
let q1 = queries.q1();
let a = q1.get(*e).unwrap();

let q2 = queries.q0_mut();
let mut q2 = queries.q0();
let mut b = q2.get_mut(*e).unwrap();
// this should fail to compile (but currently doesn't)
b.0 = a.0
Expand All @@ -708,7 +712,7 @@ mod tests {
run_system(&mut world, system);
}

/// this test exists to show that read-only world-only queries can return data that lives as long as 'world
/// this test exists to show that read-only world-only queries can return data that lives as long as 'world
#[test]
fn long_life_test() {
struct Holder<'w> {
Expand Down
38 changes: 22 additions & 16 deletions crates/bevy_ecs/src/system/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ where
}
}

/// Returns an [`Iterator`] over the query results.
#[inline]
pub fn iter_mut(&mut self) -> QueryIter<'_, '_, Q, F> {
// SAFE: system runs without conflicts with other systems.
// same-system queries have runtime borrow checks when they conflict
unsafe {
self.state
.iter_unchecked_manual(self.world, self.last_change_tick, self.change_tick)
}
}

/// Returns an [`Iterator`] over all possible combinations of `K` query results without repetition.
/// This can only be called for read-only queries
///
Expand All @@ -181,17 +192,6 @@ where
}
}

/// Returns an [`Iterator`] over the query results.
#[inline]
pub fn iter_mut(&mut self) -> QueryIter<'_, '_, Q, F> {
// SAFE: system runs without conflicts with other systems.
// same-system queries have runtime borrow checks when they conflict
unsafe {
self.state
.iter_unchecked_manual(self.world, self.last_change_tick, self.change_tick)
}
}

/// Iterates over all possible combinations of `K` query results without repetition.
///
/// The returned value is not an `Iterator`, because that would lead to aliasing of mutable references.
Expand Down Expand Up @@ -285,7 +285,7 @@ where
/// Runs `f` on each query result. This is faster than the equivalent iter() method, but cannot
/// be chained like a normal [`Iterator`].
#[inline]
pub fn for_each_mut(&'s mut self, f: impl FnMut(<Q::Fetch as Fetch<'w, 's>>::Item)) {
pub fn for_each_mut(&mut self, f: impl FnMut(<Q::Fetch as Fetch<'_, '_>>::Item)) {
// SAFE: system runs without conflicts with other systems. same-system queries have runtime
// borrow checks when they conflict
unsafe {
Expand Down Expand Up @@ -328,10 +328,10 @@ where
/// Runs `f` on each query result in parallel using the given task pool.
#[inline]
pub fn par_for_each_mut(
&'s mut self,
&mut self,
task_pool: &TaskPool,
batch_size: usize,
f: impl Fn(<Q::Fetch as Fetch<'w, 's>>::Item) + Send + Sync + Clone,
f: impl Fn(<Q::Fetch as Fetch<'_, '_>>::Item) + Send + Sync + Clone,
) {
// SAFE: system runs without conflicts with other systems. same-system queries have runtime
// borrow checks when they conflict
Expand All @@ -351,7 +351,10 @@ where
///
/// This can only be called for read-only queries, see [`Self::get_mut`] for write-queries.
#[inline]
pub fn get(&'s self, entity: Entity) -> Result<<Q::Fetch as Fetch<'w, 's>>::Item, QueryEntityError>
pub fn get(
&'s self,
entity: Entity,
) -> Result<<Q::Fetch as Fetch<'w, 's>>::Item, QueryEntityError>
where
Q::Fetch: ReadOnlyFetch,
{
Expand Down Expand Up @@ -406,7 +409,10 @@ where
/// entity does not have the given component type or if the given component type does not match
/// this query.
#[inline]
pub fn get_component<T: Component>(&self, entity: Entity) -> Result<&T, QueryComponentError> {
pub fn get_component<T: Component>(
&self,
entity: Entity,
) -> Result<&'w T, QueryComponentError> {
let world = self.world;
let entity_ref = world
.get_entity(entity)
Expand Down
13 changes: 10 additions & 3 deletions crates/bevy_ecs/src/system/system_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ where
fn default_config() {}
}

impl<'w, 's, Q: WorldQuery + 'static, F: WorldQuery + 'static> SystemParamFetch<'w, 's> for QueryState<Q, F>
impl<'w, 's, Q: WorldQuery + 'static, F: WorldQuery + 'static> SystemParamFetch<'w, 's>
for QueryState<Q, F>
where
F::Fetch: FilterFetch,
{
Expand Down Expand Up @@ -184,7 +185,13 @@ fn assert_component_access_compatibility(
query_type, filter_type, system_name, accesses);
}

pub struct QuerySet<T>(T);
pub struct QuerySet<'w, 's, T> {
query_states: &'s T,
world: &'w World,
last_change_tick: u32,
change_tick: u32,
}

pub struct QuerySetState<T>(T);

impl_query_set!();
Expand Down Expand Up @@ -228,7 +235,7 @@ impl<'w, T: Component> Res<'w, T> {
self.ticks
.is_changed(self.last_change_tick, self.change_tick)
}

pub fn into_inner(self) -> &'w T {
self.value
}
Expand Down
9 changes: 5 additions & 4 deletions crates/bevy_render/src/camera/camera.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use bevy_ecs::{
component::Component,
entity::Entity,
event::EventReader,
prelude::QueryState,
query::Added,
reflect::ReflectComponent,
system::{Query, QuerySet, Res},
system::{QuerySet, Res},
};
use bevy_math::{Mat4, Vec2, Vec3};
use bevy_reflect::{Reflect, ReflectDeserialize};
Expand Down Expand Up @@ -70,8 +71,8 @@ pub fn camera_system<T: CameraProjection + Component>(
mut window_created_events: EventReader<WindowCreated>,
windows: Res<Windows>,
mut queries: QuerySet<(
Query<(Entity, &mut Camera, &mut T)>,
Query<Entity, Added<Camera>>,
QueryState<(Entity, &mut Camera, &mut T)>,
QueryState<Entity, Added<Camera>>,
)>,
) {
let mut changed_window_ids = Vec::new();
Expand Down Expand Up @@ -99,7 +100,7 @@ pub fn camera_system<T: CameraProjection + Component>(
for entity in &mut queries.q1().iter() {
added_cameras.push(entity);
}
for (entity, mut camera, mut camera_projection) in queries.q0_mut().iter_mut() {
for (entity, mut camera, mut camera_projection) in queries.q0().iter_mut() {
if let Some(window) = windows.get(camera.window) {
if changed_window_ids.contains(&window.id())
|| added_cameras.contains(&entity)
Expand Down
Loading

0 comments on commit 110b13b

Please sign in to comment.