Skip to content

Commit

Permalink
[naga wgsl-in] Automatic conversions for global var initializers.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy authored and teoxoy committed Dec 6, 2023
1 parent 1970210 commit 1676ee0
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 4 deletions.
28 changes: 24 additions & 4 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -875,10 +875,30 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ast::GlobalDeclKind::Var(ref v) => {
let ty = self.resolve_ast_type(v.ty, &mut ctx)?;

let init = v
.init
.map(|init| self.expression(init, &mut ctx.as_const()))
.transpose()?;
let init;
if let Some(init_ast) = v.init {
let mut ectx = ctx.as_const();
let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(ty);
let converted = ectx
.try_automatic_conversions(lowered, &ty_res, v.name.span)
.map_err(|error| match error {
Error::AutoConversion {
dest_span: _,
dest_type,
source_span: _,
source_type,
} => Error::InitializationTypeMismatch {
name: v.name.span,
expected: dest_type,
got: source_type,
},
other => other,
})?;
init = Some(converted);
} else {
init = None;
}

let binding = if let Some(ref binding) = v.binding {
Some(crate::ResourceBinding {
Expand Down
44 changes: 44 additions & 0 deletions naga/tests/in/abstract-types-var.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// i/x: type inferred / explicit
// vX/mX/aX: vector / matrix / array of X
// where X: u/i/f: u32 / i32 / f32
// s: vector splat
// r: vector spread (vector arg to vector constructor)
// p: "partial" constructor (type parameter inferred)
// u/i/f/ai/af: u32 / i32 / f32 / abstract float / abstract integer as parameter
// _: just for alignment

// Ensure that:
// - the inferred type is correct.
// - all parameters' types are considered.
// - all parameters are converted to the consensus type.

var<private> xvipaiai: vec2<i32> = vec2(42, 43);
var<private> xvupaiai: vec2<u32> = vec2(44, 45);
var<private> xvfpaiai: vec2<f32> = vec2(46, 47);

var<private> xvupuai: vec2<u32> = vec2(42u, 43);
var<private> xvupaiu: vec2<u32> = vec2(42, 43u);

var<private> xvuuai: vec2<u32> = vec2<u32>(42u, 43);
var<private> xvuaiu: vec2<u32> = vec2<u32>(42, 43u);

var<private> xmfpaiaiaiai: mat2x2<f32> = mat2x2(1, 2, 3, 4);
var<private> xmfpafaiaiai: mat2x2<f32> = mat2x2(1.0, 2, 3, 4);
var<private> xmfpaiafaiai: mat2x2<f32> = mat2x2(1, 2.0, 3, 4);
var<private> xmfpaiaiafai: mat2x2<f32> = mat2x2(1, 2, 3.0, 4);
var<private> xmfpaiaiaiaf: mat2x2<f32> = mat2x2(1, 2, 3, 4.0);

var<private> xvispai: vec2<i32> = vec2(1);
var<private> xvfspaf: vec2<f32> = vec2(1.0);
var<private> xvis_ai: vec2<i32> = vec2<i32>(1);
var<private> xvus_ai: vec2<u32> = vec2<u32>(1);
var<private> xvfs_ai: vec2<f32> = vec2<f32>(1);
var<private> xvfs_af: vec2<f32> = vec2<f32>(1.0);

var<private> xafafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafaiai: array<f32, 2> = array<f32, 2>(1, 2);

var<private> xafpaiai: array<i32, 2> = array(1, 2);
var<private> xafpaiaf: array<f32, 2> = array(1, 2.0);
var<private> xafpafai: array<f32, 2> = array(1.0, 2);
var<private> xafpafaf: array<f32, 2> = array(1.0, 2.0);
12 changes: 12 additions & 0 deletions naga/tests/out/msl/abstract-types-var.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_5 {
float inner[2];
};
struct type_7 {
int inner[2];
};
78 changes: 78 additions & 0 deletions naga/tests/out/spv/abstract-types-var.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 70
OpCapability Shader
OpCapability Linkage
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpDecorate %10 ArrayStride 4
OpDecorate %12 ArrayStride 4
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpTypeVector %4 2
%6 = OpTypeInt 32 0
%5 = OpTypeVector %6 2
%8 = OpTypeFloat 32
%7 = OpTypeVector %8 2
%9 = OpTypeMatrix %7 2
%11 = OpConstant %6 2
%10 = OpTypeArray %8 %11
%12 = OpTypeArray %4 %11
%13 = OpConstant %4 42
%14 = OpConstant %4 43
%15 = OpConstantComposite %3 %13 %14
%16 = OpConstant %6 44
%17 = OpConstant %6 45
%18 = OpConstantComposite %5 %16 %17
%19 = OpConstant %8 46.0
%20 = OpConstant %8 47.0
%21 = OpConstantComposite %7 %19 %20
%22 = OpConstant %6 42
%23 = OpConstant %6 43
%24 = OpConstantComposite %5 %22 %23
%25 = OpConstant %8 1.0
%26 = OpConstant %8 2.0
%27 = OpConstantComposite %7 %25 %26
%28 = OpConstant %8 3.0
%29 = OpConstant %8 4.0
%30 = OpConstantComposite %7 %28 %29
%31 = OpConstantComposite %9 %27 %30
%32 = OpConstant %4 1
%33 = OpConstantComposite %3 %32 %32
%34 = OpConstantComposite %7 %25 %25
%35 = OpConstant %6 1
%36 = OpConstantComposite %5 %35 %35
%37 = OpConstantComposite %10 %25 %26
%38 = OpConstant %4 2
%39 = OpConstantComposite %12 %32 %38
%41 = OpTypePointer Private %3
%40 = OpVariable %41 Private %15
%43 = OpTypePointer Private %5
%42 = OpVariable %43 Private %18
%45 = OpTypePointer Private %7
%44 = OpVariable %45 Private %21
%46 = OpVariable %43 Private %24
%47 = OpVariable %43 Private %24
%48 = OpVariable %43 Private %24
%49 = OpVariable %43 Private %24
%51 = OpTypePointer Private %9
%50 = OpVariable %51 Private %31
%52 = OpVariable %51 Private %31
%53 = OpVariable %51 Private %31
%54 = OpVariable %51 Private %31
%55 = OpVariable %51 Private %31
%56 = OpVariable %41 Private %33
%57 = OpVariable %45 Private %34
%58 = OpVariable %41 Private %33
%59 = OpVariable %43 Private %36
%60 = OpVariable %45 Private %34
%61 = OpVariable %45 Private %34
%63 = OpTypePointer Private %10
%62 = OpVariable %63 Private %37
%64 = OpVariable %63 Private %37
%66 = OpTypePointer Private %12
%65 = OpVariable %66 Private %39
%67 = OpVariable %63 Private %37
%68 = OpVariable %63 Private %37
%69 = OpVariable %63 Private %37
25 changes: 25 additions & 0 deletions naga/tests/out/wgsl/abstract-types-var.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
var<private> xvipaiai: vec2<i32> = vec2<i32>(42, 43);
var<private> xvupaiai: vec2<u32> = vec2<u32>(44u, 45u);
var<private> xvfpaiai: vec2<f32> = vec2<f32>(46.0, 47.0);
var<private> xvupuai: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xvupaiu: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xvuuai: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xvuaiu: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xmfpaiaiaiai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpafaiaiai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpaiafaiai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpaiaiafai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpaiaiaiaf: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xvispai: vec2<i32> = vec2(1);
var<private> xvfspaf: vec2<f32> = vec2(1.0);
var<private> xvis_ai: vec2<i32> = vec2(1);
var<private> xvus_ai: vec2<u32> = vec2(1u);
var<private> xvfs_ai: vec2<f32> = vec2(1.0);
var<private> xvfs_af: vec2<f32> = vec2(1.0);
var<private> xafafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafaiai: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafpaiai: array<i32, 2> = array<i32, 2>(1, 2);
var<private> xafpaiaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafpafai: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafpafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);

4 changes: 4 additions & 0 deletions naga/tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,10 @@ fn convert_wgsl() {
"abstract-types-const",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL,
),
(
"abstract-types-var",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL,
),
];

for &(name, targets) in inputs.iter() {
Expand Down
16 changes: 16 additions & 0 deletions naga/tests/wgsl_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2003,6 +2003,22 @@ fn constructor_type_error_span() {
)
}

#[test]
fn global_initialization_type_mismatch() {
check(
"
var<private> a: vec2<f32> = vec2<i32>(1i, 2i);
",
r###"error: the type of `a` is expected to be `vec2<f32>`, but got `vec2<i32>`
┌─ wgsl:2:22
2 │ var<private> a: vec2<f32> = vec2<i32>(1i, 2i);
│ ^ definition of `a`
"###,
)
}

#[test]
fn binding_array_local() {
check_validation! {
Expand Down

0 comments on commit 1676ee0

Please sign in to comment.