diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 06a91dc80f..8f71869712 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -208,7 +208,7 @@ pub fn process_overrides<'a>( // recompute their types and other metadata. For the time being, // do a full re-validation. let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); - let module_info = validator.validate(&module)?; + let module_info = validator.validate_resolved_overrides(&module)?; Ok((Cow::Owned(module), Cow::Owned(module_info))) } diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 9ef3a9edfb..a723ac824b 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -224,7 +224,7 @@ impl super::Validator { crate::TypeInner::Scalar { .. } => {} _ => return Err(ConstExpressionError::InvalidSplatType(value)), }, - _ if global_expr_kind.is_const(handle) || !self.allow_overrides => { + _ if global_expr_kind.is_const(handle) => { return Err(ConstExpressionError::NonFullyEvaluatedConst) } // the constant evaluator will report errors about override-expressions diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 63806067c2..584b24fc53 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -274,7 +274,7 @@ pub struct Validator { valid_expression_list: Vec>, valid_expression_set: HandleSet, override_ids: FastHashSet, - allow_overrides: bool, + overrides_resolved: bool, /// A checklist of expressions that must be visited by a specific kind of /// statement. @@ -359,6 +359,12 @@ pub enum ValidationError { name: String, source: OverrideError, }, + #[error("Override {handle:?} '{name}' is unresolved")] + UnresolvedOverride { + handle: Handle, + name: String, + source: ConstExpressionError, + }, #[error("Global variable {handle:?} '{name}' is invalid")] GlobalVariable { handle: Handle, @@ -465,7 +471,7 @@ impl Validator { valid_expression_list: Vec::new(), valid_expression_set: HandleSet::new(), override_ids: FastHashSet::default(), - allow_overrides: true, + overrides_resolved: false, needs_visit: HandleSet::new(), } } @@ -525,10 +531,6 @@ impl Validator { gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, ) -> Result<(), OverrideError> { - if !self.allow_overrides { - return Err(OverrideError::NotAllowed); - } - let o = &gctx.overrides[handle]; if let Some(id) = o.id { @@ -569,18 +571,18 @@ impl Validator { &mut self, module: &crate::Module, ) -> Result> { - self.allow_overrides = true; + self.overrides_resolved = false; self.validate_impl(module) } /// Check the given module to be valid. /// - /// With the additional restriction that overrides are not present. - pub fn validate_no_overrides( + /// With the additional restriction that overrides are all resolved. + pub fn validate_resolved_overrides( &mut self, module: &crate::Module, ) -> Result> { - self.allow_overrides = false; + self.overrides_resolved = true; self.validate_impl(module) } @@ -623,20 +625,6 @@ impl Validator { } .with_span_handle(handle, &module.types) })?; - if !self.allow_overrides { - if let crate::TypeInner::Array { - size: crate::ArraySize::Pending(_), - .. - } = ty.inner - { - return Err((ValidationError::Type { - handle, - name: ty.name.clone().unwrap_or_default(), - source: TypeError::UnresolvedOverride(handle), - }) - .with_span_handle(handle, &module.types)); - } - } mod_info.type_flags.push(ty_info.flags); self.types[handle.index()] = ty_info; } @@ -691,7 +679,24 @@ impl Validator { source, } .with_span_handle(handle, &module.overrides) - })? + })?; + if self.overrides_resolved { + if let Some(expr) = r#override.init { + self.validate_const_expression( + expr, + module.to_ctx(), + &mod_info, + &global_expr_kind + ).map_err(|source| { + ValidationError::UnresolvedOverride { + handle, + name: r#override.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.overrides) + })?; + } + } } } diff --git a/naga/tests/validation.rs b/naga/tests/validation.rs index 3ffb231e25..b2de03ed4f 100644 --- a/naga/tests/validation.rs +++ b/naga/tests/validation.rs @@ -307,7 +307,7 @@ fn main() {{ ); let module = naga::front::wgsl::parse_str(&source).unwrap(); let err = valid::Validator::new(Default::default(), valid::Capabilities::all()) - .validate_no_overrides(&module) + .validate(&module) .expect_err("module should be invalid"); assert_eq!(err.emit_to_string(&source), expected_err); } @@ -381,7 +381,7 @@ fn incompatible_interpolation_and_sampling_types() { for (invalid_source, invalid_module, interpolation, sampling, interpolate_attr) in invalid_cases { let err = valid::Validator::new(Default::default(), valid::Capabilities::all()) - .validate_no_overrides(&invalid_module) + .validate(&invalid_module) .expect_err(&format!( "module should be invalid for {interpolate_attr:?}" )); @@ -679,7 +679,7 @@ error: Entry point main at Compute is invalid for (source, expected_err) in cases { let module = naga::front::wgsl::parse_str(source).unwrap(); let err = valid::Validator::new(Default::default(), valid::Capabilities::all()) - .validate_no_overrides(&module) + .validate(&module) .expect_err("module should be invalid"); println!("{}", err.emit_to_string(source)); assert_eq!(err.emit_to_string(source), expected_err);