From 1fb9617113782af9fb6e3f8a9108dbb629084e8e Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Mon, 5 Jun 2023 23:00:59 -0700 Subject: [PATCH] spv-out: implement OpArrayLength on array buffer bindings --- src/back/spv/index.rs | 79 +++++++-- src/back/spv/mod.rs | 7 + src/back/spv/writer.rs | 67 ++++---- tests/in/binding-buffer-arrays.wgsl | 4 +- tests/out/spv/binding-buffer-arrays.spvasm | 179 +++++++++++---------- tests/out/wgsl/binding-buffer-arrays.wgsl | 7 +- 6 files changed, 210 insertions(+), 133 deletions(-) diff --git a/src/back/spv/index.rs b/src/back/spv/index.rs index d2cbdf4d6d..0b4335a1bd 100644 --- a/src/back/spv/index.rs +++ b/src/back/spv/index.rs @@ -42,32 +42,81 @@ impl<'w> BlockContext<'w> { array: Handle, block: &mut Block, ) -> Result { - // Naga IR permits runtime-sized arrays as global variables or as the - // final member of a struct that is a global variable. SPIR-V permits - // only the latter, so this back end wraps bare runtime-sized arrays - // in a made-up struct; see `helpers::global_needs_wrapper` and its uses. - // This code must handle both cases. - let (structure_id, last_member_index) = match self.ir_function.expressions[array] { + // Naga IR permits runtime-sized arrays as global variables, or as the + // final member of a struct that is a global variable, or one of these + // inside a buffer that is itself an element in a buffer bindings array. + // SPIR-V requires that runtime-sized arrays are wrapped in structs. + // See `helpers::global_needs_wrapper` and its uses. + let (opt_array_index, global_handle, opt_last_member_index) = match self + .ir_function + .expressions[array] + { + // Note that SPIR-V forbids `OpArrayLength` on a variable pointer, + // so we aren't handling `crate::Expression::Access` here. crate::Expression::AccessIndex { base, index } => { match self.ir_function.expressions[base] { - crate::Expression::GlobalVariable(handle) => ( - self.writer.global_variables[handle.index()].access_id, - index, - ), - _ => return Err(Error::Validation("array length expression")), + // The global variable is an array of buffer bindings of structs, + // and we are accessing the last member. + crate::Expression::AccessIndex { + base: base_outer, + index: index_outer, + } => match self.ir_function.expressions[base_outer] { + crate::Expression::GlobalVariable(handle) => { + (Some(index_outer), handle, Some(index)) + } + _ => return Err(Error::Validation("array length expression case-1a")), + }, + crate::Expression::GlobalVariable(handle) => { + let global = &self.ir_module.global_variables[handle]; + match self.ir_module.types[global.ty].inner { + // The global variable is an array of buffer bindings of run-time arrays. + crate::TypeInner::BindingArray { .. } => (Some(index), handle, None), + // The global variable is a struct, and we are accessing the last member + _ => (None, handle, Some(index)), + } + } + _ => return Err(Error::Validation("array length expression case-1c")), } } + // The global variable is a run-time array. crate::Expression::GlobalVariable(handle) => { let global = &self.ir_module.global_variables[handle]; if !global_needs_wrapper(self.ir_module, global) { - return Err(Error::Validation("array length expression")); + return Err(Error::Validation("array length expression case-2")); } - - (self.writer.global_variables[handle.index()].var_id, 0) + (None, handle, None) } - _ => return Err(Error::Validation("array length expression")), + _ => return Err(Error::Validation("array length expression case-3")), }; + let gvar = self.writer.global_variables[global_handle.index()].clone(); + let (last_member_index, gvar_id) = match opt_last_member_index { + Some(index) => (index, gvar.access_id), + None => { + let global = &self.ir_module.global_variables[global_handle]; + if !global_needs_wrapper(self.ir_module, global) { + return Err(Error::Validation( + "pointer to a global that is not a wrapped array", + )); + } + (0, gvar.var_id) + } + }; + let structure_id = match opt_array_index { + // We are indexing inside a binding array, generate the access op. + Some(index) => { + let index_id = self.get_index_constant(index); + let structure_id = self.gen_id(); + block.body.push(Instruction::access_chain( + gvar.pointer_to_binding_array_element_type_id, + structure_id, + gvar_id, + &[index_id], + )); + structure_id + } + None => gvar_id, + }; let length_id = self.gen_id(); block.body.push(Instruction::array_length( self.writer.get_uint_type_id(), diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 1e10d2e9c6..53ffdac43f 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -495,6 +495,11 @@ struct GlobalVariable { /// to refer to the global in the function body. This is the id of that access, /// updated for each function in `write_function`. access_id: Word, + + /// The ID of a type that is the base for the array of bindings. + /// This is needed because the block encoder isn't expected to generate new types, + /// and there are cases where we need to get a pointer to one of those bases. + pointer_to_binding_array_element_type_id: Word, } impl GlobalVariable { @@ -503,6 +508,7 @@ impl GlobalVariable { var_id: 0, handle_id: 0, access_id: 0, + pointer_to_binding_array_element_type_id: 0, } } @@ -511,6 +517,7 @@ impl GlobalVariable { var_id: id, handle_id: 0, access_id: 0, + pointer_to_binding_array_element_type_id: 0, } } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 64e27fae13..1fd2ae5416 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -597,36 +597,43 @@ impl Writer { // Handle globals are pre-emitted and should be loaded automatically. // // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. - let is_binding_array = match ir_module.types[var.ty].inner { - crate::TypeInner::BindingArray { .. } => true, - _ => false, - }; - - if var.space == crate::AddressSpace::Handle && !is_binding_array { - let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); - let id = self.id_gen.next(); - prelude - .body - .push(Instruction::load(var_type_id, id, gv.var_id, None)); - gv.access_id = gv.var_id; - gv.handle_id = id; - } else if global_needs_wrapper(ir_module, var) { - let class = map_storage_class(var.space); - let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?; - let index_id = self.get_index_constant(0); - - let id = self.id_gen.next(); - prelude.body.push(Instruction::access_chain( - pointer_type_id, - id, - gv.var_id, - &[index_id], - )); - gv.access_id = id; - } else { - // by default, the variable ID is accessed as is - gv.access_id = gv.var_id; - }; + match ir_module.types[var.ty].inner { + crate::TypeInner::BindingArray { base, size: _ } => { + gv.access_id = gv.var_id; + if var.space != crate::AddressSpace::Handle { + let class = map_storage_class(var.space); + gv.pointer_to_binding_array_element_type_id = + self.get_pointer_id(&ir_module.types, base, class)?; + } + } + _ => { + if var.space == crate::AddressSpace::Handle { + let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); + let id = self.id_gen.next(); + prelude + .body + .push(Instruction::load(var_type_id, id, gv.var_id, None)); + gv.access_id = gv.var_id; + gv.handle_id = id; + } else if global_needs_wrapper(ir_module, var) { + let class = map_storage_class(var.space); + let pointer_type_id = + self.get_pointer_id(&ir_module.types, var.ty, class)?; + let index_id = self.get_index_constant(0); + let id = self.id_gen.next(); + prelude.body.push(Instruction::access_chain( + pointer_type_id, + id, + gv.var_id, + &[index_id], + )); + gv.access_id = id; + } else { + // by default, the variable ID is accessed as is + gv.access_id = gv.var_id; + }; + } + } // work around borrow checking in the presence of `self.xxx()` calls self.global_variables[handle.index()] = gv; diff --git a/tests/in/binding-buffer-arrays.wgsl b/tests/in/binding-buffer-arrays.wgsl index a76d52c200..e0acc3af48 100644 --- a/tests/in/binding-buffer-arrays.wgsl +++ b/tests/in/binding-buffer-arrays.wgsl @@ -2,7 +2,7 @@ struct UniformIndex { index: u32 } -struct Foo { x: u32 } +struct Foo { x: u32, far: array } @group(0) @binding(0) var storage_array: binding_array; @group(0) @binding(10) @@ -23,5 +23,7 @@ fn main(fragment_in: FragmentIn) -> @location(0) u32 { u1 += storage_array[uniform_index].x; u1 += storage_array[non_uniform_index].x; + u1 += arrayLength(&storage_array[0].far); + return u1; } diff --git a/tests/out/spv/binding-buffer-arrays.spvasm b/tests/out/spv/binding-buffer-arrays.spvasm index b09377b6b0..61811a494e 100644 --- a/tests/out/spv/binding-buffer-arrays.spvasm +++ b/tests/out/spv/binding-buffer-arrays.spvasm @@ -1,104 +1,113 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 65 +; Bound: 71 OpCapability Shader OpCapability ShaderNonUniform OpExtension "SPV_KHR_storage_buffer_storage_class" OpExtension "SPV_EXT_descriptor_indexing" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %28 "main" %23 %26 -OpExecutionMode %28 OriginUpperLeft +OpEntryPoint Fragment %29 "main" %24 %27 +OpExecutionMode %29 OriginUpperLeft OpMemberDecorate %6 0 Offset 0 -OpMemberDecorate %7 0 Offset 0 -OpMemberDecorate %10 0 Offset 0 -OpDecorate %11 NonWritable -OpDecorate %11 DescriptorSet 0 -OpDecorate %11 Binding 0 -OpDecorate %7 Block -OpDecorate %15 DescriptorSet 0 -OpDecorate %15 Binding 10 -OpDecorate %16 Block -OpMemberDecorate %16 0 Offset 0 -OpDecorate %23 Location 0 -OpDecorate %23 Flat -OpDecorate %26 Location 0 -OpDecorate %56 NonUniform +OpDecorate %7 ArrayStride 4 +OpMemberDecorate %8 0 Offset 0 +OpMemberDecorate %8 1 Offset 4 +OpMemberDecorate %11 0 Offset 0 +OpDecorate %12 NonWritable +OpDecorate %12 DescriptorSet 0 +OpDecorate %12 Binding 0 +OpDecorate %8 Block +OpDecorate %16 DescriptorSet 0 +OpDecorate %16 Binding 10 +OpDecorate %17 Block +OpMemberDecorate %17 0 Offset 0 +OpDecorate %24 Location 0 +OpDecorate %24 Flat +OpDecorate %27 Location 0 +OpDecorate %57 NonUniform %2 = OpTypeVoid %4 = OpTypeInt 32 1 %3 = OpConstant %4 1 %5 = OpTypeInt 32 0 %6 = OpTypeStruct %5 -%7 = OpTypeStruct %5 -%9 = OpConstant %5 1 -%8 = OpTypeArray %7 %9 -%10 = OpTypeStruct %5 -%14 = OpConstant %5 10 -%13 = OpTypeArray %7 %14 -%12 = OpTypePointer StorageBuffer %13 -%11 = OpVariable %12 StorageBuffer -%16 = OpTypeStruct %6 -%17 = OpTypePointer Uniform %16 -%15 = OpVariable %17 Uniform -%19 = OpTypePointer Function %5 -%20 = OpConstantNull %5 -%24 = OpTypePointer Input %5 -%23 = OpVariable %24 Input -%27 = OpTypePointer Output %5 -%26 = OpVariable %27 Output -%29 = OpTypeFunction %2 -%30 = OpTypePointer Uniform %6 -%31 = OpConstant %5 0 -%33 = OpTypePointer StorageBuffer %8 -%35 = OpTypePointer Uniform %5 -%39 = OpTypePointer StorageBuffer %7 -%40 = OpTypePointer StorageBuffer %5 -%46 = OpTypeBool -%48 = OpConstantNull %5 -%57 = OpConstantNull %5 -%28 = OpFunction %2 None %29 -%21 = OpLabel -%18 = OpVariable %19 Function %20 -%25 = OpLoad %5 %23 -%22 = OpCompositeConstruct %10 %25 -%32 = OpAccessChain %30 %15 %31 -OpBranch %34 -%34 = OpLabel -%36 = OpAccessChain %35 %32 %31 -%37 = OpLoad %5 %36 -%38 = OpCompositeExtract %5 %22 0 -OpStore %18 %31 -%41 = OpAccessChain %40 %11 %31 %31 -%42 = OpLoad %5 %41 -%43 = OpLoad %5 %18 -%44 = OpIAdd %5 %43 %42 -OpStore %18 %44 -%45 = OpULessThan %46 %37 %9 -OpSelectionMerge %49 None -OpBranchConditional %45 %50 %49 +%7 = OpTypeRuntimeArray %4 +%8 = OpTypeStruct %5 %7 +%10 = OpConstant %5 1 +%9 = OpTypeArray %8 %10 +%11 = OpTypeStruct %5 +%15 = OpConstant %5 10 +%14 = OpTypeArray %8 %15 +%13 = OpTypePointer StorageBuffer %14 +%12 = OpVariable %13 StorageBuffer +%17 = OpTypeStruct %6 +%18 = OpTypePointer Uniform %17 +%16 = OpVariable %18 Uniform +%20 = OpTypePointer Function %5 +%21 = OpConstantNull %5 +%25 = OpTypePointer Input %5 +%24 = OpVariable %25 Input +%28 = OpTypePointer Output %5 +%27 = OpVariable %28 Output +%30 = OpTypeFunction %2 +%31 = OpTypePointer StorageBuffer %8 +%32 = OpTypePointer Uniform %6 +%33 = OpConstant %5 0 +%35 = OpTypePointer StorageBuffer %9 +%37 = OpTypePointer Uniform %5 +%41 = OpTypePointer StorageBuffer %5 +%47 = OpTypeBool +%49 = OpConstantNull %5 +%58 = OpConstantNull %5 +%65 = OpTypePointer StorageBuffer %7 +%29 = OpFunction %2 None %30 +%22 = OpLabel +%19 = OpVariable %20 Function %21 +%26 = OpLoad %5 %24 +%23 = OpCompositeConstruct %11 %26 +%34 = OpAccessChain %32 %16 %33 +OpBranch %36 +%36 = OpLabel +%38 = OpAccessChain %37 %34 %33 +%39 = OpLoad %5 %38 +%40 = OpCompositeExtract %5 %23 0 +OpStore %19 %33 +%42 = OpAccessChain %41 %12 %33 %33 +%43 = OpLoad %5 %42 +%44 = OpLoad %5 %19 +%45 = OpIAdd %5 %44 %43 +OpStore %19 %45 +%46 = OpULessThan %47 %39 %10 +OpSelectionMerge %50 None +OpBranchConditional %46 %51 %50 +%51 = OpLabel +%48 = OpAccessChain %41 %12 %39 %33 +%52 = OpLoad %5 %48 +OpBranch %50 %50 = OpLabel -%47 = OpAccessChain %40 %11 %37 %31 -%51 = OpLoad %5 %47 -OpBranch %49 -%49 = OpLabel -%52 = OpPhi %5 %48 %34 %51 %50 -%53 = OpLoad %5 %18 -%54 = OpIAdd %5 %53 %52 -OpStore %18 %54 -%55 = OpULessThan %46 %38 %9 -OpSelectionMerge %58 None -OpBranchConditional %55 %59 %58 +%53 = OpPhi %5 %49 %36 %52 %51 +%54 = OpLoad %5 %19 +%55 = OpIAdd %5 %54 %53 +OpStore %19 %55 +%56 = OpULessThan %47 %40 %10 +OpSelectionMerge %59 None +OpBranchConditional %56 %60 %59 +%60 = OpLabel +%57 = OpAccessChain %41 %12 %40 %33 +%61 = OpLoad %5 %57 +OpBranch %59 %59 = OpLabel -%56 = OpAccessChain %40 %11 %38 %31 -%60 = OpLoad %5 %56 -OpBranch %58 -%58 = OpLabel -%61 = OpPhi %5 %57 %49 %60 %59 -%62 = OpLoad %5 %18 -%63 = OpIAdd %5 %62 %61 -OpStore %18 %63 -%64 = OpLoad %5 %18 -OpStore %26 %64 +%62 = OpPhi %5 %58 %50 %61 %60 +%63 = OpLoad %5 %19 +%64 = OpIAdd %5 %63 %62 +OpStore %19 %64 +%66 = OpAccessChain %31 %12 %33 +%67 = OpArrayLength %5 %66 1 +%68 = OpLoad %5 %19 +%69 = OpIAdd %5 %68 %67 +OpStore %19 %69 +%70 = OpLoad %5 %19 +OpStore %27 %70 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/binding-buffer-arrays.wgsl b/tests/out/wgsl/binding-buffer-arrays.wgsl index 4a36faa399..6aac5d254a 100644 --- a/tests/out/wgsl/binding-buffer-arrays.wgsl +++ b/tests/out/wgsl/binding-buffer-arrays.wgsl @@ -4,6 +4,7 @@ struct UniformIndex { struct Foo { x: u32, + far: array, } struct FragmentIn { @@ -31,6 +32,8 @@ fn main(fragment_in: FragmentIn) -> @location(0) @interpolate(flat) u32 { let _e23 = storage_array[non_uniform_index].x; let _e24 = u1_; u1_ = (_e24 + _e23); - let _e26 = u1_; - return _e26; + let _e31 = u1_; + u1_ = (_e31 + arrayLength((&storage_array[0].far))); + let _e33 = u1_; + return _e33; }