diff --git a/CHANGELOG.md b/CHANGELOG.md
index 01c5f55c63..d4c84d5431 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -30,11 +30,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added ⭐
+- [PR#1081](https://github.com/EmbarkStudios/rust-gpu/pull/1081) added the ability
+ to access SPIR-V specialization constants (`OpSpecConstant`) via entry-point
+ inputs declared as `#[spirv(spec_constant(id = ..., default = ...))] x: u32`
+ (see also [the `#[spirv(spec_constant)]` attribute documentation](docs/src/attributes.md#specialization-constants))
- [PR#1036](https://github.com/EmbarkStudios/rust-gpu/pull/1036) added a `--force-spirv-passthru` flag to `example-runner-wgpu`, to bypass Naga (`wgpu`'s shader translator),
used it to test `debugPrintf` for `wgpu`, and updated `ShaderPanicStrategy::DebugPrintfThenExit` docs to reflect what "enabling `debugPrintf`" looks like for `wgpu`
(e.g. `VK_LOADER_LAYERS_ENABLE=VK_LAYER_KHRONOS_validation VK_LAYER_ENABLES=VK_VALIDATION_FEATURE_ENABLE_DEBUG_PRINTF_EXT DEBUG_PRINTF_TO_STDOUT=1`)
- [PR#1080](https://github.com/EmbarkStudios/rust-gpu/pull/1080) added `debugPrintf`-based
- panic reporting, with the desired behavior selected via `spirv_builder::ShaderPanicStrategy`
+ panic reporting, with the desired behavior selected via `spirv_builder::ShaderPanicStrategy`
(see its documentation for more details about each available panic handling strategy)
### Changed 🛠
diff --git a/crates/rustc_codegen_spirv/src/attr.rs b/crates/rustc_codegen_spirv/src/attr.rs
index d668146113..37915a2b09 100644
--- a/crates/rustc_codegen_spirv/src/attr.rs
+++ b/crates/rustc_codegen_spirv/src/attr.rs
@@ -68,6 +68,12 @@ pub enum IntrinsicType {
Matrix,
}
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub struct SpecConstant {
+ pub id: u32,
+ pub default: Option,
+}
+
// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside
// `tests/ui/spirv-attr` should be updated (and new ones added if necessary).
#[derive(Debug, Clone)]
@@ -87,6 +93,7 @@ pub enum SpirvAttribute {
Flat,
Invariant,
InputAttachmentIndex(u32),
+ SpecConstant(SpecConstant),
// `fn`/closure attributes:
BufferLoadIntrinsic,
@@ -121,6 +128,7 @@ pub struct AggregatedSpirvAttributes {
pub flat: Option>,
pub invariant: Option>,
pub input_attachment_index: Option>,
+ pub spec_constant: Option>,
// `fn`/closure attributes:
pub buffer_load_intrinsic: Option>,
@@ -211,6 +219,12 @@ impl AggregatedSpirvAttributes {
span,
"#[spirv(attachment_index)]",
),
+ SpecConstant(value) => try_insert(
+ &mut self.spec_constant,
+ value,
+ span,
+ "#[spirv(spec_constant)]",
+ ),
BufferLoadIntrinsic => try_insert(
&mut self.buffer_load_intrinsic,
(),
@@ -300,7 +314,8 @@ impl CheckSpirvAttrVisitor<'_> {
| SpirvAttribute::Binding(_)
| SpirvAttribute::Flat
| SpirvAttribute::Invariant
- | SpirvAttribute::InputAttachmentIndex(_) => match target {
+ | SpirvAttribute::InputAttachmentIndex(_)
+ | SpirvAttribute::SpecConstant(_) => match target {
Target::Param => {
let parent_hir_id = self.tcx.hir().parent_id(hir_id);
let parent_is_entry_point =
diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
index 0cd3c21fb1..04e8268568 100644
--- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
+++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
@@ -1,6 +1,6 @@
use super::CodegenCx;
use crate::abi::ConvSpirvType;
-use crate::attr::{AggregatedSpirvAttributes, Entry, Spanned};
+use crate::attr::{AggregatedSpirvAttributes, Entry, Spanned, SpecConstant};
use crate::builder::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
use crate::spirv_type::SpirvType;
@@ -40,7 +40,10 @@ struct EntryParamDeducedFromRustRefOrValue<'tcx> {
/// The SPIR-V storage class to declare the shader interface variable in,
/// either deduced from the type (e.g. opaque handles use `UniformConstant`),
/// provided via `#[spirv(...)]` attributes, or an `Input`/`Output` default.
- storage_class: StorageClass,
+ //
+ // HACK(eddyb) this can be `Err(SpecConstant)` to indicate this is actually
+ // an `OpSpecConstant` being exposed as if it were an `Input`
+ storage_class: Result,
/// Whether this entry-point parameter doesn't allow writes to the underlying
/// shader interface variable (i.e. is by-value, or `&T` where `T: Freeze`).
@@ -387,6 +390,30 @@ impl<'tcx> CodegenCx<'tcx> {
}
}
+ // HACK(eddyb) only handle `attrs.spec_constant` after everything above
+ // would've assumed it was actually an implicitly-`Input`.
+ let mut storage_class = Ok(storage_class);
+ if let Some(spec_constant) = attrs.spec_constant {
+ if ref_or_value_layout.ty != self.tcx.types.u32 {
+ self.tcx.sess.span_err(
+ hir_param.ty_span,
+ format!(
+ "unsupported `#[spirv(spec_constant)]` type `{}` (expected `{}`)",
+ ref_or_value_layout.ty, self.tcx.types.u32
+ ),
+ );
+ } else if let Some(storage_class) = attrs.storage_class {
+ self.tcx.sess.span_err(
+ storage_class.span,
+ "`#[spirv(spec_constant)]` cannot have a storage class",
+ );
+ } else {
+ assert_eq!(storage_class, Ok(StorageClass::Input));
+ assert!(!is_ref);
+ storage_class = Err(spec_constant.value);
+ }
+ }
+
EntryParamDeducedFromRustRefOrValue {
value_layout,
storage_class,
@@ -407,9 +434,6 @@ impl<'tcx> CodegenCx<'tcx> {
) {
let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.hir().attrs(hir_param.hir_id));
- // Pre-allocate the module-scoped `OpVariable`'s *Result* ID.
- let var = self.emit_global().id();
-
let EntryParamDeducedFromRustRefOrValue {
value_layout,
storage_class,
@@ -417,14 +441,35 @@ impl<'tcx> CodegenCx<'tcx> {
} = self.entry_param_deduce_from_rust_ref_or_value(entry_arg_abi.layout, hir_param, &attrs);
let value_spirv_type = value_layout.spirv_type(hir_param.ty_span, self);
+ let (var_id, spec_const_id) = match storage_class {
+ // Pre-allocate the module-scoped `OpVariable` *Result* ID.
+ Ok(_) => (
+ Ok(self.emit_global().id()),
+ Err("entry-point interface variable is not a `#[spirv(spec_constant)]`"),
+ ),
+ Err(SpecConstant { id, default }) => {
+ let mut emit = self.emit_global();
+ let spec_const_id = emit.spec_constant_u32(value_spirv_type, default.unwrap_or(0));
+ emit.decorate(
+ spec_const_id,
+ Decoration::SpecId,
+ [Operand::LiteralInt32(id)],
+ );
+ (
+ Err("`#[spirv(spec_constant)]` is not an entry-point interface variable"),
+ Ok(spec_const_id),
+ )
+ }
+ };
+
// Emit decorations deduced from the reference/value Rust type.
if read_only {
// NOTE(eddyb) it appears only `StorageBuffer`s simultaneously:
// - allow `NonWritable` decorations on shader interface variables
// - default to writable (i.e. the decoration actually has an effect)
- if storage_class == StorageClass::StorageBuffer {
+ if storage_class == Ok(StorageClass::StorageBuffer) {
self.emit_global()
- .decorate(var, Decoration::NonWritable, []);
+ .decorate(var_id.unwrap(), Decoration::NonWritable, []);
}
}
@@ -454,14 +499,20 @@ impl<'tcx> CodegenCx<'tcx> {
}
let var_ptr_spirv_type;
let (value_ptr, value_len) = match storage_class {
- StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer => {
+ Ok(
+ StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer,
+ ) => {
let var_spirv_type = SpirvType::InterfaceBlock {
inner_type: value_spirv_type,
}
.def(hir_param.span, self);
var_ptr_spirv_type = self.type_ptr_to(var_spirv_type);
- let value_ptr = bx.struct_gep(var_spirv_type, var.with_type(var_ptr_spirv_type), 0);
+ let value_ptr = bx.struct_gep(
+ var_spirv_type,
+ var_id.unwrap().with_type(var_ptr_spirv_type),
+ 0,
+ );
let value_len = if is_unsized_with_len {
match self.lookup_type(value_spirv_type) {
@@ -478,7 +529,7 @@ impl<'tcx> CodegenCx<'tcx> {
let len_spirv_type = self.type_isize();
let len = bx
.emit()
- .array_length(len_spirv_type, None, var, 0)
+ .array_length(len_spirv_type, None, var_id.unwrap(), 0)
.unwrap();
Some(len.with_type(len_spirv_type))
@@ -493,9 +544,9 @@ impl<'tcx> CodegenCx<'tcx> {
None
};
- (value_ptr, value_len)
+ (Ok(value_ptr), value_len)
}
- StorageClass::UniformConstant => {
+ Ok(StorageClass::UniformConstant) => {
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
match self.lookup_type(value_spirv_type) {
@@ -524,7 +575,7 @@ impl<'tcx> CodegenCx<'tcx> {
None
};
- (var.with_type(var_ptr_spirv_type), value_len)
+ (Ok(var_id.unwrap().with_type(var_ptr_spirv_type)), value_len)
}
_ => {
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
@@ -533,12 +584,19 @@ impl<'tcx> CodegenCx<'tcx> {
self.tcx.sess.span_fatal(
hir_param.ty_span,
format!(
- "unsized types are not supported for storage class {storage_class:?}"
+ "unsized types are not supported for {}",
+ match storage_class {
+ Ok(storage_class) => format!("storage class {storage_class:?}"),
+ Err(SpecConstant { .. }) => "`#[spirv(spec_constant)]`".into(),
+ },
),
);
}
- (var.with_type(var_ptr_spirv_type), None)
+ (
+ var_id.map(|var_id| var_id.with_type(var_ptr_spirv_type)),
+ None,
+ )
}
};
@@ -546,21 +604,26 @@ impl<'tcx> CodegenCx<'tcx> {
// starting from the `value_ptr` pointing to a `value_spirv_type`
// (e.g. `Input` doesn't use indirection, so we have to load from it).
if let ty::Ref(..) = entry_arg_abi.layout.ty.kind() {
- call_args.push(value_ptr);
+ call_args.push(value_ptr.unwrap());
match entry_arg_abi.mode {
PassMode::Direct(_) => assert_eq!(value_len, None),
PassMode::Pair(..) => call_args.push(value_len.unwrap()),
_ => unreachable!(),
}
} else {
- assert_eq!(storage_class, StorageClass::Input);
assert_matches!(entry_arg_abi.mode, PassMode::Direct(_));
- let value = bx.load(
- entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx),
- value_ptr,
- entry_arg_abi.layout.align.abi,
- );
+ let value = match storage_class {
+ Ok(_) => {
+ assert_eq!(storage_class, Ok(StorageClass::Input));
+ bx.load(
+ entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx),
+ value_ptr.unwrap(),
+ entry_arg_abi.layout.align.abi,
+ )
+ }
+ Err(SpecConstant { .. }) => spec_const_id.unwrap().with_type(value_spirv_type),
+ };
call_args.push(value);
assert_eq!(value_len, None);
}
@@ -573,48 +636,76 @@ impl<'tcx> CodegenCx<'tcx> {
// name (e.g. "foo" for `foo: Vec3`). While `OpName` is *not* suppposed
// to be semantic, OpenGL and some tooling rely on it for reflection.
if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind {
- self.emit_global().name(var, ident.to_string());
+ self.emit_global()
+ .name(var_id.or(spec_const_id).unwrap(), ident.to_string());
}
// Emit `OpDecorate`s based on attributes.
let mut decoration_supersedes_location = false;
- if let Some(builtin) = attrs.builtin.map(|attr| attr.value) {
+ if let Some(builtin) = attrs.builtin {
+ if let Err(SpecConstant { .. }) = storage_class {
+ self.tcx.sess.span_fatal(
+ builtin.span,
+ format!(
+ "`#[spirv(spec_constant)]` cannot be `{:?}` builtin",
+ builtin.value
+ ),
+ );
+ }
self.emit_global().decorate(
- var,
+ var_id.unwrap(),
Decoration::BuiltIn,
- std::iter::once(Operand::BuiltIn(builtin)),
+ std::iter::once(Operand::BuiltIn(builtin.value)),
);
decoration_supersedes_location = true;
}
- if let Some(index) = attrs.descriptor_set.map(|attr| attr.value) {
+ if let Some(descriptor_set) = attrs.descriptor_set {
+ if let Err(SpecConstant { .. }) = storage_class {
+ self.tcx.sess.span_fatal(
+ descriptor_set.span,
+ "`#[spirv(descriptor_set = ...)]` cannot apply to `#[spirv(spec_constant)]`",
+ );
+ }
self.emit_global().decorate(
- var,
+ var_id.unwrap(),
Decoration::DescriptorSet,
- std::iter::once(Operand::LiteralInt32(index)),
+ std::iter::once(Operand::LiteralInt32(descriptor_set.value)),
);
decoration_supersedes_location = true;
}
- if let Some(index) = attrs.binding.map(|attr| attr.value) {
+ if let Some(binding) = attrs.binding {
+ if let Err(SpecConstant { .. }) = storage_class {
+ self.tcx.sess.span_fatal(
+ binding.span,
+ "`#[spirv(binding = ...)]` cannot apply to `#[spirv(spec_constant)]`",
+ );
+ }
self.emit_global().decorate(
- var,
+ var_id.unwrap(),
Decoration::Binding,
- std::iter::once(Operand::LiteralInt32(index)),
+ std::iter::once(Operand::LiteralInt32(binding.value)),
);
decoration_supersedes_location = true;
}
- if attrs.flat.is_some() {
+ if let Some(flat) = attrs.flat {
+ if let Err(SpecConstant { .. }) = storage_class {
+ self.tcx.sess.span_fatal(
+ flat.span,
+ "`#[spirv(flat)]` cannot apply to `#[spirv(spec_constant)]`",
+ );
+ }
self.emit_global()
- .decorate(var, Decoration::Flat, std::iter::empty());
+ .decorate(var_id.unwrap(), Decoration::Flat, std::iter::empty());
}
if let Some(invariant) = attrs.invariant {
- self.emit_global()
- .decorate(var, Decoration::Invariant, std::iter::empty());
- if storage_class != StorageClass::Output {
- self.tcx.sess.span_err(
+ if storage_class != Ok(StorageClass::Output) {
+ self.tcx.sess.span_fatal(
invariant.span,
- "#[spirv(invariant)] is only valid on Output variables",
+ "`#[spirv(invariant)]` is only valid on Output variables",
);
}
+ self.emit_global()
+ .decorate(var_id.unwrap(), Decoration::Invariant, std::iter::empty());
}
let is_subpass_input = match self.lookup_type(value_spirv_type) {
@@ -635,7 +726,7 @@ impl<'tcx> CodegenCx<'tcx> {
if let Some(attachment_index) = attrs.input_attachment_index {
if is_subpass_input && self.builder.has_capability(Capability::InputAttachment) {
self.emit_global().decorate(
- var,
+ var_id.unwrap(),
Decoration::InputAttachmentIndex,
std::iter::once(Operand::LiteralInt32(attachment_index.value)),
);
@@ -657,14 +748,16 @@ impl<'tcx> CodegenCx<'tcx> {
);
}
- self.check_for_bad_types(
- execution_model,
- hir_param.ty_span,
- var_ptr_spirv_type,
- storage_class,
- attrs.builtin.is_some(),
- attrs.flat,
- );
+ if let Ok(storage_class) = storage_class {
+ self.check_for_bad_types(
+ execution_model,
+ hir_param.ty_span,
+ var_ptr_spirv_type,
+ storage_class,
+ attrs.builtin.is_some(),
+ attrs.flat,
+ );
+ }
// Assign locations from left to right, incrementing each storage class
// individually.
@@ -673,34 +766,43 @@ impl<'tcx> CodegenCx<'tcx> {
let has_location = !decoration_supersedes_location
&& matches!(
storage_class,
- StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant
+ Ok(StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant)
);
if has_location {
let location = decoration_locations
- .entry(storage_class)
+ .entry(storage_class.unwrap())
.or_insert_with(|| 0);
self.emit_global().decorate(
- var,
+ var_id.unwrap(),
Decoration::Location,
std::iter::once(Operand::LiteralInt32(*location)),
);
*location += 1;
}
- // Emit the `OpVariable` with its *Result* ID set to `var`.
- self.emit_global()
- .variable(var_ptr_spirv_type, Some(var), storage_class, None);
+ match storage_class {
+ Ok(storage_class) => {
+ let var = var_id.unwrap();
- // Record this `OpVariable` as needing to be added (if applicable),
- // to the *Interface* operands of the `OpEntryPoint` instruction.
- if self.emit_global().version().unwrap() > (1, 3) {
- // SPIR-V >= v1.4 includes all OpVariables in the interface.
- op_entry_point_interface_operands.push(var);
- } else {
- // SPIR-V <= v1.3 only includes Input and Output in the interface.
- if storage_class == StorageClass::Input || storage_class == StorageClass::Output {
- op_entry_point_interface_operands.push(var);
+ // Emit the `OpVariable` with its *Result* ID set to `var_id`.
+ self.emit_global()
+ .variable(var_ptr_spirv_type, Some(var), storage_class, None);
+
+ // Record this `OpVariable` as needing to be added (if applicable),
+ // to the *Interface* operands of the `OpEntryPoint` instruction.
+ if self.emit_global().version().unwrap() > (1, 3) {
+ // SPIR-V >= v1.4 includes all OpVariables in the interface.
+ op_entry_point_interface_operands.push(var);
+ } else {
+ // SPIR-V <= v1.3 only includes Input and Output in the interface.
+ if storage_class == StorageClass::Input || storage_class == StorageClass::Output
+ {
+ op_entry_point_interface_operands.push(var);
+ }
+ }
}
+ // Emitted earlier.
+ Err(SpecConstant { .. }) => {}
}
}
diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs
index b0af1948aa..4129187b39 100644
--- a/crates/rustc_codegen_spirv/src/symbols.rs
+++ b/crates/rustc_codegen_spirv/src/symbols.rs
@@ -1,4 +1,4 @@
-use crate::attr::{Entry, ExecutionModeExtra, IntrinsicType, SpirvAttribute};
+use crate::attr::{Entry, ExecutionModeExtra, IntrinsicType, SpecConstant, SpirvAttribute};
use crate::builder::libm_intrinsics;
use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
use rustc_ast::ast::{AttrKind, Attribute, LitIntType, LitKind, MetaItemLit, NestedMetaItem};
@@ -30,6 +30,10 @@ pub struct Symbols {
binding: Symbol,
input_attachment_index: Symbol,
+ spec_constant: Symbol,
+ id: Symbol,
+ default: Symbol,
+
attributes: FxHashMap,
execution_modes: FxHashMap,
pub libm_intrinsics: FxHashMap,
@@ -392,6 +396,10 @@ impl Symbols {
binding: Symbol::intern("binding"),
input_attachment_index: Symbol::intern("input_attachment_index"),
+ spec_constant: Symbol::intern("spec_constant"),
+ id: Symbol::intern("id"),
+ default: Symbol::intern("default"),
+
attributes,
execution_modes,
libm_intrinsics,
@@ -466,6 +474,8 @@ pub(crate) fn parse_attrs_for_checking<'a>(
SpirvAttribute::Binding(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.input_attachment_index) {
SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
+ } else if arg.has_name(sym.spec_constant) {
+ SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
} else {
let name = match arg.ident() {
Some(i) => i,
@@ -494,6 +504,38 @@ pub(crate) fn parse_attrs_for_checking<'a>(
})
}
+fn parse_spec_constant_attr(
+ sym: &Symbols,
+ arg: &NestedMetaItem,
+) -> Result {
+ let mut id = None;
+ let mut default = None;
+
+ if let Some(attrs) = arg.meta_item_list() {
+ for attr in attrs {
+ if attr.has_name(sym.id) {
+ if id.is_none() {
+ id = Some(parse_attr_int_value(attr)?);
+ } else {
+ return Err((attr.span(), "`id` may only be specified once".into()));
+ }
+ } else if attr.has_name(sym.default) {
+ if default.is_none() {
+ default = Some(parse_attr_int_value(attr)?);
+ } else {
+ return Err((attr.span(), "`default` may only be specified once".into()));
+ }
+ } else {
+ return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
+ }
+ }
+ }
+ Ok(SpecConstant {
+ id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
+ default,
+ })
+}
+
fn parse_attr_int_value(arg: &NestedMetaItem) -> Result {
let arg = match arg.meta_item() {
Some(arg) => arg,
diff --git a/docs/src/attributes.md b/docs/src/attributes.md
index 35af009a47..627f3d752f 100644
--- a/docs/src/attributes.md
+++ b/docs/src/attributes.md
@@ -110,3 +110,47 @@ fn main(#[spirv(workgroup)] var: &mut [Vec4; 4]) { }
## Generic storage classes
The SPIR-V storage class of types is inferred for function signatures. The inference logic can be guided by attributes on the interface specification in the entry points. This also means it needs to be clear from the documentation if an API requires a certain storage class (e.g `workgroup`) for a variable. Storage class attributes are only permitted on entry points.
+
+## Specialization constants
+
+Entry point inputs also allow access to [SPIR-V "specialization constants"](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#SpecializationSection),
+which are each associated with an user-specified numeric "ID" (SPIR-V `SpecId`),
+used to override them later ("specializing" the shader):
+* in Vulkan: [during pipeline creation, via `VkSpecializationInfo`](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/chap10.html#pipelines-specialization-constants)
+* in WebGPU: [during pipeline creation, via `GPUProgrammableStage``#constants`](https://www.w3.org/TR/webgpu/#gpuprogrammablestage)
+ * note: WebGPU calls them ["pipeline-overridable constants"](https://gpuweb.github.io/gpuweb/wgsl/#pipeline-overridable)
+* in OpenCL: [via `clSetProgramSpecializationConstant()` calls, before `clBuildProgram()`](https://registry.khronos.org/OpenCL/sdk/3.0/docs/man/html/clSetProgramSpecializationConstant.html)
+
+If a "specialization constant" is not overriden, it falls back to its *default*
+value, which is either user-specified (via `default = ...`), or `0` otherwise.
+
+While only "specialization constants" of type `u32` are currently supported, it's
+always possible to *manually* create values of other types, from one or more `u32`s.
+
+Example:
+
+```rust
+#[spirv(vertex)]
+fn main(
+ // Default is implicitly `0`, if not specified.
+ #[spirv(spec_constant(id = 1))] no_default: u32,
+
+ // IDs don't need to be sequential or obey any order.
+ #[spirv(spec_constant(id = 9000, default = 123))] default_123: u32,
+
+ // Assembling a larger value out of multiple `u32` is also possible.
+ #[spirv(spec_constant(id = 100))] x_u64_lo: u32,
+ #[spirv(spec_constant(id = 101))] x_u64_hi: u32,
+) {
+ let x_u64 = ((x_u64_hi as u64) << 32) | (x_u64_lo as u64);
+}
+```
+
+**Note**: despite the name "constants", they are *runtime values* from the
+perspective of compiled Rust code (or at most similar to "link-time constants"),
+and as such have no connection to *Rust constants*, especially not Rust type-level
+constants and `const` generics - while specializing some e.g. `fn foo`
+by `N` long after it was compiled to SPIR-V, or using "specialization constants"
+as Rust array lengths, Rust would sadly require *dependent types* to type-check
+such code (as it would for e.g. expressing C `T[n]` types with runtime `n`),
+and the main benefit over truly dynamic inputs is a (potential) performance boost.
diff --git a/examples/runners/ash/src/main.rs b/examples/runners/ash/src/main.rs
index 9ff4430b60..03c2e1111a 100644
--- a/examples/runners/ash/src/main.rs
+++ b/examples/runners/ash/src/main.rs
@@ -187,6 +187,19 @@ pub fn main() {
});
}
}
+ Some(key @ (VirtualKeyCode::NumpadAdd | VirtualKeyCode::NumpadSubtract)) => {
+ let factor =
+ &mut ctx.sky_fs_spec_id_0x5007_sun_intensity_extra_spec_const_factor;
+ *factor = if key == VirtualKeyCode::NumpadAdd {
+ factor.saturating_add(1)
+ } else {
+ factor.saturating_sub(1)
+ };
+
+ // HACK(eddyb) to see any changes, re-specializing the
+ // shader module is needed (e.g. during pipeline rebuild).
+ ctx.rebuild_pipelines(vk::PipelineCache::null());
+ }
_ => *control_flow = ControlFlow::Wait,
},
WindowEvent::Resized(_) => {
@@ -657,6 +670,9 @@ pub struct RenderCtx {
pub rendering_paused: bool,
pub recompiling_shaders: bool,
pub start: std::time::Instant,
+
+ // NOTE(eddyb) this acts like an integration test for specialization constants.
+ pub sky_fs_spec_id_0x5007_sun_intensity_extra_spec_const_factor: u32,
}
impl RenderCtx {
@@ -702,6 +718,8 @@ impl RenderCtx {
rendering_paused: false,
recompiling_shaders: false,
start: std::time::Instant::now(),
+
+ sky_fs_spec_id_0x5007_sun_intensity_extra_spec_const_factor: 100,
}
}
@@ -723,6 +741,18 @@ impl RenderCtx {
}
pub fn rebuild_pipelines(&mut self, pipeline_cache: vk::PipelineCache) {
+ // NOTE(eddyb) this acts like an integration test for specialization constants.
+ let spec_const_entries = [vk::SpecializationMapEntry::builder()
+ .constant_id(0x5007)
+ .offset(0)
+ .size(4)
+ .build()];
+ let spec_const_data =
+ u32::to_le_bytes(self.sky_fs_spec_id_0x5007_sun_intensity_extra_spec_const_factor);
+ let specialization_info = vk::SpecializationInfo::builder()
+ .map_entries(&spec_const_entries)
+ .data(&spec_const_data);
+
self.cleanup_pipelines();
let pipeline_layout = self.create_pipeline_layout();
let viewport = vk::PipelineViewportStateCreateInfo::builder()
@@ -754,6 +784,7 @@ impl RenderCtx {
module: *frag_module,
p_name: (*frag_name).as_ptr(),
stage: vk::ShaderStageFlags::FRAGMENT,
+ p_specialization_info: &*specialization_info,
..Default::default()
},
]))
diff --git a/examples/runners/cpu/src/main.rs b/examples/runners/cpu/src/main.rs
index 7fccbc6975..30c12a9420 100644
--- a/examples/runners/cpu/src/main.rs
+++ b/examples/runners/cpu/src/main.rs
@@ -141,7 +141,7 @@ fn main() {
* vec2(WIDTH as f32, HEIGHT as f32);
// evaluate the fragment shader for the specific pixel
- let color = shader_module::fs(&push_constants, frag_coord);
+ let color = shader_module::fs(&push_constants, frag_coord, 1);
color_u32_from_vec4(color)
})
diff --git a/examples/shaders/sky-shader/src/lib.rs b/examples/shaders/sky-shader/src/lib.rs
index 4bebaf38f3..5096df0dd9 100644
--- a/examples/shaders/sky-shader/src/lib.rs
+++ b/examples/shaders/sky-shader/src/lib.rs
@@ -71,7 +71,7 @@ fn sun_intensity(zenith_angle_cos: f32) -> f32 {
)
}
-fn sky(dir: Vec3, sun_position: Vec3) -> Vec3 {
+fn sky(dir: Vec3, sun_position: Vec3, sun_intensity_extra_spec_const_factor: u32) -> Vec3 {
let up = vec3(0.0, 1.0, 0.0);
let sunfade = 1.0 - (1.0 - saturate(sun_position.y / 450000.0).exp());
let rayleigh_coefficient = RAYLEIGH - (1.0 * (1.0 - sunfade));
@@ -96,7 +96,12 @@ fn sky(dir: Vec3, sun_position: Vec3) -> Vec3 {
let beta_r_theta = beta_r * rayleigh_phase(cos_theta * 0.5 + 0.5);
let beta_m_theta = beta_m * henyey_greenstein_phase(cos_theta, MIE_DIRECTIONAL_G);
- let sun_e = sun_intensity(sun_direction.dot(up));
+ let sun_e = sun_intensity(sun_direction.dot(up))
+
+ // HACK(eddyb) this acts like an integration test for specialization constants,
+ // but the correct value is only obtained when this is a noop (multiplies by `1`).
+ * (sun_intensity_extra_spec_const_factor as f32 / 100.0);
+
let mut lin = pow(
sun_e * ((beta_r_theta + beta_m_theta) / (beta_r + beta_m)) * (Vec3::splat(1.0) - fex),
1.5,
@@ -130,7 +135,11 @@ fn get_ray_dir(uv: Vec2, pos: Vec3, look_at_pos: Vec3) -> Vec3 {
(forward + uv.x * right + uv.y * up).normalize()
}
-pub fn fs(constants: &ShaderConstants, frag_coord: Vec2) -> Vec4 {
+pub fn fs(
+ constants: &ShaderConstants,
+ frag_coord: Vec2,
+ sun_intensity_extra_spec_const_factor: u32,
+) -> Vec4 {
let mut uv = (frag_coord - 0.5 * vec2(constants.width as f32, constants.height as f32))
/ constants.height as f32;
uv.y = -uv.y;
@@ -141,7 +150,7 @@ pub fn fs(constants: &ShaderConstants, frag_coord: Vec2) -> Vec4 {
let dir = get_ray_dir(uv, eye_pos, sun_pos);
// evaluate Preetham sky model
- let color = sky(dir, sun_pos);
+ let color = sky(dir, sun_pos, sun_intensity_extra_spec_const_factor);
// Tonemapping
let color = color.max(Vec3::splat(0.0)).min(Vec3::splat(1024.0));
@@ -154,9 +163,12 @@ pub fn main_fs(
#[spirv(frag_coord)] in_frag_coord: Vec4,
#[spirv(push_constant)] constants: &ShaderConstants,
output: &mut Vec4,
+
+ // NOTE(eddyb) this acts like an integration test for specialization constants.
+ #[spirv(spec_constant(id = 0x5007, default = 100))] sun_intensity_extra_spec_const_factor: u32,
) {
let frag_coord = vec2(in_frag_coord.x, in_frag_coord.y);
- *output = fs(constants, frag_coord);
+ *output = fs(constants, frag_coord, sun_intensity_extra_spec_const_factor);
}
#[spirv(vertex)]
diff --git a/tests/ui/dis/spec_constant-attr.rs b/tests/ui/dis/spec_constant-attr.rs
new file mode 100644
index 0000000000..ce04ba674b
--- /dev/null
+++ b/tests/ui/dis/spec_constant-attr.rs
@@ -0,0 +1,29 @@
+#![crate_name = "spec_constant_attr"]
+
+// Tests the various forms of `#[spirv(spec_constant)]`.
+
+// build-pass
+// compile-flags: -C llvm-args=--disassemble-globals
+// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
+// normalize-stderr-test "OpSource .*\n" -> ""
+// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> ""
+// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
+
+// FIXME(eddyb) this should use revisions to track both the `vulkan1.2` output
+// and the pre-`vulkan1.2` output, but per-revisions `{only,ignore}-*` directives
+// are not supported in `compiletest-rs`.
+// ignore-vulkan1.2
+
+use spirv_std::spirv;
+
+#[spirv(fragment)]
+pub fn main(
+ #[spirv(spec_constant(id = 1))] no_default: u32,
+ #[spirv(spec_constant(id = 2, default = 0))] default_0: u32,
+ #[spirv(spec_constant(id = 123, default = 123))] default_123: u32,
+ #[spirv(spec_constant(id = 0xffff_ffff, default = 0xffff_ffff))] max_id_and_default: u32,
+
+ out: &mut u32,
+) {
+ *out = no_default + default_0 + default_123 + max_id_and_default;
+}
diff --git a/tests/ui/dis/spec_constant-attr.stderr b/tests/ui/dis/spec_constant-attr.stderr
new file mode 100644
index 0000000000..5cdc1f968d
--- /dev/null
+++ b/tests/ui/dis/spec_constant-attr.stderr
@@ -0,0 +1,30 @@
+OpCapability Shader
+OpCapability Float64
+OpCapability Int64
+OpCapability Int16
+OpCapability Int8
+OpCapability ShaderClockKHR
+OpExtension "SPV_KHR_shader_clock"
+OpMemoryModel Logical Simple
+OpEntryPoint Fragment %1 "main" %2
+OpExecutionMode %1 OriginUpperLeft
+%3 = OpString "$OPSTRING_FILENAME/spec_constant-attr.rs"
+OpName %4 "no_default"
+OpName %5 "default_0"
+OpName %6 "default_123"
+OpName %7 "max_id_and_default"
+OpName %2 "out"
+OpDecorate %4 SpecId 1
+OpDecorate %5 SpecId 2
+OpDecorate %6 SpecId 123
+OpDecorate %7 SpecId 4294967295
+OpDecorate %2 Location 0
+%8 = OpTypeInt 32 0
+%9 = OpTypePointer Output %8
+%10 = OpTypeVoid
+%11 = OpTypeFunction %10
+%4 = OpSpecConstant %8 0
+%5 = OpSpecConstant %8 0
+%6 = OpSpecConstant %8 123
+%7 = OpSpecConstant %8 4294967295
+%2 = OpVariable %9 Output
diff --git a/tests/ui/spirv-attr/invariant-invalid.stderr b/tests/ui/spirv-attr/invariant-invalid.stderr
index c0e917418f..4ba05ebacc 100644
--- a/tests/ui/spirv-attr/invariant-invalid.stderr
+++ b/tests/ui/spirv-attr/invariant-invalid.stderr
@@ -1,4 +1,4 @@
-error: #[spirv(invariant)] is only valid on Output variables
+error: `#[spirv(invariant)]` is only valid on Output variables
--> $DIR/invariant-invalid.rs:7:21
|
7 | pub fn main(#[spirv(invariant)] input: f32) {}