From 4bfac7f17d43e3a8f774994eb7c6215192b04203 Mon Sep 17 00:00:00 2001 From: Erich Gubler Date: Fri, 14 Oct 2022 13:30:34 -0400 Subject: [PATCH] WIP: lots of changes, but not enough clean, beware! --- commits.md | 7 ++ src/back/hlsl/conv.rs | 15 ++- src/back/hlsl/writer.rs | 9 +- src/proc/layouter.rs | 82 +++++++++++-- src/proc/mod.rs | 40 +++++-- src/proc/typifier.rs | 36 +++--- src/valid/analyzer.rs | 3 + src/valid/compose.rs | 13 +- src/valid/expression.rs | 257 +++++++++++++++++++++++++++++++--------- src/valid/function.rs | 167 ++++++++++++++++++++++---- src/valid/interface.rs | 19 ++- src/valid/mod.rs | 57 +++++++-- src/valid/type.rs | 25 +++- 13 files changed, 584 insertions(+), 146 deletions(-) create mode 100644 commits.md diff --git a/commits.md b/commits.md new file mode 100644 index 0000000000..c661ead71f --- /dev/null +++ b/commits.md @@ -0,0 +1,7 @@ +- [ ] Split out `ExprFedDepError` +- [ ] Combine handle validation pass errors at current implementation level +- [ ] Split out the pass! + - [ ] squash! add `ExpressionTypeResolver::check` + - [ ] squash! `impl Index for ExpressionTypeResolver` + - [ ] note: trade-offs were made, we're now doing multiple passes. Boo for + perf, yay for abstraction. diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 039bfcce30..154eb107de 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -40,12 +40,12 @@ impl crate::TypeInner { } } - pub(super) fn try_size_hlsl( + pub(super) fn size_hlsl( &self, types: &crate::UniqueArena, constants: &crate::Arena, - ) -> Result { - Ok(match *self { + ) -> u32 { + match *self { Self::Matrix { columns, rows, @@ -58,17 +58,16 @@ impl crate::TypeInner { Self::Array { base, size, stride } => { let count = match size { crate::ArraySize::Constant(handle) => { - let constant = constants.try_get(handle)?; - constant.to_array_length().unwrap_or(1) + constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element crate::ArraySize::Dynamic => 1, }; - let last_el_size = types[base].inner.try_size_hlsl(types, constants)?; + let last_el_size = types[base].inner.size_hlsl(types, constants); ((count - 1) * stride) + last_el_size } - _ => self.try_size(constants)?, - }) + _ => self.size(constants), + } } /// Used to generate the name of the wrapped type constructor diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 42e3060798..e283e774e5 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -805,6 +805,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { /// /// # Notes /// Ends in a newline + /// + /// # Panics + /// + /// This function may panic when given a malformed IR module. fn write_struct( &mut self, module: &Module, @@ -829,10 +833,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset - + ty_inner - .try_size_hlsl(&module.types, &module.constants) - .unwrap(); + last_offset = member.offset + ty_inner.size_hlsl(&module.types, &module.constants); // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 0c3a00db15..68a2a65192 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -1,4 +1,7 @@ -use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; +use crate::{ + arena::{Arena, BadHandle, Handle, UniqueArena}, + WithSpan, +}; use std::{fmt::Display, num::NonZeroU32, ops}; /// A newtype struct where its only valid values are powers of 2 @@ -130,8 +133,6 @@ pub enum LayoutErrorInner { InvalidStructMemberType(u32, Handle), #[error("Type width must be a power of two")] NonPowerOfTwoWidth, - #[error("Array size is a bad handle")] - BadHandle(#[from] BadHandle), } #[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] @@ -153,6 +154,20 @@ impl Layouter { self.layouts.clear(); } + pub fn validate_handles( + &mut self, + types: &UniqueArena, + constants: &Arena, + ) -> Result<(), WithSpan> { + types + .iter() + .skip(self.layouts.len()) + .try_for_each(|(_handle, ty)| { + // TODO: track the type we fail on? + ty.inner.validate_handles(constants) + }) + } + /// Extend this `Layouter` with layouts for any new entries in `types`. /// /// Ensure that every type in `types` has a corresponding [TypeLayout] in @@ -166,7 +181,13 @@ impl Layouter { /// end can call this function at any time, passing its current type and /// constant arenas, and then assume that layouts are available for all /// types. + /// + /// # Panics + /// + /// If `types` contains invalid [`Handle`]s to `constants`, then this function will panic. You + /// can check for this condition by calling [`Self::validate_handles`]. #[allow(clippy::or_fun_call)] + #[track_caller] pub fn update( &mut self, types: &UniqueArena, @@ -175,10 +196,8 @@ impl Layouter { use crate::TypeInner as Ti; for (ty_handle, ty) in types.iter().skip(self.layouts.len()) { - let size = ty - .inner - .try_size(constants) - .map_err(|error| LayoutErrorInner::BadHandle(error).with(ty_handle))?; + // phase-id: layouter + let size = ty.inner.size(constants); let layout = match ty.inner { Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => { let alignment = Alignment::new(width as u32) @@ -255,3 +274,52 @@ impl Layouter { Ok(()) } } + +// #[derive(Debug)] +// struct HandlesValidated(T); + +// impl HandlesValidated { +// pub const fn new(t: T) -> Self { +// Self(t) +// } +// } + +// impl<'a, T> Clone for HandlesValidated<&'a T> { +// fn clone(&self) -> Self { +// let Self(ref_) = self; +// Self(ref_) +// } +// } + +// impl<'a, T> Copy for HandlesValidated<&'a T> {} + +// impl<'a, T> AsRef for HandlesValidated<&'a T> { +// fn as_ref(&self) -> &T { +// todo!() +// } +// } + +// impl<'a, T> Deref for HandlesValidated<&'a T> { +// type Target = T; + +// fn deref(&self) -> &Self::Target { +// let Self(inner) = self; +// inner +// } +// } + +// impl<'a, T> Deref for HandlesValidated<&'a mut T> { +// type Target = T; + +// fn deref(&self) -> &Self::Target { +// let Self(inner) = self; +// inner +// } +// } + +// impl<'a, T> DerefMut for HandlesValidated<&'a mut T> { +// fn deref_mut(&mut self) -> &mut Self::Target { +// let Self(inner) = self; +// inner +// } +// } diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 0b9f406af4..112a48afbe 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -16,6 +16,8 @@ pub use namer::{EntryPointIndex, NameKey, Namer}; pub use terminator::ensure_block_returns; pub use typifier::{ResolveContext, ResolveError, TypeResolution}; +use crate::{arena::BadHandle, WithSpan}; + impl From for super::ScalarKind { fn from(format: super::StorageFormat) -> Self { use super::{ScalarKind as Sk, StorageFormat as Sf}; @@ -97,11 +99,32 @@ impl super::TypeInner { } } - pub fn try_size( + pub fn validate_handles( &self, constants: &super::Arena, - ) -> Result { - Ok(match *self { + ) -> Result<(), WithSpan> { + match self { + &Self::Array { + base: _, + size: super::ArraySize::Constant(handle), + stride: _, + } => constants + .try_get(handle) + .map(|_| ()) + .map_err(|e| WithSpan::new(e).with_handle(handle, constants)), + _ => Ok(()), + } + } + + /// Get the size of this type. + /// + /// # Panics + /// + /// Panics if the `constants` doesn't contain a referenced handle. This may not happen in + /// a properly validated IR module. You can check for this condition with + /// [`Self::validate_handles`]. + pub fn size(&self, constants: &super::Arena) -> u32 { + match *self { Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, Self::Vector { size, @@ -122,8 +145,7 @@ impl super::TypeInner { } => { let count = match size { super::ArraySize::Constant(handle) => { - let constant = constants.try_get(handle)?; - constant.to_array_length().unwrap_or(1) + constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element super::ArraySize::Dynamic => 1, @@ -132,13 +154,7 @@ impl super::TypeInner { } Self::Struct { span, .. } => span, Self::Image { .. } | Self::Sampler { .. } | Self::BindingArray { .. } => 0, - }) - } - - /// Get the size of this type. Panics if the `constants` doesn't contain - /// a referenced handle. This may not happen in a properly validated IR module. - pub fn size(&self, constants: &super::Arena) -> u32 { - self.try_size(constants).unwrap() + } } /// Return the canonical form of `self`, or `None` if it's already in diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 921700fd49..263c4c8394 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -162,8 +162,6 @@ impl crate::ConstantInner { #[derive(Clone, Debug, Error, PartialEq)] pub enum ResolveError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Index {index} is out of bounds for expression {expr:?}")] OutOfBoundsIndex { expr: Handle, @@ -211,6 +209,15 @@ pub struct ResolveContext<'a> { } impl<'a> ResolveContext<'a> { + pub fn validate_resolution_handles(&self, expr: &crate::Expression) -> Result<(), BadHandle> { + match *expr { + crate::Expression::Constant(h) => self.constants.try_get(h).map(|_| ()), + crate::Expression::GlobalVariable(h) => self.global_vars.try_get(h).map(|_| ()), + crate::Expression::LocalVariable(h) => self.local_vars.try_get(h).map(|_| ()), + _ => Ok(()), + } + } + /// Determine the type of `expr`. /// /// The `past` argument must be a closure that can resolve the types of any @@ -405,20 +412,15 @@ impl<'a> ResolveContext<'a> { } } } - crate::Expression::Constant(h) => { - let constant = self.constants.try_get(h)?; - match constant.inner { - crate::ConstantInner::Scalar { width, ref value } => { - TypeResolution::Value(Ti::Scalar { - kind: value.scalar_kind(), - width, - }) - } - crate::ConstantInner::Composite { ty, components: _ } => { - TypeResolution::Handle(ty) - } + crate::Expression::Constant(h) => match self.constants[h].inner { + crate::ConstantInner::Scalar { width, ref value } => { + TypeResolution::Value(Ti::Scalar { + kind: value.scalar_kind(), + width, + }) } - } + crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty), + }, crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) { Ti::Scalar { kind, width } => { TypeResolution::Value(Ti::Vector { size, kind, width }) @@ -452,7 +454,7 @@ impl<'a> ResolveContext<'a> { TypeResolution::Handle(arg.ty) } crate::Expression::GlobalVariable(h) => { - let var = self.global_vars.try_get(h)?; + let var = &self.global_vars[h]; if var.space == crate::AddressSpace::Handle { TypeResolution::Handle(var.ty) } else { @@ -463,7 +465,7 @@ impl<'a> ResolveContext<'a> { } } crate::Expression::LocalVariable(h) => { - let var = self.local_vars.try_get(h)?; + let var = &self.local_vars[h]; TypeResolution::Value(Ti::Pointer { base: var.ty, space: crate::AddressSpace::Function, diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 26bd7132fd..2686ef211f 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -706,6 +706,9 @@ impl FunctionInfo { }, }; + // FIXME: Break out this layer below into its own API. We can rebuild this layer on top of + // it once that's done. + let ty = resolve_context.resolve(expression, |h| { self.expressions .get(h.index()) diff --git a/src/valid/compose.rs b/src/valid/compose.rs index 6e5c499223..304de0fc47 100644 --- a/src/valid/compose.rs +++ b/src/valid/compose.rs @@ -9,8 +9,6 @@ use crate::arena::{BadHandle, Handle}; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum ComposeError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Composing of type {0:?} can't be done")] Type(Handle), #[error("Composing expects {expected} components but {given} were given")] @@ -19,6 +17,14 @@ pub enum ComposeError { ComponentType { index: u32 }, } +#[cfg(feature = "validate")] +pub fn validate_compose_handles( + self_ty_handle: Handle, + type_arena: &UniqueArena, +) -> Result<(), BadHandle> { + type_arena.get_handle(self_ty_handle).map(|_| ()) +} + #[cfg(feature = "validate")] pub fn validate_compose( self_ty_handle: Handle, @@ -28,8 +34,7 @@ pub fn validate_compose( ) -> Result<(), ComposeError> { use crate::TypeInner as Ti; - let self_ty = type_arena.get_handle(self_ty_handle)?; - match self_ty.inner { + match type_arena[self_ty_handle].inner { // vectors are composed from scalars or other vectors Ti::Vector { size, kind, width } => { let mut total = 0; diff --git a/src/valid/expression.rs b/src/valid/expression.rs index bf639065b4..49bb488d57 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1,11 +1,15 @@ +use std::ops::Index; + #[cfg(feature = "validate")] use super::{compose::validate_compose, FunctionInfo, ShaderStages, TypeFlags}; +use super::{ExprFwdDepError, InvalidHandleError}; #[cfg(feature = "validate")] use crate::arena::UniqueArena; use crate::{ - arena::{BadHandle, Handle}, + arena::Handle, proc::{IndexableLengthError, ResolveError}, + valid::compose::validate_compose_handles, }; #[derive(Clone, Debug, thiserror::Error)] @@ -15,10 +19,6 @@ pub enum ExpressionError { DoesntExist, #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] NotInScope, - #[error("Depends on {0:?}, which has not been processed yet")] - ForwardDependency(Handle), - #[error(transparent)] - BadDependency(#[from] BadHandle), #[error("Base type {0:?} is not compatible with this expression")] InvalidBaseType(Handle), #[error("Accessing with index {0:?} can't be done")] @@ -121,6 +121,7 @@ pub enum ExpressionError { MissingCapabilities(super::Capabilities), } +/// TODO: document intended usage #[cfg(feature = "validate")] struct ExpressionTypeResolver<'a> { root: Handle, @@ -130,20 +131,177 @@ struct ExpressionTypeResolver<'a> { #[cfg(feature = "validate")] impl<'a> ExpressionTypeResolver<'a> { + /// TODO: document intended usage + // TODO: make this just return a `()` as a checking feature fn resolve( &self, handle: Handle, - ) -> Result<&'a crate::TypeInner, ExpressionError> { + ) -> Result<&'a crate::TypeInner, ExprFwdDepError> { if handle < self.root { Ok(self.info[handle].ty.inner_with(self.types)) } else { - Err(ExpressionError::ForwardDependency(handle)) + Err(ExprFwdDepError { + depends_on: self.root, + }) // TODO: Should this be `self.root`? } } + + /// TODO: document intended usage + fn check(&self, handle: Handle) -> Result<(), ExprFwdDepError> { + self.resolve(handle).map(|_| ()) + } +} + +#[cfg(feature = "validate")] +impl<'a> Index> for ExpressionTypeResolver<'a> { + type Output = crate::TypeInner; + + fn index(&self, handle: Handle) -> &Self::Output { + self.resolve(handle).unwrap() + } } #[cfg(feature = "validate")] impl super::Validator { + pub(super) fn validate_expression_handles( + &self, + root: Handle, + expression: &crate::Expression, + function: &crate::Function, + module: &crate::Module, + info: &FunctionInfo, + ) -> Result<(), InvalidHandleError> { + use crate::{Expression as E, TypeInner as Ti}; + + let resolver = ExpressionTypeResolver { + root, + types: &module.types, + info, + }; + + // NOTE: TODO: nag to keep in sync + match *expression { + E::Access { base, index } => { + resolver.check(base)?; + resolver.check(index)?; + } + E::AccessIndex { base, index: _ } => resolver.check(base)?, + E::Constant(handle) => module.constants.try_get(handle).map(|_| ())?, + E::Splat { size: _, value } => resolver.check(value)?, + E::Swizzle { + size: _, + vector, + pattern: _, + } => resolver.check(vector)?, + E::Compose { ty, ref components } => { + for &handle in components { + if handle >= root { + return Err(ExprFwdDepError { depends_on: handle }.into()); + } + } + validate_compose_handles(ty, &module.types)?; + } + E::FunctionArgument(_) => (), + E::GlobalVariable(handle) => module.global_variables.try_get(handle).map(|_| ())?, + E::LocalVariable(handle) => function.local_variables.try_get(handle).map(|_| ())?, + E::Load { pointer } => resolver.check(pointer)?, + E::ImageSample { + image, + coordinate, + sampler, + array_index, + depth_ref, + level, + .. + } => { + if let Ok(image_ty) = Self::global_var_ty(module, function, image) { + if let Ti::Image { .. } = module.types[image_ty].inner { + if let Some(expr) = array_index { + resolver.check(expr)?; + resolver.check(coordinate)?; + if let Some(expr) = depth_ref { + resolver.check(expr)?; + } + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Zero => (), + crate::SampleLevel::Exact(expr) => { + resolver.check(expr)?; + } + crate::SampleLevel::Bias(expr) => resolver.check(expr)?, + crate::SampleLevel::Gradient { x, y } => { + resolver.check(x)?; + resolver.check(y)?; + } + } + } + } + } + } + E::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + if let Ok(image_ty) = Self::global_var_ty(module, function, image) { + if let Ti::Image { class, .. } = module.types[image_ty].inner { + if let Some(expr) = array_index { + resolver.check(coordinate)?; + resolver.check(expr)?; + if let (Some(sample), true) = (sample, class.is_multisampled()) { + resolver.check(sample)?; + } + if let (Some(level), true) = (level, class.is_mipmapped()) { + resolver.check(level)?; + } + } + } + } + } + E::ImageQuery { .. } => (), + E::Unary { expr, op: _ } => resolver.check(expr)?, + E::Binary { op: _, left, right } => { + resolver.check(left)?; + resolver.check(right)?; + } + E::Select { + condition, + accept, + reject, + } => { + resolver.check(accept)?; + resolver.check(reject)?; + resolver.check(condition)?; + } + E::Derivative { axis: _, expr } => resolver.check(expr)?, + E::Relational { fun: _, argument } => resolver.check(argument)?, + E::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + resolver.check(arg)?; + IntoIterator::into_iter([arg1, arg2, arg3]) + .flatten() + .try_for_each(|a| resolver.check(a))?; + } + E::As { + expr, + kind: _, + convert: _, + } => resolver.check(expr)?, + E::CallResult(_) | E::AtomicResult { .. } => (), + E::ArrayLength(expr) => match resolver.resolve(expr)? { + Ti::Pointer { base, .. } => resolver.types.get_handle(*base).map(|_| ())?, + _ => (), + }, + }; + Ok(()) + } + pub(super) fn validate_expression( &self, root: Handle, @@ -163,7 +321,7 @@ impl super::Validator { let stages = match *expression { E::Access { base, index } => { - let base_type = resolver.resolve(base)?; + let base_type = &resolver[base]; // See the documentation for `Expression::Access`. let dynamic_indexing_restricted = match *base_type { Ti::Vector { .. } => false, @@ -176,7 +334,7 @@ impl super::Validator { return Err(ExpressionError::InvalidBaseType(base)); } }; - match *resolver.resolve(index)? { + match resolver[index] { //TODO: only allow one of these Ti::Scalar { kind: Sk::Sint | Sk::Uint, @@ -254,7 +412,7 @@ impl super::Validator { Ok(limit) } - let limit = resolve_index_limit(module, base, resolver.resolve(base)?, true)?; + let limit = resolve_index_limit(module, base, &resolver[base], true)?; if index >= limit { return Err(ExpressionError::IndexOutOfBounds( base, @@ -263,11 +421,8 @@ impl super::Validator { } ShaderStages::all() } - E::Constant(handle) => { - let _ = module.constants.try_get(handle)?; - ShaderStages::all() - } - E::Splat { size: _, value } => match *resolver.resolve(value)? { + E::Constant(handle) => ShaderStages::all(), + E::Splat { size: _, value } => match resolver[value] { Ti::Scalar { .. } => ShaderStages::all(), ref other => { log::error!("Splat scalar type {:?}", other); @@ -279,7 +434,7 @@ impl super::Validator { vector, pattern, } => { - let vec_size = match *resolver.resolve(vector)? { + let vec_size = match resolver[vector] { Ti::Vector { size: vec_size, .. } => vec_size, ref other => { log::error!("Swizzle vector type {:?}", other); @@ -294,11 +449,6 @@ impl super::Validator { ShaderStages::all() } E::Compose { ref components, ty } => { - for &handle in components { - if handle >= root { - return Err(ExpressionError::ForwardDependency(handle)); - } - } validate_compose( ty, &module.constants, @@ -313,16 +463,10 @@ impl super::Validator { } ShaderStages::all() } - E::GlobalVariable(handle) => { - let _ = module.global_variables.try_get(handle)?; - ShaderStages::all() - } - E::LocalVariable(handle) => { - let _ = function.local_variables.try_get(handle)?; - ShaderStages::all() - } + E::GlobalVariable(handle) => ShaderStages::all(), + E::LocalVariable(handle) => ShaderStages::all(), E::Load { pointer } => { - match *resolver.resolve(pointer)? { + match resolver[pointer] { Ti::Pointer { base, .. } if self.types[base.index()] .flags @@ -365,7 +509,7 @@ impl super::Validator { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Sint, width: _, @@ -401,7 +545,7 @@ impl super::Validator { crate::ImageDimension::D2 => 2, crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3, }; - match *resolver.resolve(coordinate)? { + match resolver[coordinate] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -439,7 +583,7 @@ impl super::Validator { // check depth reference type if let Some(expr) = depth_ref { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -476,7 +620,7 @@ impl super::Validator { crate::SampleLevel::Auto => ShaderStages::FRAGMENT, crate::SampleLevel::Zero => ShaderStages::all(), crate::SampleLevel::Exact(expr) => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -485,7 +629,7 @@ impl super::Validator { ShaderStages::all() } crate::SampleLevel::Bias(expr) => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -494,7 +638,7 @@ impl super::Validator { ShaderStages::all() } crate::SampleLevel::Gradient { x, y } => { - match *resolver.resolve(x)? { + match resolver[x] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -507,7 +651,7 @@ impl super::Validator { return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x)) } } - match *resolver.resolve(y)? { + match resolver[y] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -538,7 +682,7 @@ impl super::Validator { arrayed, dim, } => { - match resolver.resolve(coordinate)?.image_storage_coordinates() { + match resolver[coordinate].image_storage_coordinates() { Some(coord_dim) if coord_dim == dim => {} _ => { return Err(ExpressionError::InvalidImageCoordinateType( @@ -550,7 +694,7 @@ impl super::Validator { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Sint, width: _, @@ -562,7 +706,7 @@ impl super::Validator { match (sample, class.is_multisampled()) { (None, false) => {} (Some(sample), true) => { - if resolver.resolve(sample)?.scalar_kind() != Some(Sk::Sint) { + if resolver[sample].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType( sample, )); @@ -576,7 +720,7 @@ impl super::Validator { match (level, class.is_mipmapped()) { (None, false) => {} (Some(level), true) => { - if resolver.resolve(level)?.scalar_kind() != Some(Sk::Sint) { + if resolver[level].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType(level)); } } @@ -610,7 +754,7 @@ impl super::Validator { } E::Unary { op, expr } => { use crate::UnaryOperator as Uo; - let inner = resolver.resolve(expr)?; + let inner = resolver[expr]; match (op, inner.scalar_kind()) { (_, Some(Sk::Sint | Sk::Bool)) //TODO: restrict Negate for bools? @@ -625,8 +769,8 @@ impl super::Validator { } E::Binary { op, left, right } => { use crate::BinaryOperator as Bo; - let left_inner = resolver.resolve(left)?; - let right_inner = resolver.resolve(right)?; + let left_inner = &resolver[left]; + let right_inner = &resolver[right]; let good = match op { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { @@ -807,9 +951,9 @@ impl super::Validator { accept, reject, } => { - let accept_inner = resolver.resolve(accept)?; - let reject_inner = resolver.resolve(reject)?; - let condition_good = match *resolver.resolve(condition)? { + let accept_inner = &resolver[accept]; + let reject_inner = &resolver[reject]; + let condition_good = match resolver[condition] { Ti::Scalar { kind: Sk::Bool, width: _, @@ -839,7 +983,7 @@ impl super::Validator { ShaderStages::all() } E::Derivative { axis: _, expr } => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } @@ -852,7 +996,7 @@ impl super::Validator { } E::Relational { fun, argument } => { use crate::RelationalFunction as Rf; - let argument_inner = resolver.resolve(argument)?; + let argument_inner = &resolver[argument]; match fun { Rf::All | Rf::Any => match *argument_inner { Ti::Vector { kind: Sk::Bool, .. } => {} @@ -885,10 +1029,11 @@ impl super::Validator { } => { use crate::MathFunction as Mf; - let arg_ty = resolver.resolve(arg)?; - let arg1_ty = arg1.map(|expr| resolver.resolve(expr)).transpose()?; - let arg2_ty = arg2.map(|expr| resolver.resolve(expr)).transpose()?; - let arg3_ty = arg3.map(|expr| resolver.resolve(expr)).transpose()?; + let resolve = |arg| &resolver[arg]; + let arg_ty = resolve(arg); + let arg1_ty = arg1.map(resolve); + let arg2_ty = arg2.map(resolve); + let arg3_ty = arg3.map(resolve); match fun { Mf::Abs => { if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { @@ -1372,7 +1517,7 @@ impl super::Validator { kind, convert, } => { - let base_width = match *resolver.resolve(expr)? { + let base_width = match resolver[expr] { crate::TypeInner::Scalar { width, .. } | crate::TypeInner::Vector { width, .. } | crate::TypeInner::Matrix { width, .. } => width, @@ -1401,9 +1546,9 @@ impl super::Validator { } ShaderStages::all() } - E::ArrayLength(expr) => match *resolver.resolve(expr)? { + E::ArrayLength(expr) => match resolver[expr] { Ti::Pointer { base, .. } => { - let base_ty = resolver.types.get_handle(base)?; + let base_ty = resolver.types[base]; if let Ti::Array { size: crate::ArraySize::Dynamic, .. diff --git a/src/valid/function.rs b/src/valid/function.rs index 2107e71e09..45211a9000 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -4,7 +4,7 @@ use crate::arena::{BadHandle, Handle}; use super::{ analyzer::{UniformityDisruptor, UniformityRequirements}, - ExpressionError, FunctionInfo, ModuleInfo, + ExpressionError, FunctionInfo, InvalidHandleError, ModuleInfo, }; use crate::span::WithSpan; #[cfg(feature = "validate")] @@ -16,8 +16,6 @@ use bit_set::BitSet; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum CallError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The callee is declared after the caller")] ForwardDeclaredFunction, #[error("Argument {index} expression is invalid")] @@ -67,8 +65,6 @@ pub enum LocalVariableError { #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum FunctionError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Expression {handle:?} is invalid")] Expression { handle: Handle, @@ -203,11 +199,8 @@ impl<'a> BlockContext<'a> { BlockContext { abilities, ..*self } } - fn get_expression( - &self, - handle: Handle, - ) -> Result<&'a crate::Expression, FunctionError> { - Ok(self.expressions.try_get(handle)?) + fn get_expression(&self, handle: Handle) -> &'a crate::Expression { + &self.expressions[handle] } fn resolve_type_impl( @@ -249,6 +242,19 @@ impl<'a> BlockContext<'a> { } impl super::Validator { + #[cfg(feature = "validate")] + fn validate_call_handles( + function: Handle, + module_functions: &Arena, + ) -> Result<(), WithSpan> { + module_functions + .try_get(function) + .map(|_| ()) + // TODO: use `e.with_span()` instead + // TODO: add better spans yo + .map_err(WithSpan::new) + } + #[cfg(feature = "validate")] fn validate_call( &mut self, @@ -257,11 +263,7 @@ impl super::Validator { result: Option>, context: &BlockContext, ) -> Result> { - let fun = context - .functions - .try_get(function) - .map_err(CallError::BadHandle) - .map_err(WithSpan::new)?; + let fun = &context.functions[function]; if fun.arguments.len() != arguments.len() { return Err(CallError::ArgumentCount { required: fun.arguments.len(), @@ -380,6 +382,84 @@ impl super::Validator { Ok(()) } + #[cfg(feature = "validate")] + fn validate_block_impl_handles( + &mut self, + statements: &crate::Block, + block_expressions: &Arena, + module_functions: &Arena, + ) -> Result<(), WithSpan> { + statements + // TODO: improve spans with `span_iter`, yo + .iter() + // NOTE: Keep this in roughly the same order as corresponding `match` in + // `validate_block_impl_handles`! + .try_for_each(|statement| match statement { + &crate::Statement::Emit(_) => Ok(()), + &crate::Statement::Block(ref block) => { + self.validate_block_handles(block, block_expressions, module_functions) + } + &crate::Statement::If { + condition: _, + accept: _, + reject: _, + } + | &crate::Statement::Switch { + selector: _, + cases: _, + } => Ok(()), + &crate::Statement::Loop { + ref body, + ref continuing, + break_if: _, + } => { + self.validate_block_impl_handles(body, block_expressions, module_functions)?; + self.validate_block_impl_handles( + continuing, + block_expressions, + module_functions, + )?; + Ok(()) + } + &crate::Statement::Break + | &crate::Statement::Continue + | &crate::Statement::Return { value: _ } + | &crate::Statement::Kill + | &crate::Statement::Barrier(_) + | &crate::Statement::Store { + pointer: _, + value: _, + } => Ok(()), + &crate::Statement::ImageStore { + image, + coordinate: _, + array_index: _, + value: _, + } => match block_expressions + .try_get(image) + .map_err(|e| e.with_span())? + { + &crate::Expression::Access { base, .. } + | &crate::Expression::AccessIndex { base, .. } => block_expressions + .try_get(base) + .map(|_| ()) + .map_err(|e| e.with_span()), + _ => Ok(()), + }, + &crate::Statement::Call { + function, + arguments: _, + result: _, + } => Self::validate_call_handles(function, module_functions), + &crate::Statement::Atomic { + pointer: _, + fun: _, + value: _, + result: _, + } => Ok(()), + }) + } + #[cfg(feature = "validate")] fn validate_block_impl( &mut self, @@ -674,14 +754,14 @@ impl super::Validator { } => { //Note: this code uses a lot of `FunctionError::InvalidImageStore`, // and could probably be refactored. - let var = match *context.get_expression(image).map_err(|e| e.with_span())? { + let var = match *context.get_expression(image) { crate::Expression::GlobalVariable(var_handle) => { &context.global_vars[var_handle] } // We're looking at a binding index situation, so punch through the index and look at the global behind it. crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { - match *context.get_expression(base).map_err(|e| e.with_span())? { + match *context.get_expression(base) { crate::Expression::GlobalVariable(var_handle) => { &context.global_vars[var_handle] } @@ -804,6 +884,16 @@ impl super::Validator { Ok(BlockInfo { stages, finished }) } + #[cfg(feature = "validate")] + fn validate_block_handles( + &mut self, + statements: &crate::Block, + block_expressions: &Arena, + module_functions: &Arena, + ) -> Result<(), WithSpan> { + self.validate_block_impl_handles(statements, block_expressions, module_functions) + } + #[cfg(feature = "validate")] fn validate_block( &mut self, @@ -858,6 +948,44 @@ impl super::Validator { Ok(()) } + pub(super) fn validate_function_handles( + &mut self, + fun: &crate::Function, + module: &crate::Module, + mod_info: &ModuleInfo, + ) -> Result<(), WithSpan> { + // FIXME: Break out expression type resolution from `FunctionInfo`'s analyses, just consume + // that here instead. + #[cfg_attr(not(feature = "validate"), allow(unused_mut))] + let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?; + + fun.arguments + .iter() + .try_for_each(|argument| { + module + .types + .get_handle(argument.ty) + .map(|_| ()) + .map_err(|e| e.with_span_handle(argument.ty, &module.types)) + }) + .map_err(|e| e.into_other())?; + + #[cfg(feature = "validate")] + if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { + for (handle, expr) in fun.expressions.iter() { + self.validate_expression_handles(handle, expr, fun, module, &info) + .map_err(|e| e.with_span_handle(handle, &fun.expressions))?; + } + } + + #[cfg(feature = "validate")] + if self.flags.contains(super::ValidationFlags::BLOCKS) { + self.validate_block_handles(&fun.body, &fun.expressions, &module.functions) + .map_err(|e| e.into_other())?; + } + Ok(()) + } + pub(super) fn validate_function( &mut self, fun: &crate::Function, @@ -884,10 +1012,7 @@ impl super::Validator { #[cfg(feature = "validate")] for (index, argument) in fun.arguments.iter().enumerate() { - let ty = module.types.get_handle(argument.ty).map_err(|err| { - FunctionError::from(err).with_span_handle(argument.ty, &module.types) - })?; - match ty.inner.pointer_space() { + match module.types[argument.ty].inner.pointer_space() { Some( crate::AddressSpace::Private | crate::AddressSpace::Function diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 072550e9b0..1d3108b8ca 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -12,8 +12,6 @@ const MAX_WORKGROUP_SIZE: u32 = 0x4000; #[derive(Clone, Debug, thiserror::Error)] pub enum GlobalVariableError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Usage isn't compatible with address space {0:?}")] InvalidUsage(crate::AddressSpace), #[error("Type isn't compatible with address space {0:?}")] @@ -355,6 +353,18 @@ impl VaryingContext<'_> { } impl super::Validator { + #[cfg(feature = "validate")] + pub(super) fn validate_global_var_handles( + &self, + var: &crate::GlobalVariable, + ) -> Result<(), BadHandle> { + let index = var.ty.index(); + self.types.get(index).map(|_| ()).ok_or(BadHandle { + kind: "type", + index, + }) + } + #[cfg(feature = "validate")] pub(super) fn validate_global_var( &self, @@ -364,10 +374,7 @@ impl super::Validator { use super::TypeFlags; log::debug!("var {:?}", var); - let type_info = self.types.get(var.ty.index()).ok_or_else(|| BadHandle { - kind: "type", - index: var.ty.index(), - })?; + let type_info = &self.types[var.ty.index()]; let (required_type_flags, is_resource) = match var.space { crate::AddressSpace::Function => { diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 14e76be61f..b7dd7b2b96 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -143,8 +143,6 @@ pub struct Validator { #[derive(Clone, Debug, thiserror::Error)] pub enum ConstantError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The type doesn't match the constant")] InvalidType, #[error("The component handle {0:?} can not be resolved")] @@ -157,6 +155,9 @@ pub enum ConstantError { #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { + // TODO: I hate these diagnostics. + #[error(transparent)] + InvalidHandle(#[from] InvalidHandleError), #[error(transparent)] Layouter(#[from] LayoutError), #[error("Type {handle:?} '{name}' is invalid")] @@ -198,6 +199,13 @@ pub enum ValidationError { Corrupted, } +// TODO: remove this +impl From for ValidationError { + fn from(e: BadHandle) -> Self { + Self::InvalidHandle(e.into()) + } +} + impl crate::TypeInner { #[cfg(feature = "validate")] const fn is_sized(&self) -> bool { @@ -270,6 +278,22 @@ impl Validator { self.valid_expression_set.clear(); } + #[cfg(feature = "validate")] + fn validate_constant_handles( + &self, + handle: Handle, + constants: &Arena, + types: &UniqueArena, + ) -> Result<(), WithSpan> { + match &constants[handle].inner { + &crate::ConstantInner::Composite { ty, components: _ } => types + .get_handle(ty) + .map(|_| ()) + .map_err(|e| e.with_span_handle(ty, types)), + _ => Ok(()), + } + } + #[cfg(feature = "validate")] fn validate_constant( &self, @@ -285,7 +309,7 @@ impl Validator { } } crate::ConstantInner::Composite { ty, ref components } => { - match types.get_handle(ty)?.inner { + match types[ty].inner { crate::TypeInner::Array { size: crate::ArraySize::Constant(size_handle), .. @@ -369,11 +393,10 @@ impl Validator { #[cfg(feature = "validate")] for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var_handles(var, &module.types) - .map_err(|e| { - e.with_span_handle(var_handle, &module.global_variables) - .into_other() - })?; + self.validate_global_var_handles(var).map_err(|e| { + e.with_span_handle(var_handle, &module.global_variables) + .into_other() + })?; self.validate_global_var(var, &module.types) .map_err(|error| { ValidationError::GlobalVariable { @@ -391,6 +414,8 @@ impl Validator { }; for (handle, fun) in module.functions.iter() { + self.validate_function_handles(fun, module, &mod_info) + .map_err(|e| e.with_handle(handle, &module.functions).into_other())?; match self.validate_function(fun, module, &mod_info, false) { Ok(info) => mod_info.functions.push(info), Err(error) => { @@ -435,3 +460,19 @@ impl Validator { Ok(mod_info) } } + +// TODO: use a more concrete model for better diagnostics? +#[derive(Clone, Debug, thiserror::Error)] +pub enum InvalidHandleError { + #[error(transparent)] + Bad(#[from] BadHandle), + #[error(transparent)] + ExpressionWithForwardDependency(#[from] ExprFwdDepError), +} + +// TODO: use a more concrete model for better diagnostics? +#[derive(Clone, Debug, thiserror::Error)] +#[error("Depends on {depends_on:?}, which has not been processed yet")] +pub struct ExprFwdDepError { + depends_on: Handle, +} diff --git a/src/valid/type.rs b/src/valid/type.rs index f103017dd9..d988920075 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -2,6 +2,8 @@ use super::Capabilities; use crate::{ arena::{Arena, BadHandle, Handle, UniqueArena}, proc::Alignment, + span::AddSpan, + WithSpan, }; bitflags::bitflags! { @@ -88,8 +90,6 @@ pub enum Disalignment { #[derive(Clone, Debug, thiserror::Error)] pub enum TypeError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The {0:?} scalar width {1} is not supported")] InvalidWidth(crate::ScalarKind, crate::Bytes), #[error("The {0:?} scalar width {1} is not supported for an atomic")] @@ -219,6 +219,25 @@ impl super::Validator { self.layouter.clear(); } + pub(super) fn validate_type_handles( + &self, + handle: Handle, + types: &UniqueArena, + constants: &Arena, + ) -> Result<(), WithSpan> { + match types[handle].inner { + crate::TypeInner::Array { + base: _, + size: crate::ArraySize::Constant(const_handle), + stride: _, + } => constants + .try_get(const_handle) + .map(|_| ()) + .map_err(|e| e.with_span_handle(const_handle, constants)), + _ => Ok(()), + } + } + pub(super) fn validate_type( &self, handle: Handle, @@ -418,7 +437,7 @@ impl super::Validator { let sized_flag = match size { crate::ArraySize::Constant(const_handle) => { - let constant = constants.try_get(const_handle)?; + let constant = &constants[const_handle]; let length_is_positive = match *constant { crate::Constant { specialization: Some(_),