diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td index fe612319193378..f2c1ee5cfd56ea 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td @@ -54,6 +54,29 @@ def SPIRV_LinkageAttributesAttr : SPIRV_Attr<"LinkageAttributes", "linkage_attri let assemblyFormat = "`<` struct(params) `>`"; } +// Description of cooperative matrix operations supported on the +// target. Represents `VkCooperativeMatrixPropertiesKHR`. See +// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkCooperativeMatrixPropertiesKHR.html +def SPIRV_CooperativeMatrixPropertiesKHRAttr : + SPIRV_Attr<"CooperativeMatrixPropertiesKHR", "coop_matrix_props_khr"> { + let parameters = (ins + "uint32_t":$m_size, + "uint32_t":$n_size, + "uint32_t":$k_size, + "mlir::Type":$a_type, + "mlir::Type":$b_type, + "mlir::Type":$c_type, + "mlir::Type":$result_type, + "bool":$acc_sat, + "mlir::spirv::ScopeAttr":$scope + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def SPIRV_CooperativeMatrixPropertiesKHRArrayAttr : + TypedArrayAttrBase; + // Description of cooperative matrix operations supported on the // target. Represents `VkCooperativeMatrixPropertiesNV`. See // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkCooperativeMatrixPropertiesNV.html @@ -130,6 +153,11 @@ def SPIRV_ResourceLimitsAttr : SPIRV_Attr<"ResourceLimits", "resource_limits"> { // The configurations of cooperative matrix operations // supported. Default is an empty list. + DefaultValuedParameter< + "ArrayAttr", + "nullptr" + >:$cooperative_matrix_properties_khr, + DefaultValuedParameter< "ArrayAttr", "nullptr" diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp index 051b2cb9f1a88e..5b7c0a59ba4200 100644 --- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp @@ -166,7 +166,8 @@ spirv::getDefaultResourceLimits(MLIRContext *context) { /*subgroup_size=*/32, /*min_subgroup_size=*/std::nullopt, /*max_subgroup_size=*/std::nullopt, - /*cooperative_matrix_properties_nv=*/ArrayAttr()); + /*cooperative_matrix_properties_khr=*/ArrayAttr{}, + /*cooperative_matrix_properties_nv=*/ArrayAttr{}); } StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; } diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir index 2a9272568d44bc..10fbcf06eb0520 100644 --- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir @@ -208,6 +208,47 @@ func.func @target_env_extra_fields() attributes { // ----- +func.func @target_env_cooperative_matrix_khr() attributes{ + // CHECK: spirv.target_env = #spirv.target_env< + // CHECK-SAME: SPV_KHR_cooperative_matrix + // CHECK-SAME: #spirv.coop_matrix_props_khr< + // CHECK-SAME: m_size = 8, n_size = 8, k_size = 32, + // CHECK-SAME: a_type = i8, b_type = i8, c_type = i32, + // CHECK-SAME: result_type = i32, acc_sat = true, scope = > + // CHECK-SAME: #spirv.coop_matrix_props_khr< + // CHECK-SAME: m_size = 8, n_size = 8, k_size = 16, + // CHECK-SAME: a_type = f16, b_type = f16, c_type = f16, + // CHECK-SAME: result_type = f16, acc_sat = false, scope = > + spirv.target_env = #spirv.target_env< + #spirv.vce, + #spirv.resource_limits< + cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr< + m_size = 8, + n_size = 8, + k_size = 32, + a_type = i8, + b_type = i8, + c_type = i32, + result_type = i32, + acc_sat = true, + scope = #spirv.scope + >, #spirv.coop_matrix_props_khr< + m_size = 8, + n_size = 8, + k_size = 16, + a_type = f16, + b_type = f16, + c_type = f16, + result_type = f16, + acc_sat = false, + scope = #spirv.scope + >] + >> +} { return } + +// ----- + func.func @target_env_cooperative_matrix_nv() attributes{ // CHECK: spirv.target_env = #spirv.target_env< // CHECK-SAME: SPV_NV_cooperative_matrix