Skip to content

Commit

Permalink
Infer the storage class for images/samplers/imagesamplers (#567)
Browse files Browse the repository at this point in the history
* Infer the storage class for images/samplers/imagesamplers

* format

* Move comment
  • Loading branch information
khyperia authored Apr 1, 2021
1 parent fc8efd4 commit 4fa73bd
Show file tree
Hide file tree
Showing 19 changed files with 137 additions and 70 deletions.
90 changes: 67 additions & 23 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,21 +228,33 @@ impl<'tcx> CodegenCx<'tcx> {
}
}

fn declare_parameter(
fn infer_param_ty_and_storage_class(
&self,
layout: TyAndLayout<'tcx>,
hir_param: &hir::Param<'tcx>,
decoration_locations: &mut HashMap<StorageClass, u32>,
attrs: &AggregatedSpirvAttributes,
) -> (Word, StorageClass) {
let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.hir().attrs(hir_param.hir_id));

// FIXME(eddyb) attribute validation should be done ahead of time.
// FIXME(eddyb) also take into account `&T` interior mutability,
// i.e. it's only immutable if `T: Freeze`, which we should check.
// FIXME(eddyb) also check the type for compatibility with being
// part of the interface, including potentially `Sync`ness etc.
let (value_ty, storage_class) = if let Some(storage_class_attr) = attrs.storage_class {
let (value_ty, mutbl, is_ref) = match *layout.ty.kind() {
TyKind::Ref(_, pointee_ty, mutbl) => (pointee_ty, mutbl, true),
_ => (layout.ty, hir::Mutability::Not, false),
};
let spirv_ty = self.layout_of(value_ty).spirv_type(hir_param.span, self);
// Some types automatically specify a storage class. Compute that here.
let inferred_storage_class_from_ty = match self.lookup_type(spirv_ty) {
SpirvType::Image { .. } | SpirvType::Sampler | SpirvType::SampledImage { .. } => {
Some(StorageClass::UniformConstant)
}
_ => None,
};
// Storage classes can be specified via attribute. Compute that here, and emit diagnostics.
let attr_storage_class = attrs.storage_class.map(|storage_class_attr| {
let storage_class = storage_class_attr.value;

let expected_mutbl = match storage_class {
StorageClass::UniformConstant
| StorageClass::Input
Expand All @@ -252,10 +264,8 @@ impl<'tcx> CodegenCx<'tcx> {
_ => hir::Mutability::Mut,
};

match *layout.ty.kind() {
TyKind::Ref(_, pointee_ty, m) if m == expected_mutbl => (pointee_ty, storage_class),

_ => self.tcx.sess.span_fatal(
if !is_ref {
self.tcx.sess.span_fatal(
hir_param.span,
&format!(
"invalid entry param type `{}` for storage class `{:?}` \
Expand All @@ -264,25 +274,62 @@ impl<'tcx> CodegenCx<'tcx> {
storage_class,
expected_mutbl.prefix_str()
),
)
}

match inferred_storage_class_from_ty {
Some(inferred) if storage_class == inferred => self.tcx.sess.span_warn(
storage_class_attr.span,
"redundant storage class specifier, storage class is inferred from type",
),
Some(inferred) => self
.tcx
.sess
.struct_span_err(
hir_param.span,
&format!(
"storage class {:?} was inferred from type, but {:?} was specified in attribute",
inferred, storage_class
),
)
.span_note(
storage_class_attr.span,
&format!("remove storage class attribute to use {:?} as storage class", inferred),
)
.emit(),
None => (),
}
} else {
match *layout.ty.kind() {
TyKind::Ref(_, pointee_ty, hir::Mutability::Mut) => {
(pointee_ty, StorageClass::Output)
}

TyKind::Ref(_, pointee_ty, hir::Mutability::Not) => self.tcx.sess.span_fatal(
storage_class
});
// If storage class was not inferred nor specified, compute the default (i.e. input/output)
let storage_class = inferred_storage_class_from_ty
.or(attr_storage_class)
.unwrap_or_else(|| match (is_ref, mutbl) {
(false, _) => StorageClass::Input,
(true, hir::Mutability::Mut) => StorageClass::Output,
(true, hir::Mutability::Not) => self.tcx.sess.span_fatal(
hir_param.span,
&format!(
"invalid entry param type `{}` (expected `{}` or `&mut {1}`)",
layout.ty, pointee_ty
layout.ty, value_ty
),
),
});

_ => (layout.ty, StorageClass::Input),
}
};
(spirv_ty, storage_class)
}

fn declare_parameter(
&self,
layout: TyAndLayout<'tcx>,
hir_param: &hir::Param<'tcx>,
decoration_locations: &mut HashMap<StorageClass, u32>,
) -> (Word, StorageClass) {
let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.hir().attrs(hir_param.hir_id));

let (value_ty, storage_class) =
self.infer_param_ty_and_storage_class(layout, hir_param, &attrs);

// Pre-allocate the module-scoped `OpVariable`'s *Result* ID.
let variable = self.emit_global().id();
Expand Down Expand Up @@ -357,10 +404,7 @@ impl<'tcx> CodegenCx<'tcx> {
}

// Emit the `OpVariable` with its *Result* ID set to `variable`.
let var_spirv_type = SpirvType::Pointer {
pointee: self.layout_of(value_ty).spirv_type(hir_param.span, self),
}
.def(hir_param.span, self);
let var_spirv_type = SpirvType::Pointer { pointee: value_ty }.def(hir_param.span, self);
self.emit_global()
.variable(var_spirv_type, Some(variable), storage_class, None);

Expand Down
5 changes: 1 addition & 4 deletions tests/ui/image/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
use spirv_std::{arch, Image2d};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image: &Image2d,
output: &mut glam::Vec4,
) {
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] image: &Image2d, output: &mut glam::Vec4) {
let texel = image.fetch(glam::IVec2::new(0, 1));
*output = texel;
}
2 changes: 1 addition & 1 deletion tests/ui/image/issue_527.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct PointsBuffer {
pub fn main_cs(
#[spirv(global_invocation_id)] id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] points_buffer: &mut PointsBuffer,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] image: &spirv_std::StorageImage2d,
#[spirv(descriptor_set = 1, binding = 1)] image: &spirv_std::StorageImage2d,
) {
unsafe { asm!("OpCapability StorageImageWriteWithoutFormat") };
let position = id.xy();
Expand Down
2 changes: 1 addition & 1 deletion tests/ui/image/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use spirv_std::{arch, StorageImage2d};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image: &StorageImage2d,
#[spirv(descriptor_set = 0, binding = 0)] image: &StorageImage2d,
output: &mut glam::Vec4,
) {
unsafe { asm!("OpCapability StorageImageReadWithoutFormat") };
Expand Down
8 changes: 4 additions & 4 deletions tests/ui/image/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use spirv_std::{arch, Cubemap, Image2d, Image2dArray, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] image2d_array: &Image2dArray,
#[spirv(uniform_constant, descriptor_set = 2, binding = 2)] cubemap: &Cubemap,
#[spirv(uniform_constant, descriptor_set = 3, binding = 3)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] image2d_array: &Image2dArray,
#[spirv(descriptor_set = 2, binding = 2)] cubemap: &Cubemap,
#[spirv(descriptor_set = 3, binding = 3)] sampler: &Sampler,
output: &mut glam::Vec4,
) {
let v2 = glam::Vec2::new(0.0, 1.0);
Expand Down
8 changes: 4 additions & 4 deletions tests/ui/image/sample_depth_reference/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use spirv_std::{arch, Cubemap, Image2d, Image2dArray, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] image_array: &Image2dArray,
#[spirv(uniform_constant, descriptor_set = 2, binding = 2)] cubemap: &Cubemap,
#[spirv(uniform_constant, descriptor_set = 3, binding = 3)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] image_array: &Image2dArray,
#[spirv(descriptor_set = 2, binding = 2)] cubemap: &Cubemap,
#[spirv(descriptor_set = 3, binding = 3)] sampler: &Sampler,
output: &mut f32,
) {
let v2 = glam::Vec2::new(0.0, 1.0);
Expand Down
8 changes: 4 additions & 4 deletions tests/ui/image/sample_depth_reference/sample_gradient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use spirv_std::{arch, Cubemap, Image2d, Image2dArray, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] image_array: &Image2dArray,
#[spirv(uniform_constant, descriptor_set = 2, binding = 2)] sampler: &Sampler,
#[spirv(uniform_constant, descriptor_set = 3, binding = 3)] cubemap: &Cubemap,
#[spirv(descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] image_array: &Image2dArray,
#[spirv(descriptor_set = 2, binding = 2)] sampler: &Sampler,
#[spirv(descriptor_set = 3, binding = 3)] cubemap: &Cubemap,
output: &mut f32,
) {
let v2 = glam::Vec2::new(0.0, 1.0);
Expand Down
8 changes: 4 additions & 4 deletions tests/ui/image/sample_depth_reference/sample_lod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use spirv_std::{arch, Cubemap, Image2d, Image2dArray, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] image_array: &Image2dArray,
#[spirv(uniform_constant, descriptor_set = 2, binding = 2)] sampler: &Sampler,
#[spirv(uniform_constant, descriptor_set = 3, binding = 3)] cubemap: &Cubemap,
#[spirv(descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] image_array: &Image2dArray,
#[spirv(descriptor_set = 2, binding = 2)] sampler: &Sampler,
#[spirv(descriptor_set = 3, binding = 3)] cubemap: &Cubemap,
output: &mut f32,
) {
let v2 = glam::Vec2::new(0.0, 1.0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use spirv_std::{arch, Image2d, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] sampler: &Sampler,
output: &mut f32,
) {
let v3 = glam::Vec3A::new(0.0, 0.0, 1.0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use spirv_std::{arch, Image2d, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] sampler: &Sampler,
output: &mut f32,
) {
let v2_dx = glam::Vec2::new(0.0, 1.0);
let v2_dy = glam::Vec2::new(0.0, 1.0);
let v3 = glam::Vec3A::new(0.0, 0.0, 1.0);
*output = image.sample_depth_reference_with_project_coordinate_by_gradient(*sampler, v3, 1.0, v2_dx, v2_dy);
*output = image.sample_depth_reference_with_project_coordinate_by_gradient(
*sampler, v3, 1.0, v2_dx, v2_dy,
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use spirv_std::{arch, Image2d, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] sampler: &Sampler,
output: &mut f32,
) {
let v3 = glam::Vec3A::new(0.0, 0.0, 1.0);
Expand Down
8 changes: 4 additions & 4 deletions tests/ui/image/sample_gradient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use spirv_std::{arch, Cubemap, Image2d, Image2dArray, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] image2d_array: &Image2dArray,
#[spirv(uniform_constant, descriptor_set = 2, binding = 2)] cubemap: &Cubemap,
#[spirv(uniform_constant, descriptor_set = 3, binding = 3)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] image2d_array: &Image2dArray,
#[spirv(descriptor_set = 2, binding = 2)] cubemap: &Cubemap,
#[spirv(descriptor_set = 3, binding = 3)] sampler: &Sampler,
output: &mut glam::Vec4,
) {
let v2 = glam::Vec2::new(0.0, 1.0);
Expand Down
8 changes: 4 additions & 4 deletions tests/ui/image/sample_lod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use spirv_std::{arch, Cubemap, Image2d, Image2dArray, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] image2d_array: &Image2dArray,
#[spirv(uniform_constant, descriptor_set = 2, binding = 2)] cubemap: &Cubemap,
#[spirv(uniform_constant, descriptor_set = 3, binding = 3)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] image2d_array: &Image2dArray,
#[spirv(descriptor_set = 2, binding = 2)] cubemap: &Cubemap,
#[spirv(descriptor_set = 3, binding = 3)] sampler: &Sampler,
output: &mut glam::Vec4,
) {
let v2 = glam::Vec2::new(0.0, 1.0);
Expand Down
4 changes: 2 additions & 2 deletions tests/ui/image/sample_with_project_coordinate/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use spirv_std::{arch, Image2d, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] sampler: &Sampler,
output: &mut glam::Vec4,
) {
let v3 = glam::Vec3::new(0.0, 1.0, 0.5);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use spirv_std::{arch, Image2d, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] sampler: &Sampler,
output: &mut glam::Vec4,
) {
let v2 = glam::Vec2::new(0.0, 1.0);
Expand Down
4 changes: 2 additions & 2 deletions tests/ui/image/sample_with_project_coordinate/sample_lod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use spirv_std::{arch, Image2d, Sampler};

#[spirv(fragment)]
pub fn main(
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(uniform_constant, descriptor_set = 1, binding = 1)] sampler: &Sampler,
#[spirv(descriptor_set = 0, binding = 0)] image2d: &Image2d,
#[spirv(descriptor_set = 1, binding = 1)] sampler: &Sampler,
output: &mut glam::Vec4,
) {
let v3 = glam::Vec3::new(0.0, 1.0, 0.5);
Expand Down
5 changes: 1 addition & 4 deletions tests/ui/image/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
use spirv_std::{arch, StorageImage2d};

#[spirv(fragment)]
pub fn main(
texels: glam::Vec2,
#[spirv(uniform_constant, descriptor_set = 0, binding = 0)] image: &StorageImage2d,
) {
pub fn main(texels: glam::Vec2, #[spirv(descriptor_set = 0, binding = 0)] image: &StorageImage2d) {
unsafe {
asm!("OpCapability StorageImageWriteWithoutFormat");
image.write(glam::UVec2::new(0, 1), texels);
Expand Down
7 changes: 7 additions & 0 deletions tests/ui/spirv-attr/bad-infer-storage-class.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Tests that storage class inference fails correctly
// build-fail

use spirv_std::Image2d;

#[spirv(vertex)]
pub fn main(#[spirv(uniform)] error: &Image2d, #[spirv(uniform_constant)] warning: &Image2d) {}
20 changes: 20 additions & 0 deletions tests/ui/spirv-attr/bad-infer-storage-class.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
error: storage class UniformConstant was inferred from type, but Uniform was specified in attribute
--> $DIR/bad-infer-storage-class.rs:7:13
|
7 | pub fn main(#[spirv(uniform)] error: &Image2d, #[spirv(uniform_constant)] warning: &Image2d) {}
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: remove storage class attribute to use UniformConstant as storage class
--> $DIR/bad-infer-storage-class.rs:7:21
|
7 | pub fn main(#[spirv(uniform)] error: &Image2d, #[spirv(uniform_constant)] warning: &Image2d) {}
| ^^^^^^^

warning: redundant storage class specifier, storage class is inferred from type
--> $DIR/bad-infer-storage-class.rs:7:56
|
7 | pub fn main(#[spirv(uniform)] error: &Image2d, #[spirv(uniform_constant)] warning: &Image2d) {}
| ^^^^^^^^^^^^^^^^

error: aborting due to previous error; 1 warning emitted

0 comments on commit 4fa73bd

Please sign in to comment.