diff --git a/CHANGELOG.md b/CHANGELOG.md index 80b811e89d..8572220adf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216). - Support for more atomic ops in the SPIR-V frontend. By @schell in [#5824](https://github.com/gfx-rs/wgpu/pull/5824). - Support local `const` declarations in WGSL. By @sagudev in [#6156](https://github.com/gfx-rs/wgpu/pull/6156). - Implemented `const_assert` in WGSL. By @sagudev in [#6198](https://github.com/gfx-rs/wgpu/pull/6198). +- Support polyfilling `inverse` in WGSL. By @chyyran in [#6385](https://github.com/gfx-rs/wgpu/pull/6385). #### General diff --git a/naga/src/back/wgsl/mod.rs b/naga/src/back/wgsl/mod.rs index d731b1ca0c..ecf59698a8 100644 --- a/naga/src/back/wgsl/mod.rs +++ b/naga/src/back/wgsl/mod.rs @@ -4,6 +4,7 @@ Backend for [WGSL][wgsl] (WebGPU Shading Language). [wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html */ +mod polyfill; mod writer; use thiserror::Error; diff --git a/naga/src/back/wgsl/polyfill/inverse/inverse_2x2_f16.wgsl b/naga/src/back/wgsl/polyfill/inverse/inverse_2x2_f16.wgsl new file mode 100644 index 0000000000..33d4c37acf --- /dev/null +++ b/naga/src/back/wgsl/polyfill/inverse/inverse_2x2_f16.wgsl @@ -0,0 +1,10 @@ +fn _naga_inverse_2x2_f16(m: mat2x2) -> mat2x2 { + var adj: mat2x2; + adj[0][0] = m[1][1]; + adj[0][1] = -m[0][1]; + adj[1][0] = -m[1][0]; + adj[1][1] = m[0][0]; + + let det: f16 = m[0][0] * m[1][1] - m[1][0] * m[0][1]; + return adj * (1 / det); +} \ No newline at end of file diff --git a/naga/src/back/wgsl/polyfill/inverse/inverse_2x2_f32.wgsl b/naga/src/back/wgsl/polyfill/inverse/inverse_2x2_f32.wgsl new file mode 100644 index 0000000000..1a2d06e511 --- /dev/null +++ b/naga/src/back/wgsl/polyfill/inverse/inverse_2x2_f32.wgsl @@ -0,0 +1,10 @@ +fn _naga_inverse_2x2_f32(m: mat2x2) -> mat2x2 { + var adj: mat2x2; + adj[0][0] = m[1][1]; + adj[0][1] = -m[0][1]; + adj[1][0] = -m[1][0]; + adj[1][1] = m[0][0]; + + let det: f32 = m[0][0] * m[1][1] - m[1][0] * m[0][1]; + return adj * (1 / det); +} \ No newline at end of file diff --git a/naga/src/back/wgsl/polyfill/inverse/inverse_3x3_f16.wgsl b/naga/src/back/wgsl/polyfill/inverse/inverse_3x3_f16.wgsl new file mode 100644 index 0000000000..ddab745254 --- /dev/null +++ b/naga/src/back/wgsl/polyfill/inverse/inverse_3x3_f16.wgsl @@ -0,0 +1,19 @@ +fn _naga_inverse_3x3_f16(m: mat3x3) -> mat3x3 { + var adj: mat3x3; + + adj[0][0] = (m[1][1] * m[2][2] - m[2][1] * m[1][2]); + adj[1][0] = - (m[1][0] * m[2][2] - m[2][0] * m[1][2]); + adj[2][0] = (m[1][0] * m[2][1] - m[2][0] * m[1][1]); + adj[0][1] = - (m[0][1] * m[2][2] - m[2][1] * m[0][2]); + adj[1][1] = (m[0][0] * m[2][2] - m[2][0] * m[0][2]); + adj[2][1] = - (m[0][0] * m[2][1] - m[2][0] * m[0][1]); + adj[0][2] = (m[0][1] * m[1][2] - m[1][1] * m[0][2]); + adj[1][2] = - (m[0][0] * m[1][2] - m[1][0] * m[0][2]); + adj[2][2] = (m[0][0] * m[1][1] - m[1][0] * m[0][1]); + + let det: f16 = (m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1]) + - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0]) + + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0])); + + return adj * (1 / det); +} \ No newline at end of file diff --git a/naga/src/back/wgsl/polyfill/inverse/inverse_3x3_f32.wgsl b/naga/src/back/wgsl/polyfill/inverse/inverse_3x3_f32.wgsl new file mode 100644 index 0000000000..270198e232 --- /dev/null +++ b/naga/src/back/wgsl/polyfill/inverse/inverse_3x3_f32.wgsl @@ -0,0 +1,19 @@ +fn _naga_inverse_3x3_f32(m: mat3x3) -> mat3x3 { + var adj: mat3x3; + + adj[0][0] = (m[1][1] * m[2][2] - m[2][1] * m[1][2]); + adj[1][0] = - (m[1][0] * m[2][2] - m[2][0] * m[1][2]); + adj[2][0] = (m[1][0] * m[2][1] - m[2][0] * m[1][1]); + adj[0][1] = - (m[0][1] * m[2][2] - m[2][1] * m[0][2]); + adj[1][1] = (m[0][0] * m[2][2] - m[2][0] * m[0][2]); + adj[2][1] = - (m[0][0] * m[2][1] - m[2][0] * m[0][1]); + adj[0][2] = (m[0][1] * m[1][2] - m[1][1] * m[0][2]); + adj[1][2] = - (m[0][0] * m[1][2] - m[1][0] * m[0][2]); + adj[2][2] = (m[0][0] * m[1][1] - m[1][0] * m[0][1]); + + let det: f32 = (m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1]) + - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0]) + + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0])); + + return adj * (1 / det); +} \ No newline at end of file diff --git a/naga/src/back/wgsl/polyfill/inverse/inverse_4x4_f16.wgsl b/naga/src/back/wgsl/polyfill/inverse/inverse_4x4_f16.wgsl new file mode 100644 index 0000000000..ce88fc2055 --- /dev/null +++ b/naga/src/back/wgsl/polyfill/inverse/inverse_4x4_f16.wgsl @@ -0,0 +1,43 @@ +fn _naga_inverse_4x4_f16(m: mat4x4) -> mat4x4 { + let sub_factor00: f16 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + let sub_factor01: f16 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; + let sub_factor02: f16 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; + let sub_factor03: f16 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; + let sub_factor04: f16 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; + let sub_factor05: f16 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; + let sub_factor06: f16 = m[1][2] * m[3][3] - m[3][2] * m[1][3]; + let sub_factor07: f16 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; + let sub_factor08: f16 = m[1][1] * m[3][2] - m[3][1] * m[1][2]; + let sub_factor09: f16 = m[1][0] * m[3][3] - m[3][0] * m[1][3]; + let sub_factor10: f16 = m[1][0] * m[3][2] - m[3][0] * m[1][2]; + let sub_factor11: f16 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; + let sub_factor12: f16 = m[1][0] * m[3][1] - m[3][0] * m[1][1]; + let sub_factor13: f16 = m[1][2] * m[2][3] - m[2][2] * m[1][3]; + let sub_factor14: f16 = m[1][1] * m[2][3] - m[2][1] * m[1][3]; + let sub_factor15: f16 = m[1][1] * m[2][2] - m[2][1] * m[1][2]; + let sub_factor16: f16 = m[1][0] * m[2][3] - m[2][0] * m[1][3]; + let sub_factor17: f16 = m[1][0] * m[2][2] - m[2][0] * m[1][2]; + let sub_factor18: f16 = m[1][0] * m[2][1] - m[2][0] * m[1][1]; + + var adj: mat4x4; + adj[0][0] = (m[1][1] * sub_factor00 - m[1][2] * sub_factor01 + m[1][3] * sub_factor02); + adj[1][0] = - (m[1][0] * sub_factor00 - m[1][2] * sub_factor03 + m[1][3] * sub_factor04); + adj[2][0] = (m[1][0] * sub_factor01 - m[1][1] * sub_factor03 + m[1][3] * sub_factor05); + adj[3][0] = - (m[1][0] * sub_factor02 - m[1][1] * sub_factor04 + m[1][2] * sub_factor05); + adj[0][1] = - (m[0][1] * sub_factor00 - m[0][2] * sub_factor01 + m[0][3] * sub_factor02); + adj[1][1] = (m[0][0] * sub_factor00 - m[0][2] * sub_factor03 + m[0][3] * sub_factor04); + adj[2][1] = - (m[0][0] * sub_factor01 - m[0][1] * sub_factor03 + m[0][3] * sub_factor05); + adj[3][1] = (m[0][0] * sub_factor02 - m[0][1] * sub_factor04 + m[0][2] * sub_factor05); + adj[0][2] = (m[0][1] * sub_factor06 - m[0][2] * sub_factor07 + m[0][3] * sub_factor08); + adj[1][2] = - (m[0][0] * sub_factor06 - m[0][2] * sub_factor09 + m[0][3] * sub_factor10); + adj[2][2] = (m[0][0] * sub_factor11 - m[0][1] * sub_factor09 + m[0][3] * sub_factor12); + adj[3][2] = - (m[0][0] * sub_factor08 - m[0][1] * sub_factor10 + m[0][2] * sub_factor12); + adj[0][3] = - (m[0][1] * sub_factor13 - m[0][2] * sub_factor14 + m[0][3] * sub_factor15); + adj[1][3] = (m[0][0] * sub_factor13 - m[0][2] * sub_factor16 + m[0][3] * sub_factor17); + adj[2][3] = - (m[0][0] * sub_factor14 - m[0][1] * sub_factor16 + m[0][3] * sub_factor18); + adj[3][3] = (m[0][0] * sub_factor15 - m[0][1] * sub_factor17 + m[0][2] * sub_factor18); + + let det = (m[0][0] * adj[0][0] + m[0][1] * adj[1][0] + m[0][2] * adj[2][0] + m[0][3] * adj[3][0]); + + return adj * (1 / det); +} \ No newline at end of file diff --git a/naga/src/back/wgsl/polyfill/inverse/inverse_4x4_f32.wgsl b/naga/src/back/wgsl/polyfill/inverse/inverse_4x4_f32.wgsl new file mode 100644 index 0000000000..a1bbca97bb --- /dev/null +++ b/naga/src/back/wgsl/polyfill/inverse/inverse_4x4_f32.wgsl @@ -0,0 +1,43 @@ +fn _naga_inverse_4x4_f32(m: mat4x4) -> mat4x4 { + let sub_factor00: f32 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + let sub_factor01: f32 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; + let sub_factor02: f32 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; + let sub_factor03: f32 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; + let sub_factor04: f32 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; + let sub_factor05: f32 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; + let sub_factor06: f32 = m[1][2] * m[3][3] - m[3][2] * m[1][3]; + let sub_factor07: f32 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; + let sub_factor08: f32 = m[1][1] * m[3][2] - m[3][1] * m[1][2]; + let sub_factor09: f32 = m[1][0] * m[3][3] - m[3][0] * m[1][3]; + let sub_factor10: f32 = m[1][0] * m[3][2] - m[3][0] * m[1][2]; + let sub_factor11: f32 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; + let sub_factor12: f32 = m[1][0] * m[3][1] - m[3][0] * m[1][1]; + let sub_factor13: f32 = m[1][2] * m[2][3] - m[2][2] * m[1][3]; + let sub_factor14: f32 = m[1][1] * m[2][3] - m[2][1] * m[1][3]; + let sub_factor15: f32 = m[1][1] * m[2][2] - m[2][1] * m[1][2]; + let sub_factor16: f32 = m[1][0] * m[2][3] - m[2][0] * m[1][3]; + let sub_factor17: f32 = m[1][0] * m[2][2] - m[2][0] * m[1][2]; + let sub_factor18: f32 = m[1][0] * m[2][1] - m[2][0] * m[1][1]; + + var adj: mat4x4; + adj[0][0] = (m[1][1] * sub_factor00 - m[1][2] * sub_factor01 + m[1][3] * sub_factor02); + adj[1][0] = - (m[1][0] * sub_factor00 - m[1][2] * sub_factor03 + m[1][3] * sub_factor04); + adj[2][0] = (m[1][0] * sub_factor01 - m[1][1] * sub_factor03 + m[1][3] * sub_factor05); + adj[3][0] = - (m[1][0] * sub_factor02 - m[1][1] * sub_factor04 + m[1][2] * sub_factor05); + adj[0][1] = - (m[0][1] * sub_factor00 - m[0][2] * sub_factor01 + m[0][3] * sub_factor02); + adj[1][1] = (m[0][0] * sub_factor00 - m[0][2] * sub_factor03 + m[0][3] * sub_factor04); + adj[2][1] = - (m[0][0] * sub_factor01 - m[0][1] * sub_factor03 + m[0][3] * sub_factor05); + adj[3][1] = (m[0][0] * sub_factor02 - m[0][1] * sub_factor04 + m[0][2] * sub_factor05); + adj[0][2] = (m[0][1] * sub_factor06 - m[0][2] * sub_factor07 + m[0][3] * sub_factor08); + adj[1][2] = - (m[0][0] * sub_factor06 - m[0][2] * sub_factor09 + m[0][3] * sub_factor10); + adj[2][2] = (m[0][0] * sub_factor11 - m[0][1] * sub_factor09 + m[0][3] * sub_factor12); + adj[3][2] = - (m[0][0] * sub_factor08 - m[0][1] * sub_factor10 + m[0][2] * sub_factor12); + adj[0][3] = - (m[0][1] * sub_factor13 - m[0][2] * sub_factor14 + m[0][3] * sub_factor15); + adj[1][3] = (m[0][0] * sub_factor13 - m[0][2] * sub_factor16 + m[0][3] * sub_factor17); + adj[2][3] = - (m[0][0] * sub_factor14 - m[0][1] * sub_factor16 + m[0][3] * sub_factor18); + adj[3][3] = (m[0][0] * sub_factor15 - m[0][1] * sub_factor17 + m[0][2] * sub_factor18); + + let det = (m[0][0] * adj[0][0] + m[0][1] * adj[1][0] + m[0][2] * adj[2][0] + m[0][3] * adj[3][0]); + + return adj * (1 / det); +} \ No newline at end of file diff --git a/naga/src/back/wgsl/polyfill/mod.rs b/naga/src/back/wgsl/polyfill/mod.rs new file mode 100644 index 0000000000..970a83a53c --- /dev/null +++ b/naga/src/back/wgsl/polyfill/mod.rs @@ -0,0 +1,66 @@ +use crate::{ScalarKind, TypeInner, VectorSize}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct InversePolyfill { + pub fun_name: &'static str, + pub source: &'static str, +} + +impl InversePolyfill { + pub fn find_overload(ty: &TypeInner) -> Option { + let &TypeInner::Matrix { + columns, + rows, + scalar, + } = ty + else { + return None; + }; + + if columns != rows || scalar.kind != ScalarKind::Float { + return None; + }; + + Self::polyfill_overload(columns, scalar.width) + } + + const fn polyfill_overload( + dimension: VectorSize, + width: crate::Bytes, + ) -> Option { + const INVERSE_2X2_F32: &str = include_str!("inverse/inverse_2x2_f32.wgsl"); + const INVERSE_3X3_F32: &str = include_str!("inverse/inverse_3x3_f32.wgsl"); + const INVERSE_4X4_F32: &str = include_str!("inverse/inverse_4x4_f32.wgsl"); + const INVERSE_2X2_F16: &str = include_str!("inverse/inverse_2x2_f16.wgsl"); + const INVERSE_3X3_F16: &str = include_str!("inverse/inverse_3x3_f16.wgsl"); + const INVERSE_4X4_F16: &str = include_str!("inverse/inverse_4x4_f16.wgsl"); + + match (dimension, width) { + (VectorSize::Bi, 4) => Some(InversePolyfill { + fun_name: "_naga_inverse_2x2_f32", + source: INVERSE_2X2_F32, + }), + (VectorSize::Tri, 4) => Some(InversePolyfill { + fun_name: "_naga_inverse_3x3_f32", + source: INVERSE_3X3_F32, + }), + (VectorSize::Quad, 4) => Some(InversePolyfill { + fun_name: "_naga_inverse_4x4_f32", + source: INVERSE_4X4_F32, + }), + (VectorSize::Bi, 2) => Some(InversePolyfill { + fun_name: "_naga_inverse_2x2_f16", + source: INVERSE_2X2_F16, + }), + (VectorSize::Tri, 2) => Some(InversePolyfill { + fun_name: "_naga_inverse_3x3_f16", + source: INVERSE_3X3_F16, + }), + (VectorSize::Quad, 2) => Some(InversePolyfill { + fun_name: "_naga_inverse_4x4_f16", + source: INVERSE_4X4_F16, + }), + _ => None, + } + } +} diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index f7555df0b0..db3ca32e61 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1,4 +1,5 @@ use super::Error; +use crate::back::wgsl::polyfill::InversePolyfill; use crate::{ back::{self, Baked}, proc::{self, ExpressionKindTracker, NameKey}, @@ -68,6 +69,7 @@ pub struct Writer { namer: proc::Namer, named_expressions: crate::NamedExpressions, ep_results: Vec<(ShaderStage, Handle)>, + required_polyfills: crate::FastIndexSet, } impl Writer { @@ -79,6 +81,7 @@ impl Writer { namer: proc::Namer::default(), named_expressions: crate::NamedExpressions::default(), ep_results: vec![], + required_polyfills: crate::FastIndexSet::default(), } } @@ -95,6 +98,7 @@ impl Writer { ); self.named_expressions.clear(); self.ep_results.clear(); + self.required_polyfills.clear(); } fn is_builtin_wgsl_struct(&self, module: &Module, handle: Handle) -> bool { @@ -203,6 +207,13 @@ impl Writer { } } + // Write any polyfills that were required. + for polyfill in &self.required_polyfills { + writeln!(self.out)?; + write!(self.out, "{}", polyfill.source)?; + writeln!(self.out)?; + } + Ok(()) } @@ -1653,6 +1664,7 @@ impl Writer { enum Function { Regular(&'static str), + InversePolyfill(InversePolyfill), } let function = match fun { @@ -1736,9 +1748,16 @@ impl Writer { Mf::Unpack2x16float => Function::Regular("unpack2x16float"), Mf::Unpack4xI8 => Function::Regular("unpack4xI8"), Mf::Unpack4xU8 => Function::Regular("unpack4xU8"), - Mf::Inverse | Mf::Outer => { - return Err(Error::UnsupportedMathFunction(fun)); + Mf::Inverse => { + let typ = func_ctx.resolve_type(arg, &module.types); + + let Some(overload) = InversePolyfill::find_overload(typ) else { + return Err(Error::UnsupportedMathFunction(fun)); + }; + + Function::InversePolyfill(overload) } + Mf::Outer => return Err(Error::UnsupportedMathFunction(fun)), }; match function { @@ -1751,6 +1770,12 @@ impl Writer { } write!(self.out, ")")? } + Function::InversePolyfill(inverse) => { + write!(self.out, "{}(", inverse.fun_name)?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")")?; + self.required_polyfills.insert(inverse); + } } }