Skip to content

Commit

Permalink
track struct field accesses and unroll them to upgrade the correct field
Browse files Browse the repository at this point in the history
  • Loading branch information
schell committed Dec 10, 2024
1 parent 50c20fc commit a2baae9
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 61 deletions.
90 changes: 47 additions & 43 deletions naga/src/front/atomic_upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
//!
//! Future work:
//!
//! - Atomics in structs are not implemented yet.
//!
//! - The GLSL front end could use this transformation as well.
//!
//! [`Atomic`]: TypeInner::Atomic
Expand All @@ -32,16 +30,19 @@
//! [`Struct`]: TypeInner::Struct
//! [`Load`]: crate::Expression::Load
//! [`Store`]: crate::Statement::Store
use std::sync::{atomic::AtomicUsize, Arc};
use std::{
collections::VecDeque,
sync::{atomic::AtomicUsize, Arc},
};

use crate::{GlobalVariable, Handle, Module, Type, TypeInner};

#[derive(Clone, Debug, thiserror::Error)]
pub enum Error {
#[error("encountered an unsupported expression")]
Unsupported,
#[error("upgrading structs of more than one member is not yet implemented")]
MultiMemberStruct,
#[error("unexpected end of struct field access indices")]
UnexpectedEndOfIndices,
#[error("encountered unsupported global initializer in an atomic variable")]
GlobalInitUnsupported,
#[error("expected to find a global variable")]
Expand Down Expand Up @@ -87,6 +88,12 @@ impl Padding {
}
}

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
pub struct ContainedGlobalVariable {
pub field_indices: VecDeque<u32>,
pub handle: Handle<GlobalVariable>,
}

struct UpgradeState<'a> {
padding: Padding,
module: &'a mut Module,
Expand All @@ -99,55 +106,48 @@ impl UpgradeState<'_> {

/// Upgrade the type, recursing until we reach the leaves.
/// At the leaves, replace scalars with atomic scalars.
fn upgrade_type(&mut self, ty: Handle<Type>) -> Result<Handle<Type>, Error> {
fn upgrade_type(
&mut self,
ty: Handle<Type>,
field_indices: &mut VecDeque<u32>,
) -> Result<Handle<Type>, Error> {
let padding = self.inc_padding();
padding.trace("upgrading type: ", ty);
padding.trace("visiting type: ", ty);

let inner = match self.module.types[ty].inner {
TypeInner::Scalar(scalar) => {
log::trace!("{padding}hit the scalar leaf, replacing with an atomic");
TypeInner::Atomic(scalar)
}
TypeInner::Pointer { base, space } => TypeInner::Pointer {
base: self.upgrade_type(base)?,
base: self.upgrade_type(base, field_indices)?,
space,
},
TypeInner::Array { base, size, stride } => TypeInner::Array {
base: self.upgrade_type(base)?,
base: self.upgrade_type(base, field_indices)?,
size,
stride,
},
TypeInner::Struct { ref members, span } => {
// In the future we should have to figure out which member needs
// upgrading, but for now we'll only cover the single-member
// case.
let &[crate::StructMember {
ref name,
ty,
ref binding,
offset,
}] = &members[..]
else {
return Err(Error::MultiMemberStruct);
};

// Take our own clones of these values now, so that
// `upgrade_type` can mutate the module.
let name = name.clone();
let binding = binding.clone();
let upgraded_member_type = self.upgrade_type(ty)?;
let index = field_indices
.pop_back()
.ok_or(Error::UnexpectedEndOfIndices)?;

let mut new_members = vec![];
for (i, mut member) in members.clone().into_iter().enumerate() {
if i == index as usize {
member.ty = self.upgrade_type(member.ty, field_indices)?;
}
new_members.push(member);
}

TypeInner::Struct {
members: vec![crate::StructMember {
name,
ty: upgraded_member_type,
binding,
offset,
}],
members: new_members,
span,
}
}
TypeInner::BindingArray { base, size } => TypeInner::BindingArray {
base: self.upgrade_type(base)?,
base: self.upgrade_type(base, field_indices)?,
size,
},
_ => return Ok(ty),
Expand All @@ -168,23 +168,27 @@ impl UpgradeState<'_> {
Ok(new_handle)
}

fn upgrade_global_variable(&mut self, handle: Handle<GlobalVariable>) -> Result<(), Error> {
fn upgrade_global_variable(
&mut self,
mut global: ContainedGlobalVariable,
) -> Result<(), Error> {
let padding = self.inc_padding();
padding.trace("upgrading global variable: ", handle);
padding.trace("visiting global variable: ", &global);

let var = &self.module.global_variables[handle];
let var = &self.module.global_variables[global.handle];
padding.trace("var: ", var);

if var.init.is_some() {
return Err(Error::GlobalInitUnsupported);
}

let var_ty = var.ty;
let new_ty = self.upgrade_type(var.ty)?;
let new_ty = self.upgrade_type(var.ty, &mut global.field_indices)?;
if new_ty != var_ty {
padding.debug("upgrading global variable: ", handle);
padding.debug("upgrading global variable: ", global.handle);
padding.debug("from ty: ", var_ty);
padding.debug("to ty: ", new_ty);
self.module.global_variables[handle].ty = new_ty;
self.module.global_variables[global.handle].ty = new_ty;
}
Ok(())
}
Expand All @@ -196,15 +200,15 @@ impl Module {
/// [`Atomic`]: TypeInner::Atomic
pub(crate) fn upgrade_atomics(
&mut self,
global_var_handles: impl IntoIterator<Item = Handle<GlobalVariable>>,
globals: impl IntoIterator<Item = ContainedGlobalVariable>,
) -> Result<(), Error> {
let mut state = UpgradeState {
padding: Default::default(),
module: self,
};

for handle in global_var_handles {
state.upgrade_global_variable(handle)?;
for global in globals {
state.upgrade_global_variable(global)?;
}

Ok(())
Expand Down
23 changes: 15 additions & 8 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ use crate::{
};

use petgraph::graphmap::GraphMap;
use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf};
use std::{collections::VecDeque, convert::TryInto, mem, num::NonZeroU32, path::PathBuf};

use super::atomic_upgrade::ContainedGlobalVariable;

pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
spirv::Capability::Shader,
Expand Down Expand Up @@ -569,21 +571,26 @@ impl BlockContext<'_> {
fn get_contained_global_variable(
&self,
mut handle: Handle<crate::Expression>,
) -> Result<Handle<crate::GlobalVariable>, Error> {
) -> Result<ContainedGlobalVariable, Error> {
log::debug!("\t\tlocating global variable in {handle:?}");
let mut accesses = VecDeque::default();
loop {
match self.expressions[handle] {
crate::Expression::Access { base, index: _ } => {
crate::Expression::Access { base, index } => {
handle = base;
log::debug!("\t\t access {handle:?}");
log::debug!("\t\t access {handle:?} {index:?}");
}
crate::Expression::AccessIndex { base, index: _ } => {
crate::Expression::AccessIndex { base, index } => {
handle = base;
log::debug!("\t\t access index {handle:?}");
accesses.push_back(index);
log::debug!("\t\t access index {handle:?} {index:?}");
}
crate::Expression::GlobalVariable(h) => {
log::debug!("\t\t found {h:?}");
return Ok(h);
return Ok(ContainedGlobalVariable {
field_indices: accesses,
handle: h,
});
}
_ => {
break;
Expand Down Expand Up @@ -616,7 +623,7 @@ pub struct Frontend<I> {
/// generated, so we can upgrade the types of their operands.
///
/// [`Atomic`]: crate::Statement::Atomic
upgrade_atomics: IndexSet<Handle<crate::GlobalVariable>>,
upgrade_atomics: IndexSet<ContainedGlobalVariable>,

lookup_type: FastHashMap<spirv::Word, LookupType>,
lookup_void_type: Option<spirv::Word>,
Expand Down
Binary file modified naga/tests/in/spv/atomic_global_struct_field_vertex.spv
Binary file not shown.
20 changes: 10 additions & 10 deletions naga/tests/in/spv/atomic_global_struct_field_vertex.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
OpMemoryModel Logical Vulkan
OpEntryPoint Vertex %1 "global_field_vertex" %2 %gl_Position
OpMemberDecorate %_struct_9 0 Offset 0
OpMemberDecorate %_struct_9 1 Offset 4
OpMemberDecorate %_struct_9 2 Offset 8
OpMemberDecorate %_struct_9 1 Offset 8
OpMemberDecorate %_struct_9 2 Offset 16
OpDecorate %_struct_10 Block
OpMemberDecorate %_struct_10 0 Offset 0
OpDecorate %2 Binding 0
Expand All @@ -18,7 +18,7 @@
%uint = OpTypeInt 32 0
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%_struct_9 = OpTypeStruct %uint %uint %v2float
%_struct_9 = OpTypeStruct %uint %v2float %uint
%_struct_10 = OpTypeStruct %_struct_9
%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10
%v4float = OpTypeVector %float 4
Expand All @@ -28,20 +28,20 @@
%2 = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer
%uint_0 = OpConstant %uint 0
%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%uint_5 = OpConstant %uint 5
%_ptr_StorageBuffer_v2float = OpTypePointer StorageBuffer %v2float
%uint_2 = OpConstant %uint 2
%uint_1 = OpConstant %uint 1
%float_0 = OpConstant %float 0
%gl_Position = OpVariable %_ptr_Output_v4float Output
%1 = OpFunction %void None %18
%28 = OpLabel
%30 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %2 %uint_0 %uint_1
%31 = OpLoad %uint %30
%32 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %2 %uint_0 %uint_0
%33 = OpAtomicIAdd %uint %32 %uint_5 %uint_0 %31
%30 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %2 %uint_0 %uint_2
%31 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %2 %uint_0 %uint_0
%32 = OpLoad %uint %31
%33 = OpAtomicIAdd %uint %30 %uint_5 %uint_0 %32
%34 = OpConvertUToF %float %33
%35 = OpInBoundsAccessChain %_ptr_StorageBuffer_v2float %2 %uint_0 %uint_2
%35 = OpInBoundsAccessChain %_ptr_StorageBuffer_v2float %2 %uint_0 %uint_1
%36 = OpLoad %v2float %35
%37 = OpCompositeExtract %float %36 0
%38 = OpCompositeExtract %float %36 1
Expand Down

0 comments on commit a2baae9

Please sign in to comment.