diff --git a/naga/src/arena/handle_set.rs b/naga/src/arena/handle_set.rs index ef2ded2ddb1..f2ce058d12f 100644 --- a/naga/src/arena/handle_set.rs +++ b/naga/src/arena/handle_set.rs @@ -25,6 +25,10 @@ impl HandleSet { } } + pub fn is_empty(&self) -> bool { + self.members.is_empty() + } + /// Return a new, empty `HandleSet`, sized to hold handles from `arena`. pub fn for_arena(arena: &impl ArenaType) -> Self { let len = arena.len(); diff --git a/naga/src/front/atomic_upgrade.rs b/naga/src/front/atomic_upgrade.rs index 2053702d454..529883ae410 100644 --- a/naga/src/front/atomic_upgrade.rs +++ b/naga/src/front/atomic_upgrade.rs @@ -85,15 +85,50 @@ impl Padding { } } -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)] -pub struct ContainedGlobalVariable { - pub field_indices: Vec, - pub handle: Handle, +#[derive(Debug, Default)] +pub struct Upgrades { + /// Global variables that we've accessed using atomic operations. + /// + /// This includes globals with composite types (arrays, structs) where we've + /// only accessed some components (elements, fields) atomically. + globals: crate::arena::HandleSet, + + /// Struct fields that we've accessed using atomic operations. + /// + /// Each key refers to some [`Struct`] type, and each value is a set of + /// the indices of the fields in that struct that have been accessed + /// atomically. + /// + /// This includes fields with composite types (arrays, structs) + /// of which we've only accessed some components (elements, fields) + /// atomically. + /// + /// [`Struct`]: crate::TypeInner::Struct + fields: crate::FastHashMap, bit_set::BitSet>, +} + +impl Upgrades { + pub fn insert_global(&mut self, global: Handle) { + self.globals.insert(global); + } + + pub fn insert_field(&mut self, struct_type: Handle, field: usize) { + self.fields.entry(struct_type).or_default().insert(field); + } + + pub fn is_empty(&self) -> bool { + self.globals.is_empty() + } } struct UpgradeState<'a> { padding: Padding, module: &'a mut Module, + + /// A map from old types to their upgraded versions. + /// + /// This ensures we never try to rebuild a type more than once. + upgraded_types: crate::FastHashMap, Handle>, } impl UpgradeState<'_> { @@ -123,29 +158,42 @@ impl UpgradeState<'_> { fn upgrade_type( &mut self, ty: Handle, - field_indices: &mut Vec, + upgrades: &Upgrades, ) -> Result, Error> { let padding = self.inc_padding(); padding.trace("visiting type: ", ty); + // If we've already upgraded this type, return the handle we produced at + // the time. + if let Some(&new) = self.upgraded_types.get(&ty) { + return Ok(new); + } + let inner = match self.module.types[ty].inner { TypeInner::Scalar(scalar) => { log::trace!("{padding}hit the scalar leaf, replacing with an atomic"); TypeInner::Atomic(scalar) } TypeInner::Pointer { base, space } => TypeInner::Pointer { - base: self.upgrade_type(base, field_indices)?, + base: self.upgrade_type(base, upgrades)?, space, }, TypeInner::Array { base, size, stride } => TypeInner::Array { - base: self.upgrade_type(base, field_indices)?, + base: self.upgrade_type(base, upgrades)?, size, stride, }, TypeInner::Struct { ref members, span } => { - let index = field_indices.pop().ok_or(Error::UnexpectedEndOfIndices)? as usize; + // If no field or subfield of this struct was ever accessed + // atomically, no change is needed. We should never have arrived here. + let Some(fields) = upgrades.fields.get(&ty) else { + unreachable!("global or field incorrectly flagged as atomically accessed"); + }; + let mut new_members = members.clone(); - new_members[index].ty = self.upgrade_type(new_members[index].ty, field_indices)?; + for field in fields { + new_members[field].ty = self.upgrade_type(new_members[field].ty, upgrades)?; + } TypeInner::Struct { members: new_members, @@ -153,7 +201,7 @@ impl UpgradeState<'_> { } } TypeInner::BindingArray { base, size } => TypeInner::BindingArray { - base: self.upgrade_type(base, field_indices)?, + base: self.upgrade_type(base, upgrades)?, size, }, _ => return Ok(ty), @@ -172,31 +220,32 @@ impl UpgradeState<'_> { padding.debug("from: ", r#type); padding.debug("to: ", &new_type); let new_handle = self.module.types.insert(new_type, span); + self.upgraded_types.insert(ty, new_handle); Ok(new_handle) } - fn upgrade_global_variable( - &mut self, - mut global: ContainedGlobalVariable, - ) -> Result<(), Error> { - let padding = self.inc_padding(); - padding.trace("visiting global variable: ", &global); + fn upgrade_all(&mut self, upgrades: &Upgrades) -> Result<(), Error> { + for handle in upgrades.globals.iter() { + let padding = self.inc_padding(); - let var = &self.module.global_variables[global.handle]; - padding.trace("var: ", var); + let global = &self.module.global_variables[handle]; + padding.trace("visiting global variable: ", handle); + padding.trace("var: ", global); - if var.init.is_some() { - return Err(Error::GlobalInitUnsupported); - } + if global.init.is_some() { + return Err(Error::GlobalInitUnsupported); + } - let var_ty = var.ty; - let new_ty = self.upgrade_type(var.ty, &mut global.field_indices)?; - if new_ty != var_ty { - padding.debug("upgrading global variable: ", global.handle); - padding.debug("from ty: ", var_ty); - padding.debug("to ty: ", new_ty); - self.module.global_variables[global.handle].ty = new_ty; + let var_ty = global.ty; + let new_ty = self.upgrade_type(var_ty, upgrades)?; + if new_ty != var_ty { + padding.debug("upgrading global variable: ", handle); + padding.debug("from ty: ", var_ty); + padding.debug("to ty: ", new_ty); + self.module.global_variables[handle].ty = new_ty; + } } + Ok(()) } } @@ -205,18 +254,17 @@ impl Module { /// Upgrade `global_var_handles` to have [`Atomic`] leaf types. /// /// [`Atomic`]: TypeInner::Atomic - pub(crate) fn upgrade_atomics( - &mut self, - globals: impl IntoIterator, - ) -> Result<(), Error> { + pub(crate) fn upgrade_atomics(&mut self, upgrades: &Upgrades) -> Result<(), Error> { let mut state = UpgradeState { padding: Default::default(), module: self, + upgraded_types: crate::FastHashMap::with_capacity_and_hasher( + upgrades.fields.len(), + Default::default(), + ), }; - for global in globals { - state.upgrade_global_variable(global)?; - } + state.upgrade_all(upgrades)?; Ok(()) } diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 25010912ebb..2dd5a53e9e6 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -36,7 +36,6 @@ mod null; use convert::*; pub use error::Error; use function::*; -use indexmap::IndexSet; use crate::{ arena::{Arena, Handle, UniqueArena}, @@ -47,7 +46,7 @@ use crate::{ use petgraph::graphmap::GraphMap; use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf}; -use super::atomic_upgrade::ContainedGlobalVariable; +use super::atomic_upgrade::Upgrades; pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[ spirv::Capability::Shader, @@ -559,50 +558,6 @@ struct BlockContext<'function> { parameter_sampling: &'function mut [image::SamplingFlags], } -impl BlockContext<'_> { - /// Descend into the expression with the given handle, locating a contained - /// global variable. - /// - /// If the expression doesn't actually refer to something in a global - /// variable, we can't upgrade its type in a way that Naga validation would - /// pass, so reject the input instead. - /// - /// This is used to track atomic upgrades. - fn get_contained_global_variable( - &self, - mut handle: Handle, - ) -> Result { - log::debug!("\t\tlocating global variable in {handle:?}"); - let mut accesses = vec![]; - loop { - match self.expressions[handle] { - crate::Expression::Access { base, index } => { - handle = base; - log::debug!("\t\t access {handle:?} {index:?}"); - } - crate::Expression::AccessIndex { base, index } => { - handle = base; - accesses.push(index); - log::debug!("\t\t access index {handle:?} {index:?}"); - } - crate::Expression::GlobalVariable(h) => { - log::debug!("\t\t found {h:?}"); - return Ok(ContainedGlobalVariable { - field_indices: accesses, - handle: h, - }); - } - _ => { - break; - } - } - } - Err(Error::AtomicUpgradeError( - crate::front::atomic_upgrade::Error::GlobalVariableMissing, - )) - } -} - enum SignAnchor { Result, Operand, @@ -619,11 +574,12 @@ pub struct Frontend { future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>, lookup_member: FastHashMap<(Handle, MemberIndex), LookupMember>, handle_sampling: FastHashMap, image::SamplingFlags>, - /// The set of all global variables accessed by [`Atomic`] statements we've + + /// A record of what is accessed by [`Atomic`] statements we've /// generated, so we can upgrade the types of their operands. /// /// [`Atomic`]: crate::Statement::Atomic - upgrade_atomics: IndexSet, + upgrade_atomics: Upgrades, lookup_type: FastHashMap, lookup_void_type: Option, @@ -1485,8 +1441,7 @@ impl> Frontend { block.push(stmt, span); // Store any associated global variables so we can upgrade their types later - self.upgrade_atomics - .insert(ctx.get_contained_global_variable(p_lexp_handle)?); + self.record_atomic_access(ctx, p_lexp_handle)?; Ok(()) } @@ -4182,8 +4137,7 @@ impl> Frontend { ); // Store any associated global variables so we can upgrade their types later - self.upgrade_atomics - .insert(ctx.get_contained_global_variable(p_lexp_handle)?); + self.record_atomic_access(ctx, p_lexp_handle)?; } Op::AtomicStore => { inst.expect(5)?; @@ -4212,8 +4166,7 @@ impl> Frontend { emitter.start(ctx.expressions); // Store any associated global variables so we can upgrade their types later - self.upgrade_atomics - .insert(ctx.get_contained_global_variable(p_lexp_handle)?); + self.record_atomic_access(ctx, p_lexp_handle)?; } Op::AtomicIIncrement | Op::AtomicIDecrement => { inst.expect(6)?; @@ -4277,8 +4230,7 @@ impl> Frontend { block.push(stmt, span); // Store any associated global variables so we can upgrade their types later - self.upgrade_atomics - .insert(ctx.get_contained_global_variable(p_exp_h)?); + self.record_atomic_access(ctx, p_exp_h)?; } Op::AtomicCompareExchange => { inst.expect(9)?; @@ -4373,8 +4325,7 @@ impl> Frontend { block.push(stmt, span); // Store any associated global variables so we can upgrade their types later - self.upgrade_atomics - .insert(ctx.get_contained_global_variable(p_exp_h)?); + self.record_atomic_access(ctx, p_exp_h)?; } Op::AtomicExchange | Op::AtomicIAdd @@ -4689,7 +4640,7 @@ impl> Frontend { if !self.upgrade_atomics.is_empty() { log::info!("Upgrading atomic pointers..."); - module.upgrade_atomics(mem::take(&mut self.upgrade_atomics))?; + module.upgrade_atomics(&self.upgrade_atomics)?; } // Do entry point specific processing after all functions are parsed so that we can @@ -5984,6 +5935,59 @@ impl> Frontend { ); Ok(()) } + + /// Record an atomic access to some component of a global variable. + /// + /// Given `handle`, an expression referring to a scalar that has had an + /// atomic operation applied to it, descend into the expression, noting + /// which global variable it ultimately refers to, and which struct fields + /// of that global's value it accesses. + /// + /// Return the handle of the type of the expression. + /// + /// If the expression doesn't actually refer to something in a global + /// variable, we can't upgrade its type in a way that Naga validation would + /// pass, so reject the input instead. + fn record_atomic_access( + &mut self, + ctx: &BlockContext, + handle: Handle, + ) -> Result, Error> { + log::debug!("\t\tlocating global variable in {handle:?}"); + match ctx.expressions[handle] { + crate::Expression::Access { base, index } => { + log::debug!("\t\t access {handle:?} {index:?}"); + let ty = self.record_atomic_access(ctx, base)?; + let crate::TypeInner::Array { base, .. } = ctx.module.types[ty].inner else { + unreachable!("Atomic operations on Access expressions only work for arrays"); + }; + Ok(base) + } + crate::Expression::AccessIndex { base, index } => { + log::debug!("\t\t access index {handle:?} {index:?}"); + let ty = self.record_atomic_access(ctx, base)?; + match ctx.module.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let index = index as usize; + self.upgrade_atomics.insert_field(ty, index); + Ok(members[index].ty) + } + crate::TypeInner::Array { base, .. } => { + Ok(base) + } + _ => unreachable!("Atomic operations on AccessIndex expressions only work for structs and arrays"), + } + } + crate::Expression::GlobalVariable(h) => { + log::debug!("\t\t found {h:?}"); + self.upgrade_atomics.insert_global(h); + Ok(ctx.module.global_variables[h].ty) + } + _ => Err(Error::AtomicUpgradeError( + crate::front::atomic_upgrade::Error::GlobalVariableMissing, + )), + } + } } fn make_index_literal(