Skip to content

Commit

Permalink
Differentiate between i32 and u32 in switch (gfx-rs#2269)
Browse files Browse the repository at this point in the history
* Differentiate between i32 and u32 in switch

* Use similar wording to other error messages

* Remove duplicate enum
  • Loading branch information
evahop authored and kvark committed Mar 18, 2023
1 parent 8442b95 commit 046f644
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 116 deletions.
2 changes: 1 addition & 1 deletion src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ impl StatementGraph {
for case in cases {
let (case_id, case_last) = self.add(&case.body, targets);
let label = match case.value {
crate::SwitchValue::Integer(_) => "case",
crate::SwitchValue::Default => "default",
_ => "case",
};
self.flow.push((id, case_id, label));
// Link the last node of the branch to the merge node
Expand Down
12 changes: 2 additions & 10 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1908,21 +1908,13 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, "switch(")?;
self.write_expr(selector, ctx)?;
writeln!(self.out, ") {{")?;
let type_postfix = match *ctx.info[selector].ty.inner_with(&self.module.types) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => "u",
_ => "",
};

// Write all cases
let l2 = level.next();
for case in cases {
match case.value {
crate::SwitchValue::Integer(value) => {
write!(self.out, "{l2}case {value}{type_postfix}:")?
}
crate::SwitchValue::I32(value) => write!(self.out, "{l2}case {value}:")?,
crate::SwitchValue::U32(value) => write!(self.out, "{l2}case {value}u:")?,
crate::SwitchValue::Default => write!(self.out, "{l2}default:")?,
}

Expand Down
14 changes: 5 additions & 9 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1883,22 +1883,18 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, "switch(")?;
self.write_expr(module, selector, func_ctx)?;
writeln!(self.out, ") {{")?;
let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => "u",
_ => "",
};

// Write all cases
let indent_level_1 = level.next();
let indent_level_2 = indent_level_1.next();

for (i, case) in cases.iter().enumerate() {
match case.value {
crate::SwitchValue::Integer(value) => {
write!(self.out, "{indent_level_1}case {value}{type_postfix}:")?
crate::SwitchValue::I32(value) => {
write!(self.out, "{indent_level_1}case {value}:")?
}
crate::SwitchValue::U32(value) => {
write!(self.out, "{indent_level_1}case {value}u:")?
}
crate::SwitchValue::Default => {
write!(self.out, "{indent_level_1}default:")?
Expand Down
14 changes: 5 additions & 9 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2614,19 +2614,15 @@ impl<W: Write> Writer<W> {
} => {
write!(self.out, "{level}switch(")?;
self.put_expression(selector, &context.expression, true)?;
let type_postfix = match *context.expression.resolve_type(selector) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => "u",
_ => "",
};
writeln!(self.out, ") {{")?;
let lcase = level.next();
for case in cases.iter() {
match case.value {
crate::SwitchValue::Integer(value) => {
write!(self.out, "{lcase}case {value}{type_postfix}:")?;
crate::SwitchValue::I32(value) => {
write!(self.out, "{lcase}case {value}:")?;
}
crate::SwitchValue::U32(value) => {
write!(self.out, "{lcase}case {value}u:")?;
}
crate::SwitchValue::Default => {
write!(self.out, "{lcase}default:")?;
Expand Down
5 changes: 4 additions & 1 deletion src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1776,12 +1776,15 @@ impl<'w> BlockContext<'w> {
case_ids.push(label_id);

match case.value {
crate::SwitchValue::Integer(value) => {
crate::SwitchValue::I32(value) => {
raw_cases.push(super::instructions::Case {
value: value as Word,
label_id,
});
}
crate::SwitchValue::U32(value) => {
raw_cases.push(super::instructions::Case { value, label_id });
}
crate::SwitchValue::Default => {
default_id = Some(label_id);
}
Expand Down
18 changes: 8 additions & 10 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -865,14 +865,6 @@ impl<W: Write> Writer<W> {
self.write_expr(module, selector, func_ctx)?;
writeln!(self.out, " {{")?;

let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => "u",
_ => "",
};

let l2 = level.next();
let mut new_case = true;
for case in cases {
Expand All @@ -884,11 +876,17 @@ impl<W: Write> Writer<W> {
}

match case.value {
crate::SwitchValue::Integer(value) => {
crate::SwitchValue::I32(value) => {
if new_case {
write!(self.out, "{l2}case ")?;
}
write!(self.out, "{value}")?;
}
crate::SwitchValue::U32(value) => {
if new_case {
write!(self.out, "{l2}case ")?;
}
write!(self.out, "{value}{type_postfix}")?;
write!(self.out, "{value}u")?;
}
crate::SwitchValue::Default => {
if new_case {
Expand Down
54 changes: 26 additions & 28 deletions src/front/glsl/parser/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,35 +186,33 @@ impl<'source> ParsingContext<'source> {
let value = match self.expect_peek(frontend)?.value {
TokenValue::Case => {
self.bump(frontend)?;
let value = {
let mut stmt = ctx.stmt_ctx();
let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
let (root, meta) =
ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
let constant = frontend.solve_constant(ctx, root, meta)?;

match frontend.module.constants[constant].inner {
ConstantInner::Scalar {
value: ScalarValue::Sint(int),
..
} => int as i32,
ConstantInner::Scalar {
value: ScalarValue::Uint(int),
..
} => int as i32,
_ => {
frontend.errors.push(Error {
kind: ErrorKind::SemanticError(
"Case values can only be integers".into(),
),
meta,
});

0
}

let mut stmt = ctx.stmt_ctx();
let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
let (root, meta) =
ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
let constant = frontend.solve_constant(ctx, root, meta)?;

match frontend.module.constants[constant].inner {
ConstantInner::Scalar {
value: ScalarValue::Sint(int),
..
} => crate::SwitchValue::I32(int as i32),
ConstantInner::Scalar {
value: ScalarValue::Uint(int),
..
} => crate::SwitchValue::U32(int as u32),
_ => {
frontend.errors.push(Error {
kind: ErrorKind::SemanticError(
"Case values can only be integers".into(),
),
meta,
});

crate::SwitchValue::I32(0)
}
};
crate::SwitchValue::Integer(value)
}
}
TokenValue::Default => {
self.bump(frontend)?;
Expand Down
2 changes: 1 addition & 1 deletion src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ impl<'function> BlockContext<'function> {
let fall_through = body.last().map_or(true, |s| !s.is_terminator());

crate::SwitchCase {
value: crate::SwitchValue::Integer(value),
value: crate::SwitchValue::I32(value),
body,
fall_through,
}
Expand Down
8 changes: 4 additions & 4 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,11 +1006,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.map(|case| {
Ok(crate::SwitchCase {
value: match case.value {
ast::SwitchValue::I32(num) if !uint => {
crate::SwitchValue::Integer(num)
ast::SwitchValue::I32(value) if !uint => {
crate::SwitchValue::I32(value)
}
ast::SwitchValue::U32(num) if uint => {
crate::SwitchValue::Integer(num as i32)
ast::SwitchValue::U32(value) if uint => {
crate::SwitchValue::U32(value)
}
ast::SwitchValue::Default => crate::SwitchValue::Default,
_ => {
Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1445,12 +1445,13 @@ pub use block::Block;

/// The value of the switch case.
// Clone is used only for error reporting and is not intended for end users
#[derive(Clone, Debug)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub enum SwitchValue {
Integer(i32),
I32(i32),
U32(u32),
Default,
}

Expand Down
81 changes: 43 additions & 38 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ pub enum FunctionError {
#[error("The `switch` value {0:?} is not an integer scalar")]
InvalidSwitchType(Handle<crate::Expression>),
#[error("Multiple `switch` cases for {0:?} are present")]
ConflictingSwitchCase(i32),
ConflictingSwitchCase(crate::SwitchValue),
#[error("The `switch` contains cases with conflicting types")]
ConflictingCaseType,
#[error("The `switch` is missing a `default` case")]
MissingDefaultCase,
#[error("Multiple `default` cases are present")]
Expand Down Expand Up @@ -434,52 +436,55 @@ impl super::Validator {
selector,
ref cases,
} => {
match *context.resolve_type(selector, &self.valid_expression_set)? {
Ti::Scalar {
kind: crate::ScalarKind::Uint,
width: _,
} => {}
Ti::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
} => {}
let uint = match context
.resolve_type(selector, &self.valid_expression_set)?
.scalar_kind()
{
Some(crate::ScalarKind::Uint) => true,
Some(crate::ScalarKind::Sint) => false,
_ => {
return Err(FunctionError::InvalidSwitchType(selector)
.with_span_handle(selector, context.expressions))
}
}
self.select_cases.clear();
let mut default = false;
};
self.switch_values.clear();
for case in cases {
match case.value {
crate::SwitchValue::Integer(value) => {
if !self.select_cases.insert(value) {
return Err(FunctionError::ConflictingSwitchCase(value)
.with_span_static(
case.body
.span_iter()
.next()
.map_or(Default::default(), |(_, s)| *s),
"conflicting switch arm here",
));
}
}
crate::SwitchValue::Default => {
if default {
return Err(FunctionError::MultipleDefaultCases
.with_span_static(
case.body
.span_iter()
.next()
.map_or(Default::default(), |(_, s)| *s),
"duplicated switch arm here",
));
}
default = true
crate::SwitchValue::I32(_) if !uint => {}
crate::SwitchValue::U32(_) if uint => {}
crate::SwitchValue::Default => {}
_ => {
return Err(FunctionError::ConflictingCaseType.with_span_static(
case.body
.span_iter()
.next()
.map_or(Default::default(), |(_, s)| *s),
"conflicting switch arm here",
));
}
};
if !self.switch_values.insert(case.value) {
return Err(match case.value {
crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
.with_span_static(
case.body
.span_iter()
.next()
.map_or(Default::default(), |(_, s)| *s),
"duplicated switch arm here",
),
_ => FunctionError::ConflictingSwitchCase(case.value)
.with_span_static(
case.body
.span_iter()
.next()
.map_or(Default::default(), |(_, s)| *s),
"conflicting switch arm here",
),
});
}
}
if !default {
if !self.switch_values.contains(&crate::SwitchValue::Default) {
return Err(FunctionError::MissingDefaultCase
.with_span_static(span, "missing default case"));
}
Expand Down
6 changes: 3 additions & 3 deletions src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ pub struct Validator {
location_mask: BitSet,
bind_group_masks: Vec<BitSet>,
#[allow(dead_code)]
select_cases: FastHashSet<i32>,
switch_values: FastHashSet<crate::SwitchValue>,
valid_expression_list: Vec<Handle<crate::Expression>>,
valid_expression_set: BitSet,
}
Expand Down Expand Up @@ -279,7 +279,7 @@ impl Validator {
layouter: Layouter::default(),
location_mask: BitSet::new(),
bind_group_masks: Vec::new(),
select_cases: FastHashSet::default(),
switch_values: FastHashSet::default(),
valid_expression_list: Vec::new(),
valid_expression_set: BitSet::new(),
}
Expand All @@ -291,7 +291,7 @@ impl Validator {
self.layouter.clear();
self.location_mask.clear();
self.bind_group_masks.clear();
self.select_cases.clear();
self.switch_values.clear();
self.valid_expression_list.clear();
self.valid_expression_set.clear();
}
Expand Down

0 comments on commit 046f644

Please sign in to comment.