Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Define KHR cooperative matrix properties #66823

Merged
merged 1 commit into from
Sep 19, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Sep 19, 2023

Stacked on top of #66820.

@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Changes

Stacked on top of #66820.


Full diff: https://github.com/llvm/llvm-project/pull/66823.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td (+29-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp (+2-1)
  • (modified) mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir (+46-5)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index 259a96651abb3f3..f2c1ee5cfd56eab 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -54,11 +54,34 @@ def SPIRV_LinkageAttributesAttr : SPIRV_Attr<"LinkageAttributes", "linkage_attri
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+// Description of cooperative matrix operations supported on the
+// target. Represents `VkCooperativeMatrixPropertiesKHR`. See
+// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkCooperativeMatrixPropertiesKHR.html
+def SPIRV_CooperativeMatrixPropertiesKHRAttr :
+    SPIRV_Attr<"CooperativeMatrixPropertiesKHR", "coop_matrix_props_khr"> {
+  let parameters = (ins
+    "uint32_t":$m_size,
+    "uint32_t":$n_size,
+    "uint32_t":$k_size,
+    "mlir::Type":$a_type,
+    "mlir::Type":$b_type,
+    "mlir::Type":$c_type,
+    "mlir::Type":$result_type,
+    "bool":$acc_sat,
+    "mlir::spirv::ScopeAttr":$scope
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def SPIRV_CooperativeMatrixPropertiesKHRArrayAttr :
+    TypedArrayAttrBase<SPIRV_CooperativeMatrixPropertiesKHRAttr,
+                       "CooperativeMatrixPropertiesKHR array attribute">;
+
 // Description of cooperative matrix operations supported on the
 // target. Represents `VkCooperativeMatrixPropertiesNV`. See
 // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkCooperativeMatrixPropertiesNV.html
 def SPIRV_CooperativeMatrixPropertiesNVAttr :
-    SPIRV_Attr<"CooperativeMatrixPropertiesNV", "coop_matrix_props"> {
+    SPIRV_Attr<"CooperativeMatrixPropertiesNV", "coop_matrix_props_nv"> {
   let parameters = (ins
     "int":$m_size,
     "int":$n_size,
@@ -130,6 +153,11 @@ def SPIRV_ResourceLimitsAttr : SPIRV_Attr<"ResourceLimits", "resource_limits"> {
 
     // The configurations of cooperative matrix operations
     // supported. Default is an empty list.
+    DefaultValuedParameter<
+      "ArrayAttr",
+      "nullptr"
+    >:$cooperative_matrix_properties_khr,
+
     DefaultValuedParameter<
       "ArrayAttr",
       "nullptr"
diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
index 051b2cb9f1a88eb..5b7c0a59ba42009 100644
--- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
@@ -166,7 +166,8 @@ spirv::getDefaultResourceLimits(MLIRContext *context) {
       /*subgroup_size=*/32,
       /*min_subgroup_size=*/std::nullopt,
       /*max_subgroup_size=*/std::nullopt,
-      /*cooperative_matrix_properties_nv=*/ArrayAttr());
+      /*cooperative_matrix_properties_khr=*/ArrayAttr{},
+      /*cooperative_matrix_properties_nv=*/ArrayAttr{});
 }
 
 StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 82a7601dbd06e96..10fbcf06eb05203 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -208,14 +208,55 @@ func.func @target_env_extra_fields() attributes {
 
 // -----
 
-func.func @target_env_cooperative_matrix() attributes{
+func.func @target_env_cooperative_matrix_khr() attributes{
+  // CHECK:      spirv.target_env = #spirv.target_env<
+  // CHECK-SAME:   SPV_KHR_cooperative_matrix
+  // CHECK-SAME: #spirv.coop_matrix_props_khr<
+  // CHECK-SAME:   m_size = 8, n_size = 8, k_size = 32,
+  // CHECK-SAME:   a_type = i8, b_type = i8, c_type = i32,
+  // CHECK-SAME:   result_type = i32, acc_sat = true, scope = <Subgroup>>
+  // CHECK-SAME: #spirv.coop_matrix_props_khr<
+  // CHECK-SAME:   m_size = 8, n_size = 8, k_size = 16,
+  // CHECK-SAME:   a_type = f16, b_type = f16, c_type = f16,
+  // CHECK-SAME:   result_type = f16, acc_sat = false, scope = <Subgroup>>
+  spirv.target_env = #spirv.target_env<
+  #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class,
+                            SPV_KHR_cooperative_matrix]>,
+  #spirv.resource_limits<
+    cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<
+      m_size = 8,
+      n_size = 8,
+      k_size = 32,
+      a_type = i8,
+      b_type = i8,
+      c_type = i32,
+      result_type = i32,
+      acc_sat = true,
+      scope = #spirv.scope<Subgroup>
+    >, #spirv.coop_matrix_props_khr<
+      m_size = 8,
+      n_size = 8,
+      k_size = 16,
+      a_type = f16,
+      b_type = f16,
+      c_type = f16,
+      result_type = f16,
+      acc_sat = false,
+      scope = #spirv.scope<Subgroup>
+    >]
+  >>
+} { return }
+
+// -----
+
+func.func @target_env_cooperative_matrix_nv() attributes{
   // CHECK:      spirv.target_env = #spirv.target_env<
   // CHECK-SAME:   SPV_NV_cooperative_matrix
-  // CHECK-SAME: #spirv.coop_matrix_props<
+  // CHECK-SAME: #spirv.coop_matrix_props_nv<
   // CHECK-SAME:   m_size = 8, n_size = 8, k_size = 32,
   // CHECK-SAME:   a_type = i8, b_type = i8, c_type = i32,
   // CHECK-SAME:   result_type = i32, scope = <Subgroup>>
-  // CHECK-SAME: #spirv.coop_matrix_props<
+  // CHECK-SAME: #spirv.coop_matrix_props_nv<
   // CHECK-SAME:   m_size = 8, n_size = 8, k_size = 16,
   // CHECK-SAME:   a_type = f16, b_type = f16, c_type = f16,
   // CHECK-SAME:   result_type = f16, scope = <Subgroup>>
@@ -223,7 +264,7 @@ func.func @target_env_cooperative_matrix() attributes{
   #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class,
                             SPV_NV_cooperative_matrix]>,
   #spirv.resource_limits<
-    cooperative_matrix_properties_nv = [#spirv.coop_matrix_props<
+    cooperative_matrix_properties_nv = [#spirv.coop_matrix_props_nv<
       m_size = 8,
       n_size = 8,
       k_size = 32,
@@ -232,7 +273,7 @@ func.func @target_env_cooperative_matrix() attributes{
       c_type = i32,
       result_type = i32,
       scope = #spirv.scope<Subgroup>
-    >, #spirv.coop_matrix_props<
+    >, #spirv.coop_matrix_props_nv<
       m_size = 8,
       n_size = 8,
       k_size = 16,

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuhar
Copy link
Member Author

kuhar commented Sep 19, 2023

Rebased

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants