Skip to content

Commit

Permalink
Add handle validation pass to Validator (#2090)
Browse files Browse the repository at this point in the history
Before proceeding with any other validation, check that all Handles are valid for their arenas, and refer only to older handles than themselves. This allows subsequent stages to simply use indexing without panics, assuming validation has passed.
  • Loading branch information
ErichDonGubler authored Dec 17, 2022
1 parent 420c984 commit 461fdda
Show file tree
Hide file tree
Showing 14 changed files with 796 additions and 170 deletions.
86 changes: 78 additions & 8 deletions src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ pub struct BadHandle {
pub index: usize,
}

impl BadHandle {
fn new<T>(handle: Handle<T>) -> Self {
Self {
kind: std::any::type_name::<T>(),
index: handle.index(),
}
}
}

/// A strongly typed reference to an arena item.
///
/// A `Handle` value can be used as an index into an [`Arena`] or [`UniqueArena`].
Expand Down Expand Up @@ -123,6 +132,35 @@ pub struct Range<T> {
marker: PhantomData<T>,
}

impl<T> Range<T> {
pub(crate) const fn erase_type(self) -> Range<()> {
let Self { inner, marker: _ } = self;
Range {
inner,
marker: PhantomData,
}
}
}

// NOTE: Keep this diagnostic in sync with that of [`BadHandle`].
#[derive(Clone, Debug, thiserror::Error)]
#[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")]
pub struct BadRangeError {
// This error is used for many `Handle` types, but there's no point in making this generic, so
// we just flatten them all to `Handle<()>` here.
kind: &'static str,
range: Range<()>,
}

impl BadRangeError {
pub fn new<T>(range: Range<T>) -> Self {
Self {
kind: std::any::type_name::<T>(),
range: range.erase_type(),
}
}
}

impl<T> Clone for Range<T> {
fn clone(&self) -> Self {
Range {
Expand Down Expand Up @@ -282,10 +320,9 @@ impl<T> Arena<T> {
}

pub fn try_get(&self, handle: Handle<T>) -> Result<&T, BadHandle> {
self.data.get(handle.index()).ok_or_else(|| BadHandle {
kind: std::any::type_name::<T>(),
index: handle.index(),
})
self.data
.get(handle.index())
.ok_or_else(|| BadHandle::new(handle))
}

/// Get a mutable reference to an element in the arena.
Expand Down Expand Up @@ -320,6 +357,31 @@ impl<T> Arena<T> {
Span::default()
}
}

/// Assert that `handle` is valid for this arena.
pub fn check_contains_handle(&self, handle: Handle<T>) -> Result<(), BadHandle> {
if handle.index() < self.data.len() {
Ok(())
} else {
Err(BadHandle::new(handle))
}
}

/// Assert that `range` is valid for this arena.
pub fn check_contains_range(&self, range: &Range<T>) -> Result<(), BadRangeError> {
// Since `range.inner` is a `Range<u32>`, we only need to
// check that the start precedes the end, and that the end is
// in range.
if range.inner.start > range.inner.end
|| self
.check_contains_handle(Handle::new(range.inner.end.try_into().unwrap()))
.is_err()
{
Err(BadRangeError::new(range.clone()))
} else {
Ok(())
}
}
}

#[cfg(feature = "deserialize")]
Expand Down Expand Up @@ -540,10 +602,18 @@ impl<T: Eq + hash::Hash> UniqueArena<T> {

/// Return this arena's value at `handle`, if that is a valid handle.
pub fn get_handle(&self, handle: Handle<T>) -> Result<&T, BadHandle> {
self.set.get_index(handle.index()).ok_or_else(|| BadHandle {
kind: std::any::type_name::<T>(),
index: handle.index(),
})
self.set
.get_index(handle.index())
.ok_or_else(|| BadHandle::new(handle))
}

/// Assert that `handle` is valid for this arena.
pub fn check_contains_handle(&self, handle: Handle<T>) -> Result<(), BadHandle> {
if handle.index() < self.set.len() {
Ok(())
} else {
Err(BadHandle::new(handle))
}
}
}

Expand Down
15 changes: 7 additions & 8 deletions src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ impl crate::TypeInner {
}
}

pub(super) fn try_size_hlsl(
pub(super) fn size_hlsl(
&self,
types: &crate::UniqueArena<crate::Type>,
constants: &crate::Arena<crate::Constant>,
) -> Result<u32, crate::arena::BadHandle> {
Ok(match *self {
) -> u32 {
match *self {
Self::Matrix {
columns,
rows,
Expand All @@ -58,17 +58,16 @@ impl crate::TypeInner {
Self::Array { base, size, stride } => {
let count = match size {
crate::ArraySize::Constant(handle) => {
let constant = constants.try_get(handle)?;
constant.to_array_length().unwrap_or(1)
constants[handle].to_array_length().unwrap_or(1)
}
// A dynamically-sized array has to have at least one element
crate::ArraySize::Dynamic => 1,
};
let last_el_size = types[base].inner.try_size_hlsl(types, constants)?;
let last_el_size = types[base].inner.size_hlsl(types, constants);
((count - 1) * stride) + last_el_size
}
_ => self.try_size(constants)?,
})
_ => self.size(constants),
}
}

/// Used to generate the name of the wrapped type constructor
Expand Down
5 changes: 1 addition & 4 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,10 +829,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
let ty_inner = &module.types[member.ty].inner;
last_offset = member.offset
+ ty_inner
.try_size_hlsl(&module.types, &module.constants)
.unwrap();
last_offset = member.offset + ty_inner.size_hlsl(&module.types, &module.constants);

// The indentation is only for readability
write!(self.out, "{}", back::INDENT)?;
Expand Down
9 changes: 2 additions & 7 deletions src/proc/layouter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arena::{Arena, BadHandle, Handle, UniqueArena};
use crate::arena::{Arena, Handle, UniqueArena};
use std::{fmt::Display, num::NonZeroU32, ops};

/// A newtype struct where its only valid values are powers of 2
Expand Down Expand Up @@ -130,8 +130,6 @@ pub enum LayoutErrorInner {
InvalidStructMemberType(u32, Handle<crate::Type>),
#[error("Type width must be a power of two")]
NonPowerOfTwoWidth,
#[error("Array size is a bad handle")]
BadHandle(#[from] BadHandle),
}

#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
Expand Down Expand Up @@ -175,10 +173,7 @@ impl Layouter {
use crate::TypeInner as Ti;

for (ty_handle, ty) in types.iter().skip(self.layouts.len()) {
let size = ty
.inner
.try_size(constants)
.map_err(|error| LayoutErrorInner::BadHandle(error).with(ty_handle))?;
let size = ty.inner.size(constants);
let layout = match ty.inner {
Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => {
let alignment = Alignment::new(width as u32)
Expand Down
19 changes: 5 additions & 14 deletions src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,9 @@ impl super::TypeInner {
}
}

pub fn try_size(
&self,
constants: &super::Arena<super::Constant>,
) -> Result<u32, crate::arena::BadHandle> {
Ok(match *self {
/// Get the size of this type.
pub fn size(&self, constants: &super::Arena<super::Constant>) -> u32 {
match *self {
Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32,
Self::Vector {
size,
Expand All @@ -122,8 +120,7 @@ impl super::TypeInner {
} => {
let count = match size {
super::ArraySize::Constant(handle) => {
let constant = constants.try_get(handle)?;
constant.to_array_length().unwrap_or(1)
constants[handle].to_array_length().unwrap_or(1)
}
// A dynamically-sized array has to have at least one element
super::ArraySize::Dynamic => 1,
Expand All @@ -132,13 +129,7 @@ impl super::TypeInner {
}
Self::Struct { span, .. } => span,
Self::Image { .. } | Self::Sampler { .. } | Self::BindingArray { .. } => 0,
})
}

/// Get the size of this type. Panics if the `constants` doesn't contain
/// a referenced handle. This may not happen in a properly validated IR module.
pub fn size(&self, constants: &super::Arena<super::Constant>) -> u32 {
self.try_size(constants).unwrap()
}
}

/// Return the canonical form of `self`, or `None` if it's already in
Expand Down
31 changes: 11 additions & 20 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arena::{Arena, BadHandle, Handle, UniqueArena};
use crate::arena::{Arena, Handle, UniqueArena};

use thiserror::Error;

Expand Down Expand Up @@ -162,8 +162,6 @@ impl crate::ConstantInner {

#[derive(Clone, Debug, Error, PartialEq)]
pub enum ResolveError {
#[error(transparent)]
BadHandle(#[from] BadHandle),
#[error("Index {index} is out of bounds for expression {expr:?}")]
OutOfBoundsIndex {
expr: Handle<crate::Expression>,
Expand Down Expand Up @@ -195,8 +193,6 @@ pub enum ResolveError {
IncompatibleOperands(String),
#[error("Function argument {0} doesn't exist")]
FunctionArgumentNotFound(u32),
#[error("Expression {0:?} depends on expressions that follow")]
ExpressionForwardDependency(Handle<crate::Expression>),
}

pub struct ResolveContext<'a> {
Expand Down Expand Up @@ -403,20 +399,15 @@ impl<'a> ResolveContext<'a> {
}
}
}
crate::Expression::Constant(h) => {
let constant = self.constants.try_get(h)?;
match constant.inner {
crate::ConstantInner::Scalar { width, ref value } => {
TypeResolution::Value(Ti::Scalar {
kind: value.scalar_kind(),
width,
})
}
crate::ConstantInner::Composite { ty, components: _ } => {
TypeResolution::Handle(ty)
}
crate::Expression::Constant(h) => match self.constants[h].inner {
crate::ConstantInner::Scalar { width, ref value } => {
TypeResolution::Value(Ti::Scalar {
kind: value.scalar_kind(),
width,
})
}
}
crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty),
},
crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
Ti::Scalar { kind, width } => {
TypeResolution::Value(Ti::Vector { size, kind, width })
Expand Down Expand Up @@ -450,7 +441,7 @@ impl<'a> ResolveContext<'a> {
TypeResolution::Handle(arg.ty)
}
crate::Expression::GlobalVariable(h) => {
let var = self.global_vars.try_get(h)?;
let var = &self.global_vars[h];
if var.space == crate::AddressSpace::Handle {
TypeResolution::Handle(var.ty)
} else {
Expand All @@ -461,7 +452,7 @@ impl<'a> ResolveContext<'a> {
}
}
crate::Expression::LocalVariable(h) => {
let var = self.local_vars.try_get(h)?;
let var = &self.local_vars[h];
TypeResolution::Value(Ti::Pointer {
base: var.ty,
space: crate::AddressSpace::Function,
Expand Down
9 changes: 2 additions & 7 deletions src/valid/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages,
use crate::span::{AddSpan as _, WithSpan};
use crate::{
arena::{Arena, Handle},
proc::{ResolveContext, ResolveError, TypeResolution},
proc::{ResolveContext, TypeResolution},
};
use std::ops;

Expand Down Expand Up @@ -706,12 +706,7 @@ impl FunctionInfo {
},
};

let ty = resolve_context.resolve(expression, |h| {
self.expressions
.get(h.index())
.map(|ei| &ei.ty)
.ok_or(ResolveError::ExpressionForwardDependency(h))
})?;
let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
self.expressions[handle.index()] = ExpressionInfo {
uniformity,
ref_count: 0,
Expand Down
7 changes: 2 additions & 5 deletions src/valid/compose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@ use crate::{
proc::TypeResolution,
};

use crate::arena::{BadHandle, Handle};
use crate::arena::Handle;

#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ComposeError {
#[error(transparent)]
BadHandle(#[from] BadHandle),
#[error("Composing of type {0:?} can't be done")]
Type(Handle<crate::Type>),
#[error("Composing expects {expected} components but {given} were given")]
Expand All @@ -28,8 +26,7 @@ pub fn validate_compose(
) -> Result<(), ComposeError> {
use crate::TypeInner as Ti;

let self_ty = type_arena.get_handle(self_ty_handle)?;
match self_ty.inner {
match type_arena[self_ty_handle].inner {
// vectors are composed from scalars or other vectors
Ti::Vector { size, kind, width } => {
let mut total = 0;
Expand Down
Loading

0 comments on commit 461fdda

Please sign in to comment.