Skip to content

Commit

Permalink
[hlsl-out] fix matCx2's nested inside global arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Jun 17, 2022
1 parent c231efe commit fbe927c
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 287 deletions.
108 changes: 60 additions & 48 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,14 +662,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if global.space == crate::AddressSpace::Uniform {
write!(self.out, " {{ ")?;

let matrix_data = get_inner_matrix_data(module, global.ty);

// We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
// See the module-level block comment in mod.rs for details.
if let TypeInner::Matrix {
rows: crate::VectorSize::Bi,
columns,
width,
} = module.types[global.ty].inner
{
if let Some((columns, crate::VectorSize::Bi, width)) = matrix_data {
let vec_ty = crate::TypeInner::Vector {
size: crate::VectorSize::Bi,
kind: crate::ScalarKind::Float,
Expand All @@ -692,18 +689,18 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Even though Naga IR matrices are column-major, we must describe
// matrices passed from the CPU as being in row-major order.
// See the module-level block comment in mod.rs for details.
let is_matrix = matches!(module.types[global.ty].inner, TypeInner::Matrix { .. });
if is_matrix || is_array_of_matrices(module, global.ty) {
if matrix_data.is_some() {
write!(self.out, "row_major ")?;
}

self.write_type(module, global.ty)?;
let sub_name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, " {}", sub_name)?;
// need to write the array size if the type was emitted with `write_type`
if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
self.write_array_size(module, base, size)?;
}
}

// need to write the array size if the type was emitted with `write_type`
if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
self.write_array_size(module, base, size)?;
}

writeln!(self.out, "; }}")?;
Expand Down Expand Up @@ -843,7 +840,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Even though Naga IR matrices are column-major, we must describe
// matrices passed from the CPU as being in row-major order.
// See the module-level block comment in mod.rs for details.
if is_array_of_matrices(module, member.ty) {
if get_inner_matrix_data(module, member.ty).is_some() {
write!(self.out, "row_major ")?;
}

Expand Down Expand Up @@ -1919,11 +1916,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
Expression::AccessIndex { base, index } => {
if let Some(crate::AddressSpace::Storage { .. }) = func_ctx.info[expr]
.ty
.inner_with(&module.types)
.pointer_space()
{
let res_ty = func_ctx.info[expr].ty.inner_with(&module.types);
if let Some(crate::AddressSpace::Storage { .. }) = res_ty.pointer_space() {
// do nothing, the chain is written on `Load`/`Store`
} else {
let base_ty_res = &func_ctx.info[base].ty;
Expand Down Expand Up @@ -1962,6 +1956,24 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
};

let mut close_paren = false;
if let TypeInner::Pointer {
base,
space: crate::AddressSpace::Uniform,
} = *res_ty
{
if let TypeInner::Matrix {
rows: crate::VectorSize::Bi,
..
} = module.types[base].inner
{
write!(self.out, "((")?;
self.write_type(module, base)?;
write!(self.out, ")")?;
close_paren = true;
}
}

self.write_expr(module, base, func_ctx)?;

match *resolved {
Expand All @@ -1988,6 +2000,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
return Err(Error::Custom(format!("Cannot index {:?}", other)))
}
}

if close_paren {
write!(self.out, ")")?;
}
}
}
Expression::FunctionArgument(pos) => {
Expand Down Expand Up @@ -2160,34 +2176,25 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
// See the module-level block comment in mod.rs for details.
Some(crate::AddressSpace::Uniform) => {
let mut close_paren = false;
if let Expression::GlobalVariable(handle) = func_ctx.expressions[pointer] {
let ty = module.global_variables[handle].ty;
match module.types[ty].inner {
TypeInner::Matrix {
rows: crate::VectorSize::Bi,
columns,
..
} => {
self.write_type(module, ty)?;
write!(self.out, "(")?;

let name = &NameKey::GlobalVariable(handle);
if let TypeInner::Matrix {
rows: crate::VectorSize::Bi,
..
} = module.types[ty].inner
{
write!(self.out, "((")?;
self.write_type(module, ty)?;
write!(self.out, ")")?;
close_paren = true;
}
}

for i in 0..columns as u8 {
if i != 0 {
write!(self.out, ", ")?;
}
write!(self.out, "{}._{}", &self.names[name], i)?;
}
self.write_expr(module, pointer, func_ctx)?;

write!(self.out, ")")?;
}
_ => {
self.write_expr(module, pointer, func_ctx)?;
}
}
} else {
self.write_expr(module, pointer, func_ctx)?;
if close_paren {
write!(self.out, ")")?;
}
}
_ => {
Expand Down Expand Up @@ -2651,12 +2658,17 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

fn is_array_of_matrices(module: &Module, handle: Handle<crate::Type>) -> bool {
fn get_inner_matrix_data(
module: &Module,
handle: Handle<crate::Type>,
) -> Option<(crate::VectorSize, crate::VectorSize, u8)> {
match module.types[handle].inner {
TypeInner::Array { base, .. } => match module.types[base].inner {
TypeInner::Matrix { .. } => true,
_ => is_array_of_matrices(module, base),
},
_ => false,
TypeInner::Matrix {
columns,
rows,
width,
} => Some((columns, rows, width)),
TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
_ => None,
}
}
6 changes: 5 additions & 1 deletion tests/in/globals.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ var<uniform> global_vec: vec3<f32>;
var<uniform> global_mat: mat3x2<f32>;

@group(0) @binding(6)
var<uniform> global_nested_arrays_of_matrices: array<array<mat4x3<f32>, 2>, 2>;
var<uniform> global_nested_arrays_of_matrices_4x4: array<array<mat4x4<f32>, 2>, 2>;

@group(0) @binding(7)
var<uniform> global_nested_arrays_of_matrices_4x2: array<array<mat4x2<f32>, 2>, 2>;

fn test_msl_packed_vec3_as_arg(arg: vec3<f32>) {}

Expand Down Expand Up @@ -59,6 +62,7 @@ fn test_msl_packed_vec3() {
fn main() {
test_msl_packed_vec3();

wg[7] = (global_nested_arrays_of_matrices_4x2[0][0] * global_nested_arrays_of_matrices_4x4[0][0][0]).x;
wg[6] = (global_mat * global_vec).x;
wg[5] = dummy[1].y;
wg[4] = float_vecs[0].w;
Expand Down
33 changes: 20 additions & 13 deletions tests/out/glsl/globals.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ uniform type_4_block_3Compute { vec3 _group_0_binding_4_cs; };

uniform type_9_block_4Compute { mat3x2 _group_0_binding_5_cs; };

uniform type_12_block_5Compute { mat4x4 _group_0_binding_6_cs[2][2]; };

uniform type_15_block_6Compute { mat4x2 _group_0_binding_7_cs[2][2]; };


void test_msl_packed_vec3_as_arg(vec3 arg) {
return;
Expand All @@ -33,8 +37,8 @@ void test_msl_packed_vec3_() {
_group_0_binding_1_cs.v3_ = vec3(1.0);
_group_0_binding_1_cs.v3_.x = 1.0;
_group_0_binding_1_cs.v3_.x = 2.0;
int _e22 = idx;
_group_0_binding_1_cs.v3_[_e22] = 3.0;
int _e23 = idx;
_group_0_binding_1_cs.v3_[_e23] = 3.0;
Foo data = _group_0_binding_1_cs;
vec3 unnamed = data.v3_;
vec2 unnamed_1 = data.v3_.zx;
Expand All @@ -49,17 +53,20 @@ void main() {
float Foo_1 = 1.0;
bool at = true;
test_msl_packed_vec3_();
mat3x2 _e11 = _group_0_binding_5_cs;
vec3 _e12 = _group_0_binding_4_cs;
wg[6] = (_e11 * _e12).x;
float _e20 = _group_0_binding_2_cs[1].y;
wg[5] = _e20;
float _e26 = _group_0_binding_3_cs[0].w;
wg[4] = _e26;
float _e30 = _group_0_binding_1_cs.v1_;
wg[3] = _e30;
float _e35 = _group_0_binding_1_cs.v3_.x;
wg[2] = _e35;
mat4x2 _e16 = _group_0_binding_7_cs[0][0];
vec4 _e23 = _group_0_binding_6_cs[0][0][0];
wg[7] = (_e16 * _e23).x;
mat3x2 _e28 = _group_0_binding_5_cs;
vec3 _e29 = _group_0_binding_4_cs;
wg[6] = (_e28 * _e29).x;
float _e37 = _group_0_binding_2_cs[1].y;
wg[5] = _e37;
float _e43 = _group_0_binding_3_cs[0].w;
wg[4] = _e43;
float _e47 = _group_0_binding_1_cs.v1_;
wg[3] = _e47;
float _e52 = _group_0_binding_1_cs.v3_.x;
wg[2] = _e52;
_group_0_binding_1_cs.v1_ = 4.0;
wg[1] = float(uint(_group_0_binding_2_cs.length()));
at_1 = 2u;
Expand Down
32 changes: 18 additions & 14 deletions tests/out/hlsl/globals.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ ByteAddressBuffer dummy : register(t2);
cbuffer float_vecs : register(b3) { float4 float_vecs[20]; }
cbuffer global_vec : register(b4) { float3 global_vec; }
cbuffer global_mat : register(b5) { struct { float2 _0; float2 _1; float2 _2; } global_mat; }
cbuffer global_nested_arrays_of_matrices : register(b6) { row_major float4x3 global_nested_arrays_of_matrices[2][2]; }
cbuffer global_nested_arrays_of_matrices_4x4_ : register(b6) { row_major float4x4 global_nested_arrays_of_matrices_4x4_[2][2]; }
cbuffer global_nested_arrays_of_matrices_4x2_ : register(b7) { struct { float2 _0; float2 _1; float2 _2; float2 _3; } global_nested_arrays_of_matrices_4x2_[2][2]; }

void test_msl_packed_vec3_as_arg(float3 arg)
{
Expand All @@ -33,8 +34,8 @@ void test_msl_packed_vec3_()
alignment.Store3(0, asuint((1.0).xxx));
alignment.Store(0+0, asuint(1.0));
alignment.Store(0+0, asuint(2.0));
int _expr22 = idx;
alignment.Store(_expr22*4+0, asuint(3.0));
int _expr23 = idx;
alignment.Store(_expr23*4+0, asuint(3.0));
Foo data = ConstructFoo(asfloat(alignment.Load3(0)), asfloat(alignment.Load(12)));
float3 unnamed = data.v3_;
float2 unnamed_1 = data.v3_.zx;
Expand All @@ -59,17 +60,20 @@ void main()
bool at = true;

test_msl_packed_vec3_();
float3x2 _expr11 = float3x2(global_mat._0, global_mat._1, global_mat._2);
float3 _expr12 = global_vec;
wg[6] = mul(_expr12, _expr11).x;
float _expr20 = asfloat(dummy.Load(4+8));
wg[5] = _expr20;
float _expr26 = float_vecs[0].w;
wg[4] = _expr26;
float _expr30 = asfloat(alignment.Load(12));
wg[3] = _expr30;
float _expr35 = asfloat(alignment.Load(0+0));
wg[2] = _expr35;
float4x2 _expr16 = ((float4x2)global_nested_arrays_of_matrices_4x2_[0][0]);
float4 _expr23 = global_nested_arrays_of_matrices_4x4_[0][0][0];
wg[7] = mul(_expr23, _expr16).x;
float3x2 _expr28 = ((float3x2)global_mat);
float3 _expr29 = global_vec;
wg[6] = mul(_expr29, _expr28).x;
float _expr37 = asfloat(dummy.Load(4+8));
wg[5] = _expr37;
float _expr43 = float_vecs[0].w;
wg[4] = _expr43;
float _expr47 = asfloat(alignment.Load(12));
wg[3] = _expr47;
float _expr52 = asfloat(alignment.Load(0+0));
wg[2] = _expr52;
alignment.Store(12, asuint(4.0));
wg[1] = float(((NagaBufferLength(dummy) - 0) / 8));
at_1 = 2u;
Expand Down
45 changes: 28 additions & 17 deletions tests/out/msl/globals.msl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@ struct type_8 {
metal::float4 inner[20];
};
struct type_11 {
metal::float4x3 inner[2];
metal::float4x4 inner[2];
};
struct type_12 {
type_11 inner[2];
};
struct type_14 {
metal::float4x2 inner[2];
};
struct type_15 {
type_14 inner[2];
};
constant metal::float3 const_type_4_ = {0.0, 0.0, 0.0};
constant metal::float3x3 const_type_14_ = {const_type_4_, const_type_4_, const_type_4_};
constant metal::float3x3 const_type_17_ = {const_type_4_, const_type_4_, const_type_4_};

void test_msl_packed_vec3_as_arg(
metal::float3 arg
Expand All @@ -42,14 +48,14 @@ void test_msl_packed_vec3_(
alignment.v3_ = metal::float3(1.0);
alignment.v3_[0] = 1.0;
alignment.v3_[0] = 2.0;
int _e22 = idx;
alignment.v3_[_e22] = 3.0;
int _e23 = idx;
alignment.v3_[_e23] = 3.0;
Foo data = alignment;
metal::float3 unnamed = data.v3_;
metal::float2 unnamed_1 = metal::float3(data.v3_).zx;
test_msl_packed_vec3_as_arg(data.v3_);
metal::float3 unnamed_2 = metal::float3(data.v3_) * const_type_14_;
metal::float3 unnamed_3 = const_type_14_ * metal::float3(data.v3_);
metal::float3 unnamed_2 = metal::float3(data.v3_) * const_type_17_;
metal::float3 unnamed_3 = const_type_17_ * metal::float3(data.v3_);
metal::float3 unnamed_4 = data.v3_ * 2.0;
metal::float3 unnamed_5 = 2.0 * data.v3_;
}
Expand All @@ -62,22 +68,27 @@ kernel void main_(
, constant type_8& float_vecs [[user(fake0)]]
, constant metal::float3& global_vec [[user(fake0)]]
, constant metal::float3x2& global_mat [[user(fake0)]]
, constant type_12& global_nested_arrays_of_matrices_4x4_ [[user(fake0)]]
, constant type_15& global_nested_arrays_of_matrices_4x2_ [[user(fake0)]]
, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]]
) {
float Foo_1 = 1.0;
bool at = true;
test_msl_packed_vec3_(alignment);
metal::float3x2 _e11 = global_mat;
metal::float3 _e12 = global_vec;
wg.inner[6] = (_e11 * _e12).x;
float _e20 = dummy[1].y;
wg.inner[5] = _e20;
float _e26 = float_vecs.inner[0].w;
wg.inner[4] = _e26;
float _e30 = alignment.v1_;
wg.inner[3] = _e30;
float _e35 = alignment.v3_[0];
wg.inner[2] = _e35;
metal::float4x2 _e16 = global_nested_arrays_of_matrices_4x2_.inner[0].inner[0];
metal::float4 _e23 = global_nested_arrays_of_matrices_4x4_.inner[0].inner[0][0];
wg.inner[7] = (_e16 * _e23).x;
metal::float3x2 _e28 = global_mat;
metal::float3 _e29 = global_vec;
wg.inner[6] = (_e28 * _e29).x;
float _e37 = dummy[1].y;
wg.inner[5] = _e37;
float _e43 = float_vecs.inner[0].w;
wg.inner[4] = _e43;
float _e47 = alignment.v1_;
wg.inner[3] = _e47;
float _e52 = alignment.v3_[0];
wg.inner[2] = _e52;
alignment.v1_ = 4.0;
wg.inner[1] = static_cast<float>(1 + (_buffer_sizes.size3 - 0 - 8) / 8);
metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed);
Expand Down
Loading

0 comments on commit fbe927c

Please sign in to comment.