Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spv/msl/hlsl-out] support pipeline constant value replacements #4998

Merged
merged 2 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion naga/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ all-features = true

[features]
default = []
clone = []
dot-out = []
glsl-in = ["pp-rs"]
glsl-out = []
Expand Down
15 changes: 13 additions & 2 deletions naga/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ impl<T> Range<T> {
/// Adding new items to the arena produces a strongly-typed [`Handle`].
/// The arena can be indexed using the given handle to obtain
/// a reference to the stored item.
#[cfg_attr(feature = "clone", derive(Clone))]
#[derive(Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "serialize", serde(transparent))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
Expand Down Expand Up @@ -297,6 +297,17 @@ impl<T> Arena<T> {
.map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) })
}

/// Drains the arena, returning an iterator over the items stored.
pub fn drain(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, T, Span)> {
let arena = std::mem::take(self);
arena
.data
.into_iter()
.zip(arena.span_info.into_iter())
.enumerate()
.map(|(i, (v, span))| unsafe { (Handle::from_usize_unchecked(i), v, span) })
}

/// Returns a iterator over the items stored in this arena,
/// returning both the item's handle and a mutable reference to it.
pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> {
Expand Down Expand Up @@ -531,7 +542,7 @@ mod tests {
///
/// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like,
/// `UniqueArena` is `HashSet`-like.
#[cfg_attr(feature = "clone", derive(Clone))]
#[derive(Clone)]
pub struct UniqueArena<T> {
set: FastIndexSet<T>,

Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,12 @@ impl<'a, W: Write> Writer<'a, W> {
pipeline_options: &'a PipelineOptions,
policies: proc::BoundsCheckPolicies,
) -> Result<Self, Error> {
if !module.overrides.is_empty() {
return Err(Error::Custom(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
}

// Check if the requested version is supported
if !options.version.is_supported() {
log::error!("Version {}", options.version);
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ pub enum Error {
Unimplemented(String), // TODO: Error used only during development
#[error("{0}")]
Custom(String),
#[error(transparent)]
PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError),
}

#[derive(Default)]
Expand Down
6 changes: 5 additions & 1 deletion naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
&mut self,
module: &Module,
module_info: &valid::ModuleInfo,
_pipeline_options: &PipelineOptions,
pipeline_options: &PipelineOptions,
) -> Result<super::ReflectionInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let module = module.as_ref();

self.reset(module);

// Write special constants, if needed
Expand Down
8 changes: 8 additions & 0 deletions naga/src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ pub mod spv;
#[cfg(feature = "wgsl-out")]
pub mod wgsl;

#[cfg(any(
feature = "hlsl-out",
feature = "msl-out",
feature = "spv-out",
feature = "glsl-out"
))]
mod pipeline_constants;

const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
const INDENT: &str = " ";
const BAKE_PREFIX: &str = "_e";
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ pub enum Error {
UnsupportedArrayOfType(Handle<crate::Type>),
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
}

#[derive(Clone, Debug, PartialEq, thiserror::Error)]
Expand Down
4 changes: 4 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3119,6 +3119,10 @@ impl<W: Write> Writer<W> {
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<TranslationInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let module = module.as_ref();

self.names.clear();
self.namer.reset(
module,
Expand Down
213 changes: 213 additions & 0 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
use super::PipelineConstants;
use crate::{Constant, Expression, Literal, Module, Scalar, Span, TypeInner};
use std::borrow::Cow;
use thiserror::Error;

#[derive(Error, Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub enum PipelineConstantError {
#[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
MissingValue(String),
#[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")]
SrcNeedsToBeFinite,
#[error("Source f64 value doesn't fit in destination")]
DstRangeTooSmall,
}

pub(super) fn process_overrides<'a>(
module: &'a Module,
pipeline_constants: &PipelineConstants,
) -> Result<Cow<'a, Module>, PipelineConstantError> {
if module.overrides.is_empty() {
return Ok(Cow::Borrowed(module));
}

let mut module = module.clone();

for (_handle, override_, span) in module.overrides.drain() {
let key = if let Some(id) = override_.id {
Cow::Owned(id.to_string())
} else if let Some(ref name) = override_.name {
Cow::Borrowed(name)
} else {
unreachable!();
};
let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
let literal = match module.types[override_.ty].inner {
TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
_ => unreachable!(),
};
module
.const_expressions
.append(Expression::Literal(literal), Span::UNDEFINED)
} else if let Some(init) = override_.init {
init
} else {
return Err(PipelineConstantError::MissingValue(key.to_string()));
};
let constant = Constant {
name: override_.name,
ty: override_.ty,
init,
};
module.constants.append(constant, span);
}

Ok(Cow::Owned(module))
}

fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
teoxoy marked this conversation as resolved.
Show resolved Hide resolved
jimblandy marked this conversation as resolved.
Show resolved Hide resolved
// note that in rust 0.0 == -0.0
match scalar {
Scalar::BOOL => {
// https://webidl.spec.whatwg.org/#js-boolean
let value = value != 0.0 && !value.is_nan();
Ok(Literal::Bool(value))
}
Scalar::I32 => {
// https://webidl.spec.whatwg.org/#js-long
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}

let value = value.trunc();
if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
return Err(PipelineConstantError::DstRangeTooSmall);
}

let value = value as i32;
Ok(Literal::I32(value))
}
Scalar::U32 => {
// https://webidl.spec.whatwg.org/#js-unsigned-long
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}

let value = value.trunc();
if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
return Err(PipelineConstantError::DstRangeTooSmall);
}

let value = value as u32;
Ok(Literal::U32(value))
}
Scalar::F32 => {
// https://webidl.spec.whatwg.org/#js-float
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}

let value = value as f32;
if !value.is_finite() {
return Err(PipelineConstantError::DstRangeTooSmall);
}

Ok(Literal::F32(value))
}
Scalar::F64 => {
// https://webidl.spec.whatwg.org/#js-double
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}

Ok(Literal::F64(value))
}
_ => unreachable!(),
}
}

#[test]
fn test_map_value_to_literal() {
let bool_test_cases = [
(0.0, false),
(-0.0, false),
(f64::NAN, false),
(1.0, true),
(f64::INFINITY, true),
(f64::NEG_INFINITY, true),
];
for (value, out) in bool_test_cases {
let res = Ok(Literal::Bool(out));
assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
}

for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
assert_eq!(map_value_to_literal(value, scalar), res);
}
}

// i32
assert_eq!(
map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
Ok(Literal::I32(i32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
Ok(Literal::I32(i32::MAX))
);
assert_eq!(
map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
Err(PipelineConstantError::DstRangeTooSmall)
);
assert_eq!(
map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
Err(PipelineConstantError::DstRangeTooSmall)
);

// u32
assert_eq!(
map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
Ok(Literal::U32(u32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
Ok(Literal::U32(u32::MAX))
);
assert_eq!(
map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
Err(PipelineConstantError::DstRangeTooSmall)
);
assert_eq!(
map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
Err(PipelineConstantError::DstRangeTooSmall)
);

// f32
assert_eq!(
map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
Ok(Literal::F32(f32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
Ok(Literal::F32(f32::MAX))
);
assert_eq!(
map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
Ok(Literal::F32(f32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
Ok(Literal::F32(f32::MAX))
);
assert_eq!(
map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
Err(PipelineConstantError::DstRangeTooSmall)
);
assert_eq!(
map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
Err(PipelineConstantError::DstRangeTooSmall)
);

// f64
assert_eq!(
map_value_to_literal(f64::MIN, Scalar::F64),
Ok(Literal::F64(f64::MIN))
);
assert_eq!(
map_value_to_literal(f64::MAX, Scalar::F64),
Ok(Literal::F64(f64::MAX))
);
}
2 changes: 2 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub enum Error {
FeatureNotImplemented(&'static str),
#[error("module is not validated properly: {0}")]
Validation(&'static str),
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
}

#[derive(Default)]
Expand Down
10 changes: 10 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2018,6 +2018,16 @@ impl Writer {
debug_info: &Option<DebugInfo>,
words: &mut Vec<Word>,
) -> Result<(), Error> {
let ir_module = if let Some(pipeline_options) = pipeline_options {
crate::back::pipeline_constants::process_overrides(
ir_module,
&pipeline_options.constants,
)?
} else {
std::borrow::Cow::Borrowed(ir_module)
};
let ir_module = ir_module.as_ref();

self.reset();

// Try to find the entry point and corresponding index
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ impl<W: Write> Writer<W> {
}

pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
if !module.overrides.is_empty() {
return Err(Error::Unimplemented(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
}

self.reset(module);

// Save all ep result types
Expand Down
Loading
Loading