diff --git a/src/back/glsl/features.rs b/src/back/glsl/features.rs index 74db74c70f..0d78558991 100644 --- a/src/back/glsl/features.rs +++ b/src/back/glsl/features.rs @@ -1,7 +1,7 @@ use super::{BackendResult, Error, Version, Writer}; use crate::{ - Binding, Bytes, Handle, ImageClass, ImageDimension, Interpolation, Sampling, ScalarKind, - ShaderStage, StorageClass, StorageFormat, Type, TypeInner, + Binding, Bytes, Expression, Handle, ImageClass, ImageDimension, Interpolation, MathFunction, + Sampling, ScalarKind, ShaderStage, StorageClass, StorageFormat, Type, TypeInner, }; use std::fmt::Write; @@ -33,6 +33,8 @@ bitflags::bitflags! { /// Arrays with a dynamic length const DYNAMIC_ARRAY_SIZE = 1 << 16; const MULTI_VIEW = 1 << 17; + /// Adds support for fused multiply-add + const FMA = 1 << 18; } } @@ -98,6 +100,7 @@ impl FeaturesManager { check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); check_feature!(MULTI_VIEW, 140, 310); + check_feature!(FMA, 400, 310); // Return an error if there are missing features if missing.is_empty() { @@ -201,6 +204,11 @@ impl FeaturesManager { writeln!(out, "#extension GL_EXT_multiview : require")?; } + if self.0.contains(Features::FMA) && version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_gpu_shader5.txt + writeln!(out, "#extension GL_EXT_gpu_shader5 : require")?; + } + Ok(()) } } @@ -347,6 +355,27 @@ impl<'a, W> Writer<'a, W> { } } + if self.options.version.supports_fma_function() { + let has_fma = self + .module + .functions + .iter() + .flat_map(|(_, f)| f.expressions.iter()) + .chain( + self.module + .entry_points + .iter() + .flat_map(|e| e.function.expressions.iter()), + ) + .any(|(_, e)| match *e { + Expression::Math { fun, .. } if fun == MathFunction::Fma => true, + _ => false, + }); + if has_fma { + self.features.request(Features::FMA); + } + } + self.features.check_availability(self.options.version) } diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 53f376a660..c0b6285685 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -133,6 +133,10 @@ impl Version { fn supports_std430_layout(&self) -> bool { *self >= Version::Desktop(430) || *self >= Version::Embedded(310) } + + fn supports_fma_function(&self) -> bool { + *self >= Version::Desktop(400) || *self >= Version::Embedded(310) + } } impl PartialOrd for Version { @@ -2433,7 +2437,30 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Refract => "refract", // computational Mf::Sign => "sign", - Mf::Fma => "fma", + Mf::Fma => { + if self.options.version.supports_fma_function() { + // Use the fma function when available + "fma" + } else { + // No fma support. Transform the function call into an arithmetic expression + write!(self.out, "(")?; + + self.write_expr(arg, ctx)?; + write!(self.out, " * ")?; + + let arg1 = + arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?; + self.write_expr(arg1, ctx)?; + write!(self.out, " + ")?; + + let arg2 = + arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?; + self.write_expr(arg2, ctx)?; + write!(self.out, ")")?; + + return Ok(()); + } + } Mf::Mix => "mix", Mf::Step => "step", Mf::SmoothStep => "smoothstep", diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 6f05f1bd10..00b19b5f23 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1861,7 +1861,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Mf::Refract => Function::Regular("refract"), // computational Mf::Sign => Function::Regular("sign"), - Mf::Fma => Function::Regular("fma"), + Mf::Fma => Function::Regular("mad"), Mf::Mix => Function::Regular("lerp"), Mf::Step => Function::Regular("step"), Mf::SmoothStep => Function::Regular("smoothstep"), diff --git a/tests/in/functions-webgl.param.ron b/tests/in/functions-webgl.param.ron new file mode 100644 index 0000000000..bc3dd4a8dd --- /dev/null +++ b/tests/in/functions-webgl.param.ron @@ -0,0 +1,7 @@ +( + glsl: ( + version: Embedded(300), + writer_flags: (bits: 0), + binding_map: {}, + ), +) diff --git a/tests/in/functions-webgl.wgsl b/tests/in/functions-webgl.wgsl new file mode 100644 index 0000000000..2ec56f88f3 --- /dev/null +++ b/tests/in/functions-webgl.wgsl @@ -0,0 +1,13 @@ +fn test_fma() -> vec2 { + let a = vec2(2.0, 2.0); + let b = vec2(0.5, 0.5); + let c = vec2(0.5, 0.5); + + return fma(a, b, c); +} + + +[[stage(vertex)]] +fn main() { + let a = test_fma(); +} diff --git a/tests/in/functions.param.ron b/tests/in/functions.param.ron new file mode 100644 index 0000000000..72873dd667 --- /dev/null +++ b/tests/in/functions.param.ron @@ -0,0 +1,2 @@ +( +) diff --git a/tests/in/functions.wgsl b/tests/in/functions.wgsl new file mode 100644 index 0000000000..e9d64a99ff --- /dev/null +++ b/tests/in/functions.wgsl @@ -0,0 +1,15 @@ +fn test_fma() -> vec2 { + let a = vec2(2.0, 2.0); + let b = vec2(0.5, 0.5); + let c = vec2(0.5, 0.5); + + // Hazard: HLSL needs a different intrinsic function for f32 and f64 + // See: https://github.com/gfx-rs/naga/issues/1579 + return fma(a, b, c); +} + + +[[stage(compute), workgroup_size(1)]] +fn main() { + let a = test_fma(); +} diff --git a/tests/out/glsl/functions-webgl.main.Vertex.glsl b/tests/out/glsl/functions-webgl.main.Vertex.glsl new file mode 100644 index 0000000000..3522f1a655 --- /dev/null +++ b/tests/out/glsl/functions-webgl.main.Vertex.glsl @@ -0,0 +1,18 @@ +#version 300 es + +precision highp float; +precision highp int; + + +vec2 test_fma() { + vec2 a = vec2(2.0, 2.0); + vec2 b = vec2(0.5, 0.5); + vec2 c = vec2(0.5, 0.5); + return (a * b + c); +} + +void main() { + vec2 _e0 = test_fma(); + return; +} + diff --git a/tests/out/glsl/functions.main.Compute.glsl b/tests/out/glsl/functions.main.Compute.glsl new file mode 100644 index 0000000000..d3e7f7d171 --- /dev/null +++ b/tests/out/glsl/functions.main.Compute.glsl @@ -0,0 +1,21 @@ +#version 310 es +#extension GL_EXT_gpu_shader5 : require + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +vec2 test_fma() { + vec2 a = vec2(2.0, 2.0); + vec2 b = vec2(0.5, 0.5); + vec2 c = vec2(0.5, 0.5); + return fma(a, b, c); +} + +void main() { + vec2 _e0 = test_fma(); + return; +} + diff --git a/tests/out/hlsl/functions.hlsl b/tests/out/hlsl/functions.hlsl new file mode 100644 index 0000000000..37a4ecfafa --- /dev/null +++ b/tests/out/hlsl/functions.hlsl @@ -0,0 +1,15 @@ + +float2 test_fma() +{ + float2 a = float2(2.0, 2.0); + float2 b = float2(0.5, 0.5); + float2 c = float2(0.5, 0.5); + return mad(a, b, c); +} + +[numthreads(1, 1, 1)] +void main() +{ + const float2 _e0 = test_fma(); + return; +} diff --git a/tests/out/hlsl/functions.hlsl.config b/tests/out/hlsl/functions.hlsl.config new file mode 100644 index 0000000000..246c485cf7 --- /dev/null +++ b/tests/out/hlsl/functions.hlsl.config @@ -0,0 +1,3 @@ +vertex=() +fragment=() +compute=(main:cs_5_1 ) diff --git a/tests/out/msl/functions.msl b/tests/out/msl/functions.msl new file mode 100644 index 0000000000..6ace92f749 --- /dev/null +++ b/tests/out/msl/functions.msl @@ -0,0 +1,18 @@ +// language: metal1.1 +#include +#include + + +metal::float2 test_fma( +) { + metal::float2 a = metal::float2(2.0, 2.0); + metal::float2 b = metal::float2(0.5, 0.5); + metal::float2 c = metal::float2(0.5, 0.5); + return metal::fma(a, b, c); +} + +kernel void main_( +) { + metal::float2 _e0 = test_fma(); + return; +} diff --git a/tests/out/spv/functions.spvasm b/tests/out/spv/functions.spvasm new file mode 100644 index 0000000000..9c1a62ae4c --- /dev/null +++ b/tests/out/spv/functions.spvasm @@ -0,0 +1,33 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 20 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %16 "main" +OpExecutionMode %16 LocalSize 1 1 1 +%2 = OpTypeVoid +%4 = OpTypeFloat 32 +%3 = OpConstant %4 2.0 +%5 = OpConstant %4 0.5 +%6 = OpTypeVector %4 2 +%9 = OpTypeFunction %6 +%17 = OpTypeFunction %2 +%8 = OpFunction %6 None %9 +%7 = OpLabel +OpBranch %10 +%10 = OpLabel +%11 = OpCompositeConstruct %6 %3 %3 +%12 = OpCompositeConstruct %6 %5 %5 +%13 = OpCompositeConstruct %6 %5 %5 +%14 = OpExtInst %6 %1 Fma %11 %12 %13 +OpReturnValue %14 +OpFunctionEnd +%16 = OpFunction %2 None %17 +%15 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpFunctionCall %6 %8 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/functions.wgsl b/tests/out/wgsl/functions.wgsl new file mode 100644 index 0000000000..9b6b41a09f --- /dev/null +++ b/tests/out/wgsl/functions.wgsl @@ -0,0 +1,12 @@ +fn test_fma() -> vec2 { + let a = vec2(2.0, 2.0); + let b = vec2(0.5, 0.5); + let c = vec2(0.5, 0.5); + return fma(a, b, c); +} + +[[stage(compute), workgroup_size(1, 1, 1)]] +fn main() { + let _e0 = test_fma(); + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 565f326190..378a322df6 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -443,6 +443,11 @@ fn convert_wgsl() { "operators", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "functions", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), + ("functions-webgl", Targets::GLSL), ( "interpolate", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,