From beddc713a969d9ea149a85c63258868ab1d2a9cd Mon Sep 17 00:00:00 2001 From: Jay Oster Date: Tue, 7 Dec 2021 07:41:53 -0800 Subject: [PATCH 1/7] [hlsl-out] Write `mad` intrinsic for `fma` function - This should be enough because we only support f32 for now. - Adds a new test for WGSL functions, in the spirit of operators.wgsl. - Closes #1579 --- src/back/hlsl/writer.rs | 2 +- tests/in/functions.param.ron | 2 ++ tests/in/functions.wgsl | 15 ++++++++++ tests/out/glsl/functions.main.Compute.glsl | 20 +++++++++++++ tests/out/hlsl/functions.hlsl | 15 ++++++++++ tests/out/hlsl/functions.hlsl.config | 3 ++ tests/out/msl/functions.msl | 18 ++++++++++++ tests/out/spv/functions.spvasm | 33 ++++++++++++++++++++++ tests/out/wgsl/functions.wgsl | 12 ++++++++ tests/snapshots.rs | 4 +++ 10 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 tests/in/functions.param.ron create mode 100644 tests/in/functions.wgsl create mode 100644 tests/out/glsl/functions.main.Compute.glsl create mode 100644 tests/out/hlsl/functions.hlsl create mode 100644 tests/out/hlsl/functions.hlsl.config create mode 100644 tests/out/msl/functions.msl create mode 100644 tests/out/spv/functions.spvasm create mode 100644 tests/out/wgsl/functions.wgsl diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 6f05f1bd10..00b19b5f23 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1861,7 +1861,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Mf::Refract => Function::Regular("refract"), // computational Mf::Sign => Function::Regular("sign"), - Mf::Fma => Function::Regular("fma"), + Mf::Fma => Function::Regular("mad"), Mf::Mix => Function::Regular("lerp"), Mf::Step => Function::Regular("step"), Mf::SmoothStep => Function::Regular("smoothstep"), diff --git a/tests/in/functions.param.ron b/tests/in/functions.param.ron new file mode 100644 index 0000000000..72873dd667 --- /dev/null +++ b/tests/in/functions.param.ron @@ -0,0 +1,2 @@ +( +) diff --git a/tests/in/functions.wgsl b/tests/in/functions.wgsl new file mode 100644 index 0000000000..e9d64a99ff --- /dev/null +++ b/tests/in/functions.wgsl @@ -0,0 +1,15 @@ +fn test_fma() -> vec2 { + let a = vec2(2.0, 2.0); + let b = vec2(0.5, 0.5); + let c = vec2(0.5, 0.5); + + // Hazard: HLSL needs a different intrinsic function for f32 and f64 + // See: https://github.com/gfx-rs/naga/issues/1579 + return fma(a, b, c); +} + + +[[stage(compute), workgroup_size(1)]] +fn main() { + let a = test_fma(); +} diff --git a/tests/out/glsl/functions.main.Compute.glsl b/tests/out/glsl/functions.main.Compute.glsl new file mode 100644 index 0000000000..138ecee6bf --- /dev/null +++ b/tests/out/glsl/functions.main.Compute.glsl @@ -0,0 +1,20 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +vec2 test_fma() { + vec2 a = vec2(2.0, 2.0); + vec2 b = vec2(0.5, 0.5); + vec2 c = vec2(0.5, 0.5); + return fma(a, b, c); +} + +void main() { + vec2 _e0 = test_fma(); + return; +} + diff --git a/tests/out/hlsl/functions.hlsl b/tests/out/hlsl/functions.hlsl new file mode 100644 index 0000000000..37a4ecfafa --- /dev/null +++ b/tests/out/hlsl/functions.hlsl @@ -0,0 +1,15 @@ + +float2 test_fma() +{ + float2 a = float2(2.0, 2.0); + float2 b = float2(0.5, 0.5); + float2 c = float2(0.5, 0.5); + return mad(a, b, c); +} + +[numthreads(1, 1, 1)] +void main() +{ + const float2 _e0 = test_fma(); + return; +} diff --git a/tests/out/hlsl/functions.hlsl.config b/tests/out/hlsl/functions.hlsl.config new file mode 100644 index 0000000000..246c485cf7 --- /dev/null +++ b/tests/out/hlsl/functions.hlsl.config @@ -0,0 +1,3 @@ +vertex=() +fragment=() +compute=(main:cs_5_1 ) diff --git a/tests/out/msl/functions.msl b/tests/out/msl/functions.msl new file mode 100644 index 0000000000..6ace92f749 --- /dev/null +++ b/tests/out/msl/functions.msl @@ -0,0 +1,18 @@ +// language: metal1.1 +#include +#include + + +metal::float2 test_fma( +) { + metal::float2 a = metal::float2(2.0, 2.0); + metal::float2 b = metal::float2(0.5, 0.5); + metal::float2 c = metal::float2(0.5, 0.5); + return metal::fma(a, b, c); +} + +kernel void main_( +) { + metal::float2 _e0 = test_fma(); + return; +} diff --git a/tests/out/spv/functions.spvasm b/tests/out/spv/functions.spvasm new file mode 100644 index 0000000000..9c1a62ae4c --- /dev/null +++ b/tests/out/spv/functions.spvasm @@ -0,0 +1,33 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 20 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %16 "main" +OpExecutionMode %16 LocalSize 1 1 1 +%2 = OpTypeVoid +%4 = OpTypeFloat 32 +%3 = OpConstant %4 2.0 +%5 = OpConstant %4 0.5 +%6 = OpTypeVector %4 2 +%9 = OpTypeFunction %6 +%17 = OpTypeFunction %2 +%8 = OpFunction %6 None %9 +%7 = OpLabel +OpBranch %10 +%10 = OpLabel +%11 = OpCompositeConstruct %6 %3 %3 +%12 = OpCompositeConstruct %6 %5 %5 +%13 = OpCompositeConstruct %6 %5 %5 +%14 = OpExtInst %6 %1 Fma %11 %12 %13 +OpReturnValue %14 +OpFunctionEnd +%16 = OpFunction %2 None %17 +%15 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpFunctionCall %6 %8 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/functions.wgsl b/tests/out/wgsl/functions.wgsl new file mode 100644 index 0000000000..9b6b41a09f --- /dev/null +++ b/tests/out/wgsl/functions.wgsl @@ -0,0 +1,12 @@ +fn test_fma() -> vec2 { + let a = vec2(2.0, 2.0); + let b = vec2(0.5, 0.5); + let c = vec2(0.5, 0.5); + return fma(a, b, c); +} + +[[stage(compute), workgroup_size(1, 1, 1)]] +fn main() { + let _e0 = test_fma(); + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 565f326190..3153485c33 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -443,6 +443,10 @@ fn convert_wgsl() { "operators", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "functions", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ( "interpolate", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, From 5cec68697b270ba24bf0aaa0e64c93c9909d1293 Mon Sep 17 00:00:00 2001 From: Jay Oster Date: Tue, 7 Dec 2021 09:25:11 -0800 Subject: [PATCH 2/7] Add FMA feature to glsl backend - I think this is right. Just iterate all known expressions in all functions and entry points to locate any `fma` function call. Should not need to walk the statement DAG. --- src/back/glsl/features.rs | 31 ++++++++++++++++++++-- tests/out/glsl/functions.main.Compute.glsl | 1 + 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/back/glsl/features.rs b/src/back/glsl/features.rs index 74db74c70f..30d790724f 100644 --- a/src/back/glsl/features.rs +++ b/src/back/glsl/features.rs @@ -1,7 +1,7 @@ use super::{BackendResult, Error, Version, Writer}; use crate::{ - Binding, Bytes, Handle, ImageClass, ImageDimension, Interpolation, Sampling, ScalarKind, - ShaderStage, StorageClass, StorageFormat, Type, TypeInner, + Binding, Bytes, Expression, Handle, ImageClass, ImageDimension, Interpolation, MathFunction, + Sampling, ScalarKind, ShaderStage, StorageClass, StorageFormat, Type, TypeInner, }; use std::fmt::Write; @@ -33,6 +33,8 @@ bitflags::bitflags! { /// Arrays with a dynamic length const DYNAMIC_ARRAY_SIZE = 1 << 16; const MULTI_VIEW = 1 << 17; + /// Adds support for fused multiply-add + const FMA = 1 << 18; } } @@ -98,6 +100,7 @@ impl FeaturesManager { check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); check_feature!(MULTI_VIEW, 140, 310); + check_feature!(FMA, 400, 310); // Return an error if there are missing features if missing.is_empty() { @@ -201,6 +204,11 @@ impl FeaturesManager { writeln!(out, "#extension GL_EXT_multiview : require")?; } + if self.0.contains(Features::FMA) && version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_gpu_shader5.txt + writeln!(out, "#extension GL_EXT_gpu_shader5 : require")?; + } + Ok(()) } } @@ -347,6 +355,25 @@ impl<'a, W> Writer<'a, W> { } } + let has_fma = self + .module + .functions + .iter() + .flat_map(|(_, f)| f.expressions.iter()) + .chain( + self.module + .entry_points + .iter() + .flat_map(|e| e.function.expressions.iter()), + ) + .any(|(_, e)| match *e { + Expression::Math { fun, .. } if fun == MathFunction::Fma => true, + _ => false, + }); + if has_fma { + self.features.request(Features::FMA); + } + self.features.check_availability(self.options.version) } diff --git a/tests/out/glsl/functions.main.Compute.glsl b/tests/out/glsl/functions.main.Compute.glsl index 138ecee6bf..d3e7f7d171 100644 --- a/tests/out/glsl/functions.main.Compute.glsl +++ b/tests/out/glsl/functions.main.Compute.glsl @@ -1,4 +1,5 @@ #version 310 es +#extension GL_EXT_gpu_shader5 : require precision highp float; precision highp int; From df975739cbe3470eb7457d16edad336dd24f6eba Mon Sep 17 00:00:00 2001 From: Jay Oster Date: Sat, 18 Dec 2021 15:28:36 -0800 Subject: [PATCH 3/7] Transform GLSL fma function into an airthmetic expression when necessary --- src/back/glsl/features.rs | 38 +++++++++++++++++++++----------------- src/back/glsl/mod.rs | 32 +++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/src/back/glsl/features.rs b/src/back/glsl/features.rs index 30d790724f..5f5bdc9b77 100644 --- a/src/back/glsl/features.rs +++ b/src/back/glsl/features.rs @@ -355,23 +355,27 @@ impl<'a, W> Writer<'a, W> { } } - let has_fma = self - .module - .functions - .iter() - .flat_map(|(_, f)| f.expressions.iter()) - .chain( - self.module - .entry_points - .iter() - .flat_map(|e| e.function.expressions.iter()), - ) - .any(|(_, e)| match *e { - Expression::Math { fun, .. } if fun == MathFunction::Fma => true, - _ => false, - }); - if has_fma { - self.features.request(Features::FMA); + if self.options.version >= Version::Desktop(400) + || (self.options.version.is_es() && self.options.version >= Version::Embedded(310)) + { + let has_fma = self + .module + .functions + .iter() + .flat_map(|(_, f)| f.expressions.iter()) + .chain( + self.module + .entry_points + .iter() + .flat_map(|e| e.function.expressions.iter()), + ) + .any(|(_, e)| match *e { + Expression::Math { fun, .. } if fun == MathFunction::Fma => true, + _ => false, + }); + if has_fma { + self.features.request(Features::FMA); + } } self.features.check_availability(self.options.version) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 53f376a660..56ebfe7576 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2433,7 +2433,37 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Refract => "refract", // computational Mf::Sign => "sign", - Mf::Fma => "fma", + Mf::Fma => { + let version = self.options.version; + if version >= Version::Desktop(400) + || (version.is_es() && version >= Version::Embedded(310)) + { + // Use the fma function when available + "fma" + } else { + // No fma support. Transform the function call into an arithmetic expression + write!(self.out, "(")?; + + self.write_expr(arg, ctx)?; + write!(self.out, " * ")?; + + let arg1 = match arg1 { + Some(arg1) => arg1, + None => return Err(Error::Custom("Missing fma arg1".to_owned())), + }; + self.write_expr(arg1, ctx)?; + write!(self.out, " + ")?; + + let arg2 = match arg2 { + Some(arg2) => arg2, + None => return Err(Error::Custom("Missing fma arg2".to_owned())), + }; + self.write_expr(arg2, ctx)?; + write!(self.out, ")")?; + + return Ok(()); + } + } Mf::Mix => "mix", Mf::Step => "step", Mf::SmoothStep => "smoothstep", From f4e241666c3b5d2fb11a9b2d4b8bf3575adae684 Mon Sep 17 00:00:00 2001 From: Jay Oster Date: Sat, 18 Dec 2021 15:50:19 -0800 Subject: [PATCH 4/7] Add tests for GLSL fma function tranformation --- tests/in/functions-webgl.param.ron | 7 +++++++ tests/in/functions-webgl.wgsl | 15 +++++++++++++++ .../out/glsl/functions-webgl.main.Vertex.glsl | 18 ++++++++++++++++++ tests/snapshots.rs | 1 + 4 files changed, 41 insertions(+) create mode 100644 tests/in/functions-webgl.param.ron create mode 100644 tests/in/functions-webgl.wgsl create mode 100644 tests/out/glsl/functions-webgl.main.Vertex.glsl diff --git a/tests/in/functions-webgl.param.ron b/tests/in/functions-webgl.param.ron new file mode 100644 index 0000000000..bc3dd4a8dd --- /dev/null +++ b/tests/in/functions-webgl.param.ron @@ -0,0 +1,7 @@ +( + glsl: ( + version: Embedded(300), + writer_flags: (bits: 0), + binding_map: {}, + ), +) diff --git a/tests/in/functions-webgl.wgsl b/tests/in/functions-webgl.wgsl new file mode 100644 index 0000000000..e372069e90 --- /dev/null +++ b/tests/in/functions-webgl.wgsl @@ -0,0 +1,15 @@ +fn test_fma() -> vec2 { + let a = vec2(2.0, 2.0); + let b = vec2(0.5, 0.5); + let c = vec2(0.5, 0.5); + + // Hazard: HLSL needs a different intrinsic function for f32 and f64 + // See: https://github.com/gfx-rs/naga/issues/1579 + return fma(a, b, c); +} + + +[[stage(vertex)]] +fn main() { + let a = test_fma(); +} diff --git a/tests/out/glsl/functions-webgl.main.Vertex.glsl b/tests/out/glsl/functions-webgl.main.Vertex.glsl new file mode 100644 index 0000000000..3522f1a655 --- /dev/null +++ b/tests/out/glsl/functions-webgl.main.Vertex.glsl @@ -0,0 +1,18 @@ +#version 300 es + +precision highp float; +precision highp int; + + +vec2 test_fma() { + vec2 a = vec2(2.0, 2.0); + vec2 b = vec2(0.5, 0.5); + vec2 c = vec2(0.5, 0.5); + return (a * b + c); +} + +void main() { + vec2 _e0 = test_fma(); + return; +} + diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 3153485c33..378a322df6 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -447,6 +447,7 @@ fn convert_wgsl() { "functions", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ("functions-webgl", Targets::GLSL), ( "interpolate", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, From 7454d1fc04cd3b74474016a56e870b640897e1eb Mon Sep 17 00:00:00 2001 From: Jay Oster Date: Sat, 18 Dec 2021 15:52:05 -0800 Subject: [PATCH 5/7] Remove the hazard comment from the webgl test input --- tests/in/functions-webgl.wgsl | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/in/functions-webgl.wgsl b/tests/in/functions-webgl.wgsl index e372069e90..2ec56f88f3 100644 --- a/tests/in/functions-webgl.wgsl +++ b/tests/in/functions-webgl.wgsl @@ -3,8 +3,6 @@ fn test_fma() -> vec2 { let b = vec2(0.5, 0.5); let c = vec2(0.5, 0.5); - // Hazard: HLSL needs a different intrinsic function for f32 and f64 - // See: https://github.com/gfx-rs/naga/issues/1579 return fma(a, b, c); } From 4f9addbb2d95c1d70b7b79ab01250db2e98079c9 Mon Sep 17 00:00:00 2001 From: Jay Oster Date: Mon, 20 Dec 2021 13:49:55 -0800 Subject: [PATCH 6/7] Add helper method for fma function support checks --- src/back/glsl/features.rs | 4 +--- src/back/glsl/mod.rs | 9 +++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/back/glsl/features.rs b/src/back/glsl/features.rs index 5f5bdc9b77..0d78558991 100644 --- a/src/back/glsl/features.rs +++ b/src/back/glsl/features.rs @@ -355,9 +355,7 @@ impl<'a, W> Writer<'a, W> { } } - if self.options.version >= Version::Desktop(400) - || (self.options.version.is_es() && self.options.version >= Version::Embedded(310)) - { + if self.options.version.supports_fma_function() { let has_fma = self .module .functions diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 56ebfe7576..661a86b9d4 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -133,6 +133,10 @@ impl Version { fn supports_std430_layout(&self) -> bool { *self >= Version::Desktop(430) || *self >= Version::Embedded(310) } + + fn supports_fma_function(&self) -> bool { + *self >= Version::Desktop(400) || *self >= Version::Embedded(310) + } } impl PartialOrd for Version { @@ -2434,10 +2438,7 @@ impl<'a, W: Write> Writer<'a, W> { // computational Mf::Sign => "sign", Mf::Fma => { - let version = self.options.version; - if version >= Version::Desktop(400) - || (version.is_es() && version >= Version::Embedded(310)) - { + if self.options.version.supports_fma_function() { // Use the fma function when available "fma" } else { From 7c8bedcb1c6f9bb14b4e9fc8f48496bea3a67491 Mon Sep 17 00:00:00 2001 From: Jay Oster Date: Mon, 20 Dec 2021 13:52:53 -0800 Subject: [PATCH 7/7] Address review comment --- src/back/glsl/mod.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 661a86b9d4..c0b6285685 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2448,17 +2448,13 @@ impl<'a, W: Write> Writer<'a, W> { self.write_expr(arg, ctx)?; write!(self.out, " * ")?; - let arg1 = match arg1 { - Some(arg1) => arg1, - None => return Err(Error::Custom("Missing fma arg1".to_owned())), - }; + let arg1 = + arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?; self.write_expr(arg1, ctx)?; write!(self.out, " + ")?; - let arg2 = match arg2 { - Some(arg2) => arg2, - None => return Err(Error::Custom("Missing fma arg2".to_owned())), - }; + let arg2 = + arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?; self.write_expr(arg2, ctx)?; write!(self.out, ")")?;