From d974f2f832681366e9e6d99c4c3b0096b37efcdd Mon Sep 17 00:00:00 2001 From: Connor Fitzgerald Date: Thu, 20 Oct 2022 02:27:30 -0400 Subject: [PATCH] [hlsl-out] Properly implement bitcast --- src/back/hlsl/writer.rs | 64 +++++++++++++++--------------- tests/out/hlsl/bitcast.hlsl | 43 ++++++++++++++++++++ tests/out/hlsl/bitcast.hlsl.config | 3 ++ tests/out/hlsl/operators.hlsl | 8 ++-- tests/snapshots.rs | 2 +- 5 files changed, 84 insertions(+), 36 deletions(-) create mode 100644 tests/out/hlsl/bitcast.hlsl create mode 100644 tests/out/hlsl/bitcast.hlsl.config diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 09cdeed48e..e29d2c41db 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2404,39 +2404,41 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { convert, } => { let inner = func_ctx.info[expr].ty.inner_with(&module.types); - let get_width = |src_width| kind.to_hlsl_str(convert.unwrap_or(src_width)); - match *inner { - TypeInner::Vector { size, width, .. } => { - write!( - self.out, - "{}{}(", - get_width(width)?, - back::vector_size_str(size) - )?; - } - TypeInner::Scalar { width, .. } => { - write!(self.out, "{}(", get_width(width)?,)?; - } - TypeInner::Matrix { - columns, - rows, - width, - } => { - write!( - self.out, - "{}{}x{}(", - get_width(width)?, - back::vector_size_str(columns), - back::vector_size_str(rows) - )?; + match convert { + Some(dst_width) => { + match *inner { + TypeInner::Vector { size, .. } => { + write!( + self.out, + "{}{}(", + kind.to_hlsl_str(dst_width)?, + back::vector_size_str(size) + )?; + } + TypeInner::Scalar { .. } => { + write!(self.out, "{}(", kind.to_hlsl_str(dst_width)?,)?; + } + TypeInner::Matrix { columns, rows, .. } => { + write!( + self.out, + "{}{}x{}(", + kind.to_hlsl_str(dst_width)?, + back::vector_size_str(columns), + back::vector_size_str(rows) + )?; + } + _ => { + return Err(Error::Unimplemented(format!( + "write_expr expression::as {:?}", + inner + ))); + } + }; } - _ => { - return Err(Error::Unimplemented(format!( - "write_expr expression::as {:?}", - inner - ))); + None => { + write!(self.out, "{}(", kind.to_hlsl_cast(),)?; } - }; + } self.write_expr(module, expr, func_ctx)?; write!(self.out, ")")?; } diff --git a/tests/out/hlsl/bitcast.hlsl b/tests/out/hlsl/bitcast.hlsl new file mode 100644 index 0000000000..eda86b0abc --- /dev/null +++ b/tests/out/hlsl/bitcast.hlsl @@ -0,0 +1,43 @@ + +[numthreads(1, 1, 1)] +void main() +{ + int2 i2_ = (int2)0; + int3 i3_ = (int3)0; + int4 i4_ = (int4)0; + uint2 u2_ = (uint2)0; + uint3 u3_ = (uint3)0; + uint4 u4_ = (uint4)0; + float2 f2_ = (float2)0; + float3 f3_ = (float3)0; + float4 f4_ = (float4)0; + + i2_ = (0).xx; + i3_ = (0).xxx; + i4_ = (0).xxxx; + u2_ = (0u).xx; + u3_ = (0u).xxx; + u4_ = (0u).xxxx; + f2_ = (0.0).xx; + f3_ = (0.0).xxx; + f4_ = (0.0).xxxx; + int2 _expr27 = i2_; + u2_ = asuint(_expr27); + int3 _expr29 = i3_; + u3_ = asuint(_expr29); + int4 _expr31 = i4_; + u4_ = asuint(_expr31); + uint2 _expr33 = u2_; + i2_ = asint(_expr33); + uint3 _expr35 = u3_; + i3_ = asint(_expr35); + uint4 _expr37 = u4_; + i4_ = asint(_expr37); + int2 _expr39 = i2_; + f2_ = asfloat(_expr39); + int3 _expr41 = i3_; + f3_ = asfloat(_expr41); + int4 _expr43 = i4_; + f4_ = asfloat(_expr43); + return; +} diff --git a/tests/out/hlsl/bitcast.hlsl.config b/tests/out/hlsl/bitcast.hlsl.config new file mode 100644 index 0000000000..246c485cf7 --- /dev/null +++ b/tests/out/hlsl/bitcast.hlsl.config @@ -0,0 +1,3 @@ +vertex=() +fragment=() +compute=(main:cs_5_1 ) diff --git a/tests/out/hlsl/operators.hlsl b/tests/out/hlsl/operators.hlsl index 48e8fee09e..62e028f804 100644 --- a/tests/out/hlsl/operators.hlsl +++ b/tests/out/hlsl/operators.hlsl @@ -31,8 +31,8 @@ float4 builtins() float4 s3_ = (bool4(false, false, false, false) ? float4(0.0, 0.0, 0.0, 0.0) : float4(1.0, 1.0, 1.0, 1.0)); float4 m1_ = lerp(float4(0.0, 0.0, 0.0, 0.0), float4(1.0, 1.0, 1.0, 1.0), float4(0.5, 0.5, 0.5, 0.5)); float4 m2_ = lerp(float4(0.0, 0.0, 0.0, 0.0), float4(1.0, 1.0, 1.0, 1.0), 0.10000000149011612); - float b1_ = float(int4(1, 1, 1, 1).x); - float4 b2_ = float4(int4(1, 1, 1, 1)); + float b1_ = asfloat(int4(1, 1, 1, 1).x); + float4 b2_ = asfloat(int4(1, 1, 1, 1)); int4 v_i32_zero = int4(float4(0.0, 0.0, 0.0, 0.0)); return (((((float4(((s1_).xxxx + v_i32_zero)) + s2_) + m1_) + m2_) + (b1_).xxxx) + b2_); } @@ -87,8 +87,8 @@ float constructors() float unnamed_6 = float(0.0); uint2 unnamed_7 = uint2(uint2(0u, 0u)); float2x3 unnamed_8 = float2x3(float2x3(float3(0.0, 0.0, 0.0), float3(0.0, 0.0, 0.0))); - uint2 unnamed_9 = uint2(uint2(0u, 0u)); - float2x3 unnamed_10 = float2x3(float2x3(float3(0.0, 0.0, 0.0), float3(0.0, 0.0, 0.0))); + uint2 unnamed_9 = asuint(uint2(0u, 0u)); + float2x3 unnamed_10 = asfloat(float2x3(float3(0.0, 0.0, 0.0), float3(0.0, 0.0, 0.0))); float _expr75 = foo.a.x; return _expr75; } diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 59a3e314bd..cebeade8f8 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -443,7 +443,7 @@ fn convert_wgsl() { ), ( "bitcast", - Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ( "boids",