Skip to content

Commit

Permalink
[mlir][spirv] Define KHR cooperative matrix properties
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar committed Sep 19, 2023
1 parent ab2c104 commit 676026a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ def SPIRV_LinkageAttributesAttr : SPIRV_Attr<"LinkageAttributes", "linkage_attri
let assemblyFormat = "`<` struct(params) `>`";
}

// Description of cooperative matrix operations supported on the
// target. Represents `VkCooperativeMatrixPropertiesKHR`. See
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkCooperativeMatrixPropertiesKHR.html
def SPIRV_CooperativeMatrixPropertiesKHRAttr :
SPIRV_Attr<"CooperativeMatrixPropertiesKHR", "coop_matrix_props_khr"> {
let parameters = (ins
"uint32_t":$m_size,
"uint32_t":$n_size,
"uint32_t":$k_size,
"mlir::Type":$a_type,
"mlir::Type":$b_type,
"mlir::Type":$c_type,
"mlir::Type":$result_type,
"bool":$acc_sat,
"mlir::spirv::ScopeAttr":$scope
);
let assemblyFormat = "`<` struct(params) `>`";
}

def SPIRV_CooperativeMatrixPropertiesKHRArrayAttr :
TypedArrayAttrBase<SPIRV_CooperativeMatrixPropertiesKHRAttr,
"CooperativeMatrixPropertiesKHR array attribute">;

// Description of cooperative matrix operations supported on the
// target. Represents `VkCooperativeMatrixPropertiesNV`. See
// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkCooperativeMatrixPropertiesNV.html
Expand Down Expand Up @@ -130,6 +153,11 @@ def SPIRV_ResourceLimitsAttr : SPIRV_Attr<"ResourceLimits", "resource_limits"> {

// The configurations of cooperative matrix operations
// supported. Default is an empty list.
DefaultValuedParameter<
"ArrayAttr",
"nullptr"
>:$cooperative_matrix_properties_khr,

DefaultValuedParameter<
"ArrayAttr",
"nullptr"
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ spirv::getDefaultResourceLimits(MLIRContext *context) {
/*subgroup_size=*/32,
/*min_subgroup_size=*/std::nullopt,
/*max_subgroup_size=*/std::nullopt,
/*cooperative_matrix_properties_nv=*/ArrayAttr());
/*cooperative_matrix_properties_khr=*/ArrayAttr{},
/*cooperative_matrix_properties_nv=*/ArrayAttr{});
}

StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
Expand Down
41 changes: 41 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,47 @@ func.func @target_env_extra_fields() attributes {

// -----

func.func @target_env_cooperative_matrix_khr() attributes{
// CHECK: spirv.target_env = #spirv.target_env<
// CHECK-SAME: SPV_KHR_cooperative_matrix
// CHECK-SAME: #spirv.coop_matrix_props_khr<
// CHECK-SAME: m_size = 8, n_size = 8, k_size = 32,
// CHECK-SAME: a_type = i8, b_type = i8, c_type = i32,
// CHECK-SAME: result_type = i32, acc_sat = true, scope = <Subgroup>>
// CHECK-SAME: #spirv.coop_matrix_props_khr<
// CHECK-SAME: m_size = 8, n_size = 8, k_size = 16,
// CHECK-SAME: a_type = f16, b_type = f16, c_type = f16,
// CHECK-SAME: result_type = f16, acc_sat = false, scope = <Subgroup>>
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class,
SPV_KHR_cooperative_matrix]>,
#spirv.resource_limits<
cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<
m_size = 8,
n_size = 8,
k_size = 32,
a_type = i8,
b_type = i8,
c_type = i32,
result_type = i32,
acc_sat = true,
scope = #spirv.scope<Subgroup>
>, #spirv.coop_matrix_props_khr<
m_size = 8,
n_size = 8,
k_size = 16,
a_type = f16,
b_type = f16,
c_type = f16,
result_type = f16,
acc_sat = false,
scope = #spirv.scope<Subgroup>
>]
>>
} { return }

// -----

func.func @target_env_cooperative_matrix_nv() attributes{
// CHECK: spirv.target_env = #spirv.target_env<
// CHECK-SAME: SPV_NV_cooperative_matrix
Expand Down

0 comments on commit 676026a

Please sign in to comment.