diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 1167357e8d..1eebbee067 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -252,6 +252,28 @@ impl StatementGraph { } "Atomic" } + S::RayQuery { query, ref fun } => { + self.dependencies.push((id, query, "query")); + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + self.dependencies.push(( + id, + acceleration_structure, + "acceleration_structure", + )); + self.dependencies.push((id, descriptor, "descriptor")); + "RayQueryInitialize" + } + crate::RayQueryFunction::Proceed { result } => { + self.emits.push((id, result)); + "RayQueryProceed" + } + crate::RayQueryFunction::Terminate => "RayQueryTerminate", + } + } }; // Set the last node to the merge node last_node = merge_id; @@ -550,6 +572,12 @@ fn write_function_expressions( edges.insert("", expr); ("ArrayLength".into(), 7) } + E::RayQueryProceedResult => ("rayQueryProceedResult".into(), 4), + E::RayQueryGetIntersection { query, committed } => { + edges.insert("", query); + let ty = if committed { "Committed" } else { "Candidate" }; + (format!("rayQueryGet{}Intersection", ty).into(), 4) + } }; // give uniform expressions an outline diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 44685fb99e..9195b96837 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -879,6 +879,8 @@ impl<'a, W: Write> Writer<'a, W> { | TypeInner::Struct { .. } | TypeInner::Image { .. } | TypeInner::Sampler { .. } + | TypeInner::AccelerationStructure + | TypeInner::RayQuery | TypeInner::BindingArray { .. } => { return Err(Error::Custom(format!("Unable to write type {inner:?}"))) } @@ -2195,6 +2197,7 @@ impl<'a, W: Write> Writer<'a, W> { self.write_expr(value, ctx)?; writeln!(self.out, ");")?; } + Statement::RayQuery { .. } => unreachable!(), } Ok(()) @@ -3277,13 +3280,17 @@ impl<'a, W: Write> Writer<'a, W> { } } // These expressions never show up in `Emit`. - Expression::CallResult(_) | Expression::AtomicResult { .. } => unreachable!(), + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; self.write_expr(expr, ctx)?; write!(self.out, ".length())")? } + // not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), } Ok(()) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index d11032bbf5..f9e52914f7 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1980,6 +1980,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } + Statement::RayQuery { .. } => unreachable!(), } Ok(()) @@ -2878,8 +2879,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_expr(module, reject, func_ctx)?; write!(self.out, ")")? } + // Not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), // Nothing to do here, since call expression already cached - Expression::CallResult(_) | Expression::AtomicResult { .. } => {} + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult => {} } if !closing_bracket.is_empty() { diff --git a/src/back/mod.rs b/src/back/mod.rs index 6755983c07..8467ee787b 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -218,3 +218,33 @@ impl crate::Statement { } } } + +bitflags::bitflags! { + /// Ray flags, for a [`RayDesc`]'s `flags` field. + /// + /// Note that these exactly correspond to the SPIR-V "Ray Flags" mask, and + /// the SPIR-V backend passes them directly through to the + /// `OpRayQueryInitializeKHR` instruction. (We have to choose something, so + /// we might as well make one back end's life easier.) + /// + /// [`RayDesc`]: crate::Module::generate_ray_desc_type + #[derive(Default)] + pub struct RayFlag: u32 { + const OPAQUE = 0x01; + const NO_OPAQUE = 0x02; + const TERMINATE_ON_FIRST_HIT = 0x04; + const SKIP_CLOSEST_HIT_SHADER = 0x08; + const CULL_BACK_FACING = 0x10; + const CULL_FRONT_FACING = 0x20; + const CULL_OPAQUE = 0x40; + const CULL_NO_OPAQUE = 0x80; + const SKIP_TRIANGLES = 0x100; + const SKIP_AABBS = 0x200; + } +} + +#[repr(u32)] +enum RayIntersectionType { + Triangle = 1, + BoundingBox = 4, +} diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index d794976602..3174d4b756 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -314,10 +314,7 @@ impl Options { match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), - texture: None, - sampler: None, - binding_array_size: None, - mutable: false, + ..Default::default() })), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", @@ -338,10 +335,7 @@ impl Options { match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), - texture: None, - sampler: None, - binding_array_size: None, - mutable: false, + ..Default::default() })), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index ee23ca294a..88600a8f11 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -25,6 +25,13 @@ const WRAPPED_ARRAY_FIELD: &str = "inner"; // Some more general handling of pointers is needed to be implemented here. const ATOMIC_REFERENCE: &str = "&"; +const RT_NAMESPACE: &str = "metal::raytracing"; +const RAY_QUERY_TYPE: &str = "_RayQuery"; +const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector"; +const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection"; +const RAY_QUERY_FIELD_READY: &str = "ready"; +const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type"; + /// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. /// /// The `sizes` slice determines whether this function writes a @@ -194,6 +201,12 @@ impl<'a> Display for TypeContext<'a> { crate::TypeInner::Sampler { comparison: _ } => { write!(out, "{NAMESPACE}::sampler") } + crate::TypeInner::AccelerationStructure => { + write!(out, "{RT_NAMESPACE}::instance_acceleration_structure") + } + crate::TypeInner::RayQuery => { + write!(out, "{RAY_QUERY_TYPE}") + } crate::TypeInner::BindingArray { base, size } => { let base_tyname = Self { handle: base, @@ -485,7 +498,11 @@ impl crate::Type { // composite types are better to be aliased, regardless of the name Ti::Struct { .. } | Ti::Array { .. } => true, // handle types may be different, depending on the global var access, so we always inline them - Ti::Image { .. } | Ti::Sampler { .. } | Ti::BindingArray { .. } => false, + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => false, } } } @@ -1831,7 +1848,9 @@ impl Writer { _ => return Err(Error::Validation), }, // has to be a named expression - crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => { + crate::Expression::CallResult(_) + | crate::Expression::AtomicResult { .. } + | crate::Expression::RayQueryProceedResult => { unreachable!() } crate::Expression::ArrayLength(expr) => { @@ -1856,6 +1875,39 @@ impl Writer { write!(self.out, ")")?; } } + crate::Expression::RayQueryGetIntersection { query, committed } => { + if !committed { + unimplemented!() + } + let ty = context.module.special_types.ray_intersection.unwrap(); + let type_name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?; + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?; + let fields = [ + "distance", + "user_instance_id", + "instance_id", + "", // SBT offset + "geometry_id", + "primitive_id", + "triangle_barycentric_coord", + "triangle_front_facing", + "", // padding + "object_to_world_transform", + "world_to_object_transform", + ]; + for field in fields { + write!(self.out, ", ")?; + if field.is_empty() { + write!(self.out, "{{}}")?; + } else { + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?; + } + } + write!(self.out, "}}")?; + } } Ok(()) } @@ -2309,6 +2361,7 @@ impl Writer { ) { use crate::Expression; self.need_bake_expressions.clear(); + for (expr_handle, expr) in func.expressions.iter() { // Expressions whose reference count is above the // threshold should always be stored in temporaries. @@ -2316,6 +2369,16 @@ impl Writer { let min_ref_count = func.expressions[expr_handle].bake_ref_count(); if min_ref_count <= expr_info.ref_count { self.need_bake_expressions.insert(expr_handle); + } else { + match expr_info.ty { + // force ray desc to be baked: it's used multiple times internally + TypeResolution::Handle(h) + if Some(h) == context.module.special_types.ray_desc => + { + self.need_bake_expressions.insert(expr_handle); + } + _ => {} + } } if let Expression::Math { fun, arg, arg1, .. } = *expr { @@ -2327,11 +2390,11 @@ impl Writer { // times, once for each component (see `put_dot_product`), so to // avoid duplicated evaluation, we must bake integer operands. - use crate::TypeInner; // check what kind of product this is depending // on the resolve type of the Dot function itself - let inner = context.resolve_type(expr_handle); - if let TypeInner::Scalar { kind, .. } = *inner { + if let crate::TypeInner::Scalar { kind, .. } = + *context.resolve_type(expr_handle) + { match kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint => { self.need_bake_expressions.insert(arg); @@ -2752,6 +2815,100 @@ impl Writer { // done writeln!(self.out, ";")?; } + crate::Statement::RayQuery { query, ref fun } => { + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //TODO: how to deal with winding? + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?; + { + let f_opaque = back::RayFlag::CULL_OPAQUE.bits(); + let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?; + } + { + let f_opaque = back::RayFlag::OPAQUE.bits(); + let f_no_opaque = back::RayFlag::NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?; + } + { + let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + writeln!(self.out, ".flags & {flag}) != 0);")?; + } + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray(" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".origin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".dir, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmax), ")?; + self.put_expression(acceleration_structure, &context.expression, true)?; + write!(self.out, ", ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".cull_mask);")?; + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?; + } + crate::RayQueryFunction::Proceed { result } => { + write!(self.out, "{level}")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?; + //TODO: actually proceed? + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?; + } + crate::RayQueryFunction::Terminate => { + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?; + } + } + } } } @@ -2863,14 +3020,41 @@ impl Writer { writeln!(self.out)?; // Work around Metal bug where `uint` is not available by default writeln!(self.out, "using {NAMESPACE}::uint;")?; - writeln!(self.out)?; + if module.types.iter().any(|(_, t)| match t.inner { + crate::TypeInner::RayQuery => true, + _ => false, + }) { + let tab = back::INDENT; + writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?; + let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>"); + writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?; + writeln!( + self.out, + "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};" + )?; + writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?; + writeln!(self.out, "}};")?; + writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?; + let v_triangle = back::RayIntersectionType::Triangle as u32; + let v_bbox = back::RayIntersectionType::BoundingBox as u32; + writeln!( + self.out, + "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : " + )?; + writeln!( + self.out, + "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;" + )?; + writeln!(self.out, "}}")?; + } if options .bounds_check_policies .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite) { self.put_default_constructible()?; } + writeln!(self.out)?; { let mut indices = vec![]; @@ -2912,11 +3096,12 @@ impl Writer { /// /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite fn put_default_constructible(&mut self) -> BackendResult { + let tab = back::INDENT; writeln!(self.out, "struct DefaultConstructible {{")?; - writeln!(self.out, " template")?; - writeln!(self.out, " operator T() && {{")?; - writeln!(self.out, " return T {{}};")?; - writeln!(self.out, " }}")?; + writeln!(self.out, "{tab}template")?; + writeln!(self.out, "{tab}operator T() && {{")?; + writeln!(self.out, "{tab}{tab}return T {{}};")?; + writeln!(self.out, "{tab}}}")?; writeln!(self.out, "}};")?; Ok(()) } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index c3fa8455e9..b28b94fe91 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1084,9 +1084,9 @@ impl<'w> BlockContext<'w> { } } crate::Expression::FunctionArgument(index) => self.function.parameter_id(index), - crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => { - self.cached[expr_handle] - } + crate::Expression::CallResult(_) + | crate::Expression::AtomicResult { .. } + | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -1386,6 +1386,12 @@ impl<'w> BlockContext<'w> { id } crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, + crate::Expression::RayQueryGetIntersection { query, committed } => { + if !committed { + return Err(Error::FeatureNotImplemented("candidate intersection")); + } + self.write_ray_query_get_intersection(query, block) + } }; self.cached[expr_handle] = id; @@ -2196,6 +2202,9 @@ impl<'w> BlockContext<'w> { block.body.push(instruction); } + crate::Statement::RayQuery { query, ref fun } => { + self.write_ray_query_function(query, fun, &mut block); + } } } diff --git a/src/back/spv/image.rs b/src/back/spv/image.rs index 1a136af77e..dc4f249949 100644 --- a/src/back/spv/image.rs +++ b/src/back/spv/image.rs @@ -373,7 +373,7 @@ impl<'w> BlockContext<'w> { }) } - fn get_image_id(&mut self, expr_handle: Handle) -> Word { + pub(super) fn get_image_id(&mut self, expr_handle: Handle) -> Word { let id = match self.ir_function.expressions[expr_handle] { crate::Expression::GlobalVariable(handle) => { self.writer.global_variables[handle.index()].handle_id diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index c213790188..96d0278285 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -249,6 +249,18 @@ impl super::Instruction { instruction } + pub(super) fn type_acceleration_structure(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeAccelerationStructureKHR); + instruction.set_result(id); + instruction + } + + pub(super) fn type_ray_query(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeRayQueryKHR); + instruction.set_result(id); + instruction + } + pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self { let mut instruction = Self::new(Op::TypeSampledImage); instruction.set_result(id); @@ -627,6 +639,55 @@ impl super::Instruction { instruction } + // + // Ray Query Instructions + // + #[allow(clippy::too_many_arguments)] + pub(super) fn ray_query_initialize( + query: Word, + acceleration_structure: Word, + ray_flags: Word, + cull_mask: Word, + ray_origin: Word, + ray_tmin: Word, + ray_dir: Word, + ray_tmax: Word, + ) -> Self { + let mut instruction = Self::new(Op::RayQueryInitializeKHR); + instruction.add_operand(query); + instruction.add_operand(acceleration_structure); + instruction.add_operand(ray_flags); + instruction.add_operand(cull_mask); + instruction.add_operand(ray_origin); + instruction.add_operand(ray_tmin); + instruction.add_operand(ray_dir); + instruction.add_operand(ray_tmax); + instruction + } + + pub(super) fn ray_query_proceed(result_type_id: Word, id: Word, query: Word) -> Self { + let mut instruction = Self::new(Op::RayQueryProceedKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction + } + + pub(super) fn ray_query_get_intersection( + op: Op, + result_type_id: Word, + id: Word, + query: Word, + intersection: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction.add_operand(intersection); + instruction + } + // // Conversion Instructions // diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 72898219af..9b084911b1 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -10,6 +10,7 @@ mod image; mod index; mod instructions; mod layout; +mod ray; mod recyclable; mod selection; mod writer; @@ -295,6 +296,8 @@ enum LocalType { base: Handle, size: u64, }, + AccelerationStructure, + RayQuery, } /// A type encountered during SPIR-V generation. @@ -383,7 +386,11 @@ fn make_local(inner: &crate::TypeInner) -> Option { class, } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)), crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler, - _ => return None, + crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure, + crate::TypeInner::RayQuery => LocalType::RayQuery, + crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } + | crate::TypeInner::BindingArray { .. } => return None, }) } diff --git a/src/back/spv/ray.rs b/src/back/spv/ray.rs new file mode 100644 index 0000000000..79eb2ff971 --- /dev/null +++ b/src/back/spv/ray.rs @@ -0,0 +1,273 @@ +/*! +Generating SPIR-V for ray query operations. +*/ + +use super::{Block, BlockContext, Instruction, LocalType, LookupType}; +use crate::arena::Handle; + +impl<'w> BlockContext<'w> { + pub(super) fn write_ray_query_function( + &mut self, + query: Handle, + function: &crate::RayQueryFunction, + block: &mut Block, + ) { + let query_id = self.cached[query]; + match *function { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //Note: composite extract indices and types must match `generate_ray_desc_type` + let desc_id = self.cached[descriptor]; + let acc_struct_id = self.get_image_id(acceleration_structure); + let width = 4; + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let ray_flags_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + ray_flags_id, + desc_id, + &[0], + )); + let cull_mask_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + cull_mask_id, + desc_id, + &[1], + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let tmin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmin_id, + desc_id, + &[2], + )); + let tmax_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmax_id, + desc_id, + &[3], + )); + + let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let ray_origin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_origin_id, + desc_id, + &[4], + )); + let ray_dir_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_dir_id, + desc_id, + &[5], + )); + + block.body.push(Instruction::ray_query_initialize( + query_id, + acc_struct_id, + ray_flags_id, + cull_mask_id, + ray_origin_id, + tmin_id, + ray_dir_id, + tmax_id, + )); + } + crate::RayQueryFunction::Proceed { result } => { + let id = self.gen_id(); + self.cached[result] = id; + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + + block + .body + .push(Instruction::ray_query_proceed(result_type_id, id, query_id)); + } + crate::RayQueryFunction::Terminate => {} + } + } + + pub(super) fn write_ray_query_get_intersection( + &mut self, + query: Handle, + block: &mut Block, + ) -> spirv::Word { + let width = 4; + let query_id = self.cached[query]; + let intersection_id = self.writer.get_constant_scalar( + crate::ScalarValue::Uint( + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, + ), + width, + ); + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let kind_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + flag_type_id, + kind_id, + query_id, + intersection_id, + )); + let instance_custom_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR, + flag_type_id, + instance_custom_index_id, + query_id, + intersection_id, + )); + let instance_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceIdKHR, + flag_type_id, + instance_id, + query_id, + intersection_id, + )); + let sbt_record_offset_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR, + flag_type_id, + sbt_record_offset_id, + query_id, + intersection_id, + )); + let geometry_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionGeometryIndexKHR, + flag_type_id, + geometry_index_id, + query_id, + intersection_id, + )); + let primitive_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR, + flag_type_id, + primitive_index_id, + query_id, + intersection_id, + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let t_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); + + let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Bi), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let barycentrics_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionBarycentricsKHR, + barycentrics_type_id, + barycentrics_id, + query_id, + intersection_id, + )); + + let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + pointer_space: None, + })); + let front_face_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionFrontFaceKHR, + bool_type_id, + front_face_id, + query_id, + intersection_id, + )); + + let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width, + })); + let object_to_world_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionObjectToWorldKHR, + transform_type_id, + object_to_world_id, + query_id, + intersection_id, + )); + let world_to_object_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionWorldToObjectKHR, + transform_type_id, + world_to_object_id, + query_id, + intersection_id, + )); + + let id = self.gen_id(); + let intersection_type_id = self.get_type_id(LookupType::Handle( + self.ir_module.special_types.ray_intersection.unwrap(), + )); + //Note: the arguments must match `generate_ray_intersection_type` layout + block.body.push(Instruction::composite_construct( + intersection_type_id, + id, + &[ + kind_id, + t_id, + instance_custom_index_id, + instance_id, + sbt_record_offset_id, + geometry_index_id, + primitive_index_id, + barycentrics_id, + front_face_id, + object_to_world_id, + world_to_object_id, + ], + )); + id + } +} diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index f264c107d3..ba235e6d03 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -350,9 +350,12 @@ impl Writer { pointer_type_id, id, spirv::StorageClass::Function, - init_word.or_else(|| { - let type_id = self.get_type_id(LookupType::Handle(variable.ty)); - Some(self.write_constant_null(type_id)) + init_word.or_else(|| match ir_module.types[variable.ty].inner { + crate::TypeInner::RayQuery => None, + _ => { + let type_id = self.get_type_id(LookupType::Handle(variable.ty)); + Some(self.write_constant_null(type_id)) + } }), ); function @@ -814,47 +817,54 @@ impl Writer { } } - fn request_image_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> { - if let crate::TypeInner::Image { - dim, - arrayed, - class, - } = *inner - { - let sampled = match class { - crate::ImageClass::Sampled { .. } => true, - crate::ImageClass::Depth { .. } => true, - crate::ImageClass::Storage { format, .. } => { - self.request_image_format_capabilities(format.into())?; - false - } - }; + fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> { + match *inner { + crate::TypeInner::Image { + dim, + arrayed, + class, + } => { + let sampled = match class { + crate::ImageClass::Sampled { .. } => true, + crate::ImageClass::Depth { .. } => true, + crate::ImageClass::Storage { format, .. } => { + self.request_image_format_capabilities(format.into())?; + false + } + }; - match dim { - crate::ImageDimension::D1 => { - if sampled { - self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?; - } else { - self.require_any("1D storage images", &[spirv::Capability::Image1D])?; + match dim { + crate::ImageDimension::D1 => { + if sampled { + self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?; + } else { + self.require_any("1D storage images", &[spirv::Capability::Image1D])?; + } } - } - crate::ImageDimension::Cube if arrayed => { - if sampled { - self.require_any( - "sampled cube array images", - &[spirv::Capability::SampledCubeArray], - )?; - } else { - self.require_any( - "cube array storage images", - &[spirv::Capability::ImageCubeArray], - )?; + crate::ImageDimension::Cube if arrayed => { + if sampled { + self.require_any( + "sampled cube array images", + &[spirv::Capability::SampledCubeArray], + )?; + } else { + self.require_any( + "cube array storage images", + &[spirv::Capability::ImageCubeArray], + )?; + } } + _ => {} } - _ => {} } + crate::TypeInner::AccelerationStructure => { + self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?; + } + crate::TypeInner::RayQuery => { + self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?; + } + _ => {} } - Ok(()) } @@ -935,6 +945,8 @@ impl Writer { self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size })); Instruction::type_pointer(id, spirv::StorageClass::UniformConstant, inner_ty) } + LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id), + LocalType::RayQuery => Instruction::type_ray_query(id), }; instruction.to_words(&mut self.logical_layout.declarations); @@ -961,9 +973,9 @@ impl Writer { self.write_type_declaration_local(id, local); - // If it's an image type, request SPIR-V capabilities here, so - // write_type_declaration_local can stay infallible. - self.request_image_capabilities(&ty.inner)?; + // If it's a type that needs SPIR-V capabilities, request them now, + // so write_type_declaration_local can stay infallible. + self.request_type_capabilities(&ty.inner)?; id } @@ -1017,7 +1029,9 @@ impl Writer { | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Image { .. } - | crate::TypeInner::Sampler { .. } => unreachable!(), + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => unreachable!(), }; instruction.to_words(&mut self.logical_layout.declarations); @@ -1756,6 +1770,8 @@ impl Writer { .iter() .flat_map(|entry| entry.function.arguments.iter()) .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty)); + let has_ray_query = ir_module.special_types.ray_desc.is_some() + | ir_module.special_types.ray_intersection.is_some(); if self.physical_layout.version < 0x10300 && has_storage_buffers { // enable the storage buffer class on < SPV-1.3 @@ -1766,6 +1782,10 @@ impl Writer { Instruction::extension("SPV_KHR_multiview") .to_words(&mut self.logical_layout.extensions) } + if has_ray_query { + Instruction::extension("SPV_KHR_ray_query") + .to_words(&mut self.logical_layout.extensions) + } Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations); Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450") .to_words(&mut self.logical_layout.ext_inst_imports); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index f24f4a9c26..92086c94a8 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -937,6 +937,7 @@ impl Writer { writeln!(self.out, "{level}workgroupBarrier();")?; } } + Statement::RayQuery { .. } => unreachable!(), } Ok(()) @@ -1621,8 +1622,12 @@ impl Writer { write!(self.out, ")")? } + // Not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), // Nothing to do here, since call expression already cached - Expression::CallResult(_) | Expression::AtomicResult { .. } => {} + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult => {} } Ok(()) diff --git a/src/front/glsl/constants.rs b/src/front/glsl/constants.rs index d9a6fc7cd7..045a9c6ffb 100644 --- a/src/front/glsl/constants.rs +++ b/src/front/glsl/constants.rs @@ -37,6 +37,8 @@ pub enum ConstantSolvingError { Load, #[error("Constants don't support image expressions")] ImageExpression, + #[error("Constants don't support ray query expressions")] + RayQueryExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -295,6 +297,9 @@ impl<'a> ConstantSolver<'a> { Expression::ImageSample { .. } | Expression::ImageLoad { .. } | Expression::ImageQuery { .. } => Err(ConstantSolvingError::ImageExpression), + Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { + Err(ConstantSolvingError::RayQueryExpression) + } } } diff --git a/src/front/glsl/types.rs b/src/front/glsl/types.rs index 632378c60b..a7967848d5 100644 --- a/src/front/glsl/types.rs +++ b/src/front/glsl/types.rs @@ -246,14 +246,7 @@ impl Frontend { expr: Handle, meta: Span, ) -> Result<()> { - let resolve_ctx = ResolveContext { - constants: &self.module.constants, - types: &self.module.types, - global_vars: &self.module.global_variables, - local_vars: &ctx.locals, - functions: &self.module.functions, - arguments: &ctx.arguments, - }; + let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments); ctx.typifier .grow(expr, &ctx.expressions, &resolve_ctx) @@ -312,14 +305,7 @@ impl Frontend { expr: Handle, meta: Span, ) -> Result<()> { - let resolve_ctx = ResolveContext { - constants: &self.module.constants, - types: &self.module.types, - global_vars: &self.module.global_variables, - local_vars: &ctx.locals, - functions: &self.module.functions, - arguments: &ctx.arguments, - }; + let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments); ctx.typifier .invalidate(expr, &ctx.expressions, &resolve_ctx) diff --git a/src/front/mod.rs b/src/front/mod.rs index 071e805a69..d6f38671ea 100644 --- a/src/front/mod.rs +++ b/src/front/mod.rs @@ -3,6 +3,7 @@ Frontend parsers that consume binary and text shaders and load them into [`Modul */ mod interpolator; +mod type_gen; #[cfg(feature = "glsl-in")] pub mod glsl; diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index ce42be35b2..c69a230cb0 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3672,7 +3672,8 @@ impl> Frontend { | S::Barrier(_) | S::Store { .. } | S::ImageStore { .. } - | S::Atomic { .. } => {} + | S::Atomic { .. } + | S::RayQuery { .. } => {} S::Call { function: ref mut callee, ref arguments, diff --git a/src/front/type_gen.rs b/src/front/type_gen.rs new file mode 100644 index 0000000000..1ee454c448 --- /dev/null +++ b/src/front/type_gen.rs @@ -0,0 +1,314 @@ +/*! +Type generators. +*/ + +use crate::{arena::Handle, span::Span}; + +impl crate::Module { + pub fn generate_atomic_compare_exchange_result( + &mut self, + kind: crate::ScalarKind, + width: crate::Bytes, + ) -> Handle { + let bool_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }, + }, + Span::UNDEFINED, + ); + let scalar_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind, width }, + }, + Span::UNDEFINED, + ); + + self.types.insert( + crate::Type { + name: Some(format!( + "__atomic_compare_exchange_result<{kind:?},{width}>" + )), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("old_value".to_string()), + ty: scalar_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("exchanged".to_string()), + ty: bool_ty, + binding: None, + offset: 4, + }, + ], + span: 8, + }, + }, + Span::UNDEFINED, + ) + } + /// Populate this module's [`SpecialTypes::ray_desc`] type. + /// + /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of + /// an [`Initialize`] [`RayQuery`] statement. In WGSL, it is a struct type + /// referred to as `RayDesc`. + /// + /// Backends consume values of this type to drive platform APIs, so if you + /// change any its fields, you must update the backends to match. Look for + /// backend code dealing with [`RayQueryFunction::Initialize`]. + /// + /// [`SpecialTypes::ray_desc`]: crate::SpecialTypes::ray_desc + /// [`descriptor`]: crate::RayQueryFunction::Initialize::descriptor + /// [`Initialize`]: crate::RayQueryFunction::Initialize + /// [`RayQuery`]: crate::Statement::RayQuery + /// [`RayQueryFunction::Initialize`]: crate::RayQueryFunction::Initialize + pub fn generate_ray_desc_type(&mut self) -> Handle { + if let Some(handle) = self.special_types.ray_desc { + return handle; + } + + let width = 4; + let ty_flag = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width, + kind: crate::ScalarKind::Uint, + }, + }, + Span::UNDEFINED, + ); + let ty_scalar = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width, + kind: crate::ScalarKind::Float, + }, + }, + Span::UNDEFINED, + ); + let ty_vector = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + kind: crate::ScalarKind::Float, + width, + }, + }, + Span::UNDEFINED, + ); + + let handle = self.types.insert( + crate::Type { + name: Some("RayDesc".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("flags".to_string()), + ty: ty_flag, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("cull_mask".to_string()), + ty: ty_flag, + binding: None, + offset: 4, + }, + crate::StructMember { + name: Some("tmin".to_string()), + ty: ty_scalar, + binding: None, + offset: 8, + }, + crate::StructMember { + name: Some("tmax".to_string()), + ty: ty_scalar, + binding: None, + offset: 12, + }, + crate::StructMember { + name: Some("origin".to_string()), + ty: ty_vector, + binding: None, + offset: 16, + }, + crate::StructMember { + name: Some("dir".to_string()), + ty: ty_vector, + binding: None, + offset: 32, + }, + ], + span: 48, + }, + }, + Span::UNDEFINED, + ); + + self.special_types.ray_desc = Some(handle); + handle + } + + /// Populate this module's [`SpecialTypes::ray_intersection`] type. + /// + /// [`SpecialTypes::ray_intersection`] is the type of a + /// `RayQueryGetIntersection` expression. In WGSL, it is a struct type + /// referred to as `RayIntersection`. + /// + /// Backends construct values of this type based on platform APIs, so if you + /// change any its fields, you must update the backends to match. Look for + /// the backend's handling for [`Expression::RayQueryGetIntersection`]. + /// + /// [`SpecialTypes::ray_intersection`]: crate::SpecialTypes::ray_intersection + /// [`Expression::RayQueryGetIntersection`]: crate::Expression::RayQueryGetIntersection + pub fn generate_ray_intersection_type(&mut self) -> Handle { + if let Some(handle) = self.special_types.ray_intersection { + return handle; + } + + let width = 4; + let ty_flag = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width, + kind: crate::ScalarKind::Uint, + }, + }, + Span::UNDEFINED, + ); + let ty_scalar = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width, + kind: crate::ScalarKind::Float, + }, + }, + Span::UNDEFINED, + ); + let ty_barycentrics = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + width, + size: crate::VectorSize::Bi, + kind: crate::ScalarKind::Float, + }, + }, + Span::UNDEFINED, + ); + let ty_bool = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width: crate::BOOL_WIDTH, + kind: crate::ScalarKind::Bool, + }, + }, + Span::UNDEFINED, + ); + let ty_transform = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width, + }, + }, + Span::UNDEFINED, + ); + + let handle = self.types.insert( + crate::Type { + name: Some("RayIntersection".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("kind".to_string()), + ty: ty_flag, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("t".to_string()), + ty: ty_scalar, + binding: None, + offset: 4, + }, + crate::StructMember { + name: Some("instance_custom_index".to_string()), + ty: ty_flag, + binding: None, + offset: 8, + }, + crate::StructMember { + name: Some("instance_id".to_string()), + ty: ty_flag, + binding: None, + offset: 12, + }, + crate::StructMember { + name: Some("sbt_record_offset".to_string()), + ty: ty_flag, + binding: None, + offset: 16, + }, + crate::StructMember { + name: Some("geometry_index".to_string()), + ty: ty_flag, + binding: None, + offset: 20, + }, + crate::StructMember { + name: Some("primitive_index".to_string()), + ty: ty_flag, + binding: None, + offset: 24, + }, + crate::StructMember { + name: Some("barycentrics".to_string()), + ty: ty_barycentrics, + binding: None, + offset: 28, + }, + crate::StructMember { + name: Some("front_face".to_string()), + ty: ty_bool, + binding: None, + offset: 36, + }, + crate::StructMember { + name: Some("object_to_world".to_string()), + ty: ty_transform, + binding: None, + offset: 48, + }, + crate::StructMember { + name: Some("world_to_object".to_string()), + ty: ty_transform, + binding: None, + offset: 112, + }, + ], + span: 176, + }, + }, + Span::UNDEFINED, + ); + + self.special_types.ray_intersection = Some(handle); + handle + } +} diff --git a/src/front/wgsl/error.rs b/src/front/wgsl/error.rs index a4e6540237..2e71a76624 100644 --- a/src/front/wgsl/error.rs +++ b/src/front/wgsl/error.rs @@ -188,6 +188,7 @@ pub enum Error<'a> { MissingAttribute(&'static str, Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), + InvalidRayQueryPointer(Span), Pointer(&'static str, Span), NotPointer(Span), NotReference(&'static str, Span), @@ -526,6 +527,11 @@ impl<'a> Error<'a> { labels: vec![(span, "atomic operand type is invalid".into())], notes: vec![], }, + Error::InvalidRayQueryPointer(span) => ParseError { + message: "ray query operation is done on a pointer to a non-ray-query".to_string(), + labels: vec![(span, "ray query pointer is invalid".into())], + notes: vec![], + }, Error::NotPointer(span) => ParseError { message: "the operand of the `*` operator must be a pointer".to_string(), labels: vec![(span, "expression is not a pointer".into())], diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index f3b157caa7..bc9cce1bee 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -234,14 +234,7 @@ impl<'a> ExpressionContext<'a, '_, '_> { /// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner /// [`Typifier`]: Typifier fn grow_types(&mut self, handle: Handle) -> Result<&mut Self, Error<'a>> { - let resolve_ctx = ResolveContext { - constants: &self.module.constants, - types: &self.module.types, - global_vars: &self.module.global_variables, - local_vars: self.local_vars, - functions: &self.module.functions, - arguments: self.arguments, - }; + let resolve_ctx = ResolveContext::with_locals(self.module, self.local_vars, self.arguments); self.typifier .grow(handle, self.naga_expressions, &resolve_ctx) .map_err(Error::InvalidResolve)?; @@ -644,14 +637,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { module: &mut module, }; - for decl in self.index.visit_ordered() { - let span = tu.decls.get_span(decl); - let decl = &tu.decls[decl]; + for decl_handle in self.index.visit_ordered() { + let span = tu.decls.get_span(decl_handle); + let decl = &tu.decls[decl_handle]; match decl.kind { ast::GlobalDeclKind::Fn(ref f) => { - let decl = self.function(f, span, ctx.reborrow())?; - ctx.globals.insert(f.name.name, decl); + let lowered_decl = self.function(f, span, ctx.reborrow())?; + ctx.globals.insert(f.name.name, lowered_decl); } ast::GlobalDeclKind::Var(ref v) => { let ty = self.resolve_ast_type(v.ty, ctx.reborrow())?; @@ -1733,50 +1726,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expression = match *ctx.resolved_inner(value) { crate::TypeInner::Scalar { kind, width } => { - let bool_ty = ctx.module.types.insert( - crate::Type { - name: None, - inner: crate::TypeInner::Scalar { - kind: crate::ScalarKind::Bool, - width: crate::BOOL_WIDTH, - }, - }, - Span::UNDEFINED, - ); - let scalar_ty = ctx.module.types.insert( - crate::Type { - name: None, - inner: crate::TypeInner::Scalar { kind, width }, - }, - Span::UNDEFINED, - ); - let struct_ty = ctx.module.types.insert( - crate::Type { - name: Some( - "__atomic_compare_exchange_result".to_string(), - ), - inner: crate::TypeInner::Struct { - members: vec![ - crate::StructMember { - name: Some("old_value".to_string()), - ty: scalar_ty, - binding: None, - offset: 0, - }, - crate::StructMember { - name: Some("exchanged".to_string()), - ty: bool_ty, - binding: None, - offset: 4, - }, - ], - span: 8, - }, - }, - Span::UNDEFINED, - ); crate::Expression::AtomicResult { - ty: struct_ty, + //TODO: cache this to avoid generating duplicate types + ty: ctx + .module + .generate_atomic_compare_exchange_result(kind, width), comparison: true, } } @@ -1919,6 +1873,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { query: crate::ImageQuery::NumSamples, } } + "rayQueryInitialize" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; + let acceleration_structure = + self.expression(args.next()?, ctx.reborrow())?; + let descriptor = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let _ = ctx.module.generate_ray_desc_type(); + let fun = crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + }; + + ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions)); + ctx.emitter.start(ctx.naga_expressions); + ctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(None); + } + "rayQueryProceed" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; + args.finish()?; + + ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions)); + let result = ctx + .naga_expressions + .append(crate::Expression::RayQueryProceedResult, span); + let fun = crate::RayQueryFunction::Proceed { result }; + + ctx.emitter.start(ctx.naga_expressions); + ctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(Some(result)); + } + "rayQueryGetCommittedIntersection" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; + args.finish()?; + + let _ = ctx.module.generate_ray_intersection_type(); + + crate::Expression::RayQueryGetIntersection { + query, + committed: true, + } + } + "RayDesc" => { + let ty = ctx.module.generate_ray_desc_type(); + let handle = self.construct( + span, + &ast::ConstructorType::Type(ty), + function.span, + arguments, + ctx.reborrow(), + )?; + return Ok(Some(handle)); + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2245,6 +2258,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { class, }, ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison }, + ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure, + ast::Type::RayQuery => crate::TypeInner::RayQuery, ast::Type::BindingArray { base, size } => { let base = self.resolve_ast_type(base, ctx.reborrow())?; @@ -2259,6 +2274,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, } } + ast::Type::RayDesc => { + return Ok(ctx.module.generate_ray_desc_type()); + } + ast::Type::RayIntersection => { + return Ok(ctx.module.generate_ray_intersection_type()); + } ast::Type::User(ref ident) => { return match ctx.globals.get(ident.name) { Some(&LoweredGlobalDecl::Type(handle)) => Ok(handle), @@ -2378,4 +2399,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { binding } + + fn ray_query_pointer( + &mut self, + expr: Handle>, + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let pointer = self.expression(expr, ctx.reborrow())?; + + ctx.grow_types(pointer)?; + match *ctx.resolved_inner(pointer) { + crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { + crate::TypeInner::RayQuery => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to ray query op", other); + Err(Error::InvalidRayQueryPointer(span)) + } + }, + ref other => { + log::error!("Type {:?} passed to ray query op", other); + Err(Error::InvalidRayQueryPointer(span)) + } + } + } } diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 8ac82fe45e..eb21fae6c9 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -206,6 +206,8 @@ impl crate::TypeInner { format!("texture{class_suffix}{dim_suffix}{array_suffix}{type_in_brackets}") } Ti::Sampler { .. } => "sampler".to_string(), + Ti::AccelerationStructure => "acceleration_structure".to_string(), + Ti::RayQuery => "ray_query".to_string(), Ti::BindingArray { base, size, .. } => { let member_type = &types[base]; let base = member_type.name.as_deref().unwrap_or("unknown"); diff --git a/src/front/wgsl/parse/ast.rs b/src/front/wgsl/parse/ast.rs index 734d9769fe..2a56ac6f80 100644 --- a/src/front/wgsl/parse/ast.rs +++ b/src/front/wgsl/parse/ast.rs @@ -229,6 +229,10 @@ pub enum Type<'a> { Sampler { comparison: bool, }, + AccelerationStructure, + RayQuery, + RayDesc, + RayIntersection, BindingArray { base: Handle>, size: ArraySize<'a>, diff --git a/src/front/wgsl/parse/mod.rs b/src/front/wgsl/parse/mod.rs index 7ff762d673..7a030259b8 100644 --- a/src/front/wgsl/parse/mod.rs +++ b/src/front/wgsl/parse/mod.rs @@ -622,6 +622,18 @@ impl Parser { let num = res.map_err(|err| Error::BadNumber(span, err))?; ast::Expression::Literal(ast::Literal::Number(num)) } + (Token::Word("RAY_FLAG_NONE"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(0))) + } + (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(4))) + } + (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(0))) + } (Token::Word(word), span) => { let start = lexer.start_byte_offset(); let _ = lexer.next(); @@ -1367,6 +1379,10 @@ impl Parser { class: crate::ImageClass::Storage { format, access }, } } + "acceleration_structure" => ast::Type::AccelerationStructure, + "ray_query" => ast::Type::RayQuery, + "RayDesc" => ast::Type::RayDesc, + "RayIntersection" => ast::Type::RayIntersection, _ => return Ok(None), })) } diff --git a/src/lib.rs b/src/lib.rs index c1b48b8991..a70015d16d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -107,6 +107,11 @@ Naga's rules for when `Expression`s are evaluated are as follows: [`Atomic`] statement, representing the result of the atomic operation, is evaluated when the `Atomic` statement is executed. +- A [`RayQueryProceedResult`] expression, which is a boolean + indicating if the ray query is finished, is evaluated when the + [`RayQuery`] statement whose [`Proceed::result`] points to it is + executed. + - All other expressions are evaluated when the (unique) [`Statement::Emit`] statement that covers them is executed. @@ -166,6 +171,7 @@ need to be stored in a local variable to be carried upwards in the statement tree. [`AtomicResult`]: Expression::AtomicResult +[`RayQueryProceedResult`]: Expression::RayQueryProceedResult [`CallResult`]: Expression::CallResult [`Constant`]: Expression::Constant [`Derivative`]: Expression::Derivative @@ -180,6 +186,9 @@ tree. [`Call`]: Statement::Call [`Emit`]: Statement::Emit [`Store`]: Statement::Store +[`RayQuery`]: Statement::RayQuery + +[`Proceed::result`]: RayQueryFunction::Proceed::result [`Validator::validate`]: valid::Validator::validate [`ModuleInfo`]: valid::ModuleInfo @@ -721,6 +730,12 @@ pub enum TypeInner { /// Can be used to sample values from images. Sampler { comparison: bool }, + /// Opaque object representing an acceleration structure of geometry. + AccelerationStructure, + + /// Locally used handle for ray queries. + RayQuery, + /// Array of bindings. /// /// A `BindingArray` represents an array where each element draws its value @@ -1436,6 +1451,20 @@ pub enum Expression { /// This doesn't match the semantics of spirv's `OpArrayLength`, which must be passed /// a pointer to a structure containing a runtime array in its' last field. ArrayLength(Handle), + + /// Result of a [`Proceed`] [`RayQuery`] statement. + /// + /// [`Proceed`]: RayQueryFunction::Proceed + /// [`RayQuery`]: Statement::RayQuery + RayQueryProceedResult, + + /// Return an intersection found by `query`. + /// + /// If `committed` is true, return the committed result available when + RayQueryGetIntersection { + query: Handle, + committed: bool, + }, } pub use block::Block; @@ -1468,6 +1497,48 @@ pub struct SwitchCase { pub fall_through: bool, } +/// An operation that a [`RayQuery` statement] applies to its [`query`] operand. +/// +/// [`RayQuery` statement]: Statement::RayQuery +/// [`query`]: Statement::RayQuery::query +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum RayQueryFunction { + /// Initialize the `RayQuery` object. + Initialize { + /// The acceleration structure within which this query should search for hits. + /// + /// The expression must be an [`AccelerationStructure`]. + /// + /// [`AccelerationStructure`]: TypeInner::AccelerationStructure + acceleration_structure: Handle, + + #[allow(rustdoc::private_intra_doc_links)] + /// A struct of detailed parameters for the ray query. + /// + /// This expression should have the struct type given in + /// [`SpecialTypes::ray_desc`]. This is available in the WGSL + /// front end as the `RayDesc` type. + descriptor: Handle, + }, + + /// Start or continue the query given by the statement's [`query`] operand. + /// + /// After executing this statement, the `result` expression is a + /// [`Bool`] scalar indicating whether there are more intersection + /// candidates to consider. + /// + /// [`query`]: Statement::RayQuery::query + /// [`Bool`]: ScalarKind::Bool + Proceed { + result: Handle, + }, + + Terminate, +} + //TODO: consider removing `Clone`. It's not valid to clone `Statement::Emit` anyway. /// Instructions which make up an executable block. // Clone is used only for error reporting and is not intended for end users @@ -1641,6 +1712,15 @@ pub enum Statement { arguments: Vec>, result: Option>, }, + RayQuery { + /// The [`RayQuery`] object this statement operates on. + /// + /// [`RayQuery`]: TypeInner::RayQuery + query: Handle, + + /// The specific operation we're performing on `query`. + fun: RayQueryFunction, + }, } /// A function argument. @@ -1757,6 +1837,26 @@ pub struct EntryPoint { pub function: Function, } +/// Set of special types that can be optionally generated by the frontends. +#[derive(Debug, Default)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct SpecialTypes { + /// Type for `RayDesc`. + /// + /// Call [`Module::generate_ray_desc_type`] to populate this if + /// needed and return the handle. + pub ray_desc: Option>, + + /// Type for `RayIntersection`. + /// + /// Call [`Module::generate_ray_intersection_type`] to populate + /// this if needed and return the handle. + pub ray_intersection: Option>, +} + /// Shader module. /// /// A module is a set of constants, global variables and functions, as well as @@ -1776,6 +1876,8 @@ pub struct EntryPoint { pub struct Module { /// Arena for the types defined in this module. pub types: UniqueArena, + /// Dictionary of special type handles. + pub special_types: SpecialTypes, /// Arena for the constants defined in this module. pub constants: Arena, /// Arena for the global variables defined in this module. diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index db07f261a4..65369d1cc8 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -238,7 +238,11 @@ impl Layouter { alignment, } } - Ti::Image { .. } | Ti::Sampler { .. } | Ti::BindingArray { .. } => TypeLayout { + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => TypeLayout { size, alignment: Alignment::ONE, }, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 6a8bfa03c7..a775272a19 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -134,7 +134,11 @@ impl super::TypeInner { count * stride } Self::Struct { span, .. } => span, - Self::Image { .. } | Self::Sampler { .. } | Self::BindingArray { .. } => 0, + Self::Image { .. } + | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery + | Self::BindingArray { .. } => 0, } } diff --git a/src/proc/terminator.rs b/src/proc/terminator.rs index 5915616cc5..ca0c3f10bc 100644 --- a/src/proc/terminator.rs +++ b/src/proc/terminator.rs @@ -34,6 +34,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::Store { .. } | S::ImageStore { .. } | S::Call { .. } + | S::RayQuery { .. } | S::Atomic { .. } | S::Barrier(_)), ) diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index b9ac468313..0bb9019a29 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -193,11 +193,14 @@ pub enum ResolveError { IncompatibleOperands(String), #[error("Function argument {0} doesn't exist")] FunctionArgumentNotFound(u32), + #[error("Special type is not registered within the module")] + MissingSpecialType, } pub struct ResolveContext<'a> { pub constants: &'a Arena, pub types: &'a UniqueArena, + pub special_types: &'a crate::SpecialTypes, pub global_vars: &'a Arena, pub local_vars: &'a Arena, pub functions: &'a Arena, @@ -205,6 +208,23 @@ pub struct ResolveContext<'a> { } impl<'a> ResolveContext<'a> { + /// Initialize a resolve context from the module. + pub const fn with_locals( + module: &'a crate::Module, + local_vars: &'a Arena, + arguments: &'a [crate::FunctionArgument], + ) -> Self { + Self { + constants: &module.constants, + types: &module.types, + special_types: &module.special_types, + global_vars: &module.global_variables, + local_vars, + functions: &module.functions, + arguments, + } + } + /// Determine the type of `expr`. /// /// The `past` argument must be a closure that can resolve the types of any @@ -867,6 +887,17 @@ impl<'a> ResolveContext<'a> { kind: crate::ScalarKind::Uint, width: 4, }), + crate::Expression::RayQueryProceedResult => TypeResolution::Value(Ti::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }), + crate::Expression::RayQueryGetIntersection { .. } => { + let result = self + .special_types + .ray_intersection + .ok_or(ResolveError::MissingSpecialType)?; + TypeResolution::Handle(result) + } }) } } diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 40d5f95c10..e9b155b6eb 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -686,7 +686,7 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, E::CallResult(function) => other_functions[function.index()].uniformity.clone(), - E::AtomicResult { .. } => Uniformity { + E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity { non_uniform_result: Some(handle), requirements: UniformityRequirements::empty(), }, @@ -694,6 +694,13 @@ impl FunctionInfo { non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY), requirements: UniformityRequirements::empty(), }, + E::RayQueryGetIntersection { + query, + committed: _, + } => Uniformity { + non_uniform_result: self.add_ref(query), + requirements: UniformityRequirements::empty(), + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -893,6 +900,18 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::RayQuery { query, ref fun } => { + let _ = self.add_ref(query); + if let crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } = *fun + { + let _ = self.add_ref(acceleration_structure); + let _ = self.add_ref(descriptor); + } + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); @@ -922,14 +941,8 @@ impl ModuleInfo { expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(), sampling: crate::FastHashSet::default(), }; - let resolve_context = ResolveContext { - constants: &module.constants, - types: &module.types, - global_vars: &module.global_variables, - local_vars: &fun.local_variables, - functions: &module.functions, - arguments: &fun.arguments, - }; + let resolve_context = + ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments); for (handle, expr) in fun.expressions.iter() { if let Err(source) = info.process_expression( @@ -1052,6 +1065,7 @@ fn uniform_control_flow() { let resolve_context = ResolveContext { constants: &constant_arena, types: &type_arena, + special_types: &crate::SpecialTypes::default(), global_vars: &global_var_arena, local_vars: &Arena::new(), functions: &Arena::new(), diff --git a/src/valid/expression.rs b/src/valid/expression.rs index af080fc183..1a91fe4d0a 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -35,6 +35,8 @@ pub enum ExpressionError { InvalidPointerType(Handle), #[error("Array length of {0:?} can't be done")] InvalidArrayType(Handle), + #[error("Get intersection of {0:?} can't be done")] + InvalidRayQueryType(Handle), #[error("Splatting {0:?} can't be done")] InvalidSplatType(Handle), #[error("Swizzling {0:?} can't be done")] @@ -1379,7 +1381,7 @@ impl super::Validator { _ => return Err(ExpressionError::InvalidCastArgument), }; let width = convert.unwrap_or(base_width); - if !self.check_width(kind, width) { + if self.check_width(kind, width).is_err() { return Err(ExpressionError::InvalidCastArgument); } ShaderStages::all() @@ -1390,7 +1392,7 @@ impl super::Validator { &crate::TypeInner::Scalar { kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint), width, - } => self.check_width(kind, width), + } => self.check_width(kind, width).is_ok(), _ => false, }; let good = match &module.types[ty].inner { @@ -1427,6 +1429,26 @@ impl super::Validator { return Err(ExpressionError::InvalidArrayType(expr)); } }, + E::RayQueryProceedResult => ShaderStages::all(), + E::RayQueryGetIntersection { + query, + committed: _, + } => match resolver[query] { + Ti::Pointer { + base, + space: crate::AddressSpace::Function, + } => match resolver.types[base].inner { + Ti::RayQuery => ShaderStages::all(), + ref other => { + log::error!("Intersection result of a pointer to {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, + ref other => { + log::error!("Intersection result of {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index 464496f6d6..737f33dc28 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -47,8 +47,6 @@ pub enum AtomicError { InvalidPointer(Handle), #[error("Operand {0:?} has invalid type.")] InvalidOperand(Handle), - #[error("Result expression {0:?} has already been introduced earlier")] - ResultAlreadyInScope(Handle), #[error("Result type for {0:?} doesn't match the statement")] ResultTypeMismatch(Handle), } @@ -131,6 +129,14 @@ pub enum FunctionError { }, #[error("Atomic operation is invalid")] InvalidAtomic(#[from] AtomicError), + #[error("Ray Query {0:?} is not a local variable")] + InvalidRayQueryExpression(Handle), + #[error("Acceleration structure {0:?} is not a matching expression")] + InvalidAccelerationStructure(Handle), + #[error("Ray descriptor {0:?} is not a matching expression")] + InvalidRayDescriptor(Handle), + #[error("Ray Query {0:?} does not have a matching type")] + InvalidRayQueryType(Handle), #[error( "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" )] @@ -169,8 +175,10 @@ struct BlockContext<'a> { info: &'a FunctionInfo, expressions: &'a Arena, types: &'a UniqueArena, + local_vars: &'a Arena, global_vars: &'a Arena, functions: &'a Arena, + special_types: &'a crate::SpecialTypes, prev_infos: &'a [FunctionInfo], return_type: Option>, } @@ -188,8 +196,10 @@ impl<'a> BlockContext<'a> { info, expressions: &fun.expressions, types: &module.types, + local_vars: &fun.local_variables, global_vars: &module.global_variables, functions: &module.functions, + special_types: &module.special_types, prev_infos, return_type: fun.result.as_ref().map(|fr| fr.ty), } @@ -299,6 +309,21 @@ impl super::Validator { Ok(callee_info.available_stages) } + #[cfg(feature = "validate")] + fn emit_expression( + &mut self, + handle: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + if self.valid_expression_set.insert(handle.index()) { + self.valid_expression_list.push(handle); + Ok(()) + } else { + Err(FunctionError::ExpressionAlreadyInScope(handle) + .with_span_handle(handle, context.expressions)) + } + } + #[cfg(feature = "validate")] fn validate_atomic( &mut self, @@ -347,13 +372,7 @@ impl super::Validator { } } - if self.valid_expression_set.insert(result.index()) { - self.valid_expression_list.push(result); - } else { - return Err(AtomicError::ResultAlreadyInScope(result) - .with_span_handle(result, context.expressions) - .into_other()); - } + self.emit_expression(result, context)?; match context.expressions[result] { crate::Expression::AtomicResult { ty, comparison } if { @@ -401,12 +420,7 @@ impl super::Validator { match *statement { S::Emit(ref range) => { for handle in range.clone() { - if self.valid_expression_set.insert(handle.index()) { - self.valid_expression_list.push(handle); - } else { - return Err(FunctionError::ExpressionAlreadyInScope(handle) - .with_span_handle(handle, context.expressions)); - } + self.emit_expression(handle, context)?; } } S::Block(ref block) => { @@ -807,6 +821,56 @@ impl super::Validator { } => { self.validate_atomic(pointer, fun, value, result, context)?; } + S::RayQuery { query, ref fun } => { + let query_var = match *context.get_expression(query) { + crate::Expression::LocalVariable(var) => &context.local_vars[var], + ref other => { + log::error!("Unexpected ray query expression {other:?}"); + return Err(FunctionError::InvalidRayQueryExpression(query) + .with_span_static(span, "invalid query expression")); + } + }; + match context.types[query_var.ty].inner { + Ti::RayQuery => {} + ref other => { + log::error!("Unexpected ray query type {other:?}"); + return Err(FunctionError::InvalidRayQueryType(query_var.ty) + .with_span_static(span, "invalid query type")); + } + } + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + match *context + .resolve_type(acceleration_structure, &self.valid_expression_set)? + { + Ti::AccelerationStructure => {} + _ => { + return Err(FunctionError::InvalidAccelerationStructure( + acceleration_structure, + ) + .with_span_static(span, "invalid acceleration structure")) + } + } + let desc_ty_given = + context.resolve_type(descriptor, &self.valid_expression_set)?; + let desc_ty_expected = context + .special_types + .ray_desc + .map(|handle| &context.types[handle].inner); + if Some(desc_ty_given) != desc_ty_expected { + return Err(FunctionError::InvalidRayDescriptor(descriptor) + .with_span_static(span, "invalid ray descriptor")); + } + } + crate::RayQueryFunction::Proceed { result } => { + self.emit_expression(result, context)?; + } + crate::RayQueryFunction::Terminate => {} + } + } } } Ok(BlockInfo { stages, finished }) diff --git a/src/valid/handles.rs b/src/valid/handles.rs index e3f9fe2531..fdd43cd585 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -1,4 +1,4 @@ -//! Implementation of [`super::Validator::validate_module_handles`]. +//! Implementation of `Validator::validate_module_handles`. use crate::{ arena::{BadHandle, BadRangeError}, @@ -39,6 +39,7 @@ impl super::Validator { ref functions, ref global_variables, ref types, + ref special_types, } = module; // NOTE: Types being first is important. All other forms of validation depend on this. @@ -76,7 +77,9 @@ impl super::Validator { | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Atomic { .. } | crate::TypeInner::Image { .. } - | crate::TypeInner::Sampler { .. } => (), + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => (), crate::TypeInner::Pointer { base, space: _ } => { this_handle.check_dep(base)?; } @@ -192,6 +195,13 @@ impl super::Validator { validate_function(Some(function_handle), function)?; } + if let Some(ty) = special_types.ray_desc { + validate_type(ty)?; + } + if let Some(ty) = special_types.ray_intersection { + validate_type(ty)?; + } + Ok(()) } @@ -377,10 +387,16 @@ impl super::Validator { handle.check_dep(function)?; } } - crate::Expression::AtomicResult { .. } => (), + crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; } + crate::Expression::RayQueryGetIntersection { + query, + committed: _, + } => { + handle.check_dep(query)?; + } } Ok(()) } @@ -494,6 +510,23 @@ impl super::Validator { validate_expr_opt(result)?; Ok(()) } + crate::Statement::RayQuery { query, ref fun } => { + validate_expr(query)?; + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + validate_expr(acceleration_structure)?; + validate_expr(descriptor)?; + } + crate::RayQueryFunction::Proceed { result } => { + validate_expr(result)?; + } + crate::RayQueryFunction::Terminate => {} + } + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 13dbd75761..d9ee9f5402 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -440,7 +440,9 @@ impl super::Validator { match types[var.ty].inner { crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } - | crate::TypeInner::BindingArray { .. } => {} + | crate::TypeInner::BindingArray { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => {} _ => { return Err(GlobalVariableError::InvalidType(var.space)); } diff --git a/src/valid/mod.rs b/src/valid/mod.rs index eb92e8892d..6b3a2e1456 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -111,6 +111,8 @@ bitflags::bitflags! { const EARLY_DEPTH_TEST = 0x400; /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`]. const MULTISAMPLED_SHADING = 0x800; + /// Support for ray queries and acceleration structures. + const RAY_QUERY = 0x1000; } } @@ -238,6 +240,8 @@ impl crate::TypeInner { Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery | Self::BindingArray { .. } => false, } } @@ -302,7 +306,7 @@ impl Validator { let con = &constants[handle]; match con.inner { crate::ConstantInner::Scalar { width, ref value } => { - if !self.check_width(value.scalar_kind(), width) { + if self.check_width(value.scalar_kind(), width).is_err() { return Err(ConstantError::InvalidType); } } diff --git a/src/valid/type.rs b/src/valid/type.rs index 4fcc1a1c58..d8dd37d09b 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -90,6 +90,8 @@ pub enum Disalignment { #[derive(Clone, Debug, thiserror::Error)] pub enum TypeError { + #[error("Capability {0:?} is required")] + MissingCapability(Capabilities), #[error("The {0:?} scalar width {1} is not supported")] InvalidWidth(crate::ScalarKind, crate::Bytes), #[error("The {0:?} scalar width {1} is not supported for an atomic")] @@ -203,13 +205,35 @@ impl TypeInfo { } impl super::Validator { - pub(super) const fn check_width(&self, kind: crate::ScalarKind, width: crate::Bytes) -> bool { - match kind { + const fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> { + if self.capabilities.contains(capability) { + Ok(()) + } else { + Err(TypeError::MissingCapability(capability)) + } + } + + pub(super) fn check_width( + &self, + kind: crate::ScalarKind, + width: crate::Bytes, + ) -> Result<(), TypeError> { + let good = match kind { crate::ScalarKind::Bool => width == crate::BOOL_WIDTH, crate::ScalarKind::Float => { - width == 4 || (width == 8 && self.capabilities.contains(Capabilities::FLOAT64)) + if width == 8 { + self.require_type_capability(Capabilities::FLOAT64)?; + true + } else { + width == 4 + } } crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4, + }; + if good { + Ok(()) + } else { + Err(TypeError::InvalidWidth(kind, width)) } } @@ -228,9 +252,7 @@ impl super::Validator { use crate::TypeInner as Ti; Ok(match types[handle].inner { Ti::Scalar { kind, width } => { - if !self.check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } + self.check_width(kind, width)?; let shareable = if kind.is_numeric() { TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE } else { @@ -247,9 +269,7 @@ impl super::Validator { ) } Ti::Vector { size, kind, width } => { - if !self.check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } + self.check_width(kind, width)?; let shareable = if kind.is_numeric() { TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE } else { @@ -271,9 +291,7 @@ impl super::Validator { rows, width, } => { - if !self.check_width(crate::ScalarKind::Float, width) { - return Err(TypeError::InvalidWidth(crate::ScalarKind::Float, width)); - } + self.check_width(crate::ScalarKind::Float, width)?; TypeInfo::new( TypeFlags::DATA | TypeFlags::SIZED @@ -355,9 +373,7 @@ impl super::Validator { // However, some cases are trivial: All our implicit base types // are DATA and SIZED, so we can never return // `InvalidPointerBase` or `InvalidPointerToUnsized`. - if !self.check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } + self.check_width(kind, width)?; // `Validator::validate_function` actually checks the storage // space of pointer arguments explicitly before checking the @@ -606,6 +622,14 @@ impl super::Validator { Ti::Image { .. } | Ti::Sampler { .. } => { TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) } + Ti::AccelerationStructure => { + self.require_type_capability(Capabilities::RAY_QUERY)?; + TypeInfo::new(TypeFlags::empty(), Alignment::ONE) + } + Ti::RayQuery => { + self.require_type_capability(Capabilities::RAY_QUERY)?; + TypeInfo::new(TypeFlags::DATA | TypeFlags::SIZED, Alignment::ONE) + } Ti::BindingArray { .. } => TypeInfo::new(TypeFlags::empty(), Alignment::ONE), }) } diff --git a/tests/in/ray-query.param.ron b/tests/in/ray-query.param.ron new file mode 100644 index 0000000000..c400db8c64 --- /dev/null +++ b/tests/in/ray-query.param.ron @@ -0,0 +1,14 @@ +( + god_mode: true, + spv: ( + version: (1, 4), + ), + msl: ( + lang_version: (2, 4), + spirv_cross_compatibility: false, + fake_missing_bindings: true, + zero_initialize_workgroup_memory: false, + per_entry_point_map: {}, + inline_samplers: [], + ), +) diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl new file mode 100644 index 0000000000..4826547ded --- /dev/null +++ b/tests/in/ray-query.wgsl @@ -0,0 +1,73 @@ +@group(0) @binding(0) +var acc_struct: acceleration_structure; + +/* +let RAY_FLAG_NONE = 0x00u; +let RAY_FLAG_OPAQUE = 0x01u; +let RAY_FLAG_NO_OPAQUE = 0x02u; +let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u; +let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u; +let RAY_FLAG_CULL_BACK_FACING = 0x10u; +let RAY_FLAG_CULL_FRONT_FACING = 0x20u; +let RAY_FLAG_CULL_OPAQUE = 0x40u; +let RAY_FLAG_CULL_NO_OPAQUE = 0x80u; +let RAY_FLAG_SKIP_TRIANGLES = 0x100u; +let RAY_FLAG_SKIP_AABBS = 0x200u; + +let RAY_QUERY_INTERSECTION_NONE = 0u; +let RAY_QUERY_INTERSECTION_TRIANGLE = 1u; +let RAY_QUERY_INTERSECTION_GENERATED = 2u; +let RAY_QUERY_INTERSECTION_AABB = 4u; + +struct RayDesc { + flags: u32, + cull_mask: u32, + t_min: f32, + t_max: f32, + origin: vec3, + dir: vec3, +} + +struct RayIntersection { + kind: u32, + t: f32, + instance_custom_index: u32, + instance_id: u32, + sbt_record_offset: u32, + geometry_index: u32, + primitive_index: u32, + barycentrics: vec2, + front_face: bool, + object_to_world: mat4x3, + world_to_object: mat4x3, +} +*/ + +struct Output { + visible: u32, + normal: vec3, +} + +@group(0) @binding(1) +var output: Output; + +fn get_torus_normal(world_point: vec3, intersection: RayIntersection) -> vec3 { + let local_point = intersection.world_to_object * vec4(world_point, 1.0); + let point_on_guiding_line = normalize(local_point.xy) * 2.4; + let world_point_on_guiding_line = intersection.object_to_world * vec4(point_on_guiding_line, 0.0, 1.0); + return normalize(world_point - world_point_on_guiding_line); +} + +@compute @workgroup_size(1) +fn main() { + var rq: ray_query; + + let dir = vec3(0.0, 1.0, 0.0); + rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), dir)); + + while (rayQueryProceed(&rq)) {} + + let intersection = rayQueryGetCommittedIntersection(&rq); + output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE); + output.normal = get_torus_normal(dir * intersection.t, intersection); +} diff --git a/tests/out/ir/access.ron b/tests/out/ir/access.ron index e544ee1a5d..41772b9332 100644 --- a/tests/out/ir/access.ron +++ b/tests/out/ir/access.ron @@ -333,6 +333,10 @@ ), ), ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ), constants: [ ( name: None, diff --git a/tests/out/ir/collatz.ron b/tests/out/ir/collatz.ron index 00cab8e885..1be31e6eff 100644 --- a/tests/out/ir/collatz.ron +++ b/tests/out/ir/collatz.ron @@ -38,6 +38,10 @@ ), ), ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ), constants: [ ( name: None, diff --git a/tests/out/ir/shadow.ron b/tests/out/ir/shadow.ron index 8956076ef3..9311f9e188 100644 --- a/tests/out/ir/shadow.ron +++ b/tests/out/ir/shadow.ron @@ -286,6 +286,10 @@ ), ), ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ), constants: [ ( name: None, diff --git a/tests/out/msl/binding-arrays.msl b/tests/out/msl/binding-arrays.msl index da1078b5d8..694f79452d 100644 --- a/tests/out/msl/binding-arrays.msl +++ b/tests/out/msl/binding-arrays.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct UniformIndex { uint index; }; diff --git a/tests/out/msl/bounds-check-image-rzsw.msl b/tests/out/msl/bounds-check-image-rzsw.msl index 9032af14ca..eeb03c9849 100644 --- a/tests/out/msl/bounds-check-image-rzsw.msl +++ b/tests/out/msl/bounds-check-image-rzsw.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + constant metal::int2 const_type_4_ = {0, 0}; constant metal::int3 const_type_7_ = {0, 0, 0}; constant metal::float4 const_type_2_ = {0.0, 0.0, 0.0, 0.0}; diff --git a/tests/out/msl/bounds-check-zero-atomic.msl b/tests/out/msl/bounds-check-zero-atomic.msl index 95028ee796..daaa079233 100644 --- a/tests/out/msl/bounds-check-zero-atomic.msl +++ b/tests/out/msl/bounds-check-zero-atomic.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct _mslBufferSizes { uint size0; }; diff --git a/tests/out/msl/bounds-check-zero.msl b/tests/out/msl/bounds-check-zero.msl index fece92de35..816983d98b 100644 --- a/tests/out/msl/bounds-check-zero.msl +++ b/tests/out/msl/bounds-check-zero.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct _mslBufferSizes { uint size0; }; diff --git a/tests/out/msl/policy-mix.msl b/tests/out/msl/policy-mix.msl index 842c57e58c..7eb0c61ede 100644 --- a/tests/out/msl/policy-mix.msl +++ b/tests/out/msl/policy-mix.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct type_1 { metal::float4 inner[10]; }; diff --git a/tests/out/msl/ray-query.msl b/tests/out/msl/ray-query.msl new file mode 100644 index 0000000000..dc24f80674 --- /dev/null +++ b/tests/out/msl/ray-query.msl @@ -0,0 +1,79 @@ +// language: metal2.4 +#include +#include + +using metal::uint; +struct _RayQuery { + metal::raytracing::intersector intersector; + metal::raytracing::intersector::result_type intersection; + bool ready = false; +}; +constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) { + return ty==metal::raytracing::intersection_type::triangle ? 1 : + ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0; +} + +struct Output { + uint visible_; + char _pad1[12]; + metal::float3 normal; +}; +struct RayIntersection { + uint kind; + float t; + uint instance_custom_index; + uint instance_id; + uint sbt_record_offset; + uint geometry_index; + uint primitive_index; + metal::float2 barycentrics; + bool front_face; + char _pad9[11]; + metal::float4x3 object_to_world; + metal::float4x3 world_to_object; +}; +struct RayDesc { + uint flags; + uint cull_mask; + float tmin; + float tmax; + metal::float3 origin; + metal::float3 dir; +}; + +metal::float3 get_torus_normal( + metal::float3 world_point, + RayIntersection intersection +) { + metal::float3 local_point = intersection.world_to_object * metal::float4(world_point, 1.0); + metal::float2 point_on_guiding_line = metal::normalize(local_point.xy) * 2.4000000953674316; + metal::float3 world_point_on_guiding_line = intersection.object_to_world * metal::float4(point_on_guiding_line, 0.0, 1.0); + return metal::normalize(world_point - world_point_on_guiding_line); +} + +kernel void main_( + metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]] +, device Output& output [[user(fake0)]] +) { + _RayQuery rq = {}; + metal::float3 dir = metal::float3(0.0, 1.0, 0.0); + RayDesc _e12 = RayDesc {4u, 255u, 0.10000000149011612, 100.0, metal::float3(0.0), dir}; + rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle); + rq.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none); + rq.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); + rq.intersector.accept_any_intersection((_e12.flags & 4) != 0); + rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq.ready = true; + while(true) { + bool _e13 = rq.ready; + rq.ready = false; + if (_e13) { + } else { + break; + } + } + RayIntersection intersection_1 = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform}; + output.visible_ = static_cast(intersection_1.kind == 0u); + metal::float3 _e25 = get_torus_normal(dir * intersection_1.t, intersection_1); + output.normal = _e25; + return; +} diff --git a/tests/out/msl/resource-binding-map.msl b/tests/out/msl/resource-binding-map.msl index 4e0b601320..b4a53d97b5 100644 --- a/tests/out/msl/resource-binding-map.msl +++ b/tests/out/msl/resource-binding-map.msl @@ -3,7 +3,6 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { @@ -11,6 +10,7 @@ struct DefaultConstructible { } }; + struct entry_point_oneInput { }; struct entry_point_oneOutput { diff --git a/tests/out/spv/ray-query.spvasm b/tests/out/spv/ray-query.spvasm new file mode 100644 index 0000000000..31dc7d75e6 --- /dev/null +++ b/tests/out/spv/ray-query.spvasm @@ -0,0 +1,152 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 95 +OpCapability RayQueryKHR +OpCapability Shader +OpExtension "SPV_KHR_ray_query" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %48 "main" %23 %25 +OpExecutionMode %48 LocalSize 1 1 1 +OpMemberDecorate %15 0 Offset 0 +OpMemberDecorate %15 1 Offset 16 +OpMemberDecorate %19 0 Offset 0 +OpMemberDecorate %19 1 Offset 4 +OpMemberDecorate %19 2 Offset 8 +OpMemberDecorate %19 3 Offset 12 +OpMemberDecorate %19 4 Offset 16 +OpMemberDecorate %19 5 Offset 20 +OpMemberDecorate %19 6 Offset 24 +OpMemberDecorate %19 7 Offset 28 +OpMemberDecorate %19 8 Offset 36 +OpMemberDecorate %19 9 Offset 48 +OpMemberDecorate %19 9 ColMajor +OpMemberDecorate %19 9 MatrixStride 16 +OpMemberDecorate %19 10 Offset 112 +OpMemberDecorate %19 10 ColMajor +OpMemberDecorate %19 10 MatrixStride 16 +OpMemberDecorate %22 0 Offset 0 +OpMemberDecorate %22 1 Offset 4 +OpMemberDecorate %22 2 Offset 8 +OpMemberDecorate %22 3 Offset 12 +OpMemberDecorate %22 4 Offset 16 +OpMemberDecorate %22 5 Offset 32 +OpDecorate %23 DescriptorSet 0 +OpDecorate %23 Binding 0 +OpDecorate %25 DescriptorSet 0 +OpDecorate %25 Binding 1 +OpDecorate %26 Block +OpMemberDecorate %26 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeFloat 32 +%3 = OpConstant %4 1.0 +%5 = OpConstant %4 2.4 +%6 = OpConstant %4 0.0 +%8 = OpTypeInt 32 0 +%7 = OpConstant %8 4 +%9 = OpConstant %8 255 +%10 = OpConstant %4 0.1 +%11 = OpConstant %4 100.0 +%12 = OpConstant %8 0 +%13 = OpTypeAccelerationStructureNV +%14 = OpTypeVector %4 3 +%15 = OpTypeStruct %8 %14 +%16 = OpTypeVector %4 2 +%17 = OpTypeBool +%18 = OpTypeMatrix %14 4 +%19 = OpTypeStruct %8 %4 %8 %8 %8 %8 %8 %16 %17 %18 %18 +%20 = OpTypeVector %4 4 +%21 = OpTypeRayQueryKHR +%22 = OpTypeStruct %8 %8 %4 %4 %14 %14 +%24 = OpTypePointer UniformConstant %13 +%23 = OpVariable %24 UniformConstant +%26 = OpTypeStruct %15 +%27 = OpTypePointer StorageBuffer %26 +%25 = OpVariable %27 StorageBuffer +%32 = OpTypeFunction %14 %14 %19 +%46 = OpTypePointer Function %21 +%49 = OpTypeFunction %2 +%51 = OpTypePointer StorageBuffer %15 +%72 = OpConstant %8 1 +%85 = OpTypePointer StorageBuffer %8 +%90 = OpTypePointer StorageBuffer %14 +%31 = OpFunction %14 None %32 +%29 = OpFunctionParameter %14 +%30 = OpFunctionParameter %19 +%28 = OpLabel +OpBranch %33 +%33 = OpLabel +%34 = OpCompositeExtract %18 %30 10 +%35 = OpCompositeConstruct %20 %29 %3 +%36 = OpMatrixTimesVector %14 %34 %35 +%37 = OpVectorShuffle %16 %36 %36 0 1 +%38 = OpExtInst %16 %1 Normalize %37 +%39 = OpVectorTimesScalar %16 %38 %5 +%40 = OpCompositeExtract %18 %30 9 +%41 = OpCompositeConstruct %20 %39 %6 %3 +%42 = OpMatrixTimesVector %14 %40 %41 +%43 = OpFSub %14 %29 %42 +%44 = OpExtInst %14 %1 Normalize %43 +OpReturnValue %44 +OpFunctionEnd +%48 = OpFunction %2 None %49 +%47 = OpLabel +%45 = OpVariable %46 Function +%50 = OpLoad %13 %23 +%52 = OpAccessChain %51 %25 %12 +OpBranch %53 +%53 = OpLabel +%54 = OpCompositeConstruct %14 %6 %3 %6 +%55 = OpCompositeConstruct %14 %6 %6 %6 +%56 = OpCompositeConstruct %22 %7 %9 %10 %11 %55 %54 +%57 = OpCompositeExtract %8 %56 0 +%58 = OpCompositeExtract %8 %56 1 +%59 = OpCompositeExtract %4 %56 2 +%60 = OpCompositeExtract %4 %56 3 +%61 = OpCompositeExtract %14 %56 4 +%62 = OpCompositeExtract %14 %56 5 +OpRayQueryInitializeKHR %45 %50 %57 %58 %61 %59 %62 %60 +OpBranch %63 +%63 = OpLabel +OpLoopMerge %64 %66 None +OpBranch %65 +%65 = OpLabel +%67 = OpRayQueryProceedKHR %17 %45 +OpSelectionMerge %68 None +OpBranchConditional %67 %68 %69 +%69 = OpLabel +OpBranch %64 +%68 = OpLabel +OpBranch %70 +%70 = OpLabel +OpBranch %71 +%71 = OpLabel +OpBranch %66 +%66 = OpLabel +OpBranch %63 +%64 = OpLabel +%73 = OpRayQueryGetIntersectionTypeKHR %8 %45 %72 +%74 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %8 %45 %72 +%75 = OpRayQueryGetIntersectionInstanceIdKHR %8 %45 %72 +%76 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %8 %45 %72 +%77 = OpRayQueryGetIntersectionGeometryIndexKHR %8 %45 %72 +%78 = OpRayQueryGetIntersectionPrimitiveIndexKHR %8 %45 %72 +%79 = OpRayQueryGetIntersectionTKHR %4 %45 %72 +%80 = OpRayQueryGetIntersectionBarycentricsKHR %16 %45 %72 +%81 = OpRayQueryGetIntersectionFrontFaceKHR %17 %45 %72 +%82 = OpRayQueryGetIntersectionObjectToWorldKHR %18 %45 %72 +%83 = OpRayQueryGetIntersectionWorldToObjectKHR %18 %45 %72 +%84 = OpCompositeConstruct %19 %73 %79 %74 %75 %76 %77 %78 %80 %81 %82 %83 +%86 = OpCompositeExtract %8 %84 0 +%87 = OpIEqual %17 %86 %12 +%88 = OpSelect %8 %87 %72 %12 +%89 = OpAccessChain %85 %52 %12 +OpStore %89 %88 +%91 = OpCompositeExtract %4 %84 1 +%92 = OpVectorTimesScalar %14 %54 %91 +%93 = OpFunctionCall %14 %31 %92 %84 +%94 = OpAccessChain %90 %52 %72 +OpStore %94 %93 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/atomicCompareExchange.wgsl b/tests/out/wgsl/atomicCompareExchange.wgsl index 2c213c8fec..bfad298fab 100644 --- a/tests/out/wgsl/atomicCompareExchange.wgsl +++ b/tests/out/wgsl/atomicCompareExchange.wgsl @@ -1,9 +1,9 @@ -struct gen___atomic_compare_exchange_result { +struct gen___atomic_compare_exchange_resultSint4_ { old_value: i32, exchanged: bool, } -struct gen___atomic_compare_exchange_result_1 { +struct gen___atomic_compare_exchange_resultUint4_ { old_value: u32, exchanged: bool, } diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 691c074a93..d968a0dfc1 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -571,6 +571,7 @@ fn convert_wgsl() { ("sprite", Targets::SPIRV), ("force_point_size_vertex_shader_webgl", Targets::GLSL), ("invariant", Targets::GLSL), + ("ray-query", Targets::SPIRV | Targets::METAL), ]; for &(name, targets) in inputs.iter() {