Skip to content

Commit

Permalink
spv-out: implement OpArrayLength on array buffer bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Jun 6, 2023
1 parent 907b7c7 commit 9fdfb1a
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 173 deletions.
69 changes: 59 additions & 10 deletions src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,76 @@ impl<'w> BlockContext<'w> {
// only the latter, so this back end wraps bare runtime-sized arrays
// in a made-up struct; see `helpers::global_needs_wrapper` and its uses.
// This code must handle both cases.
let (structure_id, last_member_index) = match self.ir_function.expressions[array] {
let (opt_array_index, global_handle, opt_last_member_index) = match self
.ir_function
.expressions[array]
{
// Note that SPIR-V forbids `OpArrayLength` on a variable pointer,
// so we aren't handling `crate::Expression::Access` here.
crate::Expression::AccessIndex { base, index } => {
match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(handle) => (
self.writer.global_variables[handle.index()].access_id,
index,
),
_ => return Err(Error::Validation("array length expression")),
// The global variable is an array of buffer bindings of structs,
// and we are accessing the last member.
crate::Expression::AccessIndex {
base: base_outer,
index: index_outer,
} => match self.ir_function.expressions[base_outer] {
crate::Expression::GlobalVariable(handle) => {
(Some(index_outer), handle, Some(index))
}
_ => return Err(Error::Validation("array length expression case-1a")),
},
crate::Expression::GlobalVariable(handle) => {
let global = &self.ir_module.global_variables[handle];
match self.ir_module.types[global.ty].inner {
// The global variable is an array of buffer bindings of run-time arrays.
crate::TypeInner::BindingArray { .. } => (Some(index), handle, None),
// The global variable is a struct, and we are accessing the last member
_ => (None, handle, Some(index)),
}
}
_ => return Err(Error::Validation("array length expression case-1c")),
}
}
// The global variable is a run-time array.
crate::Expression::GlobalVariable(handle) => {
let global = &self.ir_module.global_variables[handle];
if !global_needs_wrapper(self.ir_module, global) {
return Err(Error::Validation("array length expression"));
return Err(Error::Validation("array length expression case-2"));
}

(self.writer.global_variables[handle.index()].var_id, 0)
(None, handle, None)
}
_ => return Err(Error::Validation("array length expression")),
_ => return Err(Error::Validation("array length expression case-3")),
};

let gvar = self.writer.global_variables[global_handle.index()].clone();
let (last_member_index, gvar_id) = match opt_last_member_index {
Some(index) => (index, gvar.access_id),
None => {
let global = &self.ir_module.global_variables[global_handle];
if !global_needs_wrapper(self.ir_module, global) {
return Err(Error::Validation(
"pointer to a global that is not a wrapped array",
));
}
(0, gvar.var_id)
}
};
let structure_id = match opt_array_index {
// We are indexing inside a binding array, generate the access op.
Some(index) => {
let index_id = self.get_index_constant(index);
let structure_id = self.gen_id();
block.body.push(Instruction::access_chain(
gvar.pointer_to_binding_array_element_type_id,
structure_id,
gvar_id,
&[index_id],
));
structure_id
}
None => gvar_id,
};
let length_id = self.gen_id();
block.body.push(Instruction::array_length(
self.writer.get_uint_type_id(),
Expand Down
7 changes: 7 additions & 0 deletions src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,11 @@ struct GlobalVariable {
/// to refer to the global in the function body. This is the id of that access,
/// updated for each function in `write_function`.
access_id: Word,

/// The ID of a type that is the base for the array of bindings.
/// This is needed because the block encoder isn't expected to generate new types,
/// and there are cases where we need to get a pointer to one of those bases.
pointer_to_binding_array_element_type_id: Word,
}

impl GlobalVariable {
Expand All @@ -502,6 +507,7 @@ impl GlobalVariable {
var_id: 0,
handle_id: 0,
access_id: 0,
pointer_to_binding_array_element_type_id: 0,
}
}

Expand All @@ -510,6 +516,7 @@ impl GlobalVariable {
var_id: id,
handle_id: 0,
access_id: 0,
pointer_to_binding_array_element_type_id: 0,
}
}

Expand Down
67 changes: 37 additions & 30 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -597,36 +597,43 @@ impl Writer {
// Handle globals are pre-emitted and should be loaded automatically.
//
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
let is_binding_array = match ir_module.types[var.ty].inner {
crate::TypeInner::BindingArray { .. } => true,
_ => false,
};

if var.space == crate::AddressSpace::Handle && !is_binding_array {
let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
let id = self.id_gen.next();
prelude
.body
.push(Instruction::load(var_type_id, id, gv.var_id, None));
gv.access_id = gv.var_id;
gv.handle_id = id;
} else if global_needs_wrapper(ir_module, var) {
let class = map_storage_class(var.space);
let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?;
let index_id = self.get_index_constant(0);

let id = self.id_gen.next();
prelude.body.push(Instruction::access_chain(
pointer_type_id,
id,
gv.var_id,
&[index_id],
));
gv.access_id = id;
} else {
// by default, the variable ID is accessed as is
gv.access_id = gv.var_id;
};
match ir_module.types[var.ty].inner {
crate::TypeInner::BindingArray { base, size: _ } => {
gv.access_id = gv.var_id;
if var.space != crate::AddressSpace::Handle {
let class = map_storage_class(var.space);
gv.pointer_to_binding_array_element_type_id =
self.get_pointer_id(&ir_module.types, base, class)?;
}
}
_ => {
if var.space == crate::AddressSpace::Handle {
let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
let id = self.id_gen.next();
prelude
.body
.push(Instruction::load(var_type_id, id, gv.var_id, None));
gv.access_id = gv.var_id;
gv.handle_id = id;
} else if global_needs_wrapper(ir_module, var) {
let class = map_storage_class(var.space);
let pointer_type_id =
self.get_pointer_id(&ir_module.types, var.ty, class)?;
let index_id = self.get_index_constant(0);
let id = self.id_gen.next();
prelude.body.push(Instruction::access_chain(
pointer_type_id,
id,
gv.var_id,
&[index_id],
));
gv.access_id = id;
} else {
// by default, the variable ID is accessed as is
gv.access_id = gv.var_id;
};
}
}

// work around borrow checking in the presence of `self.xxx()` calls
self.global_variables[handle.index()] = gv;
Expand Down
4 changes: 3 additions & 1 deletion tests/in/binding-buffer-arrays.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ struct UniformIndex {
index: u32
}

struct Foo { x: u32 }
struct Foo { x: u32, far: array<i32> }
@group(0) @binding(0)
var<storage, read> storage_array: binding_array<Foo, 1>;
@group(0) @binding(10)
Expand All @@ -23,5 +23,7 @@ fn main(fragment_in: FragmentIn) -> @location(0) u32 {
u1 += storage_array[uniform_index].x;
u1 += storage_array[non_uniform_index].x;

u1 += arrayLength(&storage_array[0].far);

return u1;
}
2 changes: 1 addition & 1 deletion tests/out/msl/workgroup-uniform-load.msl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using metal::uint;

constexpr constant unsigned SIZE = 128u;
struct type_2 {
int inner[SIZE];
int inner[128];
};

struct test_workgroupUniformLoadInput {
Expand Down
Loading

0 comments on commit 9fdfb1a

Please sign in to comment.