From 3728f655f042dbd9020a3c92506d455af8dbf841 Mon Sep 17 00:00:00 2001 From: Ilya Veselov Date: Wed, 25 Sep 2024 19:38:29 +0200 Subject: [PATCH] Allow 16 bit floating point operand for LLVM_AtomicRMWOp As far as AMDGPU target supports vectorization for atomic_rmw operation, allow construction of LLVM_AtomicRMWOp with 16 bit floating point values. This patch enables building of LLVM_AtomicRMWOp with fixed vectors of 16 bit fp values as operands. See also: #94845, #95393, #95394 Signed-off-by: Ilya Veselov --- .../include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 10 +++++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 3 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 14 +++++-- mlir/test/Dialect/LLVMIR/invalid.mlir | 8 ++++ mlir/test/Dialect/LLVMIR/roundtrip.mlir | 6 ++- mlir/test/Target/LLVMIR/llvmir.mlir | 41 ++++++++++--------- 6 files changed, 57 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index c3d352d8d0dd48..fa16f098cc6a2f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -139,6 +139,16 @@ class LLVM_VectorOf : Type< class LLVM_ScalarOrVectorOf : AnyTypeOf<[element, LLVM_VectorOf]>; +// Type constraint accepting an LLVM fixed vector type with an additional constraint +// on the vector element type. +class LLVM_FixedVectorOf : Type< + And<[LLVM_AnyFixedVector.predicate, + SubstLeaves< + "$_self", + "::mlir::LLVM::getVectorElementType($_self)", + element.predicate>]>, + "LLVM dialect-compatible fixed vector of " # element.summary>; + // Base class for LLVM operations. Defines the interface to the llvm::IRBuilder // used to translate to proper LLVM IR and the interface to the mlir::OpBuilder // used to import from LLVM IR. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 030160821bd823..beb8723f3bcd3b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1737,7 +1737,8 @@ def LLVM_ConstantOp // Atomic operations. // -def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>; +def LLVM_AtomicRMWType + : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger, LLVM_FixedVectorOf]>; def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [ TypesMatchWith<"result #0 and operand #1 have the same type", diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 0561c364c7d591..ccdc4c79d7d189 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -3008,9 +3008,17 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, LogicalResult AtomicRMWOp::verify() { auto valType = getVal().getType(); - if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub || - getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) { - if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) + if (getBinOp() == AtomicBinOp::fadd && isCompatibleVectorType(valType)) { + Type elemType = getVectorElementType(valType); + // Only 16 bit floating point elements are supported for now. + if (!(isCompatibleFloatingPointType(elemType) && + elemType.getIntOrFloatBitWidth() == 16)) + return emitOpError("unexpected LLVM IR type for vector element"); + } else if (getBinOp() == AtomicBinOp::fadd || + getBinOp() == AtomicBinOp::fsub || + getBinOp() == AtomicBinOp::fmin || + getBinOp() == AtomicBinOp::fmax) { + if (!isCompatibleFloatingPointType(valType)) return emitOpError("expected LLVM IR floating point type"); } else if (getBinOp() == AtomicBinOp::xchg) { DataLayout dataLayout = DataLayout::closest(*this); diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 9388d7ef24936e..0e8c473fc3257d 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -643,6 +643,14 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) { // ----- +func.func @atomicrmw_unexpected_vector_element(%ptr : !llvm.ptr, %f32_vec : vector<3xf32>) { + // expected-error@+1 {{unexpected LLVM IR type for vector element}} + %0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<3xf32> + llvm.return +} + +// ----- + func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) { // expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}} %0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1 diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 62f1de2b7fe7d4..23db1fc1446791 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -420,11 +420,13 @@ func.func @atomic_store(%val : f32, %large_val : i256, %ptr : !llvm.ptr) { } // CHECK-LABEL: @atomicrmw -func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32) { +func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32, %f16_vec : vector<2xf16>) { // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, f32 %0 = llvm.atomicrmw fadd %ptr, %val monotonic : !llvm.ptr, f32 + // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, vector<2xf16> + %1 = llvm.atomicrmw fadd %ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16> // CHECK: llvm.atomicrmw volatile fsub %{{.*}}, %{{.*}} syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32 - %1 = llvm.atomicrmw volatile fsub %ptr, %val syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32 + %2 = llvm.atomicrmw volatile fsub %ptr, %val syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32 llvm.return } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 007284d0ca4435..96a1ad90dbef22 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1496,50 +1496,53 @@ llvm.func @elements_constant_3d_array() -> !llvm.array<2 x array<2 x array<2 x i // CHECK-LABEL: @atomicrmw llvm.func @atomicrmw( %f32_ptr : !llvm.ptr, %f32 : f32, + %f16_vec_ptr : !llvm.ptr, %f16_vec : vector<2xf16>, %i32_ptr : !llvm.ptr, %i32 : i32) { // CHECK: atomicrmw fadd ptr %{{.*}}, float %{{.*}} monotonic %0 = llvm.atomicrmw fadd %f32_ptr, %f32 monotonic : !llvm.ptr, f32 + // CHECK: atomicrmw fadd ptr %{{.*}}, <2 x half> %{{.*}} monotonic + %1 = llvm.atomicrmw fadd %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16> // CHECK: atomicrmw fsub ptr %{{.*}}, float %{{.*}} monotonic - %1 = llvm.atomicrmw fsub %f32_ptr, %f32 monotonic : !llvm.ptr, f32 + %2 = llvm.atomicrmw fsub %f32_ptr, %f32 monotonic : !llvm.ptr, f32 // CHECK: atomicrmw fmax ptr %{{.*}}, float %{{.*}} monotonic - %2 = llvm.atomicrmw fmax %f32_ptr, %f32 monotonic : !llvm.ptr, f32 + %3 = llvm.atomicrmw fmax %f32_ptr, %f32 monotonic : !llvm.ptr, f32 // CHECK: atomicrmw fmin ptr %{{.*}}, float %{{.*}} monotonic - %3 = llvm.atomicrmw fmin %f32_ptr, %f32 monotonic : !llvm.ptr, f32 + %4 = llvm.atomicrmw fmin %f32_ptr, %f32 monotonic : !llvm.ptr, f32 // CHECK: atomicrmw xchg ptr %{{.*}}, float %{{.*}} monotonic - %4 = llvm.atomicrmw xchg %f32_ptr, %f32 monotonic : !llvm.ptr, f32 + %5 = llvm.atomicrmw xchg %f32_ptr, %f32 monotonic : !llvm.ptr, f32 // CHECK: atomicrmw add ptr %{{.*}}, i32 %{{.*}} acquire - %5 = llvm.atomicrmw add %i32_ptr, %i32 acquire : !llvm.ptr, i32 + %6 = llvm.atomicrmw add %i32_ptr, %i32 acquire : !llvm.ptr, i32 // CHECK: atomicrmw sub ptr %{{.*}}, i32 %{{.*}} release - %6 = llvm.atomicrmw sub %i32_ptr, %i32 release : !llvm.ptr, i32 + %7 = llvm.atomicrmw sub %i32_ptr, %i32 release : !llvm.ptr, i32 // CHECK: atomicrmw and ptr %{{.*}}, i32 %{{.*}} acq_rel - %7 = llvm.atomicrmw _and %i32_ptr, %i32 acq_rel : !llvm.ptr, i32 + %8 = llvm.atomicrmw _and %i32_ptr, %i32 acq_rel : !llvm.ptr, i32 // CHECK: atomicrmw nand ptr %{{.*}}, i32 %{{.*}} seq_cst - %8 = llvm.atomicrmw nand %i32_ptr, %i32 seq_cst : !llvm.ptr, i32 + %9 = llvm.atomicrmw nand %i32_ptr, %i32 seq_cst : !llvm.ptr, i32 // CHECK: atomicrmw or ptr %{{.*}}, i32 %{{.*}} monotonic - %9 = llvm.atomicrmw _or %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %10 = llvm.atomicrmw _or %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw xor ptr %{{.*}}, i32 %{{.*}} monotonic - %10 = llvm.atomicrmw _xor %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %11 = llvm.atomicrmw _xor %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw max ptr %{{.*}}, i32 %{{.*}} monotonic - %11 = llvm.atomicrmw max %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %12 = llvm.atomicrmw max %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw min ptr %{{.*}}, i32 %{{.*}} monotonic - %12 = llvm.atomicrmw min %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %13 = llvm.atomicrmw min %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw umax ptr %{{.*}}, i32 %{{.*}} monotonic - %13 = llvm.atomicrmw umax %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %14 = llvm.atomicrmw umax %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw umin ptr %{{.*}}, i32 %{{.*}} monotonic - %14 = llvm.atomicrmw umin %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %15 = llvm.atomicrmw umin %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw uinc_wrap ptr %{{.*}}, i32 %{{.*}} monotonic - %15 = llvm.atomicrmw uinc_wrap %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %16 = llvm.atomicrmw uinc_wrap %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw udec_wrap ptr %{{.*}}, i32 %{{.*}} monotonic - %16 = llvm.atomicrmw udec_wrap %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %17 = llvm.atomicrmw udec_wrap %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw usub_cond ptr %{{.*}}, i32 %{{.*}} monotonic - %17 = llvm.atomicrmw usub_cond %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %18 = llvm.atomicrmw usub_cond %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw usub_sat ptr %{{.*}}, i32 %{{.*}} monotonic - %18 = llvm.atomicrmw usub_sat %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + %19 = llvm.atomicrmw usub_sat %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw volatile // CHECK-SAME: syncscope("singlethread") // CHECK-SAME: align 8 - %19 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32 + %20 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32 llvm.return }