Skip to content

Commit

Permalink
[wgsl-in] implement firstTrailingBit/firstLeadingBit u32 overloads (#…
Browse files Browse the repository at this point in the history
…1865)

* [wgsl-in] implement firstTrailingBit/firstLeadingBit u32 overloads

* fix MSL type issue

reverts b9162e4
  • Loading branch information
teoxoy authored Apr 27, 2022
1 parent 062b66c commit f2e7818
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 70 deletions.
30 changes: 29 additions & 1 deletion src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2806,6 +2806,30 @@ impl<'a, W: Write> Writer<'a, W> {
let extract_bits = fun == Mf::ExtractBits;
let insert_bits = fun == Mf::InsertBits;

// we might need to cast to unsigned integers since
// GLSL's findLSB / findMSB always return signed integers
let need_extra_paren = {
(fun == Mf::FindLsb || fun == Mf::FindMsb)
&& match *ctx.info[arg].ty.inner_with(&self.module.types) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => {
write!(self.out, "uint(")?;
true
}
crate::TypeInner::Vector {
kind: crate::ScalarKind::Uint,
size,
..
} => {
write!(self.out, "uvec{}(", size as u8)?;
true
}
_ => false,
}
};

write!(self.out, "{}(", fun_name)?;
self.write_expr(arg, ctx)?;
if let Some(arg) = arg1 {
Expand Down Expand Up @@ -2838,7 +2862,11 @@ impl<'a, W: Write> Writer<'a, W> {
self.write_expr(arg, ctx)?;
}
}
write!(self.out, ")")?
write!(self.out, ")")?;

if need_extra_paren {
write!(self.out, ")")?
}
}
// `As` is always a call.
// If `convert` is true the function name is the type
Expand Down
31 changes: 4 additions & 27 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1591,21 +1591,6 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Scalar { .. } => true,
_ => false,
};
let argument_size_suffix = match *context.resolve_type(arg) {
crate::TypeInner::Vector {
size: crate::VectorSize::Bi,
..
} => "2",
crate::TypeInner::Vector {
size: crate::VectorSize::Tri,
..
} => "3",
crate::TypeInner::Vector {
size: crate::VectorSize::Quad,
..
} => "4",
_ => "",
};

let fun_name = match fun {
// comparison
Expand Down Expand Up @@ -1705,21 +1690,13 @@ impl<W: Write> Writer<W> {
self.put_expression(arg1.unwrap(), context, false)?;
write!(self.out, ")")?;
} else if fun == Mf::FindLsb {
write!(
self.out,
"(((1 + int{}({}::ctz(",
argument_size_suffix, NAMESPACE
)?;
write!(self.out, "((({}::ctz(", NAMESPACE)?;
self.put_expression(arg, context, true)?;
write!(self.out, "))) % 33) - 1)")?;
write!(self.out, ") + 1) % 33) - 1)")?;
} else if fun == Mf::FindMsb {
write!(
self.out,
"(((1 + int{}({}::clz(",
argument_size_suffix, NAMESPACE
)?;
write!(self.out, "((({}::clz(", NAMESPACE)?;
self.put_expression(arg, context, true)?;
write!(self.out, "))) % 33) - 1)")?;
write!(self.out, ") + 1) % 33) - 1)")?
} else if fun == Mf::Unpack2x16float {
write!(self.out, "float2(as_type<half2>(")?;
self.put_expression(arg, context, false)?;
Expand Down
40 changes: 40 additions & 0 deletions src/front/glsl/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,17 @@ fn inject_standard_builtins(
_ => {}
}

// we need to cast the return type of findLsb / findMsb
let mc = if kind == Sk::Uint {
match mc {
MacroCall::MathFunction(MathFunction::FindLsb) => MacroCall::FindLsbUint,
MacroCall::MathFunction(MathFunction::FindMsb) => MacroCall::FindMsbUint,
mc => mc,
}
} else {
mc
};

declaration.overloads.push(module.add_builtin(args, mc))
}
}
Expand Down Expand Up @@ -1580,6 +1591,8 @@ pub enum MacroCall {
},
ImageStore,
MathFunction(MathFunction),
FindLsbUint,
FindMsbUint,
BitfieldExtract,
BitfieldInsert,
Relational(RelationalFunction),
Expand Down Expand Up @@ -1848,6 +1861,33 @@ impl MacroCall {
Span::default(),
body,
),
mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => {
let fun = match mc {
MacroCall::FindLsbUint => MathFunction::FindLsb,
MacroCall::FindMsbUint => MathFunction::FindMsb,
_ => unreachable!(),
};
let res = ctx.add_expression(
Expression::Math {
fun,
arg: args[0],
arg1: None,
arg2: None,
arg3: None,
},
Span::default(),
body,
);
ctx.add_expression(
Expression::As {
expr: res,
kind: Sk::Sint,
convert: Some(4),
},
Span::default(),
body,
)
}
MacroCall::BitfieldInsert => {
let conv_arg_2 = ctx.add_expression(
Expression::As {
Expand Down
10 changes: 5 additions & 5 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -821,13 +821,13 @@ impl<'a> ResolveContext<'a> {
Mf::CountOneBits |
Mf::ReverseBits |
Mf::ExtractBits |
Mf::InsertBits => res_arg.clone(),
Mf::InsertBits |
Mf::FindLsb |
Mf::FindMsb => match *res_arg.inner_with(types) {
Ti::Scalar { kind: _, width } =>
TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Sint, width }),
Ti::Vector { size, kind: _, width } =>
TypeResolution::Value(Ti::Vector { size, kind: crate::ScalarKind::Sint, width }),
Ti::Scalar { kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
TypeResolution::Value(Ti::Scalar { kind, width }),
Ti::Vector { size, kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
TypeResolution::Value(Ti::Vector { size, kind, width }),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{:?}({:?})", fun, other)
)),
Expand Down
4 changes: 2 additions & 2 deletions tests/in/bits.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn main() {
u3 = extractBits(u3, 5u, 10u);
u4 = extractBits(u4, 5u, 10u);
i = firstTrailingBit(i);
i2 = firstTrailingBit(u2);
u2 = firstTrailingBit(u2);
i3 = firstLeadingBit(i3);
i = firstLeadingBit(u);
u = firstLeadingBit(u);
}
4 changes: 2 additions & 2 deletions tests/out/glsl/bits.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ void main() {
int _e120 = i;
i = findLSB(_e120);
uvec2 _e122 = u2_;
i2_ = findLSB(_e122);
u2_ = uvec2(findLSB(_e122));
ivec3 _e124 = i3_;
i3_ = findMSB(_e124);
uint _e126 = u;
i = findMSB(_e126);
u = uint(findMSB(_e126));
return;
}

8 changes: 4 additions & 4 deletions tests/out/msl/bits.msl
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ kernel void main_(
metal::uint4 _e116 = u4_;
u4_ = metal::extract_bits(_e116, 5u, 10u);
int _e120 = i;
i = (((1 + int(metal::ctz(_e120))) % 33) - 1);
i = (((metal::ctz(_e120) + 1) % 33) - 1);
metal::uint2 _e122 = u2_;
i2_ = (((1 + int2(metal::ctz(_e122))) % 33) - 1);
u2_ = (((metal::ctz(_e122) + 1) % 33) - 1);
metal::int3 _e124 = i3_;
i3_ = (((1 + int3(metal::clz(_e124))) % 33) - 1);
i3_ = (((metal::clz(_e124) + 1) % 33) - 1);
uint _e126 = u;
i = (((1 + int(metal::clz(_e126))) % 33) - 1);
u = (((metal::clz(_e126) + 1) % 33) - 1);
return;
}
8 changes: 4 additions & 4 deletions tests/out/spv/bits.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ OpStore %33 %110
%112 = OpExtInst %4 %1 FindILsb %111
OpStore %19 %112
%113 = OpLoad %14 %29
%114 = OpExtInst %11 %1 FindILsb %113
OpStore %21 %114
%114 = OpExtInst %14 %1 FindILsb %113
OpStore %29 %114
%115 = OpLoad %12 %23
%116 = OpExtInst %12 %1 FindSMsb %115
OpStore %23 %116
%117 = OpLoad %6 %27
%118 = OpExtInst %4 %1 FindUMsb %117
OpStore %19 %118
%118 = OpExtInst %6 %1 FindUMsb %117
OpStore %27 %118
OpReturn
OpFunctionEnd
4 changes: 2 additions & 2 deletions tests/out/wgsl/bits.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ fn main() {
let _e120 = i;
i = firstTrailingBit(_e120);
let _e122 = u2_;
i2_ = firstTrailingBit(_e122);
u2_ = firstTrailingBit(_e122);
let _e124 = i3_;
i3_ = firstLeadingBit(_e124);
let _e126 = u;
i = firstLeadingBit(_e126);
u = firstLeadingBit(_e126);
return;
}
46 changes: 23 additions & 23 deletions tests/out/wgsl/bits_glsl-frag.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -79,29 +79,29 @@ fn main_1() {
let _e232 = i4_;
i4_ = firstTrailingBit(_e232);
let _e235 = u;
i = firstTrailingBit(_e235);
let _e238 = u2_;
i2_ = firstTrailingBit(_e238);
let _e241 = u3_;
i3_ = firstTrailingBit(_e241);
let _e244 = u4_;
i4_ = firstTrailingBit(_e244);
let _e247 = i;
i = firstLeadingBit(_e247);
let _e250 = i2_;
i2_ = firstLeadingBit(_e250);
let _e253 = i3_;
i3_ = firstLeadingBit(_e253);
let _e256 = i4_;
i4_ = firstLeadingBit(_e256);
let _e259 = u;
i = firstLeadingBit(_e259);
let _e262 = u2_;
i2_ = firstLeadingBit(_e262);
let _e265 = u3_;
i3_ = firstLeadingBit(_e265);
let _e268 = u4_;
i4_ = firstLeadingBit(_e268);
i = i32(firstTrailingBit(_e235));
let _e239 = u2_;
i2_ = vec2<i32>(firstTrailingBit(_e239));
let _e243 = u3_;
i3_ = vec3<i32>(firstTrailingBit(_e243));
let _e247 = u4_;
i4_ = vec4<i32>(firstTrailingBit(_e247));
let _e251 = i;
i = firstLeadingBit(_e251);
let _e254 = i2_;
i2_ = firstLeadingBit(_e254);
let _e257 = i3_;
i3_ = firstLeadingBit(_e257);
let _e260 = i4_;
i4_ = firstLeadingBit(_e260);
let _e263 = u;
i = i32(firstLeadingBit(_e263));
let _e267 = u2_;
i2_ = vec2<i32>(firstLeadingBit(_e267));
let _e271 = u3_;
i3_ = vec3<i32>(firstLeadingBit(_e271));
let _e275 = u4_;
i4_ = vec4<i32>(firstLeadingBit(_e275));
return;
}

Expand Down

0 comments on commit f2e7818

Please sign in to comment.