diff --git a/Cargo.lock b/Cargo.lock index a13aefcf3f..6fa89bf5b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1884,6 +1884,7 @@ dependencies = [ "rustc-demangle", "serde", "serde_json", + "smallvec", "spirv-tools", "tar", "tempfile", diff --git a/crates/rustc_codegen_spirv/Cargo.toml b/crates/rustc_codegen_spirv/Cargo.toml index 22c89f053b..d3abee849b 100644 --- a/crates/rustc_codegen_spirv/Cargo.toml +++ b/crates/rustc_codegen_spirv/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/EmbarkStudios/rust-gpu" [lib] crate-type = ["dylib"] +test = false [features] # By default, the use-compiled-tools is enabled, as doesn't require additional @@ -33,6 +34,7 @@ rspirv = { git = "https://github.com/gfx-rs/rspirv.git", rev = "279cc519166b6a0b rustc-demangle = "0.1.18" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +smallvec = "1.6.1" spirv-tools = { version = "0.4.0", default-features = false } tar = "0.4.30" topological-sort = "0.1" diff --git a/crates/rustc_codegen_spirv/src/abi.rs b/crates/rustc_codegen_spirv/src/abi.rs index 1f10d1e5c4..65b5867539 100644 --- a/crates/rustc_codegen_spirv/src/abi.rs +++ b/crates/rustc_codegen_spirv/src/abi.rs @@ -27,19 +27,12 @@ use std::fmt; /// tracking. #[derive(Default)] pub struct RecursivePointeeCache<'tcx> { - map: RefCell, StorageClass), PointeeDefState>>, + map: RefCell, PointeeDefState>>, } impl<'tcx> RecursivePointeeCache<'tcx> { - fn begin( - &self, - cx: &CodegenCx<'tcx>, - span: Span, - pointee: PointeeTy<'tcx>, - storage_class: StorageClass, - ) -> Option { - // Warning: storage_class must match the one called with end() - match self.map.borrow_mut().entry((pointee, storage_class)) { + fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option { + match self.map.borrow_mut().entry(pointee) { // State: This is the first time we've seen this type. Record that we're beginning to translate this type, // and start doing the translation. Entry::Vacant(entry) => { @@ -52,7 +45,11 @@ impl<'tcx> RecursivePointeeCache<'tcx> { // emit an OpTypeForwardPointer, and use that ID. (This is the juicy part of this algorithm) PointeeDefState::Defining => { let new_id = cx.emit_global().id(); - cx.emit_global().type_forward_pointer(new_id, storage_class); + // NOTE(eddyb) we emit `StorageClass::Generic` here, but later + // the linker will specialize the entire SPIR-V module to use + // storage classes inferred from `OpVariable`s. + cx.emit_global() + .type_forward_pointer(new_id, StorageClass::Generic); entry.insert(PointeeDefState::DefiningWithForward(new_id)); if !cx.builder.has_capability(Capability::Addresses) && !cx @@ -81,11 +78,9 @@ impl<'tcx> RecursivePointeeCache<'tcx> { cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>, - storage_class: StorageClass, pointee_spv: Word, ) -> Word { - // Warning: storage_class must match the one called with begin() - match self.map.borrow_mut().entry((pointee, storage_class)) { + match self.map.borrow_mut().entry(pointee) { // We should have hit begin() on this type already, which always inserts an entry. Entry::Vacant(_) => bug!("RecursivePointeeCache::end should always have entry"), Entry::Occupied(mut entry) => match *entry.get() { @@ -93,7 +88,6 @@ impl<'tcx> RecursivePointeeCache<'tcx> { // OpTypeForwardPointer has been emitted. This is the most common case. PointeeDefState::Defining => { let id = SpirvType::Pointer { - storage_class, pointee: pointee_spv, } .def(span, cx); @@ -105,7 +99,6 @@ impl<'tcx> RecursivePointeeCache<'tcx> { PointeeDefState::DefiningWithForward(id) => { entry.insert(PointeeDefState::Defined(id)); SpirvType::Pointer { - storage_class, pointee: pointee_spv, } .def_with_id(cx, span, id) @@ -261,11 +254,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> { PassMode::Cast(cast_target) => cast_target.spirv_type(span, cx), PassMode::Indirect { .. } => { let pointee = self.ret.layout.spirv_type(span, cx); - let pointer = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee, - } - .def(span, cx); + let pointer = SpirvType::Pointer { pointee }.def(span, cx); // Important: the return pointer comes *first*, not last. argument_types.push(pointer); SpirvType::Void.def(span, cx) @@ -304,11 +293,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> { extra_attrs: None, .. } => { let pointee = arg.layout.spirv_type(span, cx); - SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee, - } - .def(span, cx) + SpirvType::Pointer { pointee }.def(span, cx) } }; argument_types.push(arg_type); @@ -465,26 +450,20 @@ fn trans_scalar<'tcx>( Primitive::F32 => SpirvType::Float(32).def(span, cx), Primitive::F64 => SpirvType::Float(64).def(span, cx), Primitive::Pointer => { - let (storage_class, pointee_ty) = dig_scalar_pointee(cx, ty, index); - // Default to function storage class. - let storage_class = storage_class.unwrap_or(StorageClass::Function); + let pointee_ty = dig_scalar_pointee(cx, ty, index); // Pointers can be recursive. So, record what we're currently translating, and if we're already translating // the same type, emit an OpTypeForwardPointer and use that ID. - if let Some(predefined_result) = - cx.type_cache - .recursive_pointee_cache - .begin(cx, span, pointee_ty, storage_class) + if let Some(predefined_result) = cx + .type_cache + .recursive_pointee_cache + .begin(cx, span, pointee_ty) { predefined_result } else { let pointee = pointee_ty.spirv_type(span, cx); - cx.type_cache.recursive_pointee_cache.end( - cx, - span, - pointee_ty, - storage_class, - pointee, - ) + cx.type_cache + .recursive_pointee_cache + .end(cx, span, pointee_ty, pointee) } } } @@ -504,12 +483,12 @@ fn dig_scalar_pointee<'tcx>( cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>, index: Option, -) -> (Option, PointeeTy<'tcx>) { +) -> PointeeTy<'tcx> { match *ty.ty.kind() { TyKind::Ref(_, elem_ty, _) | TyKind::RawPtr(TypeAndMut { ty: elem_ty, .. }) => { let elem = cx.layout_of(elem_ty); match index { - None => (None, PointeeTy::Ty(elem)), + None => PointeeTy::Ty(elem), Some(index) => { if elem.is_unsized() { dig_scalar_pointee(cx, ty.field(cx, index), None) @@ -518,12 +497,12 @@ fn dig_scalar_pointee<'tcx>( // of ScalarPair could be deduced, but it's actually e.g. a sized pointer followed by some other // completely unrelated type, not a wide pointer. So, translate this as a single scalar, one // component of that ScalarPair. - (None, PointeeTy::Ty(elem)) + PointeeTy::Ty(elem) } } } } - TyKind::FnPtr(sig) if index.is_none() => (None, PointeeTy::Fn(sig)), + TyKind::FnPtr(sig) if index.is_none() => PointeeTy::Fn(sig), TyKind::Adt(def, _) if def.is_box() => { let ptr_ty = cx.layout_of(cx.tcx.mk_mut_ptr(ty.ty.boxed_ty())); dig_scalar_pointee(cx, ptr_ty, index) @@ -542,11 +521,8 @@ fn dig_scalar_pointee_adt<'tcx>( cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>, index: Option, -) -> (Option, PointeeTy<'tcx>) { - // Storage classes can only be applied on structs containing a single pointer field (because we said so), so we only - // need to handle the attribute here. - let storage_class = get_storage_class(cx, ty); - let result = match &ty.variants { +) -> PointeeTy<'tcx> { + match &ty.variants { // If it's a Variants::Multiple, then we want to emit the type of the dataful variant, not the type of the // discriminant. This is because the discriminant can e.g. have type *mut(), whereas we want the full underlying // type, only available in the dataful variant. @@ -603,21 +579,16 @@ fn dig_scalar_pointee_adt<'tcx>( }, } } - }; - match (storage_class, result) { - (storage_class, (None, result)) => (storage_class, result), - (None, (storage_class, result)) => (storage_class, result), - (Some(one), (Some(two), _)) => cx.tcx.sess.fatal(&format!( - "Double-applied storage class ({:?} and {:?}) on type {}", - one, two, ty.ty - )), } } -/// Handles `#[spirv(storage_class="blah")]`. Note this is only called in the scalar translation code, because this is only +/// Handles `#[spirv(storage_class="blah")]`. Note this is only called in the entry interface variables code, because this is only /// used for spooky builtin stuff, and we pinky promise to never have more than one pointer field in one of these. // TODO: Enforce this is only used in spirv-std. -fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option { +pub(crate) fn get_storage_class<'tcx>( + cx: &CodegenCx<'tcx>, + ty: TyAndLayout<'tcx>, +) -> Option { if let TyKind::Adt(adt, _substs) = ty.ty.kind() { for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) { if let SpirvAttribute::StorageClass(storage_class) = attr { diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 9db323ca16..e517887a74 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -727,11 +727,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn alloca(&mut self, ty: Self::Type, _align: Align) -> Self::Value { - let ptr_ty = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: ty, - } - .def(self.span(), self); + let ptr_ty = SpirvType::Pointer { pointee: ty }.def(self.span(), self); // "All OpVariable instructions in a function must be the first instructions in the first block." let mut builder = self.emit(); builder.select_block(Some(0)).unwrap(); @@ -779,10 +775,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { return value; } let ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "load called on variable that wasn't a pointer: {:?}", ty @@ -803,10 +796,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn atomic_load(&mut self, ptr: Self::Value, order: AtomicOrdering, _size: Size) -> Self::Value { let ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "atomic_load called on variable that wasn't a pointer: {:?}", ty @@ -906,10 +896,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value { let ptr_elem_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "store called on variable that wasn't a pointer: {:?}", ty @@ -946,10 +933,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { _size: Size, ) { let ptr_elem_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "atomic_store called on variable that wasn't a pointer: {:?}", ty @@ -979,15 +963,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn struct_gep(&mut self, ptr: Self::Value, idx: u64) -> Self::Value { - let (storage_class, result_pointee_type) = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class, - pointee, - } => match self.lookup_type(pointee) { - SpirvType::Adt { field_types, .. } => (storage_class, field_types[idx as usize]), + let result_pointee_type = match self.lookup_type(ptr.ty) { + SpirvType::Pointer { pointee } => match self.lookup_type(pointee) { + SpirvType::Adt { field_types, .. } => field_types[idx as usize], SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element, .. } - | SpirvType::Vector { element, .. } => (storage_class, element), + | SpirvType::Vector { element, .. } => element, other => self.fatal(&format!( "struct_gep not on struct, array, or vector type: {:?}, index {}", other, idx @@ -999,7 +980,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )), }; let result_type = SpirvType::Pointer { - storage_class, pointee: result_pointee_type, } .def(self.span(), self); @@ -1225,14 +1205,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn pointercast(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value { let val_pointee = match self.lookup_type(val.ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => self.fatal(&format!( "pointercast called on non-pointer source type: {:?}", other )), }; let dest_pointee = match self.lookup_type(dest_ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => self.fatal(&format!( "pointercast called on non-pointer dest type: {:?}", other @@ -1249,12 +1229,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .access_chain(dest_ty, None, val.def(self), indices) .unwrap() .with_type(dest_ty) - } else if self - .really_unsafe_ignore_bitcasts - .borrow() - .contains(&self.current_fn) - { - val } else { let result = self .emit() @@ -1531,7 +1505,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { return; } let src_pointee = match self.lookup_type(src.ty) { - SpirvType::Pointer { pointee, .. } => Some(pointee), + SpirvType::Pointer { pointee } => Some(pointee), _ => None, }; let src_element_size = src_pointee.and_then(|p| self.lookup_type(p).sizeof(self)); @@ -1592,7 +1566,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )); } let elem_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, _ => self.fatal(&format!( "memset called on non-pointer type: {}", self.debug_type(ptr.ty) @@ -1780,10 +1754,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { _weak: bool, ) -> Self::Value { let dst_pointee_ty = match self.lookup_type(dst.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "atomic_cmpxchg called on variable that wasn't a pointer: {:?}", ty @@ -1820,10 +1791,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { order: AtomicOrdering, ) -> Self::Value { let dst_pointee_ty = match self.lookup_type(dst.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => self.fatal(&format!( "atomic_rmw called on variable that wasn't a pointer: {:?}", ty diff --git a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs index 0ca218fa30..af7a172288 100644 --- a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs +++ b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs @@ -3,7 +3,7 @@ use crate::abi::ConvSpirvType; use crate::builder_spirv::{SpirvValue, SpirvValueExt}; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; -use rspirv::spirv::{CLOp, GLOp, StorageClass}; +use rspirv::spirv::{CLOp, GLOp}; use rustc_codegen_ssa::mir::operand::OperandRef; use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::{BuilderMethods, IntrinsicCallMethods}; @@ -103,11 +103,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { let mut ptr = args[0].immediate(); if let PassMode::Cast(ty) = fn_abi.ret.mode { let pointee = ty.spirv_type(self.span(), self); - let pointer = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee, - } - .def(self.span(), self); + let pointer = SpirvType::Pointer { pointee }.def(self.span(), self); ptr = self.pointercast(ptr, pointer); } let load = self.volatile_load(ptr); @@ -383,6 +379,9 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { if self.kernel_mode { self.cl_op(CLOp::ctz, ret_ty, [args[0].immediate()]) } else { + self.ext_inst + .borrow_mut() + .import_integer_functions_2_intel(self); self.emit() .u_count_trailing_zeros_intel( args[0].immediate().ty, diff --git a/crates/rustc_codegen_spirv/src/builder/mod.rs b/crates/rustc_codegen_spirv/src/builder/mod.rs index 2895e8a241..5f74b76d3a 100644 --- a/crates/rustc_codegen_spirv/src/builder/mod.rs +++ b/crates/rustc_codegen_spirv/src/builder/mod.rs @@ -12,7 +12,7 @@ use crate::abi::ConvSpirvType; use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt}; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; -use rspirv::spirv::{StorageClass, Word}; +use rspirv::spirv::Word; use rustc_codegen_ssa::mir::operand::OperandValue; use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::{ @@ -104,11 +104,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // "An OpAccessChain instruction is the equivalent of an LLVM getelementptr instruction where the first index element is zero." // https://github.com/gpuweb/gpuweb/issues/33 let mut result_indices = Vec::with_capacity(indices.len() - 1); - let (storage_class, mut result_pointee_type) = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { - storage_class, - pointee, - } => (storage_class, pointee), + let mut result_pointee_type = match self.lookup_type(ptr.ty) { + SpirvType::Pointer { pointee } => pointee, other_type => self.fatal(&format!( "GEP first deref not implemented for type {:?}", other_type @@ -125,7 +122,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }; } let result_type = SpirvType::Pointer { - storage_class, pointee: result_pointee_type, } .def(self.span(), self); @@ -330,11 +326,7 @@ impl<'a, 'tcx> ArgAbiMethods<'tcx> for Builder<'a, 'tcx> { self.fatal("unsized `ArgAbi` must be handled through `store_fn_arg`"); } else if let PassMode::Cast(cast) = arg_abi.mode { let cast_ty = cast.spirv_type(self.span(), self); - let cast_ptr_ty = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: cast_ty, - } - .def(self.span(), self); + let cast_ptr_ty = SpirvType::Pointer { pointee: cast_ty }.def(self.span(), self); let cast_dst = self.pointercast(dst.llval, cast_ptr_ty); self.store(val, cast_dst, arg_abi.layout.align.abi); } else { diff --git a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs index 10da830cfc..13c0c4caf1 100644 --- a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs +++ b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs @@ -263,11 +263,25 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { element: inst.operands[0].unwrap_id_ref(), } .def(self.span(), self), - Op::TypePointer => SpirvType::Pointer { - storage_class: inst.operands[0].unwrap_storage_class(), - pointee: inst.operands[1].unwrap_id_ref(), + Op::TypePointer => { + let storage_class = inst.operands[0].unwrap_storage_class(); + if storage_class != StorageClass::Generic { + self.struct_err("TypePointer in asm! requires `Generic` storage class") + .note(&format!( + "`{:?}` storage class was specified", + storage_class + )) + .help(&format!( + "the storage class will be inferred automatically (e.g. to `{:?}`)", + storage_class + )) + .emit(); + } + SpirvType::Pointer { + pointee: inst.operands[1].unwrap_id_ref(), + } + .def(self.span(), self) } - .def(self.span(), self), Op::TypeImage => SpirvType::Image { sampled_type: inst.operands[0].unwrap_id_ref(), dim: inst.operands[1].unwrap_dim(), @@ -511,26 +525,28 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { use crate::spirv_type_constraints::{instruction_signatures, InstSig, TyListPat, TyPat}; #[derive(Debug)] - struct Mismatch; + struct Unapplicable; /// Recursively match `ty` against `pat`, returning one of: /// * `Ok(None)`: `pat` matched but contained no type variables /// * `Ok(Some(var))`: `pat` matched and `var` is the type variable /// * `Err(Mismatch)`: `pat` didn't match or isn't supported right now - fn apply_ty_pat( + fn match_ty_pat( cx: &CodegenCx<'_>, pat: &TyPat<'_>, ty: Word, - ) -> Result, Mismatch> { + ) -> Result, Unapplicable> { match pat { TyPat::Any => Ok(None), &TyPat::T => Ok(Some(ty)), TyPat::Either(a, b) => { - apply_ty_pat(cx, a, ty).or_else(|Mismatch| apply_ty_pat(cx, b, ty)) + match_ty_pat(cx, a, ty).or_else(|Unapplicable| match_ty_pat(cx, b, ty)) } _ => match (pat, cx.lookup_type(ty)) { + (TyPat::Any, _) | (&TyPat::T, _) | (TyPat::Either(..), _) => unreachable!(), + (TyPat::Void, SpirvType::Void) => Ok(None), - (TyPat::Pointer(pat), SpirvType::Pointer { pointee: ty, .. }) + (TyPat::Pointer(_, pat), SpirvType::Pointer { pointee: ty, .. }) | (TyPat::Vector(pat), SpirvType::Vector { element: ty, .. }) | ( TyPat::Vector4(pat), @@ -546,17 +562,19 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { }, ) | (TyPat::SampledImage(pat), SpirvType::SampledImage { image_type: ty }) => { - apply_ty_pat(cx, pat, ty) + match_ty_pat(cx, pat, ty) } - _ => Err(Mismatch), + _ => Err(Unapplicable), }, } } // FIXME(eddyb) try multiple signatures until one fits. let mut sig = match instruction_signatures(instruction.class.opcode)? { - [sig @ InstSig { - output: Some(_), .. + [sig + @ InstSig { + output_type: Some(_), + .. }] => *sig, _ => return None, }; @@ -564,9 +582,20 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { let mut combined_var = None; let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any()); - while let TyListPat::Cons { first: pat, suffix } = *sig.inputs { - let &ty = id_to_type_map.get(&ids.next()?)?; - match apply_ty_pat(self, pat, ty) { + while let TyListPat::Cons { first: pat, suffix } = *sig.input_types { + sig.input_types = suffix; + + let match_result = match id_to_type_map.get(&ids.next()?) { + Some(&ty) => match_ty_pat(self, pat, ty), + + // Non-value ID operand (or value operand of unknown type), + // only `TyPat::Any` is valid. + None => match pat { + TyPat::Any => Ok(None), + _ => Err(Unapplicable), + }, + }; + match match_result { Ok(Some(var)) => match combined_var { Some(combined_var) => { // FIXME(eddyb) this could use some error reporting @@ -580,11 +609,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { None => combined_var = Some(var), }, Ok(None) => {} - Err(Mismatch) => return None, + Err(Unapplicable) => return None, } - sig.inputs = suffix; } - match sig.inputs { + match sig.input_types { + TyListPat::Cons { .. } => unreachable!(), + TyListPat::Any => {} TyListPat::Nil => { if ids.next().is_some() { @@ -595,7 +625,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { } let var = combined_var?; - match sig.output.unwrap() { + match sig.output_type.unwrap() { &TyPat::T => Some(var), TyPat::Vector4(&TyPat::T) => Some( SpirvType::Vector { @@ -731,7 +761,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { Some(match kind { TypeofKind::Plain => ty, TypeofKind::Dereference => match self.lookup_type(ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => { self.tcx.sess.span_err( span, @@ -753,7 +783,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { self.check_reg(span, reg); match place { Some(place) => match self.lookup_type(place.llval.ty) { - SpirvType::Pointer { pointee, .. } => Some(pointee), + SpirvType::Pointer { pointee } => Some(pointee), other => { self.tcx.sess.span_err( span, diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index 12c705f050..caf3de1cdb 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -44,10 +44,7 @@ impl SpirvValue { global_var: _, } => { let ty = match cx.lookup_type(self.ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, ty => bug!("load called on variable that wasn't a pointer: {:?}", ty), }; Some(initializer.with_type(ty)) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index c9bbbbdf76..0245828133 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -271,7 +271,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { let (base_addr, _base_addr_space) = match self.tcx.global_alloc(ptr.alloc_id) { GlobalAlloc::Memory(alloc) => { let pointee = match self.lookup_type(ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => self.tcx.sess.fatal(&format!( "GlobalAlloc::Memory type not implemented: {}", other.debug(ty, self) @@ -306,28 +306,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { .fatal("Non-pointer-typed scalar_to_backend Scalar::Ptr not supported"); // unsafe { llvm::LLVMConstPtrToInt(llval, llty) } } else { - match (self.lookup_type(value.ty), self.lookup_type(ty)) { - ( - SpirvType::Pointer { - storage_class: a_space, - pointee: a, - }, - SpirvType::Pointer { - storage_class: b_space, - pointee: b, - }, - ) => { - if a_space != b_space { - // TODO: Emit the correct type that is passed into this function. - self.zombie_no_span( - value.def_cx(self), - "invalid pointer space in constant", - ); - } - assert_ty_eq!(self, a, b); - } - _ => assert_ty_eq!(self, value.ty, ty), - } + assert_ty_eq!(self, value.ty, ty); value } } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index 6d0156fe29..2f001e14e7 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -117,11 +117,6 @@ impl<'tcx> CodegenCx<'tcx> { let crate_relative_name = instance.to_string(); self.entry_stub(&instance, &fn_abi, declared, crate_relative_name, entry) } - SpirvAttribute::ReallyUnsafeIgnoreBitcasts => { - self.really_unsafe_ignore_bitcasts - .borrow_mut() - .insert(declared); - } SpirvAttribute::UnrollLoops => { self.unroll_loops_decorations .borrow_mut() @@ -184,14 +179,11 @@ impl<'tcx> CodegenCx<'tcx> { } fn declare_global(&self, span: Span, ty: Word) -> SpirvValue { - let ptr_ty = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: ty, - } - .def(span, self); + let ptr_ty = SpirvType::Pointer { pointee: ty }.def(span, self); + // FIXME(eddyb) figure out what the correct storage class is. let result = self .emit_global() - .variable(ptr_ty, None, StorageClass::Function, None) + .variable(ptr_ty, None, StorageClass::Private, None) .with_type(ptr_ty); // TODO: These should be StorageClass::Private, so just zombie for now. self.zombie_with_span(result.def_cx(self), span, "Globals are not supported yet"); @@ -264,7 +256,7 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> { Err(_) => return, }; let value_ty = match self.lookup_type(g.ty) { - SpirvType::Pointer { pointee, .. } => pointee, + SpirvType::Pointer { pointee } => pointee, other => self.tcx.sess.fatal(&format!( "global had non-pointer type {}", other.debug(g.ty, self) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 08cc9bdd0d..560da694fa 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -1,10 +1,12 @@ use super::CodegenCx; +use crate::abi::ConvSpirvType; use crate::builder_spirv::SpirvValue; use crate::spirv_type::SpirvType; use crate::symbols::{parse_attrs, Entry, SpirvAttribute}; use rspirv::dr::Operand; use rspirv::spirv::{Decoration, ExecutionModel, FunctionControl, StorageClass, Word}; -use rustc_hir::{Param, PatKind}; +use rustc_hir as hir; +use rustc_middle::ty::layout::TyAndLayout; use rustc_middle::ty::{Instance, Ty}; use rustc_span::Span; use rustc_target::abi::call::{FnAbi, PassMode}; @@ -18,7 +20,7 @@ impl<'tcx> CodegenCx<'tcx> { pub fn entry_stub( &self, instance: &Instance<'_>, - fn_abi: &FnAbi<'_, Ty<'_>>, + fn_abi: &FnAbi<'tcx, Ty<'tcx>>, entry_func: SpirvValue, name: String, entry: Entry, @@ -60,6 +62,7 @@ impl<'tcx> CodegenCx<'tcx> { self.shader_entry_stub( self.tcx.def_span(instance.def_id()), entry_func, + fn_abi, body.params, name, execution_model, @@ -78,7 +81,8 @@ impl<'tcx> CodegenCx<'tcx> { &self, span: Span, entry_func: SpirvValue, - hir_params: &[Param<'tcx>], + entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>, + hir_params: &[hir::Param<'tcx>], name: String, execution_model: ExecutionModel, ) -> Word { @@ -88,11 +92,11 @@ impl<'tcx> CodegenCx<'tcx> { arguments: vec![], } .def(span, self); - let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) { + let entry_func_return_type = match self.lookup_type(entry_func.ty) { SpirvType::Function { return_type, - arguments, - } => (return_type, arguments), + arguments: _, + } => return_type, other => self.tcx.sess.fatal(&format!( "Invalid entry_stub type: {}", other.debug(entry_func.ty, self) @@ -100,11 +104,12 @@ impl<'tcx> CodegenCx<'tcx> { }; let mut decoration_locations = HashMap::new(); // Create OpVariables before OpFunction so they're global instead of local vars. - let arguments = entry_func_args + let arguments = entry_fn_abi + .args .iter() .zip(hir_params) - .map(|(&arg, hir_param)| { - self.declare_parameter(arg, hir_param, &mut decoration_locations) + .map(|(entry_fn_arg, hir_param)| { + self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations) }) .collect::>(); let mut emit = self.emit_global(); @@ -113,7 +118,7 @@ impl<'tcx> CodegenCx<'tcx> { .unwrap(); emit.begin_block(None).unwrap(); emit.function_call( - entry_func_return, + entry_func_return_type, None, entry_func.def_cx(self), arguments.iter().map(|&(a, _)| a), @@ -139,24 +144,26 @@ impl<'tcx> CodegenCx<'tcx> { fn declare_parameter( &self, - arg: Word, - hir_param: &Param<'tcx>, + layout: TyAndLayout<'tcx>, + hir_param: &hir::Param<'tcx>, decoration_locations: &mut HashMap, ) -> (Word, StorageClass) { - let storage_class = match self.lookup_type(arg) { - SpirvType::Pointer { storage_class, .. } => storage_class, - other => self.tcx.sess.fatal(&format!( - "Invalid entry arg type {}", - other.debug(arg, self) - )), - }; + let storage_class = crate::abi::get_storage_class(self, layout).unwrap_or_else(|| { + self.tcx.sess.span_fatal( + hir_param.span, + &format!("invalid entry param type `{}`", layout.ty), + ); + }); let mut has_location = matches!( storage_class, StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant ); // Note: this *declares* the variable too. - let variable = self.emit_global().variable(arg, None, storage_class, None); - if let PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind { + let spirv_type = layout.spirv_type(hir_param.span, self); + let variable = self + .emit_global() + .variable(spirv_type, None, storage_class, None); + if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind { self.emit_global().name(variable, ident.to_string()); } for attr in parse_attrs(self, hir_param.attrs) { diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index 3158d4c433..83acd6274a 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -30,7 +30,7 @@ use rustc_target::abi::call::FnAbi; use rustc_target::abi::{HasDataLayout, TargetDataLayout}; use rustc_target::spec::{HasTargetSpec, Target}; use std::cell::{Cell, RefCell}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::iter::once; pub struct CodegenCx<'tcx> { @@ -58,7 +58,6 @@ pub struct CodegenCx<'tcx> { /// Cache of all the builtin symbols we need pub sym: Box, pub instruction_table: InstructionTable, - pub really_unsafe_ignore_bitcasts: RefCell>, pub libm_intrinsics: RefCell>, /// Simple `panic!("...")` and builtin panics (from MIR `Assert`s) call `#[lang = "panic"]`. @@ -116,7 +115,6 @@ impl<'tcx> CodegenCx<'tcx> { kernel_mode, sym, instruction_table: InstructionTable::new(), - really_unsafe_ignore_bitcasts: Default::default(), libm_intrinsics: Default::default(), panic_fn_id: Default::default(), panic_bounds_check_fn_id: Default::default(), @@ -217,18 +215,15 @@ impl<'tcx> CodegenCx<'tcx> { /// See note on `SpirvValueKind::ConstantPointer` pub fn make_constant_pointer(&self, span: Span, value: SpirvValue) -> SpirvValue { - let ty = SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: value.ty, - } - .def(span, self); + let ty = SpirvType::Pointer { pointee: value.ty }.def(span, self); let initializer = value.def_cx(self); // Create these up front instead of on demand in SpirvValue::def because // SpirvValue::def can't use cx.emit() + // FIXME(eddyb) figure out what the correct storage class is. let global_var = self.emit_global() - .variable(ty, None, StorageClass::Function, Some(initializer)); + .variable(ty, None, StorageClass::Private, Some(initializer)); // In all likelihood, this zombie message will get overwritten in SpirvValue::def_with_span // to the use site of this constant. However, if this constant happens to never get used, we diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs index eab5588a96..281dab8e76 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs @@ -1,7 +1,6 @@ use super::CodegenCx; use crate::abi::ConvSpirvType; use crate::spirv_type::SpirvType; -use rspirv::spirv::StorageClass; use rspirv::spirv::Word; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeMethods, LayoutTypeMethods}; @@ -181,25 +180,14 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> { } } fn type_ptr_to(&self, ty: Self::Type) -> Self::Type { - SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: ty, - } - .def(DUMMY_SP, self) + SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self) } fn type_ptr_to_ext(&self, ty: Self::Type, _address_space: AddressSpace) -> Self::Type { - SpirvType::Pointer { - storage_class: StorageClass::Function, - pointee: ty, - } - .def(DUMMY_SP, self) + SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self) } fn element_type(&self, ty: Self::Type) -> Self::Type { match self.lookup_type(ty) { - SpirvType::Pointer { - storage_class: _, - pointee, - } => pointee, + SpirvType::Pointer { pointee } => pointee, SpirvType::Vector { element, .. } => element, spirv_type => self.tcx.sess.fatal(&format!( "element_type called on invalid type: {:?}", diff --git a/crates/rustc_codegen_spirv/src/link.rs b/crates/rustc_codegen_spirv/src/link.rs index 2d22ebd659..fba0100f9e 100644 --- a/crates/rustc_codegen_spirv/src/link.rs +++ b/crates/rustc_codegen_spirv/src/link.rs @@ -124,6 +124,13 @@ fn link_exe( let spv_binary = do_link(sess, &objects, &rlibs, legalize); + if let Ok(ref path) = std::env::var("DUMP_POST_LINK") { + File::create(path) + .unwrap() + .write_all(spirv_tools::binary::from_binary(&spv_binary)) + .unwrap(); + } + let spv_binary = if sess.opts.optimize != OptLevel::No || sess.opts.debuginfo == DebugInfo::None { let _timer = sess.timer("link_spirv_opt"); diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 1d434ae11f..f09684338b 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -9,13 +9,14 @@ mod inline; mod mem2reg; mod new_structurizer; mod simple_passes; +mod specializer; mod structurizer; mod zombies; use crate::decorations::{CustomDecoration, UnrollLoopsDecoration}; use rspirv::binary::Consumer; -use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader}; -use rspirv::spirv::{Op, Word}; +use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader, Operand}; +use rspirv::spirv::{Op, StorageClass, Word}; use rustc_errors::ErrorReported; use rustc_session::Session; use std::collections::HashMap; @@ -106,6 +107,17 @@ pub fn link(sess: &Session, mut inputs: Vec, opts: &Options) -> Result, opts: &Options) -> Result>, +//! ``` + +use crate::spirv_type_constraints::{self, InstSig, StorageClassPat, TyListPat, TyPat}; +use indexmap::{IndexMap, IndexSet}; +use rspirv::dr::{Builder, Function, Instruction, Module, Operand}; +use rspirv::spirv::{Op, StorageClass, Word}; +use rustc_data_structures::captures::Captures; +use smallvec::SmallVec; +use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; +use std::convert::{TryFrom, TryInto}; +use std::ops::{Range, RangeTo}; +use std::{fmt, io, iter, mem, slice}; + +// FIXME(eddyb) move this elsewhere. +struct FmtBy) -> fmt::Result>(F); + +impl) -> fmt::Result> fmt::Debug for FmtBy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0(f) + } +} + +impl) -> fmt::Result> fmt::Display for FmtBy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0(f) + } +} + +pub trait Specialization { + /// Return `true` if the specializer should replace every occurence of + /// `operand` with some other inferred `Operand`. + fn specialize_operand(&self, operand: &Operand) -> bool; + + /// The operand that should be used to replace unresolved inference variables, + /// i.e. the uses of operands for which `specialize_operand` returns `true`, + /// but which none of the instructions in the same SPIR-V function require + /// any particular concrete value or relate it to the function's signature, + /// so an arbitrary choice can be made (as long as it's valid SPIR-V etc.). + fn concrete_fallback(&self) -> Operand; +} + +/// Helper to avoid needing an `impl` of `Specialization`, while allowing the rest +/// of this module to use `Specialization` (instead of `Fn(&Operand) -> bool`). +pub struct SimpleSpecialization bool> { + pub specialize_operand: SO, + pub concrete_fallback: Operand, +} + +impl bool> Specialization for SimpleSpecialization { + fn specialize_operand(&self, operand: &Operand) -> bool { + (self.specialize_operand)(operand) + } + fn concrete_fallback(&self) -> Operand { + self.concrete_fallback.clone() + } +} + +pub fn specialize(module: Module, specialization: impl Specialization) -> Module { + // FIXME(eddyb) use `log`/`tracing` instead. + let debug = std::env::var("SPECIALIZER_DEBUG").is_ok(); + let dump_instances = std::env::var("SPECIALIZER_DUMP_INSTANCES").ok(); + + let mut debug_names = HashMap::new(); + if debug || dump_instances.is_some() { + debug_names = module + .debugs + .iter() + .filter(|inst| inst.class.opcode == Op::Name) + .map(|inst| { + ( + inst.operands[0].unwrap_id_ref(), + inst.operands[1].unwrap_literal_string().to_string(), + ) + }) + .collect(); + } + + let mut specializer = Specializer { + specialization, + + debug, + debug_names, + + generics: IndexMap::new(), + int_consts: HashMap::new(), + }; + + specializer.collect_generics(&module); + + let call_graph = CallGraph::collect(&module); + let mut non_generic_replacements = vec![]; + for func_idx in call_graph.post_order() { + if let Some(replacements) = specializer.infer_function(&module.functions[func_idx]) { + non_generic_replacements.push((func_idx, replacements)); + } + } + + let mut expander = Expander::new(&specializer, module); + + // For non-"generic" functions, we can apply `replacements` right away, + // though not before finishing inference for all functions first + // (because `expander` needs to borrow `specializer` immutably). + if debug { + eprintln!("non-generic replacements:"); + } + for (func_idx, replacements) in non_generic_replacements { + let mut func = mem::replace( + &mut expander.builder.module_mut().functions[func_idx], + Function::new(), + ); + if debug { + let empty = replacements.with_instance.is_empty() + && replacements.with_concrete_or_param.is_empty(); + if !empty { + eprintln!(" in %{}:", func.def_id().unwrap()); + } + } + for (loc, operand) in + replacements.to_concrete(&[], |instance| expander.alloc_instance_id(instance)) + { + if debug { + eprintln!(" {} -> {:?}", operand, loc); + } + func.index_set(loc, operand.into()); + } + expander.builder.module_mut().functions[func_idx] = func; + } + expander.propagate_instances(); + + if let Some(path) = dump_instances { + expander + .dump_instances(&mut std::fs::File::create(path).unwrap()) + .unwrap(); + } + + expander.expand_module() +} + +// FIXME(eddyb) use newtyped indices and `IndexVec`. +type FuncIdx = usize; + +struct CallGraph { + entry_points: IndexSet, + + /// `callees[i].contains(j)` implies `functions[i]` calls `functions[j]`. + callees: Vec>, +} + +impl CallGraph { + fn collect(module: &Module) -> Self { + let func_id_to_idx: HashMap<_, _> = module + .functions + .iter() + .enumerate() + .map(|(i, func)| (func.def_id().unwrap(), i)) + .collect(); + let entry_points = module + .entry_points + .iter() + .map(|entry| { + assert_eq!(entry.class.opcode, Op::EntryPoint); + func_id_to_idx[&entry.operands[1].unwrap_id_ref()] + }) + .collect(); + let callees = module + .functions + .iter() + .map(|func| { + func.all_inst_iter() + .filter(|inst| inst.class.opcode == Op::FunctionCall) + .map(|inst| func_id_to_idx[&inst.operands[0].unwrap_id_ref()]) + .collect() + }) + .collect(); + Self { + entry_points, + callees, + } + } + + /// Order functions using a post-order traversal, i.e. callees before callers. + // FIXME(eddyb) replace this with `rustc_data_structures::graph::iterate` + // (or similar). + fn post_order(&self) -> Vec { + let num_funcs = self.callees.len(); + + // FIXME(eddyb) use a proper bitset. + let mut visited = vec![false; num_funcs]; + let mut post_order = Vec::with_capacity(num_funcs); + + // Visit the call graph with entry points as roots. + for &entry in &self.entry_points { + self.post_order_step(entry, &mut visited, &mut post_order); + } + + // Also visit any functions that were not reached from entry points + // (they might be dead but they should be processed nonetheless). + for func in 0..num_funcs { + if !visited[func] { + self.post_order_step(func, &mut visited, &mut post_order); + } + } + + post_order + } + + fn post_order_step(&self, func: FuncIdx, visited: &mut [bool], post_order: &mut Vec) { + if visited[func] { + return; + } + visited[func] = true; + + for &callee in &self.callees[func] { + self.post_order_step(callee, visited, post_order) + } + + post_order.push(func); + } +} + +// HACK(eddyb) `Copy` version of `Operand` that only includes the cases that +// are relevant to the inference algorithm (and is also smaller). +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +enum CopyOperand { + IdRef(Word), + StorageClass(StorageClass), +} + +#[derive(Debug)] +struct NotSupportedAsCopyOperand(Operand); + +impl TryFrom<&Operand> for CopyOperand { + type Error = NotSupportedAsCopyOperand; + fn try_from(operand: &Operand) -> Result { + match *operand { + Operand::IdRef(id) => Ok(Self::IdRef(id)), + Operand::StorageClass(s) => Ok(Self::StorageClass(s)), + _ => Err(NotSupportedAsCopyOperand(operand.clone())), + } + } +} + +impl Into for CopyOperand { + fn into(self) -> Operand { + match self { + Self::IdRef(id) => Operand::IdRef(id), + Self::StorageClass(s) => Operand::StorageClass(s), + } + } +} + +impl fmt::Display for CopyOperand { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::IdRef(id) => write!(f, "%{}", id), + Self::StorageClass(s) => write!(f, "{:?}", s), + } + } +} + +/// The "value" of a `Param`/`InferVar`, if we know anything about it. +// FIXME(eddyb) find a more specific name. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +enum Value { + /// The value of this `Param`/`InferVar` is completely known. + Unknown, + + /// The value of this `Param`/`InferVar` is known to be a specific `Operand`. + Known(CopyOperand), + + /// The value of this `Param`/`InferVar` is the same as another `Param`/`InferVar`. + /// + /// For consistency, and to allow some `Param` <-> `InferVar` mapping, + /// all cases of `values[y] == Value::SameAs(x)` should have `x < y`, + /// i.e. "newer" variables must be redirected to "older" ones. + SameAs(T), +} + +// FIXME(eddyb) clippy bug suggests `Self` even when it'd be a type mismatch. +#[allow(clippy::use_self)] +impl Value { + fn map_var(self, f: impl FnOnce(T) -> U) -> Value { + match self { + Value::Unknown => Value::Unknown, + Value::Known(o) => Value::Known(o), + Value::SameAs(var) => Value::SameAs(f(var)), + } + } +} + +/// Newtype'd "generic" parameter index. +// FIXME(eddyb) use `rustc_index` for this instead. +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct Param(u32); + +impl fmt::Display for Param { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "${}", self.0) + } +} + +impl Param { + // HACK(eddyb) this works around `Range` not being iterable + // because `Param` doesn't implement the (unstable) `Step` trait. + fn range_iter(range: &Range) -> impl Iterator + Clone { + (range.start.0..range.end.0).map(Self) + } +} + +/// A specific instance of a "generic" global/function. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct Instance { + generic_id: Word, + generic_args: GA, +} + +// FIXME(eddyb) clippy bug suggests `Self` even when it'd be a type mismatch. +#[allow(clippy::use_self)] +impl Instance { + fn as_ref(&self) -> Instance<&GA> { + Instance { + generic_id: self.generic_id, + generic_args: &self.generic_args, + } + } + + fn map_generic_args(self, f: impl FnMut(T) -> U) -> Instance + where + GA: IntoIterator, + GA2: std::iter::FromIterator, + { + Instance { + generic_id: self.generic_id, + generic_args: self.generic_args.into_iter().map(f).collect(), + } + } + + // FIXME(eddyb) implement `Step` for `Param` and `InferVar` instead. + fn display<'a, T: fmt::Display, GAI: Iterator + Clone>( + &'a self, + f: impl FnOnce(&'a GA) -> GAI, + ) -> impl fmt::Display { + let &Self { + generic_id, + ref generic_args, + } = self; + let generic_args_iter = f(generic_args); + FmtBy(move |f| { + write!(f, "%{}<", generic_id)?; + for (i, arg) in generic_args_iter.clone().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; + } + write!(f, ">") + }) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum InstructionLocation { + Module, + FnParam(usize), + FnBody { + /// Block index within a function. + block_idx: usize, + + /// Instruction index within the block with index `block_idx`. + inst_idx: usize, + }, +} + +trait OperandIndexGetSet { + fn index_get(&self, index: I) -> Operand; + fn index_set(&mut self, index: I, operand: Operand); +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum OperandIdx { + ResultType, + Input(usize), +} + +impl OperandIndexGetSet for Instruction { + fn index_get(&self, idx: OperandIdx) -> Operand { + match idx { + OperandIdx::ResultType => Operand::IdRef(self.result_type.unwrap()), + OperandIdx::Input(i) => self.operands[i].clone(), + } + } + fn index_set(&mut self, idx: OperandIdx, operand: Operand) { + match idx { + OperandIdx::ResultType => self.result_type = Some(operand.unwrap_id_ref()), + OperandIdx::Input(i) => self.operands[i] = operand, + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct OperandLocation { + inst_loc: InstructionLocation, + operand_idx: OperandIdx, +} + +impl OperandIndexGetSet for Instruction { + fn index_get(&self, loc: OperandLocation) -> Operand { + assert_eq!(loc.inst_loc, InstructionLocation::Module); + self.index_get(loc.operand_idx) + } + fn index_set(&mut self, loc: OperandLocation, operand: Operand) { + assert_eq!(loc.inst_loc, InstructionLocation::Module); + self.index_set(loc.operand_idx, operand); + } +} + +impl OperandIndexGetSet for Function { + fn index_get(&self, loc: OperandLocation) -> Operand { + let inst = match loc.inst_loc { + InstructionLocation::Module => self.def.as_ref().unwrap(), + InstructionLocation::FnParam(i) => &self.parameters[i], + InstructionLocation::FnBody { + block_idx, + inst_idx, + } => &self.blocks[block_idx].instructions[inst_idx], + }; + inst.index_get(loc.operand_idx) + } + fn index_set(&mut self, loc: OperandLocation, operand: Operand) { + let inst = match loc.inst_loc { + InstructionLocation::Module => self.def.as_mut().unwrap(), + InstructionLocation::FnParam(i) => &mut self.parameters[i], + InstructionLocation::FnBody { + block_idx, + inst_idx, + } => &mut self.blocks[block_idx].instructions[inst_idx], + }; + inst.index_set(loc.operand_idx, operand); + } +} + +// FIXME(eddyb) this is a bit like `Value` but more explicit, +// and the name isn't too nice, but at least it's very clear. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +enum ConcreteOrParam { + Concrete(CopyOperand), + Param(Param), +} + +impl ConcreteOrParam { + /// Replace `Param(i)` with `generic_args[i]` while preserving `Concrete`. + fn apply_generic_args(self, generic_args: &[CopyOperand]) -> CopyOperand { + match self { + Self::Concrete(x) => x, + Self::Param(Param(i)) => generic_args[i as usize], + } + } +} + +#[derive(Debug)] +struct Replacements { + /// Operands that need to be replaced with instances of "generic" globals. + /// Keyed by instance to optimize for few instances used many times. + // FIXME(eddyb) fine-tune the length of `SmallVec<[_; 4]>` here. + with_instance: IndexMap>, Vec>, + + /// Operands that need to be replaced with a concrete operand or a parameter. + with_concrete_or_param: Vec<(OperandLocation, ConcreteOrParam)>, +} + +impl Replacements { + /// Apply `generic_args` to all the `ConcreteOrParam`s in this `Replacements` + /// (i.e. replacing `Param(i)` with `generic_args[i]`), producing a stream of + /// "replace the operand at `OperandLocation` with this concrete `CopyOperand`". + /// The `concrete_instance_id` closure should look up and/or allocate an ID + /// for a specific concrete `Instance`. + fn to_concrete<'a>( + &'a self, + generic_args: &'a [CopyOperand], + mut concrete_instance_id: impl FnMut(Instance>) -> Word + 'a, + ) -> impl Iterator + 'a { + self.with_instance + .iter() + .flat_map(move |(instance, locations)| { + let concrete = CopyOperand::IdRef(concrete_instance_id( + instance + .as_ref() + .map_generic_args(|x| x.apply_generic_args(generic_args)), + )); + locations.iter().map(move |&loc| (loc, concrete)) + }) + .chain( + self.with_concrete_or_param + .iter() + .map(move |&(loc, x)| (loc, x.apply_generic_args(generic_args))), + ) + } +} + +/// Computed "generic" shape for a SPIR-V global/function. In the interest of efficient +/// representation, the parameters of operands that are themselves "generic", +/// are concatenated by default, i.e. parameters come from disjoint leaves. +/// +/// As an example, for `%T = OpTypeStruct %A %B`, if `%A` and `%B` have 2 and 3 +/// parameters, respectively, `%T` will have `A0, A1, B0, B1, B2` as parameters. +struct Generic { + param_count: u32, + + /// Defining instruction for this global (`OpType...`, `OpConstant...`, etc.) + /// or function (`OpFunction`). + // FIXME(eddyb) consider using `SmallVec` for the operands, or converting + // the operands into something more like `InferOperand`, but that would + // complicate `InferOperandList`, which has to be able to iterate them. + def: Instruction, + + /// `param_values[p]` constrains what "generic" args `Param(p)` could take. + /// This is only present if any constraints were inferred from the defining + /// instruction of a global, or the body of a function. Inference performed + /// after `collect_generics` (e.g. from instructions in function bodies) is + /// monotonic, i.e. it may only introduce more constraints, not remove any. + // FIXME(eddyb) use `rustc_index`'s `IndexVec` for this. + param_values: Option>>, + + /// Operand replacements that need to be performed on the defining instruction + /// of a global, or an entire function (including all instructions in its body), + /// in order to expand an instance of it. + replacements: Replacements, +} + +struct Specializer { + specialization: S, + + // FIXME(eddyb) use `log`/`tracing` instead. + debug: bool, + + // HACK(eddyb) if debugging is requested, this is used to quickly get `OpName`s. + debug_names: HashMap, + + // FIXME(eddyb) compact SPIR-V IDs to allow flatter maps. + generics: IndexMap, + + /// Integer `OpConstant`s (i.e. containing a `LiteralInt32`), to be used + /// for interpreting `TyPat::IndexComposite` (such as for `OpAccessChain`). + int_consts: HashMap, +} + +impl Specializer { + /// Returns the number of "generic" parameters `operand` "takes", either + /// because it's specialized by, or it refers to a "generic" global/function. + /// In the latter case, the `&Generic` for that global/function is also returned. + fn params_needed_by(&self, operand: &Operand) -> (u32, Option<&Generic>) { + if self.specialization.specialize_operand(operand) { + // Each operand we specialize by is one leaf "generic" parameter. + (1, None) + } else if let Operand::IdRef(id) = operand { + self.generics + .get(id) + .map_or((0, None), |generic| (generic.param_count, Some(generic))) + } else { + (0, None) + } + } + + fn collect_generics(&mut self, module: &Module) { + // Process all defining instructions for globals (types, constants, + // and module-scoped variables), and functions' `OpFunction` instructions, + // but note that for `OpFunction`s only the signature is considered, + // actual inference based on bodies happens later, in `infer_function`. + let types_global_values_and_functions = module + .types_global_values + .iter() + .chain(module.functions.iter().filter_map(|f| f.def.as_ref())); + + let mut forward_declared_pointers = HashSet::new(); + for inst in types_global_values_and_functions { + let result_id = inst.result_id.unwrap_or_else(|| { + unreachable!( + "Op{:?} is in `types_global_values` but not have a result ID", + inst.class.opcode + ); + }); + + if inst.class.opcode == Op::TypeForwardPointer { + forward_declared_pointers.insert(inst.operands[0].unwrap_id_ref()); + } + if forward_declared_pointers.remove(&result_id) { + // HACK(eddyb) this is a forward-declared pointer, pretend + // it's not "generic" at all to avoid breaking the rest of + // the logic - see module-level docs for how this should be + // handled in the future to support recursive data types. + assert_eq!(inst.class.opcode, Op::TypePointer); + continue; + } + + // Record all integer `OpConstant`s (used for `IndexComposite`). + if inst.class.opcode == Op::Constant { + if let Operand::LiteralInt32(x) = inst.operands[0] { + self.int_consts.insert(result_id, x); + } + } + + // Instantiate `inst` in a fresh inference context, to determine + // how many parameters it needs, and how they might be constrained. + let (param_count, param_values, replacements) = { + let mut infer_cx = InferCx::new(self); + infer_cx.instantiate_instruction(inst, InstructionLocation::Module); + + let param_count = infer_cx.infer_var_values.len() as u32; + + // FIXME(eddyb) dedup this with `infer_function`. + let param_values = infer_cx + .infer_var_values + .iter() + .map(|v| v.map_var(|InferVar(i)| Param(i))); + // Only allocate `param_values` if they constrain parameters. + let param_values = if param_values.clone().any(|v| v != Value::Unknown) { + Some(param_values.collect()) + } else { + None + }; + + ( + param_count, + param_values, + infer_cx.into_replacements(..Param(param_count)), + ) + }; + + // Inference variables become "generic" parameters. + if param_count > 0 { + self.generics.insert( + result_id, + Generic { + param_count, + def: inst.clone(), + param_values, + replacements, + }, + ); + } + } + } + + /// Perform inference across the entire definition of `func`, including all + /// the instructions in its body, and either store the resulting `Replacements` + /// in its `Generic` (if `func` is "generic"), or return them otherwise. + fn infer_function(&mut self, func: &Function) -> Option { + let func_id = func.def_id().unwrap(); + + let param_count = self + .generics + .get(&func_id) + .map_or(0, |generic| generic.param_count); + + let (param_values, replacements) = { + let mut infer_cx = InferCx::new(self); + infer_cx.instantiate_function(func); + + // FIXME(eddyb) dedup this with `collect_generics`. + let param_values = infer_cx.infer_var_values[..param_count as usize] + .iter() + .map(|v| v.map_var(|InferVar(i)| Param(i))); + // Only allocate `param_values` if they constrain parameters. + let param_values = if param_values.clone().any(|v| v != Value::Unknown) { + Some(param_values.collect()) + } else { + None + }; + + ( + param_values, + infer_cx.into_replacements(..Param(param_count)), + ) + }; + + if let Some(generic) = self.generics.get_mut(&func_id) { + // All constraints `func` could have from `collect_generics` + // would have to come from its `OpTypeFunction`, but types don't have + // internal constraints like e.g. `OpConstant*` and `OpVariable` do. + assert!(generic.param_values.is_none()); + + generic.param_values = param_values; + generic.replacements = replacements; + + None + } else { + Some(replacements) + } + } +} + +/// Newtype'd inference variable index. +// FIXME(eddyb) use `rustc_index` for this instead. +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct InferVar(u32); + +impl fmt::Display for InferVar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "?{}", self.0) + } +} + +impl InferVar { + // HACK(eddyb) this works around `Range` not being iterable + // because `InferVar` doesn't implement the (unstable) `Step` trait. + fn range_iter(range: &Range) -> impl Iterator + Clone { + (range.start.0..range.end.0).map(Self) + } +} + +struct InferCx<'a, S: Specialization> { + specializer: &'a Specializer, + + /// `infer_var_values[i]` holds the current state of `InferVar(i)`. + /// Each inference variable starts out as `Unknown`, may become `SameAs` + /// pointing to another inference variable, but eventually inference must + /// result in `Known` values (i.e. concrete `Operand`s). + // FIXME(eddyb) use `rustc_index`'s `IndexVec` for this. + infer_var_values: Vec>, + + /// Instantiated *Result Type* of each instruction that has any `InferVar`s, + /// used when an instruction's result is an input to a later instruction. + /// + /// Note that for consistency, for `OpFunction` this contains *Function Type* + /// instead of *Result Type*, which is inexplicably specified as: + /// > *Result Type* must be the same as the *Return Type* declared in *Function Type* + type_of_result: IndexMap, + + /// Operands that need to be replaced with instances of "generic" globals/functions + /// (taking as "generic" arguments the results of inference). + instantiated_operands: Vec<(OperandLocation, Instance>)>, + + /// Operands that need to be replaced with results of inference. + inferred_operands: Vec<(OperandLocation, InferVar)>, +} + +impl<'a, S: Specialization> InferCx<'a, S> { + fn new(specializer: &'a Specializer) -> Self { + InferCx { + specializer, + + infer_var_values: vec![], + type_of_result: IndexMap::new(), + instantiated_operands: vec![], + inferred_operands: vec![], + } + } +} + +#[derive(Clone, Debug, PartialEq)] +enum InferOperand { + Unknown, + Var(InferVar), + Concrete(CopyOperand), + Instance(Instance>), +} + +impl InferOperand { + /// Construct an `InferOperand` based on whether `operand` refers to some + /// "generic" definition, or we're specializing by it. + /// Also returns the remaining inference variables, not used by this operand. + fn from_operand_and_generic_args( + operand: &Operand, + generic_args: Range, + cx: &InferCx<'_, impl Specialization>, + ) -> (Self, Range) { + let (needed, generic) = cx.specializer.params_needed_by(operand); + let split = InferVar(generic_args.start.0 + needed); + let (generic_args, rest) = (generic_args.start..split, split..generic_args.end); + ( + if generic.is_some() { + Self::Instance(Instance { + generic_id: operand.unwrap_id_ref(), + generic_args, + }) + } else if needed == 0 { + CopyOperand::try_from(operand).map_or(Self::Unknown, Self::Concrete) + } else { + assert_eq!(needed, 1); + Self::Var(generic_args.start) + }, + rest, + ) + } + + fn display_with_infer_var_values<'a>( + &'a self, + infer_var_value: impl Fn(InferVar) -> Value + Copy + 'a, + ) -> impl fmt::Display + '_ { + FmtBy(move |f| { + let var_with_value = |v| { + FmtBy(move |f| { + write!(f, "{}", v)?; + match infer_var_value(v) { + Value::Unknown => Ok(()), + Value::Known(o) => write!(f, " = {}", o), + Value::SameAs(v) => write!(f, " = {}", v), + } + }) + }; + match self { + Self::Unknown => write!(f, "_"), + Self::Var(v) => write!(f, "{}", var_with_value(*v)), + Self::Concrete(o) => write!(f, "{}", o), + Self::Instance(instance) => write!( + f, + "{}", + instance.display(|generic_args| { + InferVar::range_iter(generic_args).map(var_with_value) + }) + ), + } + }) + } + + fn display_with_infer_cx<'a>( + &'a self, + cx: &'a InferCx<'_, impl Specialization>, + ) -> impl fmt::Display + '_ { + self.display_with_infer_var_values(move |v| { + // HACK(eddyb) can't use `resolve_infer_var` because that mutates + // `InferCx` (for the "path compression" union-find optimization). + let get = |v: InferVar| cx.infer_var_values[v.0 as usize]; + let mut value = get(v); + while let Value::SameAs(v) = value { + let next = get(v); + if next == Value::Unknown { + break; + } + value = next; + } + value + }) + } +} + +impl fmt::Display for InferOperand { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.display_with_infer_var_values(|_| Value::Unknown) + .fmt(f) + } +} + +/// How to filter and/or map the operands in an `InferOperandList`, while iterating. +/// +/// Having this in `InferOperandList` itself, instead of using iterator combinators, +/// allows storing `InferOperandList`s directly in `Match`, for `TyPatList` matches. +#[derive(Copy, Clone, PartialEq, Eq)] +enum InferOperandListTransform { + /// The list is the result of keeping only ID operands, and mapping them to + /// their types (or `InferOperand::Unknown` for non-value operands, or + /// value operands which don't have a "generic" type). + /// + /// This is used to match against the `inputs` `TyListPat` of `InstSig`. + TypeOfId, +} + +#[derive(Clone, PartialEq)] +struct InferOperandList<'a> { + operands: &'a [Operand], + + /// Joined ranges of all `InferVar`s needed by individual `Operand`s, + /// either for `InferOperand::Instance` or `InferOperand::Var`. + all_generic_args: Range, + + transform: Option, +} + +impl<'a> InferOperandList<'a> { + fn split_first( + &self, + cx: &InferCx<'_, impl Specialization>, + ) -> Option<(InferOperand, InferOperandList<'a>)> { + let mut list = self.clone(); + loop { + let (first_operand, rest) = list.operands.split_first()?; + list.operands = rest; + + let (first, rest_args) = InferOperand::from_operand_and_generic_args( + first_operand, + list.all_generic_args.clone(), + cx, + ); + list.all_generic_args = rest_args; + + // Maybe filter this operand, but only *after* consuming the "generic" args for it. + match self.transform { + None => {} + + // Skip a non-ID operand. + Some(InferOperandListTransform::TypeOfId) => { + if first_operand.id_ref_any().is_none() { + continue; + } + } + } + + // Maybe replace this operand with a different one. + let first = match self.transform { + None => first, + + // Map `first` to its type. + Some(InferOperandListTransform::TypeOfId) => match first { + InferOperand::Concrete(CopyOperand::IdRef(id)) => cx + .type_of_result + .get(&id) + .cloned() + .unwrap_or(InferOperand::Unknown), + InferOperand::Unknown | InferOperand::Var(_) | InferOperand::Concrete(_) => { + InferOperand::Unknown + } + InferOperand::Instance(instance) => { + let generic = &cx.specializer.generics[&instance.generic_id]; + + // HACK(eddyb) work around the inexplicable fact that `OpFunction` is + // specified with a *Result Type* that isn't the type of its *Result*: + // > *Result Type* must be the same as the *Return Type* declared in *Function Type* + // So we use *Function Type* instead as the type of its *Result*, and + // we are helped by `instantiate_instruction`, which ensures that the + // "generic" args we have are specifically meant for *Function Type*. + let type_of_result = match generic.def.class.opcode { + Op::Function => Some(generic.def.operands[1].unwrap_id_ref()), + _ => generic.def.result_type, + }; + + match type_of_result { + Some(type_of_result) => { + InferOperand::from_operand_and_generic_args( + &Operand::IdRef(type_of_result), + instance.generic_args, + cx, + ) + .0 + } + None => InferOperand::Unknown, + } + } + }, + }; + + return Some((first, list)); + } + } + + fn iter<'b>( + &self, + cx: &'b InferCx<'_, impl Specialization>, + ) -> impl Iterator + 'b + where + 'a: 'b, + { + let mut list = self.clone(); + iter::from_fn(move || { + let (next, rest) = list.split_first(cx)?; + list = rest; + Some(next) + }) + } + + fn display_with_infer_cx<'b>( + &'b self, + cx: &'b InferCx<'a, impl Specialization>, + ) -> impl fmt::Display + '_ { + FmtBy(move |f| { + f.debug_list() + .entries(self.iter(cx).map(|operand| { + FmtBy(move |f| write!(f, "{}", operand.display_with_infer_cx(cx))) + })) + .finish() + }) + } +} + +/// `SmallVec` with a map interface. +#[derive(Default)] +struct SmallIntMap(SmallVec); + +impl SmallIntMap { + fn get(&self, i: usize) -> Option<&A::Item> { + self.0.get(i) + } + + fn get_mut_or_default(&mut self, i: usize) -> &mut A::Item + where + A::Item: Default, + { + let needed = i + 1; + if self.0.len() < needed { + self.0.resize_with(needed, Default::default); + } + &mut self.0[i] + } +} + +impl IntoIterator for SmallIntMap { + type Item = (usize, A::Item); + type IntoIter = iter::Enumerate>; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter().enumerate() + } +} + +impl<'a, A: smallvec::Array> IntoIterator for &'a mut SmallIntMap { + type Item = (usize, &'a mut A::Item); + type IntoIter = iter::Enumerate>; + fn into_iter(self) -> Self::IntoIter { + self.0.iter_mut().enumerate() + } +} + +#[derive(PartialEq)] +struct IndexCompositeMatch<'a> { + /// *Indexes* `Operand`s (see `TyPat::IndexComposite`'s doc comment for details). + indices: &'a [Operand], + + /// The result of indexing the composite type with all `indices`. + leaf: InferOperand, +} + +/// Inference success (e.g. type matched type pattern). +#[must_use] +#[derive(Default)] +struct Match<'a> { + /// Whether this success isn't guaranteed, because of missing information + /// (such as the defining instructions of non-"generic" types). + /// + /// If there are other alternatives, they will be attempted as well, + /// and merged using `Match::or` (if they don't result in `Unapplicable`). + ambiguous: bool, + + // FIXME(eddyb) create some type for these that allows providing common methods + // + /// `storage_class_var_found[i][..]` holds all the `InferOperand`s matched by + /// `StorageClassPat::Var(i)` (currently `i` is always `0`, aka `StorageClassPat::S`). + storage_class_var_found: SmallIntMap<[SmallVec<[InferOperand; 2]>; 1]>, + + /// `ty_var_found[i][..]` holds all the `InferOperand`s matched by + /// `TyPat::Var(i)` (currently `i` is always `0`, aka `TyPat::T`). + ty_var_found: SmallIntMap<[SmallVec<[InferOperand; 4]>; 1]>, + + /// `index_composite_found[i][..]` holds all the `InferOperand`s matched by + /// `TyPat::IndexComposite(TyPat::Var(i))` (currently `i` is always `0`, aka `TyPat::T`). + index_composite_ty_var_found: SmallIntMap<[SmallVec<[IndexCompositeMatch<'a>; 1]>; 1]>, + + /// `ty_list_var_found[i][..]` holds all the `InferOperandList`s matched by + /// `TyListPat::Var(i)` (currently `i` is always `0`, aka `TyListPat::TS`). + ty_list_var_found: SmallIntMap<[SmallVec<[InferOperandList<'a>; 2]>; 1]>, +} + +impl<'a> Match<'a> { + /// Combine two `Match`es such that the result implies both of them apply, + /// i.e. contains the union of their constraints. + fn and(mut self, other: Self) -> Self { + let Match { + ambiguous, + storage_class_var_found, + ty_var_found, + index_composite_ty_var_found, + ty_list_var_found, + } = &mut self; + + *ambiguous |= other.ambiguous; + for (i, other_found) in other.storage_class_var_found { + storage_class_var_found + .get_mut_or_default(i) + .extend(other_found); + } + for (i, other_found) in other.ty_var_found { + ty_var_found.get_mut_or_default(i).extend(other_found); + } + for (i, other_found) in other.index_composite_ty_var_found { + index_composite_ty_var_found + .get_mut_or_default(i) + .extend(other_found); + } + for (i, other_found) in other.ty_list_var_found { + ty_list_var_found.get_mut_or_default(i).extend(other_found); + } + self + } + + /// Combine two `Match`es such that the result allows for either applying, + /// i.e. contains the intersection of their constraints. + fn or(mut self, other: Self) -> Self { + let Match { + ambiguous, + storage_class_var_found, + ty_var_found, + index_composite_ty_var_found, + ty_list_var_found, + } = &mut self; + + *ambiguous |= other.ambiguous; + for (i, self_found) in storage_class_var_found { + let other_found = other + .storage_class_var_found + .get(i) + .map(|xs| &xs[..]) + .unwrap_or(&[]); + self_found.retain(|x| other_found.contains(x)); + } + for (i, self_found) in ty_var_found { + let other_found = other.ty_var_found.get(i).map(|xs| &xs[..]).unwrap_or(&[]); + self_found.retain(|x| other_found.contains(x)); + } + for (i, self_found) in index_composite_ty_var_found { + let other_found = other + .index_composite_ty_var_found + .get(i) + .map(|xs| &xs[..]) + .unwrap_or(&[]); + self_found.retain(|x| other_found.contains(x)); + } + for (i, self_found) in ty_list_var_found { + let other_found = other + .ty_list_var_found + .get(i) + .map(|xs| &xs[..]) + .unwrap_or(&[]); + self_found.retain(|x| other_found.contains(x)); + } + self + } + + fn debug_with_infer_cx<'b>( + &'b self, + cx: &'b InferCx<'a, impl Specialization>, + ) -> impl fmt::Debug + Captures<'a> + '_ { + fn debug_var_found<'a, A: smallvec::Array + 'a, T: 'a, TD: fmt::Display>( + var_found: &'a SmallIntMap>>, + display: &'a impl Fn(&'a T) -> TD, + ) -> impl Iterator + 'a { + var_found + .0 + .iter() + .filter(|found| !found.is_empty()) + .map(move |found| { + FmtBy(move |f| { + let mut found = found.iter().map(display); + write!(f, "{}", found.next().unwrap())?; + for x in found { + write!(f, " = {}", x)?; + } + Ok(()) + }) + }) + } + FmtBy(move |f| { + let Self { + ambiguous, + storage_class_var_found, + ty_var_found, + index_composite_ty_var_found, + ty_list_var_found, + } = self; + write!(f, "Match{} ", if *ambiguous { " (ambiguous)" } else { "" })?; + let mut list = f.debug_list(); + list.entries(debug_var_found(storage_class_var_found, &move |operand| { + operand.display_with_infer_cx(cx) + })); + list.entries(debug_var_found(ty_var_found, &move |operand| { + operand.display_with_infer_cx(cx) + })); + list.entries( + index_composite_ty_var_found + .0 + .iter() + .enumerate() + .filter(|(_, found)| !found.is_empty()) + .flat_map(|(i, found)| found.iter().map(move |x| (i, x))) + .map(move |(i, IndexCompositeMatch { indices, leaf })| { + FmtBy(move |f| { + match ty_var_found.get(i) { + Some(found) if found.len() == 1 => { + write!(f, "{}", found[0].display_with_infer_cx(cx))?; + } + found => { + let found = found.map_or(&[][..], |xs| &xs[..]); + write!(f, "(")?; + for (j, operand) in found.iter().enumerate() { + if j != 0 { + write!(f, " = ")?; + } + write!(f, "{}", operand.display_with_infer_cx(cx))?; + } + write!(f, ")")?; + } + } + for operand in &indices[..] { + // Show the value for literals and IDs pointing to + // known `OpConstant`s (e.g. struct field indices). + let maybe_idx = match operand { + Operand::IdRef(id) => cx.specializer.int_consts.get(id), + Operand::LiteralInt32(idx) => Some(idx), + _ => None, + }; + match maybe_idx { + Some(idx) => write!(f, ".{}", idx)?, + None => write!(f, "[{}]", operand)?, + } + } + write!(f, " = {}", leaf.display_with_infer_cx(cx)) + }) + }), + ); + list.entries(debug_var_found(ty_list_var_found, &move |list| { + list.display_with_infer_cx(cx) + })); + list.finish() + }) + } +} + +/// Pattern-matching failure, returned by `match_*` when the pattern doesn't apply. +struct Unapplicable; + +impl<'a, S: Specialization> InferCx<'a, S> { + /// Match `storage_class` against `pat`, returning a `Match` with found `Var`s. + fn match_storage_class_pat( + &self, + pat: &StorageClassPat, + storage_class: InferOperand, + ) -> Match<'a> { + match pat { + StorageClassPat::Any => Match::default(), + StorageClassPat::Var(i) => { + let mut m = Match::default(); + m.storage_class_var_found + .get_mut_or_default(*i) + .push(storage_class); + m + } + } + } + + /// Match `ty` against `pat`, returning a `Match` with found `Var`s. + fn match_ty_pat(&self, pat: &TyPat<'_>, ty: InferOperand) -> Result, Unapplicable> { + match pat { + TyPat::Any => Ok(Match::default()), + TyPat::Var(i) => { + let mut m = Match::default(); + m.ty_var_found.get_mut_or_default(*i).push(ty); + Ok(m) + } + TyPat::Either(a, b) => match self.match_ty_pat(a, ty.clone()) { + Ok(m) if !m.ambiguous => Ok(m), + a_result => match (a_result, self.match_ty_pat(b, ty)) { + (Ok(ma), Ok(mb)) => Ok(ma.or(mb)), + (Ok(m), _) | (_, Ok(m)) => Ok(m), + (Err(Unapplicable), Err(Unapplicable)) => Err(Unapplicable), + }, + }, + TyPat::IndexComposite(composite_pat) => match composite_pat { + TyPat::Var(i) => { + let mut m = Match::default(); + m.index_composite_ty_var_found.get_mut_or_default(*i).push( + IndexCompositeMatch { + // HACK(eddyb) leave empty `indices` in here for + // `match_inst_sig` to fill in, as it has access + // to the whole `Instruction` but we don't. + indices: &[], + leaf: ty, + }, + ); + Ok(m) + } + _ => unreachable!( + "`IndexComposite({:?})` isn't supported, only type variable + patterns are (for the composite type), e.g. `IndexComposite(T)`", + composite_pat + ), + }, + _ => { + let instance = match ty { + InferOperand::Unknown | InferOperand::Concrete(_) => { + return Ok(Match { + ambiguous: true, + ..Match::default() + }) + } + InferOperand::Var(_) => return Err(Unapplicable), + InferOperand::Instance(instance) => instance, + }; + let generic = &self.specializer.generics[&instance.generic_id]; + + let ty_operands = InferOperandList { + operands: &generic.def.operands, + all_generic_args: instance.generic_args, + transform: None, + }; + let simple = |op, inner_pat| { + if generic.def.class.opcode == op { + self.match_ty_pat(inner_pat, ty_operands.split_first(self).unwrap().0) + } else { + Err(Unapplicable) + } + }; + match pat { + TyPat::Any | TyPat::Var(_) | TyPat::Either(..) | TyPat::IndexComposite(_) => { + unreachable!() + } + + // HACK(eddyb) `TyPat::Void` can't be observed because it's + // not "generic", so it would return early as ambiguous. + TyPat::Void => unreachable!(), + + TyPat::Pointer(storage_class_pat, pointee_pat) => { + let mut ty_operands = ty_operands.iter(self); + let (storage_class, pointee_ty) = + (ty_operands.next().unwrap(), ty_operands.next().unwrap()); + Ok(self + .match_storage_class_pat(storage_class_pat, storage_class) + .and(self.match_ty_pat(pointee_pat, pointee_ty)?)) + } + TyPat::Array(pat) => simple(Op::TypeArray, pat), + TyPat::Vector(pat) => simple(Op::TypeVector, pat), + TyPat::Vector4(pat) => match ty_operands.operands { + [_, Operand::LiteralInt32(4)] => simple(Op::TypeVector, pat), + _ => Err(Unapplicable), + }, + TyPat::Matrix(pat) => simple(Op::TypeMatrix, pat), + TyPat::Image(pat) => simple(Op::TypeImage, pat), + TyPat::Pipe(_pat) => { + if generic.def.class.opcode == Op::TypePipe { + Ok(Match::default()) + } else { + Err(Unapplicable) + } + } + TyPat::SampledImage(pat) => simple(Op::TypeSampledImage, pat), + TyPat::Struct(fields_pat) => self.match_ty_list_pat(fields_pat, ty_operands), + TyPat::Function(ret_pat, params_pat) => { + let (ret_ty, params_ty_list) = ty_operands.split_first(self).unwrap(); + Ok(self + .match_ty_pat(ret_pat, ret_ty)? + .and(self.match_ty_list_pat(params_pat, params_ty_list)?)) + } + } + } + } + } + + /// Match `ty_list` against `pat`, returning a `Match` with found `Var`s. + fn match_ty_list_pat( + &self, + mut list_pat: &TyListPat<'_>, + mut ty_list: InferOperandList<'a>, + ) -> Result, Unapplicable> { + let mut m = Match::default(); + + while let TyListPat::Cons { first: pat, suffix } = list_pat { + list_pat = suffix; + + let (ty, rest) = ty_list.split_first(self).ok_or(Unapplicable)?; + ty_list = rest; + + m = m.and(self.match_ty_pat(pat, ty)?); + } + + match list_pat { + TyListPat::Cons { .. } => unreachable!(), + + TyListPat::Any => {} + TyListPat::Var(i) => { + m.ty_list_var_found.get_mut_or_default(*i).push(ty_list); + } + TyListPat::Repeat(repeat_list_pat) => { + let mut tys = ty_list.iter(self).peekable(); + loop { + let mut list_pat = repeat_list_pat; + while let TyListPat::Cons { first: pat, suffix } = list_pat { + m = m.and(self.match_ty_pat(pat, tys.next().ok_or(Unapplicable)?)?); + list_pat = suffix; + } + assert!(matches!(list_pat, TyListPat::Nil)); + if tys.peek().is_none() { + break; + } + } + } + TyListPat::Nil => { + if ty_list.split_first(self).is_some() { + return Err(Unapplicable); + } + } + } + + Ok(m) + } + + /// Match `inst`'s input operands (with `inputs_generic_args` as "generic" args), + /// and `result_type`, against `sig`, returning a `Match` with found `Var`s. + fn match_inst_sig( + &self, + sig: &InstSig<'_>, + inst: &'a Instruction, + inputs_generic_args: Range, + result_type: Option, + ) -> Result, Unapplicable> { + let mut m = Match::default(); + + if let Some(pat) = sig.storage_class { + // FIXME(eddyb) going through all the operands to find the one that + // is a storage class is inefficient, storage classes should be part + // of a single unified list of operand patterns. + let all_operands = InferOperandList { + operands: &inst.operands, + all_generic_args: inputs_generic_args.clone(), + transform: None, + }; + let storage_class = all_operands + .iter(self) + .zip(&inst.operands) + .filter(|(_, original)| matches!(original, Operand::StorageClass(_))) + .map(|(operand, _)| operand) + .next() + .ok_or(Unapplicable)?; + m = m.and(self.match_storage_class_pat(pat, storage_class)); + } + + let input_ty_list = InferOperandList { + operands: &inst.operands, + all_generic_args: inputs_generic_args, + transform: Some(InferOperandListTransform::TypeOfId), + }; + + m = m.and(self.match_ty_list_pat(sig.input_types, input_ty_list.clone())?); + + match (sig.output_type, result_type) { + (Some(pat), Some(result_type)) => { + m = m.and(self.match_ty_pat(pat, result_type)?); + } + (None, None) => {} + _ => return Err(Unapplicable), + } + + if !m.index_composite_ty_var_found.0.is_empty() { + let composite_indices = { + // Drain the `input_types` prefix (everything before `..`). + let mut ty_list = input_ty_list; + let mut list_pat = sig.input_types; + while let TyListPat::Cons { first: _, suffix } = list_pat { + list_pat = suffix; + ty_list = ty_list.split_first(self).ok_or(Unapplicable)?.1; + } + + assert_eq!( + list_pat, + &TyListPat::Any, + "`IndexComposite` must have input types end in `..`" + ); + + // Extract the underlying remaining `operands` - while iterating on + // the `TypeOfId` list would skip over non-ID operands, and replace + // ID operands with their types, the `operands` slice is still a + // subslice of `inst.operands` (minus the prefix we drained above). + ty_list.operands + }; + + // Fill in all the `indices` fields left empty by `match_ty_pat`. + for (_, found) in &mut m.index_composite_ty_var_found { + for index_composite_match in found { + let empty = mem::replace(&mut index_composite_match.indices, composite_indices); + assert_eq!(empty, &[]); + } + } + } + + Ok(m) + } + + /// Match `inst`'s input operands (with `inputs_generic_args` as "generic" args), + /// and `result_type`, against `sigs`, returning a `Match` with found `Var`s. + fn match_inst_sigs( + &self, + sigs: &[InstSig<'_>], + inst: &'a Instruction, + inputs_generic_args: Range, + result_type: Option, + ) -> Result, Unapplicable> { + let mut result = Err(Unapplicable); + for sig in sigs { + result = match ( + result, + self.match_inst_sig(sig, inst, inputs_generic_args.clone(), result_type.clone()), + ) { + (Err(Unapplicable), Ok(m)) if !m.ambiguous => return Ok(m), + (Ok(a), Ok(b)) => Ok(a.or(b)), + (Ok(m), _) | (_, Ok(m)) => Ok(m), + (Err(Unapplicable), Err(Unapplicable)) => Err(Unapplicable), + }; + } + result + } +} + +enum InferError { + /// Mismatch between operands, returned by `equate_*(a, b)` when `a != b`. + // FIXME(eddyb) track where the mismatched operands come from. + Conflict(InferOperand, InferOperand), +} + +impl InferError { + fn report(self, inst: &Instruction) { + // FIXME(eddyb) better error reporting than this. + match self { + Self::Conflict(a, b) => { + eprintln!("inference conflict: {:?} vs {:?}", a, b); + } + } + eprint!(" in "); + // FIXME(eddyb) deduplicate this with other instruction printing logic. + if let Some(result_id) = inst.result_id { + eprint!("%{} = ", result_id); + } + eprint!("Op{:?}", inst.class.opcode); + for operand in inst + .result_type + .map(Operand::IdRef) + .iter() + .chain(inst.operands.iter()) + { + eprint!(" {}", operand); + } + eprintln!(); + + std::process::exit(1); + } +} + +impl<'a, S: Specialization> InferCx<'a, S> { + /// Traverse `SameAs` chains starting at `x` and return the first `InferVar` + /// that isn't `SameAs` (i.e. that is `Unknown` or `Known`). + /// This corresponds to `find(v)` from union-find. + fn resolve_infer_var(&mut self, v: InferVar) -> InferVar { + match self.infer_var_values[v.0 as usize] { + Value::Unknown | Value::Known(_) => v, + Value::SameAs(next) => { + let resolved = self.resolve_infer_var(next); + if resolved != next { + // Update the `SameAs` entry for faster lookup next time + // (also known as "path compression" in union-find). + self.infer_var_values[v.0 as usize] = Value::SameAs(resolved); + } + resolved + } + } + } + + /// Enforce that `a = b`, returning a combined `InferVar`, if successful. + /// This corresponds to `union(a, b)` from union-find. + fn equate_infer_vars(&mut self, a: InferVar, b: InferVar) -> Result { + let (a, b) = (self.resolve_infer_var(a), self.resolve_infer_var(b)); + + if a == b { + return Ok(a); + } + + // Maintain the invariant that "newer" variables are redirected to "older" ones. + let (older, newer) = (a.min(b), a.max(b)); + let newer_value = mem::replace( + &mut self.infer_var_values[newer.0 as usize], + Value::SameAs(older), + ); + match (self.infer_var_values[older.0 as usize], newer_value) { + // Guaranteed by `resolve_infer_var`. + (Value::SameAs(_), _) | (_, Value::SameAs(_)) => unreachable!(), + + // Both `newer` and `older` had a `Known` value, they must match. + (Value::Known(x), Value::Known(y)) => { + if x != y { + return Err(InferError::Conflict( + InferOperand::Concrete(x), + InferOperand::Concrete(y), + )); + } + } + + // Move the `Known` value from `newer` to `older`. + (Value::Unknown, Value::Known(_)) => { + self.infer_var_values[older.0 as usize] = newer_value; + } + + (_, Value::Unknown) => {} + } + + Ok(older) + } + + /// Enforce that `a = b`, returning a combined `Range`, if successful. + fn equate_infer_var_ranges( + &mut self, + a: Range, + b: Range, + ) -> Result, InferError> { + if a == b { + return Ok(a); + } + + assert_eq!(a.end.0 - a.start.0, b.end.0 - b.start.0); + + for (a, b) in InferVar::range_iter(&a).zip(InferVar::range_iter(&b)) { + self.equate_infer_vars(a, b)?; + } + + // Pick the "oldest" range to maintain the invariant that "newer" variables + // are redirected to "older" ones, while keeping a contiguous range + // (instead of splitting it into individual variables), for performance. + Ok(if a.start < b.start { a } else { b }) + } + + /// Enforce that `a = b`, returning a combined `InferOperand`, if successful. + fn equate_infer_operands( + &mut self, + a: InferOperand, + b: InferOperand, + ) -> Result { + if a == b { + return Ok(a); + } + + Ok(match (a.clone(), b.clone()) { + // Instances of "generic" globals/functions must be of the same ID, + // and their `generic_args` inference variables must be unified. + ( + InferOperand::Instance(Instance { + generic_id: a_id, + generic_args: a_args, + }), + InferOperand::Instance(Instance { + generic_id: b_id, + generic_args: b_args, + }), + ) => { + if a_id != b_id { + return Err(InferError::Conflict(a, b)); + } + InferOperand::Instance(Instance { + generic_id: a_id, + generic_args: self.equate_infer_var_ranges(a_args, b_args)?, + }) + } + + // Instances of "generic" globals/functions can never equal anything else. + (InferOperand::Instance(_), _) | (_, InferOperand::Instance(_)) => { + return Err(InferError::Conflict(a, b)); + } + + // Inference variables must be unified. + (InferOperand::Var(a), InferOperand::Var(b)) => { + InferOperand::Var(self.equate_infer_vars(a, b)?) + } + + // An inference variable can be assigned a concrete value. + (InferOperand::Var(v), InferOperand::Concrete(new)) + | (InferOperand::Concrete(new), InferOperand::Var(v)) => { + let v = self.resolve_infer_var(v); + match &mut self.infer_var_values[v.0 as usize] { + // Guaranteed by `resolve_infer_var`. + Value::SameAs(_) => unreachable!(), + + &mut Value::Known(old) => { + if new != old { + return Err(InferError::Conflict( + InferOperand::Concrete(old), + InferOperand::Concrete(new), + )); + } + } + + value @ Value::Unknown => *value = Value::Known(new), + } + InferOperand::Var(v) + } + + // Concrete `Operand`s must simply match. + (InferOperand::Concrete(_), InferOperand::Concrete(_)) => { + // Success case is handled by `if a == b` early return above. + return Err(InferError::Conflict(a, b)); + } + + // Unknowns can be ignored in favor of non-`Unknown`. + // NOTE(eddyb) `x` cannot be `Instance`, that is handled above. + (InferOperand::Unknown, x) | (x, InferOperand::Unknown) => x, + }) + } + + /// Compute the result ("leaf") type for a `TyPat::IndexComposite` pattern, + /// by applying each index in `indices` to `composite_ty`, extracting the + /// element type (for `OpType{Array,RuntimeArray,Vector,Matrix}`), or the + /// field type for `OpTypeStruct`, where `indices` contains the field index. + fn index_composite(&self, composite_ty: InferOperand, indices: &[Operand]) -> InferOperand { + let mut ty = composite_ty; + for idx in &indices[..] { + let instance = match ty { + InferOperand::Unknown | InferOperand::Concrete(_) | InferOperand::Var(_) => { + return InferOperand::Unknown; + } + InferOperand::Instance(instance) => instance, + }; + let generic = &self.specializer.generics[&instance.generic_id]; + + let ty_opcode = generic.def.class.opcode; + let ty_operands = InferOperandList { + operands: &generic.def.operands, + all_generic_args: instance.generic_args, + transform: None, + }; + + let ty_operands_idx = match ty_opcode { + Op::TypeArray | Op::TypeRuntimeArray | Op::TypeVector | Op::TypeMatrix => 0, + Op::TypeStruct => match idx { + Operand::IdRef(id) => { + *self.specializer.int_consts.get(id).unwrap_or_else(|| { + unreachable!("non-constant `OpTypeStruct` field index {}", id); + }) + } + &Operand::LiteralInt32(i) => i, + _ => { + unreachable!("invalid `OpTypeStruct` field index operand {:?}", idx); + } + }, + _ => unreachable!("indexing non-composite type `Op{:?}`", ty_opcode), + }; + + ty = ty_operands + .iter(self) + .nth(ty_operands_idx as usize) + .unwrap_or_else(|| { + unreachable!( + "out of bounds index {} for `Op{:?}`", + ty_operands_idx, ty_opcode + ); + }); + } + ty + } + + /// Enforce that all the `InferOperand`/`InferOperandList`s found for the + /// same pattern variable (i.e. `*Pat::Var(i)` with the same `i`), are equal. + fn equate_match_findings(&mut self, m: Match<'_>) -> Result<(), InferError> { + let Match { + ambiguous: _, + + storage_class_var_found, + ty_var_found, + index_composite_ty_var_found, + ty_list_var_found, + } = m; + + for (_, found) in storage_class_var_found { + let mut found = found.into_iter(); + if let Some(first) = found.next() { + found.try_fold(first, |a, b| self.equate_infer_operands(a, b))?; + } + } + + for (i, found) in ty_var_found { + let mut found = found.into_iter(); + if let Some(first) = found.next() { + let equated_ty = found.try_fold(first, |a, b| self.equate_infer_operands(a, b))?; + + // Apply any `IndexComposite(Var(i))`'s indices to `equated_ty`, + // and equate the resulting "leaf" type with the found "leaf" type. + let index_composite_found = index_composite_ty_var_found + .get(i) + .map_or(&[][..], |xs| &xs[..]); + for IndexCompositeMatch { indices, leaf } in index_composite_found { + let indexing_result_ty = self.index_composite(equated_ty.clone(), indices); + self.equate_infer_operands(indexing_result_ty, leaf.clone())?; + } + } + } + + for (_, mut found) in ty_list_var_found { + if let Some((first_list, other_lists)) = found.split_first_mut() { + // Advance all the lists in lock-step so that we don't have to + // allocate state proportional to list length and/or `found.len()`. + while let Some((first, rest)) = first_list.split_first(self) { + *first_list = rest; + + other_lists.iter_mut().try_fold(first, |a, b_list| { + let (b, rest) = b_list + .split_first(self) + .expect("list length mismatch (invalid SPIR-V?)"); + *b_list = rest; + self.equate_infer_operands(a, b) + })?; + } + + for other_list in other_lists { + assert!( + other_list.split_first(self).is_none(), + "list length mismatch (invalid SPIR-V?)" + ); + } + } + } + + Ok(()) + } + + /// Track an instantiated operand, to be included in the `Replacements` + /// (produced by `into_replacements`), if it has any `InferVar`s at all. + fn record_instantiated_operand(&mut self, loc: OperandLocation, operand: InferOperand) { + match operand { + InferOperand::Var(v) => { + self.inferred_operands.push((loc, v)); + } + InferOperand::Instance(instance) => { + self.instantiated_operands.push((loc, instance)); + } + InferOperand::Unknown | InferOperand::Concrete(_) => {} + } + } + + /// Instantiate all of `inst`'s operands (and *Result Type*) that refer to + /// "generic" globals/functions, or we need to specialize by, with fresh + /// inference variables, and enforce any inference constraints applicable. + fn instantiate_instruction(&mut self, inst: &'a Instruction, inst_loc: InstructionLocation) { + let mut all_generic_args = { + let next_infer_var = InferVar(self.infer_var_values.len().try_into().unwrap()); + next_infer_var..next_infer_var + }; + + // HACK(eddyb) work around the inexplicable fact that `OpFunction` is + // specified with a *Result Type* that isn't the type of its *Result*: + // > *Result Type* must be the same as the *Return Type* declared in *Function Type* + // Specifically, we don't instantiate *Result Type* (to avoid ending + // up with redundant `InferVar`s), and instead overlap its "generic" args + // with that of the *Function Type*, for `instantiations. + let (instantiate_result_type, record_fn_ret_ty, type_of_result) = match inst.class.opcode { + Op::Function => ( + None, + inst.result_type, + Some(inst.operands[1].unwrap_id_ref()), + ), + _ => (inst.result_type, None, inst.result_type), + }; + + for (operand_idx, operand) in instantiate_result_type + .map(Operand::IdRef) + .iter() + .map(|o| (OperandIdx::ResultType, o)) + .chain( + inst.operands + .iter() + .enumerate() + .map(|(i, o)| (OperandIdx::Input(i), o)), + ) + { + // HACK(eddyb) use `v..InferVar(u32::MAX)` as an open-ended range of sorts. + let (operand, rest) = InferOperand::from_operand_and_generic_args( + operand, + all_generic_args.end..InferVar(u32::MAX), + self, + ); + let generic_args = all_generic_args.end..rest.start; + all_generic_args.end = generic_args.end; + + let generic = match &operand { + InferOperand::Instance(instance) => { + Some(&self.specializer.generics[&instance.generic_id]) + } + _ => None, + }; + + // Initialize the new inference variables (for `operand`'s "generic" args) + // with either `generic.param_values` (if present) or all `Unknown`s. + match generic { + Some(Generic { + param_values: Some(values), + .. + }) => self.infer_var_values.extend( + values + .iter() + .map(|v| v.map_var(|Param(p)| InferVar(generic_args.start.0 + p))), + ), + + _ => { + self.infer_var_values + .extend(InferVar::range_iter(&generic_args).map(|_| Value::Unknown)); + } + } + + self.record_instantiated_operand( + OperandLocation { + inst_loc, + operand_idx, + }, + operand, + ); + } + + // HACK(eddyb) workaround for `OpFunction`, see earlier HACK commment. + if let Some(ret_ty) = record_fn_ret_ty { + let (ret_ty, _) = InferOperand::from_operand_and_generic_args( + &Operand::IdRef(ret_ty), + all_generic_args.clone(), + self, + ); + self.record_instantiated_operand( + OperandLocation { + inst_loc, + operand_idx: OperandIdx::ResultType, + }, + ret_ty, + ); + } + + // *Result Type* comes first in `all_generic_args`, extract it back out. + let (type_of_result, inputs_generic_args) = match type_of_result { + Some(type_of_result) => { + let (type_of_result, rest) = InferOperand::from_operand_and_generic_args( + &Operand::IdRef(type_of_result), + all_generic_args.clone(), + self, + ); + ( + Some(type_of_result), + // HACK(eddyb) workaround for `OpFunction`, see earlier HACK commment. + match inst.class.opcode { + Op::Function => all_generic_args, + _ => rest, + }, + ) + } + None => (None, all_generic_args), + }; + + let debug_dump_if_enabled = |cx: &Self, prefix| { + if cx.specializer.debug { + let result_type = match inst.class.opcode { + // HACK(eddyb) workaround for `OpFunction`, see earlier HACK commment. + Op::Function => Some( + InferOperand::from_operand_and_generic_args( + &Operand::IdRef(inst.result_type.unwrap()), + inputs_generic_args.clone(), + cx, + ) + .0, + ), + _ => type_of_result.clone(), + }; + let inputs = InferOperandList { + operands: &inst.operands, + all_generic_args: inputs_generic_args.clone(), + transform: None, + }; + + if inst_loc != InstructionLocation::Module { + eprint!(" "); + } + eprint!("{}", prefix); + if let Some(result_id) = inst.result_id { + eprint!("%{} = ", result_id); + } + eprint!("Op{:?}", inst.class.opcode); + for operand in result_type.into_iter().chain(inputs.iter(cx)) { + eprint!(" {}", operand.display_with_infer_cx(cx)); + } + eprintln!(); + } + }; + + // If we have some instruction signatures for `inst`, enforce them. + if let Some(sigs) = spirv_type_constraints::instruction_signatures(inst.class.opcode) { + // HACK(eddyb) workaround for `OpFunction`, see earlier HACK commment. + // (specifically, `type_of_result` isn't *Result Type* for `OpFunction`) + assert_ne!(inst.class.opcode, Op::Function); + + debug_dump_if_enabled(self, " -> "); + + let m = match self.match_inst_sigs( + sigs, + inst, + inputs_generic_args.clone(), + type_of_result.clone(), + ) { + Ok(m) => m, + + // While this could be an user error *in theory*, we haven't really + // unified any of the `InferOperand`s found by pattern match variables, + // at this point, so the only the possible error case is that `inst` + // doesn't match the *shapes* specified in `sigs`, i.e. this is likely + // a bug in `spirv_type_constraints`, not some kind of inference conflict. + Err(Unapplicable) => unreachable!( + "spirv_type_constraints(Op{:?}) = `{:?}` doesn't match `{:?}`", + inst.class.opcode, sigs, inst + ), + }; + + if self.specializer.debug { + if inst_loc != InstructionLocation::Module { + eprint!(" "); + } + eprintln!(" found {:?}", m.debug_with_infer_cx(self)); + } + + if let Err(e) = self.equate_match_findings(m) { + e.report(inst); + } + + debug_dump_if_enabled(self, " <- "); + } else { + debug_dump_if_enabled(self, ""); + } + + if let Some(type_of_result) = type_of_result { + // Keep the (instantiated) *Result Type*, for future instructions to use + // (but only if it has any `InferVar`s at all). + match type_of_result { + InferOperand::Var(_) | InferOperand::Instance(_) => { + self.type_of_result + .insert(inst.result_id.unwrap(), type_of_result); + } + InferOperand::Unknown | InferOperand::Concrete(_) => {} + } + } + } + + /// Instantiate `func`'s definition and all instructions in its body, + /// effectively performing inference across the entire function body. + fn instantiate_function(&mut self, func: &'a Function) { + let func_id = func.def_id().unwrap(); + + if self.specializer.debug { + eprintln!(); + eprint!("specializer::instantiate_function(%{}", func_id); + if let Some(name) = self.specializer.debug_names.get(&func_id) { + eprint!(" {}", name); + } + eprintln!("):"); + } + + // Instantiate the defining `OpFunction` first, so that the first + // inference variables match the parameters from the `Generic` + // (if the `OpTypeFunction` is "generic", that is). + assert!(self.infer_var_values.is_empty()); + self.instantiate_instruction(func.def.as_ref().unwrap(), InstructionLocation::Module); + + if self.specializer.debug { + eprintln!("infer body {{"); + } + + // If the `OpTypeFunction` is indeed "generic", we have to extract the + // return / parameter types for `OpReturnValue` and `OpFunctionParameter`. + let ret_ty = match self.type_of_result.get(&func_id).cloned() { + Some(InferOperand::Instance(instance)) => { + let generic = &self.specializer.generics[&instance.generic_id]; + assert_eq!(generic.def.class.opcode, Op::TypeFunction); + + let (ret_ty, mut params_ty_list) = InferOperandList { + operands: &generic.def.operands, + all_generic_args: instance.generic_args, + transform: None, + } + .split_first(self) + .unwrap(); + + // HACK(eddyb) manual iteration to avoid borrowing `self`. + let mut params = func.parameters.iter().enumerate(); + while let Some((param_ty, rest)) = params_ty_list.split_first(self) { + params_ty_list = rest; + + let (i, param) = params.next().unwrap(); + assert_eq!(param.class.opcode, Op::FunctionParameter); + + if self.specializer.debug { + eprintln!( + " %{} = Op{:?} {}", + param.result_id.unwrap(), + param.class.opcode, + param_ty.display_with_infer_cx(self) + ); + } + + self.record_instantiated_operand( + OperandLocation { + inst_loc: InstructionLocation::FnParam(i), + operand_idx: OperandIdx::ResultType, + }, + param_ty.clone(), + ); + match param_ty { + InferOperand::Var(_) | InferOperand::Instance(_) => { + self.type_of_result + .insert(param.result_id.unwrap(), param_ty); + } + InferOperand::Unknown | InferOperand::Concrete(_) => {} + } + } + assert_eq!(params.next(), None); + + Some(ret_ty) + } + + _ => None, + }; + + for (block_idx, block) in func.blocks.iter().enumerate() { + for (inst_idx, inst) in block.instructions.iter().enumerate() { + // Manually handle `OpReturnValue`/`OpReturn` because there's no + // way to inject `ret_ty` into `spirv_type_constraints` rules. + match inst.class.opcode { + Op::ReturnValue => { + let ret_val_id = inst.operands[0].unwrap_id_ref(); + if let (Some(expected), Some(found)) = ( + ret_ty.clone(), + self.type_of_result.get(&ret_val_id).cloned(), + ) { + if let Err(e) = self.equate_infer_operands(expected, found) { + e.report(inst); + } + } + } + + Op::Return => {} + + _ => self.instantiate_instruction( + inst, + InstructionLocation::FnBody { + block_idx, + inst_idx, + }, + ), + } + } + } + + if self.specializer.debug { + eprint!("}}"); + if let Some(func_ty) = self.type_of_result.get(&func_id) { + eprint!(" -> %{}: {}", func_id, func_ty.display_with_infer_cx(self)); + } + eprintln!(); + } + } + + /// Helper for `into_replacements`, that computes a single `ConcreteOrParam`. + /// For all `Param(p)` in `generic_params`, inference variables that resolve + /// to `InferVar(p)` are replaced with `Param(p)`, whereas other inference + /// variables are considered unconstrained, and are instead replaced with + /// `S::concrete_fallback()` (which is chosen by the specialization). + fn resolve_infer_var_to_concrete_or_param( + &mut self, + v: InferVar, + generic_params: RangeTo, + ) -> ConcreteOrParam { + let v = self.resolve_infer_var(v); + let InferVar(i) = v; + match self.infer_var_values[i as usize] { + // Guaranteed by `resolve_infer_var`. + Value::SameAs(_) => unreachable!(), + + Value::Unknown => { + if i < generic_params.end.0 { + ConcreteOrParam::Param(Param(i)) + } else { + ConcreteOrParam::Concrete( + CopyOperand::try_from(&self.specializer.specialization.concrete_fallback()) + .unwrap(), + ) + } + } + Value::Known(x) => ConcreteOrParam::Concrete(x), + } + } + + /// Consume the `InferCx` and return a set of replacements that need to be + /// performed to instantiate the global/function inferred with this `InferCx`. + /// See `resolve_infer_var_to_concrete_or_param` for how inference variables + /// are handled (using `generic_params` and `S::concrete_fallback()`). + fn into_replacements(mut self, generic_params: RangeTo) -> Replacements { + let mut with_instance: IndexMap<_, Vec<_>> = IndexMap::new(); + for (loc, instance) in mem::replace(&mut self.instantiated_operands, vec![]) { + with_instance + .entry(Instance { + generic_id: instance.generic_id, + generic_args: InferVar::range_iter(&instance.generic_args) + .map(|v| self.resolve_infer_var_to_concrete_or_param(v, generic_params)) + .collect(), + }) + .or_default() + .push(loc); + } + + let with_concrete_or_param = mem::replace(&mut self.inferred_operands, vec![]) + .into_iter() + .map(|(loc, v)| { + ( + loc, + self.resolve_infer_var_to_concrete_or_param(v, generic_params), + ) + }) + .collect(); + + Replacements { + with_instance, + with_concrete_or_param, + } + } +} + +// HACK(eddyb) this state could live in `Specializer` except for the fact that +// it's commonly mutated at the same time as parts of `Specializer` are read, +// and in particular this arrangement allows calling `&mut self` methods on +// `Expander` while (immutably) iterating over data inside the `Specializer`. +struct Expander<'a, S: Specialization> { + specializer: &'a Specializer, + + builder: Builder, + + /// All the instances of "generic" globals/functions that need to be expanded, + /// and their cached IDs (which are allocated as-needed, before expansion). + // NOTE(eddyb) this relies on `BTreeMap` so that `all_instances_of` can use + // `BTreeMap::range` to get all `Instances` that share a certain ID. + // FIXME(eddyb) fine-tune the length of `SmallVec<[_; 4]>` here. + instances: BTreeMap>, Word>, + + /// Instances of "generic" globals/functions that have yet to have had their + /// own `replacements` analyzed in order to fully collect all instances. + // FIXME(eddyb) fine-tune the length of `SmallVec<[_; 4]>` here. + propagate_instances_queue: VecDeque>>, +} + +impl<'a, S: Specialization> Expander<'a, S> { + fn new(specializer: &'a Specializer, module: Module) -> Self { + Expander { + specializer, + + builder: Builder::new_from_module(module), + + instances: BTreeMap::new(), + propagate_instances_queue: VecDeque::new(), + } + } + + /// Return the subset of `instances` that have `generic_id`. + /// This is efficiently implemented via `BTreeMap::range`, taking advantage + /// of the derived `Ord` on `Instance`, which orders by `generic_id` first, + /// resulting in `instances` being grouped by `generic_id`. + fn all_instances_of( + &self, + generic_id: Word, + ) -> std::collections::btree_map::Range<'_, Instance>, Word> { + let first_instance_of = |generic_id| Instance { + generic_id, + generic_args: SmallVec::new(), + }; + self.instances + .range(first_instance_of(generic_id)..first_instance_of(generic_id + 1)) + } + + /// Allocate a new ID for `instance`, or return a cached one if it exists. + /// If a new ID is created, `instance` is added to `propagate_instances_queue`, + /// so that `propagate_instances` can later find all transitive dependencies. + fn alloc_instance_id(&mut self, instance: Instance>) -> Word { + use std::collections::btree_map::Entry; + + match self.instances.entry(instance) { + Entry::Occupied(entry) => *entry.get(), + Entry::Vacant(entry) => { + // Get the `Instance` back from the map key, to avoid having to + // clone it earlier when calling `self.instances.entry(instance)`. + let instance = entry.key().clone(); + + self.propagate_instances_queue.push_back(instance); + *entry.insert(self.builder.id()) + } + } + } + + /// Process all instances seen (by `alloc_instance_id`) up until this point, + /// to find the full set of instances (transitively) needed by the module. + /// + /// **Warning**: calling `alloc_instance_id` later, without another call to + /// `propagate_instances`, will potentially result in missed instances, i.e. + /// that are added to `propagate_instances_queue` but never processed. + fn propagate_instances(&mut self) { + while let Some(instance) = self.propagate_instances_queue.pop_back() { + // Drain the iterator to generate all the `alloc_instance_id` calls. + for _ in self.specializer.generics[&instance.generic_id] + .replacements + .to_concrete(&instance.generic_args, |i| self.alloc_instance_id(i)) + {} + } + } + + /// Expand every "generic" global/function, and `OpName`/decorations applied + /// to them, to their respective full set of instances, treating the original + /// "generic" definition and its inferred `Replacements` as a template. + fn expand_module(mut self) -> Module { + // From here on out we assume all instances are known, so ensure there + // aren't any left unpropagated. + self.propagate_instances(); + + // HACK(eddyb) steal `Vec`s so that we can still call methods on `self` below. + let module = self.builder.module_mut(); + let mut entry_points = mem::replace(&mut module.entry_points, vec![]); + let debugs = mem::replace(&mut module.debugs, vec![]); + let annotations = mem::replace(&mut module.annotations, vec![]); + let types_global_values = mem::replace(&mut module.types_global_values, vec![]); + let functions = mem::replace(&mut module.functions, vec![]); + + // Adjust `OpEntryPoint ...` in-place to use the new IDs for *Interface* + // module-scoped `OpVariable`s (which should each have one instance). + for inst in &mut entry_points { + let func_id = inst.operands[1].unwrap_id_ref(); + assert!( + !self.specializer.generics.contains_key(&func_id), + "entry-point %{} shouldn't be \"generic\"", + func_id + ); + + for interface_operand in &mut inst.operands[3..] { + let interface_id = interface_operand.unwrap_id_ref(); + let mut instances = self.all_instances_of(interface_id); + match (instances.next(), instances.next()) { + (None, _) => unreachable!( + "entry-point %{} has overly-\"generic\" \ + interface variable %{}, with no instances", + func_id, interface_id + ), + (Some(_), Some(_)) => unreachable!( + "entry-point %{} has overly-\"generic\" \ + interface variable %{}, with too many instances: {:?}", + func_id, + interface_id, + FmtBy(|f| f + .debug_list() + .entries(self.all_instances_of(interface_id).map( + |(instance, _)| FmtBy(move |f| write!( + f, + "{}", + instance.display(|generic_args| generic_args.iter().copied()) + )) + )) + .finish()) + ), + (Some((_, &instance_id)), None) => { + *interface_operand = Operand::IdRef(instance_id); + } + } + } + } + + // FIXME(eddyb) bucket `instances` into global vs function, and count + // annotations separately, so that we can know exact capacities below. + + // Expand `Op* %target ...` when `target` is "generic". + let expand_debug_or_annotation = |insts: Vec| { + let mut expanded_insts = Vec::with_capacity(insts.len().next_power_of_two()); + for inst in insts { + if let [Operand::IdRef(target), ..] = inst.operands[..] { + if self.specializer.generics.contains_key(&target) { + expanded_insts.extend(self.all_instances_of(target).map( + |(_, &instance_id)| { + let mut expanded_inst = inst.clone(); + expanded_inst.operands[0] = Operand::IdRef(instance_id); + expanded_inst + }, + )); + continue; + } + } + expanded_insts.push(inst); + } + expanded_insts + }; + + // Expand `Op(Member)Name %target ...` when `target` is "generic". + let expanded_debugs = expand_debug_or_annotation(debugs); + + // Expand `Op(Member)Decorate* %target ...`, when `target` is "generic". + let expanded_annotations = expand_debug_or_annotation(annotations); + + // Expand "generic" globals (types, constants and module-scoped variables). + let mut expanded_types_global_values = + Vec::with_capacity(types_global_values.len().next_power_of_two()); + for inst in types_global_values { + if let Some(result_id) = inst.result_id { + if let Some(generic) = self.specializer.generics.get(&result_id) { + expanded_types_global_values.extend(self.all_instances_of(result_id).map( + |(instance, &instance_id)| { + let mut expanded_inst = inst.clone(); + expanded_inst.result_id = Some(instance_id); + for (loc, operand) in generic + .replacements + .to_concrete(&instance.generic_args, |i| self.instances[&i]) + { + expanded_inst.index_set(loc, operand.into()); + } + expanded_inst + }, + )); + continue; + } + } + expanded_types_global_values.push(inst); + } + + // Expand "generic" functions. + let mut expanded_functions = Vec::with_capacity(functions.len().next_power_of_two()); + for func in functions { + let func_id = func.def_id().unwrap(); + if let Some(generic) = self.specializer.generics.get(&func_id) { + let old_expanded_functions_len = expanded_functions.len(); + expanded_functions.extend(self.all_instances_of(func_id).map( + |(instance, &instance_id)| { + let mut expanded_func = func.clone(); + expanded_func.def.as_mut().unwrap().result_id = Some(instance_id); + for (loc, operand) in generic + .replacements + .to_concrete(&instance.generic_args, |i| self.instances[&i]) + { + expanded_func.index_set(loc, operand.into()); + } + expanded_func + }, + )); + + // Renumber all of the IDs defined within the function itself, + // to avoid conflicts between all the expanded copies. + // While some passes (such as inlining) may handle IDs reuse + // between different function bodies (mostly because they do + // their own renumbering), it's better not to tempt fate here. + // FIXME(eddyb) use compact IDs for more efficient renumbering. + let newly_expanded_functions = + &mut expanded_functions[old_expanded_functions_len..]; + if newly_expanded_functions.len() > 1 { + // NOTE(eddyb) this is defined outside the loop to avoid + // allocating it for every expanded copy of the function. + let mut rewrite_rules = HashMap::new(); + + for func in newly_expanded_functions { + rewrite_rules.extend(func.parameters.iter_mut().map(|param| { + let old_id = param.result_id.unwrap(); + let new_id = self.builder.id(); + + // HACK(eddyb) this is only needed because we're using + // `apply_rewrite_rules` and that only works on `Block`s, + // it should be generalized to handle `Function`s too. + param.result_id = Some(new_id); + + (old_id, new_id) + })); + rewrite_rules.extend( + func.blocks + .iter() + .flat_map(|b| b.label.iter().chain(b.instructions.iter())) + .filter_map(|inst| inst.result_id) + .map(|old_id| (old_id, self.builder.id())), + ); + + super::apply_rewrite_rules(&rewrite_rules, &mut func.blocks); + } + } + + continue; + } + expanded_functions.push(func); + } + + // No new instances should've been found during expansion - they would've + // panicked while attempting to get `self.instances[&instance]` anyway. + assert!(self.propagate_instances_queue.is_empty()); + + let module = self.builder.module_mut(); + module.entry_points = entry_points; + module.debugs = expanded_debugs; + module.annotations = expanded_annotations; + module.types_global_values = expanded_types_global_values; + module.functions = expanded_functions; + + self.builder.module() + } + + fn dump_instances(&self, w: &mut impl io::Write) -> io::Result<()> { + writeln!(w, "; All specializer \"generic\"s and their instances:")?; + writeln!(w)?; + + // FIXME(eddyb) maybe dump (transitive) dependencies? could use a def-use graph. + for (&generic_id, generic) in &self.specializer.generics { + if let Some(name) = self.specializer.debug_names.get(&generic_id) { + writeln!(w, "; {}", name)?; + } + + write!( + w, + "{} = Op{:?}", + Instance { + generic_id, + generic_args: Param(0)..Param(generic.param_count) + } + .display(Param::range_iter), + generic.def.class.opcode + )?; + let mut next_param = Param(0); + for operand in generic + .def + .result_type + .map(Operand::IdRef) + .iter() + .chain(generic.def.operands.iter()) + { + write!(w, " ")?; + let (needed, used_generic) = self.specializer.params_needed_by(operand); + let params = next_param..Param(next_param.0 + needed); + + // NOTE(eddyb) see HACK comment in `instantiate_instruction`. + if generic.def.class.opcode != Op::Function { + next_param = params.end; + } + + if used_generic.is_some() { + write!( + w, + "{}", + Instance { + generic_id: operand.unwrap_id_ref(), + generic_args: params + } + .display(Param::range_iter) + )?; + } else if needed == 1 { + write!(w, "{}", params.start)?; + } else { + write!(w, "{}", operand)?; + } + } + writeln!(w)?; + + if let Some(param_values) = &generic.param_values { + write!(w, " where")?; + for (i, v) in param_values.iter().enumerate() { + let p = Param(i as u32); + match v { + Value::Unknown => {} + Value::Known(o) => write!(w, " {} = {},", p, o)?, + Value::SameAs(q) => write!(w, " {} = {},", p, q)?, + } + } + writeln!(w)?; + } + + for (instance, instance_id) in self.all_instances_of(generic_id) { + assert_eq!(instance.generic_id, generic_id); + writeln!( + w, + " %{} = {}", + instance_id, + instance.display(|generic_args| generic_args.iter().copied()) + )?; + } + + writeln!(w)?; + } + Ok(()) + } +} diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index 8efeeda30e..36568d69d1 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -60,7 +60,6 @@ pub enum SpirvType { element: Word, }, Pointer { - storage_class: StorageClass, pointee: Word, }, Function { @@ -194,11 +193,13 @@ impl SpirvType { } result } - Self::Pointer { - storage_class, - pointee, - } => { - let result = cx.emit_global().type_pointer(None, storage_class, pointee); + Self::Pointer { pointee } => { + // NOTE(eddyb) we emit `StorageClass::Generic` here, but later + // the linker will specialize the entire SPIR-V module to use + // storage classes inferred from `OpVariable`s. + let result = cx + .emit_global() + .type_pointer(None, StorageClass::Generic, pointee); // no pointers to functions if let Self::Function { .. } = cx.lookup_type(pointee) { cx.zombie_even_in_user_code( @@ -249,13 +250,13 @@ impl SpirvType { return cached; } let result = match self { - Self::Pointer { - storage_class, - pointee, - } => { - let result = cx - .emit_global() - .type_pointer(Some(id), storage_class, pointee); + Self::Pointer { pointee } => { + // NOTE(eddyb) we emit `StorageClass::Generic` here, but later + // the linker will specialize the entire SPIR-V module to use + // storage classes inferred from `OpVariable`s. + let result = + cx.emit_global() + .type_pointer(Some(id), StorageClass::Generic, pointee); // no pointers to functions if let Self::Function { .. } = cx.lookup_type(pointee) { cx.zombie_even_in_user_code( @@ -440,13 +441,9 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> { .field("id", &self.id) .field("element", &self.cx.debug_type(element)) .finish(), - SpirvType::Pointer { - storage_class, - pointee, - } => f + SpirvType::Pointer { pointee } => f .debug_struct("Pointer") .field("id", &self.id) - .field("storage_class", &storage_class) .field("pointee", &self.cx.debug_type(pointee)) .finish(), SpirvType::Function { @@ -599,11 +596,8 @@ impl SpirvTypePrinter<'_, '_> { ty(self.cx, stack, f, element)?; f.write_str("]") } - SpirvType::Pointer { - storage_class, - pointee, - } => { - write!(f, "*{{{:?}}} ", storage_class)?; + SpirvType::Pointer { pointee } => { + f.write_str("*")?; ty(self.cx, stack, f, pointee) } SpirvType::Function { diff --git a/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs b/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs index 31c5607324..059fa59cf4 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs @@ -1,9 +1,10 @@ //! SPIR-V type constraints. Can be used to perform a subset of validation, //! or for inference purposes. //! -//! Only type equality is currently handled here, no concrete type constraints, -//! nor anything involving non-type operands. While more constraints could be -//! supported, encoding all the possible rules for them may be challenging. +//! Only type/storage-class equality is currently handled here, no concrete +//! type/storage-class constraints, nor anything involving non-type/storage-class +//! operands. While more constraints could be supported, encoding all the possible +//! rules for them may be challenging. //! //! Type constraints could be provided in two representations: //! * static/generic: the constraints are built up from generic types @@ -21,6 +22,32 @@ use rspirv::spirv::Op; +/// Helper trait to allow macros to work uniformly across different pattern types. +trait Pat { + /// Unconstrained pattern, i.e. that matches everything. + const ANY: Self; +} + +/// Pattern for a SPIR-V storage class, dynamic representation (see module-level docs). +#[derive(Debug, PartialEq, Eq)] +pub enum StorageClassPat { + /// Unconstrained storage class. + Any, + + /// Storage class variable: all occurrences of `Var(i)` with the same `i` must be + /// identical storage classes. For convenience, these associated consts are provided: + /// * `StorageClassPat::S` for `StorageClassPat::Var(0)` + Var(usize), +} + +impl Pat for StorageClassPat { + const ANY: Self = Self::Any; +} + +impl StorageClassPat { + pub const S: Self = Self::Var(0); +} + /// Pattern for a SPIR-V type, dynamic representation (see module-level docs). #[derive(Debug, PartialEq, Eq)] pub enum TyPat<'a> { @@ -42,11 +69,12 @@ pub enum TyPat<'a> { /// > `OpTypeImage` (unless that underlying *Sampled Type* is `OpTypeVoid`). Void, + /// `OpTypePointer`, with an inner pattern for its *Storage Class* operand, + /// and another for its *Type* operand. + Pointer(&'a StorageClassPat, &'a TyPat<'a>), + // FIXME(eddyb) try to DRY the same-shape patterns below. // - /// `OpTypePointer`, with an inner pattern for its *Type* operand. - Pointer(&'a TyPat<'a>), - /// `OpTypeArray`, with an inner pattern for its *Element Type* operand. Array(&'a TyPat<'a>), @@ -89,6 +117,10 @@ pub enum TyPat<'a> { IndexComposite(&'a TyPat<'a>), } +impl Pat for TyPat<'_> { + const ANY: Self = Self::Any; +} + impl TyPat<'_> { pub const T: Self = Self::Var(0); } @@ -107,9 +139,9 @@ pub enum TyListPat<'a> { /// * `TyListPat::TS` for `TyListPat::Var(0)` Var(usize), - /// Uniform repeat type list: all types in the list must be identical with - /// eachother, and each of them must also match the inner pattern. - Repeat(&'a TyPat<'a>), + /// Uniform repeat type list: equivalent to repeating the inner type list + /// pattern (which must be finite), enough times to cover the whole list. + Repeat(&'a TyListPat<'a>), /// Empty type list. Nil, @@ -121,6 +153,10 @@ pub enum TyListPat<'a> { }, } +impl Pat for TyListPat<'_> { + const ANY: Self = Self::Any; +} + impl TyListPat<'_> { pub const TS: Self = Self::Var(0); } @@ -128,26 +164,45 @@ impl TyListPat<'_> { /// Instruction "signature", dynamic representation (see module-level docs). #[derive(Copy, Clone, Debug)] pub struct InstSig<'a> { - /// Patterns for the complete list of types of the instruction's value operands. - pub inputs: &'a TyListPat<'a>, + /// Pattern for an instruction's sole storage class operand, if applicable. + // FIXME(eddyb) integrate this with `input_types` - it's non-trivial because + // that matches *the types of* ID operands, not the operands themselves. + pub storage_class: Option<&'a StorageClassPat>, + + /// Patterns for the complete list of types of the instruction's ID operands, + /// where non-value operands (i.e. IDs of instructions without a *Result Type*) + /// can only match `TyPat::Any`. + pub input_types: &'a TyListPat<'a>, /// Pattern for the instruction's *Result Type* operand, if applicable. - pub output: Option<&'a TyPat<'a>>, + pub output_type: Option<&'a TyPat<'a>>, } /// Returns an array of valid signatures for an instruction with opcode `op`, /// or `None` if there aren't any known type constraints for that instruction. pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { + // Restrict the names the `pat!` macro can take as pattern constructors. + mod pat_ctors { + pub const S: super::StorageClassPat = super::StorageClassPat::S; + // NOTE(eddyb) it would be really nice if we could import `TyPat::{* - Any, Var}`, + // i.e. all but those two variants. + pub use super::TyPat::{ + Array, Either, Function, Image, IndexComposite, Matrix, Pipe, Pointer, SampledImage, + Struct, Vector, Vector4, Void, + }; + pub const T: super::TyPat<'_> = super::TyPat::T; + pub use super::TyListPat::Repeat; + pub const TS: super::TyListPat<'_> = super::TyListPat::TS; + } + macro_rules! pat { - (_) => { &TyPat::Any }; + (_) => { &Pat::ANY }; ($ctor:ident $(($($inner:tt $(($($inner_args:tt)+))?),+))?) => { - &TyPat::$ctor $(($(pat!($inner $(($($inner_args)+))?)),+))? + &pat_ctors::$ctor $(($(pat!($inner $(($($inner_args)+))?)),+))? }; ([]) => { &TyListPat::Nil }; ([..]) => { &TyListPat::Any }; - ([...$ctor:ident $(($($inner:tt $(($($inner_args:tt)+))?),+))?]) => { - &TyListPat::$ctor $(($(pat!($inner $(($($inner_args)+))?)),+))? - }; + ([...$($rest:tt)+]) => { pat!($($rest)+) }; ([$first:tt $(($($first_args:tt)+))? $(, $($rest:tt)*)?]) => { &TyListPat::Cons { first: pat!($first $(($($first_args)+))?), @@ -164,10 +219,15 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { }; } macro_rules! sig { - ($(($($in_tys:tt)*) $(-> $out_ty:tt $(($($out_ty_args:tt)+))?)?)|+) => { + ($( + $({$($storage_class:tt)*})? + ($($in_tys:tt)*) + $(-> $out_ty:tt $(($($out_ty_args:tt)+))?)? + )|+) => { return Some(&[$(InstSig { - inputs: pat!([$($in_tys)*]), - output: optionify!($(pat!($out_ty $(($($out_ty_args)+))?))?), + storage_class: optionify!($(pat!($($storage_class)*))?), + input_types: pat!([$($in_tys)*]), + output_type: optionify!($(pat!($out_ty $(($($out_ty_args)+))?))?), }),*]); }; } @@ -246,10 +306,10 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { // 3.37.7. Constant-Creation Instructions Op::ConstantTrue | Op::ConstantFalse | Op::Constant => {} Op::ConstantComposite | Op::SpecConstantComposite => sig! { - (...TS) -> Struct([...TS]) | - (...Repeat(T)) -> Array(T) | - (...Repeat(T)) -> Vector(T) | - (...Repeat(T)) -> Matrix(T) + (...TS) -> Struct(TS) | + (...Repeat([T])) -> Array(T) | + (...Repeat([T])) -> Vector(T) | + (...Repeat([T])) -> Matrix(T) }, Op::ConstantSampler | Op::ConstantNull @@ -258,41 +318,42 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { | Op::SpecConstant => {} Op::SpecConstantOp => { unreachable!( - "Op{:?} should be specially handled outside type_constraints", + "Op{:?} should be specially handled outside spirv_type_constraints", op ); } // 3.37.8. Memory Instructions Op::Variable => sig! { - () -> _ | - (T) -> Pointer(T) + {S} () -> Pointer(S, _) | + {S} (T) -> Pointer(S, T) }, - Op::ImageTexelPointer => sig! { (Pointer(Image(T)), _, _) -> Pointer(T) }, - Op::Load => sig! { (Pointer(T)) -> T }, - Op::Store => sig! { (Pointer(T), T) }, - Op::CopyMemory => sig! { (Pointer(T), Pointer(T)) }, + Op::ImageTexelPointer => sig! { (Pointer(_, Image(T)), _, _) -> Pointer(_, T) }, + Op::Load => sig! { (Pointer(_, T)) -> T }, + Op::Store => sig! { (Pointer(_, T), T) }, + Op::CopyMemory => sig! { (Pointer(_, T), Pointer(_, T)) }, Op::CopyMemorySized => {} Op::AccessChain | Op::InBoundsAccessChain => sig! { - (Pointer(T), ../*indices*/) -> Pointer(IndexComposite(T)) + (Pointer(S, T), ../*indices*/) -> Pointer(S, IndexComposite(T)) }, Op::PtrAccessChain | Op::InBoundsPtrAccessChain => sig! { - (Pointer(T), _, ../*indices*/) -> Pointer(IndexComposite(T)) + (Pointer(S, T), _, ../*indices*/) -> Pointer(S, IndexComposite(T)) }, Op::ArrayLength | Op::GenericPtrMemSemantics => {} // SPIR-V 1.4 Op::PtrEqual | Op::PtrNotEqual | Op::PtrDiff => sig! { - (Pointer(T), Pointer(T)) -> _ + (Pointer(_, T), Pointer(_, T)) -> _ }, // 3.37.9. Function Instructions - Op::Function | Op::FunctionParameter | Op::FunctionEnd => { + Op::Function => {} + Op::FunctionParameter | Op::FunctionEnd => { unreachable!( - "Op{:?} should be specially handled outside type_constraints", + "Op{:?} should be specially handled outside spirv_type_constraints", op ); } - Op::FunctionCall => sig! { (Function(T, [...TS]), ...TS) -> T }, + Op::FunctionCall => sig! { (Function(T, TS), ...TS) -> T }, // 3.37.10. Image Instructions Op::SampledImage => sig! { (T, _) -> SampledImage(T) }, @@ -363,9 +424,8 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { | Op::FConvert => {} Op::QuantizeToF16 => sig! { (T) -> T }, Op::ConvertPtrToU | Op::SatConvertSToU | Op::SatConvertUToS | Op::ConvertUToPtr => {} - Op::PtrCastToGeneric | Op::GenericCastToPtr | Op::GenericCastToPtrExplicit => sig! { - (Pointer(T)) -> Pointer(T) - }, + Op::PtrCastToGeneric | Op::GenericCastToPtr => sig! { (Pointer(_, T)) -> Pointer(_, T) }, + Op::GenericCastToPtrExplicit => sig! { {S} (Pointer(_, T)) -> Pointer(S, T) }, Op::Bitcast => {} // 3.37.12. Composite Instructions @@ -379,10 +439,10 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { }, Op::VectorShuffle => sig! { (Vector(T), Vector(T)) -> Vector(T) }, Op::CompositeConstruct => sig! { - (...TS) -> Struct([...TS]) | - (...Repeat(T)) -> Array(T) | - (...Repeat(T)) -> Matrix(T) | - (...Repeat(Either(Vector(T), T))) -> Vector(T) + (...TS) -> Struct(TS) | + (...Repeat([T])) -> Array(T) | + (...Repeat([T])) -> Matrix(T) | + (...Repeat([Either(Vector(T), T)])) -> Vector(T) }, Op::CompositeExtract => sig! { (T, ../*indices*/) -> IndexComposite(T) }, Op::CompositeInsert => sig! { (IndexComposite(T), T, ../*indices*/) -> T }, @@ -392,7 +452,7 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { Op::CopyLogical => sig! { // FIXME(eddyb) this is shallow right now, it should recurse instead (Array(T)) -> Array(T) | - (Struct([...TS])) -> Struct([...TS]) + (Struct(TS)) -> Struct(TS) }, // 3.37.13. Arithmetic Instructions @@ -496,7 +556,7 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { | Op::FwidthCoarse => sig! { (T) -> T }, // 3.37.17. Control-Flow Instructions - Op::Phi => sig! { (...Repeat(T)) -> T }, + Op::Phi => sig! { (...Repeat([T, _])) -> T }, Op::LoopMerge | Op::SelectionMerge | Op::Label @@ -506,7 +566,7 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { | Op::Kill => {} Op::Return | Op::ReturnValue => { unreachable!( - "Op{:?} should be specially handled outside type_constraints", + "Op{:?} should be specially handled outside spirv_type_constraints", op ); } @@ -514,9 +574,9 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { // 3.37.18. Atomic Instructions Op::AtomicLoad | Op::AtomicIIncrement | Op::AtomicIDecrement => sig! { - (Pointer(T), _, _) -> T + (Pointer(_, T), _, _) -> T }, - Op::AtomicStore => sig! { (Pointer(T), _, _, T) }, + Op::AtomicStore => sig! { (Pointer(_, T), _, _, T) }, Op::AtomicExchange | Op::AtomicIAdd | Op::AtomicISub @@ -526,14 +586,14 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { | Op::AtomicUMax | Op::AtomicAnd | Op::AtomicOr - | Op::AtomicXor => sig! { (Pointer(T), _, _, T) -> T }, + | Op::AtomicXor => sig! { (Pointer(_, T), _, _, T) -> T }, Op::AtomicCompareExchange | Op::AtomicCompareExchangeWeak => sig! { - (Pointer(T), _, _, _, T, T) -> T + (Pointer(_, T), _, _, _, T, T) -> T }, // Capability: Kernel Op::AtomicFlagTestAndSet | Op::AtomicFlagClear => {} // SPV_EXT_shader_atomic_float_add - Op::AtomicFAddEXT => sig! { (Pointer(T), _, _, T) -> T }, + Op::AtomicFAddEXT => sig! { (Pointer(_, T), _, _, T) -> T }, // 3.37.19. Primitive Instructions Op::EmitVertex | Op::EndPrimitive | Op::EmitStreamVertex | Op::EndStreamPrimitive => {} @@ -544,7 +604,7 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { Op::NamedBarrierInitialize | Op::MemoryNamedBarrier => {} // 3.37.21. Group and Subgroup Instructions - Op::GroupAsyncCopy => sig! { (_, Pointer(T), Pointer(T), _, _, _) -> _ }, + Op::GroupAsyncCopy => sig! { (_, Pointer(_, T), Pointer(_, T), _, _, _) -> _ }, Op::GroupWaitEvents => {} Op::GroupAll | Op::GroupAny => {} Op::GroupBroadcast => sig! { (_, T, _) -> T }, @@ -575,8 +635,8 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { // SPV_INTEL_subgroups Op::SubgroupShuffleINTEL | Op::SubgroupShuffleXorINTEL => sig! { (T, _) -> T }, Op::SubgroupShuffleDownINTEL | Op::SubgroupShuffleUpINTEL => sig! { (T, T, _) -> T }, - Op::SubgroupBlockReadINTEL => sig! { (Pointer(T)) -> T }, - Op::SubgroupBlockWriteINTEL => sig! { (Pointer(T), T) }, + Op::SubgroupBlockReadINTEL => sig! { (Pointer(_, T)) -> T }, + Op::SubgroupBlockWriteINTEL => sig! { (Pointer(_, T), T) }, Op::SubgroupImageBlockReadINTEL | Op::SubgroupImageBlockWriteINTEL => {} // SPV_INTEL_media_block_io Op::SubgroupImageMediaBlockReadINTEL | Op::SubgroupImageMediaBlockWriteINTEL => {} @@ -600,9 +660,9 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { Op::GetKernelLocalSizeForSubgroupCount | Op::GetKernelMaxNumSubgroups => {} // 3.37.23. Pipe Instructions - Op::ReadPipe | Op::WritePipe => sig! { (Pipe(T), Pointer(T), _, _) -> _ }, + Op::ReadPipe | Op::WritePipe => sig! { (Pipe(T), Pointer(_, T), _, _) -> _ }, Op::ReservedReadPipe | Op::ReservedWritePipe => sig! { - (Pipe(T), _, _, Pointer(T), _, _) -> _ + (Pipe(T), _, _, Pointer(_, T), _, _) -> _ }, Op::ReserveReadPipePackets | Op::ReserveWritePipePackets @@ -619,7 +679,7 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { Op::ConstantPipeStorage | Op::CreatePipeFromPipeStorage => {} // SPV_INTEL_blocking_pipes Op::ReadPipeBlockingINTEL | Op::WritePipeBlockingINTEL => sig! { - (Pipe(T), Pointer(T), _, _) + (Pipe(T), Pointer(_, T), _, _) }, // 3.37.24. Non-Uniform Instructions @@ -700,12 +760,15 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { } // SPV_EXT_demote_to_helper_invocation Op::DemoteToHelperInvocationEXT | Op::IsHelperInvocationEXT => { - reserved!(SPV_EXT_demote_to_helper_invocation) + // NOTE(eddyb) we actually use these despite not being in the standard yet. + // reserved!(SPV_EXT_demote_to_helper_invocation) } // SPV_INTEL_shader_integer_functions2 - Op::UCountLeadingZerosINTEL - | Op::UCountTrailingZerosINTEL - | Op::AbsISubINTEL + Op::UCountLeadingZerosINTEL | Op::UCountTrailingZerosINTEL => { + // NOTE(eddyb) we actually use these despite not being in the standard yet. + // reserved!(SPV_INTEL_shader_integer_functions2) + } + Op::AbsISubINTEL | Op::AbsUSubINTEL | Op::IAddSatINTEL | Op::UAddSatINTEL diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index fc5b2cae4e..65c687ca17 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -331,10 +331,6 @@ impl Symbols { .iter() .map(|&(a, b)| (a, SpirvAttribute::Entry(b.into()))); let custom_attributes = [ - ( - "really_unsafe_ignore_bitcasts", - SpirvAttribute::ReallyUnsafeIgnoreBitcasts, - ), ("sampler", SpirvAttribute::Sampler), ("block", SpirvAttribute::Block), ("flat", SpirvAttribute::Flat), @@ -444,7 +440,6 @@ pub enum SpirvAttribute { Entry(Entry), DescriptorSet(u32), Binding(u32), - ReallyUnsafeIgnoreBitcasts, Image { dim: Dim, depth: u32, diff --git a/crates/spirv-builder/src/test/basic.rs b/crates/spirv-builder/src/test/basic.rs index 144111c31c..47655bd620 100644 --- a/crates/spirv-builder/src/test/basic.rs +++ b/crates/spirv-builder/src/test/basic.rs @@ -148,9 +148,12 @@ fn asm_op_decorate() { "%image_2d = OpTypeImage %float Dim2D 0 0 0 1 Unknown", "%sampled_image_2d = OpTypeSampledImage %image_2d", "%image_array = OpTypeRuntimeArray %sampled_image_2d", - "%ptr_image_array = OpTypePointer UniformConstant %image_array", + // NOTE(eddyb) `Generic` is used here because it's the placeholder + // for storage class inference - both of the two `OpTypePointer` + // types below should end up inferring to `UniformConstant`. + "%ptr_image_array = OpTypePointer Generic %image_array", "%image_2d_var = OpVariable %ptr_image_array UniformConstant", - "%ptr_sampled_image_2d = OpTypePointer UniformConstant %sampled_image_2d", + "%ptr_sampled_image_2d = OpTypePointer Generic %sampled_image_2d", "", // ^^ type preamble "%offset = OpLoad _ {0}", "%24 = OpAccessChain %ptr_sampled_image_2d %image_2d_var %offset", @@ -282,7 +285,7 @@ pub struct ShaderConstants { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(constants: PushConstant) { - let _constants = constants.load(); + let _constants = *constants; } "#); } @@ -306,7 +309,7 @@ pub struct ShaderConstants { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(constants: PushConstant) { - let _constants = constants.load(); + let _constants = *constants; } "#, ); @@ -390,7 +393,7 @@ fn signum() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input, mut o: Output) { - o.store(i.load().signum()); + *o = i.signum(); }"#); } @@ -487,9 +490,9 @@ fn mat3_vec3_multiply() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(input: Input, mut output: Output) { - let input = input.load(); + let input = *input; let vector = input * glam::Vec3::new(1.0, 2.0, 3.0); - output.store(vector); + *output = vector; } "#); } @@ -516,9 +519,12 @@ fn complex_image_sample_inst() { "%image_2d = OpTypeImage %float Dim2D 0 0 0 1 Unknown", "%sampled_image_2d = OpTypeSampledImage %image_2d", "%image_array = OpTypeRuntimeArray %sampled_image_2d", - "%ptr_image_array = OpTypePointer UniformConstant %image_array", + // NOTE(eddyb) `Generic` is used here because it's the placeholder + // for storage class inference - both of the two `OpTypePointer` + // types below should end up inferring to `UniformConstant`. + "%ptr_image_array = OpTypePointer Generic %image_array", "%image_2d_var = OpVariable %ptr_image_array UniformConstant", - "%ptr_sampled_image_2d = OpTypePointer UniformConstant %sampled_image_2d", + "%ptr_sampled_image_2d = OpTypePointer Generic %sampled_image_2d", "", // ^^ type preamble "%offset = OpLoad _ {1}", "%24 = OpAccessChain %ptr_sampled_image_2d %image_2d_var %offset", @@ -568,10 +574,9 @@ fn image_read() { val(r#" #[allow(unused_attributes)] #[spirv(fragment)] -pub fn main(input: UniformConstant, mut output: Output) { - let image = input.load(); +pub fn main(image: UniformConstant, mut output: Output) { let coords = image.read(glam::IVec2::new(0, 1)); - output.store(coords); + *output = coords; } "#); } @@ -582,8 +587,7 @@ fn image_write() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(input: Input, image: UniformConstant) { - let texels = input.load(); - let image = image.load(); + let texels = *input; image.write(glam::UVec2::new(0, 1), texels); } "#); diff --git a/crates/spirv-builder/src/test/control_flow.rs b/crates/spirv-builder/src/test/control_flow.rs index 511b7cf413..65a2bde0fc 100644 --- a/crates/spirv-builder/src/test/control_flow.rs +++ b/crates/spirv-builder/src/test/control_flow.rs @@ -6,7 +6,7 @@ fn cf_while() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 10 { + while *i < 10 { } } "#); @@ -18,8 +18,8 @@ fn cf_while_while() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 20 { - while i.load() < 10 { + while *i < 20 { + while *i < 10 { } } } @@ -32,8 +32,8 @@ fn cf_while_while_break() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 20 { - while i.load() < 10 { + while *i < 20 { + while *i < 10 { break; } } @@ -47,9 +47,9 @@ fn cf_while_while_if_break() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 20 { - while i.load() < 10 { - if i.load() > 10 { + while *i < 20 { + while *i < 10 { + if *i > 10 { break; } } @@ -64,7 +64,7 @@ fn cf_while_break() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 10 { + while *i < 10 { break; } } @@ -77,8 +77,8 @@ fn cf_while_if_break() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 10 { - if i.load() == 0 { + while *i < 10 { + if *i == 0 { break; } } @@ -92,8 +92,8 @@ fn cf_while_if_break_else_break() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 10 { - if i.load() == 0 { + while *i < 10 { + if *i == 0 { break; } else { break; @@ -109,11 +109,11 @@ fn cf_while_if_break_if_break() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 10 { - if i.load() == 0 { + while *i < 10 { + if *i == 0 { break; } - if i.load() == 1 { + if *i == 1 { break; } } @@ -127,8 +127,8 @@ fn cf_while_while_continue() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 20 { - while i.load() < 10 { + while *i < 20 { + while *i < 10 { continue; } } @@ -142,9 +142,9 @@ fn cf_while_while_if_continue() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 20 { - while i.load() < 10 { - if i.load() > 5 { + while *i < 20 { + while *i < 10 { + if *i > 5 { continue; } } @@ -159,7 +159,7 @@ fn cf_while_continue() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 10 { + while *i < 10 { continue; } } @@ -172,8 +172,8 @@ fn cf_while_if_continue() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 10 { - if i.load() == 0 { + while *i < 10 { + if *i == 0 { continue; } } @@ -187,8 +187,8 @@ fn cf_while_if_continue_else_continue() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 10 { - if i.load() == 0 { + while *i < 10 { + if *i == 0 { continue; } else { continue; @@ -204,7 +204,7 @@ fn cf_while_return() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 10 { + while *i < 10 { return; } } @@ -217,7 +217,7 @@ fn cf_if_return_else() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - if i.load() < 10 { + if *i < 10 { return; } else { } @@ -231,7 +231,7 @@ fn cf_if_return_else_return() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - if i.load() < 10 { + if *i < 10 { return; } else { return; @@ -246,8 +246,8 @@ fn cf_if_while() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - if i.load() == 0 { - while i.load() < 10 { + if *i == 0 { + while *i < 10 { } } } @@ -260,7 +260,7 @@ fn cf_if() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - if i.load() > 0 { + if *i > 0 { } } @@ -272,10 +272,10 @@ fn cf_ifx2() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - if i.load() > 0 { + if *i > 0 { } - if i.load() > 1 { + if *i > 1 { } } @@ -288,7 +288,7 @@ fn cf_if_else() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - if i.load() > 0 { + if *i > 0 { } else { @@ -303,9 +303,9 @@ fn cf_if_elseif_else() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - if i.load() > 0 { + if *i > 0 { - } else if i.load() < 0 { + } else if *i < 0 { } else { @@ -320,8 +320,8 @@ fn cf_if_if() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - if i.load() > 0 { - if i.load() < 10 { + if *i > 0 { + if *i < 10 { } } @@ -335,12 +335,12 @@ fn cf_defer() { #[allow(unused_attributes)] #[spirv(fragment)] pub fn main(i: Input) { - while i.load() < 32 { + while *i < 32 { let current_position = 0; - if i.load() < current_position { + if *i < current_position { break; } - if i.load() < current_position { + if *i < current_position { break; } } diff --git a/crates/spirv-std/src/storage_class.rs b/crates/spirv-std/src/storage_class.rs index 268ae5cfe3..d5a31fb6c1 100644 --- a/crates/spirv-std/src/storage_class.rs +++ b/crates/spirv-std/src/storage_class.rs @@ -7,21 +7,28 @@ //! form a storage class, and unless stated otherwise, storage class-based //! restrictions are not restrictions on intermediate objects and their types. +use core::ops::{Deref, DerefMut}; + macro_rules! storage_class { ($(#[$($meta:meta)+])* storage_class $name:ident ; $($tt:tt)*) => { $(#[$($meta)+])* #[allow(unused_attributes)] - pub struct $name<'value, T> { - value: &'value mut T, + pub struct $name<'value, T: ?Sized> { + reference: &'value mut T, + } + + impl Deref for $name<'_, T> { + type Target = T; + fn deref(&self) -> &T { + self.reference + } } impl $name<'_, T> { /// Load the value into memory. - #[inline] - #[allow(unused_attributes)] - #[spirv(really_unsafe_ignore_bitcasts)] + #[deprecated(note = "storage_class::Foo types now implement Deref, and can be used like &T")] pub fn load(&self) -> T { - *self.value + **self } } @@ -32,18 +39,23 @@ macro_rules! storage_class { ($(#[$($meta:meta)+])* writeable storage_class $name:ident $($tt:tt)+) => { storage_class!($(#[$($meta)+])* storage_class $name $($tt)+); - impl $name<'_, T> { + impl DerefMut for $name<'_, T> { + fn deref_mut(&mut self) -> &mut T { + self.reference + } + } + + impl $name<'_, T> { /// Store the value in storage. - #[inline] - #[allow(unused_attributes)] - #[spirv(really_unsafe_ignore_bitcasts)] + #[deprecated(note = "storage_class::Foo types now implement DerefMut, and can be used like &mut T")] pub fn store(&mut self, v: T) { - *self.value = v + **self = v } /// A convenience function to load a value into memory and store it. - pub fn then(&mut self, then: impl FnOnce(T) -> T) { - self.store((then)(self.load())); + #[deprecated(note = "storage_class::Foo types now implement DerefMut, and can be used like &mut T")] + pub fn then(&mut self, f: impl FnOnce(T) -> T) { + **self = f(**self); } } }; diff --git a/examples/shaders/compute-shader/src/lib.rs b/examples/shaders/compute-shader/src/lib.rs index 14cc52dd5e..7eb5938b03 100644 --- a/examples/shaders/compute-shader/src/lib.rs +++ b/examples/shaders/compute-shader/src/lib.rs @@ -4,6 +4,8 @@ feature(register_attr), register_attr(spirv) )] +// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds. +#![deny(warnings)] extern crate spirv_std; diff --git a/examples/shaders/mouse-shader/src/lib.rs b/examples/shaders/mouse-shader/src/lib.rs index c5f7804434..7d75da9ad4 100644 --- a/examples/shaders/mouse-shader/src/lib.rs +++ b/examples/shaders/mouse-shader/src/lib.rs @@ -1,7 +1,11 @@ -#![cfg_attr(target_arch = "spirv", no_std)] -#![feature(lang_items)] -#![feature(register_attr)] -#![register_attr(spirv)] +#![cfg_attr( + target_arch = "spirv", + no_std, + feature(register_attr), + register_attr(spirv) +)] +// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds. +#![deny(warnings)] use core::f32::consts::PI; use shared::*; @@ -145,9 +149,7 @@ pub fn main_fs( constants: PushConstant, mut output: Output, ) { - let constants = constants.load(); - - let frag_coord = vec2(in_frag_coord.load().x, in_frag_coord.load().y); + let frag_coord = vec2(in_frag_coord.x, in_frag_coord.y); let cursor = vec2(constants.cursor_x, constants.cursor_y); let drag_start = vec2(constants.drag_start_x, constants.drag_start_y); @@ -250,7 +252,7 @@ pub fn main_fs( WHITE, ); - output.store(painter.color.extend(1.0)); + *output = painter.color.extend(1.0); } #[allow(unused_attributes)] @@ -259,12 +261,12 @@ pub fn main_vs( #[spirv(vertex_index)] vert_idx: Input, #[spirv(position)] mut builtin_pos: Output, ) { - let vert_idx = vert_idx.load(); + let vert_idx = *vert_idx; // Create a "full screen triangle" by mapping the vertex index. // ported from https://www.saschawillems.de/blog/2016/08/13/vulkan-tutorial-on-rendering-a-fullscreen-quad-without-buffers/ let uv = vec2(((vert_idx << 1) & 2) as f32, (vert_idx & 2) as f32); let pos = 2.0 * uv - Vec2::one(); - builtin_pos.store(pos.extend(0.0).extend(1.0)); + *builtin_pos = pos.extend(0.0).extend(1.0); } diff --git a/examples/shaders/simplest-shader/src/lib.rs b/examples/shaders/simplest-shader/src/lib.rs index 974d63839a..5d7fd331c4 100644 --- a/examples/shaders/simplest-shader/src/lib.rs +++ b/examples/shaders/simplest-shader/src/lib.rs @@ -4,6 +4,8 @@ feature(register_attr), register_attr(spirv) )] +// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds. +#![deny(warnings)] #[cfg(not(target_arch = "spirv"))] #[macro_use] @@ -14,7 +16,7 @@ use spirv_std::storage_class::{Input, Output}; #[allow(unused_attributes)] #[spirv(fragment)] pub fn main_fs(mut output: Output) { - output.store(vec4(1.0, 0.0, 0.0, 1.0)) + *output = vec4(1.0, 0.0, 0.0, 1.0); } #[allow(unused_attributes)] @@ -23,11 +25,11 @@ pub fn main_vs( #[spirv(vertex_index)] vert_id: Input, #[spirv(position)] mut out_pos: Output, ) { - let vert_id = vert_id.load(); - out_pos.store(vec4( + let vert_id = *vert_id; + *out_pos = vec4( (vert_id - 1) as f32, ((vert_id & 1) * 2 - 1) as f32, 0.0, 1.0, - )); + ); } diff --git a/examples/shaders/sky-shader/src/lib.rs b/examples/shaders/sky-shader/src/lib.rs index 613e43a692..9a9072d6b9 100644 --- a/examples/shaders/sky-shader/src/lib.rs +++ b/examples/shaders/sky-shader/src/lib.rs @@ -6,6 +6,8 @@ feature(register_attr, lang_items), register_attr(spirv) )] +// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds. +#![deny(warnings)] #[cfg(not(target_arch = "spirv"))] #[macro_use] @@ -163,11 +165,8 @@ pub fn main_fs( constants: PushConstant, mut output: Output, ) { - let constants = constants.load(); - - let frag_coord = vec2(in_frag_coord.load().x, in_frag_coord.load().y); - let color = fs(&constants, frag_coord); - output.store(color); + let frag_coord = vec2(in_frag_coord.x, in_frag_coord.y); + *output = fs(&constants, frag_coord); } #[allow(unused_attributes)] @@ -176,12 +175,12 @@ pub fn main_vs( #[spirv(vertex_index)] vert_idx: Input, #[spirv(position)] mut builtin_pos: Output, ) { - let vert_idx = vert_idx.load(); + let vert_idx = *vert_idx; // Create a "full screen triangle" by mapping the vertex index. // ported from https://www.saschawillems.de/blog/2016/08/13/vulkan-tutorial-on-rendering-a-fullscreen-quad-without-buffers/ let uv = vec2(((vert_idx << 1) & 2) as f32, (vert_idx & 2) as f32); let pos = 2.0 * uv - Vec2::one(); - builtin_pos.store(pos.extend(0.0).extend(1.0)); + *builtin_pos = pos.extend(0.0).extend(1.0); }