diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 1556371df11..d128c855ca9 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -404,6 +404,7 @@ fn write_function_expressions( let (label, color_id) = match *expression { E::Literal(_) => ("Literal".into(), 2), E::Constant(_) => ("Constant".into(), 2), + E::Override(_) => ("Override".into(), 2), E::ZeroValue(_) => ("ZeroValue".into(), 2), E::Compose { ref components, .. } => { payload = Some(Payload::Arguments(components)); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 6fad36eb675..4b67522053d 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2520,6 +2520,7 @@ impl<'a, W: Write> Writer<'a, W> { |writer, expr| writer.write_expr(expr, ctx), )?; } + Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())), // `Access` is applied to arrays, vectors and matrices and is written as indexing Expression::Access { base, index } => { self.write_expr(base, ctx)?; diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 081ee4f2a9b..5aad4a20770 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2156,6 +2156,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => { + return Err(Error::Unimplemented("overrides are WIP".into())) + } // All of the multiplication can be expressed as `mul`, // except vector * vector, which needs to use the "*" operator. Expression::Binary { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 1e496b5f50f..c3d4d734c97 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1401,6 +1401,9 @@ impl Writer { |writer, context, expr| writer.put_expression(expr, context, true), )?; } + crate::Expression::Override(_) => { + return Err(Error::FeatureNotImplemented("overrides are WIP".into())) + } crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { // This is an acceptable place to generate a `ReadZeroSkipWrite` check. diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 6c96fa09e39..4eca34168c5 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -239,6 +239,9 @@ impl<'w> BlockContext<'w> { let init = self.ir_module.constants[handle].init; self.writer.constant_ids[init.index()] } + crate::Expression::Override(_) => { + return Err(Error::FeatureNotImplemented("overrides are WIP")) + } crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index c737934f5e3..bd4d5f17d73 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1190,6 +1190,9 @@ impl Writer { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => { + return Err(Error::Unimplemented("overrides are WIP".into())) + } Expression::FunctionArgument(pos) => { let name_key = func_ctx.argument_key(pos); let name = &self.names[&name_key]; diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 301bbe32405..21c4c9cdc20 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle}; pub struct ExpressionTracer<'tracer> { pub constants: &'tracer Arena, + pub overrides: &'tracer Arena, /// The arena in which we are currently tracing expressions. pub expressions: &'tracer Arena, @@ -88,6 +89,11 @@ impl<'tracer> ExpressionTracer<'tracer> { None => self.expressions_used.insert(init), } } + Ex::Override(_) => { + // All overrides are considered used by definition. We mark + // their types and initialization expressions as used in + // `compact::compact`, so we have no more work to do here. + } Ex::ZeroValue(ty) => self.types_used.insert(ty), Ex::Compose { ty, ref components } => { self.types_used.insert(ty); @@ -219,6 +225,9 @@ impl ModuleMap { | Ex::CallResult(_) | Ex::RayQueryProceedResult => {} + // All overrides are retained, so their handles never change. + Ex::Override(_) => {} + // Expressions that contain handles that need to be adjusted. Ex::Constant(ref mut constant) => self.constants.adjust(constant), Ex::ZeroValue(ref mut ty) => self.types.adjust(ty), diff --git a/naga/src/compact/functions.rs b/naga/src/compact/functions.rs index b0d08c7e96e..98a23acee0a 100644 --- a/naga/src/compact/functions.rs +++ b/naga/src/compact/functions.rs @@ -4,6 +4,7 @@ use super::{FunctionMap, ModuleMap}; pub struct FunctionTracer<'a> { pub function: &'a crate::Function, pub constants: &'a crate::Arena, + pub overrides: &'a crate::Arena, pub types_used: &'a mut HandleSet, pub constants_used: &'a mut HandleSet, @@ -47,6 +48,7 @@ impl<'a> FunctionTracer<'a> { fn as_expression(&mut self) -> super::expressions::ExpressionTracer { super::expressions::ExpressionTracer { constants: self.constants, + overrides: self.overrides, expressions: &self.function.expressions, types_used: self.types_used, diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index b4e57ed5c9f..2b49d349952 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -54,6 +54,14 @@ pub fn compact(module: &mut crate::Module) { } } + // We treat all overrides as used by definition. + for (_, override_) in module.overrides.iter() { + module_tracer.types_used.insert(override_.ty); + if let Some(init) = override_.init { + module_tracer.const_expressions_used.insert(init); + } + } + // We assume that all functions are used. // // Observe which types, constant expressions, constants, and @@ -158,6 +166,15 @@ pub fn compact(module: &mut crate::Module) { } }); + // Adjust override types and initializers. + log::trace!("adjusting overrides"); + for (_, override_) in module.overrides.iter_mut() { + module_map.types.adjust(&mut override_.ty); + if let Some(init) = override_.init.as_mut() { + module_map.const_expressions.adjust(init); + } + } + // Adjust global variables' types and initializers. log::trace!("adjusting global variables"); for (_, global) in module.global_variables.iter_mut() { @@ -235,6 +252,7 @@ impl<'module> ModuleTracer<'module> { expressions::ExpressionTracer { expressions: &self.module.const_expressions, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, expressions_used: &mut self.const_expressions_used, @@ -249,6 +267,7 @@ impl<'module> ModuleTracer<'module> { FunctionTracer { function, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, const_expressions_used: &mut self.const_expressions_used, diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 198d9c52dd2..8a7e736eddd 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -128,6 +128,7 @@ impl> super::Frontend { expressions: &mut fun.expressions, local_arena: &mut fun.local_variables, const_arena: &mut module.constants, + overrides: &mut module.overrides, const_expressions: &mut module.const_expressions, type_arena: &module.types, global_arena: &module.global_variables, @@ -573,6 +574,7 @@ impl<'function> BlockContext<'function> { crate::proc::GlobalCtx { types: self.type_arena, constants: self.const_arena, + overrides: self.overrides, const_expressions: self.const_expressions, } } diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 5f0ebdf7659..ba42b1706df 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -531,6 +531,7 @@ struct BlockContext<'function> { local_arena: &'function mut Arena, /// Constants arena of the module being processed const_arena: &'function mut Arena, + overrides: &'function mut Arena, const_expressions: &'function mut Arena, /// Type arena of the module being processed type_arena: &'function UniqueArena, @@ -3932,7 +3933,7 @@ impl> Frontend { Op::TypeImage => self.parse_type_image(inst, &mut module), Op::TypeSampledImage => self.parse_type_sampled_image(inst), Op::TypeSampler => self.parse_type_sampler(inst, &mut module), - Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), + Op::Constant => self.parse_constant(inst, &mut module), Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module), diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 07e68f8dd93..e607f0794f8 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -190,7 +190,7 @@ pub enum Error<'a> { expected: String, got: String, }, - MissingType(Span), + DeclMissingTypeAndInit(Span), MissingAttribute(&'static str, Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), @@ -269,6 +269,7 @@ pub enum Error<'a> { scalar: String, inner: ConstantEvaluatorError, }, + PipelineConstantIDValue(Span), } impl<'a> Error<'a> { @@ -518,11 +519,11 @@ impl<'a> Error<'a> { notes: vec![], } } - Error::MissingType(name_span) => ParseError { - message: format!("variable `{}` needs a type", &source[name_span]), + Error::DeclMissingTypeAndInit(name_span) => ParseError { + message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]), labels: vec![( name_span, - format!("definition of `{}`", &source[name_span]).into(), + "needs a type specifier or initializer".into(), )], notes: vec![], }, @@ -770,6 +771,14 @@ impl<'a> Error<'a> { format!("the expression should have been converted to have {} scalar type", scalar), ] }, + Error::PipelineConstantIDValue(span) => ParseError { + message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(), + labels: vec![( + span, + "must be between 0 and 65535 inclusive".into(), + )], + notes: vec![], + }, } } } diff --git a/naga/src/front/wgsl/index.rs b/naga/src/front/wgsl/index.rs index a5524fe8f11..593405508fd 100644 --- a/naga/src/front/wgsl/index.rs +++ b/naga/src/front/wgsl/index.rs @@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> { ast::GlobalDeclKind::Fn(ref f) => f.name, ast::GlobalDeclKind::Var(ref v) => v.name, ast::GlobalDeclKind::Const(ref c) => c.name, + ast::GlobalDeclKind::Override(ref o) => o.name, ast::GlobalDeclKind::Struct(ref s) => s.name, ast::GlobalDeclKind::Type(ref t) => t.name, } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index c3aa6a932b2..87b3732effc 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -786,6 +786,7 @@ enum LoweredGlobalDecl { Function(Handle), Var(Handle), Const(Handle), + Override(Handle), Type(Handle), EntryPoint, } @@ -965,6 +966,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ctx.globals .insert(c.name.name, LoweredGlobalDecl::Const(handle)); } + ast::GlobalDeclKind::Override(ref o) => { + let init = o + .init + .map(|init| self.expression(init, &mut ctx.as_const())) + .transpose()?; + let inferred_type = init + .map(|init| ctx.as_const().register_type(init)) + .transpose()?; + + let explicit_ty = + o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) + .transpose()?; + + let id = + o.id.map(|id| self.const_u32(id, &mut ctx.as_const())) + .transpose()?; + + let id = if let Some((id, id_span)) = id { + Some( + u16::try_from(id) + .map_err(|_| Error::PipelineConstantIDValue(id_span))?, + ) + } else { + None + }; + + let ty = match (explicit_ty, inferred_type) { + (Some(explicit_ty), Some(inferred_type)) => { + if explicit_ty == inferred_type { + explicit_ty + } else { + let gctx = ctx.module.to_ctx(); + return Err(Error::InitializationTypeMismatch { + name: o.name.span, + expected: explicit_ty.to_wgsl(&gctx), + got: inferred_type.to_wgsl(&gctx), + }); + } + } + (Some(explicit_ty), None) => explicit_ty, + (None, Some(inferred_type)) => inferred_type, + (None, None) => { + return Err(Error::DeclMissingTypeAndInit(o.name.span)); + } + }; + + let handle = ctx.module.overrides.append( + crate::Override { + name: Some(o.name.name.to_string()), + id, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(o.name.name, LoweredGlobalDecl::Override(handle)); + } ast::GlobalDeclKind::Struct(ref s) => { let handle = self.r#struct(s, span, &mut ctx)?; ctx.globals @@ -1202,7 +1262,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty = explicit_ty; initializer = None; } - (None, None) => return Err(Error::MissingType(v.name.span)), + (None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)), } let (const_initializer, initializer) = { @@ -1816,9 +1876,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; Ok(Some(handle)) } - Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => { - Err(Error::Unexpected(function.span, ExpectedToken::Function)) - } + Some( + &LoweredGlobalDecl::Const(_) + | &LoweredGlobalDecl::Override(_) + | &LoweredGlobalDecl::Var(_), + ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)), Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)), Some(&LoweredGlobalDecl::Function(function)) => { let arguments = arguments diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index dbaac523cbe..ea8013ee7c2 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> { Fn(Function<'a>), Var(GlobalVariable<'a>), Const(Const<'a>), + Override(Override<'a>), Struct(Struct<'a>), Type(TypeAlias<'a>), } @@ -200,6 +201,14 @@ pub struct Const<'a> { pub init: Handle>, } +#[derive(Debug)] +pub struct Override<'a> { + pub name: Ident<'a>, + pub id: Option>>, + pub ty: Option>>, + pub init: Option>>, +} + /// The size of an [`Array`] or [`BindingArray`]. /// /// [`Array`]: Type::Array diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 51fc2f013b1..810e67f9fe9 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -2170,6 +2170,7 @@ impl Parser { let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); + let mut id = ParsedAttribute::default(); let mut dependencies = FastIndexSet::default(); let mut ctx = ExpressionContext { @@ -2193,6 +2194,11 @@ impl Parser { bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } + ("id", name_span) => { + lexer.expect(Token::Paren('('))?; + id.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } ("vertex", name_span) => { stage.set(crate::ShaderStage::Vertex, name_span)?; } @@ -2283,6 +2289,30 @@ impl Parser { Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) } + (Token::Word("override"), _) => { + let name = lexer.next_ident()?; + + let ty = if lexer.skip(Token::Separator(':')) { + Some(self.type_decl(lexer, &mut ctx)?) + } else { + None + }; + + let init = if lexer.skip(Token::Operation('=')) { + Some(self.general_expression(lexer, &mut ctx)?) + } else { + None + }; + + lexer.expect(Token::Separator(';'))?; + + Some(ast::GlobalDeclKind::Override(ast::Override { + name, + id: id.value, + ty, + init, + })) + } (Token::Word("var"), _) => { let mut var = self.variable_decl(lexer, &mut ctx)?; var.binding = binding.take(); diff --git a/naga/src/front/wgsl/to_wgsl.rs b/naga/src/front/wgsl/to_wgsl.rs index c8331ace095..ba6063ab46d 100644 --- a/naga/src/front/wgsl/to_wgsl.rs +++ b/naga/src/front/wgsl/to_wgsl.rs @@ -226,6 +226,7 @@ mod tests { let gctx = crate::proc::GlobalCtx { types: &types, constants: &crate::Arena::new(), + overrides: &crate::Arena::new(), const_expressions: &crate::Arena::new(), }; let array = crate::TypeInner::Array { diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 372214600d8..f5280e489e9 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -175,7 +175,7 @@ tree. A Naga *constant expression* is one of the following [`Expression`] variants, whose operands (if any) are also constant expressions: - [`Literal`] -- [`Constant`], for [`Constant`s][const_type] whose `override` is `None` +- [`Constant`], for [`Constant`]s - [`ZeroValue`], for fixed-size types - [`Compose`] - [`Access`] @@ -194,8 +194,7 @@ A constant expression can be evaluated at module translation time. ## Override expressions A Naga *override expression* is the same as a [constant expression], -except that it is also allowed to refer to [`Constant`s][const_type] -whose `override` is something other than `None`. +except that it is also allowed to reference other [`Override`]s. An override expression can be evaluated at pipeline creation time. @@ -238,8 +237,6 @@ An override expression can be evaluated at pipeline creation time. [`Math`]: Expression::Math [`As`]: Expression::As -[const_type]: Constant - [constant expression]: index.html#constant-expressions */ @@ -888,6 +885,25 @@ pub enum Literal { AbstractFloat(f64), } +/// Pipeline-overridable constant. +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct Override { + pub name: Option, + /// Pipeline Constant ID. + pub id: Option, + pub ty: Handle, + + /// The default value of the pipeline-overridable constant. + /// + /// This [`Handle`] refers to [`Module::const_expressions`], not + /// any [`Function::expressions`] arena. + pub init: Option>, +} + /// Constant value. #[derive(Debug, PartialEq)] #[cfg_attr(feature = "clone", derive(Clone))] @@ -902,13 +918,6 @@ pub struct Constant { /// /// This [`Handle`] refers to [`Module::const_expressions`], not /// any [`Function::expressions`] arena. - /// - /// If `override` is `None`, then this must be a Naga - /// [constant expression]. Otherwise, this may be a Naga - /// [override expression] or [constant expression]. - /// - /// [constant expression]: index.html#constant-expressions - /// [override expression]: index.html#override-expressions pub init: Handle, } @@ -1294,6 +1303,8 @@ pub enum Expression { Literal(Literal), /// Constant value. Constant(Handle), + /// Pipeline-overridable constant. + Override(Handle), /// Zero value of a type. ZeroValue(Handle), /// Composite expression. @@ -2036,6 +2047,8 @@ pub struct Module { pub special_types: SpecialTypes, /// Arena for the constants defined in this module. pub constants: Arena, + /// Arena for the pipeline-overridable constants defined in this module. + pub overrides: Arena, /// Arena for the global variables defined in this module. pub global_variables: Arena, /// [Constant expressions] and [override expressions] used by this module. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index e54e3f3bc01..3f1f10a2622 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -4,8 +4,8 @@ use arrayvec::ArrayVec; use crate::{ arena::{Arena, Handle, UniqueArena}, - ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner, - UnaryOperator, + ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type, + TypeInner, UnaryOperator, }; /// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating @@ -289,6 +289,9 @@ pub struct ConstantEvaluator<'a> { /// The module's constant arena. constants: &'a Arena, + /// The module's override arena. + overrides: &'a Arena, + /// The arena to which we are contributing expressions. expressions: &'a mut Arena, @@ -454,6 +457,7 @@ impl<'a> ConstantEvaluator<'a> { behavior, types: &mut module.types, constants: &module.constants, + overrides: &module.overrides, expressions: &mut module.const_expressions, function_local_data: None, } @@ -513,6 +517,7 @@ impl<'a> ConstantEvaluator<'a> { behavior, types: &mut module.types, constants: &module.constants, + overrides: &module.overrides, expressions, function_local_data: Some(FunctionLocalData { const_expressions: &module.const_expressions, @@ -527,6 +532,7 @@ impl<'a> ConstantEvaluator<'a> { crate::proc::GlobalCtx { types: self.types, constants: self.constants, + overrides: self.overrides, const_expressions: match self.function_local_data { Some(ref data) => data.const_expressions, None => self.expressions, @@ -603,6 +609,9 @@ impl<'a> ConstantEvaluator<'a> { // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } + Expression::Override(_) => Err(ConstantEvaluatorError::NotImplemented( + "overrides are WIP".into(), + )), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } @@ -1892,6 +1901,7 @@ mod tests { fn unary_op() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let scalar_ty = types.insert( @@ -1970,6 +1980,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -2021,6 +2032,7 @@ mod tests { fn cast() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let scalar_ty = types.insert( @@ -2053,6 +2065,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -2071,6 +2084,7 @@ mod tests { fn access() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let matrix_ty = types.insert( @@ -2168,6 +2182,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -2221,6 +2236,7 @@ mod tests { fn compose_of_constants() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let i32_ty = types.insert( @@ -2258,6 +2274,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -2300,6 +2317,7 @@ mod tests { fn splat_of_constant() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let i32_ty = types.insert( @@ -2337,6 +2355,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 93315faa716..96c0dbcff2c 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -633,6 +633,7 @@ impl crate::Module { GlobalCtx { types: &self.types, constants: &self.constants, + overrides: &self.overrides, const_expressions: &self.const_expressions, } } @@ -648,6 +649,7 @@ pub(super) enum U32EvalError { pub struct GlobalCtx<'a> { pub types: &'a crate::UniqueArena, pub constants: &'a crate::Arena, + pub overrides: &'a crate::Arena, pub const_expressions: &'a crate::Arena, } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 9c4403445c1..845b35cb4de 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -185,6 +185,7 @@ pub enum ResolveError { pub struct ResolveContext<'a> { pub constants: &'a Arena, + pub overrides: &'a Arena, pub types: &'a UniqueArena, pub special_types: &'a crate::SpecialTypes, pub global_vars: &'a Arena, @@ -202,6 +203,7 @@ impl<'a> ResolveContext<'a> { ) -> Self { Self { constants: &module.constants, + overrides: &module.overrides, types: &module.types, special_types: &module.special_types, global_vars: &module.global_variables, @@ -407,6 +409,7 @@ impl<'a> ResolveContext<'a> { }, crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()), crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty), + crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty), crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty), crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), crate::Expression::FunctionArgument(index) => { diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index df6fc5e9b02..17c76b2738e 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -527,7 +527,7 @@ impl FunctionInfo { non_uniform_result: self.add_ref(vector), requirements: UniformityRequirements::empty(), }, - E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(), + E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(), E::Compose { ref components, .. } => { let non_uniform_result = components .iter() @@ -1139,6 +1139,7 @@ fn uniform_control_flow() { }; let resolve_context = ResolveContext { constants: &Arena::new(), + overrides: &Arena::new(), types: &type_arena, special_types: &crate::SpecialTypes::default(), global_vars: &global_var_arena, diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 54d8b3b3570..f41948b9102 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -343,7 +343,7 @@ impl super::Validator { self.validate_literal(literal)?; ShaderStages::all() } - E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(), E::Compose { ref components, ty } => { validate_compose( ty, diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 1884c01303f..0643b1c9f51 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -31,6 +31,7 @@ impl super::Validator { pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { let &crate::Module { ref constants, + ref overrides, ref entry_points, ref functions, ref global_variables, @@ -68,7 +69,7 @@ impl super::Validator { } for handle_and_expr in const_expressions.iter() { - Self::validate_const_expression_handles(handle_and_expr, constants, types)?; + Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?; } let validate_type = |handle| Self::validate_type_handle(handle, types); @@ -81,6 +82,19 @@ impl super::Validator { validate_const_expr(init)?; } + for (_handle, override_) in overrides.iter() { + let &crate::Override { + name: _, + id: _, + ty, + init, + } = override_; + validate_type(ty)?; + if let Some(init_expr) = init { + validate_const_expr(init_expr)?; + } + } + for (_handle, global_variable) in global_variables.iter() { let &crate::GlobalVariable { name: _, @@ -135,6 +149,7 @@ impl super::Validator { Self::validate_expression_handles( handle_and_expr, constants, + overrides, const_expressions, types, local_variables, @@ -181,6 +196,13 @@ impl super::Validator { handle.check_valid_for(constants).map(|_| ()) } + fn validate_override_handle( + handle: Handle, + overrides: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(overrides).map(|_| ()) + } + fn validate_expression_handle( handle: Handle, expressions: &Arena, @@ -198,9 +220,11 @@ impl super::Validator { fn validate_const_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, + overrides: &Arena, types: &UniqueArena, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { @@ -209,6 +233,12 @@ impl super::Validator { validate_constant(constant)?; handle.check_dep(constants[constant].init)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + if let Some(init) = overrides[override_].init { + handle.check_dep(init)?; + } + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -225,6 +255,7 @@ impl super::Validator { fn validate_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, + overrides: &Arena, const_expressions: &Arena, types: &UniqueArena, local_variables: &Arena, @@ -234,6 +265,7 @@ impl super::Validator { current_function: Option>, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_const_expr = |handle| Self::validate_expression_handle(handle, const_expressions); let validate_type = |handle| Self::validate_type_handle(handle, types); @@ -255,6 +287,9 @@ impl super::Validator { crate::Expression::Constant(constant) => { validate_constant(constant)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -659,6 +694,7 @@ fn constant_deps() { let mut const_exprs = Arena::new(); let mut fun_exprs = Arena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let i32_handle = types.insert( Type { @@ -686,6 +722,7 @@ fn constant_deps() { assert!(super::Validator::validate_const_expression_handles( handle_and_expr, &constants, + &overrides, &types, ) .is_err()); diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 388495a3ac1..c0dfd44e1ab 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -182,6 +182,16 @@ pub enum ConstantError { NonConstructibleType, } +#[derive(Clone, Debug, thiserror::Error)] +pub enum OverrideError { + #[error("The type doesn't match the override")] + InvalidType, + #[error("The type is not constructible")] + NonConstructibleType, + #[error("The type is not a scalar")] + TypeNotScalar, +} + #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { #[error(transparent)] @@ -205,6 +215,12 @@ pub enum ValidationError { name: String, source: ConstantError, }, + #[error("Override {handle:?} '{name}' is invalid")] + Override { + handle: Handle, + name: String, + source: OverrideError, + }, #[error("Global variable {handle:?} '{name}' is invalid")] GlobalVariable { handle: Handle, @@ -327,6 +343,35 @@ impl Validator { Ok(()) } + fn validate_override( + &self, + handle: Handle, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), OverrideError> { + let o = &gctx.overrides[handle]; + + let type_info = &self.types[o.ty.index()]; + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(OverrideError::NonConstructibleType); + } + + let decl_ty = &gctx.types[o.ty].inner; + match decl_ty { + &crate::TypeInner::Scalar(_) => {} + _ => return Err(OverrideError::TypeNotScalar), + } + + if let Some(init) = o.init { + let init_ty = mod_info[init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(OverrideError::InvalidType); + } + } + + Ok(()) + } + /// Check the given module to be valid. pub fn validate( &mut self, @@ -404,6 +449,18 @@ impl Validator { .with_span_handle(handle, &module.constants) })? } + + for (handle, override_) in module.overrides.iter() { + self.validate_override(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::Override { + handle, + name: override_.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.overrides) + })? + } } for (var_handle, var) in module.global_variables.iter() { diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl new file mode 100644 index 00000000000..803269a656b --- /dev/null +++ b/naga/tests/in/overrides.wgsl @@ -0,0 +1,14 @@ +@id(0) override has_point_light: bool = true; // Algorithmic control +@id(1200) override specular_param: f32 = 2.3; // Numeric control +@id(1300) override gain: f32; // Must be overridden + override width: f32 = 0.0; // Specified at the API level using + // the name "width". + override depth: f32; // Specified at the API level using + // the name "depth". + // Must be overridden. + // override height = 2 * depth; // The default value + // (if not set at the API level), + // depends on another + // overridable constant. + +override inferred_f32 = 2.718; diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron new file mode 100644 index 00000000000..9ad1b3914ed --- /dev/null +++ b/naga/tests/out/analysis/overrides.info.ron @@ -0,0 +1,26 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [], + const_expression_types: [ + Value(Scalar(( + kind: Bool, + width: 1, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + ], +) \ No newline at end of file diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 0670534e90c..37ace5283f3 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -324,6 +324,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("global_const"), diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 0670534e90c..37ace5283f3 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -324,6 +324,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("global_const"), diff --git a/naga/tests/out/ir/collatz.compact.ron b/naga/tests/out/ir/collatz.compact.ron index cfc3bfa0ee4..fe4af55c1b3 100644 --- a/naga/tests/out/ir/collatz.compact.ron +++ b/naga/tests/out/ir/collatz.compact.ron @@ -46,6 +46,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("v_indices"), diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index cfc3bfa0ee4..fe4af55c1b3 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -46,6 +46,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("v_indices"), diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron new file mode 100644 index 00000000000..5ac9ade6f68 --- /dev/null +++ b/naga/tests/out/ir/overrides.compact.ron @@ -0,0 +1,71 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("has_point_light"), + id: Some(0), + ty: 1, + init: Some(1), + ), + ( + name: Some("specular_param"), + id: Some(1200), + ty: 2, + init: Some(2), + ), + ( + name: Some("gain"), + id: Some(1300), + ty: 2, + init: None, + ), + ( + name: Some("width"), + id: None, + ty: 2, + init: Some(3), + ), + ( + name: Some("depth"), + id: None, + ty: 2, + init: None, + ), + ( + name: Some("inferred_f32"), + id: None, + ty: 2, + init: Some(4), + ), + ], + global_variables: [], + const_expressions: [ + Literal(Bool(true)), + Literal(F32(2.3)), + Literal(F32(0.0)), + Literal(F32(2.718)), + ], + functions: [], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron new file mode 100644 index 00000000000..5ac9ade6f68 --- /dev/null +++ b/naga/tests/out/ir/overrides.ron @@ -0,0 +1,71 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("has_point_light"), + id: Some(0), + ty: 1, + init: Some(1), + ), + ( + name: Some("specular_param"), + id: Some(1200), + ty: 2, + init: Some(2), + ), + ( + name: Some("gain"), + id: Some(1300), + ty: 2, + init: None, + ), + ( + name: Some("width"), + id: None, + ty: 2, + init: Some(3), + ), + ( + name: Some("depth"), + id: None, + ty: 2, + init: None, + ), + ( + name: Some("inferred_f32"), + id: None, + ty: 2, + init: Some(4), + ), + ], + global_variables: [], + const_expressions: [ + Literal(Bool(true)), + Literal(F32(2.3)), + Literal(F32(0.0)), + Literal(F32(2.718)), + ], + functions: [], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/shadow.compact.ron b/naga/tests/out/ir/shadow.compact.ron index 4e651806911..fab0f1e2f60 100644 --- a/naga/tests/out/ir/shadow.compact.ron +++ b/naga/tests/out/ir/shadow.compact.ron @@ -253,6 +253,7 @@ init: 22, ), ], + overrides: [], global_variables: [ ( name: Some("t_shadow"), diff --git a/naga/tests/out/ir/shadow.ron b/naga/tests/out/ir/shadow.ron index 0b2310284a7..9acbbdaadd4 100644 --- a/naga/tests/out/ir/shadow.ron +++ b/naga/tests/out/ir/shadow.ron @@ -456,6 +456,7 @@ init: 38, ), ], + overrides: [], global_variables: [ ( name: Some("t_shadow"), diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 5bb48bbfe4e..4d078d86b21 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -811,6 +811,14 @@ fn convert_wgsl() { "abstract-types-operators", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, ), + ( + "overrides", + Targets::IR | Targets::ANALYSIS, // | Targets::SPIRV + // | Targets::METAL + // | Targets::GLSL + // | Targets::HLSL + // | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index 5624b3098e3..a9c80536613 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -570,11 +570,11 @@ fn local_var_missing_type() { var x; } "#, - r#"error: variable `x` needs a type + r#"error: declaration of `x` needs a type specifier or initializer ┌─ wgsl:3:21 │ 3 │ var x; - │ ^ definition of `x` + │ ^ needs a type specifier or initializer "#, );