Skip to content

Commit

Permalink
Add F8_16x16x32_F32 support for MFMA (#17792)
Browse files Browse the repository at this point in the history
Added F8_16x16x32xF32 MFMA layout support and their e2e tests. 
Needed to adjust/branch in the e2e matmul test's cmake because only gfx94x GPUs have FP8 MFMA layouts and it has different I8 intrinsic shape/layout as opposed to what is present in gfx90x.
  • Loading branch information
raikonenfnu committed Jul 23, 2024
1 parent 4529567 commit 09da7be
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 16 deletions.
4 changes: 2 additions & 2 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: chip = <wgp_count = 304>>

// GFX940: target = #iree_gpu.target<arch = "gfx940",
// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],

// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) {

static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
MMAIntrinsic type) {
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
Type f16 = Float16Type::get(context);
Type f32 = Float32Type::get(context);

Expand All @@ -216,6 +217,9 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
}
Expand Down Expand Up @@ -290,6 +294,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]>
Expand Down Expand Up @@ -416,6 +421,7 @@ MMAAttr::getABCVectorTypes() const {
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
Expand Down Expand Up @@ -452,6 +458,7 @@ int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
Expand All @@ -467,6 +474,7 @@ int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
return 64;
Expand All @@ -490,6 +498,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 4}};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 8}};
Expand Down Expand Up @@ -517,6 +526,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{8, 1}};
Expand All @@ -537,6 +547,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
Expand Down Expand Up @@ -573,6 +584,7 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
auto [m, n, k] = getMNKShape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,18 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>
// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 2>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 3>;
def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 2>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 4>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 5>;
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 6>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>;

def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
MFMA_F8E4M3FNUZ_16x16x32_F32,
MFMA_I8_16x16x32_I32,
MFMA_I8_32x32x16_I32,
WMMA_F16_16x16x16_F32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ const WgpDetails *getCDNA3WgpDetails() {
static const MMAIntrinsic cdna3MMAOps[] = {
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32,
MMAIntrinsic::MFMA_I8_16x16x32_I32,
MMAIntrinsic::MFMA_I8_32x32x16_I32,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,56 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {

// -----

// Basic f8, f8 -> f32 matmul.

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
hal.executable @matmul_256x256x256_f8_f32 {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export @matmul_256x256x256_f8_f32 layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @matmul_256x256x256_f8_f32() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
%5 = tensor.empty() : tensor<256x256xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
%7 = linalg.matmul ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
return
}
}
}
}

// Make sure it generates the mfma instructions we expect for f8 inputs.

// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F8E4M3FNUZ_16x16x32_F32>,
// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2>
// CHECK-SAME: prefetch_shared_memory

// CHECK-LABEL: func.func @matmul_256x256x256_f8_f32()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
// along the K dimension. So in total 32 mfma ops.
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type<storage_buffer>>

// -----

// Basic i8, i8 -> i32 matmul.

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ std::optional<int32_t> ElementTypeOp::getTypeValue(Type type) {
return makeElementTypeValue(numericalType, intType.getWidth());
} else if (auto floatType = llvm::dyn_cast_if_present<FloatType>(type)) {
switch (APFloat::SemanticsToEnum(floatType.getFloatSemantics())) {
case APFloat::S_Float8E4M3FNUZ:
case APFloat::S_IEEEhalf:
case APFloat::S_IEEEsingle:
case APFloat::S_IEEEdouble:
Expand Down
34 changes: 34 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2318,6 +2318,39 @@ iree_generated_e2e_runner_test(
"requires-gpu-cdna3"
)

if(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx94")

# I8 Intrinsics has different layout on CDNA3/gfx94x,
# and only CDNA3/gfx94x has F8 intrinsics.

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_f8_large_cdna3_mfma
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f8E4M3FNUZ"
"--acc_type=f32"
"--shapes=gpu_large_aligned"
"--compilation_info=LLVMGPUVectorDistributeMFMA"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_i8_large_cdna3_mfma_tb
Expand Down Expand Up @@ -2375,6 +2408,7 @@ iree_generated_e2e_runner_test(
"noubsan"
"requires-gpu-cdna3"
)
endif()

elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")

Expand Down
56 changes: 46 additions & 10 deletions tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MatrixElemTypeId(enum.Enum):
I32 = "i32"
F32 = "f32"
F16 = "f16"
F8E4M3FNUZ = "f8E4M3FNUZ"
BF16 = "bf16"


Expand Down Expand Up @@ -271,6 +272,10 @@ def get_rocm_test_compilation_infos(
MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1),
MMASchedule("MFMA_F16_32x32x8_F32", 1, 4, 2, 1, 2),
MMASchedule("MFMA_F16_32x32x8_F32", 4, 2, 1, 2, 4),
MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 1, 4, 1, 1),
MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 2, 4, 2, 1),
MMASchedule("MFMA_I8_16x16x32_I32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_I8_16x16x32_I32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_I8_16x16x32_I32", 4, 1, 4, 1, 1),
Expand Down Expand Up @@ -310,7 +315,10 @@ def get_rocm_test_compilation_infos(
wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
wg_tile_k = schedule.k_tile_count * 8
elif schedule.intrinsic == "MFMA_I8_16x16x32_I32":
elif (
schedule.intrinsic == "MFMA_I8_16x16x32_I32"
or schedule.intrinsic == "MFMA_F8E4M3FNUZ_16x16x32_F32"
):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 32
Expand Down Expand Up @@ -454,6 +462,20 @@ def int_or_DYN(s: DimSize):
return s.value or "DYN"


# Gets friendlier form/type that we can use as arg types which we can cast into the target_type.
def cast_argtype_if_required(target_type: MatrixElemTypeId):
if target_type == MatrixElemTypeId.F8E4M3FNUZ:
return MatrixElemTypeId.F32
return target_type


# Gets the op needed to cast/convert from the friendly form/type into the target_type.
def get_castback_from_arg_op(target_type: MatrixElemTypeId):
if target_type == MatrixElemTypeId.F8E4M3FNUZ:
return "arith.truncf"
return ValueError(f"Unhandled castback type of {t}")


# Describes the fully resolved shape dimensions of all 3 input matrices,
# LHS, RHS, and Accumulator, in a testcase.
# Each value is a string, which may either represent a positive integer such as "123",
Expand Down Expand Up @@ -559,8 +581,10 @@ def generate_function(
rhs_c = int_or_question_mark(shapes.rhs_cols)
acc_r = int_or_question_mark(shapes.acc_rows)
acc_c = int_or_question_mark(shapes.acc_cols)
lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"

casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{casted_lhs_rhs_type.value}>"
rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{casted_lhs_rhs_type.value}>"
acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"

if transpose_rhs:
Expand Down Expand Up @@ -603,13 +627,22 @@ def generate_function(
)
func_definition = func_definition + compilation_info_string
generate_function.compilation_index += 1

compute = f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
if casted_lhs_rhs_type != lhs_rhs_type:
castback_op = get_castback_from_arg_op(lhs_rhs_type)
compute_lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
compute_rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
compute = (
f" %lhs_casted = {castback_op} %lhs: {lhs_tensor_type} to {compute_lhs_tensor_type}\n"
f" %rhs_casted = {castback_op} %rhs: {rhs_tensor_type} to {compute_rhs_tensor_type}\n"
f" %result = {op_name} {compilation_info_attr}ins(%lhs_casted, %rhs_casted: {compute_lhs_tensor_type}, {compute_rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}"
)
if shape.accumulate:
signature = f"({lhs_tensor_type}, {rhs_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
func_definition = func_definition + (
f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f"{compute}\n"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
Expand All @@ -627,7 +660,7 @@ def generate_function(
f" %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n"
f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f"{compute}"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
Expand All @@ -639,7 +672,7 @@ def generate_function(
f" %init_acc = tensor.empty() : {acc_tensor_type}\n"
f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f"{compute}"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
Expand Down Expand Up @@ -733,8 +766,9 @@ def generate_call(
rhs_shape = [shape.k, shape.n]
transpose_rhs = 0

op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
op = op + generate_random_matrix("lhs", lhs_shape, casted_lhs_rhs_type)
op = op + generate_random_matrix("rhs", rhs_shape, casted_lhs_rhs_type)
if shape.accumulate:
op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
# TODO(#16168): there's a bug with in-place input->output aliasing and
Expand Down Expand Up @@ -822,7 +856,7 @@ def parse_arguments():
parser.add_argument(
"--lhs_rhs_type",
type=str,
choices=["i32", "i8", "f32", "f16", "bf16"],
choices=["i32", "i8", "f32", "f16", "f8E4M3FNUZ", "bf16"],
help="Numeric type of input matrices",
required=True,
)
Expand Down Expand Up @@ -915,6 +949,8 @@ def write_calls_file(functions, calls, filename, requirements):
def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
if acc_type != MatrixElemTypeId.NONE:
return acc_type
if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
return MatrixElemTypeId.F32
if lhs_rhs_type == MatrixElemTypeId.I8:
return MatrixElemTypeId.I32
return lhs_rhs_type
Expand Down

0 comments on commit 09da7be

Please sign in to comment.