diff --git a/crates/rustc_codegen_spirv/src/abi.rs b/crates/rustc_codegen_spirv/src/abi.rs index 302baec0c6..593b1b2a58 100644 --- a/crates/rustc_codegen_spirv/src/abi.rs +++ b/crates/rustc_codegen_spirv/src/abi.rs @@ -370,6 +370,40 @@ impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> { .def_with_name(cx, span, TyLayoutNameKey::from(*self)), Abi::Scalar(ref scalar) => trans_scalar(cx, span, *self, scalar, Size::ZERO), Abi::ScalarPair(ref a, ref b) => { + // NOTE(eddyb) unlike `Abi::Scalar`'s simpler newtype-unpacking + // behavior, `Abi::ScalarPair` can be composed in two ways: + // * two `Abi::Scalar` fields (and any number of ZST fields), + // gets handled the same as a `struct { a, b }`, further below + // * an `Abi::ScalarPair` field (and any number of ZST fields), + // which requires more work to allow taking a reference to + // that field, and there are two potential approaches: + // 1. wrapping that field's SPIR-V type in a single-field + // `OpTypeStruct` - this has the disadvantage that GEPs + // would have to inject an extra `0` field index, and other + // field-related operations would also need additional work + // 2. reusing that field's SPIR-V type, instead of defining + // a new one, offering the `(a, b)` shape `rustc_codegen_ssa` + // expects, while letting noop pointercasts access the sole + // `Abi::ScalarPair` field - this is the approach taken here + let mut non_zst_fields = (0..self.fields.count()) + .map(|i| (i, self.field(cx, i))) + .filter(|(_, field)| !field.is_zst()); + let sole_non_zst_field = match (non_zst_fields.next(), non_zst_fields.next()) { + (Some(field), None) => Some(field), + _ => None, + }; + if let Some((i, field)) = sole_non_zst_field { + // Only unpack a newtype if the field and the newtype line up + // perfectly, in every way that could potentially affect ABI. + if self.fields.offset(i) == Size::ZERO + && field.size == self.size + && field.align == self.align + && field.abi == self.abi + { + return field.spirv_type(span, cx); + } + } + // Note: We can't use auto_struct_layout here because the spirv types here might be undefined due to // recursive pointer types. let a_offset = Size::ZERO; diff --git a/tests/ui/lang/issue-836.rs b/tests/ui/lang/issue-836.rs new file mode 100644 index 0000000000..9cced028c4 --- /dev/null +++ b/tests/ui/lang/issue-836.rs @@ -0,0 +1,42 @@ +// Test that newtypes of `ScalarPair` can have references taken to their field. + +// build-pass + +use spirv_std as _; + +struct Newtype(T); + +impl Newtype { + fn get(&self) -> &T { + &self.0 + } +} + +impl Newtype<&[u32]> { + fn slice_get(&self) -> &&[u32] { + &self.0 + } +} + +impl> Newtype { + fn deref_index(&self, i: usize) -> &u32 { + &self.0[i] + } +} + +struct CustomPair(u32, u32); + +#[spirv(fragment)] +pub fn main( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] slice: &[u32], + #[spirv(flat)] out: &mut u32, +) { + let newtype_slice = Newtype(slice); + *out = newtype_slice.get()[0]; + *out += newtype_slice.slice_get()[1]; + *out += newtype_slice.deref_index(2); + + let newtype_custom_pair = Newtype(CustomPair(*out, *out + 1)); + *out += newtype_custom_pair.get().0; + *out += newtype_custom_pair.get().1; +}