diff --git a/Makefile b/Makefile index 28bd8e7cb4a9..2448609845c2 100644 --- a/Makefile +++ b/Makefile @@ -454,6 +454,7 @@ SOURCE_FILES = \ EmulateFloat16Math.cpp \ Error.cpp \ Expr.cpp \ + ExtractTileOperations.cpp \ FastIntegerDivide.cpp \ FindCalls.cpp \ FindIntrinsics.cpp \ @@ -627,6 +628,7 @@ HEADER_FILES = \ ExprUsesVar.h \ Extern.h \ ExternFuncArgument.h \ + ExtractTileOperations.h \ FastIntegerDivide.h \ FindCalls.h \ FindIntrinsics.h \ @@ -841,6 +843,7 @@ RUNTIME_LL_COMPONENTS = \ x86_avx \ x86_avx2 \ x86_avx512 \ + x86_amx \ x86_sse41 RUNTIME_EXPORTED_INCLUDES = $(INCLUDE_DIR)/HalideRuntime.h \ diff --git a/dependencies/llvm/CMakeLists.txt b/dependencies/llvm/CMakeLists.txt index bbf94f0010ef..eba329c76335 100644 --- a/dependencies/llvm/CMakeLists.txt +++ b/dependencies/llvm/CMakeLists.txt @@ -20,6 +20,9 @@ message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") message(STATUS "Using ClangConfig.cmake in: ${Clang_DIR}") +# LLVM_PACKAGE_VERSION does not propagate to higher scopes +set(Halide_LLVM_VERSION ${LLVM_PACKAGE_VERSION} CACHE INTERNAL "Provided LLVM version") + if (LLVM_PACKAGE_VERSION VERSION_LESS 11.0) message(FATAL_ERROR "LLVM version must be 11.0 or newer") endif () diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8b3be6a070b1..27ece3dbe28b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -59,6 +59,7 @@ set(HEADER_FILES ExprUsesVar.h Extern.h ExternFuncArgument.h + ExtractTileOperations.h FastIntegerDivide.h FindCalls.h FindIntrinsics.h @@ -219,6 +220,7 @@ set(SOURCE_FILES EmulateFloat16Math.cpp Error.cpp Expr.cpp + ExtractTileOperations.cpp FastIntegerDivide.cpp FindCalls.cpp FindIntrinsics.cpp diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index e51e4ffd5a77..af2950e0f6b0 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -2564,11 +2564,9 @@ void CodeGen_LLVM::visit(const Call *op) { internal_assert(op->is_extern() || op->is_intrinsic()) << "Can only codegen extern calls and intrinsics\n"; - if (op->type.is_vector()) { - value = call_overloaded_intrin(op->type, op->name, op->args); - if (value) { - return; - } + value = call_overloaded_intrin(op->type, op->name, op->args); + if (value) { + return; } // Some call nodes are actually injected at various stages as a diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 1626c3d7b7f4..32c8abbdacbc 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -79,15 +79,21 @@ class CodeGen_X86 : public CodeGen_Posix { void visit(const EQ *) override; void visit(const NE *) override; void visit(const Select *) override; + void visit(const Allocate *) override; + void visit(const Load *) override; + void visit(const Store *) override; void codegen_vector_reduce(const VectorReduce *, const Expr &init) override; // @} + +private: + Scope mem_type; }; CodeGen_X86::CodeGen_X86(Target t) : CodeGen_Posix(complete_x86_target(t)) { } -const int max_intrinsic_args = 4; +const int max_intrinsic_args = 6; struct x86Intrinsic { const char *intrin_name; @@ -95,6 +101,10 @@ struct x86Intrinsic { const char *name; halide_type_t arg_types[max_intrinsic_args]; Target::Feature feature = Target::FeatureEnd; + uint32_t flags = 0; + enum Options { + AccessesMemory = 1 << 0, + }; }; // clang-format off @@ -199,6 +209,19 @@ const x86Intrinsic intrinsic_defs[] = { {"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids}, {"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids}, {"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids}, + + {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tileloadd64_i8", UInt(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tileloadd64_bf16", BFloat(16, 512), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tdpbsud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tdpbusd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tdpbuud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tdpbf16ps", Float(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Float(32, 256), BFloat(16, 512), BFloat(16, 512)}, Target::AVX512_SapphireRapids}, + {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, + {"tilezero_f32", Float(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, + {"tilestored64_i32", Int(32), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, + {"tilestored64_f32", Int(32), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Float(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, }; // clang-format on @@ -221,7 +244,9 @@ void CodeGen_X86::init_module() { } auto *fn = declare_intrin_overload(i.name, ret_type, i.intrin_name, std::move(arg_types)); - fn->addFnAttr(llvm::Attribute::ReadNone); + if ((i.flags & x86Intrinsic::AccessesMemory) == 0) { + fn->addFnAttr(llvm::Attribute::ReadNone); + } fn->addFnAttr(llvm::Attribute::NoUnwind); } } @@ -584,6 +609,38 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init CodeGen_Posix::codegen_vector_reduce(op, init); } +void CodeGen_X86::visit(const Allocate *op) { + ScopedBinding bind(mem_type, op->name, op->memory_type); + CodeGen_Posix::visit(op); +} + +void CodeGen_X86::visit(const Load *op) { + if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); + LoadInst *load = builder->CreateAlignedLoad(ptr->getType()->getPointerElementType(), ptr, llvm::Align(op->type.bytes())); + add_tbaa_metadata(load, op->name, op->index); + value = load; + return; + } + CodeGen_Posix::visit(op); +} + +void CodeGen_X86::visit(const Store *op) { + if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { + Value *val = codegen(op->value); + Halide::Type value_type = op->value.type(); + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base); + StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes())); + add_tbaa_metadata(store, op->name, op->index); + return; + } + CodeGen_Posix::visit(op); +} + string CodeGen_X86::mcpu() const { if (target.has_feature(Target::AVX512_SapphireRapids)) { #if LLVM_VERSION >= 120 @@ -644,7 +701,7 @@ string CodeGen_X86::mattrs() const { } if (target.has_feature(Target::AVX512_SapphireRapids)) { #if LLVM_VERSION >= 120 - features += ",+avx512bf16,+avx512vnni"; + features += ",+avx512bf16,+avx512vnni,+amx-int8,+amx-bf16"; #else user_error << "AVX512 SapphireRapids requires LLVM 12 or later."; #endif diff --git a/src/Expr.h b/src/Expr.h index c5472e766fa4..b70d608d290b 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -379,6 +379,10 @@ enum class MemoryType { * intermediate buffers. Necessary for vgather-vscatter instructions * on Hexagon */ VTCM, + + /** AMX Tile register for X86. Any data that would be used in an AMX matrix + * multiplication must first be loaded into an AMX tile register. */ + AMXTile, }; namespace Internal { diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp new file mode 100644 index 000000000000..08df3ff7e39f --- /dev/null +++ b/src/ExtractTileOperations.cpp @@ -0,0 +1,444 @@ +#include "ExtractTileOperations.h" + +#include "IRMatch.h" +#include "IRMutator.h" +#include "IROperator.h" +#include "Util.h" + +namespace Halide { +namespace Internal { + +using std::string; +using std::vector; + +namespace { + +template +struct Tile { + bool result; + Expr base; + Expr stride[Dim]; + int extent[Dim]; +}; + +enum class AMXOpType { + Int8, + Bfloat16, +}; + +/// returns the appropriate `Halide::Type` for the given operation type +Type amx_op_type_result_type(AMXOpType op_ty) { + switch (op_ty) { + case AMXOpType::Int8: + return Int(32, 256); + case AMXOpType::Bfloat16: + return Float(32, 256); + default: + internal_error << "Unexpected"; + return Type(); + } +} + +const auto wild_i32 = Variable::make(Int(32), "*"); +const auto wild_i32x = Variable::make(Int(32, 0), "*"); + +Tile<1> get_1d_tile_index(const Expr &e) { + if (const auto *r1 = e.as()) { + return {true, r1->base, {r1->stride}, {r1->lanes}}; + } + + return {}; +} + +Tile<2> get_2d_tile_index(const Expr &e) { + // ramp(ramp(base, 1, 4), x4(stride), 4) + vector matches; + if (const auto *r1 = e.as()) { + if (const auto *r2 = r1->base.as()) { + auto ramp_2d_pattern = Ramp::make(Ramp::make(wild_i32, wild_i32, r2->lanes), Broadcast::make(wild_i32, r2->lanes), r1->lanes); + if (expr_match(ramp_2d_pattern, e, matches)) { + return {true, std::move(matches[0]), {std::move(matches[2]), std::move(matches[1])}, {r1->lanes, r2->lanes}}; + } + } + } + return {}; +} + +Tile<3> get_3d_tile_index(const Expr &e) { + vector matches; + + // there could be a sub node + const Sub *sub = e.as(); + const Add *add = nullptr; + + if (sub) { + add = sub->a.as(); + } else { + add = e.as(); + } + + if (!add) { + return {}; + } + + const auto &first = add->a; + const auto &second = add->b; + + // ramp(x[x*r](base), x[x*r](stride), x) + x[x*y](ramp(idx, 1, r)) + + const auto *r1 = first.as(); + const auto *b2 = second.as(); + if (!r1 && !b2) { + // Try switching the order + r1 = second.as(); + b2 = first.as(); + } + if (!r1 || !b2) { + return {}; + } + + const auto *b1 = r1->base.as(); + const auto *r2 = b2->value.as(); + + if (!b1 || !r2) { + return {}; + } + + int x_tile = r1->lanes; + int r_tile = r2->lanes; + int y_tile = b1->lanes / r_tile; + if (y_tile != b2->lanes / x_tile) { + return {}; + } + + auto pattern1 = Ramp::make(Broadcast::make(wild_i32, b1->lanes), Broadcast::make(wild_i32, b1->lanes), r1->lanes); + if (!expr_match(pattern1, first, matches)) { + return {}; + } + Expr base = std::move(matches[0]); + Expr x_stride = std::move(matches[1]); + + auto pattern2 = Broadcast::make(Ramp::make(wild_i32, wild_i32, r2->lanes), b2->lanes); + if (!expr_match(pattern2, second, matches)) { + return {}; + } + base += std::move(matches[0]); + Expr r_stride = std::move(matches[1]); + + if (sub) { + Expr adj = sub->b; + const Broadcast *bcast = adj.as(); + + if (!bcast) { + return {}; + } + + if (bcast->lanes != b1->lanes * r1->lanes) { + return {}; + } + + base -= bcast->value; + } + + return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}}; +} + +struct Matmul { + bool result = false; + Stmt stmt; + int tile_x; + int tile_y; + int tile_r; +}; + +Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_type) { + // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] + const auto wild_i8x = Variable::make(Int(8, 0), "*"); + const auto wild_u8x = Variable::make(UInt(8, 0), "*"); + const auto wild_bf16x = Variable::make(BFloat(16, 0), "*"); + const auto wild_f32x = Variable::make(Float(32, 0), "*"); + + vector matches; + if (op_type == AMXOpType::Int8) { + const auto pattern1 = wild_i32x + wild_i32x; + if (!expr_match(pattern1, op->value, matches)) { + return {}; + } + } else { // AMXOpType::Bfloat16 + const auto pattern1 = wild_f32x + wild_f32x; + if (!expr_match(pattern1, op->value, matches)) { + return {}; + } + } + + const auto *reduce = matches[0].as(); + const auto *load = matches[1].as(); + if (!reduce || reduce->op != VectorReduce::Add) { + return {}; + } + if (!load || load->name != op->name || !equal(load->index, op->index)) { + return {}; + } + + if (op_type == AMXOpType::Int8) { + auto pattern2 = cast(Int(32, 0), cast(Int(32, 0), wild_i8x) * wild_i32x); + auto pattern2_unsigned = cast(Int(32, 0), cast(Int(32, 0), wild_u8x) * wild_i32x); + + if (!(expr_match(pattern2, reduce->value, matches) || expr_match(pattern2_unsigned, reduce->value, matches))) { + return {}; + } + } else { + auto pattern2 = cast(Float(32, 0), cast(Float(32, 0), wild_bf16x) * wild_f32x); + + if (!expr_match(pattern2, reduce->value, matches)) { + return {}; + } + } + + const auto *lhs_load = matches[0].as(); + const auto *rhs_broadcast = matches[1].as(); + if (!lhs_load || !rhs_broadcast) { + return {}; + } + const auto *rhs_cast = rhs_broadcast->value.as(); + if (rhs_cast) { + if (op_type == AMXOpType::Int8) { + if (!(rhs_cast->value.type().element_of() == Int(8) || rhs_cast->value.type().element_of() == UInt(8))) { + user_assert(false) << "Expected rhs cast of i8/u8"; + } + } else { // AMXOpType::Bfloat16 + user_assert(rhs_cast->value.type().element_of() == BFloat(16)) << "Expected rhs cast of bf16"; + } + } else { + return {}; + } + + const auto *rhs_load = rhs_cast->value.as(); + if (!rhs_load) { + return {}; + } + + const auto lhs_tile = get_3d_tile_index(lhs_load->index); + + if (!lhs_tile.result) { + return {}; + } + + const int tile_x = lhs_tile.extent[0]; + const int tile_y = lhs_tile.extent[1]; + const int tile_r = lhs_tile.extent[2]; + const int factor = reduce->value.type().lanes() / reduce->type.lanes(); + + Expr rhs_base; + Expr rhs_stride; + + const auto rhs_tile2 = get_2d_tile_index(rhs_load->index); + if (!rhs_tile2.result) { + const auto rhs_tile1 = get_1d_tile_index(rhs_load->index); + + if (!rhs_tile1.result) { + return {}; + } + + if (rhs_tile1.extent[0] != tile_y * tile_r) { + return {}; + } + + rhs_base = rhs_tile1.base; + rhs_stride = rhs_tile1.stride[0]; + } else { + if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) { + return {}; + } + + rhs_base = rhs_tile2.base; + rhs_stride = rhs_tile2.stride[0]; + } + + if (op->index.type().lanes() != tile_x * tile_y || + factor != tile_r) { + return {}; + } + +#if LLVM_VERSION < 130 + user_assert(op_type != AMXOpType::Bfloat16 && + lhs_load->type.is_int() && rhs_cast->value.type().is_int()) + << "LLVM 13 or above is required for unsigned or float AMX instructions"; +#endif + + // {rows, colbytes, var, index} + auto lhs_var = Variable::make(Handle(), lhs_load->name); + const auto &lhs_load_type = lhs_load->type; + int element_width = lhs_load_type.bytes(); + auto lhs_type = lhs_load_type.with_lanes(1024 / element_width); + auto lhs = Call::make(lhs_type, "tile_load", {tile_x, tile_r * element_width, lhs_var, lhs_tile.base * element_width, lhs_tile.stride[0] * element_width}, Call::Intrinsic); + + auto rhs_var = Variable::make(Handle(), rhs_load->name); + const auto &rhs_load_type = rhs_load->type; + auto rhs_type = rhs_load_type.with_lanes(1024 / element_width); + auto rhs = Call::make(rhs_type, "tile_load", {1, tile_y * tile_r * element_width, rhs_var, rhs_base * element_width, rhs_stride * tile_y * element_width}, Call::Intrinsic); + auto res_type = amx_op_type_result_type(op_type); + + // {rows, colbytes, acc, out, lhs, rhs} + auto out = Load::make(res_type, new_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + + // 4 bytes for i32, f32 + auto colbytes = tile_y * 4; + auto matmul = Call::make(res_type, "tile_matmul", {tile_x, colbytes, tile_r, out, lhs, rhs}, Call::Intrinsic); + auto store = Store::make(new_name, matmul, Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); + return {true, std::move(store), tile_x, tile_y, tile_r}; +} + +Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name) { + if (const auto *ramp = op->index.as()) { + if (const auto *bcast = op->value.as()) { + if (is_const_one(ramp->stride) && + is_const_zero(bcast->value) && + (bcast->lanes == tile_x * tile_y)) { + auto rows = Cast::make(Int(16), tile_x); + auto bytes = op->value.type().bytes(); + auto colbytes = Cast::make(Int(16), tile_y * bytes); + const auto &store_type = op->value.type(); + // will be f32 or i32 + auto tile_zero_type = store_type.with_lanes(1024 / store_type.bytes()); + auto val = Call::make(tile_zero_type, "tile_zero", {rows, colbytes}, Call::Intrinsic); + auto store = Store::make(new_name, std::move(val), Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); + return store; + } + } + } + return {}; +} + +Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { + auto tile = get_2d_tile_index(op->index); + if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { + auto out = Variable::make(Handle(), op->name); + auto tile_type = op->value.type().with_lanes(256); + auto tile_val = Load::make(tile_type, amx_name, Ramp::make(0, 1, 256), {}, {}, const_true(256), {}); + auto bytes = op->value.type().bytes(); + internal_assert(bytes == 4) << "AMX store only supported for int32 and float32 output, not for " << op->value.type() << "\n"; + // {tile_x, tile_y, var, base, stride} + auto store = Call::make(Int(32), "tile_store", {tile_x, tile_y * bytes, std::move(out), tile.base * bytes, tile.stride[0] * bytes, std::move(tile_val)}, Call::Intrinsic); + return Evaluate::make(std::move(store)); + } + return {}; +} + +class ExtractTileOperations : public IRMutator { + using IRMutator::visit; + + string tile_name; + string amx_name; + vector pending_stores; + bool in_allocate = false; + int found_tile_x = -1; + int found_tile_y = -1; + int found_tile_r = -1; + AMXOpType op_type; + + Stmt visit(const Allocate *op) override { + if (op->memory_type == MemoryType::AMXTile) { + user_assert( + (op->type.is_int() && op->type.bits() == 32) || + (op->type.is_float() && op->type.bits() == 32)) + << "scheduled tile operations must yield 32-bit integers or 32-bit floats"; + + if (op->type.is_int() && op->type.bits() == 32) { + op_type = AMXOpType::Int8; + } else { + op_type = AMXOpType::Bfloat16; + } + + user_assert(!in_allocate) << "Already in AMX allocation: " << amx_name; + ScopedValue old_amx_name(amx_name, op->name + ".amx"); + ScopedValue old_tile_name(tile_name, op->name); + ScopedValue old_in_alloc(in_allocate, true); + Stmt body = op->body; + + pending_stores.clear(); + body = mutate(body); + if (found_tile_x < 0 || found_tile_y < 0 || found_tile_r < 0) { + return op; + } + if (!pending_stores.empty()) { + // Really only need to go over the pending stores + body = mutate(body); + } + + auto alloc_type = amx_op_type_result_type(op_type); + return Allocate::make(amx_name, alloc_type, MemoryType::AMXTile, {1}, const_true(), body); + } + return IRMutator::visit(op); + } + + Stmt visit(const Free *op) override { + if (op->name != tile_name) { + return op; + } + return Free::make(amx_name); + } + + Stmt visit(const ProducerConsumer *op) override { + if (op->name != tile_name) { + return IRMutator::visit(op); + } + + auto body = mutate(op->body); + return ProducerConsumer::make(amx_name, op->is_producer, std::move(body)); + } + + Expr visit(const Load *op) override { + // Any tile load will be matched elsewhere, so a load here means that + // the AMX tile is used outside of a tile instruction. + user_assert(op->name != tile_name) << "AMX tile allocation used outside a tile instruction"; + return IRMutator::visit(op); + } + + Stmt visit(const Store *op) override { + if (op->name != tile_name) { + const auto *load = op->value.as(); + if (!load || load->name != tile_name) { + return op; + } + auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y); + user_assert(store.defined()) << "Store to AMX tile allocation of a non-tile value"; + return store; + } + + auto matmul = convert_to_matmul(op, amx_name, op_type); + if (matmul.result) { + user_assert( + (found_tile_x < 0 || matmul.tile_x == found_tile_x) && + (found_tile_y < 0 || matmul.tile_y == found_tile_y) && + (found_tile_r < 0 || matmul.tile_r == found_tile_r)) + << "Found different tile sizes for AMX tile allocation"; + found_tile_x = matmul.tile_x; + found_tile_y = matmul.tile_y; + found_tile_r = matmul.tile_r; + return matmul.stmt; + } + + if (found_tile_x < 0 || found_tile_y < 0) { + pending_stores.emplace_back(op); + return op; + } + + auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name); + if (zero.defined()) { + return zero; + } + + // Otherwise there is some other operation using the allocation, so we cannot use the AMX instructions + user_assert(false) << "Found non-tile operations for AMX tile allocation"; + return op; + } +}; + +} // namespace + +Stmt extract_tile_operations(const Stmt &s) { + return ExtractTileOperations().mutate(s); +} +} // namespace Internal +} // namespace Halide diff --git a/src/ExtractTileOperations.h b/src/ExtractTileOperations.h new file mode 100644 index 000000000000..918e3b1b9940 --- /dev/null +++ b/src/ExtractTileOperations.h @@ -0,0 +1,21 @@ +#ifndef HALIDE_EXTRACT_TILE_OPERATIONS_H +#define HALIDE_EXTRACT_TILE_OPERATIONS_H + +/** \file + * Defines the lowering pass that injects calls to tile intrinsics that support + * AMX instructions. + */ + +#include "Expr.h" + +namespace Halide { +namespace Internal { + +/** Rewrite any AMX tile operations that have been stored in the AMXTile memory + * type as intrinsic calls, to be used in the X86 backend. */ +Stmt extract_tile_operations(const Stmt &s); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index b78463f55a95..e9c6364879ae 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1279,6 +1279,7 @@ class InjectThreadBarriers : public IRMutator { case MemoryType::Register: case MemoryType::LockedCache: case MemoryType::VTCM: + case MemoryType::AMXTile: break; } @@ -1303,6 +1304,7 @@ class InjectThreadBarriers : public IRMutator { case MemoryType::Register: case MemoryType::LockedCache: case MemoryType::VTCM: + case MemoryType::AMXTile: break; } diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index ad8c0057ff48..b1b826f46118 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -135,6 +135,9 @@ std::ostream &operator<<(std::ostream &out, const MemoryType &t) { case MemoryType::VTCM: out << "VTCM"; break; + case MemoryType::AMXTile: + out << "AMXTile"; + break; } return out; } diff --git a/src/LLVM_Runtime_Linker.cpp b/src/LLVM_Runtime_Linker.cpp index 4308f1f8ae87..62203f247a4b 100644 --- a/src/LLVM_Runtime_Linker.cpp +++ b/src/LLVM_Runtime_Linker.cpp @@ -230,6 +230,7 @@ DECLARE_NO_INITMOD(windows_d3d12compute_arm) #endif // WITH_D3D12 #ifdef WITH_X86 +DECLARE_LL_INITMOD(x86_amx) DECLARE_LL_INITMOD(x86_avx512) DECLARE_LL_INITMOD(x86_avx2) DECLARE_LL_INITMOD(x86_avx) @@ -237,6 +238,7 @@ DECLARE_LL_INITMOD(x86) DECLARE_LL_INITMOD(x86_sse41) DECLARE_CPP_INITMOD(x86_cpu_features) #else +DECLARE_NO_INITMOD(x86_amx) DECLARE_NO_INITMOD(x86_avx512) DECLARE_NO_INITMOD(x86_avx2) DECLARE_NO_INITMOD(x86_avx) @@ -1064,6 +1066,11 @@ std::unique_ptr get_initial_module_for_target(Target t, llvm::LLVM if (t.has_feature(Target::AVX512)) { modules.push_back(get_initmod_x86_avx512_ll(c)); } +#if LLVM_VERSION >= 120 + if (t.has_feature(Target::AVX512_SapphireRapids)) { + modules.push_back(get_initmod_x86_amx_ll(c)); + } +#endif if (t.has_feature(Target::Profile)) { user_assert(t.os != Target::WebAssemblyRuntime) << "The profiler cannot be used in a threadless environment."; modules.push_back(get_initmod_profiler_inlined(c, bits_64, debug)); diff --git a/src/Lower.cpp b/src/Lower.cpp index 4b99414b2ef4..e225bfde122a 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -22,6 +22,7 @@ #include "DebugToFile.h" #include "Deinterleave.h" #include "EarlyFree.h" +#include "ExtractTileOperations.h" #include "FindCalls.h" #include "FindIntrinsics.h" #include "FlattenNestedRamps.h" @@ -375,6 +376,14 @@ void lower_impl(const vector &output_funcs, s = lower_unsafe_promises(s, t); log("Lowering after lowering unsafe promises:", s); +#if LLVM_VERSION >= 120 + if (t.has_feature(Target::AVX512_SapphireRapids)) { + debug(1) << "Extracting tile operations...\n"; + s = extract_tile_operations(s); + log("Lowering after extracting tile operations:", s); + } +#endif + debug(1) << "Flattening nested ramps...\n"; s = flatten_nested_ramps(s); log("Lowering after flattening nested ramps:", s); diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index 4cc19de8013d..c0282fccb99a 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -115,6 +115,11 @@ set(RUNTIME_LL x86_sse41 ) +if (Halide_LLVM_VERSION VERSION_GREATER_EQUAL 12.0) + # AMX instructions require LLVM 12 or newer + list(APPEND RUNTIME_LL x86_amx) +endif () + set(RUNTIME_BC compute_20 compute_30 diff --git a/src/runtime/x86_amx.ll b/src/runtime/x86_amx.ll new file mode 100644 index 000000000000..6c7f4659f0c8 --- /dev/null +++ b/src/runtime/x86_amx.ll @@ -0,0 +1,97 @@ +define weak_odr <1024 x i8> @tileloadd64_i8(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline readonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly + %3 = bitcast x86_amx %2 to <1024 x i8> + ret <1024 x i8> %3 +} +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) + +define weak_odr <512 x i16> @tileloadd64_bf16(i16 %rows, i16 %colbytes, i8* %ptr, i64 %off, i64 %stride) nounwind alwaysinline readonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %rows, i16 %colbytes, i8* %1, i64 %stride) nounwind readonly + %3 = bitcast x86_amx %2 to <512 x i16> + ret <512 x i16> %3 +} + +define weak_odr <256 x i32> @tdpbssd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <256 x i32> @tdpbsud(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <256 x i32> @tdpbusd(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbusd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <256 x i32> @tdpbuud(i16 %rows, i16 %colbytes, i16 %acc, <256 x i32> %out, <1024 x i8> %lhs, <1024 x i8> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <1024 x i8> %lhs to x86_amx + %2 = bitcast <1024 x i8> %rhs to x86_amx + %3 = bitcast <256 x i32> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x i32> + ret <256 x i32> %5 +} +declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr <256 x float> @tdpbf16ps(i16 %rows, i16 %colbytes, i16 %acc, <256 x float> %out, <512 x i16> %lhs, <512 x i16> %rhs) nounwind alwaysinline readnone { + %1 = bitcast <512 x i16> %lhs to x86_amx + %2 = bitcast <512 x i16> %rhs to x86_amx + %3 = bitcast <256 x float> %out to x86_amx + %4 = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %rows, i16 %colbytes, i16 %acc, x86_amx %3, x86_amx %1, x86_amx %2) nounwind readnone + %5 = bitcast x86_amx %4 to <256 x float> + ret <256 x float> %5 +} +declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) + +define weak_odr i32 @tilestored64_i32(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x i32> %val) nounwind alwaysinline writeonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = bitcast <256 x i32> %val to x86_amx + tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly + ret i32 zeroinitializer ; return 0 since Halide has no void return value +} +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) + +define weak_odr i32 @tilestored64_f32(i16 %rows, i16 %cols, i8* %ptr, i64 %off, i64 %stride, <256 x float> %val) nounwind alwaysinline writeonly { + %1 = getelementptr i8, i8* %ptr, i64 %off + %2 = bitcast <256 x float> %val to x86_amx + tail call void @llvm.x86.tilestored64.internal(i16 %rows, i16 %cols, i8* %1, i64 %stride, x86_amx %2) nounwind writeonly + ret i32 zeroinitializer +} + +; NB: Even though this should be readnone, that will cause LLVM to try to +; generate a single zero tile, and copy it each time it is used. However the AMX +; registers cannot be copied, so this causes compilation failures: +; LLVM ERROR: Cannot emit physreg copy instruction +; renamable $tmm1 = COPY renamable $tmm0 +define weak_odr <256 x i32> @tilezero_i32(i16 %rows, i16 %colbytes) nounwind alwaysinline { + %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) nounwind + %2 = bitcast x86_amx %1 to <256 x i32> + ret <256 x i32> %2 +} + +define weak_odr <256 x float> @tilezero_f32(i16 %rows, i16 %colbytes) nounwind alwaysinline { + %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %rows, i16 %colbytes) nounwind + %2 = bitcast x86_amx %1 to <256 x float> + ret <256 x float> %2 +} +declare x86_amx @llvm.x86.tilezero.internal(i16, i16) diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 370fe711c663..743ece0565ff 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -309,6 +309,7 @@ tests(GROUPS correctness strided_load.cpp target.cpp thread_safety.cpp + tiled_matmul.cpp tracing.cpp tracing_bounds.cpp tracing_broadcast.cpp diff --git a/test/correctness/tiled_matmul.cpp b/test/correctness/tiled_matmul.cpp new file mode 100644 index 000000000000..7fbeedef3ecc --- /dev/null +++ b/test/correctness/tiled_matmul.cpp @@ -0,0 +1,265 @@ +#include "Halide.h" +#include + +using namespace Halide; + +void fill_buffer_a_bf16(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; ++iy) { + for (int ix = 0; ix < acc; ++ix) { + // value between 0 and 100 + bfloat16_t val = bfloat16_t(((float)rand() / (float)(RAND_MAX)) * 100.f); + buf(ix, iy) = val; + } + } +} + +void fill_buffer_b_bf16(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 2; ++iy) { + for (int ix = 0; ix < col; ++ix) { + for (int ik = 0; ik < 2; ++ik) { + bfloat16_t val = bfloat16_t(((float)rand() / (float)(RAND_MAX)) * 100.f); + buf(ik, ix, iy) = val; + } + } + } +} + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 4; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +bool equal_eps(float lhs, float rhs, float eps) { + return std::abs(lhs - rhs) < eps; +} + +struct make_uint_t { + template + Type operator()(Args &&...args) const { + return UInt(static_cast(args)...); + } +}; + +struct make_int_t { + template + Type operator()(Args &&...args) const { + return Int(static_cast(args)...); + } +}; + +template +bool matmul() { + constexpr int row = 16; + constexpr int col = 16; + constexpr int acc = 16; + + Buffer A_buf(acc, row); + Buffer B_buf(4, col, acc / 4); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 4, x, r / 4)); + + constexpr int tile_x = 8; + constexpr int tile_y = 8; + constexpr int tile_r = 4; + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + // schedule the consumer + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + + Buffer out(col, row); + + result.realize(out); + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 4, i, k / 4)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n"; + return false; + } + } + } + + return true; +} + +bool matmul_bf16() { + // lhs: 32x16, rhs: 16x32 + const int row = 32; + const int col = 32; + const int acc = 16; + + Var x("x"), y("y"); + Buffer A(acc, row); + Buffer B(2, col, acc / 2); + + RDom r(0, acc, "acc"); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(cast(A(r.x, y))) * cast(B(r.x % 2, x, r.x / 2)); + + int tile_x = 8; + int tile_y = 8; + int tile_r = 2; + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r.x, rro, rri, tile_r) + .reorder({rri, rxi, ryi, rro, x, y}) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + // schedule the consumer + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + fill_buffer_a_bf16(A, row, acc); + fill_buffer_b_bf16(B, col, acc); + + Buffer out(col, row); + + // Uncomment to check the asm + //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); + //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + + result.realize(out); + + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { + float val = 0.f; + for (int k = 0; k < acc; ++k) { + val += static_cast(A(k, j)) * static_cast(B(k % 2, i, k / 2)); + } + if (!equal_eps(val, out(i, j), 0.01f)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n"; + return false; + } + } + } + + return true; +} + +auto matmul_ss = &matmul; +auto matmul_us = &matmul; +auto matmul_su = &matmul; +auto matmul_uu = &matmul; + +int main(int argc, char **argv) { + Target t = get_jit_target_from_environment(); + if (!t.has_feature(Target::AVX512_SapphireRapids)) { + printf("[SKIP] No AMX target enabled\n"); + return 0; + } + + printf("Running AMX matmul (signed/signed)\n"); + if (!matmul_ss()) { + return -1; + } else { + printf("Success!\n"); + } + + // llvm >= 13.0 is required for unsigned and float AMX instructions + if (Halide::Internal::get_llvm_version() >= 130) { + printf("Running AMX matmul (signed/unsigned)\n"); + if (!matmul_su()) { + return -1; + } else { + printf("Success!\n"); + } + + printf("Running AMX matmul (unsigned/signed)\n"); + if (!matmul_us()) { + return -1; + } else { + printf("Success!\n"); + } + + printf("Running AMX matmul (unsigned/unsigned)\n"); + if (!matmul_uu()) { + return -1; + } else { + printf("Success!\n"); + } + + printf("Running AMX matmul (bf16)\n"); + if (!matmul_bf16()) { + return -1; + } else { + printf("Success!\n"); + } + } + return 0; +} \ No newline at end of file diff --git a/test/performance/CMakeLists.txt b/test/performance/CMakeLists.txt index 65aa41da00f3..80f58a16afae 100644 --- a/test/performance/CMakeLists.txt +++ b/test/performance/CMakeLists.txt @@ -1,5 +1,6 @@ tests(GROUPS performance SOURCES + tiled_matmul.cpp async_gpu.cpp block_transpose.cpp boundary_conditions.cpp diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp new file mode 100644 index 000000000000..2fd90683bd38 --- /dev/null +++ b/test/performance/tiled_matmul.cpp @@ -0,0 +1,256 @@ +#include "Halide.h" +#include "halide_benchmark.h" +#include "halide_test_dirs.h" + +#include +#include + +using namespace Halide; + +void fill_buffer_a_bf16(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; ++iy) { + for (int ix = 0; ix < acc; ++ix) { + // value between 0 and 100 + bfloat16_t val = bfloat16_t(((float)rand() / (float)(RAND_MAX)) * 100.f); + buf(ix, iy) = val; + } + } +} + +void fill_buffer_b_bf16(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 2; ++iy) { + for (int ix = 0; ix < col; ++ix) { + for (int ik = 0; ik < 2; ++ik) { + bfloat16_t val = bfloat16_t(((float)rand() / (float)(RAND_MAX)) * 100.f); + buf(ik, ix, iy) = val; + } + } + } +} + +struct make_uint_t { + template + Type operator()(Args &&...args) const { + return UInt(static_cast(args)...); + } +}; + +struct make_int_t { + template + Type operator()(Args &&...args) const { + return Int(static_cast(args)...); + } +}; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 4; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +template +bool matmul(Halide::Target target) { + // used for compiling to llvm IR or asm + (void)target; + + constexpr bool lhs_signed = std::is_signed::value; + constexpr bool rhs_signed = std::is_signed::value; + + auto lhs = typename std::conditional::type{}; + auto rhs = typename std::conditional::type{}; + + const int row = 16; + const int col = 16; + const int acc = 16; + + Var x("x"), y("y"); + ImageParam A(lhs(8), 2, "lhs"); + // NB the RHS matrix in AMX instructions should be tiled in "VNNI format", + // where instead of being (cols, rows) where rows are adjacent in memory it + // should be (4, cols, rows / 4) for int8, or (2, cols, rows / 2) for bf16. + // This means that the rows must always be divisible by 4 (or 2 for bf16). + ImageParam B(rhs(8), 3, "rhs"); + + RDom r(0, acc); + + Func mm("matmul"); + mm(y, x) = cast(0); + mm(y, x) += cast(A(r.x, x)) * B(r.x % 4, y, r.x / 4); + + // Ensure all (x, y) tile sizes are the same so that loops are fused. + int tile_y = 8; + int tile_x = 6; + int tile_r = 4; + + // Schedule the reduction + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + mm.compute_at(mm.in(), y) + .store_in(MemoryType::AMXTile) + .update() + // Split into (x,y) tile + .tile(y, x, ryi, rxi, tile_y, tile_x, TailStrategy::GuardWithIf) + // Split reduction dim by tile_r + .split(r.x, rro, rri, tile_r) + // Reorder so that the (x,y) tile is inside the inner ro loop + .reorder({rri, ryi, rxi, rro, y, x}) + .atomic() + .vectorize(rri) + .vectorize(ryi) + .vectorize(rxi); + + // Schedule the initialization + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), y) + .tile(y, x, iyi, ixi, tile_y, tile_x) + .vectorize(iyi) + .vectorize(ixi); + + // Schedule the consumer + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(y, x, mmyi, mmxi, tile_y, tile_x) + .vectorize(mmyi) + .vectorize(mmxi); + + Buffer a_buf(acc, row); + fill_buffer_a(a_buf, row, acc); + A.set(a_buf); + + Buffer b_buf(4, col, acc / 4); + fill_buffer_b(b_buf, col, acc); + B.set(b_buf); + + Buffer out(col, row); + + Func result = mm.in(); + + // Uncomment to check the asm + //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); + //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + + auto time = Tools::benchmark(20, 20, [&]() { + result.realize(out); + }); + std::cout << "Exec time: " << time << "\n"; + std::cout << "Success!\n"; + return true; +} + +auto matmul_ss = &matmul; +auto matmul_us = &matmul; +auto matmul_su = &matmul; +auto matmul_uu = &matmul; + +bool equal_eps(float lhs, float rhs, float eps) { + return std::abs(lhs - rhs) < eps; +} + +bool matmul_bf16(Halide::Target target) { + (void)target; + + // lhs: 32x16, rhs: 16x32 + const int row = 32; + const int col = 32; + const int acc = 16; + + Var x("x"), y("y"); + ImageParam A(BFloat(16), 2, "lhs"); + ImageParam B(BFloat(16), 3, "rhs"); + + RDom r(0, acc, "acc"); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(cast(A(r.x, y))) * cast(B(r.x % 2, x, r.x / 2)); + + int tile_x = 8; + int tile_y = 8; + int tile_r = 2; + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r.x, rro, rri, tile_r) + .reorder({rri, rxi, ryi, rro, x, y}) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + // schedule the consumer + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + Buffer a_buf(acc, row); + fill_buffer_a_bf16(a_buf, row, acc); + A.set(a_buf); + + Buffer b_buf(2, col, acc / 2); + fill_buffer_b_bf16(b_buf, col, acc); + B.set(b_buf); + + Buffer out(col, row); + + // Uncomment to check the asm + //result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); + //result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + + auto time = Tools::benchmark(20, 20, [&]() { + result.realize(out); + }); + + std::cout << "Exec time: " << time << "\n"; + std::cout << "Success!\n"; + return true; +} + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + if (!target.has_feature(Target::AVX512_SapphireRapids)) { + std::cout << "[SKIP] The tiled matmul test is only designed to test AMX support.\n"; + return 0; + } + + printf("Running AMX (signed/signed)\n"); + matmul_ss(target); + printf("Running AMX (unsigned/signed)\n"); + matmul_us(target); + printf("Running AMX (signed/unsigned)\n"); + matmul_su(target); + printf("Running AMX (unsigned/unsigned)\n"); + matmul_uu(target); + + printf("Running AMX (bf16)\n"); + matmul_bf16(target); + return 0; +}