From a5d2d091ecb35ac705c1df1f5d4a1d812b8e476c Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sat, 18 Mar 2023 13:28:30 -0700 Subject: [PATCH] spv-out: refactor non-uniform indexing semantics to support buffers --- src/back/spv/block.rs | 68 +++++++++++++++------- src/back/spv/writer.rs | 7 +++ tests/out/spv/binding-arrays.spvasm | 44 +++++++------- tests/out/spv/binding-buffer-arrays.spvasm | 3 + 4 files changed, 78 insertions(+), 44 deletions(-) diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 45a02ac7df..f69bc36014 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -3,8 +3,9 @@ Implementations for `BlockContext` methods. */ use super::{ - index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, Dimension, - Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, WriterFlags, + helpers, index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, + Dimension, Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, + WriterFlags, }; use crate::{arena::Handle, proc::TypeResolution}; use spirv::Word; @@ -220,7 +221,6 @@ impl<'w> BlockContext<'w> { block: &mut Block, ) -> Result<(), Error> { let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty); - let id = match self.ir_function.expressions[expr_handle] { crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => { // See `is_intermediate`; we'll handle this later in @@ -236,9 +236,15 @@ impl<'w> BlockContext<'w> { crate::TypeInner::BindingArray { base: binding_type, .. } => { + let space = match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(gvar) => { + self.ir_module.global_variables[gvar].space + } + _ => unreachable!(), + }; let binding_array_false_pointer = LookupType::Local(LocalType::Pointer { base: binding_type, - class: spirv::StorageClass::UniformConstant, + class: helpers::map_storage_class(space), }); let result_id = match self.write_expression_pointer( @@ -264,15 +270,6 @@ impl<'w> BlockContext<'w> { None, )); - if self.fun_info[index].uniformity.non_uniform_result.is_some() { - self.writer.require_any( - "NonUniformEXT", - &[spirv::Capability::ShaderNonUniform], - )?; - self.writer.use_extension("SPV_EXT_descriptor_indexing"); - self.writer - .decorate(load_id, spirv::Decoration::NonUniform, &[]); - } load_id } ref other => { @@ -315,9 +312,15 @@ impl<'w> BlockContext<'w> { crate::TypeInner::BindingArray { base: binding_type, .. } => { + let space = match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(gvar) => { + self.ir_module.global_variables[gvar].space + } + _ => unreachable!(), + }; let binding_array_false_pointer = LookupType::Local(LocalType::Pointer { base: binding_type, - class: spirv::StorageClass::UniformConstant, + class: helpers::map_storage_class(space), }); let result_id = match self.write_expression_pointer( @@ -1433,11 +1436,25 @@ impl<'w> BlockContext<'w> { // but we expect these checks to almost always succeed, and keeping branches to a // minimum is essential. let mut accumulated_checks = None; + // Is true if we are accessing into a binding array of buffers with a non-uniform index. + let mut is_non_uniform_binding_array = false; self.temp_list.clear(); let root_id = loop { expr_handle = match self.ir_function.expressions[expr_handle] { crate::Expression::Access { base, index } => { + if let crate::Expression::GlobalVariable(var_handle) = + self.ir_function.expressions[base] + { + let gvar = &self.ir_module.global_variables[var_handle]; + if let crate::TypeInner::BindingArray { .. } = + self.ir_module.types[gvar.ty].inner + { + is_non_uniform_binding_array |= + self.fun_info[index].uniformity.non_uniform_result.is_some(); + } + } + let index_id = match self.write_bounds_check(base, index, block)? { BoundsCheckResult::KnownInBounds(known_index) => { // Even if the index is known, `OpAccessIndex` @@ -1470,7 +1487,6 @@ impl<'w> BlockContext<'w> { } }; self.temp_list.push(index_id); - base } crate::Expression::AccessIndex { base, index } => { @@ -1493,10 +1509,13 @@ impl<'w> BlockContext<'w> { } }; - let pointer = if self.temp_list.is_empty() { - ExpressionPointer::Ready { - pointer_id: root_id, - } + let (pointer_id, expr_pointer) = if self.temp_list.is_empty() { + ( + root_id, + ExpressionPointer::Ready { + pointer_id: root_id, + }, + ) } else { self.temp_list.reverse(); let pointer_id = self.gen_id(); @@ -1507,16 +1526,21 @@ impl<'w> BlockContext<'w> { // caller to generate the branch, the access, the load or store, and // the zero value (for loads). Otherwise, we can emit the access // ourselves, and just hand them the id of the pointer. - match accumulated_checks { + let expr_pointer = match accumulated_checks { Some(condition) => ExpressionPointer::Conditional { condition, access }, None => { block.body.push(access); ExpressionPointer::Ready { pointer_id } } - } + }; + (pointer_id, expr_pointer) }; + if is_non_uniform_binding_array { + self.writer + .decorate_non_uniform_binding_array_access(pointer_id)?; + } - Ok(pointer) + Ok(expr_pointer) } /// Build the instructions for matrix - matrix column operations diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index ba5c572ab9..90c6b5089d 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1966,6 +1966,13 @@ impl Writer { pub const fn get_capabilities_used(&self) -> &crate::FastHashSet { &self.capabilities_used } + + pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> { + self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?; + self.use_extension("SPV_EXT_descriptor_indexing"); + self.decorate(id, spirv::Decoration::NonUniform, &[]); + Ok(()) + } } #[test] diff --git a/tests/out/spv/binding-arrays.spvasm b/tests/out/spv/binding-arrays.spvasm index 30cf4847ce..31929d155b 100644 --- a/tests/out/spv/binding-arrays.spvasm +++ b/tests/out/spv/binding-arrays.spvasm @@ -35,28 +35,28 @@ OpMemberDecorate %49 0 Offset 0 OpDecorate %65 Location 0 OpDecorate %65 Flat OpDecorate %68 Location 0 -OpDecorate %97 NonUniform -OpDecorate %121 NonUniform -OpDecorate %123 NonUniform -OpDecorate %148 NonUniform -OpDecorate %150 NonUniform -OpDecorate %188 NonUniform -OpDecorate %219 NonUniform -OpDecorate %238 NonUniform -OpDecorate %257 NonUniform -OpDecorate %279 NonUniform -OpDecorate %281 NonUniform -OpDecorate %303 NonUniform -OpDecorate %305 NonUniform -OpDecorate %327 NonUniform -OpDecorate %329 NonUniform -OpDecorate %351 NonUniform -OpDecorate %353 NonUniform -OpDecorate %375 NonUniform -OpDecorate %377 NonUniform -OpDecorate %399 NonUniform -OpDecorate %401 NonUniform -OpDecorate %424 NonUniform +OpDecorate %96 NonUniform +OpDecorate %120 NonUniform +OpDecorate %122 NonUniform +OpDecorate %147 NonUniform +OpDecorate %149 NonUniform +OpDecorate %187 NonUniform +OpDecorate %218 NonUniform +OpDecorate %237 NonUniform +OpDecorate %256 NonUniform +OpDecorate %278 NonUniform +OpDecorate %280 NonUniform +OpDecorate %302 NonUniform +OpDecorate %304 NonUniform +OpDecorate %326 NonUniform +OpDecorate %328 NonUniform +OpDecorate %350 NonUniform +OpDecorate %352 NonUniform +OpDecorate %374 NonUniform +OpDecorate %376 NonUniform +OpDecorate %398 NonUniform +OpDecorate %400 NonUniform +OpDecorate %423 NonUniform %2 = OpTypeVoid %4 = OpTypeInt 32 1 %3 = OpConstant %4 5 diff --git a/tests/out/spv/binding-buffer-arrays.spvasm b/tests/out/spv/binding-buffer-arrays.spvasm index 413dab3726..cd73bd756c 100644 --- a/tests/out/spv/binding-buffer-arrays.spvasm +++ b/tests/out/spv/binding-buffer-arrays.spvasm @@ -3,7 +3,9 @@ ; Generator: rspirv ; Bound: 66 OpCapability Shader +OpCapability ShaderNonUniform OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_EXT_descriptor_indexing" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %29 "main" %24 %27 @@ -22,6 +24,7 @@ OpMemberDecorate %17 0 Offset 0 OpDecorate %24 Location 0 OpDecorate %24 Flat OpDecorate %27 Location 0 +OpDecorate %57 NonUniform %2 = OpTypeVoid %4 = OpTypeInt 32 1 %3 = OpConstant %4 1