Skip to content

Commit

Permalink
[naga wgsl-in] Implement abstract types for consts, constructors.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy committed Nov 22, 2023
1 parent 3765b12 commit cfab70f
Show file tree
Hide file tree
Showing 20 changed files with 1,487 additions and 701 deletions.
2 changes: 1 addition & 1 deletion naga/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ deserialize = ["serde", "bitflags/serde", "indexmap/serde"]
arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"]
spv-in = ["petgraph", "spirv"]
spv-out = ["spirv"]
wgsl-in = ["hexf-parse", "unicode-xid"]
wgsl-in = ["hexf-parse", "unicode-xid", "compact"]
wgsl-out = []
hlsl-out = []
compact = []
Expand Down
20 changes: 20 additions & 0 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ pub enum Error<'a> {
ExpectedPositiveArrayLength(Span),
MissingWorkgroupSize(Span),
ConstantEvaluatorError(ConstantEvaluatorError, Span),
AutoConversion {
dest_span: Span,
dest_type: String,
source_span: Span,
source_type: String,
},
}

impl<'a> Error<'a> {
Expand Down Expand Up @@ -712,6 +718,20 @@ impl<'a> Error<'a> {
)],
notes: vec![],
},
Error::AutoConversion { dest_span, ref dest_type, source_span, ref source_type } => ParseError {
message: format!("automatic conversions cannot convert `{source_type}` to `{dest_type}`"),
labels: vec![
(
dest_span,
format!("a value of type {dest_type} is required here").into(),
),
(
source_span,
format!("this expression has type {source_type}").into(),
)
],
notes: vec![],
}
}
}
}
208 changes: 153 additions & 55 deletions naga/src/front/wgsl/lower/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
components: &[Handle<ast::Expression<'source>>],
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Expression>, Error<'source>> {
use crate::proc::TypeResolution as Tr;

let constructor_h = self.constructor(constructor, ctx)?;

let components = match *components {
[] => Components::None,
[component] => {
let span = ctx.ast_expressions.get_span(component);
let component = self.expression(component, ctx)?;
let component = self.expression_for_abstract(component, ctx)?;
let ty_inner = super::resolve_inner!(ctx, component);

Components::One {
Expand All @@ -134,13 +136,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ref ast_components @ [_, _, ..] => {
let components = ast_components
.iter()
.map(|&expr| self.expression(expr, ctx))
.map(|&expr| self.expression_for_abstract(expr, ctx))
.collect::<Result<_, _>>()?;
let spans = ast_components
.iter()
.map(|&expr| ctx.ast_expressions.get_span(expr))
.collect();

for &component in &components {
ctx.grow_types(component)?;
}

Components::Many { components, spans }
}
};
Expand Down Expand Up @@ -288,56 +294,96 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// Vector constructor (splat)
(
Components::One {
component,
ty_inner: &crate::TypeInner::Scalar(src_scalar),
mut component,
ty_inner: &crate::TypeInner::Scalar(_),
..
},
Constructor::Type((
_,
&crate::TypeInner::Vector {
size,
scalar: dst_scalar,
},
)),
) if dst_scalar == src_scalar => {
Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })),
) => {
ctx.convert_slice_to_common_scalar(std::slice::from_mut(&mut component), scalar)?;
expr = crate::Expression::Splat {
size,
value: component,
};
}

// Vector constructor (by elements), partial
(Components::Many { components, spans }, Constructor::PartialVector { size }) => {
let scalar =
component_scalar_from_constructor_args(&components, ctx).map_err(|index| {
(
Components::Many {
mut components,
spans,
},
Constructor::PartialVector { size },
) => {
let consensus_scalar =
automatic_conversion_consensus(&components, ctx).map_err(|index| {
Error::InvalidConstructorComponentType(spans[index], index as i32)
})?;
let inner = scalar.to_inner_vector(size);
ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?;
let inner = consensus_scalar.to_inner_vector(size);
let ty = ctx.ensure_type_exists(inner);
expr = crate::Expression::Compose { ty, components };
}

// Vector constructor (by elements), full type given
(
Components::Many { components, .. },
Constructor::Type((ty, &crate::TypeInner::Vector { .. })),
Components::Many { mut components, .. },
Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })),
) => {
ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?;
expr = crate::Expression::Compose { ty, components };
}

// Matrix constructor (by elements)
// Matrix constructor (by elements), partial
(
Components::Many { components, spans },
Components::Many {
mut components,
spans,
},
Constructor::PartialMatrix { columns, rows },
)
| (
Components::Many { components, spans },
Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })),
) if components.len() == columns as usize * rows as usize => {
let scalar =
component_scalar_from_constructor_args(&components, ctx).map_err(|index| {
let consensus_scalar =
automatic_conversion_consensus(&components, ctx).map_err(|index| {
Error::InvalidConstructorComponentType(spans[index], index as i32)
})?;
ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?;
let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows));

let components = components
.chunks(rows as usize)
.map(|vec_components| {
ctx.append_expression(
crate::Expression::Compose {
ty: vec_ty,
components: Vec::from(vec_components),
},
Default::default(),
)
})
.collect::<Result<Vec<_>, _>>()?;

let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
columns,
rows,
scalar: consensus_scalar,
});
expr = crate::Expression::Compose { ty, components };
}

// Matrix constructor (by elements), type given
(
Components::Many { mut components, .. },
Constructor::Type((
_,
&crate::TypeInner::Matrix {
columns,
rows,
scalar,
},
)),
) if components.len() == columns as usize * rows as usize => {
let element = Tr::Value(crate::TypeInner::Scalar(scalar));
ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?;
let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows));

let components = components
Expand All @@ -363,28 +409,55 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

// Matrix constructor (by columns)
(
Components::Many { components, spans },
Components::Many {
mut components,
spans,
},
Constructor::PartialMatrix { columns, rows },
)
| (
Components::Many { components, spans },
Components::Many {
mut components,
spans,
},
Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })),
) => {
let scalar =
component_scalar_from_constructor_args(&components, ctx).map_err(|index| {
let consensus_scalar =
automatic_conversion_consensus(&components, ctx).map_err(|index| {
Error::InvalidConstructorComponentType(spans[index], index as i32)
})?;
ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?;
let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
columns,
rows,
scalar,
scalar: consensus_scalar,
});
expr = crate::Expression::Compose { ty, components };
}

// Array constructor - infer type
(components, Constructor::PartialArray) => {
let components = components.into_components_vec();
let mut components = components.into_components_vec();
if let Ok(consensus_scalar) = automatic_conversion_consensus(&components, ctx) {
// Note that this will *not* necessarily convert all the
// components to the same type! The `automatic_conversion_consensus`
// function only considers the parameters' leaf scalar
// types; the parameters themselves could be any mix of
// vectors, matrices, and scalars.
//
// But *if* it is possible for this array construction
// expression to be well-typed at all, then all the
// parameters must have the same type constructors (vec,
// matrix, scalar) applied to their leaf scalars, so
// reconciling their scalars is always the right thing to
// do. And if this array construction is not well-typed,
// these conversions will not make it so, and we can let
// validation catch the error.
ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?;
} else {
// There's no consensus scalar. Emit the `Compose`
// expression anyway, and let validation catch the problem.
}

let base = ctx.register_type(components[0])?;

Expand All @@ -403,15 +476,30 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
expr = crate::Expression::Compose { ty, components };
}

// Array or Struct constructor
// Array constructor, explicit type
(components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => {
let mut components = components.into_components_vec();
ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), span)?;
expr = crate::Expression::Compose { ty, components };
}

// Struct constructor
(
components,
Constructor::Type((
ty,
&crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. },
)),
Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })),
) => {
let components = components.into_components_vec();
let mut components = components.into_components_vec();
let struct_ty_span = ctx.module.types.get_span(ty);

// Make a vector of the members' type handles in advance, to
// avoid borrowing `members` from `ctx` while we generate
// new code.
let members: Vec<Handle<crate::Type>> = members.iter().map(|m| m.ty).collect();

for (component, &ty) in components.iter_mut().zip(&members) {
*component =
ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?;
}
expr = crate::Expression::Compose { ty, components };
}

Expand Down Expand Up @@ -504,12 +592,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
}

/// Compute a vector or matrix's scalar type from those of its
/// constructor arguments.
/// Find the consensus scalar of `components` under WGSL's automatic
/// conversions.
///
/// Given `components`, the arguments given to a vector or matrix
/// constructor, return the scalar type of the vector or matrix's
/// elements.
/// If `components` can all be converted to any common scalar via
/// WGSL's automatic conversions, return the best such scalar.
///
/// The `components` slice must not be empty. All elements' types must
/// have been resolved.
Expand All @@ -518,20 +605,31 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
/// constructors, return `Err(i)`, where `i` is the index in
/// `components` of some problematic argument.
///
/// This function doesn't fully type-check the arguments, so it may
/// return `Ok` even when the Naga validator will reject the resulting
/// This function doesn't fully type-check the arguments - it only
/// considers their leaf scalar types. This means it may return `Ok`
/// even when the Naga validator will reject the resulting
/// construction expression later.
fn component_scalar_from_constructor_args(
fn automatic_conversion_consensus(
components: &[Handle<crate::Expression>],
ctx: &mut ExpressionContext<'_, '_, '_>,
ctx: &ExpressionContext<'_, '_, '_>,
) -> Result<Scalar, usize> {
// Since we don't yet implement abstract types, we can settle for
// just inspecting the first element.
let first = components[0];
ctx.grow_types(first).map_err(|_| 0_usize)?;
let inner = ctx.typifier()[first].inner_with(&ctx.module.types);
match inner.scalar() {
Some(scalar) => Ok(scalar),
None => Err(0),
log::trace!("JIMB: automatic_conversion_consensus");
let types = &ctx.module.types;
let mut inners = components
.iter()
.map(|&c| ctx.typifier()[c].inner_with(types));
let mut best = inners.next().unwrap().scalar().ok_or(0_usize)?;
log::trace!(" start: {best:?}");
for (inner, i) in inners.zip(1..) {
let scalar = inner.scalar().ok_or(i)?;
match best.automatic_conversion_join(scalar) {
Some(new_best) => {
best = new_best;
log::trace!(" new: {best:?}");
}
None => return Err(i),
}
}

Ok(best)
}
Loading

0 comments on commit cfab70f

Please sign in to comment.