Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga] resolve the size of override-sized arrays in backends #6787

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ pub enum Error {
/// [`crate::Sampling::First`] is unsupported.
#[error("`{:?}` sampling is unsupported", crate::Sampling::First)]
FirstSamplingNotSupported,
#[error(transparent)]
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
}

/// Binary operation with a different logic on the GLSL side.
Expand Down Expand Up @@ -976,13 +978,12 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, "[")?;

// Write the array size
// Writes nothing if `ArraySize::Dynamic`
match size {
crate::ArraySize::Constant(size) => {
// Writes nothing if `ResolvedSize::Runtime`
match size.resolve(self.module.to_ctx())? {
proc::ResolvedSize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => (),
proc::ResolvedSize::Runtime => (),
}

write!(self.out, "]")?;
Expand Down Expand Up @@ -4519,13 +4520,9 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, ")")?;
}
TypeInner::Array { base, size, .. } => {
let count = match size
.to_indexable_length(self.module)
.expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => return Ok(()),
let count = match size.resolve(self.module.to_ctx())? {
proc::ResolvedSize::Constant(size) => size,
proc::ResolvedSize::Runtime => return Ok(()),
};
self.write_type(base)?;
self.write_array_size(base, size)?;
Expand Down
19 changes: 9 additions & 10 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl crate::TypeInner {
}
}

pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 {
pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> Result<u32, Error> {
match *self {
Self::Matrix {
columns,
Expand All @@ -62,19 +62,18 @@ impl crate::TypeInner {
} => {
let stride = Alignment::from(rows) * scalar.width as u32;
let last_row_size = rows as u32 * scalar.width as u32;
((columns as u32 - 1) * stride) + last_row_size
Ok(((columns as u32 - 1) * stride) + last_row_size)
}
Self::Array { base, size, stride } => {
let count = match size {
crate::ArraySize::Constant(size) => size.get(),
// A dynamically-sized array has to have at least one element
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => 1,
let count = match size.resolve(gctx)? {
crate::proc::ResolvedSize::Constant(size) => size,
// A runtime-sized array has to have at least one element
crate::proc::ResolvedSize::Runtime => 1,
};
let last_el_size = gctx.types[base].inner.size_hlsl(gctx);
((count - 1) * stride) + last_el_size
let last_el_size = gctx.types[base].inner.size_hlsl(gctx)?;
Ok(((count - 1) * stride) + last_el_size)
}
_ => self.size(gctx),
_ => Ok(self.size(gctx)),
}
}

Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ pub enum Error {
Custom(String),
#[error("overrides should not be present at this stage")]
Override,
#[error(transparent)]
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
}

#[derive(Default)]
Expand Down
10 changes: 4 additions & 6 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,12 +1090,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
) -> BackendResult {
write!(self.out, "[")?;

match size {
crate::ArraySize::Constant(size) => {
match size.resolve(module.to_ctx())? {
proc::ResolvedSize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => unreachable!(),
proc::ResolvedSize::Runtime => unreachable!(),
}

write!(self.out, "]")?;
Expand Down Expand Up @@ -1140,7 +1139,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
let ty_inner = &module.types[member.ty].inner;
last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx());
last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;

// The indentation is only for readability
write!(self.out, "{}", back::INDENT)?;
Expand Down Expand Up @@ -2851,7 +2850,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => unreachable!(),
}
write!(self.out, ")")?;
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ pub enum Error {
Override,
#[error("bitcasting to {0:?} is not supported")]
UnsupportedBitCast(crate::TypeInner),
#[error(transparent)]
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
}

#[derive(Clone, Debug, PartialEq, thiserror::Error)]
Expand Down
24 changes: 8 additions & 16 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, Transl
use crate::{
arena::{Handle, HandleSet},
back::{self, Baked},
proc::index,
proc::{self, NameKey, TypeResolution},
proc::{self, index, NameKey, TypeResolution},
valid, FastHashMap, FastHashSet,
};
#[cfg(test)]
Expand Down Expand Up @@ -2555,7 +2554,6 @@ impl<W: Write> Writer<W> {
self.out.write_str(") < ")?;
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global =
context.function.originating_global(base).ok_or_else(|| {
Expand Down Expand Up @@ -2692,7 +2690,7 @@ impl<W: Write> Writer<W> {
) -> BackendResult {
let accessing_wrapped_array = match *base_ty {
crate::TypeInner::Array {
size: crate::ArraySize::Constant(_),
size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
..
} => true,
_ => false,
Expand Down Expand Up @@ -2720,7 +2718,6 @@ impl<W: Write> Writer<W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global = context.function.originating_global(base).ok_or_else(|| {
Error::GenericValidation("Could not find originating global".into())
Expand Down Expand Up @@ -3911,8 +3908,8 @@ impl<W: Write> Writer<W> {
first_time: false,
};

match size {
crate::ArraySize::Constant(size) => {
match size.resolve(module.to_ctx())? {
proc::ResolvedSize::Constant(size) => {
writeln!(self.out, "struct {name} {{")?;
writeln!(
self.out,
Expand All @@ -3924,10 +3921,7 @@ impl<W: Write> Writer<W> {
)?;
writeln!(self.out, "}};")?;
}
crate::ArraySize::Pending(_) => {
unreachable!()
}
crate::ArraySize::Dynamic => {
proc::ResolvedSize::Runtime => {
writeln!(self.out, "typedef {base_name} {name}[1];")?;
}
}
Expand Down Expand Up @@ -6321,11 +6315,9 @@ mod workgroup_mem_init {
writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
}
crate::TypeInner::Array { base, size, .. } => {
let count = match size.to_indexable_length(module).expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => unreachable!(),
let count = match size.resolve(module.to_ctx())? {
proc::ResolvedSize::Constant(size) => size,
proc::ResolvedSize::Runtime => unreachable!(),
};

access_stack.enter_array(|access_stack, array_depth| {
Expand Down
113 changes: 62 additions & 51 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
Span, Statement, TypeInner, WithSpan,
Span, Statement, Type, TypeInner, UniqueArena, WithSpan,
};
use std::{borrow::Cow, collections::HashSet, mem};
use thiserror::Error;
Expand All @@ -29,19 +29,29 @@ pub enum PipelineConstantError {
NegativeWorkgroupSize,
}

/// Replace all overrides in `module` with constants.
/// Replace all overrides in `module` with fully-evaluated constant expressions.
///
/// If no changes are needed, this just returns `Cow::Borrowed`
/// references to `module` and `module_info`. Otherwise, it clones
/// `module`, edits its [`global_expressions`] arena to contain only
/// fully-evaluated expressions, and returns `Cow::Owned` values
/// holding the simplified module and its validation results.
/// Given `pipeline_constants`, providing values for all overrides in
/// `module`:
///
/// In either case, the module returned has an empty `overrides`
/// arena, and the `global_expressions` arena contains only
/// fully-evaluated expressions.
/// - Replace all [`Override`] expressions with fully-evaluated
/// constant expressions.
///
/// [`global_expressions`]: Module::global_expressions
/// - Replace all [`Override`][paso] array sizes with [`Expression`]
/// array sizes, referring to fully-evaluated constant expressions.
///
/// - Empty out the `module.overrides` arena.
///
/// Although the above is described in terms of changes to `module`'s
/// contents, this function only actually has shared access to
/// `module`. When changes are needed, this function clones `module`
/// and returns a [`Cow::Owned`] value. If no changes are needed, this
/// function returns a [`Cow::Borrowed`] value that just passes along
/// the original reference.
///
/// [`Override`]: Expression::Override
/// [paso]: crate::PendingArraySize::Override
/// [`Expression`]: crate::PendingArraySize::Expression
pub fn process_overrides<'a>(
module: &'a Module,
module_info: &'a ModuleInfo,
Expand All @@ -51,6 +61,7 @@ pub fn process_overrides<'a>(
return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
}

let original_module_types = &module.types;
let mut module = module.clone();

// A map from override handles to the handles of the constants
Expand Down Expand Up @@ -196,7 +207,12 @@ pub fn process_overrides<'a>(
}
module.entry_points = entry_points;

process_pending(&mut module, &override_map, &adjusted_global_expressions)?;
process_pending(
&mut module,
original_module_types,
&override_map,
&adjusted_global_expressions,
);

// Now that we've rewritten all the expressions, we need to
// recompute their types and other metadata. For the time being,
Expand All @@ -209,60 +225,55 @@ pub fn process_overrides<'a>(

fn process_pending(
module: &mut Module,
original_module_types: &UniqueArena<Type>,
override_map: &HandleVec<Override, Handle<Constant>>,
adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
) -> Result<(), PipelineConstantError> {
for (handle, ty) in module.types.clone().iter() {
) {
for (handle, ty) in original_module_types.iter() {
if let TypeInner::Array {
base,
size: crate::ArraySize::Pending(size),
stride,
} = ty.inner
{
let expr = match size {
match size {
crate::PendingArraySize::Expression(size_expr) => {
adjusted_global_expressions[size_expr]
let expr = adjusted_global_expressions[size_expr];
if expr != size_expr {
module.types.replace(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UniqueArena::replace is specified to panic if the new value already exists in the arena. The WGSL front end also creates PendingArraySize::Expression array sizes. Why do we believe that these replace calls will never panic?

If this isn't a bug, it definitely deserves a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression handle in PendingArraySize::Expression is the only handle to that expression.

I realize that this is not validated though, I will add some validation for this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression handle in PendingArraySize::Expression is the only handle to that expression.

This is needed for type equivalency, arrays sized via an override expression should never be equal to other arrays.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation is not necessary, I missed the fact that we are working with a UniqueArena #6787 (comment).

handle,
Type {
name: ty.name.clone(),
inner: TypeInner::Array {
base,
size: crate::ArraySize::Pending(
crate::PendingArraySize::Expression(expr),
),
stride,
},
},
);
}
}
crate::PendingArraySize::Override(size_override) => {
module.constants[override_map[size_override]].init
let expr = module.constants[override_map[size_override]].init;
module.types.replace(
handle,
Type {
name: ty.name.clone(),
inner: TypeInner::Array {
base,
size: crate::ArraySize::Pending(
crate::PendingArraySize::Expression(expr),
),
stride,
},
},
);
}
};
let value = module
.to_ctx()
.eval_expr_to_u32(expr)
.map(|n| {
if n == 0 {
Err(PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(
module.global_expressions.get_span(expr),
"evaluated to zero",
),
))
} else {
Ok(std::num::NonZeroU32::new(n).unwrap())
}
})
.map_err(|_| {
PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(module.global_expressions.get_span(expr), "negative"),
)
})??;
module.types.replace(
handle,
crate::Type {
name: None,
inner: TypeInner::Array {
base,
size: crate::ArraySize::Constant(value),
stride,
},
},
);
}
}
Ok(())
}

fn process_workgroup_size_override(
Expand Down
7 changes: 3 additions & 4 deletions naga/src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,12 @@ impl BlockContext<'_> {
block: &mut Block,
) -> Result<MaybeKnown<u32>, Error> {
let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
match sequence_ty.indexable_length(self.ir_module) {
match sequence_ty
.indexable_length(self.ir_module, crate::ArraySize::indexable_length_resolved)
{
Ok(crate::proc::IndexableLength::Known(known_length)) => {
Ok(MaybeKnown::Known(known_length))
}
Ok(crate::proc::IndexableLength::Pending) => {
unreachable!()
}
Ok(crate::proc::IndexableLength::Dynamic) => {
let length_id = self.write_runtime_array_length(sequence, block)?;
Ok(MaybeKnown::Computed(length_id))
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ pub enum Error {
Validation(&'static str),
#[error("overrides should not be present at this stage")]
Override,
#[error(transparent)]
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
}

#[derive(Default)]
Expand Down
Loading
Loading