Skip to content

Commit

Permalink
[naga] Implement quantizeToF16
Browse files Browse the repository at this point in the history
Implement WGSL frontend and WGSL, SPIR-V, HLSL, MSL, and GLSL
backends. WGSL and SPIR-V backends natively support the instruction.
MSL and HLSL emulate it by casting to f16 and back to f32. GLSL does
similar but must (mis)use (un)pack2x16 to do so.
  • Loading branch information
jamienicol committed Nov 11, 2024
1 parent 6a60458 commit 829f4e2
Show file tree
Hide file tree
Showing 17 changed files with 210 additions and 75 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Bottom level categories:
- Parse `diagnostic(…)` directives, but don't implement any triggering rules yet. By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456).
- Fix an issue where `naga` CLI would incorrectly skip the first positional argument when `--stdin-file-path` was specified. By @ErichDonGubler in [#6480](https://github.com/gfx-rs/wgpu/pull/6480).
- Fix textureNumLevels in the GLSL backend. By @magcius in [#6483](https://github.com/gfx-rs/wgpu/pull/6483).
- Implement `quantizeToF16()` for WGSL frontend, and WGSL, SPIR-V, HLSL, MSL, and GLSL backends. By @jamienicol in [#6519](https://github.com/gfx-rs/wgpu/pull/6519).

#### General

Expand Down
43 changes: 43 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3722,6 +3722,49 @@ impl<'a, W: Write> Writer<'a, W> {

return Ok(());
}
// FIXME: move to correct category
Mf::QuantizeToF16 => match *ctx.resolve_type(arg, &self.module.types) {
crate::TypeInner::Scalar { .. } => {
write!(self.out, "unpackHalf2x16(packHalf2x16(vec2(")?;
self.write_expr(arg, ctx)?;
write!(self.out, "))).x")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Bi,
..
} => {
write!(self.out, "unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, "))")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Tri,
..
} => {
write!(self.out, "vec3(unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".zz)).x)")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Quad,
..
} => {
write!(self.out, "vec4(unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".zw)))")?;
return Ok(());
}
_ => unreachable!(
"Correct TypeInner for QuantizeToF16 should be already validated"
),
},
};

let extract_bits = fun == Mf::ExtractBits;
Expand Down
8 changes: 8 additions & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3036,6 +3036,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Unpack4x8unorm,
Unpack4xI8,
Unpack4xU8,
QuantizeToF16,
Regular(&'static str),
MissingIntOverload(&'static str),
MissingIntReturnType(&'static str),
Expand Down Expand Up @@ -3127,6 +3128,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
Mf::Unpack4xI8 => Function::Unpack4xI8,
Mf::Unpack4xU8 => Function::Unpack4xU8,
// FIXME: move to correct location
Mf::QuantizeToF16 => Function::QuantizeToF16,
_ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
};

Expand Down Expand Up @@ -3303,6 +3306,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24")?;
}
Function::QuantizeToF16 => {
write!(self.out, "f16tof32(f32tof16(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
}
Function::Regular(fun_name) => {
write!(self.out, "{fun_name}(")?;
self.write_expr(module, arg, func_ctx)?;
Expand Down
18 changes: 18 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,8 @@ impl<W: Write> Writer<W> {
Mf::Unpack2x16float => "",
Mf::Unpack4xI8 => "",
Mf::Unpack4xU8 => "",
// FIXME: move to correct category
Mf::QuantizeToF16 => "",
};

match fun {
Expand Down Expand Up @@ -2144,6 +2146,22 @@ impl<W: Write> Writer<W> {
self.put_expression(arg, context, true)?;
write!(self.out, " >> 24) << 24 >> 24")?;
}
Mf::QuantizeToF16 => {
match *context.resolve_type(arg) {
crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
crate::TypeInner::Vector { size, .. } => write!(
self.out,
"{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
size = back::vector_size_str(size),
)?,
_ => unreachable!(
"Correct TypeInner for QuantizeToF16 should be already validated"
),
};

self.put_expression(arg, context, true)?;
write!(self.out, "))")?;
}
_ => {
write!(self.out, "{NAMESPACE}::{fun_name}")?;
self.put_call_parameters(
Expand Down
7 changes: 7 additions & 0 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,13 @@ impl<'w> BlockContext<'w> {

MathOp::Custom(Instruction::composite_construct(result_type_id, id, &parts))
}
// FIXME: move to correct category
Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
spirv::Op::QuantizeToF16,
result_type_id,
id,
arg0_id,
)),
};

block.body.push(match math_op {
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,8 @@ impl<W: Write> Writer<W> {
Function::InversePolyfill(overload)
}
Mf::Outer => return Err(Error::UnsupportedMathFunction(fun)),
// FIXME: move to correct category
Mf::QuantizeToF16 => Function::Regular("quantizeToF16"),
};

match function {
Expand Down
2 changes: 2 additions & 0 deletions naga/src/front/wgsl/parse/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
"unpack2x16float" => Mf::Unpack2x16float,
"unpack4xI8" => Mf::Unpack4xI8,
"unpack4xU8" => Mf::Unpack4xU8,
// FIXME: move to correct category
"quantizeToF16" => Mf::QuantizeToF16,
_ => return None,
})
}
Expand Down
2 changes: 2 additions & 0 deletions naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,8 @@ pub enum MathFunction {
Unpack2x16float,
Unpack4xI8,
Unpack4xU8,
// FIXME: put in correct category
QuantizeToF16,
}

/// Sampling modifier to control the level of detail.
Expand Down
2 changes: 2 additions & 0 deletions naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ impl super::MathFunction {
Self::Unpack2x16float => 1,
Self::Unpack4xI8 => 1,
Self::Unpack4xU8 => 1,
// FIXME: move to correct category
Self::QuantizeToF16 => 1,
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion naga/src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ impl<'a> ResolveContext<'a> {
| Mf::Exp2
| Mf::Log
| Mf::Log2
| Mf::Pow => res_arg.clone(),
| Mf::Pow
| Mf::QuantizeToF16 => res_arg.clone(),
Mf::Modf | Mf::Frexp => {
let (size, width) = match res_arg.inner_with(types) {
&Ti::Scalar(crate::Scalar {
Expand Down
21 changes: 21 additions & 0 deletions naga/src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,27 @@ impl super::Validator {
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
// FIXME: move to correct category
Mf::QuantizeToF16 => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Float,
width: 4,
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Float,
width: 4,
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
}
ShaderStages::all()
}
Expand Down
4 changes: 4 additions & 0 deletions naga/tests/in/math-functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,8 @@ fn main() {
let frexp_b = frexp(1.5).fract;
let frexp_c: i32 = frexp(1.5).exp;
let frexp_d: i32 = frexp(vec4(1.5, 1.5, 1.5, 1.5)).exp.x;
let quantizeToF16_a: f32 = quantizeToF16(1.0);
let quantizeToF16_b: vec2<f32> = quantizeToF16(vec2(1.0, 1.0));
let quantizeToF16_c: vec3<f32> = quantizeToF16(vec3(1.0, 1.0, 1.0));
let quantizeToF16_d: vec4<f32> = quantizeToF16(vec4(1.0, 1.0, 1.0, 1.0));
}
4 changes: 4 additions & 0 deletions naga/tests/out/glsl/math-functions.main.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,9 @@ void main() {
float frexp_b = naga_frexp(1.5).fract_;
int frexp_c = naga_frexp(1.5).exp_;
int frexp_d = naga_frexp(vec4(1.5, 1.5, 1.5, 1.5)).exp_.x;
float quantizeToF16_a = unpackHalf2x16(packHalf2x16(vec2(1.0))).x;
vec2 quantizeToF16_b = unpackHalf2x16(packHalf2x16(vec2(1.0, 1.0)));
vec3 quantizeToF16_c = vec3(unpackHalf2x16(packHalf2x16(vec3(1.0, 1.0, 1.0).xy)), unpackHalf2x16(packHalf2x16(vec3(1.0, 1.0, 1.0).zz)).x);
vec4 quantizeToF16_d = vec4(unpackHalf2x16(packHalf2x16(vec4(1.0, 1.0, 1.0, 1.0).xy)), unpackHalf2x16(packHalf2x16(vec4(1.0, 1.0, 1.0, 1.0).zw)));
}

4 changes: 4 additions & 0 deletions naga/tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,8 @@ void main()
float frexp_b = naga_frexp(1.5).fract;
int frexp_c = naga_frexp(1.5).exp_;
int frexp_d = naga_frexp(float4(1.5, 1.5, 1.5, 1.5)).exp_.x;
float quantizeToF16_a = f16tof32(f32tof16(1.0));
float2 quantizeToF16_b = f16tof32(f32tof16(float2(1.0, 1.0)));
float3 quantizeToF16_c = f16tof32(f32tof16(float3(1.0, 1.0, 1.0)));
float4 quantizeToF16_d = f16tof32(f32tof16(float4(1.0, 1.0, 1.0, 1.0)));
}
4 changes: 4 additions & 0 deletions naga/tests/out/msl/math-functions.msl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,8 @@ fragment void main_(
float frexp_b = naga_frexp(1.5).fract;
int frexp_c = naga_frexp(1.5).exp;
int frexp_d = naga_frexp(metal::float4(1.5, 1.5, 1.5, 1.5)).exp.x;
float quantizeToF16_a = float(half(1.0));
metal::float2 quantizeToF16_b = metal::float2(metal::half2(metal::float2(1.0, 1.0)));
metal::float3 quantizeToF16_c = metal::float3(metal::half3(metal::float3(1.0, 1.0, 1.0)));
metal::float4 quantizeToF16_d = metal::float4(metal::half4(metal::float4(1.0, 1.0, 1.0, 1.0)));
}
Loading

0 comments on commit 829f4e2

Please sign in to comment.