Skip to content

Commit

Permalink
Implement Pack/Unpack for HLSL (#2353)
Browse files Browse the repository at this point in the history
  • Loading branch information
Elabajaba authored Jun 23, 2023
1 parent ffe2308 commit adf1cca
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 3 deletions.
233 changes: 231 additions & 2 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,40 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.need_bake_expressions.insert(fun_handle);
}

if let Expression::Math { fun, arg, .. } = *expr {
if let Expression::Math {
fun,
arg,
arg1,
arg2,
arg3,
} = *expr
{
match fun {
crate::MathFunction::Asinh
| crate::MathFunction::Acosh
| crate::MathFunction::Atanh
| crate::MathFunction::Unpack2x16float => {
| crate::MathFunction::Unpack2x16float
| crate::MathFunction::Unpack2x16snorm
| crate::MathFunction::Unpack2x16unorm
| crate::MathFunction::Unpack4x8snorm
| crate::MathFunction::Unpack4x8unorm
| crate::MathFunction::Pack2x16float
| crate::MathFunction::Pack2x16snorm
| crate::MathFunction::Pack2x16unorm
| crate::MathFunction::Pack4x8snorm
| crate::MathFunction::Pack4x8unorm => {
self.need_bake_expressions.insert(arg);
}
crate::MathFunction::ExtractBits => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
self.need_bake_expressions.insert(arg2.unwrap());
}
crate::MathFunction::InsertBits => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
self.need_bake_expressions.insert(arg2.unwrap());
self.need_bake_expressions.insert(arg3.unwrap());
}
crate::MathFunction::CountLeadingZeros => {
let inner = info[fun_handle].ty.inner_with(&module.types);
Expand Down Expand Up @@ -2589,7 +2616,18 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
enum Function {
Asincosh { is_sin: bool },
Atanh,
ExtractBits,
InsertBits,
Pack2x16float,
Pack2x16snorm,
Pack2x16unorm,
Pack4x8snorm,
Pack4x8unorm,
Unpack2x16float,
Unpack2x16snorm,
Unpack2x16unorm,
Unpack4x8snorm,
Unpack4x8unorm,
Regular(&'static str),
MissingIntOverload(&'static str),
MissingIntReturnType(&'static str),
Expand Down Expand Up @@ -2663,7 +2701,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"),
Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"),
Mf::ExtractBits => Function::ExtractBits,
Mf::InsertBits => Function::InsertBits,
// Data Packing
Mf::Pack2x16float => Function::Pack2x16float,
Mf::Pack2x16snorm => Function::Pack2x16snorm,
Mf::Pack2x16unorm => Function::Pack2x16unorm,
Mf::Pack4x8snorm => Function::Pack4x8snorm,
Mf::Pack4x8unorm => Function::Pack4x8unorm,
// Data Unpacking
Mf::Unpack2x16float => Function::Unpack2x16float,
Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
_ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
};

Expand All @@ -2687,13 +2738,191 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
}
Function::ExtractBits => {
// e: T,
// offset: u32,
// count: u32
// T is u32 or i32 or vecN<u32> or vecN<i32>
if let (Some(offset), Some(count)) = (arg1, arg2) {
let scalar_width: u8 = 32;
// Works for signed and unsigned
// (count == 0 ? 0 : (e << (32 - count - offset)) >> (32 - count))
write!(self.out, "(")?;
self.write_expr(module, count, func_ctx)?;
write!(self.out, " == 0 ? 0 : (")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " << ({scalar_width} - ")?;
self.write_expr(module, count, func_ctx)?;
write!(self.out, " - ")?;
self.write_expr(module, offset, func_ctx)?;
write!(self.out, ")) >> ({scalar_width} - ")?;
self.write_expr(module, count, func_ctx)?;
write!(self.out, "))")?;
}
}
Function::InsertBits => {
// e: T,
// newbits: T,
// offset: u32,
// count: u32
// returns T
// T is i32, u32, vecN<i32>, or vecN<u32>
if let (Some(newbits), Some(offset), Some(count)) = (arg1, arg2, arg3) {
let scalar_width: u8 = 32;
let scalar_max: u32 = 0xFFFFFFFF;
// mask = ((0xFFFFFFFFu >> (32 - count)) << offset)
// (count == 0 ? e : ((e & ~mask) | ((newbits << offset) & mask)))
write!(self.out, "(")?;
self.write_expr(module, count, func_ctx)?;
write!(self.out, " == 0 ? ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " : ")?;
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " & ~")?;
// mask
write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
self.write_expr(module, count, func_ctx)?;
write!(self.out, ")) << ")?;
self.write_expr(module, offset, func_ctx)?;
write!(self.out, ")")?;
// end mask
write!(self.out, ") | ((")?;
self.write_expr(module, newbits, func_ctx)?;
write!(self.out, " << ")?;
self.write_expr(module, offset, func_ctx)?;
write!(self.out, ") & ")?;
// // mask
write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
self.write_expr(module, count, func_ctx)?;
write!(self.out, ")) << ")?;
self.write_expr(module, offset, func_ctx)?;
write!(self.out, ")")?;
// // end mask
write!(self.out, "))")?;
}
}
Function::Pack2x16float => {
write!(self.out, "(f32tof16(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "[0]) | f32tof16(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "[1]) << 16)")?;
}
Function::Pack2x16snorm => {
let scale = 32767;

write!(self.out, "uint((int(round(clamp(")?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
"[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
}
Function::Pack2x16unorm => {
let scale = 65535;

write!(self.out, "(uint(round(clamp(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
}
Function::Pack4x8snorm => {
let scale = 127;

write!(self.out, "uint((int(round(clamp(")?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
"[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
"[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
"[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
}
Function::Pack4x8unorm => {
let scale = 255;

write!(self.out, "(uint(round(clamp(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
"[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
"[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
}

Function::Unpack2x16float => {
write!(self.out, "float2(f16tof32(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "), f16tof32((")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ") >> 16))")?;
}
Function::Unpack2x16snorm => {
let scale = 32767;

write!(self.out, "(float2(int2(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " << 16, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ") >> 16) / {scale}.0)")?;
}
Function::Unpack2x16unorm => {
let scale = 65535;

write!(self.out, "(float2(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " & 0xFFFF, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 16) / {scale}.0)")?;
}
Function::Unpack4x8snorm => {
let scale = 127;

write!(self.out, "(float4(int4(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " << 24, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " << 16, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " << 8, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ") >> 24) / {scale}.0)")?;
}
Function::Unpack4x8unorm => {
let scale = 255;

write!(self.out, "(float4(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " & 0xFF, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 8 & 0xFF, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 16 & 0xFF, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) / {scale}.0)")?;
}
Function::Regular(fun_name) => {
write!(self.out, "{fun_name}(")?;
self.write_expr(module, arg, func_ctx)?;
Expand Down
11 changes: 11 additions & 0 deletions src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,17 @@ impl super::TypeInner {
}
}

pub const fn scalar_width(&self) -> Option<u8> {
// Multiply by 8 to get the bit width
match *self {
super::TypeInner::Scalar { width, .. } | super::TypeInner::Vector { width, .. } => {
Some(width * 8)
}
super::TypeInner::Matrix { width, .. } => Some(width * 8),
_ => None,
}
}

pub const fn pointer_space(&self) -> Option<crate::AddressSpace> {
match *self {
Self::Pointer { space, .. } => Some(space),
Expand Down
Loading

0 comments on commit adf1cca

Please sign in to comment.