From 461fdda4250cdf9aa49de960afd0fa5042a51bc9 Mon Sep 17 00:00:00 2001 From: Erich Gubler Date: Fri, 16 Dec 2022 17:26:43 -0700 Subject: [PATCH] Add handle validation pass to `Validator` (#2090) Before proceeding with any other validation, check that all Handles are valid for their arenas, and refer only to older handles than themselves. This allows subsequent stages to simply use indexing without panics, assuming validation has passed. --- src/arena.rs | 86 +++++- src/back/hlsl/conv.rs | 15 +- src/back/hlsl/writer.rs | 5 +- src/proc/layouter.rs | 9 +- src/proc/mod.rs | 19 +- src/proc/typifier.rs | 31 +- src/valid/analyzer.rs | 9 +- src/valid/compose.rs | 7 +- src/valid/expression.rs | 112 ++++---- src/valid/function.rs | 28 +- src/valid/handles.rs | 616 ++++++++++++++++++++++++++++++++++++++++ src/valid/interface.rs | 9 +- src/valid/mod.rs | 14 +- src/valid/type.rs | 6 +- 14 files changed, 796 insertions(+), 170 deletions(-) create mode 100644 src/valid/handles.rs diff --git a/src/arena.rs b/src/arena.rs index 99d977b2ff..1e7b659371 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -15,6 +15,15 @@ pub struct BadHandle { pub index: usize, } +impl BadHandle { + fn new(handle: Handle) -> Self { + Self { + kind: std::any::type_name::(), + index: handle.index(), + } + } +} + /// A strongly typed reference to an arena item. /// /// A `Handle` value can be used as an index into an [`Arena`] or [`UniqueArena`]. @@ -123,6 +132,35 @@ pub struct Range { marker: PhantomData, } +impl Range { + pub(crate) const fn erase_type(self) -> Range<()> { + let Self { inner, marker: _ } = self; + Range { + inner, + marker: PhantomData, + } + } +} + +// NOTE: Keep this diagnostic in sync with that of [`BadHandle`]. +#[derive(Clone, Debug, thiserror::Error)] +#[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")] +pub struct BadRangeError { + // This error is used for many `Handle` types, but there's no point in making this generic, so + // we just flatten them all to `Handle<()>` here. + kind: &'static str, + range: Range<()>, +} + +impl BadRangeError { + pub fn new(range: Range) -> Self { + Self { + kind: std::any::type_name::(), + range: range.erase_type(), + } + } +} + impl Clone for Range { fn clone(&self) -> Self { Range { @@ -282,10 +320,9 @@ impl Arena { } pub fn try_get(&self, handle: Handle) -> Result<&T, BadHandle> { - self.data.get(handle.index()).ok_or_else(|| BadHandle { - kind: std::any::type_name::(), - index: handle.index(), - }) + self.data + .get(handle.index()) + .ok_or_else(|| BadHandle::new(handle)) } /// Get a mutable reference to an element in the arena. @@ -320,6 +357,31 @@ impl Arena { Span::default() } } + + /// Assert that `handle` is valid for this arena. + pub fn check_contains_handle(&self, handle: Handle) -> Result<(), BadHandle> { + if handle.index() < self.data.len() { + Ok(()) + } else { + Err(BadHandle::new(handle)) + } + } + + /// Assert that `range` is valid for this arena. + pub fn check_contains_range(&self, range: &Range) -> Result<(), BadRangeError> { + // Since `range.inner` is a `Range`, we only need to + // check that the start precedes the end, and that the end is + // in range. + if range.inner.start > range.inner.end + || self + .check_contains_handle(Handle::new(range.inner.end.try_into().unwrap())) + .is_err() + { + Err(BadRangeError::new(range.clone())) + } else { + Ok(()) + } + } } #[cfg(feature = "deserialize")] @@ -540,10 +602,18 @@ impl UniqueArena { /// Return this arena's value at `handle`, if that is a valid handle. pub fn get_handle(&self, handle: Handle) -> Result<&T, BadHandle> { - self.set.get_index(handle.index()).ok_or_else(|| BadHandle { - kind: std::any::type_name::(), - index: handle.index(), - }) + self.set + .get_index(handle.index()) + .ok_or_else(|| BadHandle::new(handle)) + } + + /// Assert that `handle` is valid for this arena. + pub fn check_contains_handle(&self, handle: Handle) -> Result<(), BadHandle> { + if handle.index() < self.set.len() { + Ok(()) + } else { + Err(BadHandle::new(handle)) + } } } diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 3feb80bb68..525930e1cf 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -40,12 +40,12 @@ impl crate::TypeInner { } } - pub(super) fn try_size_hlsl( + pub(super) fn size_hlsl( &self, types: &crate::UniqueArena, constants: &crate::Arena, - ) -> Result { - Ok(match *self { + ) -> u32 { + match *self { Self::Matrix { columns, rows, @@ -58,17 +58,16 @@ impl crate::TypeInner { Self::Array { base, size, stride } => { let count = match size { crate::ArraySize::Constant(handle) => { - let constant = constants.try_get(handle)?; - constant.to_array_length().unwrap_or(1) + constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element crate::ArraySize::Dynamic => 1, }; - let last_el_size = types[base].inner.try_size_hlsl(types, constants)?; + let last_el_size = types[base].inner.size_hlsl(types, constants); ((count - 1) * stride) + last_el_size } - _ => self.try_size(constants)?, - }) + _ => self.size(constants), + } } /// Used to generate the name of the wrapped type constructor diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index e1325e5abf..bf841a121c 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -829,10 +829,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset - + ty_inner - .try_size_hlsl(&module.types, &module.constants) - .unwrap(); + last_offset = member.offset + ty_inner.size_hlsl(&module.types, &module.constants); // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 0c3a00db15..db07f261a4 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -1,4 +1,4 @@ -use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; +use crate::arena::{Arena, Handle, UniqueArena}; use std::{fmt::Display, num::NonZeroU32, ops}; /// A newtype struct where its only valid values are powers of 2 @@ -130,8 +130,6 @@ pub enum LayoutErrorInner { InvalidStructMemberType(u32, Handle), #[error("Type width must be a power of two")] NonPowerOfTwoWidth, - #[error("Array size is a bad handle")] - BadHandle(#[from] BadHandle), } #[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] @@ -175,10 +173,7 @@ impl Layouter { use crate::TypeInner as Ti; for (ty_handle, ty) in types.iter().skip(self.layouts.len()) { - let size = ty - .inner - .try_size(constants) - .map_err(|error| LayoutErrorInner::BadHandle(error).with(ty_handle))?; + let size = ty.inner.size(constants); let layout = match ty.inner { Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => { let alignment = Alignment::new(width as u32) diff --git a/src/proc/mod.rs b/src/proc/mod.rs index a5731de896..c718c33b24 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -97,11 +97,9 @@ impl super::TypeInner { } } - pub fn try_size( - &self, - constants: &super::Arena, - ) -> Result { - Ok(match *self { + /// Get the size of this type. + pub fn size(&self, constants: &super::Arena) -> u32 { + match *self { Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, Self::Vector { size, @@ -122,8 +120,7 @@ impl super::TypeInner { } => { let count = match size { super::ArraySize::Constant(handle) => { - let constant = constants.try_get(handle)?; - constant.to_array_length().unwrap_or(1) + constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element super::ArraySize::Dynamic => 1, @@ -132,13 +129,7 @@ impl super::TypeInner { } Self::Struct { span, .. } => span, Self::Image { .. } | Self::Sampler { .. } | Self::BindingArray { .. } => 0, - }) - } - - /// Get the size of this type. Panics if the `constants` doesn't contain - /// a referenced handle. This may not happen in a properly validated IR module. - pub fn size(&self, constants: &super::Arena) -> u32 { - self.try_size(constants).unwrap() + } } /// Return the canonical form of `self`, or `None` if it's already in diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 9df538cc2b..47ad05e06c 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -1,4 +1,4 @@ -use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; +use crate::arena::{Arena, Handle, UniqueArena}; use thiserror::Error; @@ -162,8 +162,6 @@ impl crate::ConstantInner { #[derive(Clone, Debug, Error, PartialEq)] pub enum ResolveError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Index {index} is out of bounds for expression {expr:?}")] OutOfBoundsIndex { expr: Handle, @@ -195,8 +193,6 @@ pub enum ResolveError { IncompatibleOperands(String), #[error("Function argument {0} doesn't exist")] FunctionArgumentNotFound(u32), - #[error("Expression {0:?} depends on expressions that follow")] - ExpressionForwardDependency(Handle), } pub struct ResolveContext<'a> { @@ -403,20 +399,15 @@ impl<'a> ResolveContext<'a> { } } } - crate::Expression::Constant(h) => { - let constant = self.constants.try_get(h)?; - match constant.inner { - crate::ConstantInner::Scalar { width, ref value } => { - TypeResolution::Value(Ti::Scalar { - kind: value.scalar_kind(), - width, - }) - } - crate::ConstantInner::Composite { ty, components: _ } => { - TypeResolution::Handle(ty) - } + crate::Expression::Constant(h) => match self.constants[h].inner { + crate::ConstantInner::Scalar { width, ref value } => { + TypeResolution::Value(Ti::Scalar { + kind: value.scalar_kind(), + width, + }) } - } + crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty), + }, crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) { Ti::Scalar { kind, width } => { TypeResolution::Value(Ti::Vector { size, kind, width }) @@ -450,7 +441,7 @@ impl<'a> ResolveContext<'a> { TypeResolution::Handle(arg.ty) } crate::Expression::GlobalVariable(h) => { - let var = self.global_vars.try_get(h)?; + let var = &self.global_vars[h]; if var.space == crate::AddressSpace::Handle { TypeResolution::Handle(var.ty) } else { @@ -461,7 +452,7 @@ impl<'a> ResolveContext<'a> { } } crate::Expression::LocalVariable(h) => { - let var = self.local_vars.try_get(h)?; + let var = &self.local_vars[h]; TypeResolution::Value(Ti::Pointer { base: var.ty, space: crate::AddressSpace::Function, diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index ec8e9d96e5..eb6c1fc4a7 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -10,7 +10,7 @@ use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, use crate::span::{AddSpan as _, WithSpan}; use crate::{ arena::{Arena, Handle}, - proc::{ResolveContext, ResolveError, TypeResolution}, + proc::{ResolveContext, TypeResolution}, }; use std::ops; @@ -706,12 +706,7 @@ impl FunctionInfo { }, }; - let ty = resolve_context.resolve(expression, |h| { - self.expressions - .get(h.index()) - .map(|ei| &ei.ty) - .ok_or(ResolveError::ExpressionForwardDependency(h)) - })?; + let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; self.expressions[handle.index()] = ExpressionInfo { uniformity, ref_count: 0, diff --git a/src/valid/compose.rs b/src/valid/compose.rs index 6e5c499223..e77d538255 100644 --- a/src/valid/compose.rs +++ b/src/valid/compose.rs @@ -4,13 +4,11 @@ use crate::{ proc::TypeResolution, }; -use crate::arena::{BadHandle, Handle}; +use crate::arena::Handle; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum ComposeError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Composing of type {0:?} can't be done")] Type(Handle), #[error("Composing expects {expected} components but {given} were given")] @@ -28,8 +26,7 @@ pub fn validate_compose( ) -> Result<(), ComposeError> { use crate::TypeInner as Ti; - let self_ty = type_arena.get_handle(self_ty_handle)?; - match self_ty.inner { + match type_arena[self_ty_handle].inner { // vectors are composed from scalars or other vectors Ti::Vector { size, kind, width } => { let mut total = 0; diff --git a/src/valid/expression.rs b/src/valid/expression.rs index a3afc0535a..14f52fb93c 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "validate")] +use std::ops::Index; + #[cfg(feature = "validate")] use super::{ compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ShaderStages, @@ -7,7 +10,7 @@ use super::{ use crate::arena::UniqueArena; use crate::{ - arena::{BadHandle, Handle}, + arena::Handle, proc::{IndexableLengthError, ResolveError}, }; @@ -18,10 +21,6 @@ pub enum ExpressionError { DoesntExist, #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] NotInScope, - #[error("Depends on {0:?}, which has not been processed yet")] - ForwardDependency(Handle), - #[error(transparent)] - BadDependency(#[from] BadHandle), #[error("Base type {0:?} is not compatible with this expression")] InvalidBaseType(Handle), #[error("Accessing with index {0:?} can't be done")] @@ -132,15 +131,19 @@ struct ExpressionTypeResolver<'a> { } #[cfg(feature = "validate")] -impl<'a> ExpressionTypeResolver<'a> { - fn resolve( - &self, - handle: Handle, - ) -> Result<&'a crate::TypeInner, ExpressionError> { +impl<'a> Index> for ExpressionTypeResolver<'a> { + type Output = crate::TypeInner; + + #[allow(clippy::panic)] + fn index(&self, handle: Handle) -> &Self::Output { if handle < self.root { - Ok(self.info[handle].ty.inner_with(self.types)) + self.info[handle].ty.inner_with(self.types) } else { - Err(ExpressionError::ForwardDependency(handle)) + // `Validator::validate_module_handles` should have caught this. + panic!( + "Depends on {:?}, which has not been processed yet", + self.root + ) } } } @@ -166,7 +169,7 @@ impl super::Validator { let stages = match *expression { E::Access { base, index } => { - let base_type = resolver.resolve(base)?; + let base_type = &resolver[base]; // See the documentation for `Expression::Access`. let dynamic_indexing_restricted = match *base_type { Ti::Vector { .. } => false, @@ -179,7 +182,7 @@ impl super::Validator { return Err(ExpressionError::InvalidBaseType(base)); } }; - match *resolver.resolve(index)? { + match resolver[index] { //TODO: only allow one of these Ti::Scalar { kind: Sk::Sint | Sk::Uint, @@ -257,7 +260,7 @@ impl super::Validator { Ok(limit) } - let limit = resolve_index_limit(module, base, resolver.resolve(base)?, true)?; + let limit = resolve_index_limit(module, base, &resolver[base], true)?; if index >= limit { return Err(ExpressionError::IndexOutOfBounds( base, @@ -266,11 +269,8 @@ impl super::Validator { } ShaderStages::all() } - E::Constant(handle) => { - let _ = module.constants.try_get(handle)?; - ShaderStages::all() - } - E::Splat { size: _, value } => match *resolver.resolve(value)? { + E::Constant(_handle) => ShaderStages::all(), + E::Splat { size: _, value } => match resolver[value] { Ti::Scalar { .. } => ShaderStages::all(), ref other => { log::error!("Splat scalar type {:?}", other); @@ -282,7 +282,7 @@ impl super::Validator { vector, pattern, } => { - let vec_size = match *resolver.resolve(vector)? { + let vec_size = match resolver[vector] { Ti::Vector { size: vec_size, .. } => vec_size, ref other => { log::error!("Swizzle vector type {:?}", other); @@ -297,11 +297,6 @@ impl super::Validator { ShaderStages::all() } E::Compose { ref components, ty } => { - for &handle in components { - if handle >= root { - return Err(ExpressionError::ForwardDependency(handle)); - } - } validate_compose( ty, &module.constants, @@ -316,16 +311,10 @@ impl super::Validator { } ShaderStages::all() } - E::GlobalVariable(handle) => { - let _ = module.global_variables.try_get(handle)?; - ShaderStages::all() - } - E::LocalVariable(handle) => { - let _ = function.local_variables.try_get(handle)?; - ShaderStages::all() - } + E::GlobalVariable(_handle) => ShaderStages::all(), + E::LocalVariable(_handle) => ShaderStages::all(), E::Load { pointer } => { - match *resolver.resolve(pointer)? { + match resolver[pointer] { Ti::Pointer { base, .. } if self.types[base.index()] .flags @@ -368,7 +357,7 @@ impl super::Validator { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Sint, width: _, @@ -408,7 +397,7 @@ impl super::Validator { crate::ImageDimension::D2 => 2, crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3, }; - match *resolver.resolve(coordinate)? { + match resolver[coordinate] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -446,7 +435,7 @@ impl super::Validator { // check depth reference type if let Some(expr) = depth_ref { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -483,7 +472,7 @@ impl super::Validator { crate::SampleLevel::Auto => ShaderStages::FRAGMENT, crate::SampleLevel::Zero => ShaderStages::all(), crate::SampleLevel::Exact(expr) => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -492,7 +481,7 @@ impl super::Validator { ShaderStages::all() } crate::SampleLevel::Bias(expr) => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -501,7 +490,7 @@ impl super::Validator { ShaderStages::all() } crate::SampleLevel::Gradient { x, y } => { - match *resolver.resolve(x)? { + match resolver[x] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -514,7 +503,7 @@ impl super::Validator { return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x)) } } - match *resolver.resolve(y)? { + match resolver[y] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -545,7 +534,7 @@ impl super::Validator { arrayed, dim, } => { - match resolver.resolve(coordinate)?.image_storage_coordinates() { + match resolver[coordinate].image_storage_coordinates() { Some(coord_dim) if coord_dim == dim => {} _ => { return Err(ExpressionError::InvalidImageCoordinateType( @@ -557,7 +546,7 @@ impl super::Validator { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Sint, width: _, @@ -569,7 +558,7 @@ impl super::Validator { match (sample, class.is_multisampled()) { (None, false) => {} (Some(sample), true) => { - if resolver.resolve(sample)?.scalar_kind() != Some(Sk::Sint) { + if resolver[sample].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType( sample, )); @@ -583,7 +572,7 @@ impl super::Validator { match (level, class.is_mipmapped()) { (None, false) => {} (Some(level), true) => { - if resolver.resolve(level)?.scalar_kind() != Some(Sk::Sint) { + if resolver[level].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType(level)); } } @@ -617,7 +606,7 @@ impl super::Validator { } E::Unary { op, expr } => { use crate::UnaryOperator as Uo; - let inner = resolver.resolve(expr)?; + let inner = &resolver[expr]; match (op, inner.scalar_kind()) { (_, Some(Sk::Sint | Sk::Bool)) //TODO: restrict Negate for bools? @@ -632,8 +621,8 @@ impl super::Validator { } E::Binary { op, left, right } => { use crate::BinaryOperator as Bo; - let left_inner = resolver.resolve(left)?; - let right_inner = resolver.resolve(right)?; + let left_inner = &resolver[left]; + let right_inner = &resolver[right]; let good = match op { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { @@ -814,9 +803,9 @@ impl super::Validator { accept, reject, } => { - let accept_inner = resolver.resolve(accept)?; - let reject_inner = resolver.resolve(reject)?; - let condition_good = match *resolver.resolve(condition)? { + let accept_inner = &resolver[accept]; + let reject_inner = &resolver[reject]; + let condition_good = match resolver[condition] { Ti::Scalar { kind: Sk::Bool, width: _, @@ -846,7 +835,7 @@ impl super::Validator { ShaderStages::all() } E::Derivative { axis: _, expr } => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } @@ -859,7 +848,7 @@ impl super::Validator { } E::Relational { fun, argument } => { use crate::RelationalFunction as Rf; - let argument_inner = resolver.resolve(argument)?; + let argument_inner = &resolver[argument]; match fun { Rf::All | Rf::Any => match *argument_inner { Ti::Vector { kind: Sk::Bool, .. } => {} @@ -892,10 +881,11 @@ impl super::Validator { } => { use crate::MathFunction as Mf; - let arg_ty = resolver.resolve(arg)?; - let arg1_ty = arg1.map(|expr| resolver.resolve(expr)).transpose()?; - let arg2_ty = arg2.map(|expr| resolver.resolve(expr)).transpose()?; - let arg3_ty = arg3.map(|expr| resolver.resolve(expr)).transpose()?; + let resolve = |arg| &resolver[arg]; + let arg_ty = resolve(arg); + let arg1_ty = arg1.map(resolve); + let arg2_ty = arg2.map(resolve); + let arg3_ty = arg3.map(resolve); match fun { Mf::Abs => { if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { @@ -1379,7 +1369,7 @@ impl super::Validator { kind, convert, } => { - let base_width = match *resolver.resolve(expr)? { + let base_width = match resolver[expr] { crate::TypeInner::Scalar { width, .. } | crate::TypeInner::Vector { width, .. } | crate::TypeInner::Matrix { width, .. } => width, @@ -1416,9 +1406,9 @@ impl super::Validator { } ShaderStages::all() } - E::ArrayLength(expr) => match *resolver.resolve(expr)? { + E::ArrayLength(expr) => match resolver[expr] { Ti::Pointer { base, .. } => { - let base_ty = resolver.types.get_handle(base)?; + let base_ty = &resolver.types[base]; if let Ti::Array { size: crate::ArraySize::Dynamic, .. diff --git a/src/valid/function.rs b/src/valid/function.rs index 0f0a7b89f5..3c555491c3 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -1,6 +1,6 @@ +use crate::arena::Handle; #[cfg(feature = "validate")] use crate::arena::{Arena, UniqueArena}; -use crate::arena::{BadHandle, Handle}; #[cfg(feature = "validate")] use super::validate_atomic_compare_exchange_struct; @@ -19,8 +19,6 @@ use bit_set::BitSet; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum CallError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The callee is declared after the caller")] ForwardDeclaredFunction, #[error("Argument {index} expression is invalid")] @@ -69,8 +67,6 @@ pub enum LocalVariableError { #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum FunctionError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Expression {handle:?} is invalid")] Expression { handle: Handle, @@ -203,11 +199,8 @@ impl<'a> BlockContext<'a> { BlockContext { abilities, ..*self } } - fn get_expression( - &self, - handle: Handle, - ) -> Result<&'a crate::Expression, FunctionError> { - Ok(self.expressions.try_get(handle)?) + fn get_expression(&self, handle: Handle) -> &'a crate::Expression { + &self.expressions[handle] } fn resolve_type_impl( @@ -257,11 +250,7 @@ impl super::Validator { result: Option>, context: &BlockContext, ) -> Result> { - let fun = context - .functions - .try_get(function) - .map_err(CallError::BadHandle) - .map_err(WithSpan::new)?; + let fun = &context.functions[function]; if fun.arguments.len() != arguments.len() { return Err(CallError::ArgumentCount { required: fun.arguments.len(), @@ -689,14 +678,14 @@ impl super::Validator { } => { //Note: this code uses a lot of `FunctionError::InvalidImageStore`, // and could probably be refactored. - let var = match *context.get_expression(image).map_err(|e| e.with_span())? { + let var = match *context.get_expression(image) { crate::Expression::GlobalVariable(var_handle) => { &context.global_vars[var_handle] } // We're looking at a binding index situation, so punch through the index and look at the global behind it. crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { - match *context.get_expression(base).map_err(|e| e.with_span())? { + match *context.get_expression(base) { crate::Expression::GlobalVariable(var_handle) => { &context.global_vars[var_handle] } @@ -899,10 +888,7 @@ impl super::Validator { #[cfg(feature = "validate")] for (index, argument) in fun.arguments.iter().enumerate() { - let ty = module.types.get_handle(argument.ty).map_err(|err| { - FunctionError::from(err).with_span_handle(argument.ty, &module.types) - })?; - match ty.inner.pointer_space() { + match module.types[argument.ty].inner.pointer_space() { Some( crate::AddressSpace::Private | crate::AddressSpace::Function diff --git a/src/valid/handles.rs b/src/valid/handles.rs new file mode 100644 index 0000000000..5b3375a873 --- /dev/null +++ b/src/valid/handles.rs @@ -0,0 +1,616 @@ +//! Implementation of [`super::Validator::validate_module_handles`]. + +use crate::{ + arena::{BadHandle, BadRangeError}, + Handle, +}; + +#[cfg(feature = "validate")] +use crate::{Arena, UniqueArena}; + +#[cfg(feature = "validate")] +use super::{TypeError, ValidationError}; + +#[cfg(feature = "validate")] +use std::{convert::TryInto, hash::Hash, num::NonZeroU32}; + +#[cfg(feature = "validate")] +impl super::Validator { + /// Validates that all handles within `module` are: + /// + /// * Valid, in the sense that they contain indices within each arena structure inside the + /// [`crate::Module`] type. + /// * No arena contents contain any items that have forward dependencies; that is, the value + /// associated with a handle only may contain references to handles in the same arena that + /// were constructed before it. + /// + /// By validating the above conditions, we free up subsequent logic to assume that handle + /// accesses are infallible. + /// + /// # Errors + /// + /// Errors returned by this method are intentionally sparse, for simplicity of implementation. + /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this + /// validation pass. + pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { + let &crate::Module { + ref constants, + ref entry_points, + ref functions, + ref global_variables, + ref types, + } = module; + + // NOTE: Types being first is important. All other forms of validation depend on this. + for (this_handle, ty) in types.iter() { + let &crate::Type { + ref name, + ref inner, + } = ty; + + let validate_array_size = |size| { + match size { + crate::ArraySize::Constant(constant) => { + let &crate::Constant { + name: _, + specialization: _, + ref inner, + } = constants.try_get(constant)?; + if !matches!(inner, &crate::ConstantInner::Scalar { .. }) { + return Err(ValidationError::Type { + handle: this_handle, + name: name.clone().unwrap_or_default(), + source: TypeError::InvalidArraySizeConstant(constant), + }); + } + } + crate::ArraySize::Dynamic => (), + }; + Ok(this_handle) + }; + + match *inner { + crate::TypeInner::Scalar { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::ValuePointer { .. } + | crate::TypeInner::Atomic { .. } + | crate::TypeInner::Image { .. } + | crate::TypeInner::Sampler { .. } => (), + crate::TypeInner::Pointer { base, space: _ } => { + this_handle.check_dep(base)?; + } + crate::TypeInner::Array { + base, + size, + stride: _, + } + | crate::TypeInner::BindingArray { base, size } => { + this_handle.check_dep(base)?; + validate_array_size(size)?; + } + crate::TypeInner::Struct { + ref members, + span: _, + } => { + this_handle.check_dep_iter(members.iter().map(|m| m.ty))?; + } + } + } + + let validate_type = |handle| Self::validate_type_handle(handle, types); + + for (this_handle, constant) in constants.iter() { + let &crate::Constant { + name: _, + specialization: _, + ref inner, + } = constant; + match *inner { + crate::ConstantInner::Scalar { .. } => (), + crate::ConstantInner::Composite { ty, ref components } => { + validate_type(ty)?; + this_handle.check_dep_iter(components.iter().copied())?; + } + } + } + + let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + + for (_handle, global_variable) in global_variables.iter() { + let &crate::GlobalVariable { + name: _, + space: _, + binding: _, + ty, + init, + } = global_variable; + validate_type(ty)?; + if let Some(init_expr) = init { + validate_constant(init_expr)?; + } + } + + let validate_function = |function: &_| -> Result<_, InvalidHandleError> { + let &crate::Function { + name: _, + ref arguments, + ref result, + ref local_variables, + ref expressions, + ref named_expressions, + ref body, + } = function; + + for arg in arguments.iter() { + let &crate::FunctionArgument { + name: _, + ty, + binding: _, + } = arg; + validate_type(ty)?; + } + + if let &Some(crate::FunctionResult { ty, binding: _ }) = result { + validate_type(ty)?; + } + + for (_handle, local_variable) in local_variables.iter() { + let &crate::LocalVariable { name: _, ty, init } = local_variable; + validate_type(ty)?; + if let Some(init_constant) = init { + validate_constant(init_constant)?; + } + } + + for handle in named_expressions.keys().copied() { + Self::validate_expression_handle(handle, expressions)?; + } + + for handle_and_expr in expressions.iter() { + Self::validate_expression_handles( + handle_and_expr, + constants, + types, + local_variables, + global_variables, + functions, + )?; + } + + Self::validate_block_handles(body, expressions, functions)?; + + Ok(()) + }; + + for entry_point in entry_points.iter() { + validate_function(&entry_point.function)?; + } + + for (_function_handle, function) in functions.iter() { + validate_function(function)?; + } + + Ok(()) + } + + fn validate_type_handle( + handle: Handle, + types: &UniqueArena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for_uniq(types).map(|_| ()) + } + + fn validate_constant_handle( + handle: Handle, + constants: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(constants).map(|_| ()) + } + + fn validate_expression_handle( + handle: Handle, + expressions: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(expressions).map(|_| ()) + } + + fn validate_function_handle( + handle: Handle, + functions: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(functions).map(|_| ()) + } + + fn validate_expression_handles( + (handle, expression): (Handle, &crate::Expression), + constants: &Arena, + types: &UniqueArena, + local_variables: &Arena, + global_variables: &Arena, + functions: &Arena, + ) -> Result<(), InvalidHandleError> { + let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_type = |handle| Self::validate_type_handle(handle, types); + + match *expression { + crate::Expression::Access { base, index } => { + handle.check_dep(base)?.check_dep(index)?; + } + crate::Expression::AccessIndex { base, .. } => { + handle.check_dep(base)?; + } + crate::Expression::Constant(constant) => { + validate_constant(constant)?; + } + crate::Expression::Splat { value, .. } => { + handle.check_dep(value)?; + } + crate::Expression::Swizzle { vector, .. } => { + handle.check_dep(vector)?; + } + crate::Expression::Compose { ty, ref components } => { + validate_type(ty)?; + handle.check_dep_iter(components.iter().copied())?; + } + crate::Expression::FunctionArgument(_arg_idx) => (), + crate::Expression::GlobalVariable(global_variable) => { + global_variable.check_valid_for(global_variables)?; + } + crate::Expression::LocalVariable(local_variable) => { + local_variable.check_valid_for(local_variables)?; + } + crate::Expression::Load { pointer } => { + handle.check_dep(pointer)?; + } + crate::Expression::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + if let Some(offset) = offset { + validate_constant(offset)?; + } + + handle + .check_dep(image)? + .check_dep(sampler)? + .check_dep(coordinate)? + .check_dep_opt(array_index)?; + + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Zero => (), + crate::SampleLevel::Exact(expr) => { + handle.check_dep(expr)?; + } + crate::SampleLevel::Bias(expr) => { + handle.check_dep(expr)?; + } + crate::SampleLevel::Gradient { x, y } => { + handle.check_dep(x)?.check_dep(y)?; + } + }; + + handle.check_dep_opt(depth_ref)?; + } + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + handle + .check_dep(image)? + .check_dep(coordinate)? + .check_dep_opt(array_index)? + .check_dep_opt(sample)? + .check_dep_opt(level)?; + } + crate::Expression::ImageQuery { image, query } => { + handle.check_dep(image)?; + match query { + crate::ImageQuery::Size { level } => { + handle.check_dep_opt(level)?; + } + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => (), + }; + } + crate::Expression::Unary { + op: _, + expr: operand, + } => { + handle.check_dep(operand)?; + } + crate::Expression::Binary { op: _, left, right } => { + handle.check_dep(left)?.check_dep(right)?; + } + crate::Expression::Select { + condition, + accept, + reject, + } => { + handle + .check_dep(condition)? + .check_dep(accept)? + .check_dep(reject)?; + } + crate::Expression::Derivative { + axis: _, + expr: argument, + } => { + handle.check_dep(argument)?; + } + crate::Expression::Relational { fun: _, argument } => { + handle.check_dep(argument)?; + } + crate::Expression::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + handle + .check_dep(arg)? + .check_dep_opt(arg1)? + .check_dep_opt(arg2)? + .check_dep_opt(arg3)?; + } + crate::Expression::As { + expr: input, + kind: _, + convert: _, + } => { + handle.check_dep(input)?; + } + crate::Expression::CallResult(function) => { + Self::validate_function_handle(function, functions)?; + } + crate::Expression::AtomicResult { .. } => (), + crate::Expression::ArrayLength(array) => { + handle.check_dep(array)?; + } + } + Ok(()) + } + + fn validate_block_handles( + block: &crate::Block, + expressions: &Arena, + functions: &Arena, + ) -> Result<(), InvalidHandleError> { + let validate_block = |block| Self::validate_block_handles(block, expressions, functions); + let validate_expr = |handle| Self::validate_expression_handle(handle, expressions); + let validate_expr_opt = |handle_opt| { + if let Some(handle) = handle_opt { + validate_expr(handle)?; + } + Ok(()) + }; + + block.iter().try_for_each(|stmt| match *stmt { + crate::Statement::Emit(ref expr_range) => { + expr_range.check_valid_for(expressions)?; + Ok(()) + } + crate::Statement::Block(ref block) => { + validate_block(block)?; + Ok(()) + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + validate_expr(condition)?; + validate_block(accept)?; + validate_block(reject)?; + Ok(()) + } + crate::Statement::Switch { + selector, + ref cases, + } => { + validate_expr(selector)?; + for &crate::SwitchCase { + value: _, + ref body, + fall_through: _, + } in cases + { + validate_block(body)?; + } + Ok(()) + } + crate::Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + validate_block(body)?; + validate_block(continuing)?; + validate_expr_opt(break_if)?; + Ok(()) + } + crate::Statement::Return { value } => validate_expr_opt(value), + crate::Statement::Store { pointer, value } => { + validate_expr(pointer)?; + validate_expr(value)?; + Ok(()) + } + crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + validate_expr(image)?; + validate_expr(coordinate)?; + validate_expr_opt(array_index)?; + validate_expr(value)?; + Ok(()) + } + crate::Statement::Atomic { + pointer, + fun, + value, + result, + } => { + validate_expr(pointer)?; + match fun { + crate::AtomicFunction::Add + | crate::AtomicFunction::Subtract + | crate::AtomicFunction::And + | crate::AtomicFunction::ExclusiveOr + | crate::AtomicFunction::InclusiveOr + | crate::AtomicFunction::Min + | crate::AtomicFunction::Max => (), + crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?, + }; + validate_expr(value)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::Call { + function, + ref arguments, + result, + } => { + Self::validate_function_handle(function, functions)?; + for arg in arguments.iter().copied() { + validate_expr(arg)?; + } + validate_expr_opt(result)?; + Ok(()) + } + crate::Statement::Break + | crate::Statement::Continue + | crate::Statement::Kill + | crate::Statement::Barrier(_) => Ok(()), + }) + } +} + +#[cfg(feature = "validate")] +impl From for ValidationError { + fn from(source: BadHandle) -> Self { + Self::InvalidHandle(source.into()) + } +} + +#[cfg(feature = "validate")] +impl From for ValidationError { + fn from(source: FwdDepError) -> Self { + Self::InvalidHandle(source.into()) + } +} + +#[cfg(feature = "validate")] +impl From for ValidationError { + fn from(source: BadRangeError) -> Self { + Self::InvalidHandle(source.into()) + } +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum InvalidHandleError { + #[error(transparent)] + BadHandle(#[from] BadHandle), + #[error(transparent)] + ForwardDependency(#[from] FwdDepError), + #[error(transparent)] + BadRange(#[from] BadRangeError), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[error( + "{subject:?} of kind depends on {depends_on:?} of kind {depends_on_kind}, which has not been \ + processed yet" +)] +pub struct FwdDepError { + // This error is used for many `Handle` types, but there's no point in making this generic, so + // we just flatten them all to `Handle<()>` here. + subject: Handle<()>, + subject_kind: &'static str, + depends_on: Handle<()>, + depends_on_kind: &'static str, +} + +#[cfg(feature = "validate")] +impl Handle { + /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`]. + pub(self) fn check_valid_for(self, arena: &Arena) -> Result<(), InvalidHandleError> { + arena.check_contains_handle(self)?; + Ok(()) + } + + /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`]. + pub(self) fn check_valid_for_uniq( + self, + arena: &UniqueArena, + ) -> Result<(), InvalidHandleError> + where + T: Eq + Hash, + { + arena.check_contains_handle(self)?; + Ok(()) + } + + /// Check that `depends_on` was constructed before `self` by comparing handle indices. + /// + /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`]) + /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid. + /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating + /// recursive definitions of arena-based values in linear time. + /// + /// # Errors + /// + /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier + /// than `self`'s, this function returns an error. + pub(self) fn check_dep(self, depends_on: Self) -> Result { + if depends_on < self { + Ok(self) + } else { + let erase_handle_type = |handle: Handle<_>| { + Handle::new(NonZeroU32::new(handle.index().try_into().unwrap()).unwrap()) + }; + Err(FwdDepError { + subject: erase_handle_type(self), + subject_kind: std::any::type_name::(), + depends_on: erase_handle_type(depends_on), + depends_on_kind: std::any::type_name::(), + }) + } + } + + /// Like [`Self::check_dep`], except for [`Option`]al handle values. + pub(self) fn check_dep_opt(self, depends_on: Option) -> Result { + self.check_dep_iter(depends_on.into_iter()) + } + + /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values. + pub(self) fn check_dep_iter( + self, + depends_on: impl Iterator, + ) -> Result { + for handle in depends_on { + self.check_dep(handle)?; + } + Ok(self) + } +} + +#[cfg(feature = "validate")] +impl crate::arena::Range { + pub(self) fn check_valid_for(&self, arena: &Arena) -> Result<(), BadRangeError> { + arena.check_contains_range(self) + } +} diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 85610b068e..289a068f75 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -2,7 +2,7 @@ use super::{ analyzer::{FunctionInfo, GlobalUse}, Capabilities, Disalignment, FunctionError, ModuleInfo, }; -use crate::arena::{BadHandle, Handle, UniqueArena}; +use crate::arena::{Handle, UniqueArena}; use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan}; use bit_set::BitSet; @@ -12,8 +12,6 @@ const MAX_WORKGROUP_SIZE: u32 = 0x4000; #[derive(Clone, Debug, thiserror::Error)] pub enum GlobalVariableError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Usage isn't compatible with address space {0:?}")] InvalidUsage(crate::AddressSpace), #[error("Type isn't compatible with address space {0:?}")] @@ -380,10 +378,7 @@ impl super::Validator { use super::TypeFlags; log::debug!("var {:?}", var); - let type_info = self.types.get(var.ty.index()).ok_or_else(|| BadHandle { - kind: "type", - index: var.ty.index(), - })?; + let type_info = &self.types[var.ty.index()]; let (required_type_flags, is_resource) = match var.space { crate::AddressSpace::Function => { diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 255d6f428e..4e62a2ca78 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -6,6 +6,7 @@ mod analyzer; mod compose; mod expression; mod function; +mod handles; mod interface; mod r#type; @@ -13,7 +14,7 @@ mod r#type; use crate::arena::{Arena, UniqueArena}; use crate::{ - arena::{BadHandle, Handle}, + arena::Handle, proc::{LayoutError, Layouter}, FastHashSet, }; @@ -31,6 +32,8 @@ pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; pub use r#type::{Disalignment, TypeError, TypeFlags}; +use self::handles::InvalidHandleError; + bitflags::bitflags! { /// Validation flags. /// @@ -146,8 +149,6 @@ pub struct Validator { #[derive(Clone, Debug, thiserror::Error)] pub enum ConstantError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The type doesn't match the constant")] InvalidType, #[error("The component handle {0:?} can not be resolved")] @@ -160,6 +161,8 @@ pub enum ConstantError { #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { + #[error(transparent)] + InvalidHandle(#[from] InvalidHandleError), #[error(transparent)] Layouter(#[from] LayoutError), #[error("Type {handle:?} '{name}' is invalid")] @@ -283,7 +286,7 @@ impl Validator { } } crate::ConstantInner::Composite { ty, ref components } => { - match types.get_handle(ty)?.inner { + match types[ty].inner { crate::TypeInner::Array { size: crate::ArraySize::Constant(size_handle), .. @@ -316,6 +319,9 @@ impl Validator { self.reset(); self.reset_types(module.types.len()); + #[cfg(feature = "validate")] + Self::validate_module_handles(module).map_err(|e| e.with_span())?; + self.layouter .update(&module.types, &module.constants) .map_err(|e| { diff --git a/src/valid/type.rs b/src/valid/type.rs index f103017dd9..172f110724 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -1,6 +1,6 @@ use super::Capabilities; use crate::{ - arena::{Arena, BadHandle, Handle, UniqueArena}, + arena::{Arena, Handle, UniqueArena}, proc::Alignment, }; @@ -88,8 +88,6 @@ pub enum Disalignment { #[derive(Clone, Debug, thiserror::Error)] pub enum TypeError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The {0:?} scalar width {1} is not supported")] InvalidWidth(crate::ScalarKind, crate::Bytes), #[error("The {0:?} scalar width {1} is not supported for an atomic")] @@ -418,7 +416,7 @@ impl super::Validator { let sized_flag = match size { crate::ArraySize::Constant(const_handle) => { - let constant = constants.try_get(const_handle)?; + let constant = &constants[const_handle]; let length_is_positive = match *constant { crate::Constant { specialization: Some(_),