diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index c2791fe92b4..378ee664749 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -5,6 +5,7 @@ use crate::{Handle, Span}; use crate::front::wgsl::error::Error; use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; +use crate::front::wgsl::Scalar; /// A cooked form of `ast::ConstructorType` that uses Naga types whenever /// possible. @@ -80,7 +81,6 @@ enum Components<'a> { Many { components: Vec>, spans: Vec, - first_component_ty_inner: &'a crate::TypeInner, }, } @@ -131,30 +131,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty_inner, } } - [component, ref rest @ ..] => { - let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx)?; - - let components = std::iter::once(Ok(component)) - .chain( - rest.iter() - .map(|&component| self.expression(component, ctx)), - ) + ref ast_components @ [_, _, ..] => { + let components = ast_components + .iter() + .map(|&expr| self.expression(expr, ctx)) .collect::>()?; - let spans = std::iter::once(span) - .chain( - rest.iter() - .map(|&component| ctx.ast_expressions.get_span(component)), - ) + let spans = ast_components + .iter() + .map(|&expr| ctx.ast_expressions.get_span(expr)) .collect(); - let first_component_ty_inner = super::resolve_inner!(ctx, component); - - Components::Many { - components, - spans, - first_component_ty_inner, - } + Components::Many { components, spans } } }; @@ -255,7 +242,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ) if dst_columns == src_columns && dst_rows == src_rows => { expr = crate::Expression::As { expr: component, - kind: crate::ScalarKind::Float, + kind: dst_scalar.kind, convert: Some(dst_scalar.width), }; } @@ -319,56 +306,39 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; } - // Vector constructor (by elements) + // Vector constructor (by elements), partial + (Components::Many { components, spans }, Constructor::PartialVector { size }) => { + let scalar = + component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + let inner = scalar.to_inner_vector(size); + let ty = ctx.ensure_type_exists(inner); + expr = crate::Expression::Compose { ty, components }; + } + + // Vector constructor (by elements), full type given ( - Components::Many { - components, - first_component_ty_inner: - &crate::TypeInner::Scalar(scalar) | &crate::TypeInner::Vector { scalar, .. }, - .. - }, - Constructor::PartialVector { size }, - ) - | ( - Components::Many { - components, - first_component_ty_inner: - &crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. }, - .. - }, - Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })), + Components::Many { components, .. }, + Constructor::Type((ty, &crate::TypeInner::Vector { .. })), ) => { - let inner = crate::TypeInner::Vector { size, scalar }; - let ty = ctx.ensure_type_exists(inner); expr = crate::Expression::Compose { ty, components }; } // Matrix constructor (by elements) ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Scalar(scalar), - .. - }, + Components::Many { components, spans }, Constructor::PartialMatrix { columns, rows }, ) | ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Scalar { .. }, - .. - }, - Constructor::Type(( - _, - &crate::TypeInner::Matrix { - columns, - rows, - scalar, - }, - )), - ) => { - let vec_ty = - ctx.ensure_type_exists(crate::TypeInner::Vector { scalar, size: rows }); + Components::Many { components, spans }, + Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })), + ) if components.len() == columns as usize * rows as usize => { + let scalar = + component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows)); let components = components .chunks(rows as usize) @@ -393,28 +363,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Matrix constructor (by columns) ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Vector { scalar, .. }, - .. - }, + Components::Many { components, spans }, Constructor::PartialMatrix { columns, rows }, ) | ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Vector { .. }, - .. - }, - Constructor::Type(( - _, - &crate::TypeInner::Matrix { - columns, - rows, - scalar, - }, - )), + Components::Many { components, spans }, + Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })), ) => { + let scalar = + component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { columns, rows, @@ -477,22 +436,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Err(Error::UnexpectedComponents(span)); } - // Parameters are of the wrong type for vector or matrix constructor - ( - Components::Many { spans, .. }, - Constructor::Type(( - _, - &crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. }, - )) - | Constructor::PartialVector { .. } - | Constructor::PartialMatrix { .. }, - ) => { - return Err(Error::InvalidConstructorComponentType(spans[0], 0)); - } - // Other types can't be constructed _ => return Err(Error::TypeNotConstructible(ty_span)), - }; + } let expr = ctx.append_expression(expr, span)?; Ok(expr) @@ -557,3 +503,35 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(handle) } } + +/// Compute a vector or matrix's scalar type from those of its +/// constructor arguments. +/// +/// Given `components`, the arguments given to a vector or matrix +/// constructor, return the scalar type of the vector or matrix's +/// elements. +/// +/// The `components` slice must not be empty. All elements' types must +/// have been resolved. +/// +/// If `components` are definitely not acceptable as arguments to such +/// constructors, return `Err(i)`, where `i` is the index in +/// `components` of some problematic argument. +/// +/// This function doesn't fully type-check the arguments, so it may +/// return `Ok` even when the Naga validator will reject the resulting +/// construction expression later. +fn component_scalar_from_constructor_args( + components: &[Handle], + ctx: &mut ExpressionContext<'_, '_, '_>, +) -> Result { + // Since we don't yet implement abstract types, we can settle for + // just inspecting the first element. + let first = components[0]; + ctx.grow_types(first).map_err(|_| 0_usize)?; + let inner = ctx.typifier()[first].inner_with(&ctx.module.types); + match inner.scalar() { + Some(scalar) => Ok(scalar), + None => Err(0), + } +} diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 40a342f6ce1..e375bb1af31 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -237,6 +237,10 @@ impl crate::Literal { pub const POINTER_SPAN: u32 = 4; impl super::TypeInner { + /// Return the scalar type of `self`. + /// + /// If `inner` is a scalar, vector, or matrix type, return + /// its scalar type. Otherwise, return `None`. pub const fn scalar(&self) -> Option { use crate::TypeInner as Ti; match *self {