diff --git a/commits.md b/commits.md new file mode 100644 index 0000000000..c661ead71f --- /dev/null +++ b/commits.md @@ -0,0 +1,7 @@ +- [ ] Split out `ExprFedDepError` +- [ ] Combine handle validation pass errors at current implementation level +- [ ] Split out the pass! + - [ ] squash! add `ExpressionTypeResolver::check` + - [ ] squash! `impl Index for ExpressionTypeResolver` + - [ ] note: trade-offs were made, we're now doing multiple passes. Boo for + perf, yay for abstraction. diff --git a/src/arena.rs b/src/arena.rs index d21a786cb5..fba18db100 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -15,6 +15,15 @@ pub struct BadHandle { pub index: usize, } +impl BadHandle { + fn new(handle: Handle) -> Self { + Self { + kind: std::any::type_name::(), + index: handle.index(), + } + } +} + /// A strongly typed reference to an arena item. /// /// A `Handle` value can be used as an index into an [`Arena`] or [`UniqueArena`]. @@ -275,10 +284,9 @@ impl Arena { } pub fn try_get(&self, handle: Handle) -> Result<&T, BadHandle> { - self.data.get(handle.index()).ok_or_else(|| BadHandle { - kind: std::any::type_name::(), - index: handle.index(), - }) + self.data + .get(handle.index()) + .ok_or_else(|| BadHandle::new(handle)) } /// Get a mutable reference to an element in the arena. @@ -313,6 +321,12 @@ impl Arena { Span::default() } } + + pub fn check_contains_handle(&self, handle: Handle) -> Result<(), BadHandle> { + (handle.index() < self.data.len()) + .then_some(()) + .ok_or_else(|| BadHandle::new(handle)) + } } #[cfg(feature = "deserialize")] @@ -533,10 +547,15 @@ impl UniqueArena { /// Return this arena's value at `handle`, if that is a valid handle. pub fn get_handle(&self, handle: Handle) -> Result<&T, BadHandle> { - self.set.get_index(handle.index()).ok_or_else(|| BadHandle { - kind: std::any::type_name::(), - index: handle.index(), - }) + self.set + .get_index(handle.index()) + .ok_or_else(|| BadHandle::new(handle)) + } + + pub fn check_contains_handle(&self, handle: Handle) -> Result<(), BadHandle> { + (handle.index() < self.set.len()) + .then_some(()) + .ok_or_else(|| BadHandle::new(handle)) } } diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 039bfcce30..154eb107de 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -40,12 +40,12 @@ impl crate::TypeInner { } } - pub(super) fn try_size_hlsl( + pub(super) fn size_hlsl( &self, types: &crate::UniqueArena, constants: &crate::Arena, - ) -> Result { - Ok(match *self { + ) -> u32 { + match *self { Self::Matrix { columns, rows, @@ -58,17 +58,16 @@ impl crate::TypeInner { Self::Array { base, size, stride } => { let count = match size { crate::ArraySize::Constant(handle) => { - let constant = constants.try_get(handle)?; - constant.to_array_length().unwrap_or(1) + constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element crate::ArraySize::Dynamic => 1, }; - let last_el_size = types[base].inner.try_size_hlsl(types, constants)?; + let last_el_size = types[base].inner.size_hlsl(types, constants); ((count - 1) * stride) + last_el_size } - _ => self.try_size(constants)?, - }) + _ => self.size(constants), + } } /// Used to generate the name of the wrapped type constructor diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 42e3060798..e283e774e5 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -805,6 +805,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { /// /// # Notes /// Ends in a newline + /// + /// # Panics + /// + /// This function may panic when given a malformed IR module. fn write_struct( &mut self, module: &Module, @@ -829,10 +833,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset - + ty_inner - .try_size_hlsl(&module.types, &module.constants) - .unwrap(); + last_offset = member.offset + ty_inner.size_hlsl(&module.types, &module.constants); // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 0be7ef072a..817fa78b0a 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1609,15 +1609,7 @@ impl Writer { Function::Regular(fun_name) => { write!(self.out, "{}(", fun_name)?; self.write_expr(module, arg, func_ctx)?; - if let Some(arg) = arg1 { - write!(self.out, ", ")?; - self.write_expr(module, arg, func_ctx)?; - } - if let Some(arg) = arg2 { - write!(self.out, ", ")?; - self.write_expr(module, arg, func_ctx)?; - } - if let Some(arg) = arg3 { + for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() { write!(self.out, ", ")?; self.write_expr(module, arg, func_ctx)?; } diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 0c3a00db15..68a2a65192 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -1,4 +1,7 @@ -use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; +use crate::{ + arena::{Arena, BadHandle, Handle, UniqueArena}, + WithSpan, +}; use std::{fmt::Display, num::NonZeroU32, ops}; /// A newtype struct where its only valid values are powers of 2 @@ -130,8 +133,6 @@ pub enum LayoutErrorInner { InvalidStructMemberType(u32, Handle), #[error("Type width must be a power of two")] NonPowerOfTwoWidth, - #[error("Array size is a bad handle")] - BadHandle(#[from] BadHandle), } #[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] @@ -153,6 +154,20 @@ impl Layouter { self.layouts.clear(); } + pub fn validate_handles( + &mut self, + types: &UniqueArena, + constants: &Arena, + ) -> Result<(), WithSpan> { + types + .iter() + .skip(self.layouts.len()) + .try_for_each(|(_handle, ty)| { + // TODO: track the type we fail on? + ty.inner.validate_handles(constants) + }) + } + /// Extend this `Layouter` with layouts for any new entries in `types`. /// /// Ensure that every type in `types` has a corresponding [TypeLayout] in @@ -166,7 +181,13 @@ impl Layouter { /// end can call this function at any time, passing its current type and /// constant arenas, and then assume that layouts are available for all /// types. + /// + /// # Panics + /// + /// If `types` contains invalid [`Handle`]s to `constants`, then this function will panic. You + /// can check for this condition by calling [`Self::validate_handles`]. #[allow(clippy::or_fun_call)] + #[track_caller] pub fn update( &mut self, types: &UniqueArena, @@ -175,10 +196,8 @@ impl Layouter { use crate::TypeInner as Ti; for (ty_handle, ty) in types.iter().skip(self.layouts.len()) { - let size = ty - .inner - .try_size(constants) - .map_err(|error| LayoutErrorInner::BadHandle(error).with(ty_handle))?; + // phase-id: layouter + let size = ty.inner.size(constants); let layout = match ty.inner { Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => { let alignment = Alignment::new(width as u32) @@ -255,3 +274,52 @@ impl Layouter { Ok(()) } } + +// #[derive(Debug)] +// struct HandlesValidated(T); + +// impl HandlesValidated { +// pub const fn new(t: T) -> Self { +// Self(t) +// } +// } + +// impl<'a, T> Clone for HandlesValidated<&'a T> { +// fn clone(&self) -> Self { +// let Self(ref_) = self; +// Self(ref_) +// } +// } + +// impl<'a, T> Copy for HandlesValidated<&'a T> {} + +// impl<'a, T> AsRef for HandlesValidated<&'a T> { +// fn as_ref(&self) -> &T { +// todo!() +// } +// } + +// impl<'a, T> Deref for HandlesValidated<&'a T> { +// type Target = T; + +// fn deref(&self) -> &Self::Target { +// let Self(inner) = self; +// inner +// } +// } + +// impl<'a, T> Deref for HandlesValidated<&'a mut T> { +// type Target = T; + +// fn deref(&self) -> &Self::Target { +// let Self(inner) = self; +// inner +// } +// } + +// impl<'a, T> DerefMut for HandlesValidated<&'a mut T> { +// fn deref_mut(&mut self) -> &mut Self::Target { +// let Self(inner) = self; +// inner +// } +// } diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 0b9f406af4..112a48afbe 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -16,6 +16,8 @@ pub use namer::{EntryPointIndex, NameKey, Namer}; pub use terminator::ensure_block_returns; pub use typifier::{ResolveContext, ResolveError, TypeResolution}; +use crate::{arena::BadHandle, WithSpan}; + impl From for super::ScalarKind { fn from(format: super::StorageFormat) -> Self { use super::{ScalarKind as Sk, StorageFormat as Sf}; @@ -97,11 +99,32 @@ impl super::TypeInner { } } - pub fn try_size( + pub fn validate_handles( &self, constants: &super::Arena, - ) -> Result { - Ok(match *self { + ) -> Result<(), WithSpan> { + match self { + &Self::Array { + base: _, + size: super::ArraySize::Constant(handle), + stride: _, + } => constants + .try_get(handle) + .map(|_| ()) + .map_err(|e| WithSpan::new(e).with_handle(handle, constants)), + _ => Ok(()), + } + } + + /// Get the size of this type. + /// + /// # Panics + /// + /// Panics if the `constants` doesn't contain a referenced handle. This may not happen in + /// a properly validated IR module. You can check for this condition with + /// [`Self::validate_handles`]. + pub fn size(&self, constants: &super::Arena) -> u32 { + match *self { Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, Self::Vector { size, @@ -122,8 +145,7 @@ impl super::TypeInner { } => { let count = match size { super::ArraySize::Constant(handle) => { - let constant = constants.try_get(handle)?; - constant.to_array_length().unwrap_or(1) + constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element super::ArraySize::Dynamic => 1, @@ -132,13 +154,7 @@ impl super::TypeInner { } Self::Struct { span, .. } => span, Self::Image { .. } | Self::Sampler { .. } | Self::BindingArray { .. } => 0, - }) - } - - /// Get the size of this type. Panics if the `constants` doesn't contain - /// a referenced handle. This may not happen in a properly validated IR module. - pub fn size(&self, constants: &super::Arena) -> u32 { - self.try_size(constants).unwrap() + } } /// Return the canonical form of `self`, or `None` if it's already in diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 9a5922ea76..263c4c8394 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -162,8 +162,6 @@ impl crate::ConstantInner { #[derive(Clone, Debug, Error, PartialEq)] pub enum ResolveError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Index {index} is out of bounds for expression {expr:?}")] OutOfBoundsIndex { expr: Handle, @@ -195,8 +193,10 @@ pub enum ResolveError { IncompatibleOperands(String), #[error("Function argument {0} doesn't exist")] FunctionArgumentNotFound(u32), - #[error("Expression {0:?} depends on expressions that follow")] - ExpressionForwardDependency(Handle), + #[error("Expression {dependent:?} depends on expressions that follow")] + ExpressionForwardDependency { + dependent: Handle, + }, } pub struct ResolveContext<'a> { @@ -209,6 +209,15 @@ pub struct ResolveContext<'a> { } impl<'a> ResolveContext<'a> { + pub fn validate_resolution_handles(&self, expr: &crate::Expression) -> Result<(), BadHandle> { + match *expr { + crate::Expression::Constant(h) => self.constants.try_get(h).map(|_| ()), + crate::Expression::GlobalVariable(h) => self.global_vars.try_get(h).map(|_| ()), + crate::Expression::LocalVariable(h) => self.local_vars.try_get(h).map(|_| ()), + _ => Ok(()), + } + } + /// Determine the type of `expr`. /// /// The `past` argument must be a closure that can resolve the types of any @@ -403,20 +412,15 @@ impl<'a> ResolveContext<'a> { } } } - crate::Expression::Constant(h) => { - let constant = self.constants.try_get(h)?; - match constant.inner { - crate::ConstantInner::Scalar { width, ref value } => { - TypeResolution::Value(Ti::Scalar { - kind: value.scalar_kind(), - width, - }) - } - crate::ConstantInner::Composite { ty, components: _ } => { - TypeResolution::Handle(ty) - } + crate::Expression::Constant(h) => match self.constants[h].inner { + crate::ConstantInner::Scalar { width, ref value } => { + TypeResolution::Value(Ti::Scalar { + kind: value.scalar_kind(), + width, + }) } - } + crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty), + }, crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) { Ti::Scalar { kind, width } => { TypeResolution::Value(Ti::Vector { size, kind, width }) @@ -450,7 +454,7 @@ impl<'a> ResolveContext<'a> { TypeResolution::Handle(arg.ty) } crate::Expression::GlobalVariable(h) => { - let var = self.global_vars.try_get(h)?; + let var = &self.global_vars[h]; if var.space == crate::AddressSpace::Handle { TypeResolution::Handle(var.ty) } else { @@ -461,7 +465,7 @@ impl<'a> ResolveContext<'a> { } } crate::Expression::LocalVariable(h) => { - let var = self.local_vars.try_get(h)?; + let var = &self.local_vars[h]; TypeResolution::Value(Ti::Pointer { base: var.ty, space: crate::AddressSpace::Function, diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 3f13aba323..2686ef211f 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -706,11 +706,14 @@ impl FunctionInfo { }, }; + // FIXME: Break out this layer below into its own API. We can rebuild this layer on top of + // it once that's done. + let ty = resolve_context.resolve(expression, |h| { self.expressions .get(h.index()) .map(|ei| &ei.ty) - .ok_or(ResolveError::ExpressionForwardDependency(h)) + .ok_or(ResolveError::ExpressionForwardDependency { dependent: h }) })?; self.expressions[handle.index()] = ExpressionInfo { uniformity, diff --git a/src/valid/compose.rs b/src/valid/compose.rs index 6e5c499223..304de0fc47 100644 --- a/src/valid/compose.rs +++ b/src/valid/compose.rs @@ -9,8 +9,6 @@ use crate::arena::{BadHandle, Handle}; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum ComposeError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Composing of type {0:?} can't be done")] Type(Handle), #[error("Composing expects {expected} components but {given} were given")] @@ -19,6 +17,14 @@ pub enum ComposeError { ComponentType { index: u32 }, } +#[cfg(feature = "validate")] +pub fn validate_compose_handles( + self_ty_handle: Handle, + type_arena: &UniqueArena, +) -> Result<(), BadHandle> { + type_arena.get_handle(self_ty_handle).map(|_| ()) +} + #[cfg(feature = "validate")] pub fn validate_compose( self_ty_handle: Handle, @@ -28,8 +34,7 @@ pub fn validate_compose( ) -> Result<(), ComposeError> { use crate::TypeInner as Ti; - let self_ty = type_arena.get_handle(self_ty_handle)?; - match self_ty.inner { + match type_arena[self_ty_handle].inner { // vectors are composed from scalars or other vectors Ti::Vector { size, kind, width } => { let mut total = 0; diff --git a/src/valid/expression.rs b/src/valid/expression.rs index bf639065b4..c2c9e2242c 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1,10 +1,12 @@ +use std::ops::Index; + #[cfg(feature = "validate")] use super::{compose::validate_compose, FunctionInfo, ShaderStages, TypeFlags}; #[cfg(feature = "validate")] use crate::arena::UniqueArena; use crate::{ - arena::{BadHandle, Handle}, + arena::Handle, proc::{IndexableLengthError, ResolveError}, }; @@ -15,10 +17,6 @@ pub enum ExpressionError { DoesntExist, #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] NotInScope, - #[error("Depends on {0:?}, which has not been processed yet")] - ForwardDependency(Handle), - #[error(transparent)] - BadDependency(#[from] BadHandle), #[error("Base type {0:?} is not compatible with this expression")] InvalidBaseType(Handle), #[error("Accessing with index {0:?} can't be done")] @@ -121,6 +119,7 @@ pub enum ExpressionError { MissingCapabilities(super::Capabilities), } +/// TODO: document intended usage #[cfg(feature = "validate")] struct ExpressionTypeResolver<'a> { root: Handle, @@ -129,15 +128,17 @@ struct ExpressionTypeResolver<'a> { } #[cfg(feature = "validate")] -impl<'a> ExpressionTypeResolver<'a> { - fn resolve( - &self, - handle: Handle, - ) -> Result<&'a crate::TypeInner, ExpressionError> { +impl<'a> Index> for ExpressionTypeResolver<'a> { + type Output = crate::TypeInner; + + fn index(&self, handle: Handle) -> &Self::Output { if handle < self.root { - Ok(self.info[handle].ty.inner_with(self.types)) + self.info[handle].ty.inner_with(self.types) } else { - Err(ExpressionError::ForwardDependency(handle)) + panic!( + "Depends on {:?}, which has not been processed yet", + self.root + ) } } } @@ -163,7 +164,7 @@ impl super::Validator { let stages = match *expression { E::Access { base, index } => { - let base_type = resolver.resolve(base)?; + let base_type = &resolver[base]; // See the documentation for `Expression::Access`. let dynamic_indexing_restricted = match *base_type { Ti::Vector { .. } => false, @@ -176,7 +177,7 @@ impl super::Validator { return Err(ExpressionError::InvalidBaseType(base)); } }; - match *resolver.resolve(index)? { + match resolver[index] { //TODO: only allow one of these Ti::Scalar { kind: Sk::Sint | Sk::Uint, @@ -254,7 +255,7 @@ impl super::Validator { Ok(limit) } - let limit = resolve_index_limit(module, base, resolver.resolve(base)?, true)?; + let limit = resolve_index_limit(module, base, &resolver[base], true)?; if index >= limit { return Err(ExpressionError::IndexOutOfBounds( base, @@ -263,11 +264,8 @@ impl super::Validator { } ShaderStages::all() } - E::Constant(handle) => { - let _ = module.constants.try_get(handle)?; - ShaderStages::all() - } - E::Splat { size: _, value } => match *resolver.resolve(value)? { + E::Constant(_handle) => ShaderStages::all(), + E::Splat { size: _, value } => match resolver[value] { Ti::Scalar { .. } => ShaderStages::all(), ref other => { log::error!("Splat scalar type {:?}", other); @@ -279,7 +277,7 @@ impl super::Validator { vector, pattern, } => { - let vec_size = match *resolver.resolve(vector)? { + let vec_size = match resolver[vector] { Ti::Vector { size: vec_size, .. } => vec_size, ref other => { log::error!("Swizzle vector type {:?}", other); @@ -294,11 +292,6 @@ impl super::Validator { ShaderStages::all() } E::Compose { ref components, ty } => { - for &handle in components { - if handle >= root { - return Err(ExpressionError::ForwardDependency(handle)); - } - } validate_compose( ty, &module.constants, @@ -313,16 +306,10 @@ impl super::Validator { } ShaderStages::all() } - E::GlobalVariable(handle) => { - let _ = module.global_variables.try_get(handle)?; - ShaderStages::all() - } - E::LocalVariable(handle) => { - let _ = function.local_variables.try_get(handle)?; - ShaderStages::all() - } + E::GlobalVariable(_handle) => ShaderStages::all(), + E::LocalVariable(_handle) => ShaderStages::all(), E::Load { pointer } => { - match *resolver.resolve(pointer)? { + match resolver[pointer] { Ti::Pointer { base, .. } if self.types[base.index()] .flags @@ -365,7 +352,7 @@ impl super::Validator { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Sint, width: _, @@ -401,7 +388,7 @@ impl super::Validator { crate::ImageDimension::D2 => 2, crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3, }; - match *resolver.resolve(coordinate)? { + match resolver[coordinate] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -439,7 +426,7 @@ impl super::Validator { // check depth reference type if let Some(expr) = depth_ref { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -476,7 +463,7 @@ impl super::Validator { crate::SampleLevel::Auto => ShaderStages::FRAGMENT, crate::SampleLevel::Zero => ShaderStages::all(), crate::SampleLevel::Exact(expr) => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -485,7 +472,7 @@ impl super::Validator { ShaderStages::all() } crate::SampleLevel::Bias(expr) => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -494,7 +481,7 @@ impl super::Validator { ShaderStages::all() } crate::SampleLevel::Gradient { x, y } => { - match *resolver.resolve(x)? { + match resolver[x] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -507,7 +494,7 @@ impl super::Validator { return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x)) } } - match *resolver.resolve(y)? { + match resolver[y] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -538,7 +525,7 @@ impl super::Validator { arrayed, dim, } => { - match resolver.resolve(coordinate)?.image_storage_coordinates() { + match resolver[coordinate].image_storage_coordinates() { Some(coord_dim) if coord_dim == dim => {} _ => { return Err(ExpressionError::InvalidImageCoordinateType( @@ -550,7 +537,7 @@ impl super::Validator { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Sint, width: _, @@ -562,7 +549,7 @@ impl super::Validator { match (sample, class.is_multisampled()) { (None, false) => {} (Some(sample), true) => { - if resolver.resolve(sample)?.scalar_kind() != Some(Sk::Sint) { + if resolver[sample].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType( sample, )); @@ -576,7 +563,7 @@ impl super::Validator { match (level, class.is_mipmapped()) { (None, false) => {} (Some(level), true) => { - if resolver.resolve(level)?.scalar_kind() != Some(Sk::Sint) { + if resolver[level].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType(level)); } } @@ -610,7 +597,7 @@ impl super::Validator { } E::Unary { op, expr } => { use crate::UnaryOperator as Uo; - let inner = resolver.resolve(expr)?; + let inner = &resolver[expr]; match (op, inner.scalar_kind()) { (_, Some(Sk::Sint | Sk::Bool)) //TODO: restrict Negate for bools? @@ -625,8 +612,8 @@ impl super::Validator { } E::Binary { op, left, right } => { use crate::BinaryOperator as Bo; - let left_inner = resolver.resolve(left)?; - let right_inner = resolver.resolve(right)?; + let left_inner = &resolver[left]; + let right_inner = &resolver[right]; let good = match op { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { @@ -807,9 +794,9 @@ impl super::Validator { accept, reject, } => { - let accept_inner = resolver.resolve(accept)?; - let reject_inner = resolver.resolve(reject)?; - let condition_good = match *resolver.resolve(condition)? { + let accept_inner = &resolver[accept]; + let reject_inner = &resolver[reject]; + let condition_good = match resolver[condition] { Ti::Scalar { kind: Sk::Bool, width: _, @@ -839,7 +826,7 @@ impl super::Validator { ShaderStages::all() } E::Derivative { axis: _, expr } => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } @@ -852,7 +839,7 @@ impl super::Validator { } E::Relational { fun, argument } => { use crate::RelationalFunction as Rf; - let argument_inner = resolver.resolve(argument)?; + let argument_inner = &resolver[argument]; match fun { Rf::All | Rf::Any => match *argument_inner { Ti::Vector { kind: Sk::Bool, .. } => {} @@ -885,10 +872,11 @@ impl super::Validator { } => { use crate::MathFunction as Mf; - let arg_ty = resolver.resolve(arg)?; - let arg1_ty = arg1.map(|expr| resolver.resolve(expr)).transpose()?; - let arg2_ty = arg2.map(|expr| resolver.resolve(expr)).transpose()?; - let arg3_ty = arg3.map(|expr| resolver.resolve(expr)).transpose()?; + let resolve = |arg| &resolver[arg]; + let arg_ty = resolve(arg); + let arg1_ty = arg1.map(resolve); + let arg2_ty = arg2.map(resolve); + let arg3_ty = arg3.map(resolve); match fun { Mf::Abs => { if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { @@ -1372,7 +1360,7 @@ impl super::Validator { kind, convert, } => { - let base_width = match *resolver.resolve(expr)? { + let base_width = match resolver[expr] { crate::TypeInner::Scalar { width, .. } | crate::TypeInner::Vector { width, .. } | crate::TypeInner::Matrix { width, .. } => width, @@ -1401,9 +1389,9 @@ impl super::Validator { } ShaderStages::all() } - E::ArrayLength(expr) => match *resolver.resolve(expr)? { + E::ArrayLength(expr) => match resolver[expr] { Ti::Pointer { base, .. } => { - let base_ty = resolver.types.get_handle(base)?; + let base_ty = &resolver.types[base]; if let Ti::Array { size: crate::ArraySize::Dynamic, .. diff --git a/src/valid/function.rs b/src/valid/function.rs index 2107e71e09..20d617801d 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -16,8 +16,6 @@ use bit_set::BitSet; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum CallError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The callee is declared after the caller")] ForwardDeclaredFunction, #[error("Argument {index} expression is invalid")] @@ -67,8 +65,6 @@ pub enum LocalVariableError { #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum FunctionError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Expression {handle:?} is invalid")] Expression { handle: Handle, @@ -203,11 +199,8 @@ impl<'a> BlockContext<'a> { BlockContext { abilities, ..*self } } - fn get_expression( - &self, - handle: Handle, - ) -> Result<&'a crate::Expression, FunctionError> { - Ok(self.expressions.try_get(handle)?) + fn get_expression(&self, handle: Handle) -> &'a crate::Expression { + &self.expressions[handle] } fn resolve_type_impl( @@ -249,6 +242,19 @@ impl<'a> BlockContext<'a> { } impl super::Validator { + #[cfg(feature = "validate")] + fn validate_call_handles( + function: Handle, + module_functions: &Arena, + ) -> Result<(), WithSpan> { + module_functions + .try_get(function) + .map(|_| ()) + // TODO: use `e.with_span()` instead + // TODO: add better spans yo + .map_err(WithSpan::new) + } + #[cfg(feature = "validate")] fn validate_call( &mut self, @@ -257,11 +263,7 @@ impl super::Validator { result: Option>, context: &BlockContext, ) -> Result> { - let fun = context - .functions - .try_get(function) - .map_err(CallError::BadHandle) - .map_err(WithSpan::new)?; + let fun = &context.functions[function]; if fun.arguments.len() != arguments.len() { return Err(CallError::ArgumentCount { required: fun.arguments.len(), @@ -380,6 +382,84 @@ impl super::Validator { Ok(()) } + #[cfg(feature = "validate")] + fn validate_block_impl_handles( + &mut self, + statements: &crate::Block, + block_expressions: &Arena, + module_functions: &Arena, + ) -> Result<(), WithSpan> { + statements + // TODO: improve spans with `span_iter`, yo + .iter() + // NOTE: Keep this in roughly the same order as corresponding `match` in + // `validate_block_impl_handles`! + .try_for_each(|statement| match statement { + &crate::Statement::Emit(_) => Ok(()), + &crate::Statement::Block(ref block) => { + self.validate_block_handles(block, block_expressions, module_functions) + } + &crate::Statement::If { + condition: _, + accept: _, + reject: _, + } + | &crate::Statement::Switch { + selector: _, + cases: _, + } => Ok(()), + &crate::Statement::Loop { + ref body, + ref continuing, + break_if: _, + } => { + self.validate_block_impl_handles(body, block_expressions, module_functions)?; + self.validate_block_impl_handles( + continuing, + block_expressions, + module_functions, + )?; + Ok(()) + } + &crate::Statement::Break + | &crate::Statement::Continue + | &crate::Statement::Return { value: _ } + | &crate::Statement::Kill + | &crate::Statement::Barrier(_) + | &crate::Statement::Store { + pointer: _, + value: _, + } => Ok(()), + &crate::Statement::ImageStore { + image, + coordinate: _, + array_index: _, + value: _, + } => match block_expressions + .try_get(image) + .map_err(|e| e.with_span())? + { + &crate::Expression::Access { base, .. } + | &crate::Expression::AccessIndex { base, .. } => block_expressions + .try_get(base) + .map(|_| ()) + .map_err(|e| e.with_span()), + _ => Ok(()), + }, + &crate::Statement::Call { + function, + arguments: _, + result: _, + } => Self::validate_call_handles(function, module_functions), + &crate::Statement::Atomic { + pointer: _, + fun: _, + value: _, + result: _, + } => Ok(()), + }) + } + #[cfg(feature = "validate")] fn validate_block_impl( &mut self, @@ -674,14 +754,14 @@ impl super::Validator { } => { //Note: this code uses a lot of `FunctionError::InvalidImageStore`, // and could probably be refactored. - let var = match *context.get_expression(image).map_err(|e| e.with_span())? { + let var = match *context.get_expression(image) { crate::Expression::GlobalVariable(var_handle) => { &context.global_vars[var_handle] } // We're looking at a binding index situation, so punch through the index and look at the global behind it. crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { - match *context.get_expression(base).map_err(|e| e.with_span())? { + match *context.get_expression(base) { crate::Expression::GlobalVariable(var_handle) => { &context.global_vars[var_handle] } @@ -804,6 +884,16 @@ impl super::Validator { Ok(BlockInfo { stages, finished }) } + #[cfg(feature = "validate")] + fn validate_block_handles( + &mut self, + statements: &crate::Block, + block_expressions: &Arena, + module_functions: &Arena, + ) -> Result<(), WithSpan> { + self.validate_block_impl_handles(statements, block_expressions, module_functions) + } + #[cfg(feature = "validate")] fn validate_block( &mut self, @@ -884,10 +974,7 @@ impl super::Validator { #[cfg(feature = "validate")] for (index, argument) in fun.arguments.iter().enumerate() { - let ty = module.types.get_handle(argument.ty).map_err(|err| { - FunctionError::from(err).with_span_handle(argument.ty, &module.types) - })?; - match ty.inner.pointer_space() { + match module.types[argument.ty].inner.pointer_space() { Some( crate::AddressSpace::Private | crate::AddressSpace::Function diff --git a/src/valid/handles.rs b/src/valid/handles.rs new file mode 100644 index 0000000000..c392086087 --- /dev/null +++ b/src/valid/handles.rs @@ -0,0 +1,610 @@ +use std::{borrow::Cow, convert::TryInto, fmt, num::NonZeroU32}; + +use crate::{arena::BadHandle, Arena, Handle}; + +impl super::Validator { + #[warn(clippy::todo)] + pub(super) fn validate_module_handles( + module: &crate::Module, + ) -> Result<(), InvalidHandleError> { + let &crate::Module { + ref constants, + ref entry_points, + ref functions, + ref global_variables, + ref types, + } = module; + + // TODO: validate error quality + fn desc_name_defer_kind<'a, T>( + name: Option<&'a str>, + handle: Handle, + ) -> impl FnOnce(&'static str) -> HandleDescriptor> { + move |type_| { + HandleDescriptor::new(handle, KindAndMaybeName::from_type(type_).with_name(name)) + } + } + + const fn desc( + handle: Handle, + kind: &'static str, + ) -> HandleDescriptor { + HandleDescriptor::new(handle, kind) + } + + // NOTE: Types being first is important. All other forms of validation depend on this. + types + .iter() + .try_for_each(|(handle, type_)| -> Result<_, InvalidHandleError> { + let span = types.get_span(handle); + + let &crate::Type { + ref name, + ref inner, + } = type_; + let this_handle = desc_name_defer_kind(name.as_deref(), handle); + + match inner { + &crate::TypeInner::Scalar { .. } + | &crate::TypeInner::Vector { .. } + | &crate::TypeInner::Matrix { .. } + | &crate::TypeInner::ValuePointer { .. } + | &crate::TypeInner::Atomic { .. } + | &crate::TypeInner::Image { .. } + | &crate::TypeInner::Sampler { .. } => Ok(()), + &crate::TypeInner::Pointer { base, .. } => this_handle("pointer type") + .check_dep(HandleDescriptor::new(base, "base type"))? + .ok(), + &crate::TypeInner::Array { base, .. } => this_handle("array type") + .check_dep(HandleDescriptor::new(base, "base type"))? + .ok(), + &crate::TypeInner::Struct { ref members, .. } => { + let this_handle = this_handle("structure"); + + members + .iter() + .map(|&crate::StructMember { ref name, ty, .. }| { + desc_name_defer_kind(name.as_deref(), ty)("member type") + }) + .try_fold(this_handle, HandleDescriptor::check_dep)? + .ok() + } + &crate::TypeInner::BindingArray { base, .. } => { + this_handle("binding array type") + .check_dep(HandleDescriptor::new(base, "base type"))? + .ok() + } + } + })?; + + let validate_type = |type_handle| -> Result<(), InvalidHandleError> { + types.check_contains_handle(type_handle)?; + Ok(()) + }; + + constants + .iter() + .try_for_each(|(handle, constant)| -> Result<_, InvalidHandleError> { + let &crate::Constant { + ref name, + specialization: _, + ref inner, + } = constant; + match *inner { + crate::ConstantInner::Scalar { .. } => Ok(()), + crate::ConstantInner::Composite { ty, ref components } => { + validate_type(ty)?; + + let this_handle = desc_name_defer_kind(name.as_deref(), handle)("constant"); + components + .iter() + .copied() + .map(|component| desc_name_defer_kind(None, component)("component")) + .try_fold(this_handle, HandleDescriptor::check_dep)? + .ok() + } + } + })?; + + let validate_constant = |constant_handle| -> Result<(), InvalidHandleError> { + constants.check_contains_handle(constant_handle)?; + Ok(()) + }; + + global_variables.iter().try_for_each( + |(global_variable_handle, global_variable)| -> Result<_, InvalidHandleError> { + let &crate::GlobalVariable { + ref name, + space: _, + binding: _, + ty, + init, + } = global_variable; + let span = global_variables.get_span(global_variable_handle); + validate_type(ty)?; + if let Some(init_expr) = init { + validate_constant(init_expr)?; + } + Ok(()) + }, + )?; + + let validate_expressions = |expressions: &Arena, + local_variables: &Arena| + -> Result<(), InvalidHandleError> { + expressions + .iter() + .try_for_each(|(this_handle, expression)| { + let expr = |handle, kind| { + HandleDescriptor::new(handle, ExpressionHandleDescription { kind }) + }; + let this_expr = |kind| expr(this_handle, kind); + let expr_opt = |opt: Option<_>, desc| opt.map(|handle| expr(handle, desc)); + + match expression { + &crate::Expression::Access { base, .. } + | &crate::Expression::AccessIndex { base, .. } => this_expr("access") + .check_dep(expr(base, "access base"))? + .ok(), + &crate::Expression::Constant(constant) => { + validate_constant(constant)?; + Ok(()) + } + &crate::Expression::Splat { value, .. } => this_expr("splat") + .check_dep(expr(value, "splat value"))? + .ok(), + &crate::Expression::Swizzle { vector, .. } => { + this_expr("swizzle").check_dep(expr(vector, "vector"))?.ok() + } + &crate::Expression::Compose { ty, ref components } => { + validate_type(ty)?; + let this_handle = this_expr("composite"); + components + .iter() + .copied() + .map(|component| expr(component, "component")) + .try_fold(this_handle, HandleDescriptor::check_dep)? + .ok() + } + // TODO: Should we validate the length of function args? + &crate::Expression::FunctionArgument(_arg_idx) => Ok(()), + &crate::Expression::GlobalVariable(global_variable) => { + global_variables.check_contains_handle(global_variable)?; + Ok(()) + } + &crate::Expression::LocalVariable(local_variable) => { + // TODO: Shouldn't we be checking for forward deps here, too? + local_variables.check_contains_handle(local_variable)?; + Ok(()) + } + &crate::Expression::Load { pointer } => { + // TODO: right naming? + this_expr("load").check_dep(expr(pointer, "pointee"))?.ok() + } + &crate::Expression::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset, + level: _, + depth_ref, + } => { + // TODO: is there a better order for validation? + + if let Some(offset) = offset { + validate_constant(offset)?; + } + + this_expr("image sample") + .check_dep(expr(image, "image"))? + .check_dep(expr(sampler, "sampler"))? // TODO: Is this name correct? :think: + .check_dep(expr(coordinate, "coordinate"))? + .check_dep_opt(expr_opt(array_index, "array index"))? + .check_dep_opt(expr_opt(depth_ref, "depth reference"))? + .ok() + } + &crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => this_expr("image load") + .check_dep(expr(image, "image"))? + .check_dep(expr(coordinate, "coordinate"))? + .check_dep_opt(expr_opt(array_index, "array index"))? + .check_dep_opt(expr_opt(sample, "sample index"))? + .check_dep_opt(expr_opt(level, "level of detail"))? + .ok(), + &crate::Expression::ImageQuery { image, query } => this_expr("image query") + .check_dep(expr(image, "image"))? + .check_dep_opt(match query { + crate::ImageQuery::Size { level } => { + expr_opt(level, "level of detail") + } + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => None, + })? + .ok(), + &crate::Expression::Unary { + op: _, + expr: operand, + } => this_expr("unary") + // TODO: maybe use operator names? + .check_dep(expr(operand, "unary operand"))? + .ok(), + &crate::Expression::Binary { op: _, left, right } => this_expr("binary") + // TODO: maybe use operator names? + .check_dep(expr(left, "left operand"))? + .check_dep(expr(right, "right operand"))? + .ok(), + &crate::Expression::Select { + condition, + accept, + reject, + } => desc(this_handle, "`select` function call") // TODO: use function name/more platform-generic name? + .check_dep(expr(condition, "condition"))? + .check_dep(expr(accept, "accept"))? + .check_dep(expr(reject, "reject"))? + .ok(), + &crate::Expression::Derivative { + axis: _, + expr: argument, + } => { + // TODO: use function name/more platform-generic name? + this_expr("derivative") + .check_dep(expr(argument, "argument"))? + .ok() + } + &crate::Expression::Relational { fun: _, argument } => { + // TODO: use function name/more platform-generic name? + desc(this_handle, "relational function call") + .check_dep(expr(argument, "argument"))? + .ok() + } + &crate::Expression::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + // TODO: use function name/more platform-generic name? + desc(this_handle, "math function call") + .check_dep(expr(arg, "first argument"))? + .check_dep_opt(expr_opt(arg1, "second argument"))? + .check_dep_opt(expr_opt(arg2, "third argument"))? + .check_dep_opt(expr_opt(arg3, "fourth argument"))? + .ok() + } + &crate::Expression::As { + expr: input, + kind: _, + convert: _, + } => { + // TODO: use `kind` (ex., "cast to ...")? + this_expr("cast").check_dep(expr(input, "input"))?.ok() + } + &crate::Expression::CallResult(function) => { + functions.check_contains_handle(function)?; + Ok(()) + } + &crate::Expression::AtomicResult { .. } => Ok(()), + &crate::Expression::ArrayLength(array) => this_expr("array length") + .check_dep(expr(array, "array"))? + .ok(), + } + }) + }; + + let validate_function = |span, function: &_| -> Result<(), InvalidHandleError> { + let &crate::Function { + name: _, + ref arguments, + ref result, + ref local_variables, + ref expressions, + ref named_expressions, + ref body, + } = function; + + local_variables.iter().try_for_each( + |(handle, local_variable)| -> Result<_, InvalidHandleError> { + let &crate::LocalVariable { ref name, ty, init } = local_variable; + validate_type(ty)?; + if let Some(init_constant) = init { + // TODO: wait, where's the context? :( + validate_constant(init_constant)?; + } + Ok(()) + }, + )?; + + validate_expressions(expressions, local_variables)?; + Ok(()) + }; + + entry_points + .iter() + .try_for_each(|entry_point| -> Result<_, InvalidHandleError> { + // TODO: Why don't we have a `handle`/`Span` here? + validate_function(crate::Span::default(), &entry_point.function) + })?; + + functions.iter().try_for_each( + |(function_handle, function)| -> Result<_, InvalidHandleError> { + let span = functions.get_span(function_handle); + validate_function(span, function) + }, + )?; + + Ok(()) + } +} + +#[derive(Clone, Debug)] +struct KindAndMaybeName<'a> { + kind: &'static str, + name: Option>, +} + +impl<'a> KindAndMaybeName<'a> { + pub const fn from_type(type_: &'static str) -> Self { + Self { + kind: type_, + name: None, + } + } + + pub fn with_name<'b>(self, name: Option>>) -> KindAndMaybeName<'b> { + let Self { + kind: type_, + name: _, + } = self; + + KindAndMaybeName { + kind: type_, + name: name.map(Into::into), + } + } + + pub fn into_static(self) -> KindAndMaybeName<'static> { + let Self { kind: type_, name } = self; + + KindAndMaybeName { + kind: type_, + name: name.map(|n| n.into_owned().into()), + } + } +} + +impl fmt::Display for KindAndMaybeName<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let &Self { ref kind, ref name } = self; + write!(f, "{kind}")?; + if let Some(name) = name.as_ref() { + write!(f, " {name:?}")?; + } + Ok(()) + } +} + +impl HandleDescription for KindAndMaybeName<'_> { + fn into_erased(self) -> Box { + Box::new(self.into_static()) + } +} + +#[derive(Clone, Debug)] +struct ExpressionHandleDescription { + kind: &'static str, +} + +impl fmt::Display for ExpressionHandleDescription { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let &Self { kind } = self; + write!(f, "{kind} expression") + } +} + +impl HandleDescription for ExpressionHandleDescription { + fn into_erased(self) -> Box { + Box::new(self) + } +} + +// TODO: use a more concrete model for better diagnostics? +#[derive(Debug, thiserror::Error)] +pub enum InvalidHandleError { + #[error(transparent)] + Bad(#[from] BadHandle), + #[error(transparent)] + ForwardDependency(#[from] FwdDepError), +} + +// TODO: use a more concrete model for better diagnostics? +#[derive(Debug, thiserror::Error)] +#[error("{subject} depends on {depends_on}, which has not been processed yet")] +pub struct FwdDepError { + // TODO: context of what's being validated? + subject: HandleDescriptor<(), Box>, + depends_on: HandleDescriptor<(), Box>, +} + +#[derive(Clone, Copy, Debug)] +pub struct HandleDescriptor { + pub(crate) handle: Handle, + pub(crate) description: D, + // TODO: track type name? +} + +impl HandleDescriptor { + pub const fn new(handle: Handle, description: D) -> Self { + Self { + handle, + description, + } + } + + pub fn description_mut(&mut self) -> &mut D { + &mut self.description + } +} + +impl HandleDescriptor +where + D: HandleDescription, +{ + /// Check that `self`'s handle is valid for `arena`. + /// + /// As with all [`Arena`] handles, it is the responsibility of the caller to ensure that + /// `self`'s handle is valid for the provided `arena`. Otherwise, the result + pub(self) fn check_valid_for(self, arena: &Arena) -> Result { + arena.check_contains_handle(self.handle)?; + Ok(self) + } + + /// Check that `depends_on`'s handle is "ready" to be consumed by `self`'s handle by comparing + /// handle indices. If `self` describes a valid value (i.e., it has been validated using + /// [`Self::is_good_in`] and this function returns [`Ok`], then it may be assumed that + /// `depends_on` also passes that validation. + /// + /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating + /// recursive definitions of arena-based values in linear time. + /// + /// As with all [`Arena`] handles, it is the responsibility of the caller to ensure that `self` + /// and `depends_on` contain handles from the same arena. Otherwise, calling this likely isn't + /// correct! + /// + /// # Errors + /// + /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier + /// than `self`'s, this function returns an error. + pub(self) fn check_dep( + self, + depends_on: HandleDescriptor, + ) -> Result + where + D2: HandleDescription, + { + if depends_on.handle < self.handle { + Ok(self) + } else { + Err(FwdDepError { + subject: self.into_erased(), + depends_on: depends_on.into_erased(), + }) + } + } + + /// Like [`Self::check_dep`], except for [`Optional`] handle values. + pub(self) fn check_dep_opt( + self, + depends_on: Option>, + ) -> Result + where + D2: HandleDescription, + { + if let Some(depends_on) = depends_on { + self.check_dep(depends_on) + } else { + Ok(self) + } + } + + fn into_erased(self) -> HandleDescriptor<(), Box> { + let Self { + handle, + description, + } = self; + + HandleDescriptor { + handle: Handle::new(NonZeroU32::new(handle.index().try_into().unwrap()).unwrap()), + description: description.into_erased(), + } + } + + /// Finishes a chain of checks done with this handle descriptor with [`Ok`]. + /// + /// This API exists because the current method API design favors chained `?` calls. It's often + /// more convenient to write: + /// + /// ``` + /// # fn main() -> Result<(), InvalidHandleError> { + /// # let first_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// # let second_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// # let third_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// # let fourth_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// # let fifth_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// first_handle + /// .check_dep(second_handle)? + /// .check_dep(third_handle)? + /// .check_dep(fourth_handle)? + /// .check_dep(fifth_handle)? + /// .ok() // requires no type inference, single expression + /// # } + /// ``` + /// + /// ...than this: + /// + /// ``` + /// # fn main() -> Result<(), InvalidHandleError> { + /// # let first_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// # let second_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// # let third_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// # let fourth_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// # let fifth_handle = HandleDescriptor::new(Handle::new(0), "asdf"); + /// first_handle + /// .check_dep(second_handle)? + /// .check_dep(third_handle)? + /// .check_dep(fourth_handle)? + /// .check_dep(fifth_handle)?; + /// Ok(()) // may require explicit type specification to use `?`, requires a block + /// # } + /// ``` + #[allow(clippy::missing_const_for_fn)] // NOTE: This fires incorrectly without this. :< + pub(self) fn ok(self) -> Result<(), InvalidHandleError> { + Ok(()) + } +} + +impl fmt::Display for HandleDescriptor +where + D: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let &Self { + ref handle, + ref description, + } = self; + write!(f, "{description} (handle {handle:?})") + } +} + +// impl PartialEq for HandleDescriptor { +// fn eq(&self, other: &Self) -> bool { +// self.handle.eq(&other.handle) +// } +// } + +pub trait HandleDescription +where + Self: fmt::Debug + fmt::Display, +{ + fn into_erased(self) -> Box; +} + +impl HandleDescription for Box { + fn into_erased(self) -> Box { + self + } +} + +impl HandleDescription for &'static str { + fn into_erased(self) -> Box { + Box::new(self) + } +} diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 072550e9b0..1d3108b8ca 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -12,8 +12,6 @@ const MAX_WORKGROUP_SIZE: u32 = 0x4000; #[derive(Clone, Debug, thiserror::Error)] pub enum GlobalVariableError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Usage isn't compatible with address space {0:?}")] InvalidUsage(crate::AddressSpace), #[error("Type isn't compatible with address space {0:?}")] @@ -355,6 +353,18 @@ impl VaryingContext<'_> { } impl super::Validator { + #[cfg(feature = "validate")] + pub(super) fn validate_global_var_handles( + &self, + var: &crate::GlobalVariable, + ) -> Result<(), BadHandle> { + let index = var.ty.index(); + self.types.get(index).map(|_| ()).ok_or(BadHandle { + kind: "type", + index, + }) + } + #[cfg(feature = "validate")] pub(super) fn validate_global_var( &self, @@ -364,10 +374,7 @@ impl super::Validator { use super::TypeFlags; log::debug!("var {:?}", var); - let type_info = self.types.get(var.ty.index()).ok_or_else(|| BadHandle { - kind: "type", - index: var.ty.index(), - })?; + let type_info = &self.types[var.ty.index()]; let (required_type_flags, is_resource) = match var.space { crate::AddressSpace::Function => { diff --git a/src/valid/mod.rs b/src/valid/mod.rs index b746f6abd4..bcf889905a 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -6,6 +6,7 @@ mod analyzer; mod compose; mod expression; mod function; +mod handles; mod interface; mod r#type; @@ -18,7 +19,7 @@ use crate::{ FastHashSet, }; use bit_set::BitSet; -use std::ops; +use std::{ops, sync::Arc}; //TODO: analyze the model at the same time as we validate it, // merge the corresponding matches over expressions and statements. @@ -31,6 +32,8 @@ pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; pub use r#type::{Disalignment, TypeError, TypeFlags}; +use self::handles::InvalidHandleError; + bitflags::bitflags! { /// Validation flags. /// @@ -143,8 +146,6 @@ pub struct Validator { #[derive(Clone, Debug, thiserror::Error)] pub enum ConstantError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The type doesn't match the constant")] InvalidType, #[error("The component handle {0:?} can not be resolved")] @@ -157,6 +158,9 @@ pub enum ConstantError { #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { + // TODO: I hate these diagnostics. + #[error(transparent)] + InvalidHandle(#[from] Arc), #[error(transparent)] Layouter(#[from] LayoutError), #[error("Type {handle:?} '{name}' is invalid")] @@ -198,6 +202,13 @@ pub enum ValidationError { Corrupted, } +// TODO: remove this +impl From for ValidationError { + fn from(e: BadHandle) -> Self { + Self::InvalidHandle(Arc::new(e.into())) + } +} + impl crate::TypeInner { #[cfg(feature = "validate")] const fn is_sized(&self) -> bool { @@ -270,6 +281,22 @@ impl Validator { self.valid_expression_set.clear(); } + #[cfg(feature = "validate")] + fn validate_constant_handles( + &self, + handle: Handle, + constants: &Arena, + types: &UniqueArena, + ) -> Result<(), WithSpan> { + match &constants[handle].inner { + &crate::ConstantInner::Composite { ty, components: _ } => types + .get_handle(ty) + .map(|_| ()) + .map_err(|e| e.with_span_handle(ty, types)), + _ => Ok(()), + } + } + #[cfg(feature = "validate")] fn validate_constant( &self, @@ -285,7 +312,7 @@ impl Validator { } } crate::ConstantInner::Composite { ty, ref components } => { - match types.get_handle(ty)?.inner { + match types[ty].inner { crate::TypeInner::Array { size: crate::ArraySize::Constant(size_handle), .. @@ -315,9 +342,13 @@ impl Validator { &mut self, module: &crate::Module, ) -> Result> { + Self::validate_module_handles(module).unwrap(); // TODO: y u no `return` + self.reset(); self.reset_types(module.types.len()); + // TODO: change call tree beyond here; use `` instead of + // `Arena::try_get`. self.layouter .update(&module.types, &module.constants) .map_err(|e| { diff --git a/src/valid/type.rs b/src/valid/type.rs index f103017dd9..d988920075 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -2,6 +2,8 @@ use super::Capabilities; use crate::{ arena::{Arena, BadHandle, Handle, UniqueArena}, proc::Alignment, + span::AddSpan, + WithSpan, }; bitflags::bitflags! { @@ -88,8 +90,6 @@ pub enum Disalignment { #[derive(Clone, Debug, thiserror::Error)] pub enum TypeError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The {0:?} scalar width {1} is not supported")] InvalidWidth(crate::ScalarKind, crate::Bytes), #[error("The {0:?} scalar width {1} is not supported for an atomic")] @@ -219,6 +219,25 @@ impl super::Validator { self.layouter.clear(); } + pub(super) fn validate_type_handles( + &self, + handle: Handle, + types: &UniqueArena, + constants: &Arena, + ) -> Result<(), WithSpan> { + match types[handle].inner { + crate::TypeInner::Array { + base: _, + size: crate::ArraySize::Constant(const_handle), + stride: _, + } => constants + .try_get(const_handle) + .map(|_| ()) + .map_err(|e| e.with_span_handle(const_handle, constants)), + _ => Ok(()), + } + } + pub(super) fn validate_type( &self, handle: Handle, @@ -418,7 +437,7 @@ impl super::Validator { let sized_flag = match size { crate::ArraySize::Constant(const_handle) => { - let constant = constants.try_get(const_handle)?; + let constant = &constants[const_handle]; let length_is_positive = match *constant { crate::Constant { specialization: Some(_),