diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index a6d95ce30b..c593d1fb72 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -157,8 +157,8 @@ impl StatementGraph { for case in cases { let (case_id, case_last) = self.add(&case.body, targets); let label = match case.value { - crate::SwitchValue::Integer(_) => "case", crate::SwitchValue::Default => "default", + _ => "case", }; self.flow.push((id, case_id, label)); // Link the last node of the branch to the merge node diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 264a9e7b1f..ed72ad6578 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1908,21 +1908,13 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, "switch(")?; self.write_expr(selector, ctx)?; writeln!(self.out, ") {{")?; - let type_postfix = match *ctx.info[selector].ty.inner_with(&self.module.types) { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - .. - } => "u", - _ => "", - }; // Write all cases let l2 = level.next(); for case in cases { match case.value { - crate::SwitchValue::Integer(value) => { - write!(self.out, "{l2}case {value}{type_postfix}:")? - } + crate::SwitchValue::I32(value) => write!(self.out, "{l2}case {value}:")?, + crate::SwitchValue::U32(value) => write!(self.out, "{l2}case {value}u:")?, crate::SwitchValue::Default => write!(self.out, "{l2}default:")?, } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 4ae4fc9f60..18128af687 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1883,13 +1883,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "switch(")?; self.write_expr(module, selector, func_ctx)?; writeln!(self.out, ") {{")?; - let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - .. - } => "u", - _ => "", - }; // Write all cases let indent_level_1 = level.next(); @@ -1897,8 +1890,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { for (i, case) in cases.iter().enumerate() { match case.value { - crate::SwitchValue::Integer(value) => { - write!(self.out, "{indent_level_1}case {value}{type_postfix}:")? + crate::SwitchValue::I32(value) => { + write!(self.out, "{indent_level_1}case {value}:")? + } + crate::SwitchValue::U32(value) => { + write!(self.out, "{indent_level_1}case {value}u:")? } crate::SwitchValue::Default => { write!(self.out, "{indent_level_1}default:")? diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index d50ac2d496..8454ddd903 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -2614,19 +2614,15 @@ impl Writer { } => { write!(self.out, "{level}switch(")?; self.put_expression(selector, &context.expression, true)?; - let type_postfix = match *context.expression.resolve_type(selector) { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - .. - } => "u", - _ => "", - }; writeln!(self.out, ") {{")?; let lcase = level.next(); for case in cases.iter() { match case.value { - crate::SwitchValue::Integer(value) => { - write!(self.out, "{lcase}case {value}{type_postfix}:")?; + crate::SwitchValue::I32(value) => { + write!(self.out, "{lcase}case {value}:")?; + } + crate::SwitchValue::U32(value) => { + write!(self.out, "{lcase}case {value}u:")?; } crate::SwitchValue::Default => { write!(self.out, "{lcase}default:")?; diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 17eabea0c0..d1f0bc7140 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1776,12 +1776,15 @@ impl<'w> BlockContext<'w> { case_ids.push(label_id); match case.value { - crate::SwitchValue::Integer(value) => { + crate::SwitchValue::I32(value) => { raw_cases.push(super::instructions::Case { value: value as Word, label_id, }); } + crate::SwitchValue::U32(value) => { + raw_cases.push(super::instructions::Case { value, label_id }); + } crate::SwitchValue::Default => { default_id = Some(label_id); } diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 9c1034614e..4a28cb25ec 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -865,14 +865,6 @@ impl Writer { self.write_expr(module, selector, func_ctx)?; writeln!(self.out, " {{")?; - let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - .. - } => "u", - _ => "", - }; - let l2 = level.next(); let mut new_case = true; for case in cases { @@ -884,11 +876,17 @@ impl Writer { } match case.value { - crate::SwitchValue::Integer(value) => { + crate::SwitchValue::I32(value) => { + if new_case { + write!(self.out, "{l2}case ")?; + } + write!(self.out, "{value}")?; + } + crate::SwitchValue::U32(value) => { if new_case { write!(self.out, "{l2}case ")?; } - write!(self.out, "{value}{type_postfix}")?; + write!(self.out, "{value}u")?; } crate::SwitchValue::Default => { if new_case { diff --git a/src/front/glsl/parser/functions.rs b/src/front/glsl/parser/functions.rs index e6c86ad6d6..8bfb9040b8 100644 --- a/src/front/glsl/parser/functions.rs +++ b/src/front/glsl/parser/functions.rs @@ -186,35 +186,33 @@ impl<'source> ParsingContext<'source> { let value = match self.expect_peek(frontend)?.value { TokenValue::Case => { self.bump(frontend)?; - let value = { - let mut stmt = ctx.stmt_ctx(); - let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?; - let (root, meta) = - ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?; - let constant = frontend.solve_constant(ctx, root, meta)?; - - match frontend.module.constants[constant].inner { - ConstantInner::Scalar { - value: ScalarValue::Sint(int), - .. - } => int as i32, - ConstantInner::Scalar { - value: ScalarValue::Uint(int), - .. - } => int as i32, - _ => { - frontend.errors.push(Error { - kind: ErrorKind::SemanticError( - "Case values can only be integers".into(), - ), - meta, - }); - - 0 - } + + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?; + let (root, meta) = + ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?; + let constant = frontend.solve_constant(ctx, root, meta)?; + + match frontend.module.constants[constant].inner { + ConstantInner::Scalar { + value: ScalarValue::Sint(int), + .. + } => crate::SwitchValue::I32(int as i32), + ConstantInner::Scalar { + value: ScalarValue::Uint(int), + .. + } => crate::SwitchValue::U32(int as u32), + _ => { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Case values can only be integers".into(), + ), + meta, + }); + + crate::SwitchValue::I32(0) } - }; - crate::SwitchValue::Integer(value) + } } TokenValue::Default => { self.bump(frontend)?; diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index 898d00f796..c26c61e1e9 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -596,7 +596,7 @@ impl<'function> BlockContext<'function> { let fall_through = body.last().map_or(true, |s| !s.is_terminator()); crate::SwitchCase { - value: crate::SwitchValue::Integer(value), + value: crate::SwitchValue::I32(value), body, fall_through, } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index f4699888c8..56127090f9 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1006,11 +1006,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .map(|case| { Ok(crate::SwitchCase { value: match case.value { - ast::SwitchValue::I32(num) if !uint => { - crate::SwitchValue::Integer(num) + ast::SwitchValue::I32(value) if !uint => { + crate::SwitchValue::I32(value) } - ast::SwitchValue::U32(num) if uint => { - crate::SwitchValue::Integer(num as i32) + ast::SwitchValue::U32(value) if uint => { + crate::SwitchValue::U32(value) } ast::SwitchValue::Default => crate::SwitchValue::Default, _ => { diff --git a/src/lib.rs b/src/lib.rs index 1ab1ce07f0..87975d88c8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1445,12 +1445,13 @@ pub use block::Block; /// The value of the switch case. // Clone is used only for error reporting and is not intended for end users -#[derive(Clone, Debug)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum SwitchValue { - Integer(i32), + I32(i32), + U32(u32), Default, } diff --git a/src/valid/function.rs b/src/valid/function.rs index a6461596e3..a13a07bcfa 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -103,7 +103,9 @@ pub enum FunctionError { #[error("The `switch` value {0:?} is not an integer scalar")] InvalidSwitchType(Handle), #[error("Multiple `switch` cases for {0:?} are present")] - ConflictingSwitchCase(i32), + ConflictingSwitchCase(crate::SwitchValue), + #[error("The `switch` contains cases with conflicting types")] + ConflictingCaseType, #[error("The `switch` is missing a `default` case")] MissingDefaultCase, #[error("Multiple `default` cases are present")] @@ -434,52 +436,55 @@ impl super::Validator { selector, ref cases, } => { - match *context.resolve_type(selector, &self.valid_expression_set)? { - Ti::Scalar { - kind: crate::ScalarKind::Uint, - width: _, - } => {} - Ti::Scalar { - kind: crate::ScalarKind::Sint, - width: _, - } => {} + let uint = match context + .resolve_type(selector, &self.valid_expression_set)? + .scalar_kind() + { + Some(crate::ScalarKind::Uint) => true, + Some(crate::ScalarKind::Sint) => false, _ => { return Err(FunctionError::InvalidSwitchType(selector) .with_span_handle(selector, context.expressions)) } - } - self.select_cases.clear(); - let mut default = false; + }; + self.switch_values.clear(); for case in cases { match case.value { - crate::SwitchValue::Integer(value) => { - if !self.select_cases.insert(value) { - return Err(FunctionError::ConflictingSwitchCase(value) - .with_span_static( - case.body - .span_iter() - .next() - .map_or(Default::default(), |(_, s)| *s), - "conflicting switch arm here", - )); - } - } - crate::SwitchValue::Default => { - if default { - return Err(FunctionError::MultipleDefaultCases - .with_span_static( - case.body - .span_iter() - .next() - .map_or(Default::default(), |(_, s)| *s), - "duplicated switch arm here", - )); - } - default = true + crate::SwitchValue::I32(_) if !uint => {} + crate::SwitchValue::U32(_) if uint => {} + crate::SwitchValue::Default => {} + _ => { + return Err(FunctionError::ConflictingCaseType.with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "conflicting switch arm here", + )); } + }; + if !self.switch_values.insert(case.value) { + return Err(match case.value { + crate::SwitchValue::Default => FunctionError::MultipleDefaultCases + .with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "duplicated switch arm here", + ), + _ => FunctionError::ConflictingSwitchCase(case.value) + .with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "conflicting switch arm here", + ), + }); } } - if !default { + if !self.switch_values.contains(&crate::SwitchValue::Default) { return Err(FunctionError::MissingDefaultCase .with_span_static(span, "missing default case")); } diff --git a/src/valid/mod.rs b/src/valid/mod.rs index f9f3930b22..6b3a2e1456 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -165,7 +165,7 @@ pub struct Validator { location_mask: BitSet, bind_group_masks: Vec, #[allow(dead_code)] - select_cases: FastHashSet, + switch_values: FastHashSet, valid_expression_list: Vec>, valid_expression_set: BitSet, } @@ -279,7 +279,7 @@ impl Validator { layouter: Layouter::default(), location_mask: BitSet::new(), bind_group_masks: Vec::new(), - select_cases: FastHashSet::default(), + switch_values: FastHashSet::default(), valid_expression_list: Vec::new(), valid_expression_set: BitSet::new(), } @@ -291,7 +291,7 @@ impl Validator { self.layouter.clear(); self.location_mask.clear(); self.bind_group_masks.clear(); - self.select_cases.clear(); + self.switch_values.clear(); self.valid_expression_list.clear(); self.valid_expression_set.clear(); }