From 943235cd5e91df9a1d41c60f525d26734bc0d261 Mon Sep 17 00:00:00 2001 From: Igor Shaposhnik Date: Thu, 7 Oct 2021 23:59:39 +0300 Subject: [PATCH] [glsl-out] Convert modulo operator on float to SPIR-V OpFRem equivalent function (#1452) --- src/back/glsl/mod.rs | 95 +++++++++++++++++----- src/lib.rs | 1 + tests/in/operators.wgsl | 9 ++ tests/out/glsl/operators.main.Compute.glsl | 8 ++ tests/out/hlsl/operators.hlsl | 9 ++ tests/out/msl/operators.msl | 9 ++ tests/out/spv/operators.spvasm | 31 +++++-- tests/out/wgsl/operators.wgsl | 8 ++ 8 files changed, 143 insertions(+), 27 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index b79aabb870..e472c68268 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -311,6 +311,16 @@ pub enum Error { Custom(String), } +/// Binary operation with a different logic on the GLSL side +enum BinaryOperation { + /// Vector comparison should use the function like `greaterThan()`, etc. + VectorCompare, + /// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem` + Modulo, + /// Any plain operation. No additional logic required + Other, +} + /// Main structure of the glsl backend responsible for all code generation pub struct Writer<'a, W> { // Inputs @@ -2214,36 +2224,81 @@ impl<'a, W: Write> Writer<'a, W> { let right_inner = ctx.info[right].ty.inner_with(&self.module.types); let function = match (left_inner, right_inner) { - (&Ti::Vector { .. }, &Ti::Vector { .. }) => match op { - Bo::Less => Some("lessThan"), - Bo::LessEqual => Some("lessThanEqual"), - Bo::Greater => Some("greaterThan"), - Bo::GreaterEqual => Some("greaterThanEqual"), - Bo::Equal => Some("equal"), - Bo::NotEqual => Some("notEqual"), - _ => None, + ( + &Ti::Vector { + kind: left_kind, .. + }, + &Ti::Vector { + kind: right_kind, .. + }, + ) => match op { + Bo::Less + | Bo::LessEqual + | Bo::Greater + | Bo::GreaterEqual + | Bo::Equal + | Bo::NotEqual => BinaryOperation::VectorCompare, + Bo::Modulo => match (left_kind, right_kind) { + (Sk::Float, _) | (_, Sk::Float) => match op { + Bo::Modulo => BinaryOperation::Modulo, + _ => BinaryOperation::Other, + }, + _ => BinaryOperation::Other, + }, + _ => BinaryOperation::Other, }, _ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) { (Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op { - Bo::Modulo => Some("mod"), - _ => None, + Bo::Modulo => BinaryOperation::Modulo, + _ => BinaryOperation::Other, }, - _ => None, + _ => BinaryOperation::Other, }, }; - write!(self.out, "{}(", function.unwrap_or(""))?; - self.write_expr(left, ctx)?; + match function { + BinaryOperation::VectorCompare => { + let op_str = match op { + Bo::Less => "lessThan(", + Bo::LessEqual => "lessThanEqual(", + Bo::Greater => "greaterThan(", + Bo::GreaterEqual => "greaterThanEqual(", + Bo::Equal => "equal(", + Bo::NotEqual => "notEqual(", + _ => unreachable!(), + }; + write!(self.out, "{}", op_str)?; + self.write_expr(left, ctx)?; + write!(self.out, ", ")?; + self.write_expr(right, ctx)?; + write!(self.out, ")")?; + } + BinaryOperation::Modulo => { + write!(self.out, "(")?; - if function.is_some() { - write!(self.out, ",")? - } else { - write!(self.out, " {} ", super::binary_operation_str(op))?; - } + // write `e1 - e2 * trunc(e1 / e2)` + self.write_expr(left, ctx)?; + write!(self.out, " - ")?; + self.write_expr(right, ctx)?; + write!(self.out, " * ")?; + write!(self.out, "trunc(")?; + self.write_expr(left, ctx)?; + write!(self.out, " / ")?; + self.write_expr(right, ctx)?; + write!(self.out, ")")?; - self.write_expr(right, ctx)?; + write!(self.out, ")")?; + } + BinaryOperation::Other => { + write!(self.out, "(")?; - write!(self.out, ")")? + self.write_expr(left, ctx)?; + write!(self.out, " {} ", super::binary_operation_str(op))?; + self.write_expr(right, ctx)?; + + write!(self.out, ")")?; + } + } } // `Select` is written as `condition ? accept : reject` // We wrap everything in parentheses to avoid precedence issues diff --git a/src/lib.rs b/src/lib.rs index 737e0bfca3..036b197a13 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -813,6 +813,7 @@ pub enum BinaryOperator { Subtract, Multiply, Divide, + /// Equivalent of the WGSL's `%` operator or SPIR-V's `OpFRem` Modulo, Equal, NotEqual, diff --git a/tests/in/operators.wgsl b/tests/in/operators.wgsl index 1b7ec7585e..7335f57755 100644 --- a/tests/in/operators.wgsl +++ b/tests/in/operators.wgsl @@ -44,10 +44,19 @@ fn constructors() -> f32 { return foo.a.x; } +fn modulo() { + // Modulo operator on float scalar or vector must be converted to mod function for GLSL + let a = 1 % 1; + let b = 1.0 % 1.0; + let c = vec3(1) % vec3(1); + let d = vec3(1.0) % vec3(1.0); +} + [[stage(compute), workgroup_size(1)]] fn main() { let a = builtins(); let b = splat(); let c = unary(); let d = constructors(); + modulo(); } diff --git a/tests/out/glsl/operators.main.Compute.glsl b/tests/out/glsl/operators.main.Compute.glsl index 99c42467f3..17e9966226 100644 --- a/tests/out/glsl/operators.main.Compute.glsl +++ b/tests/out/glsl/operators.main.Compute.glsl @@ -44,11 +44,19 @@ float constructors() { return _e11; } +void modulo() { + int a1 = (1 % 1); + float b1 = (1.0 - 1.0 * trunc(1.0 / 1.0)); + ivec3 c = (ivec3(1) % ivec3(1)); + vec3 d = (vec3(1.0) - vec3(1.0) * trunc(vec3(1.0) / vec3(1.0))); +} + void main() { vec4 _e4 = builtins(); vec4 _e5 = splat(); int _e6 = unary(); float _e7 = constructors(); + modulo(); return; } diff --git a/tests/out/hlsl/operators.hlsl b/tests/out/hlsl/operators.hlsl index 904ca2e1cf..31aacb1640 100644 --- a/tests/out/hlsl/operators.hlsl +++ b/tests/out/hlsl/operators.hlsl @@ -53,6 +53,14 @@ float constructors() return _expr11; } +void modulo() +{ + int a1 = (1 % 1); + float b1 = (1.0 % 1.0); + int3 c = (int3(1.xxx) % int3(1.xxx)); + float3 d = (float3(1.0.xxx) % float3(1.0.xxx)); +} + [numthreads(1, 1, 1)] void main() { @@ -60,5 +68,6 @@ void main() const float4 _e5 = splat(); const int _e6 = unary(); const float _e7 = constructors(); + modulo(); return; } diff --git a/tests/out/msl/operators.msl b/tests/out/msl/operators.msl index daf4b3a10b..a552a33e94 100644 --- a/tests/out/msl/operators.msl +++ b/tests/out/msl/operators.msl @@ -48,11 +48,20 @@ float constructors( return _e11; } +void modulo( +) { + int a1 = 1 % 1; + float b1 = metal::fmod(1.0, 1.0); + metal::int3 c = metal::int3(1) % metal::int3(1); + metal::float3 d = metal::fmod(metal::float3(1.0), metal::float3(1.0)); +} + kernel void main1( ) { metal::float4 _e4 = builtins(); metal::float4 _e5 = splat(); int _e6 = unary(); float _e7 = constructors(); + modulo(); return; } diff --git a/tests/out/spv/operators.spvasm b/tests/out/spv/operators.spvasm index 475241ab75..688917bf34 100644 --- a/tests/out/spv/operators.spvasm +++ b/tests/out/spv/operators.spvasm @@ -1,12 +1,12 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 101 +; Bound: 115 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %94 "main" -OpExecutionMode %94 LocalSize 1 1 1 +OpEntryPoint GLCompute %108 "main" +OpExecutionMode %108 LocalSize 1 1 1 OpMemberDecorate %22 0 Offset 0 OpMemberDecorate %22 1 Offset 16 %2 = OpTypeVoid @@ -45,6 +45,8 @@ OpMemberDecorate %22 1 Offset 16 %90 = OpTypeInt 32 0 %89 = OpConstant %90 0 %95 = OpTypeFunction %2 +%99 = OpTypeVector %8 3 +%103 = OpTypeVector %4 3 %28 = OpFunction %19 None %29 %27 = OpLabel OpBranch %30 @@ -122,9 +124,24 @@ OpFunctionEnd %93 = OpLabel OpBranch %96 %96 = OpLabel -%97 = OpFunctionCall %19 %28 -%98 = OpFunctionCall %19 %53 -%99 = OpFunctionCall %8 %70 -%100 = OpFunctionCall %4 %82 +%97 = OpSMod %8 %7 %7 +%98 = OpFMod %4 %3 %3 +%100 = OpCompositeConstruct %99 %7 %7 %7 +%101 = OpCompositeConstruct %99 %7 %7 %7 +%102 = OpSMod %99 %100 %101 +%104 = OpCompositeConstruct %103 %3 %3 %3 +%105 = OpCompositeConstruct %103 %3 %3 %3 +%106 = OpFMod %103 %104 %105 +OpReturn +OpFunctionEnd +%108 = OpFunction %2 None %95 +%107 = OpLabel +OpBranch %109 +%109 = OpLabel +%110 = OpFunctionCall %19 %28 +%111 = OpFunctionCall %19 %53 +%112 = OpFunctionCall %8 %70 +%113 = OpFunctionCall %4 %82 +%114 = OpFunctionCall %2 %94 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/operators.wgsl b/tests/out/wgsl/operators.wgsl index 6c7ab1147d..ed5b9f78fe 100644 --- a/tests/out/wgsl/operators.wgsl +++ b/tests/out/wgsl/operators.wgsl @@ -41,11 +41,19 @@ fn constructors() -> f32 { return e11; } +fn modulo() { + let a1: i32 = (1 % 1); + let b1: f32 = (1.0 % 1.0); + let c: vec3 = (vec3(1) % vec3(1)); + let d: vec3 = (vec3(1.0) % vec3(1.0)); +} + [[stage(compute), workgroup_size(1, 1, 1)]] fn main() { let e4: vec4 = builtins(); let e5: vec4 = splat(); let e6: i32 = unary(); let e7: f32 = constructors(); + modulo(); return; }