From 5d8f84e0f7702ba06e963645d90e490c53a7bcd2 Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sun, 5 Feb 2023 12:34:04 +0100 Subject: [PATCH 01/15] Add countTrailingZeros --- src/back/glsl/mod.rs | 75 ++++++++++++++++++- src/back/hlsl/writer.rs | 44 ++++++++++- src/back/msl/writer.rs | 1 + src/back/spv/block.rs | 65 ++++++++++++++++ src/back/wgsl/writer.rs | 1 + src/front/wgsl/parse/conv.rs | 1 + src/lib.rs | 1 + src/proc/mod.rs | 1 + src/proc/typifier.rs | 1 + src/valid/expression.rs | 3 +- tests/in/math-functions.wgsl | 6 ++ .../out/glsl/math-functions.main.Vertex.glsl | 11 ++- tests/out/hlsl/math-functions.hlsl | 11 ++- tests/out/spv/math-functions.spvasm | 40 +++++++--- tests/out/wgsl/math-functions.wgsl | 6 ++ 15 files changed, 250 insertions(+), 17 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index aaed745f26..3db68ad921 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1149,7 +1149,8 @@ impl<'a, W: Write> Writer<'a, W> { } } } - crate::MathFunction::CountLeadingZeros => { + crate::MathFunction::CountTrailingZeros + | crate::MathFunction::CountLeadingZeros => { if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { self.need_bake_expressions.insert(arg); } @@ -2960,6 +2961,78 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Transpose => "transpose", Mf::Determinant => "determinant", // bits + Mf::CountTrailingZeros => { + if self.options.version.supports_integer_functions() { + match *ctx.info[arg].ty.inner_with(&self.module.types) { + crate::TypeInner::Vector { size, kind, .. } => { + let s = back::vector_size_str(size); + + if let crate::ScalarKind::Uint = kind { + write!(self.out, "uvec{s}(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") - ivec{s}(1))")?; + } else { + write!(self.out, "mix(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") - ivec{s}(1), ivec{s}(0), lessThan(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ", ivec{s}(0)))")?; + } + } + crate::TypeInner::Scalar { kind, .. } => { + if let crate::ScalarKind::Uint = kind { + write!(self.out, "uint(findLSB(")?; + } else { + write!(self.out, "(")?; + self.write_expr(arg, ctx)?; + write!(self.out, " == 0 ? -1 : findLSB(")?; + } + + self.write_expr(arg, ctx)?; + write!(self.out, ") - 1)")?; + } + _ => unreachable!(), + }; + } else { + match *ctx.info[arg].ty.inner_with(&self.module.types) { + crate::TypeInner::Vector { size, kind, .. } => { + let s = back::vector_size_str(size); + + if let crate::ScalarKind::Uint = kind { + write!(self.out, "uvec{s}(")?; + write!(self.out, "floor(log2(vec{s}(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5)) - vec{s}(1.0))")?; + } else { + write!(self.out, "ivec{s}(")?; + write!(self.out, "mix(floor(log2(vec{s}(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5)) - vec{s}(1.0), ")?; + write!(self.out, "vec{s}(0.0), lessThan(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ", ivec{s}(0u))))")?; + } + } + crate::TypeInner::Scalar { kind, .. } => { + if let crate::ScalarKind::Uint = kind { + write!(self.out, "uint(floor(log2(float(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5)) - 1.0)")?; + } else { + write!(self.out, "(")?; + self.write_expr(arg, ctx)?; + write!(self.out, " == 0 ? -1 : int(")?; + write!(self.out, "floor(log2(float(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5))) - 1.0)")?; + } + } + _ => unreachable!(), + }; + } + + return Ok(()); + } Mf::CountLeadingZeros => { if self.options.version.supports_integer_functions() { match *ctx.info[arg].ty.inner_with(&self.module.types) { diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 2af7a3524b..35f9d55010 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -124,7 +124,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { | crate::MathFunction::Unpack2x16float => { self.need_bake_expressions.insert(arg); } - crate::MathFunction::CountLeadingZeros => { + crate::MathFunction::CountTrailingZeros + | crate::MathFunction::CountLeadingZeros => { let inner = info[fun_handle].ty.inner_with(&module.types); if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { self.need_bake_expressions.insert(arg); @@ -2551,6 +2552,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Unpack2x16float, Regular(&'static str), MissingIntOverload(&'static str), + CountTrailingZeros, CountLeadingZeros, } @@ -2614,6 +2616,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Mf::Transpose => Function::Regular("transpose"), Mf::Determinant => Function::Regular("determinant"), // bits + Mf::CountTrailingZeros => Function::CountTrailingZeros, Mf::CountLeadingZeros => Function::CountLeadingZeros, Mf::CountOneBits => Function::MissingIntOverload("countbits"), Mf::ReverseBits => Function::MissingIntOverload("reversebits"), @@ -2682,6 +2685,45 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; } } + Function::CountTrailingZeros => { + match *func_ctx.info[arg].ty.inner_with(&module.types) { + TypeInner::Vector { size, kind, .. } => { + let s = match size { + crate::VectorSize::Bi => ".xx", + crate::VectorSize::Tri => ".xxx", + crate::VectorSize::Quad => ".xxxx", + }; + + if let ScalarKind::Uint = kind { + write!(self.out, "asuint(firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") - (1){s})")?; + } else { + write!(self.out, "(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " == (0){s} ? (-1){s} : firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") - (1){s})")?; + } + } + TypeInner::Scalar { kind, .. } => { + if let ScalarKind::Uint = kind { + write!(self.out, "asuint(firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") - 1)")?; + } else { + write!(self.out, "(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " == 0 ? -1 : firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") - 1)")?; + } + } + _ => unreachable!(), + } + + return Ok(()); + } Function::CountLeadingZeros => { match *func_ctx.info[arg].ty.inner_with(&module.types) { TypeInner::Vector { size, kind, .. } => { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index a4635ac2e7..07fd45647e 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1689,6 +1689,7 @@ impl Writer { Mf::Transpose => "transpose", Mf::Determinant => "determinant", // bits + Mf::CountTrailingZeros => "ctz", Mf::CountLeadingZeros => "clz", Mf::CountOneBits => "popcount", Mf::ReverseBits => "reverse_bits", diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index bf09fc3ad1..a6e1fa4630 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -888,6 +888,71 @@ impl<'w> BlockContext<'w> { id, arg0_id, )), + Mf::CountTrailingZeros => { + let int = crate::ScalarValue::Sint(1); + + let (int_type_id, int_id) = match *arg_ty { + crate::TypeInner::Vector { size, width, .. } => { + let ty = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + kind: crate::ScalarKind::Sint, + width, + pointer_space: None, + })); + + self.temp_list.clear(); + self.temp_list + .resize(size as _, self.writer.get_constant_scalar(int, width)); + + let id = self.gen_id(); + block.body.push(Instruction::constant_composite( + ty, + id, + &self.temp_list, + )); + + (ty, id) + } + crate::TypeInner::Scalar { width, .. } => ( + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Sint, + width, + pointer_space: None, + })), + self.writer.get_constant_scalar(int, width), + ), + _ => unreachable!(), + }; + + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FindILsb, + int_type_id, + id, + &[arg0_id], + )); + + let sub_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + int_type_id, + sub_id, + id, + int_id, + )); + + if let Some(crate::ScalarKind::Uint) = arg_scalar_kind { + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + result_type_id, + self.gen_id(), + sub_id, + )); + } + + return Ok(()); + } Mf::CountLeadingZeros => { let int = crate::ScalarValue::Sint(31); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 4f58cd9ee8..be4d9e4423 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1578,6 +1578,7 @@ impl Writer { Mf::Transpose => Function::Regular("transpose"), Mf::Determinant => Function::Regular("determinant"), // bits + Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"), Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"), Mf::CountOneBits => Function::Regular("countOneBits"), Mf::ReverseBits => Function::Regular("reverseBits"), diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index 4164332166..0e20774dc5 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -191,6 +191,7 @@ pub fn map_standard_fun(word: &str) -> Option { "transpose" => Mf::Transpose, "determinant" => Mf::Determinant, // bits + "countTrailingZeros" => Mf::CountTrailingZeros, "countLeadingZeros" => Mf::CountLeadingZeros, "countOneBits" => Mf::CountOneBits, "reverseBits" => Mf::ReverseBits, diff --git a/src/lib.rs b/src/lib.rs index 136205ca17..e60d8e0cc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1066,6 +1066,7 @@ pub enum MathFunction { Transpose, Determinant, // bits + CountTrailingZeros, CountLeadingZeros, CountOneBits, ReverseBits, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 0a8e4a961a..6a8bfa03c7 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -279,6 +279,7 @@ impl super::MathFunction { Self::Transpose => 1, Self::Determinant => 1, // bits + Self::CountTrailingZeros => 1, Self::CountLeadingZeros => 1, Self::CountOneBits => 1, Self::ReverseBits => 1, diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 64896ec413..e384340570 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -793,6 +793,7 @@ impl<'a> ResolveContext<'a> { )), }, // bits + Mf::CountTrailingZeros | Mf::CountLeadingZeros | Mf::CountOneBits | Mf::ReverseBits | diff --git a/src/valid/expression.rs b/src/valid/expression.rs index d599e6b62f..01d6910eba 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1223,7 +1223,8 @@ impl super::Validator { _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } - Mf::CountLeadingZeros + Mf::CountTrailingZeros + | Mf::CountLeadingZeros | Mf::CountOneBits | Mf::ReverseBits | Mf::FindLsb diff --git a/tests/in/math-functions.wgsl b/tests/in/math-functions.wgsl index 1c2a8e4579..aa2679155f 100644 --- a/tests/in/math-functions.wgsl +++ b/tests/in/math-functions.wgsl @@ -10,6 +10,12 @@ fn main() { let g = refract(v, v, f); let const_dot = dot(vec2(), vec2()); let first_leading_bit_abs = firstLeadingBit(abs(0u)); + let ctz_a = countTrailingZeros(-1); + let ctz_b = countTrailingZeros(1u); + let ctz_c = countTrailingZeros(vec2(-1)); + let ctz_d = countTrailingZeros(vec2(1u)); + let ctz_e = countTrailingZeros(0); + let ctz_f = countTrailingZeros(0u); let clz_a = countLeadingZeros(-1); let clz_b = countLeadingZeros(1u); let clz_c = countLeadingZeros(vec2(-1)); diff --git a/tests/out/glsl/math-functions.main.Vertex.glsl b/tests/out/glsl/math-functions.main.Vertex.glsl index 8cfa6a10b5..b62a425ae7 100644 --- a/tests/out/glsl/math-functions.main.Vertex.glsl +++ b/tests/out/glsl/math-functions.main.Vertex.glsl @@ -14,10 +14,17 @@ void main() { vec4 g = refract(v, v, 1.0); int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y); uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); + int ctz_a = (-1 == 0 ? -1 : findLSB(-1) - 1); + uint ctz_b = uint(findLSB(1u) - 1); + ivec2 _e20 = ivec2(-1); + ivec2 ctz_c = mix(findLSB(_e20) - ivec2(1), ivec2(0), lessThan(_e20, ivec2(0))); + uvec2 ctz_d = uvec2(findLSB(uvec2(1u)) - ivec2(1)); + int ctz_e = (0 == 0 ? -1 : findLSB(0) - 1); + uint ctz_f = uint(findLSB(0u) - 1); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); - ivec2 _e20 = ivec2(-1); - ivec2 clz_c = mix(ivec2(31) - findMSB(_e20), ivec2(0), lessThan(_e20, ivec2(0))); + ivec2 _e34 = ivec2(-1); + ivec2 clz_c = mix(ivec2(31) - findMSB(_e34), ivec2(0), lessThan(_e34, ivec2(0))); uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u))); } diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index 2a95c849c9..1c052ad3d6 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -10,9 +10,16 @@ void main() float4 g = refract(v, v, 1.0); int const_dot = dot(int2(0, 0), int2(0, 0)); uint first_leading_bit_abs = firstbithigh(abs(0u)); + int ctz_a = (-1 == 0 ? -1 : firstbitlow(-1) - 1); + uint ctz_b = asuint(firstbitlow(1u) - 1); + int2 _expr20 = (-1).xx; + int2 ctz_c = (_expr20 == (0).xx ? (-1).xx : firstbitlow(_expr20) - (1).xx); + uint2 ctz_d = asuint(firstbitlow((1u).xx) - (1).xx); + int ctz_e = (0 == 0 ? -1 : firstbitlow(0) - 1); + uint ctz_f = asuint(firstbitlow(0u) - 1); int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1)); uint clz_b = asuint(31 - firstbithigh(1u)); - int2 _expr20 = (-1).xx; - int2 clz_c = (_expr20 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr20)); + int2 _expr34 = (-1).xx; + int2 clz_c = (_expr34 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr34)); uint2 clz_d = asuint((31).xx - firstbithigh((1u).xx)); } diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index 131c6caceb..ff0797bea1 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 55 +; Bound: 75 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -20,13 +20,16 @@ OpEntryPoint Vertex %16 "main" %13 = OpTypeVector %7 2 %14 = OpConstantComposite %13 %6 %6 %17 = OpTypeFunction %2 -%40 = OpConstant %7 31 +%40 = OpConstant %7 1 %49 = OpTypeVector %9 2 +%61 = OpConstant %7 31 %25 = OpConstantComposite %12 %5 %5 %5 %5 %26 = OpConstantComposite %12 %3 %3 %3 %3 %29 = OpConstantNull %7 %47 = OpConstantComposite %13 %40 %40 %52 = OpConstantComposite %13 %40 %40 +%68 = OpConstantComposite %13 %61 %61 +%72 = OpConstantComposite %13 %61 %61 %16 = OpFunction %2 None %17 %15 = OpLabel OpBranch %18 @@ -48,17 +51,34 @@ OpBranch %18 %28 = OpIAdd %7 %33 %36 %37 = OpCopyObject %9 %8 %38 = OpExtInst %9 %1 FindUMsb %37 -%39 = OpExtInst %7 %1 FindUMsb %10 -%41 = OpISub %7 %40 %39 -%42 = OpExtInst %7 %1 FindUMsb %11 -%43 = OpISub %7 %40 %42 +%39 = OpExtInst %7 %1 FindILsb %10 +%41 = OpISub %7 %39 %40 +%42 = OpExtInst %7 %1 FindILsb %11 +%43 = OpISub %7 %42 %40 %44 = OpBitcast %9 %43 %45 = OpCompositeConstruct %13 %10 %10 -%46 = OpExtInst %13 %1 FindUMsb %45 -%48 = OpISub %13 %47 %46 +%46 = OpExtInst %13 %1 FindILsb %45 +%48 = OpISub %13 %46 %47 %50 = OpCompositeConstruct %49 %11 %11 -%51 = OpExtInst %13 %1 FindUMsb %50 -%53 = OpISub %13 %52 %51 +%51 = OpExtInst %13 %1 FindILsb %50 +%53 = OpISub %13 %51 %52 %54 = OpBitcast %49 %53 +%55 = OpExtInst %7 %1 FindILsb %6 +%56 = OpISub %7 %55 %40 +%57 = OpExtInst %7 %1 FindILsb %8 +%58 = OpISub %7 %57 %40 +%59 = OpBitcast %9 %58 +%60 = OpExtInst %7 %1 FindUMsb %10 +%62 = OpISub %7 %61 %60 +%63 = OpExtInst %7 %1 FindUMsb %11 +%64 = OpISub %7 %61 %63 +%65 = OpBitcast %9 %64 +%66 = OpCompositeConstruct %13 %10 %10 +%67 = OpExtInst %13 %1 FindUMsb %66 +%69 = OpISub %13 %68 %67 +%70 = OpCompositeConstruct %49 %11 %11 +%71 = OpExtInst %13 %1 FindUMsb %70 +%73 = OpISub %13 %72 %71 +%74 = OpBitcast %49 %73 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/math-functions.wgsl b/tests/out/wgsl/math-functions.wgsl index d91a26cff4..e2db446387 100644 --- a/tests/out/wgsl/math-functions.wgsl +++ b/tests/out/wgsl/math-functions.wgsl @@ -9,6 +9,12 @@ fn main() { let g = refract(v, v, 1.0); let const_dot = dot(vec2(0, 0), vec2(0, 0)); let first_leading_bit_abs = firstLeadingBit(abs(0u)); + let ctz_a = countTrailingZeros(-1); + let ctz_b = countTrailingZeros(1u); + let ctz_c = countTrailingZeros(vec2(-1)); + let ctz_d = countTrailingZeros(vec2(1u)); + let ctz_e = countTrailingZeros(0); + let ctz_f = countTrailingZeros(0u); let clz_a = countLeadingZeros(-1); let clz_b = countLeadingZeros(1u); let clz_c = countLeadingZeros(vec2(-1)); From ad9f67d2300f5910c027fbc2cc7bf52e2c1dbdaa Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sun, 5 Feb 2023 12:56:11 +0100 Subject: [PATCH 02/15] Adding missing msl --- tests/out/msl/math-functions.msl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/out/msl/math-functions.msl b/tests/out/msl/math-functions.msl index 3fdb7b75a5..2423815da6 100644 --- a/tests/out/msl/math-functions.msl +++ b/tests/out/msl/math-functions.msl @@ -18,6 +18,12 @@ vertex void main_( int const_dot = ( + const_type_1_.x * const_type_1_.x + const_type_1_.y * const_type_1_.y); uint _e13 = metal::abs(0u); uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1); + int ctz_a = metal::ctz(-1); + uint ctz_b = metal::ctz(1u); + metal::int2 ctz_c = metal::ctz(metal::int2(-1)); + metal::uint2 ctz_d = metal::ctz(metal::uint2(1u)); + int ctz_e = metal::ctz(0); + uint ctz_f = metal::ctz(0u); int clz_a = metal::clz(-1); uint clz_b = metal::clz(1u); metal::int2 clz_c = metal::clz(metal::int2(-1)); From 737cfcd05fed9434e10926312180220b15f35d01 Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sun, 5 Feb 2023 19:20:00 +0100 Subject: [PATCH 03/15] Use lsb to retrieve counting trailing zeros --- src/back/glsl/mod.rs | 76 +------------------ src/back/hlsl/writer.rs | 42 +--------- src/back/spv/block.rs | 66 +--------------- .../out/glsl/math-functions.main.Vertex.glsl | 13 ++-- tests/out/hlsl/math-functions.hlsl | 12 +-- tests/out/spv/math-functions.spvasm | 60 ++++++--------- 6 files changed, 40 insertions(+), 229 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 3db68ad921..35f57e86fb 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1149,8 +1149,7 @@ impl<'a, W: Write> Writer<'a, W> { } } } - crate::MathFunction::CountTrailingZeros - | crate::MathFunction::CountLeadingZeros => { + crate::MathFunction::CountLeadingZeros => { if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { self.need_bake_expressions.insert(arg); } @@ -2961,78 +2960,7 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Transpose => "transpose", Mf::Determinant => "determinant", // bits - Mf::CountTrailingZeros => { - if self.options.version.supports_integer_functions() { - match *ctx.info[arg].ty.inner_with(&self.module.types) { - crate::TypeInner::Vector { size, kind, .. } => { - let s = back::vector_size_str(size); - - if let crate::ScalarKind::Uint = kind { - write!(self.out, "uvec{s}(findLSB(")?; - self.write_expr(arg, ctx)?; - write!(self.out, ") - ivec{s}(1))")?; - } else { - write!(self.out, "mix(findLSB(")?; - self.write_expr(arg, ctx)?; - write!(self.out, ") - ivec{s}(1), ivec{s}(0), lessThan(")?; - self.write_expr(arg, ctx)?; - write!(self.out, ", ivec{s}(0)))")?; - } - } - crate::TypeInner::Scalar { kind, .. } => { - if let crate::ScalarKind::Uint = kind { - write!(self.out, "uint(findLSB(")?; - } else { - write!(self.out, "(")?; - self.write_expr(arg, ctx)?; - write!(self.out, " == 0 ? -1 : findLSB(")?; - } - - self.write_expr(arg, ctx)?; - write!(self.out, ") - 1)")?; - } - _ => unreachable!(), - }; - } else { - match *ctx.info[arg].ty.inner_with(&self.module.types) { - crate::TypeInner::Vector { size, kind, .. } => { - let s = back::vector_size_str(size); - - if let crate::ScalarKind::Uint = kind { - write!(self.out, "uvec{s}(")?; - write!(self.out, "floor(log2(vec{s}(")?; - self.write_expr(arg, ctx)?; - write!(self.out, ") + 0.5)) - vec{s}(1.0))")?; - } else { - write!(self.out, "ivec{s}(")?; - write!(self.out, "mix(floor(log2(vec{s}(")?; - self.write_expr(arg, ctx)?; - write!(self.out, ") + 0.5)) - vec{s}(1.0), ")?; - write!(self.out, "vec{s}(0.0), lessThan(")?; - self.write_expr(arg, ctx)?; - write!(self.out, ", ivec{s}(0u))))")?; - } - } - crate::TypeInner::Scalar { kind, .. } => { - if let crate::ScalarKind::Uint = kind { - write!(self.out, "uint(floor(log2(float(")?; - self.write_expr(arg, ctx)?; - write!(self.out, ") + 0.5)) - 1.0)")?; - } else { - write!(self.out, "(")?; - self.write_expr(arg, ctx)?; - write!(self.out, " == 0 ? -1 : int(")?; - write!(self.out, "floor(log2(float(")?; - self.write_expr(arg, ctx)?; - write!(self.out, ") + 0.5))) - 1.0)")?; - } - } - _ => unreachable!(), - }; - } - - return Ok(()); - } + Mf::CountTrailingZeros => "findLSB", Mf::CountLeadingZeros => { if self.options.version.supports_integer_functions() { match *ctx.info[arg].ty.inner_with(&self.module.types) { diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 35f9d55010..17445bd396 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2552,7 +2552,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Unpack2x16float, Regular(&'static str), MissingIntOverload(&'static str), - CountTrailingZeros, CountLeadingZeros, } @@ -2616,7 +2615,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Mf::Transpose => Function::Regular("transpose"), Mf::Determinant => Function::Regular("determinant"), // bits - Mf::CountTrailingZeros => Function::CountTrailingZeros, + Mf::CountTrailingZeros => Function::Regular("firstbitlow"), Mf::CountLeadingZeros => Function::CountLeadingZeros, Mf::CountOneBits => Function::MissingIntOverload("countbits"), Mf::ReverseBits => Function::MissingIntOverload("reversebits"), @@ -2685,45 +2684,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; } } - Function::CountTrailingZeros => { - match *func_ctx.info[arg].ty.inner_with(&module.types) { - TypeInner::Vector { size, kind, .. } => { - let s = match size { - crate::VectorSize::Bi => ".xx", - crate::VectorSize::Tri => ".xxx", - crate::VectorSize::Quad => ".xxxx", - }; - - if let ScalarKind::Uint = kind { - write!(self.out, "asuint(firstbitlow(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, ") - (1){s})")?; - } else { - write!(self.out, "(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, " == (0){s} ? (-1){s} : firstbitlow(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, ") - (1){s})")?; - } - } - TypeInner::Scalar { kind, .. } => { - if let ScalarKind::Uint = kind { - write!(self.out, "asuint(firstbitlow(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, ") - 1)")?; - } else { - write!(self.out, "(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, " == 0 ? -1 : firstbitlow(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, ") - 1)")?; - } - } - _ => unreachable!(), - } - - return Ok(()); - } Function::CountLeadingZeros => { match *func_ctx.info[arg].ty.inner_with(&module.types) { TypeInner::Vector { size, kind, .. } => { diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index a6e1fa4630..c0730ad068 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -888,71 +888,7 @@ impl<'w> BlockContext<'w> { id, arg0_id, )), - Mf::CountTrailingZeros => { - let int = crate::ScalarValue::Sint(1); - - let (int_type_id, int_id) = match *arg_ty { - crate::TypeInner::Vector { size, width, .. } => { - let ty = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(size), - kind: crate::ScalarKind::Sint, - width, - pointer_space: None, - })); - - self.temp_list.clear(); - self.temp_list - .resize(size as _, self.writer.get_constant_scalar(int, width)); - - let id = self.gen_id(); - block.body.push(Instruction::constant_composite( - ty, - id, - &self.temp_list, - )); - - (ty, id) - } - crate::TypeInner::Scalar { width, .. } => ( - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Sint, - width, - pointer_space: None, - })), - self.writer.get_constant_scalar(int, width), - ), - _ => unreachable!(), - }; - - block.body.push(Instruction::ext_inst( - self.writer.gl450_ext_inst_id, - spirv::GLOp::FindILsb, - int_type_id, - id, - &[arg0_id], - )); - - let sub_id = self.gen_id(); - block.body.push(Instruction::binary( - spirv::Op::ISub, - int_type_id, - sub_id, - id, - int_id, - )); - - if let Some(crate::ScalarKind::Uint) = arg_scalar_kind { - block.body.push(Instruction::unary( - spirv::Op::Bitcast, - result_type_id, - self.gen_id(), - sub_id, - )); - } - - return Ok(()); - } + Mf::CountTrailingZeros => MathOp::Ext(spirv::GLOp::FindILsb), Mf::CountLeadingZeros => { let int = crate::ScalarValue::Sint(31); diff --git a/tests/out/glsl/math-functions.main.Vertex.glsl b/tests/out/glsl/math-functions.main.Vertex.glsl index b62a425ae7..1967bf01d2 100644 --- a/tests/out/glsl/math-functions.main.Vertex.glsl +++ b/tests/out/glsl/math-functions.main.Vertex.glsl @@ -14,13 +14,12 @@ void main() { vec4 g = refract(v, v, 1.0); int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y); uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); - int ctz_a = (-1 == 0 ? -1 : findLSB(-1) - 1); - uint ctz_b = uint(findLSB(1u) - 1); - ivec2 _e20 = ivec2(-1); - ivec2 ctz_c = mix(findLSB(_e20) - ivec2(1), ivec2(0), lessThan(_e20, ivec2(0))); - uvec2 ctz_d = uvec2(findLSB(uvec2(1u)) - ivec2(1)); - int ctz_e = (0 == 0 ? -1 : findLSB(0) - 1); - uint ctz_f = uint(findLSB(0u) - 1); + int ctz_a = findLSB(-1); + uint ctz_b = findLSB(1u); + ivec2 ctz_c = findLSB(ivec2(-1)); + uvec2 ctz_d = findLSB(uvec2(1u)); + int ctz_e = findLSB(0); + uint ctz_f = findLSB(0u); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); ivec2 _e34 = ivec2(-1); diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index 1c052ad3d6..37b52b634d 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -10,13 +10,13 @@ void main() float4 g = refract(v, v, 1.0); int const_dot = dot(int2(0, 0), int2(0, 0)); uint first_leading_bit_abs = firstbithigh(abs(0u)); - int ctz_a = (-1 == 0 ? -1 : firstbitlow(-1) - 1); - uint ctz_b = asuint(firstbitlow(1u) - 1); + int ctz_a = firstbitlow(-1); + uint ctz_b = firstbitlow(1u); int2 _expr20 = (-1).xx; - int2 ctz_c = (_expr20 == (0).xx ? (-1).xx : firstbitlow(_expr20) - (1).xx); - uint2 ctz_d = asuint(firstbitlow((1u).xx) - (1).xx); - int ctz_e = (0 == 0 ? -1 : firstbitlow(0) - 1); - uint ctz_f = asuint(firstbitlow(0u) - 1); + int2 ctz_c = firstbitlow(_expr20); + uint2 ctz_d = firstbitlow((1u).xx); + int ctz_e = firstbitlow(0); + uint ctz_f = firstbitlow(0u); int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1)); uint clz_b = asuint(31 - firstbithigh(1u)); int2 _expr34 = (-1).xx; diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index ff0797bea1..056adfa8f0 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 75 +; Bound: 63 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -20,16 +20,13 @@ OpEntryPoint Vertex %16 "main" %13 = OpTypeVector %7 2 %14 = OpConstantComposite %13 %6 %6 %17 = OpTypeFunction %2 -%40 = OpConstant %7 1 -%49 = OpTypeVector %9 2 -%61 = OpConstant %7 31 +%43 = OpTypeVector %9 2 +%49 = OpConstant %7 31 %25 = OpConstantComposite %12 %5 %5 %5 %5 %26 = OpConstantComposite %12 %3 %3 %3 %3 %29 = OpConstantNull %7 -%47 = OpConstantComposite %13 %40 %40 -%52 = OpConstantComposite %13 %40 %40 -%68 = OpConstantComposite %13 %61 %61 -%72 = OpConstantComposite %13 %61 %61 +%56 = OpConstantComposite %13 %49 %49 +%60 = OpConstantComposite %13 %49 %49 %16 = OpFunction %2 None %17 %15 = OpLabel OpBranch %18 @@ -52,33 +49,24 @@ OpBranch %18 %37 = OpCopyObject %9 %8 %38 = OpExtInst %9 %1 FindUMsb %37 %39 = OpExtInst %7 %1 FindILsb %10 -%41 = OpISub %7 %39 %40 -%42 = OpExtInst %7 %1 FindILsb %11 -%43 = OpISub %7 %42 %40 -%44 = OpBitcast %9 %43 -%45 = OpCompositeConstruct %13 %10 %10 -%46 = OpExtInst %13 %1 FindILsb %45 -%48 = OpISub %13 %46 %47 -%50 = OpCompositeConstruct %49 %11 %11 -%51 = OpExtInst %13 %1 FindILsb %50 -%53 = OpISub %13 %51 %52 -%54 = OpBitcast %49 %53 -%55 = OpExtInst %7 %1 FindILsb %6 -%56 = OpISub %7 %55 %40 -%57 = OpExtInst %7 %1 FindILsb %8 -%58 = OpISub %7 %57 %40 -%59 = OpBitcast %9 %58 -%60 = OpExtInst %7 %1 FindUMsb %10 -%62 = OpISub %7 %61 %60 -%63 = OpExtInst %7 %1 FindUMsb %11 -%64 = OpISub %7 %61 %63 -%65 = OpBitcast %9 %64 -%66 = OpCompositeConstruct %13 %10 %10 -%67 = OpExtInst %13 %1 FindUMsb %66 -%69 = OpISub %13 %68 %67 -%70 = OpCompositeConstruct %49 %11 %11 -%71 = OpExtInst %13 %1 FindUMsb %70 -%73 = OpISub %13 %72 %71 -%74 = OpBitcast %49 %73 +%40 = OpExtInst %9 %1 FindILsb %11 +%41 = OpCompositeConstruct %13 %10 %10 +%42 = OpExtInst %13 %1 FindILsb %41 +%44 = OpCompositeConstruct %43 %11 %11 +%45 = OpExtInst %43 %1 FindILsb %44 +%46 = OpExtInst %7 %1 FindILsb %6 +%47 = OpExtInst %9 %1 FindILsb %8 +%48 = OpExtInst %7 %1 FindUMsb %10 +%50 = OpISub %7 %49 %48 +%51 = OpExtInst %7 %1 FindUMsb %11 +%52 = OpISub %7 %49 %51 +%53 = OpBitcast %9 %52 +%54 = OpCompositeConstruct %13 %10 %10 +%55 = OpExtInst %13 %1 FindUMsb %54 +%57 = OpISub %13 %56 %55 +%58 = OpCompositeConstruct %43 %11 %11 +%59 = OpExtInst %13 %1 FindUMsb %58 +%61 = OpISub %13 %60 %59 +%62 = OpBitcast %43 %61 OpReturn OpFunctionEnd \ No newline at end of file From ec2995540ae64ad9046c0ba477e508948172033a Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sun, 5 Feb 2023 19:27:14 +0100 Subject: [PATCH 04/15] Simplified tests --- tests/in/math-functions.wgsl | 10 +- .../out/glsl/math-functions.main.Vertex.glsl | 14 +-- tests/out/hlsl/math-functions.hlsl | 15 +-- tests/out/msl/math-functions.msl | 10 +- tests/out/spv/math-functions.spvasm | 109 +++++++++--------- tests/out/wgsl/math-functions.wgsl | 10 +- 6 files changed, 78 insertions(+), 90 deletions(-) diff --git a/tests/in/math-functions.wgsl b/tests/in/math-functions.wgsl index aa2679155f..81e6ae2187 100644 --- a/tests/in/math-functions.wgsl +++ b/tests/in/math-functions.wgsl @@ -10,12 +10,10 @@ fn main() { let g = refract(v, v, f); let const_dot = dot(vec2(), vec2()); let first_leading_bit_abs = firstLeadingBit(abs(0u)); - let ctz_a = countTrailingZeros(-1); - let ctz_b = countTrailingZeros(1u); - let ctz_c = countTrailingZeros(vec2(-1)); - let ctz_d = countTrailingZeros(vec2(1u)); - let ctz_e = countTrailingZeros(0); - let ctz_f = countTrailingZeros(0u); + let ctz_a = countTrailingZeros(0u); + let ctz_b = countTrailingZeros(0xFFFFFFFFu); + let ctz_c = countTrailingZeros(vec2(1u)); + let ctz_d = countTrailingZeros(vec2(0u)); let clz_a = countLeadingZeros(-1); let clz_b = countLeadingZeros(1u); let clz_c = countLeadingZeros(vec2(-1)); diff --git a/tests/out/glsl/math-functions.main.Vertex.glsl b/tests/out/glsl/math-functions.main.Vertex.glsl index 1967bf01d2..e724e40ea1 100644 --- a/tests/out/glsl/math-functions.main.Vertex.glsl +++ b/tests/out/glsl/math-functions.main.Vertex.glsl @@ -14,16 +14,14 @@ void main() { vec4 g = refract(v, v, 1.0); int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y); uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); - int ctz_a = findLSB(-1); - uint ctz_b = findLSB(1u); - ivec2 ctz_c = findLSB(ivec2(-1)); - uvec2 ctz_d = findLSB(uvec2(1u)); - int ctz_e = findLSB(0); - uint ctz_f = findLSB(0u); + uint ctz_a = findLSB(0u); + uint ctz_b = findLSB(4294967295u); + uvec2 ctz_c = findLSB(uvec2(1u)); + uvec2 ctz_d = findLSB(uvec2(0u)); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); - ivec2 _e34 = ivec2(-1); - ivec2 clz_c = mix(ivec2(31) - findMSB(_e34), ivec2(0), lessThan(_e34, ivec2(0))); + ivec2 _e30 = ivec2(-1); + ivec2 clz_c = mix(ivec2(31) - findMSB(_e30), ivec2(0), lessThan(_e30, ivec2(0))); uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u))); } diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index 37b52b634d..e070ba7210 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -10,16 +10,13 @@ void main() float4 g = refract(v, v, 1.0); int const_dot = dot(int2(0, 0), int2(0, 0)); uint first_leading_bit_abs = firstbithigh(abs(0u)); - int ctz_a = firstbitlow(-1); - uint ctz_b = firstbitlow(1u); - int2 _expr20 = (-1).xx; - int2 ctz_c = firstbitlow(_expr20); - uint2 ctz_d = firstbitlow((1u).xx); - int ctz_e = firstbitlow(0); - uint ctz_f = firstbitlow(0u); + uint ctz_a = firstbitlow(0u); + uint ctz_b = firstbitlow(4294967295u); + uint2 ctz_c = firstbitlow((1u).xx); + uint2 ctz_d = firstbitlow((0u).xx); int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1)); uint clz_b = asuint(31 - firstbithigh(1u)); - int2 _expr34 = (-1).xx; - int2 clz_c = (_expr34 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr34)); + int2 _expr30 = (-1).xx; + int2 clz_c = (_expr30 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr30)); uint2 clz_d = asuint((31).xx - firstbithigh((1u).xx)); } diff --git a/tests/out/msl/math-functions.msl b/tests/out/msl/math-functions.msl index 2423815da6..818bd98d23 100644 --- a/tests/out/msl/math-functions.msl +++ b/tests/out/msl/math-functions.msl @@ -18,12 +18,10 @@ vertex void main_( int const_dot = ( + const_type_1_.x * const_type_1_.x + const_type_1_.y * const_type_1_.y); uint _e13 = metal::abs(0u); uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1); - int ctz_a = metal::ctz(-1); - uint ctz_b = metal::ctz(1u); - metal::int2 ctz_c = metal::ctz(metal::int2(-1)); - metal::uint2 ctz_d = metal::ctz(metal::uint2(1u)); - int ctz_e = metal::ctz(0); - uint ctz_f = metal::ctz(0u); + uint ctz_a = metal::ctz(0u); + uint ctz_b = metal::ctz(4294967295u); + metal::uint2 ctz_c = metal::ctz(metal::uint2(1u)); + metal::uint2 ctz_d = metal::ctz(metal::uint2(0u)); int clz_a = metal::clz(-1); uint clz_b = metal::clz(1u); metal::int2 clz_c = metal::clz(metal::int2(-1)); diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index 056adfa8f0..3857a6033b 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,11 +1,11 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 63 +; Bound: 62 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %16 "main" +OpEntryPoint Vertex %17 "main" %2 = OpTypeVoid %4 = OpTypeFloat 32 %3 = OpConstant %4 1.0 @@ -14,59 +14,58 @@ OpEntryPoint Vertex %16 "main" %6 = OpConstant %7 0 %9 = OpTypeInt 32 0 %8 = OpConstant %9 0 -%10 = OpConstant %7 -1 +%10 = OpConstant %9 4294967295 %11 = OpConstant %9 1 -%12 = OpTypeVector %4 4 -%13 = OpTypeVector %7 2 -%14 = OpConstantComposite %13 %6 %6 -%17 = OpTypeFunction %2 -%43 = OpTypeVector %9 2 -%49 = OpConstant %7 31 -%25 = OpConstantComposite %12 %5 %5 %5 %5 -%26 = OpConstantComposite %12 %3 %3 %3 %3 -%29 = OpConstantNull %7 -%56 = OpConstantComposite %13 %49 %49 -%60 = OpConstantComposite %13 %49 %49 -%16 = OpFunction %2 None %17 -%15 = OpLabel -OpBranch %18 -%18 = OpLabel -%19 = OpCompositeConstruct %12 %5 %5 %5 %5 -%20 = OpExtInst %4 %1 Degrees %3 -%21 = OpExtInst %4 %1 Radians %3 -%22 = OpExtInst %12 %1 Degrees %19 -%23 = OpExtInst %12 %1 Radians %19 -%24 = OpExtInst %12 %1 FClamp %19 %25 %26 -%27 = OpExtInst %12 %1 Refract %19 %19 %3 -%30 = OpCompositeExtract %7 %14 0 -%31 = OpCompositeExtract %7 %14 0 -%32 = OpIMul %7 %30 %31 -%33 = OpIAdd %7 %29 %32 -%34 = OpCompositeExtract %7 %14 1 -%35 = OpCompositeExtract %7 %14 1 -%36 = OpIMul %7 %34 %35 -%28 = OpIAdd %7 %33 %36 -%37 = OpCopyObject %9 %8 -%38 = OpExtInst %9 %1 FindUMsb %37 -%39 = OpExtInst %7 %1 FindILsb %10 -%40 = OpExtInst %9 %1 FindILsb %11 -%41 = OpCompositeConstruct %13 %10 %10 -%42 = OpExtInst %13 %1 FindILsb %41 -%44 = OpCompositeConstruct %43 %11 %11 -%45 = OpExtInst %43 %1 FindILsb %44 -%46 = OpExtInst %7 %1 FindILsb %6 -%47 = OpExtInst %9 %1 FindILsb %8 -%48 = OpExtInst %7 %1 FindUMsb %10 -%50 = OpISub %7 %49 %48 -%51 = OpExtInst %7 %1 FindUMsb %11 -%52 = OpISub %7 %49 %51 -%53 = OpBitcast %9 %52 -%54 = OpCompositeConstruct %13 %10 %10 -%55 = OpExtInst %13 %1 FindUMsb %54 -%57 = OpISub %13 %56 %55 -%58 = OpCompositeConstruct %43 %11 %11 -%59 = OpExtInst %13 %1 FindUMsb %58 -%61 = OpISub %13 %60 %59 -%62 = OpBitcast %43 %61 +%12 = OpConstant %7 -1 +%13 = OpTypeVector %4 4 +%14 = OpTypeVector %7 2 +%15 = OpConstantComposite %14 %6 %6 +%18 = OpTypeFunction %2 +%42 = OpTypeVector %9 2 +%48 = OpConstant %7 31 +%26 = OpConstantComposite %13 %5 %5 %5 %5 +%27 = OpConstantComposite %13 %3 %3 %3 %3 +%30 = OpConstantNull %7 +%55 = OpConstantComposite %14 %48 %48 +%59 = OpConstantComposite %14 %48 %48 +%17 = OpFunction %2 None %18 +%16 = OpLabel +OpBranch %19 +%19 = OpLabel +%20 = OpCompositeConstruct %13 %5 %5 %5 %5 +%21 = OpExtInst %4 %1 Degrees %3 +%22 = OpExtInst %4 %1 Radians %3 +%23 = OpExtInst %13 %1 Degrees %20 +%24 = OpExtInst %13 %1 Radians %20 +%25 = OpExtInst %13 %1 FClamp %20 %26 %27 +%28 = OpExtInst %13 %1 Refract %20 %20 %3 +%31 = OpCompositeExtract %7 %15 0 +%32 = OpCompositeExtract %7 %15 0 +%33 = OpIMul %7 %31 %32 +%34 = OpIAdd %7 %30 %33 +%35 = OpCompositeExtract %7 %15 1 +%36 = OpCompositeExtract %7 %15 1 +%37 = OpIMul %7 %35 %36 +%29 = OpIAdd %7 %34 %37 +%38 = OpCopyObject %9 %8 +%39 = OpExtInst %9 %1 FindUMsb %38 +%40 = OpExtInst %9 %1 FindILsb %8 +%41 = OpExtInst %9 %1 FindILsb %10 +%43 = OpCompositeConstruct %42 %11 %11 +%44 = OpExtInst %42 %1 FindILsb %43 +%45 = OpCompositeConstruct %42 %8 %8 +%46 = OpExtInst %42 %1 FindILsb %45 +%47 = OpExtInst %7 %1 FindUMsb %12 +%49 = OpISub %7 %48 %47 +%50 = OpExtInst %7 %1 FindUMsb %11 +%51 = OpISub %7 %48 %50 +%52 = OpBitcast %9 %51 +%53 = OpCompositeConstruct %14 %12 %12 +%54 = OpExtInst %14 %1 FindUMsb %53 +%56 = OpISub %14 %55 %54 +%57 = OpCompositeConstruct %42 %11 %11 +%58 = OpExtInst %14 %1 FindUMsb %57 +%60 = OpISub %14 %59 %58 +%61 = OpBitcast %42 %60 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/math-functions.wgsl b/tests/out/wgsl/math-functions.wgsl index e2db446387..639505c4a1 100644 --- a/tests/out/wgsl/math-functions.wgsl +++ b/tests/out/wgsl/math-functions.wgsl @@ -9,12 +9,10 @@ fn main() { let g = refract(v, v, 1.0); let const_dot = dot(vec2(0, 0), vec2(0, 0)); let first_leading_bit_abs = firstLeadingBit(abs(0u)); - let ctz_a = countTrailingZeros(-1); - let ctz_b = countTrailingZeros(1u); - let ctz_c = countTrailingZeros(vec2(-1)); - let ctz_d = countTrailingZeros(vec2(1u)); - let ctz_e = countTrailingZeros(0); - let ctz_f = countTrailingZeros(0u); + let ctz_a = countTrailingZeros(0u); + let ctz_b = countTrailingZeros(4294967295u); + let ctz_c = countTrailingZeros(vec2(1u)); + let ctz_d = countTrailingZeros(vec2(0u)); let clz_a = countLeadingZeros(-1); let clz_b = countLeadingZeros(1u); let clz_c = countLeadingZeros(vec2(-1)); From eff9db7389b78c3e5ba64b2bd3be2810bd98286f Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sun, 5 Feb 2023 19:44:36 +0100 Subject: [PATCH 05/15] Fixing glsl --- src/back/glsl/mod.rs | 34 +++++++++++++++++-- .../out/glsl/math-functions.main.Vertex.glsl | 8 ++--- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 35f57e86fb..18e71f138b 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1149,7 +1149,8 @@ impl<'a, W: Write> Writer<'a, W> { } } } - crate::MathFunction::CountLeadingZeros => { + crate::MathFunction::CountTrailingZeros + | crate::MathFunction::CountLeadingZeros => { if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { self.need_bake_expressions.insert(arg); } @@ -2960,7 +2961,36 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Transpose => "transpose", Mf::Determinant => "determinant", // bits - Mf::CountTrailingZeros => "findLSB", + Mf::CountTrailingZeros => { + match *ctx.info[arg].ty.inner_with(&self.module.types) { + crate::TypeInner::Vector { size, kind, .. } => { + let s = back::vector_size_str(size); + + if let crate::ScalarKind::Uint = kind { + write!(self.out, "uvec{s}(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "findMSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } + } + crate::TypeInner::Scalar { kind, .. } => { + if let crate::ScalarKind::Uint = kind { + write!(self.out, "uint(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } + } + _ => unreachable!(), + }; + return Ok(()); + } Mf::CountLeadingZeros => { if self.options.version.supports_integer_functions() { match *ctx.info[arg].ty.inner_with(&self.module.types) { diff --git a/tests/out/glsl/math-functions.main.Vertex.glsl b/tests/out/glsl/math-functions.main.Vertex.glsl index e724e40ea1..58f0df8ff7 100644 --- a/tests/out/glsl/math-functions.main.Vertex.glsl +++ b/tests/out/glsl/math-functions.main.Vertex.glsl @@ -14,10 +14,10 @@ void main() { vec4 g = refract(v, v, 1.0); int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y); uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); - uint ctz_a = findLSB(0u); - uint ctz_b = findLSB(4294967295u); - uvec2 ctz_c = findLSB(uvec2(1u)); - uvec2 ctz_d = findLSB(uvec2(0u)); + uint ctz_a = uint(findLSB(0u)); + uint ctz_b = uint(findLSB(4294967295u)); + uvec2 ctz_c = uvec2(findLSB(uvec2(1u))); + uvec2 ctz_d = uvec2(findLSB(uvec2(0u))); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); ivec2 _e30 = ivec2(-1); From 0a98f82d0684ff41d0478aab87d1d98ad0976693 Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sat, 11 Feb 2023 11:27:47 +0100 Subject: [PATCH 06/15] Updated with min(findlsb, 32) --- src/back/glsl/mod.rs | 17 +-- src/back/hlsl/writer.rs | 25 +++- src/back/spv/block.rs | 75 +++++++++- tests/in/math-functions.wgsl | 10 +- .../out/glsl/math-functions.main.Vertex.glsl | 18 ++- tests/out/hlsl/math-functions.hlsl | 18 ++- tests/out/msl/math-functions.msl | 10 +- tests/out/spv/math-functions.spvasm | 140 +++++++++++------- tests/out/wgsl/math-functions.wgsl | 10 +- 9 files changed, 237 insertions(+), 86 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 18e71f138b..96473e48d9 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2965,26 +2965,25 @@ impl<'a, W: Write> Writer<'a, W> { match *ctx.info[arg].ty.inner_with(&self.module.types) { crate::TypeInner::Vector { size, kind, .. } => { let s = back::vector_size_str(size); - if let crate::ScalarKind::Uint = kind { - write!(self.out, "uvec{s}(findLSB(")?; + write!(self.out, "min(uvec{s}(findLSB(")?; self.write_expr(arg, ctx)?; - write!(self.out, "))")?; + write!(self.out, ")), uvec{s}(32u))")?; } else { - write!(self.out, "findMSB(")?; + write!(self.out, "ivec{s}(min(uvec{s}(findLSB(")?; self.write_expr(arg, ctx)?; - write!(self.out, ")")?; + write!(self.out, ")), uvec{s}(32u)))")?; } } crate::TypeInner::Scalar { kind, .. } => { if let crate::ScalarKind::Uint = kind { - write!(self.out, "uint(findLSB(")?; + write!(self.out, "min(uint(findLSB(")?; self.write_expr(arg, ctx)?; - write!(self.out, "))")?; + write!(self.out, ")), 32u)")?; } else { - write!(self.out, "findLSB(")?; + write!(self.out, "int(min(uint(findLSB(")?; self.write_expr(arg, ctx)?; - write!(self.out, ")")?; + write!(self.out, ")), 32u))")?; } } _ => unreachable!(), diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 17445bd396..7472c1afef 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2552,6 +2552,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Unpack2x16float, Regular(&'static str), MissingIntOverload(&'static str), + CountTrailingZeros, CountLeadingZeros, } @@ -2615,7 +2616,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Mf::Transpose => Function::Regular("transpose"), Mf::Determinant => Function::Regular("determinant"), // bits - Mf::CountTrailingZeros => Function::Regular("firstbitlow"), + Mf::CountTrailingZeros => Function::CountTrailingZeros, Mf::CountLeadingZeros => Function::CountLeadingZeros, Mf::CountOneBits => Function::MissingIntOverload("countbits"), Mf::ReverseBits => Function::MissingIntOverload("reversebits"), @@ -2684,6 +2685,28 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; } } + Function::CountTrailingZeros => { + match *func_ctx.info[arg].ty.inner_with(&module.types) { + TypeInner::Vector { size, kind, .. } => { + let s = match size { + crate::VectorSize::Bi => ".xx", + crate::VectorSize::Tri => ".xxx", + crate::VectorSize::Quad => ".xxxx", + }; + + write!(self.out, "min(asuint((32){s}, asuint(firstbitlow(")?; + } + TypeInner::Scalar { kind, .. } => { + write!(self.out, "min(asuint((32), asuint(firstbitlow(")?; + } + _ => unreachable!(), + } + + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))))")?; + + return Ok(()); + } Function::CountLeadingZeros => { match *func_ctx.info[arg].ty.inner_with(&module.types) { TypeInner::Vector { size, kind, .. } => { diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index c0730ad068..fa3701d3f1 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -888,7 +888,80 @@ impl<'w> BlockContext<'w> { id, arg0_id, )), - Mf::CountTrailingZeros => MathOp::Ext(spirv::GLOp::FindILsb), + Mf::CountTrailingZeros => { + let uint = crate::ScalarValue::Uint(32); + + let (uint_type_id, uint_id) = match *arg_ty { + crate::TypeInner::Vector { size, width, .. } => { + let uty = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + + self.temp_list.clear(); + self.temp_list.resize( + size as _, + self.writer.get_constant_scalar(uint, width), + ); + + let uid = self.gen_id(); + block.body.push(Instruction::constant_composite( + uty, + uid, + &self.temp_list, + )); + (uty, uid) + } + crate::TypeInner::Scalar { width, .. } => ( + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })), + self.writer.get_constant_scalar(uint, width), + ), + _ => unreachable!(), + }; + + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FindILsb, + result_type_id, + id, + &[arg0_id], + )); + + let cast_id = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + uint_type_id, + self.gen_id(), + id, + )); + + let min_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + uint_type_id, + min_id, + &[uint_id, cast_id], + )); + + if let Some(crate::ScalarKind::Sint) = arg_scalar_kind { + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + result_type_id, + self.gen_id(), + min_id, + )); + } + + return Ok(()); + } Mf::CountLeadingZeros => { let int = crate::ScalarValue::Sint(31); diff --git a/tests/in/math-functions.wgsl b/tests/in/math-functions.wgsl index 81e6ae2187..db50880d14 100644 --- a/tests/in/math-functions.wgsl +++ b/tests/in/math-functions.wgsl @@ -11,9 +11,13 @@ fn main() { let const_dot = dot(vec2(), vec2()); let first_leading_bit_abs = firstLeadingBit(abs(0u)); let ctz_a = countTrailingZeros(0u); - let ctz_b = countTrailingZeros(0xFFFFFFFFu); - let ctz_c = countTrailingZeros(vec2(1u)); - let ctz_d = countTrailingZeros(vec2(0u)); + let ctz_b = countTrailingZeros(0); + let ctz_c = countTrailingZeros(0xFFFFFFFFu); + let ctz_d = countTrailingZeros(-1); + let ctz_e = countTrailingZeros(vec2(0u)); + let ctz_f = countTrailingZeros(vec2(0)); + let ctz_g = countTrailingZeros(vec2(1u)); + let ctz_h = countTrailingZeros(vec2(1)); let clz_a = countLeadingZeros(-1); let clz_b = countLeadingZeros(1u); let clz_c = countLeadingZeros(vec2(-1)); diff --git a/tests/out/glsl/math-functions.main.Vertex.glsl b/tests/out/glsl/math-functions.main.Vertex.glsl index 58f0df8ff7..4f4e7b718e 100644 --- a/tests/out/glsl/math-functions.main.Vertex.glsl +++ b/tests/out/glsl/math-functions.main.Vertex.glsl @@ -14,14 +14,20 @@ void main() { vec4 g = refract(v, v, 1.0); int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y); uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); - uint ctz_a = uint(findLSB(0u)); - uint ctz_b = uint(findLSB(4294967295u)); - uvec2 ctz_c = uvec2(findLSB(uvec2(1u))); - uvec2 ctz_d = uvec2(findLSB(uvec2(0u))); + uint ctz_a = min(uint(findLSB(0u)), 32u); + int ctz_b = int(min(uint(findLSB(0)), 32u)); + uint ctz_c = min(uint(findLSB(4294967295u)), 32u); + int ctz_d = int(min(uint(findLSB(-1)), 32u)); + uvec2 ctz_e = min(uvec2(findLSB(uvec2(0u))), uvec2(32u)); + ivec2 _e27 = ivec2(0); + ivec2 ctz_f = ivec2(min(uvec2(findLSB(_e27)), uvec2(32u))); + uvec2 ctz_g = min(uvec2(findLSB(uvec2(1u))), uvec2(32u)); + ivec2 _e33 = ivec2(1); + ivec2 ctz_h = ivec2(min(uvec2(findLSB(_e33)), uvec2(32u))); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); - ivec2 _e30 = ivec2(-1); - ivec2 clz_c = mix(ivec2(31) - findMSB(_e30), ivec2(0), lessThan(_e30, ivec2(0))); + ivec2 _e40 = ivec2(-1); + ivec2 clz_c = mix(ivec2(31) - findMSB(_e40), ivec2(0), lessThan(_e40, ivec2(0))); uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u))); } diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index e070ba7210..ad6e5e9046 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -10,13 +10,19 @@ void main() float4 g = refract(v, v, 1.0); int const_dot = dot(int2(0, 0), int2(0, 0)); uint first_leading_bit_abs = firstbithigh(abs(0u)); - uint ctz_a = firstbitlow(0u); - uint ctz_b = firstbitlow(4294967295u); - uint2 ctz_c = firstbitlow((1u).xx); - uint2 ctz_d = firstbitlow((0u).xx); + uint ctz_a = min(asuint((32), asuint(firstbitlow(0u)))); + int ctz_b = min(asuint((32), asuint(firstbitlow(0)))); + uint ctz_c = min(asuint((32), asuint(firstbitlow(4294967295u)))); + int ctz_d = min(asuint((32), asuint(firstbitlow(-1)))); + uint2 ctz_e = min(asuint((32).xx, asuint(firstbitlow((0u).xx)))); + int2 _expr27 = (0).xx; + int2 ctz_f = min(asuint((32).xx, asuint(firstbitlow(_expr27)))); + uint2 ctz_g = min(asuint((32).xx, asuint(firstbitlow((1u).xx)))); + int2 _expr33 = (1).xx; + int2 ctz_h = min(asuint((32).xx, asuint(firstbitlow(_expr33)))); int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1)); uint clz_b = asuint(31 - firstbithigh(1u)); - int2 _expr30 = (-1).xx; - int2 clz_c = (_expr30 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr30)); + int2 _expr40 = (-1).xx; + int2 clz_c = (_expr40 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr40)); uint2 clz_d = asuint((31).xx - firstbithigh((1u).xx)); } diff --git a/tests/out/msl/math-functions.msl b/tests/out/msl/math-functions.msl index 818bd98d23..c2aac6ef98 100644 --- a/tests/out/msl/math-functions.msl +++ b/tests/out/msl/math-functions.msl @@ -19,9 +19,13 @@ vertex void main_( uint _e13 = metal::abs(0u); uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1); uint ctz_a = metal::ctz(0u); - uint ctz_b = metal::ctz(4294967295u); - metal::uint2 ctz_c = metal::ctz(metal::uint2(1u)); - metal::uint2 ctz_d = metal::ctz(metal::uint2(0u)); + int ctz_b = metal::ctz(0); + uint ctz_c = metal::ctz(4294967295u); + int ctz_d = metal::ctz(-1); + metal::uint2 ctz_e = metal::ctz(metal::uint2(0u)); + metal::int2 ctz_f = metal::ctz(metal::int2(0)); + metal::uint2 ctz_g = metal::ctz(metal::uint2(1u)); + metal::int2 ctz_h = metal::ctz(metal::int2(1)); int clz_a = metal::clz(-1); uint clz_b = metal::clz(1u); metal::int2 clz_c = metal::clz(metal::int2(-1)); diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index 3857a6033b..31437702c2 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,11 +1,11 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 62 +; Bound: 102 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %17 "main" +OpEntryPoint Vertex %18 "main" %2 = OpTypeVoid %4 = OpTypeFloat 32 %3 = OpConstant %4 1.0 @@ -15,57 +15,89 @@ OpEntryPoint Vertex %17 "main" %9 = OpTypeInt 32 0 %8 = OpConstant %9 0 %10 = OpConstant %9 4294967295 -%11 = OpConstant %9 1 -%12 = OpConstant %7 -1 -%13 = OpTypeVector %4 4 -%14 = OpTypeVector %7 2 -%15 = OpConstantComposite %14 %6 %6 -%18 = OpTypeFunction %2 -%42 = OpTypeVector %9 2 -%48 = OpConstant %7 31 -%26 = OpConstantComposite %13 %5 %5 %5 %5 -%27 = OpConstantComposite %13 %3 %3 %3 %3 -%30 = OpConstantNull %7 -%55 = OpConstantComposite %14 %48 %48 -%59 = OpConstantComposite %14 %48 %48 -%17 = OpFunction %2 None %18 -%16 = OpLabel -OpBranch %19 -%19 = OpLabel -%20 = OpCompositeConstruct %13 %5 %5 %5 %5 -%21 = OpExtInst %4 %1 Degrees %3 -%22 = OpExtInst %4 %1 Radians %3 -%23 = OpExtInst %13 %1 Degrees %20 -%24 = OpExtInst %13 %1 Radians %20 -%25 = OpExtInst %13 %1 FClamp %20 %26 %27 -%28 = OpExtInst %13 %1 Refract %20 %20 %3 -%31 = OpCompositeExtract %7 %15 0 -%32 = OpCompositeExtract %7 %15 0 -%33 = OpIMul %7 %31 %32 -%34 = OpIAdd %7 %30 %33 -%35 = OpCompositeExtract %7 %15 1 -%36 = OpCompositeExtract %7 %15 1 -%37 = OpIMul %7 %35 %36 -%29 = OpIAdd %7 %34 %37 -%38 = OpCopyObject %9 %8 -%39 = OpExtInst %9 %1 FindUMsb %38 -%40 = OpExtInst %9 %1 FindILsb %8 -%41 = OpExtInst %9 %1 FindILsb %10 -%43 = OpCompositeConstruct %42 %11 %11 -%44 = OpExtInst %42 %1 FindILsb %43 -%45 = OpCompositeConstruct %42 %8 %8 -%46 = OpExtInst %42 %1 FindILsb %45 -%47 = OpExtInst %7 %1 FindUMsb %12 -%49 = OpISub %7 %48 %47 -%50 = OpExtInst %7 %1 FindUMsb %11 -%51 = OpISub %7 %48 %50 -%52 = OpBitcast %9 %51 -%53 = OpCompositeConstruct %14 %12 %12 -%54 = OpExtInst %14 %1 FindUMsb %53 -%56 = OpISub %14 %55 %54 -%57 = OpCompositeConstruct %42 %11 %11 -%58 = OpExtInst %14 %1 FindUMsb %57 -%60 = OpISub %14 %59 %58 -%61 = OpBitcast %42 %60 +%11 = OpConstant %7 -1 +%12 = OpConstant %9 1 +%13 = OpConstant %7 1 +%14 = OpTypeVector %4 4 +%15 = OpTypeVector %7 2 +%16 = OpConstantComposite %15 %6 %6 +%19 = OpTypeFunction %2 +%42 = OpConstant %9 32 +%60 = OpTypeVector %9 2 +%88 = OpConstant %7 31 +%27 = OpConstantComposite %14 %5 %5 %5 %5 +%28 = OpConstantComposite %14 %3 %3 %3 %3 +%31 = OpConstantNull %7 +%63 = OpConstantComposite %60 %42 %42 +%69 = OpConstantComposite %60 %42 %42 +%76 = OpConstantComposite %60 %42 %42 +%82 = OpConstantComposite %60 %42 %42 +%95 = OpConstantComposite %15 %88 %88 +%99 = OpConstantComposite %15 %88 %88 +%18 = OpFunction %2 None %19 +%17 = OpLabel +OpBranch %20 +%20 = OpLabel +%21 = OpCompositeConstruct %14 %5 %5 %5 %5 +%22 = OpExtInst %4 %1 Degrees %3 +%23 = OpExtInst %4 %1 Radians %3 +%24 = OpExtInst %14 %1 Degrees %21 +%25 = OpExtInst %14 %1 Radians %21 +%26 = OpExtInst %14 %1 FClamp %21 %27 %28 +%29 = OpExtInst %14 %1 Refract %21 %21 %3 +%32 = OpCompositeExtract %7 %16 0 +%33 = OpCompositeExtract %7 %16 0 +%34 = OpIMul %7 %32 %33 +%35 = OpIAdd %7 %31 %34 +%36 = OpCompositeExtract %7 %16 1 +%37 = OpCompositeExtract %7 %16 1 +%38 = OpIMul %7 %36 %37 +%30 = OpIAdd %7 %35 %38 +%39 = OpCopyObject %9 %8 +%40 = OpExtInst %9 %1 FindUMsb %39 +%41 = OpExtInst %9 %1 FindILsb %8 +%44 = OpBitcast %9 %41 +%45 = OpExtInst %9 %1 UMin %42 %43 +%46 = OpExtInst %7 %1 FindILsb %6 +%48 = OpBitcast %9 %46 +%49 = OpExtInst %9 %1 UMin %42 %47 +%50 = OpBitcast %7 %49 +%51 = OpExtInst %9 %1 FindILsb %10 +%53 = OpBitcast %9 %51 +%54 = OpExtInst %9 %1 UMin %42 %52 +%55 = OpExtInst %7 %1 FindILsb %11 +%57 = OpBitcast %9 %55 +%58 = OpExtInst %9 %1 UMin %42 %56 +%59 = OpBitcast %7 %58 +%61 = OpCompositeConstruct %60 %8 %8 +%62 = OpExtInst %60 %1 FindILsb %61 +%65 = OpBitcast %60 %62 +%66 = OpExtInst %60 %1 UMin %63 %64 +%67 = OpCompositeConstruct %15 %6 %6 +%68 = OpExtInst %15 %1 FindILsb %67 +%71 = OpBitcast %60 %68 +%72 = OpExtInst %60 %1 UMin %69 %70 +%73 = OpBitcast %15 %72 +%74 = OpCompositeConstruct %60 %12 %12 +%75 = OpExtInst %60 %1 FindILsb %74 +%78 = OpBitcast %60 %75 +%79 = OpExtInst %60 %1 UMin %76 %77 +%80 = OpCompositeConstruct %15 %13 %13 +%81 = OpExtInst %15 %1 FindILsb %80 +%84 = OpBitcast %60 %81 +%85 = OpExtInst %60 %1 UMin %82 %83 +%86 = OpBitcast %15 %85 +%87 = OpExtInst %7 %1 FindUMsb %11 +%89 = OpISub %7 %88 %87 +%90 = OpExtInst %7 %1 FindUMsb %12 +%91 = OpISub %7 %88 %90 +%92 = OpBitcast %9 %91 +%93 = OpCompositeConstruct %15 %11 %11 +%94 = OpExtInst %15 %1 FindUMsb %93 +%96 = OpISub %15 %95 %94 +%97 = OpCompositeConstruct %60 %12 %12 +%98 = OpExtInst %15 %1 FindUMsb %97 +%100 = OpISub %15 %99 %98 +%101 = OpBitcast %60 %100 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/math-functions.wgsl b/tests/out/wgsl/math-functions.wgsl index 639505c4a1..71ae0fd749 100644 --- a/tests/out/wgsl/math-functions.wgsl +++ b/tests/out/wgsl/math-functions.wgsl @@ -10,9 +10,13 @@ fn main() { let const_dot = dot(vec2(0, 0), vec2(0, 0)); let first_leading_bit_abs = firstLeadingBit(abs(0u)); let ctz_a = countTrailingZeros(0u); - let ctz_b = countTrailingZeros(4294967295u); - let ctz_c = countTrailingZeros(vec2(1u)); - let ctz_d = countTrailingZeros(vec2(0u)); + let ctz_b = countTrailingZeros(0); + let ctz_c = countTrailingZeros(4294967295u); + let ctz_d = countTrailingZeros(-1); + let ctz_e = countTrailingZeros(vec2(0u)); + let ctz_f = countTrailingZeros(vec2(0)); + let ctz_g = countTrailingZeros(vec2(1u)); + let ctz_h = countTrailingZeros(vec2(1)); let clz_a = countLeadingZeros(-1); let clz_b = countLeadingZeros(1u); let clz_c = countLeadingZeros(vec2(-1)); From e936d9ce565859f23efc8fdc903d1e9f493c151a Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sat, 11 Feb 2023 11:36:35 +0100 Subject: [PATCH 07/15] Using kind and fixing spv --- src/back/hlsl/writer.rs | 20 ++++++++++++++++++-- src/back/spv/block.rs | 2 +- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 7472c1afef..00ef56204f 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2694,10 +2694,26 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::VectorSize::Quad => ".xxxx", }; - write!(self.out, "min(asuint((32){s}, asuint(firstbitlow(")?; + if let ScalarKind::Uint = kind { + write!(self.out, "min((32){s}, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "asint(min((32){s}, asuint(firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))))")?; + } } TypeInner::Scalar { kind, .. } => { - write!(self.out, "min(asuint((32), asuint(firstbitlow(")?; + if let ScalarKind::Uint = kind { + write!(self.out, "min((32), firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "asint(min((32), asuint(firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))))")?; + } } _ => unreachable!(), } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index fa3701d3f1..6d877a28bb 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -938,7 +938,7 @@ impl<'w> BlockContext<'w> { block.body.push(Instruction::unary( spirv::Op::Bitcast, uint_type_id, - self.gen_id(), + cast_id, id, )); From 799874cc48bdef4d27ba0c87a0ced50d129d28c0 Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sat, 11 Feb 2023 12:44:19 +0100 Subject: [PATCH 08/15] Casting when needed for spv --- src/back/spv/block.rs | 61 ++++++++++++++++++++--- tests/out/spv/math-functions.spvasm | 76 ++++++++++++++++------------- 2 files changed, 96 insertions(+), 41 deletions(-) diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 6d877a28bb..93ba7e5103 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -890,10 +890,45 @@ impl<'w> BlockContext<'w> { )), Mf::CountTrailingZeros => { let uint = crate::ScalarValue::Uint(32); + let int = crate::ScalarValue::Sint(0); + + let (int_type_id, _int_id) = match *arg_ty { + crate::TypeInner::Vector { size, width, .. } => { + let ty = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + kind: crate::ScalarKind::Sint, + width, + pointer_space: None, + })); + + self.temp_list.clear(); + self.temp_list + .resize(size as _, self.writer.get_constant_scalar(int, width)); + + let id = self.gen_id(); + block.body.push(Instruction::constant_composite( + ty, + id, + &self.temp_list, + )); + + (ty, id) + } + crate::TypeInner::Scalar { width, .. } => ( + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Sint, + width, + pointer_space: None, + })), + self.writer.get_constant_scalar(int, width), + ), + _ => unreachable!(), + }; let (uint_type_id, uint_id) = match *arg_ty { crate::TypeInner::Vector { size, width, .. } => { - let uty = self.get_type_id(LookupType::Local(LocalType::Value { + let ty = self.get_type_id(LookupType::Local(LocalType::Value { vector_size: Some(size), kind: crate::ScalarKind::Uint, width, @@ -906,13 +941,14 @@ impl<'w> BlockContext<'w> { self.writer.get_constant_scalar(uint, width), ); - let uid = self.gen_id(); + let id = self.gen_id(); block.body.push(Instruction::constant_composite( - uty, - uid, + ty, + id, &self.temp_list, )); - (uty, uid) + + (ty, id) } crate::TypeInner::Scalar { width, .. } => ( self.get_type_id(LookupType::Local(LocalType::Value { @@ -926,12 +962,23 @@ impl<'w> BlockContext<'w> { _ => unreachable!(), }; + let mut arg_id = arg0_id; + if let Some(crate::ScalarKind::Uint) = arg_scalar_kind { + arg_id = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + int_type_id, + arg_id, + arg0_id, + )); + } + block.body.push(Instruction::ext_inst( self.writer.gl450_ext_inst_id, spirv::GLOp::FindILsb, - result_type_id, + int_type_id, id, - &[arg0_id], + &[arg_id], )); let cast_id = self.gen_id(); diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index 31437702c2..35253ff104 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -23,15 +23,19 @@ OpEntryPoint Vertex %18 "main" %16 = OpConstantComposite %15 %6 %6 %19 = OpTypeFunction %2 %42 = OpConstant %9 32 -%60 = OpTypeVector %9 2 +%58 = OpTypeVector %9 2 %88 = OpConstant %7 31 %27 = OpConstantComposite %14 %5 %5 %5 %5 %28 = OpConstantComposite %14 %3 %3 %3 %3 %31 = OpConstantNull %7 -%63 = OpConstantComposite %60 %42 %42 -%69 = OpConstantComposite %60 %42 %42 -%76 = OpConstantComposite %60 %42 %42 -%82 = OpConstantComposite %60 %42 %42 +%61 = OpConstantComposite %15 %6 %6 +%62 = OpConstantComposite %58 %42 %42 +%68 = OpConstantComposite %15 %6 %6 +%69 = OpConstantComposite %58 %42 %42 +%75 = OpConstantComposite %15 %6 %6 +%76 = OpConstantComposite %58 %42 %42 +%82 = OpConstantComposite %15 %6 %6 +%83 = OpConstantComposite %58 %42 %42 %95 = OpConstantComposite %15 %88 %88 %99 = OpConstantComposite %15 %88 %88 %18 = OpFunction %2 None %19 @@ -55,37 +59,41 @@ OpBranch %20 %30 = OpIAdd %7 %35 %38 %39 = OpCopyObject %9 %8 %40 = OpExtInst %9 %1 FindUMsb %39 -%41 = OpExtInst %9 %1 FindILsb %8 +%43 = OpBitcast %7 %8 +%41 = OpExtInst %7 %1 FindILsb %43 %44 = OpBitcast %9 %41 -%45 = OpExtInst %9 %1 UMin %42 %43 +%45 = OpExtInst %9 %1 UMin %42 %44 %46 = OpExtInst %7 %1 FindILsb %6 -%48 = OpBitcast %9 %46 -%49 = OpExtInst %9 %1 UMin %42 %47 -%50 = OpBitcast %7 %49 -%51 = OpExtInst %9 %1 FindILsb %10 -%53 = OpBitcast %9 %51 -%54 = OpExtInst %9 %1 UMin %42 %52 -%55 = OpExtInst %7 %1 FindILsb %11 -%57 = OpBitcast %9 %55 -%58 = OpExtInst %9 %1 UMin %42 %56 -%59 = OpBitcast %7 %58 -%61 = OpCompositeConstruct %60 %8 %8 -%62 = OpExtInst %60 %1 FindILsb %61 -%65 = OpBitcast %60 %62 -%66 = OpExtInst %60 %1 UMin %63 %64 -%67 = OpCompositeConstruct %15 %6 %6 -%68 = OpExtInst %15 %1 FindILsb %67 -%71 = OpBitcast %60 %68 -%72 = OpExtInst %60 %1 UMin %69 %70 -%73 = OpBitcast %15 %72 -%74 = OpCompositeConstruct %60 %12 %12 -%75 = OpExtInst %60 %1 FindILsb %74 -%78 = OpBitcast %60 %75 -%79 = OpExtInst %60 %1 UMin %76 %77 +%47 = OpBitcast %9 %46 +%48 = OpExtInst %9 %1 UMin %42 %47 +%49 = OpBitcast %7 %48 +%51 = OpBitcast %7 %10 +%50 = OpExtInst %7 %1 FindILsb %51 +%52 = OpBitcast %9 %50 +%53 = OpExtInst %9 %1 UMin %42 %52 +%54 = OpExtInst %7 %1 FindILsb %11 +%55 = OpBitcast %9 %54 +%56 = OpExtInst %9 %1 UMin %42 %55 +%57 = OpBitcast %7 %56 +%59 = OpCompositeConstruct %58 %8 %8 +%63 = OpBitcast %15 %59 +%60 = OpExtInst %15 %1 FindILsb %63 +%64 = OpBitcast %58 %60 +%65 = OpExtInst %58 %1 UMin %62 %64 +%66 = OpCompositeConstruct %15 %6 %6 +%67 = OpExtInst %15 %1 FindILsb %66 +%70 = OpBitcast %58 %67 +%71 = OpExtInst %58 %1 UMin %69 %70 +%72 = OpBitcast %15 %71 +%73 = OpCompositeConstruct %58 %12 %12 +%77 = OpBitcast %15 %73 +%74 = OpExtInst %15 %1 FindILsb %77 +%78 = OpBitcast %58 %74 +%79 = OpExtInst %58 %1 UMin %76 %78 %80 = OpCompositeConstruct %15 %13 %13 %81 = OpExtInst %15 %1 FindILsb %80 -%84 = OpBitcast %60 %81 -%85 = OpExtInst %60 %1 UMin %82 %83 +%84 = OpBitcast %58 %81 +%85 = OpExtInst %58 %1 UMin %83 %84 %86 = OpBitcast %15 %85 %87 = OpExtInst %7 %1 FindUMsb %11 %89 = OpISub %7 %88 %87 @@ -95,9 +103,9 @@ OpBranch %20 %93 = OpCompositeConstruct %15 %11 %11 %94 = OpExtInst %15 %1 FindUMsb %93 %96 = OpISub %15 %95 %94 -%97 = OpCompositeConstruct %60 %12 %12 +%97 = OpCompositeConstruct %58 %12 %12 %98 = OpExtInst %15 %1 FindUMsb %97 %100 = OpISub %15 %99 %98 -%101 = OpBitcast %60 %100 +%101 = OpBitcast %58 %100 OpReturn OpFunctionEnd \ No newline at end of file From 7208727a7d56f8f2f32b0cd2f9e45b2130df85aa Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sat, 11 Feb 2023 13:00:24 +0100 Subject: [PATCH 09/15] Fixing hlsl --- src/back/hlsl/writer.rs | 17 ++++------------- tests/out/hlsl/math-functions.hlsl | 16 ++++++++-------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 00ef56204f..cff013845b 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2695,29 +2695,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { }; if let ScalarKind::Uint = kind { - write!(self.out, "min((32){s}, firstbitlow(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, "))")?; + write!(self.out, "min((32u){s}, asuint(firstbitlow(asint(")?; } else { - write!(self.out, "asint(min((32){s}, asuint(firstbitlow(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, "))))")?; + write!(self.out, "asint(min((32u){s}, asuint(firstbitlow(")?; } } TypeInner::Scalar { kind, .. } => { if let ScalarKind::Uint = kind { - write!(self.out, "min((32), firstbitlow(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, "))")?; + write!(self.out, "min(32u, asuint(firstbitlow(asint(")?; } else { - write!(self.out, "asint(min((32), asuint(firstbitlow(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, "))))")?; + write!(self.out, "asint(min(32u, asuint(firstbitlow(")?; } } _ => unreachable!(), } - self.write_expr(module, arg, func_ctx)?; write!(self.out, "))))")?; diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index ad6e5e9046..5af15eb622 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -10,16 +10,16 @@ void main() float4 g = refract(v, v, 1.0); int const_dot = dot(int2(0, 0), int2(0, 0)); uint first_leading_bit_abs = firstbithigh(abs(0u)); - uint ctz_a = min(asuint((32), asuint(firstbitlow(0u)))); - int ctz_b = min(asuint((32), asuint(firstbitlow(0)))); - uint ctz_c = min(asuint((32), asuint(firstbitlow(4294967295u)))); - int ctz_d = min(asuint((32), asuint(firstbitlow(-1)))); - uint2 ctz_e = min(asuint((32).xx, asuint(firstbitlow((0u).xx)))); + uint ctz_a = min(32u, asuint(firstbitlow(asint(0u)))); + int ctz_b = asint(min(32u, asuint(firstbitlow(0)))); + uint ctz_c = min(32u, asuint(firstbitlow(asint(4294967295u)))); + int ctz_d = asint(min(32u, asuint(firstbitlow(-1)))); + uint2 ctz_e = min((32u).xx, asuint(firstbitlow(asint((0u).xx)))); int2 _expr27 = (0).xx; - int2 ctz_f = min(asuint((32).xx, asuint(firstbitlow(_expr27)))); - uint2 ctz_g = min(asuint((32).xx, asuint(firstbitlow((1u).xx)))); + int2 ctz_f = asint(min((32u).xx, asuint(firstbitlow(_expr27)))); + uint2 ctz_g = min((32u).xx, asuint(firstbitlow(asint((1u).xx)))); int2 _expr33 = (1).xx; - int2 ctz_h = min(asuint((32).xx, asuint(firstbitlow(_expr33)))); + int2 ctz_h = asint(min((32u).xx, asuint(firstbitlow(_expr33)))); int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1)); uint clz_b = asuint(31 - firstbithigh(1u)); int2 _expr40 = (-1).xx; From 5fffd56c1848b1d31aa8e0d1e7a406e3abd7504a Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Mon, 13 Feb 2023 21:03:43 +0100 Subject: [PATCH 10/15] Removing unnecessary baking --- src/back/glsl/mod.rs | 3 +-- src/back/hlsl/writer.rs | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 96473e48d9..e1d270daf6 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1149,8 +1149,7 @@ impl<'a, W: Write> Writer<'a, W> { } } } - crate::MathFunction::CountTrailingZeros - | crate::MathFunction::CountLeadingZeros => { + crate::MathFunction::CountLeadingZeros => { if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { self.need_bake_expressions.insert(arg); } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index cff013845b..3088f79732 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -124,8 +124,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { | crate::MathFunction::Unpack2x16float => { self.need_bake_expressions.insert(arg); } - crate::MathFunction::CountTrailingZeros - | crate::MathFunction::CountLeadingZeros => { + crate::MathFunction::CountLeadingZeros => { let inner = info[fun_handle].ty.inner_with(&module.types); if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { self.need_bake_expressions.insert(arg); From 75663197db4a3b70126955fa8f55954a285050dc Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Mon, 13 Feb 2023 21:12:25 +0100 Subject: [PATCH 11/15] Adding generated code --- tests/out/glsl/math-functions.main.Vertex.glsl | 6 ++---- tests/out/hlsl/math-functions.hlsl | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/out/glsl/math-functions.main.Vertex.glsl b/tests/out/glsl/math-functions.main.Vertex.glsl index 4f4e7b718e..3c5c1dd345 100644 --- a/tests/out/glsl/math-functions.main.Vertex.glsl +++ b/tests/out/glsl/math-functions.main.Vertex.glsl @@ -19,11 +19,9 @@ void main() { uint ctz_c = min(uint(findLSB(4294967295u)), 32u); int ctz_d = int(min(uint(findLSB(-1)), 32u)); uvec2 ctz_e = min(uvec2(findLSB(uvec2(0u))), uvec2(32u)); - ivec2 _e27 = ivec2(0); - ivec2 ctz_f = ivec2(min(uvec2(findLSB(_e27)), uvec2(32u))); + ivec2 ctz_f = ivec2(min(uvec2(findLSB(ivec2(0))), uvec2(32u))); uvec2 ctz_g = min(uvec2(findLSB(uvec2(1u))), uvec2(32u)); - ivec2 _e33 = ivec2(1); - ivec2 ctz_h = ivec2(min(uvec2(findLSB(_e33)), uvec2(32u))); + ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u))); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); ivec2 _e40 = ivec2(-1); diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index 5af15eb622..40d04685ca 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -15,11 +15,9 @@ void main() uint ctz_c = min(32u, asuint(firstbitlow(asint(4294967295u)))); int ctz_d = asint(min(32u, asuint(firstbitlow(-1)))); uint2 ctz_e = min((32u).xx, asuint(firstbitlow(asint((0u).xx)))); - int2 _expr27 = (0).xx; - int2 ctz_f = asint(min((32u).xx, asuint(firstbitlow(_expr27)))); + int2 ctz_f = asint(min((32u).xx, asuint(firstbitlow((0).xx)))); uint2 ctz_g = min((32u).xx, asuint(firstbitlow(asint((1u).xx)))); - int2 _expr33 = (1).xx; - int2 ctz_h = asint(min((32u).xx, asuint(firstbitlow(_expr33)))); + int2 ctz_h = asint(min((32u).xx, asuint(firstbitlow((1).xx)))); int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1)); uint clz_b = asuint(31 - firstbithigh(1u)); int2 _expr40 = (-1).xx; From 63abeaa2994a04c8040288777496faf1d722f9b1 Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Sun, 19 Feb 2023 10:17:25 +0100 Subject: [PATCH 12/15] Removing bitcasts as suggested --- src/back/hlsl/writer.rs | 14 ++-- src/back/spv/block.rs | 50 +++----------- tests/out/hlsl/math-functions.hlsl | 8 +-- tests/out/spv/math-functions.spvasm | 102 ++++++++++++---------------- 4 files changed, 65 insertions(+), 109 deletions(-) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 3088f79732..b2de975b95 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2694,22 +2694,28 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { }; if let ScalarKind::Uint = kind { - write!(self.out, "min((32u){s}, asuint(firstbitlow(asint(")?; + write!(self.out, "min((32u){s}, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; } else { write!(self.out, "asint(min((32u){s}, asuint(firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))))")?; } } TypeInner::Scalar { kind, .. } => { if let ScalarKind::Uint = kind { - write!(self.out, "min(32u, asuint(firstbitlow(asint(")?; + write!(self.out, "min(32u, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; } else { write!(self.out, "asint(min(32u, asuint(firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))))")?; } } _ => unreachable!(), } - self.write_expr(module, arg, func_ctx)?; - write!(self.out, "))))")?; return Ok(()); } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 93ba7e5103..3ea245ebdc 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -926,7 +926,7 @@ impl<'w> BlockContext<'w> { _ => unreachable!(), }; - let (uint_type_id, uint_id) = match *arg_ty { + let uint_id = match *arg_ty { crate::TypeInner::Vector { size, width, .. } => { let ty = self.get_type_id(LookupType::Local(LocalType::Value { vector_size: Some(size), @@ -948,65 +948,31 @@ impl<'w> BlockContext<'w> { &self.temp_list, )); - (ty, id) + id + } + crate::TypeInner::Scalar { width, .. } => { + self.writer.get_constant_scalar(uint, width) } - crate::TypeInner::Scalar { width, .. } => ( - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Uint, - width, - pointer_space: None, - })), - self.writer.get_constant_scalar(uint, width), - ), _ => unreachable!(), }; - let mut arg_id = arg0_id; - if let Some(crate::ScalarKind::Uint) = arg_scalar_kind { - arg_id = self.gen_id(); - block.body.push(Instruction::unary( - spirv::Op::Bitcast, - int_type_id, - arg_id, - arg0_id, - )); - } - block.body.push(Instruction::ext_inst( self.writer.gl450_ext_inst_id, spirv::GLOp::FindILsb, int_type_id, id, - &[arg_id], - )); - - let cast_id = self.gen_id(); - block.body.push(Instruction::unary( - spirv::Op::Bitcast, - uint_type_id, - cast_id, - id, + &[arg0_id], )); let min_id = self.gen_id(); block.body.push(Instruction::ext_inst( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, - uint_type_id, + result_type_id, min_id, - &[uint_id, cast_id], + &[uint_id, id], )); - if let Some(crate::ScalarKind::Sint) = arg_scalar_kind { - block.body.push(Instruction::unary( - spirv::Op::Bitcast, - result_type_id, - self.gen_id(), - min_id, - )); - } - return Ok(()); } Mf::CountLeadingZeros => { diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index 40d04685ca..958e77d80a 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -10,13 +10,13 @@ void main() float4 g = refract(v, v, 1.0); int const_dot = dot(int2(0, 0), int2(0, 0)); uint first_leading_bit_abs = firstbithigh(abs(0u)); - uint ctz_a = min(32u, asuint(firstbitlow(asint(0u)))); + uint ctz_a = min(32u, firstbitlow(0u)); int ctz_b = asint(min(32u, asuint(firstbitlow(0)))); - uint ctz_c = min(32u, asuint(firstbitlow(asint(4294967295u)))); + uint ctz_c = min(32u, firstbitlow(4294967295u)); int ctz_d = asint(min(32u, asuint(firstbitlow(-1)))); - uint2 ctz_e = min((32u).xx, asuint(firstbitlow(asint((0u).xx)))); + uint2 ctz_e = min((32u).xx, firstbitlow((0u).xx)); int2 ctz_f = asint(min((32u).xx, asuint(firstbitlow((0).xx)))); - uint2 ctz_g = min((32u).xx, asuint(firstbitlow(asint((1u).xx)))); + uint2 ctz_g = min((32u).xx, firstbitlow((1u).xx)); int2 ctz_h = asint(min((32u).xx, asuint(firstbitlow((1).xx)))); int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1)); uint clz_b = asuint(31 - firstbithigh(1u)); diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index 35253ff104..5774d67f6a 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 102 +; Bound: 86 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -23,21 +23,21 @@ OpEntryPoint Vertex %18 "main" %16 = OpConstantComposite %15 %6 %6 %19 = OpTypeFunction %2 %42 = OpConstant %9 32 -%58 = OpTypeVector %9 2 -%88 = OpConstant %7 31 +%50 = OpTypeVector %9 2 +%72 = OpConstant %7 31 %27 = OpConstantComposite %14 %5 %5 %5 %5 %28 = OpConstantComposite %14 %3 %3 %3 %3 %31 = OpConstantNull %7 -%61 = OpConstantComposite %15 %6 %6 -%62 = OpConstantComposite %58 %42 %42 +%53 = OpConstantComposite %15 %6 %6 +%54 = OpConstantComposite %50 %42 %42 +%58 = OpConstantComposite %15 %6 %6 +%59 = OpConstantComposite %50 %42 %42 +%63 = OpConstantComposite %15 %6 %6 +%64 = OpConstantComposite %50 %42 %42 %68 = OpConstantComposite %15 %6 %6 -%69 = OpConstantComposite %58 %42 %42 -%75 = OpConstantComposite %15 %6 %6 -%76 = OpConstantComposite %58 %42 %42 -%82 = OpConstantComposite %15 %6 %6 -%83 = OpConstantComposite %58 %42 %42 -%95 = OpConstantComposite %15 %88 %88 -%99 = OpConstantComposite %15 %88 %88 +%69 = OpConstantComposite %50 %42 %42 +%79 = OpConstantComposite %15 %72 %72 +%83 = OpConstantComposite %15 %72 %72 %18 = OpFunction %2 None %19 %17 = OpLabel OpBranch %20 @@ -59,53 +59,37 @@ OpBranch %20 %30 = OpIAdd %7 %35 %38 %39 = OpCopyObject %9 %8 %40 = OpExtInst %9 %1 FindUMsb %39 -%43 = OpBitcast %7 %8 -%41 = OpExtInst %7 %1 FindILsb %43 -%44 = OpBitcast %9 %41 -%45 = OpExtInst %9 %1 UMin %42 %44 -%46 = OpExtInst %7 %1 FindILsb %6 -%47 = OpBitcast %9 %46 -%48 = OpExtInst %9 %1 UMin %42 %47 -%49 = OpBitcast %7 %48 -%51 = OpBitcast %7 %10 -%50 = OpExtInst %7 %1 FindILsb %51 -%52 = OpBitcast %9 %50 -%53 = OpExtInst %9 %1 UMin %42 %52 -%54 = OpExtInst %7 %1 FindILsb %11 -%55 = OpBitcast %9 %54 -%56 = OpExtInst %9 %1 UMin %42 %55 -%57 = OpBitcast %7 %56 -%59 = OpCompositeConstruct %58 %8 %8 -%63 = OpBitcast %15 %59 -%60 = OpExtInst %15 %1 FindILsb %63 -%64 = OpBitcast %58 %60 -%65 = OpExtInst %58 %1 UMin %62 %64 -%66 = OpCompositeConstruct %15 %6 %6 +%41 = OpExtInst %7 %1 FindILsb %8 +%43 = OpExtInst %9 %1 UMin %42 %41 +%44 = OpExtInst %7 %1 FindILsb %6 +%45 = OpExtInst %7 %1 UMin %42 %44 +%46 = OpExtInst %7 %1 FindILsb %10 +%47 = OpExtInst %9 %1 UMin %42 %46 +%48 = OpExtInst %7 %1 FindILsb %11 +%49 = OpExtInst %7 %1 UMin %42 %48 +%51 = OpCompositeConstruct %50 %8 %8 +%52 = OpExtInst %15 %1 FindILsb %51 +%55 = OpExtInst %50 %1 UMin %54 %52 +%56 = OpCompositeConstruct %15 %6 %6 +%57 = OpExtInst %15 %1 FindILsb %56 +%60 = OpExtInst %15 %1 UMin %59 %57 +%61 = OpCompositeConstruct %50 %12 %12 +%62 = OpExtInst %15 %1 FindILsb %61 +%65 = OpExtInst %50 %1 UMin %64 %62 +%66 = OpCompositeConstruct %15 %13 %13 %67 = OpExtInst %15 %1 FindILsb %66 -%70 = OpBitcast %58 %67 -%71 = OpExtInst %58 %1 UMin %69 %70 -%72 = OpBitcast %15 %71 -%73 = OpCompositeConstruct %58 %12 %12 -%77 = OpBitcast %15 %73 -%74 = OpExtInst %15 %1 FindILsb %77 -%78 = OpBitcast %58 %74 -%79 = OpExtInst %58 %1 UMin %76 %78 -%80 = OpCompositeConstruct %15 %13 %13 -%81 = OpExtInst %15 %1 FindILsb %80 -%84 = OpBitcast %58 %81 -%85 = OpExtInst %58 %1 UMin %83 %84 -%86 = OpBitcast %15 %85 -%87 = OpExtInst %7 %1 FindUMsb %11 -%89 = OpISub %7 %88 %87 -%90 = OpExtInst %7 %1 FindUMsb %12 -%91 = OpISub %7 %88 %90 -%92 = OpBitcast %9 %91 -%93 = OpCompositeConstruct %15 %11 %11 -%94 = OpExtInst %15 %1 FindUMsb %93 -%96 = OpISub %15 %95 %94 -%97 = OpCompositeConstruct %58 %12 %12 -%98 = OpExtInst %15 %1 FindUMsb %97 -%100 = OpISub %15 %99 %98 -%101 = OpBitcast %58 %100 +%70 = OpExtInst %15 %1 UMin %69 %67 +%71 = OpExtInst %7 %1 FindUMsb %11 +%73 = OpISub %7 %72 %71 +%74 = OpExtInst %7 %1 FindUMsb %12 +%75 = OpISub %7 %72 %74 +%76 = OpBitcast %9 %75 +%77 = OpCompositeConstruct %15 %11 %11 +%78 = OpExtInst %15 %1 FindUMsb %77 +%80 = OpISub %15 %79 %78 +%81 = OpCompositeConstruct %50 %12 %12 +%82 = OpExtInst %15 %1 FindUMsb %81 +%84 = OpISub %15 %83 %82 +%85 = OpBitcast %50 %84 OpReturn OpFunctionEnd \ No newline at end of file From 62cd01d06887c671958cd0a502878be62d9b388f Mon Sep 17 00:00:00 2001 From: Mauro Gentile <62186646+gents83@users.noreply.github.com> Date: Mon, 20 Feb 2023 20:02:27 +0100 Subject: [PATCH 13/15] Update src/back/spv/block.rs with suggestion of @teoxoy Co-authored-by: Teodor Tanasoaia <28601907+teoxoy@users.noreply.github.com> --- src/back/spv/block.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 3ea245ebdc..1ad32b36c3 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -964,16 +964,13 @@ impl<'w> BlockContext<'w> { &[arg0_id], )); - let min_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + MathOp::Custom(Instruction::ext_inst( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, result_type_id, - min_id, - &[uint_id, id], - )); - - return Ok(()); + id, + &[uint_id, lsb_id], + )) } Mf::CountLeadingZeros => { let int = crate::ScalarValue::Sint(31); From a7789bff16d4bd9f717e8c231fbabe534463248a Mon Sep 17 00:00:00 2001 From: Mauro Gentile <62186646+gents83@users.noreply.github.com> Date: Mon, 20 Feb 2023 20:02:43 +0100 Subject: [PATCH 14/15] Update src/back/spv/block.rs with suggestion of @teoxoy Co-authored-by: Teodor Tanasoaia <28601907+teoxoy@users.noreply.github.com> --- src/back/spv/block.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 1ad32b36c3..9d1bc87661 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -956,11 +956,12 @@ impl<'w> BlockContext<'w> { _ => unreachable!(), }; + let lsb_id = self.gen_id(); block.body.push(Instruction::ext_inst( self.writer.gl450_ext_inst_id, spirv::GLOp::FindILsb, - int_type_id, - id, + result_type_id, + lsb_id, &[arg0_id], )); From 25921d0673e88b39f0ad60a19db03093b08921f3 Mon Sep 17 00:00:00 2001 From: Mauro Gentile Date: Mon, 20 Feb 2023 20:25:17 +0100 Subject: [PATCH 15/15] Using get_constant_composite + remove unused code --- src/back/spv/block.rs | 50 ++---------------- tests/out/spv/math-functions.spvasm | 82 +++++++++++++---------------- 2 files changed, 41 insertions(+), 91 deletions(-) diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index d9f93a5b7a..11d5782633 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -876,50 +876,15 @@ impl<'w> BlockContext<'w> { )), Mf::CountTrailingZeros => { let uint = crate::ScalarValue::Uint(32); - let int = crate::ScalarValue::Sint(0); - - let (int_type_id, _int_id) = match *arg_ty { - crate::TypeInner::Vector { size, width, .. } => { - let ty = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(size), - kind: crate::ScalarKind::Sint, - width, - pointer_space: None, - })); - - self.temp_list.clear(); - self.temp_list - .resize(size as _, self.writer.get_constant_scalar(int, width)); - - let id = self.gen_id(); - block.body.push(Instruction::constant_composite( - ty, - id, - &self.temp_list, - )); - - (ty, id) - } - crate::TypeInner::Scalar { width, .. } => ( - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Sint, - width, - pointer_space: None, - })), - self.writer.get_constant_scalar(int, width), - ), - _ => unreachable!(), - }; - let uint_id = match *arg_ty { crate::TypeInner::Vector { size, width, .. } => { - let ty = self.get_type_id(LookupType::Local(LocalType::Value { + let ty = LocalType::Value { vector_size: Some(size), kind: crate::ScalarKind::Uint, width, pointer_space: None, - })); + } + .into(); self.temp_list.clear(); self.temp_list.resize( @@ -927,14 +892,7 @@ impl<'w> BlockContext<'w> { self.writer.get_constant_scalar(uint, width), ); - let id = self.gen_id(); - block.body.push(Instruction::constant_composite( - ty, - id, - &self.temp_list, - )); - - id + self.writer.get_constant_composite(ty, &self.temp_list) } crate::TypeInner::Scalar { width, .. } => { self.writer.get_constant_scalar(uint, width) diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index 5774d67f6a..d1a98b4e43 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 86 +; Bound: 78 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -22,22 +22,14 @@ OpEntryPoint Vertex %18 "main" %15 = OpTypeVector %7 2 %16 = OpConstantComposite %15 %6 %6 %19 = OpTypeFunction %2 -%42 = OpConstant %9 32 -%50 = OpTypeVector %9 2 -%72 = OpConstant %7 31 %27 = OpConstantComposite %14 %5 %5 %5 %5 %28 = OpConstantComposite %14 %3 %3 %3 %3 +%42 = OpConstant %9 32 +%50 = OpTypeVector %9 2 +%53 = OpConstantComposite %50 %42 %42 +%65 = OpConstant %7 31 +%72 = OpConstantComposite %15 %65 %65 %31 = OpConstantNull %7 -%53 = OpConstantComposite %15 %6 %6 -%54 = OpConstantComposite %50 %42 %42 -%58 = OpConstantComposite %15 %6 %6 -%59 = OpConstantComposite %50 %42 %42 -%63 = OpConstantComposite %15 %6 %6 -%64 = OpConstantComposite %50 %42 %42 -%68 = OpConstantComposite %15 %6 %6 -%69 = OpConstantComposite %50 %42 %42 -%79 = OpConstantComposite %15 %72 %72 -%83 = OpConstantComposite %15 %72 %72 %18 = OpFunction %2 None %19 %17 = OpLabel OpBranch %20 @@ -59,37 +51,37 @@ OpBranch %20 %30 = OpIAdd %7 %35 %38 %39 = OpCopyObject %9 %8 %40 = OpExtInst %9 %1 FindUMsb %39 -%41 = OpExtInst %7 %1 FindILsb %8 -%43 = OpExtInst %9 %1 UMin %42 %41 -%44 = OpExtInst %7 %1 FindILsb %6 -%45 = OpExtInst %7 %1 UMin %42 %44 -%46 = OpExtInst %7 %1 FindILsb %10 -%47 = OpExtInst %9 %1 UMin %42 %46 -%48 = OpExtInst %7 %1 FindILsb %11 -%49 = OpExtInst %7 %1 UMin %42 %48 +%43 = OpExtInst %9 %1 FindILsb %8 +%41 = OpExtInst %9 %1 UMin %42 %43 +%45 = OpExtInst %7 %1 FindILsb %6 +%44 = OpExtInst %7 %1 UMin %42 %45 +%47 = OpExtInst %9 %1 FindILsb %10 +%46 = OpExtInst %9 %1 UMin %42 %47 +%49 = OpExtInst %7 %1 FindILsb %11 +%48 = OpExtInst %7 %1 UMin %42 %49 %51 = OpCompositeConstruct %50 %8 %8 -%52 = OpExtInst %15 %1 FindILsb %51 -%55 = OpExtInst %50 %1 UMin %54 %52 -%56 = OpCompositeConstruct %15 %6 %6 -%57 = OpExtInst %15 %1 FindILsb %56 -%60 = OpExtInst %15 %1 UMin %59 %57 -%61 = OpCompositeConstruct %50 %12 %12 -%62 = OpExtInst %15 %1 FindILsb %61 -%65 = OpExtInst %50 %1 UMin %64 %62 -%66 = OpCompositeConstruct %15 %13 %13 -%67 = OpExtInst %15 %1 FindILsb %66 -%70 = OpExtInst %15 %1 UMin %69 %67 -%71 = OpExtInst %7 %1 FindUMsb %11 -%73 = OpISub %7 %72 %71 -%74 = OpExtInst %7 %1 FindUMsb %12 -%75 = OpISub %7 %72 %74 -%76 = OpBitcast %9 %75 -%77 = OpCompositeConstruct %15 %11 %11 -%78 = OpExtInst %15 %1 FindUMsb %77 -%80 = OpISub %15 %79 %78 -%81 = OpCompositeConstruct %50 %12 %12 -%82 = OpExtInst %15 %1 FindUMsb %81 -%84 = OpISub %15 %83 %82 -%85 = OpBitcast %50 %84 +%54 = OpExtInst %50 %1 FindILsb %51 +%52 = OpExtInst %50 %1 UMin %53 %54 +%55 = OpCompositeConstruct %15 %6 %6 +%57 = OpExtInst %15 %1 FindILsb %55 +%56 = OpExtInst %15 %1 UMin %53 %57 +%58 = OpCompositeConstruct %50 %12 %12 +%60 = OpExtInst %50 %1 FindILsb %58 +%59 = OpExtInst %50 %1 UMin %53 %60 +%61 = OpCompositeConstruct %15 %13 %13 +%63 = OpExtInst %15 %1 FindILsb %61 +%62 = OpExtInst %15 %1 UMin %53 %63 +%64 = OpExtInst %7 %1 FindUMsb %11 +%66 = OpISub %7 %65 %64 +%67 = OpExtInst %7 %1 FindUMsb %12 +%68 = OpISub %7 %65 %67 +%69 = OpBitcast %9 %68 +%70 = OpCompositeConstruct %15 %11 %11 +%71 = OpExtInst %15 %1 FindUMsb %70 +%73 = OpISub %15 %72 %71 +%74 = OpCompositeConstruct %50 %12 %12 +%75 = OpExtInst %15 %1 FindUMsb %74 +%76 = OpISub %15 %72 %75 +%77 = OpBitcast %50 %76 OpReturn OpFunctionEnd \ No newline at end of file