From 587b62ddd1ca11cdecc43cd578bc173822e54e40 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Tue, 28 Nov 2023 13:22:16 +0100 Subject: [PATCH 1/3] [wgsl-in] add support for override declarations [ir] split overrides from constants --- naga/src/back/dot/mod.rs | 1 + naga/src/back/glsl/mod.rs | 1 + naga/src/back/hlsl/writer.rs | 3 + naga/src/back/msl/writer.rs | 3 + naga/src/back/spv/block.rs | 3 + naga/src/back/wgsl/writer.rs | 3 + naga/src/compact/expressions.rs | 21 +++++++ naga/src/compact/functions.rs | 4 ++ naga/src/compact/mod.rs | 35 +++++++++++ naga/src/front/spv/function.rs | 2 + naga/src/front/spv/mod.rs | 3 +- naga/src/front/wgsl/error.rs | 17 ++++-- naga/src/front/wgsl/index.rs | 1 + naga/src/front/wgsl/lower/mod.rs | 70 ++++++++++++++++++++-- naga/src/front/wgsl/parse/ast.rs | 9 +++ naga/src/front/wgsl/parse/mod.rs | 30 ++++++++++ naga/src/front/wgsl/to_wgsl.rs | 1 + naga/src/lib.rs | 37 ++++++++---- naga/src/proc/constant_evaluator.rs | 3 + naga/src/proc/mod.rs | 2 + naga/src/proc/typifier.rs | 3 + naga/src/valid/analyzer.rs | 3 +- naga/src/valid/expression.rs | 2 +- naga/src/valid/handles.rs | 39 +++++++++++- naga/src/valid/mod.rs | 57 ++++++++++++++++++ naga/tests/in/overrides.wgsl | 12 ++++ naga/tests/out/analysis/overrides.info.ron | 22 +++++++ naga/tests/out/ir/access.compact.ron | 1 + naga/tests/out/ir/access.ron | 1 + naga/tests/out/ir/collatz.compact.ron | 1 + naga/tests/out/ir/collatz.ron | 1 + naga/tests/out/ir/overrides.compact.ron | 64 ++++++++++++++++++++ naga/tests/out/ir/overrides.ron | 67 +++++++++++++++++++++ naga/tests/out/ir/shadow.compact.ron | 1 + naga/tests/out/ir/shadow.ron | 1 + naga/tests/snapshots.rs | 8 +++ naga/tests/wgsl_errors.rs | 4 +- 37 files changed, 510 insertions(+), 26 deletions(-) create mode 100644 naga/tests/in/overrides.wgsl create mode 100644 naga/tests/out/analysis/overrides.info.ron create mode 100644 naga/tests/out/ir/overrides.compact.ron create mode 100644 naga/tests/out/ir/overrides.ron diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 1556371df1..d128c855ca 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -404,6 +404,7 @@ fn write_function_expressions( let (label, color_id) = match *expression { E::Literal(_) => ("Literal".into(), 2), E::Constant(_) => ("Constant".into(), 2), + E::Override(_) => ("Override".into(), 2), E::ZeroValue(_) => ("ZeroValue".into(), 2), E::Compose { ref components, .. } => { payload = Some(Payload::Arguments(components)); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 7944e46162..83e270e757 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2525,6 +2525,7 @@ impl<'a, W: Write> Writer<'a, W> { |writer, expr| writer.write_expr(expr, ctx), )?; } + Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())), // `Access` is applied to arrays, vectors and matrices and is written as indexing Expression::Access { base, index } => { self.write_expr(base, ctx)?; diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 900d3cbc6d..a3fa5b48fa 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2151,6 +2151,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => { + return Err(Error::Unimplemented("overrides are WIP".into())) + } // All of the multiplication can be expressed as `mul`, // except vector * vector, which needs to use the "*" operator. Expression::Binary { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 17154c3cd5..21e1a2a106 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1394,6 +1394,9 @@ impl Writer { |writer, context, expr| writer.put_expression(expr, context, true), )?; } + crate::Expression::Override(_) => { + return Err(Error::FeatureNotImplemented("overrides are WIP".into())) + } crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { // This is an acceptable place to generate a `ReadZeroSkipWrite` check. diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index df6ecd00ff..2354527c12 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -239,6 +239,9 @@ impl<'w> BlockContext<'w> { let init = self.ir_module.constants[handle].init; self.writer.constant_ids[init.index()] } + crate::Expression::Override(_) => { + return Err(Error::FeatureNotImplemented("overrides are WIP")) + } crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 10da339968..ac400dd366 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1180,6 +1180,9 @@ impl Writer { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => { + return Err(Error::Unimplemented("overrides are WIP".into())) + } Expression::FunctionArgument(pos) => { let name_key = func_ctx.argument_key(pos); let name = &self.names[&name_key]; diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 301bbe3240..ccc3409613 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle}; pub struct ExpressionTracer<'tracer> { pub constants: &'tracer Arena, + pub overrides: &'tracer Arena, /// The arena in which we are currently tracing expressions. pub expressions: &'tracer Arena, @@ -13,6 +14,9 @@ pub struct ExpressionTracer<'tracer> { /// The used map for `constants`. pub constants_used: &'tracer mut HandleSet, + /// The used map for `overrides`. + pub overrides_used: &'tracer mut HandleSet, + /// The used set for `arena`. /// /// This points to whatever arena holds the expressions we are @@ -88,6 +92,22 @@ impl<'tracer> ExpressionTracer<'tracer> { None => self.expressions_used.insert(init), } } + Ex::Override(handle) => { + self.overrides_used.insert(handle); + // Overrides and expressions are mutually recursive, which + // complicates our nice one-pass algorithm. However, since + // overrides don't refer to each other, we can get around + // this by looking *through* each override and marking its + // initializer as used. Since `expr` refers to the override, + // and the override refers to the initializer, it must + // precede `expr` in the arena. + if let Some(init) = self.overrides[handle].init { + match self.const_expressions_used { + Some(ref mut used) => used.insert(init), + None => self.expressions_used.insert(init), + } + } + } Ex::ZeroValue(ty) => self.types_used.insert(ty), Ex::Compose { ty, ref components } => { self.types_used.insert(ty); @@ -221,6 +241,7 @@ impl ModuleMap { // Expressions that contain handles that need to be adjusted. Ex::Constant(ref mut constant) => self.constants.adjust(constant), + Ex::Override(ref mut override_) => self.overrides.adjust(override_), Ex::ZeroValue(ref mut ty) => self.types.adjust(ty), Ex::Compose { ref mut ty, diff --git a/naga/src/compact/functions.rs b/naga/src/compact/functions.rs index b0d08c7e96..f11df7be1a 100644 --- a/naga/src/compact/functions.rs +++ b/naga/src/compact/functions.rs @@ -4,9 +4,11 @@ use super::{FunctionMap, ModuleMap}; pub struct FunctionTracer<'a> { pub function: &'a crate::Function, pub constants: &'a crate::Arena, + pub overrides: &'a crate::Arena, pub types_used: &'a mut HandleSet, pub constants_used: &'a mut HandleSet, + pub overrides_used: &'a mut HandleSet, pub const_expressions_used: &'a mut HandleSet, /// Function-local expressions used. @@ -47,10 +49,12 @@ impl<'a> FunctionTracer<'a> { fn as_expression(&mut self) -> super::expressions::ExpressionTracer { super::expressions::ExpressionTracer { constants: self.constants, + overrides: self.overrides, expressions: &self.function.expressions, types_used: self.types_used, constants_used: self.constants_used, + overrides_used: self.overrides_used, expressions_used: &mut self.expressions_used, const_expressions_used: Some(&mut self.const_expressions_used), } diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index 7dfb8ee80d..3aad85f348 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -54,6 +54,14 @@ pub fn compact(module: &mut crate::Module) { } } + // We treat all overrides as used by definition. + for (handle, override_) in module.overrides.iter() { + module_tracer.overrides_used.insert(handle); + if let Some(init) = override_.init { + module_tracer.const_expressions_used.insert(init); + } + } + // We assume that all functions are used. // // Observe which types, constant expressions, constants, and @@ -99,6 +107,11 @@ pub fn compact(module: &mut crate::Module) { module_tracer.types_used.insert(constant.ty); } } + for (handle, override_) in module.overrides.iter() { + if module_tracer.overrides_used.contains(handle) { + module_tracer.types_used.insert(override_.ty); + } + } // Treat all named types as used. for (handle, ty) in module.types.iter() { @@ -158,6 +171,20 @@ pub fn compact(module: &mut crate::Module) { } }); + // Drop unused overrides in place, reusing existing storage. + log::trace!("adjusting overrides"); + module.overrides.retain_mut(|handle, override_| { + if module_map.overrides.used(handle) { + module_map.types.adjust(&mut override_.ty); + if let Some(init) = override_.init.as_mut() { + module_map.const_expressions.adjust(init); + } + true + } else { + false + } + }); + // Adjust global variables' types and initializers. log::trace!("adjusting global variables"); for (_, global) in module.global_variables.iter_mut() { @@ -193,6 +220,7 @@ struct ModuleTracer<'module> { module: &'module crate::Module, types_used: HandleSet, constants_used: HandleSet, + overrides_used: HandleSet, const_expressions_used: HandleSet, } @@ -202,6 +230,7 @@ impl<'module> ModuleTracer<'module> { module, types_used: HandleSet::for_arena(&module.types), constants_used: HandleSet::for_arena(&module.constants), + overrides_used: HandleSet::for_arena(&module.overrides), const_expressions_used: HandleSet::for_arena(&module.const_expressions), } } @@ -235,8 +264,10 @@ impl<'module> ModuleTracer<'module> { expressions::ExpressionTracer { expressions: &self.module.const_expressions, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, + overrides_used: &mut self.overrides_used, expressions_used: &mut self.const_expressions_used, const_expressions_used: None, } @@ -249,8 +280,10 @@ impl<'module> ModuleTracer<'module> { FunctionTracer { function, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, + overrides_used: &mut self.overrides_used, const_expressions_used: &mut self.const_expressions_used, expressions_used: HandleSet::for_arena(&function.expressions), } @@ -260,6 +293,7 @@ impl<'module> ModuleTracer<'module> { struct ModuleMap { types: HandleMap, constants: HandleMap, + overrides: HandleMap, const_expressions: HandleMap, } @@ -268,6 +302,7 @@ impl From> for ModuleMap { ModuleMap { types: HandleMap::from_set(used.types_used), constants: HandleMap::from_set(used.constants_used), + overrides: HandleMap::from_set(used.overrides_used), const_expressions: HandleMap::from_set(used.const_expressions_used), } } diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 198d9c52dd..8a7e736edd 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -128,6 +128,7 @@ impl> super::Frontend { expressions: &mut fun.expressions, local_arena: &mut fun.local_variables, const_arena: &mut module.constants, + overrides: &mut module.overrides, const_expressions: &mut module.const_expressions, type_arena: &module.types, global_arena: &module.global_variables, @@ -573,6 +574,7 @@ impl<'function> BlockContext<'function> { crate::proc::GlobalCtx { types: self.type_arena, constants: self.const_arena, + overrides: self.overrides, const_expressions: self.const_expressions, } } diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index ec67dc5524..4a5d34a452 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -532,6 +532,7 @@ struct BlockContext<'function> { local_arena: &'function mut Arena, /// Constants arena of the module being processed const_arena: &'function mut Arena, + overrides: &'function mut Arena, const_expressions: &'function mut Arena, /// Type arena of the module being processed type_arena: &'function UniqueArena, @@ -3933,7 +3934,7 @@ impl> Frontend { Op::TypeImage => self.parse_type_image(inst, &mut module), Op::TypeSampledImage => self.parse_type_sampled_image(inst), Op::TypeSampler => self.parse_type_sampler(inst, &mut module), - Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), + Op::Constant => self.parse_constant(inst, &mut module), Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module), diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index f5acbe2d65..cea101b3d6 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -190,7 +190,7 @@ pub enum Error<'a> { expected: String, got: String, }, - MissingType(Span), + DeclMissingTypeAndInit(Span), MissingAttribute(&'static str, Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), @@ -251,6 +251,7 @@ pub enum Error<'a> { ExpectedPositiveArrayLength(Span), MissingWorkgroupSize(Span), ConstantEvaluatorError(ConstantEvaluatorError, Span), + PipelineConstantIDValue(Span), } impl<'a> Error<'a> { @@ -500,11 +501,11 @@ impl<'a> Error<'a> { notes: vec![], } } - Error::MissingType(name_span) => ParseError { - message: format!("variable `{}` needs a type", &source[name_span]), + Error::DeclMissingTypeAndInit(name_span) => ParseError { + message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]), labels: vec![( name_span, - format!("definition of `{}`", &source[name_span]).into(), + "needs a type specifier or initializer".into(), )], notes: vec![], }, @@ -712,6 +713,14 @@ impl<'a> Error<'a> { )], notes: vec![], }, + Error::PipelineConstantIDValue(span) => ParseError { + message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(), + labels: vec![( + span, + "must be between 0 and 65535 inclusive".into(), + )], + notes: vec![], + }, } } } diff --git a/naga/src/front/wgsl/index.rs b/naga/src/front/wgsl/index.rs index a5524fe8f1..593405508f 100644 --- a/naga/src/front/wgsl/index.rs +++ b/naga/src/front/wgsl/index.rs @@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> { ast::GlobalDeclKind::Fn(ref f) => f.name, ast::GlobalDeclKind::Var(ref v) => v.name, ast::GlobalDeclKind::Const(ref c) => c.name, + ast::GlobalDeclKind::Override(ref o) => o.name, ast::GlobalDeclKind::Struct(ref s) => s.name, ast::GlobalDeclKind::Type(ref t) => t.name, } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index b7550d3931..154dbc4472 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -784,6 +784,7 @@ enum LoweredGlobalDecl { Function(Handle), Var(Handle), Const(Handle), + Override(Handle), Type(Handle), EntryPoint, } @@ -933,6 +934,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ctx.globals .insert(c.name.name, LoweredGlobalDecl::Const(handle)); } + ast::GlobalDeclKind::Override(ref o) => { + let init = o + .init + .map(|init| self.expression(init, &mut ctx.as_const())) + .transpose()?; + let inferred_type = init + .map(|init| ctx.as_const().register_type(init)) + .transpose()?; + + let explicit_ty = + o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) + .transpose()?; + + let id = + o.id.map(|id| self.const_u32(id, &mut ctx.as_const())) + .transpose()?; + + let id = if let Some((id, id_span)) = id { + Some( + u16::try_from(id) + .map_err(|_| Error::PipelineConstantIDValue(id_span))?, + ) + } else { + None + }; + + let ty = match (explicit_ty, inferred_type) { + (Some(explicit_ty), Some(inferred_type)) => { + if explicit_ty == inferred_type { + explicit_ty + } else { + let gctx = ctx.module.to_ctx(); + return Err(Error::InitializationTypeMismatch { + name: o.name.span, + expected: explicit_ty.to_wgsl(&gctx), + got: inferred_type.to_wgsl(&gctx), + }); + } + } + (Some(explicit_ty), None) => explicit_ty, + (None, Some(inferred_type)) => inferred_type, + (None, None) => { + return Err(Error::DeclMissingTypeAndInit(o.name.span)); + } + }; + + let handle = ctx.module.overrides.append( + crate::Override { + name: Some(o.name.name.to_string()), + id, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(o.name.name, LoweredGlobalDecl::Override(handle)); + } ast::GlobalDeclKind::Struct(ref s) => { let handle = self.r#struct(s, span, &mut ctx)?; ctx.globals @@ -1160,7 +1220,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .as_expression(block, &mut emitter) .register_type(initializer)?, (None, None) => { - return Err(Error::MissingType(v.name.span)); + return Err(Error::DeclMissingTypeAndInit(v.name.span)); } }; @@ -1715,9 +1775,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; Ok(Some(handle)) } - Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => { - Err(Error::Unexpected(function.span, ExpectedToken::Function)) - } + Some( + &LoweredGlobalDecl::Const(_) + | &LoweredGlobalDecl::Override(_) + | &LoweredGlobalDecl::Var(_), + ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)), Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)), Some(&LoweredGlobalDecl::Function(function)) => { let arguments = arguments diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index dbaac523cb..ea8013ee7c 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> { Fn(Function<'a>), Var(GlobalVariable<'a>), Const(Const<'a>), + Override(Override<'a>), Struct(Struct<'a>), Type(TypeAlias<'a>), } @@ -200,6 +201,14 @@ pub struct Const<'a> { pub init: Handle>, } +#[derive(Debug)] +pub struct Override<'a> { + pub name: Ident<'a>, + pub id: Option>>, + pub ty: Option>>, + pub init: Option>>, +} + /// The size of an [`Array`] or [`BindingArray`]. /// /// [`Array`]: Type::Array diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 51fc2f013b..810e67f9fe 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -2170,6 +2170,7 @@ impl Parser { let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); + let mut id = ParsedAttribute::default(); let mut dependencies = FastIndexSet::default(); let mut ctx = ExpressionContext { @@ -2193,6 +2194,11 @@ impl Parser { bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } + ("id", name_span) => { + lexer.expect(Token::Paren('('))?; + id.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } ("vertex", name_span) => { stage.set(crate::ShaderStage::Vertex, name_span)?; } @@ -2283,6 +2289,30 @@ impl Parser { Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) } + (Token::Word("override"), _) => { + let name = lexer.next_ident()?; + + let ty = if lexer.skip(Token::Separator(':')) { + Some(self.type_decl(lexer, &mut ctx)?) + } else { + None + }; + + let init = if lexer.skip(Token::Operation('=')) { + Some(self.general_expression(lexer, &mut ctx)?) + } else { + None + }; + + lexer.expect(Token::Separator(';'))?; + + Some(ast::GlobalDeclKind::Override(ast::Override { + name, + id: id.value, + ty, + init, + })) + } (Token::Word("var"), _) => { let mut var = self.variable_decl(lexer, &mut ctx)?; var.binding = binding.take(); diff --git a/naga/src/front/wgsl/to_wgsl.rs b/naga/src/front/wgsl/to_wgsl.rs index cdfa1f0b1f..b9065d10f4 100644 --- a/naga/src/front/wgsl/to_wgsl.rs +++ b/naga/src/front/wgsl/to_wgsl.rs @@ -224,6 +224,7 @@ mod tests { let gctx = crate::proc::GlobalCtx { types: &types, constants: &crate::Arena::new(), + overrides: &crate::Arena::new(), const_expressions: &crate::Arena::new(), }; let array = crate::TypeInner::Array { diff --git a/naga/src/lib.rs b/naga/src/lib.rs index c9c12e8bfa..c3a43873cd 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -175,7 +175,7 @@ tree. A Naga *constant expression* is one of the following [`Expression`] variants, whose operands (if any) are also constant expressions: - [`Literal`] -- [`Constant`], for [`Constant`s][const_type] whose `override` is `None` +- [`Constant`], for [`Constant`]s - [`ZeroValue`], for fixed-size types - [`Compose`] - [`Access`] @@ -194,8 +194,7 @@ A constant expression can be evaluated at module translation time. ## Override expressions A Naga *override expression* is the same as a [constant expression], -except that it is also allowed to refer to [`Constant`s][const_type] -whose `override` is something other than `None`. +except that it is also allowed to reference other [`Override`]s. An override expression can be evaluated at pipeline creation time. @@ -238,8 +237,6 @@ An override expression can be evaluated at pipeline creation time. [`Math`]: Expression::Math [`As`]: Expression::As -[const_type]: Constant - [constant expression]: index.html#constant-expressions */ @@ -871,6 +868,25 @@ pub enum Literal { Bool(bool), } +/// Pipeline-overridable constant. +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct Override { + pub name: Option, + /// Pipeline Constant ID. + pub id: Option, + pub ty: Handle, + + /// The default value of the pipeline-overridable constant. + /// + /// This [`Handle`] refers to [`Module::const_expressions`], not + /// any [`Function::expressions`] arena. + pub init: Option>, +} + /// Constant value. #[derive(Debug, PartialEq)] #[cfg_attr(feature = "clone", derive(Clone))] @@ -885,13 +901,6 @@ pub struct Constant { /// /// This [`Handle`] refers to [`Module::const_expressions`], not /// any [`Function::expressions`] arena. - /// - /// If `override` is `None`, then this must be a Naga - /// [constant expression]. Otherwise, this may be a Naga - /// [override expression] or [constant expression]. - /// - /// [constant expression]: index.html#constant-expressions - /// [override expression]: index.html#override-expressions pub init: Handle, } @@ -1277,6 +1286,8 @@ pub enum Expression { Literal(Literal), /// Constant value. Constant(Handle), + /// Pipeline-overridable constant. + Override(Handle), /// Zero value of a type. ZeroValue(Handle), /// Composite expression. @@ -2019,6 +2030,8 @@ pub struct Module { pub special_types: SpecialTypes, /// Arena for the constants defined in this module. pub constants: Arena, + /// Arena for the pipeline-overridable constants defined in this module. + pub overrides: Arena, /// Arena for the global variables defined in this module. pub global_variables: Arena, /// [Constant expressions] and [override expressions] used by this module. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 2985253057..752228385e 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -337,6 +337,9 @@ impl<'a> ConstantEvaluator<'a> { // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } + Expression::Override(_) => Err(ConstantEvaluatorError::NotImplemented( + "overrides are WIP".into(), + )), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index fd99cef6d1..9ccd7240d3 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -604,6 +604,7 @@ impl crate::Module { GlobalCtx { types: &self.types, constants: &self.constants, + overrides: &self.overrides, const_expressions: &self.const_expressions, } } @@ -619,6 +620,7 @@ pub(super) enum U32EvalError { pub struct GlobalCtx<'a> { pub types: &'a crate::UniqueArena, pub constants: &'a crate::Arena, + pub overrides: &'a crate::Arena, pub const_expressions: &'a crate::Arena, } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 9c4403445c..845b35cb4d 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -185,6 +185,7 @@ pub enum ResolveError { pub struct ResolveContext<'a> { pub constants: &'a Arena, + pub overrides: &'a Arena, pub types: &'a UniqueArena, pub special_types: &'a crate::SpecialTypes, pub global_vars: &'a Arena, @@ -202,6 +203,7 @@ impl<'a> ResolveContext<'a> { ) -> Self { Self { constants: &module.constants, + overrides: &module.overrides, types: &module.types, special_types: &module.special_types, global_vars: &module.global_variables, @@ -407,6 +409,7 @@ impl<'a> ResolveContext<'a> { }, crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()), crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty), + crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty), crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty), crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), crate::Expression::FunctionArgument(index) => { diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index df6fc5e9b0..17c76b2738 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -527,7 +527,7 @@ impl FunctionInfo { non_uniform_result: self.add_ref(vector), requirements: UniformityRequirements::empty(), }, - E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(), + E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(), E::Compose { ref components, .. } => { let non_uniform_result = components .iter() @@ -1139,6 +1139,7 @@ fn uniform_control_flow() { }; let resolve_context = ResolveContext { constants: &Arena::new(), + overrides: &Arena::new(), types: &type_arena, special_types: &crate::SpecialTypes::default(), global_vars: &global_var_arena, diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 733b9bb614..0c82fd951c 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -343,7 +343,7 @@ impl super::Validator { self.validate_literal(literal)?; ShaderStages::all() } - E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(), E::Compose { ref components, ty } => { validate_compose( ty, diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 1884c01303..0643b1c9f5 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -31,6 +31,7 @@ impl super::Validator { pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { let &crate::Module { ref constants, + ref overrides, ref entry_points, ref functions, ref global_variables, @@ -68,7 +69,7 @@ impl super::Validator { } for handle_and_expr in const_expressions.iter() { - Self::validate_const_expression_handles(handle_and_expr, constants, types)?; + Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?; } let validate_type = |handle| Self::validate_type_handle(handle, types); @@ -81,6 +82,19 @@ impl super::Validator { validate_const_expr(init)?; } + for (_handle, override_) in overrides.iter() { + let &crate::Override { + name: _, + id: _, + ty, + init, + } = override_; + validate_type(ty)?; + if let Some(init_expr) = init { + validate_const_expr(init_expr)?; + } + } + for (_handle, global_variable) in global_variables.iter() { let &crate::GlobalVariable { name: _, @@ -135,6 +149,7 @@ impl super::Validator { Self::validate_expression_handles( handle_and_expr, constants, + overrides, const_expressions, types, local_variables, @@ -181,6 +196,13 @@ impl super::Validator { handle.check_valid_for(constants).map(|_| ()) } + fn validate_override_handle( + handle: Handle, + overrides: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(overrides).map(|_| ()) + } + fn validate_expression_handle( handle: Handle, expressions: &Arena, @@ -198,9 +220,11 @@ impl super::Validator { fn validate_const_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, + overrides: &Arena, types: &UniqueArena, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { @@ -209,6 +233,12 @@ impl super::Validator { validate_constant(constant)?; handle.check_dep(constants[constant].init)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + if let Some(init) = overrides[override_].init { + handle.check_dep(init)?; + } + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -225,6 +255,7 @@ impl super::Validator { fn validate_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, + overrides: &Arena, const_expressions: &Arena, types: &UniqueArena, local_variables: &Arena, @@ -234,6 +265,7 @@ impl super::Validator { current_function: Option>, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_const_expr = |handle| Self::validate_expression_handle(handle, const_expressions); let validate_type = |handle| Self::validate_type_handle(handle, types); @@ -255,6 +287,9 @@ impl super::Validator { crate::Expression::Constant(constant) => { validate_constant(constant)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -659,6 +694,7 @@ fn constant_deps() { let mut const_exprs = Arena::new(); let mut fun_exprs = Arena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let i32_handle = types.insert( Type { @@ -686,6 +722,7 @@ fn constant_deps() { assert!(super::Validator::validate_const_expression_handles( handle_and_expr, &constants, + &overrides, &types, ) .is_err()); diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 70a4d39d2a..f5a2414e2f 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -182,6 +182,16 @@ pub enum ConstantError { NonConstructibleType, } +#[derive(Clone, Debug, thiserror::Error)] +pub enum OverrideError { + #[error("The type doesn't match the override")] + InvalidType, + #[error("The type is not constructible")] + NonConstructibleType, + #[error("The type is not a scalar")] + TypeNotScalar, +} + #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { #[error(transparent)] @@ -205,6 +215,12 @@ pub enum ValidationError { name: String, source: ConstantError, }, + #[error("Override {handle:?} '{name}' is invalid")] + Override { + handle: Handle, + name: String, + source: OverrideError, + }, #[error("Global variable {handle:?} '{name}' is invalid")] GlobalVariable { handle: Handle, @@ -327,6 +343,35 @@ impl Validator { Ok(()) } + fn validate_override( + &self, + handle: Handle, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), OverrideError> { + let o = &gctx.overrides[handle]; + + let type_info = &self.types[o.ty.index()]; + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(OverrideError::NonConstructibleType); + } + + let decl_ty = &gctx.types[o.ty].inner; + match decl_ty { + &crate::TypeInner::Scalar(_) => {} + _ => return Err(OverrideError::TypeNotScalar), + } + + if let Some(init) = o.init { + let init_ty = mod_info[init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(OverrideError::InvalidType); + } + } + + Ok(()) + } + /// Check the given module to be valid. pub fn validate( &mut self, @@ -404,6 +449,18 @@ impl Validator { .with_span_handle(handle, &module.constants) })? } + + for (handle, override_) in module.overrides.iter() { + self.validate_override(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::Override { + handle, + name: override_.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.overrides) + })? + } } for (var_handle, var) in module.global_variables.iter() { diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl new file mode 100644 index 0000000000..4cbc469f54 --- /dev/null +++ b/naga/tests/in/overrides.wgsl @@ -0,0 +1,12 @@ +@id(0) override has_point_light: bool = true; // Algorithmic control +@id(1200) override specular_param: f32 = 2.3; // Numeric control +@id(1300) override gain: f32; // Must be overridden + override width: f32 = 0.0; // Specified at the API level using + // the name "width". + override depth: f32; // Specified at the API level using + // the name "depth". + // Must be overridden. + // override height = 2 * depth; // The default value + // (if not set at the API level), + // depends on another + // overridable constant. \ No newline at end of file diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron new file mode 100644 index 0000000000..d9932bf034 --- /dev/null +++ b/naga/tests/out/analysis/overrides.info.ron @@ -0,0 +1,22 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [], + const_expression_types: [ + Value(Scalar(( + kind: Bool, + width: 1, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + ], +) \ No newline at end of file diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 70ea0c4bb5..b330cae647 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -324,6 +324,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("global_const"), diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 55d27c97eb..3ac6295e6b 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -363,6 +363,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("global_const"), diff --git a/naga/tests/out/ir/collatz.compact.ron b/naga/tests/out/ir/collatz.compact.ron index cfc3bfa0ee..fe4af55c1b 100644 --- a/naga/tests/out/ir/collatz.compact.ron +++ b/naga/tests/out/ir/collatz.compact.ron @@ -46,6 +46,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("v_indices"), diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index effde120a5..ecb6144603 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -46,6 +46,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("v_indices"), diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron new file mode 100644 index 0000000000..a68d887c16 --- /dev/null +++ b/naga/tests/out/ir/overrides.compact.ron @@ -0,0 +1,64 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("has_point_light"), + id: Some(0), + ty: 1, + init: Some(1), + ), + ( + name: Some("specular_param"), + id: Some(1200), + ty: 2, + init: Some(2), + ), + ( + name: Some("gain"), + id: Some(1300), + ty: 2, + init: None, + ), + ( + name: Some("width"), + id: None, + ty: 2, + init: Some(3), + ), + ( + name: Some("depth"), + id: None, + ty: 2, + init: None, + ), + ], + global_variables: [], + const_expressions: [ + Literal(Bool(true)), + Literal(F32(2.3)), + Literal(F32(0.0)), + ], + functions: [], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron new file mode 100644 index 0000000000..696e18d1b4 --- /dev/null +++ b/naga/tests/out/ir/overrides.ron @@ -0,0 +1,67 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("has_point_light"), + id: Some(0), + ty: 1, + init: Some(1), + ), + ( + name: Some("specular_param"), + id: Some(1200), + ty: 2, + init: Some(3), + ), + ( + name: Some("gain"), + id: Some(1300), + ty: 2, + init: None, + ), + ( + name: Some("width"), + id: None, + ty: 2, + init: Some(6), + ), + ( + name: Some("depth"), + id: None, + ty: 2, + init: None, + ), + ], + global_variables: [], + const_expressions: [ + Literal(Bool(true)), + Literal(I32(0)), + Literal(F32(2.3)), + Literal(I32(1200)), + Literal(I32(1300)), + Literal(F32(0.0)), + ], + functions: [], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/shadow.compact.ron b/naga/tests/out/ir/shadow.compact.ron index 4e65180691..fab0f1e2f6 100644 --- a/naga/tests/out/ir/shadow.compact.ron +++ b/naga/tests/out/ir/shadow.compact.ron @@ -253,6 +253,7 @@ init: 22, ), ], + overrides: [], global_variables: [ ( name: Some("t_shadow"), diff --git a/naga/tests/out/ir/shadow.ron b/naga/tests/out/ir/shadow.ron index 0b2310284a..9acbbdaadd 100644 --- a/naga/tests/out/ir/shadow.ron +++ b/naga/tests/out/ir/shadow.ron @@ -456,6 +456,7 @@ init: 38, ), ], + overrides: [], global_variables: [ ( name: Some("t_shadow"), diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index af21857bc7..6d2c3f4242 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -787,6 +787,14 @@ fn convert_wgsl() { "f64", Targets::SPIRV | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "overrides", + Targets::IR | Targets::ANALYSIS, // | Targets::SPIRV + // | Targets::METAL + // | Targets::GLSL + // | Targets::HLSL + // | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index 99257457fb..7d4739806b 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -553,11 +553,11 @@ fn local_var_missing_type() { var x; } "#, - r#"error: variable `x` needs a type + r#"error: declaration of `x` needs a type specifier or initializer ┌─ wgsl:3:21 │ 3 │ var x; - │ ^ definition of `x` + │ ^ needs a type specifier or initializer "#, ); From 7e130d68653506f7eaf6e09f395648caf4fe90f1 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 7 Dec 2023 10:06:00 -0800 Subject: [PATCH 2/3] Compaction never removes overrides. --- naga/src/compact/expressions.rs | 26 +++++++------------------- naga/src/compact/functions.rs | 2 -- naga/src/compact/mod.rs | 32 ++++++++------------------------ 3 files changed, 15 insertions(+), 45 deletions(-) diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index ccc3409613..21c4c9cdc2 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -14,9 +14,6 @@ pub struct ExpressionTracer<'tracer> { /// The used map for `constants`. pub constants_used: &'tracer mut HandleSet, - /// The used map for `overrides`. - pub overrides_used: &'tracer mut HandleSet, - /// The used set for `arena`. /// /// This points to whatever arena holds the expressions we are @@ -92,21 +89,10 @@ impl<'tracer> ExpressionTracer<'tracer> { None => self.expressions_used.insert(init), } } - Ex::Override(handle) => { - self.overrides_used.insert(handle); - // Overrides and expressions are mutually recursive, which - // complicates our nice one-pass algorithm. However, since - // overrides don't refer to each other, we can get around - // this by looking *through* each override and marking its - // initializer as used. Since `expr` refers to the override, - // and the override refers to the initializer, it must - // precede `expr` in the arena. - if let Some(init) = self.overrides[handle].init { - match self.const_expressions_used { - Some(ref mut used) => used.insert(init), - None => self.expressions_used.insert(init), - } - } + Ex::Override(_) => { + // All overrides are considered used by definition. We mark + // their types and initialization expressions as used in + // `compact::compact`, so we have no more work to do here. } Ex::ZeroValue(ty) => self.types_used.insert(ty), Ex::Compose { ty, ref components } => { @@ -239,9 +225,11 @@ impl ModuleMap { | Ex::CallResult(_) | Ex::RayQueryProceedResult => {} + // All overrides are retained, so their handles never change. + Ex::Override(_) => {} + // Expressions that contain handles that need to be adjusted. Ex::Constant(ref mut constant) => self.constants.adjust(constant), - Ex::Override(ref mut override_) => self.overrides.adjust(override_), Ex::ZeroValue(ref mut ty) => self.types.adjust(ty), Ex::Compose { ref mut ty, diff --git a/naga/src/compact/functions.rs b/naga/src/compact/functions.rs index f11df7be1a..98a23acee0 100644 --- a/naga/src/compact/functions.rs +++ b/naga/src/compact/functions.rs @@ -8,7 +8,6 @@ pub struct FunctionTracer<'a> { pub types_used: &'a mut HandleSet, pub constants_used: &'a mut HandleSet, - pub overrides_used: &'a mut HandleSet, pub const_expressions_used: &'a mut HandleSet, /// Function-local expressions used. @@ -54,7 +53,6 @@ impl<'a> FunctionTracer<'a> { types_used: self.types_used, constants_used: self.constants_used, - overrides_used: self.overrides_used, expressions_used: &mut self.expressions_used, const_expressions_used: Some(&mut self.const_expressions_used), } diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index 3aad85f348..843a0ccf53 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -55,8 +55,8 @@ pub fn compact(module: &mut crate::Module) { } // We treat all overrides as used by definition. - for (handle, override_) in module.overrides.iter() { - module_tracer.overrides_used.insert(handle); + for (_, override_) in module.overrides.iter() { + module_tracer.types_used.insert(override_.ty); if let Some(init) = override_.init { module_tracer.const_expressions_used.insert(init); } @@ -107,11 +107,6 @@ pub fn compact(module: &mut crate::Module) { module_tracer.types_used.insert(constant.ty); } } - for (handle, override_) in module.overrides.iter() { - if module_tracer.overrides_used.contains(handle) { - module_tracer.types_used.insert(override_.ty); - } - } // Treat all named types as used. for (handle, ty) in module.types.iter() { @@ -171,19 +166,14 @@ pub fn compact(module: &mut crate::Module) { } }); - // Drop unused overrides in place, reusing existing storage. + // Adjust override types and initializers. log::trace!("adjusting overrides"); - module.overrides.retain_mut(|handle, override_| { - if module_map.overrides.used(handle) { - module_map.types.adjust(&mut override_.ty); - if let Some(init) = override_.init.as_mut() { - module_map.const_expressions.adjust(init); - } - true - } else { - false + for (_, override_) in module.overrides.iter_mut() { + module_map.types.adjust(&mut override_.ty); + if let Some(init) = override_.init.as_mut() { + module_map.const_expressions.adjust(init); } - }); + } // Adjust global variables' types and initializers. log::trace!("adjusting global variables"); @@ -220,7 +210,6 @@ struct ModuleTracer<'module> { module: &'module crate::Module, types_used: HandleSet, constants_used: HandleSet, - overrides_used: HandleSet, const_expressions_used: HandleSet, } @@ -230,7 +219,6 @@ impl<'module> ModuleTracer<'module> { module, types_used: HandleSet::for_arena(&module.types), constants_used: HandleSet::for_arena(&module.constants), - overrides_used: HandleSet::for_arena(&module.overrides), const_expressions_used: HandleSet::for_arena(&module.const_expressions), } } @@ -267,7 +255,6 @@ impl<'module> ModuleTracer<'module> { overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, - overrides_used: &mut self.overrides_used, expressions_used: &mut self.const_expressions_used, const_expressions_used: None, } @@ -283,7 +270,6 @@ impl<'module> ModuleTracer<'module> { overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, - overrides_used: &mut self.overrides_used, const_expressions_used: &mut self.const_expressions_used, expressions_used: HandleSet::for_arena(&function.expressions), } @@ -293,7 +279,6 @@ impl<'module> ModuleTracer<'module> { struct ModuleMap { types: HandleMap, constants: HandleMap, - overrides: HandleMap, const_expressions: HandleMap, } @@ -302,7 +287,6 @@ impl From> for ModuleMap { ModuleMap { types: HandleMap::from_set(used.types_used), constants: HandleMap::from_set(used.constants_used), - overrides: HandleMap::from_set(used.overrides_used), const_expressions: HandleMap::from_set(used.const_expressions_used), } } From 85cd5455219ce98748d3b588bd5f252d352034a5 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 7 Dec 2023 10:59:05 -0800 Subject: [PATCH 3/3] Add test for concretization of abstract override initializers. --- naga/tests/in/overrides.wgsl | 4 +++- naga/tests/out/analysis/overrides.info.ron | 4 ++++ naga/tests/out/ir/overrides.compact.ron | 7 +++++++ naga/tests/out/ir/overrides.ron | 7 +++++++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index 4cbc469f54..803269a656 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -9,4 +9,6 @@ // override height = 2 * depth; // The default value // (if not set at the API level), // depends on another - // overridable constant. \ No newline at end of file + // overridable constant. + +override inferred_f32 = 2.718; diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index d9932bf034..9ad1b3914e 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -18,5 +18,9 @@ kind: Float, width: 4, ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), ], ) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index a68d887c16..5ac9ade6f6 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -52,12 +52,19 @@ ty: 2, init: None, ), + ( + name: Some("inferred_f32"), + id: None, + ty: 2, + init: Some(4), + ), ], global_variables: [], const_expressions: [ Literal(Bool(true)), Literal(F32(2.3)), Literal(F32(0.0)), + Literal(F32(2.718)), ], functions: [], entry_points: [], diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index 696e18d1b4..f087317544 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -52,6 +52,12 @@ ty: 2, init: None, ), + ( + name: Some("inferred_f32"), + id: None, + ty: 2, + init: Some(7), + ), ], global_variables: [], const_expressions: [ @@ -61,6 +67,7 @@ Literal(I32(1200)), Literal(I32(1300)), Literal(F32(0.0)), + Literal(F32(2.718)), ], functions: [], entry_points: [],