Skip to content

Commit

Permalink
Allow 16 bit floating point operand for LLVM_AtomicRMWOp
Browse files Browse the repository at this point in the history
As far as AMDGPU target supports vectorization for atomic_rmw operation,
allow construction of LLVM_AtomicRMWOp with 16 bit floating point values.
This patch enables building of LLVM_AtomicRMWOp with fixed vectors of
16 bit fp values as operands.

See also: llvm#94845, llvm#95393, llvm#95394

Signed-off-by: Ilya Veselov <iveselov.nn@gmail.com>
  • Loading branch information
joviliast committed Oct 1, 2024
1 parent 2d3119c commit 3728f65
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 25 deletions.
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ class LLVM_VectorOf<Type element> : Type<
class LLVM_ScalarOrVectorOf<Type element> :
AnyTypeOf<[element, LLVM_VectorOf<element>]>;

// Type constraint accepting an LLVM fixed vector type with an additional constraint
// on the vector element type.
class LLVM_FixedVectorOf<Type element> : Type<
And<[LLVM_AnyFixedVector.predicate,
SubstLeaves<
"$_self",
"::mlir::LLVM::getVectorElementType($_self)",
element.predicate>]>,
"LLVM dialect-compatible fixed vector of " # element.summary>;

// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder
// used to translate to proper LLVM IR and the interface to the mlir::OpBuilder
// used to import from LLVM IR.
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,8 @@ def LLVM_ConstantOp
// Atomic operations.
//

def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>;
def LLVM_AtomicRMWType
: AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger, LLVM_FixedVectorOf<LLVM_AnyFloat>]>;

def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
TypesMatchWith<"result #0 and operand #1 have the same type",
Expand Down
14 changes: 11 additions & 3 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3008,9 +3008,17 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,

LogicalResult AtomicRMWOp::verify() {
auto valType = getVal().getType();
if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
if (getBinOp() == AtomicBinOp::fadd && isCompatibleVectorType(valType)) {
Type elemType = getVectorElementType(valType);
// Only 16 bit floating point elements are supported for now.
if (!(isCompatibleFloatingPointType(elemType) &&
elemType.getIntOrFloatBitWidth() == 16))
return emitOpError("unexpected LLVM IR type for vector element");
} else if (getBinOp() == AtomicBinOp::fadd ||
getBinOp() == AtomicBinOp::fsub ||
getBinOp() == AtomicBinOp::fmin ||
getBinOp() == AtomicBinOp::fmax) {
if (!isCompatibleFloatingPointType(valType))
return emitOpError("expected LLVM IR floating point type");
} else if (getBinOp() == AtomicBinOp::xchg) {
DataLayout dataLayout = DataLayout::closest(*this);
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,14 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) {

// -----

func.func @atomicrmw_unexpected_vector_element(%ptr : !llvm.ptr, %f32_vec : vector<3xf32>) {
// expected-error@+1 {{unexpected LLVM IR type for vector element}}
%0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<3xf32>
llvm.return
}

// -----

func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) {
// expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}}
%0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1
Expand Down
6 changes: 4 additions & 2 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,13 @@ func.func @atomic_store(%val : f32, %large_val : i256, %ptr : !llvm.ptr) {
}

// CHECK-LABEL: @atomicrmw
func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32) {
func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32, %f16_vec : vector<2xf16>) {
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, f32
%0 = llvm.atomicrmw fadd %ptr, %val monotonic : !llvm.ptr, f32
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, vector<2xf16>
%1 = llvm.atomicrmw fadd %ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
// CHECK: llvm.atomicrmw volatile fsub %{{.*}}, %{{.*}} syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
%1 = llvm.atomicrmw volatile fsub %ptr, %val syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
%2 = llvm.atomicrmw volatile fsub %ptr, %val syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
llvm.return
}

Expand Down
41 changes: 22 additions & 19 deletions mlir/test/Target/LLVMIR/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1496,50 +1496,53 @@ llvm.func @elements_constant_3d_array() -> !llvm.array<2 x array<2 x array<2 x i
// CHECK-LABEL: @atomicrmw
llvm.func @atomicrmw(
%f32_ptr : !llvm.ptr, %f32 : f32,
%f16_vec_ptr : !llvm.ptr, %f16_vec : vector<2xf16>,
%i32_ptr : !llvm.ptr, %i32 : i32) {
// CHECK: atomicrmw fadd ptr %{{.*}}, float %{{.*}} monotonic
%0 = llvm.atomicrmw fadd %f32_ptr, %f32 monotonic : !llvm.ptr, f32
// CHECK: atomicrmw fadd ptr %{{.*}}, <2 x half> %{{.*}} monotonic
%1 = llvm.atomicrmw fadd %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
// CHECK: atomicrmw fsub ptr %{{.*}}, float %{{.*}} monotonic
%1 = llvm.atomicrmw fsub %f32_ptr, %f32 monotonic : !llvm.ptr, f32
%2 = llvm.atomicrmw fsub %f32_ptr, %f32 monotonic : !llvm.ptr, f32
// CHECK: atomicrmw fmax ptr %{{.*}}, float %{{.*}} monotonic
%2 = llvm.atomicrmw fmax %f32_ptr, %f32 monotonic : !llvm.ptr, f32
%3 = llvm.atomicrmw fmax %f32_ptr, %f32 monotonic : !llvm.ptr, f32
// CHECK: atomicrmw fmin ptr %{{.*}}, float %{{.*}} monotonic
%3 = llvm.atomicrmw fmin %f32_ptr, %f32 monotonic : !llvm.ptr, f32
%4 = llvm.atomicrmw fmin %f32_ptr, %f32 monotonic : !llvm.ptr, f32
// CHECK: atomicrmw xchg ptr %{{.*}}, float %{{.*}} monotonic
%4 = llvm.atomicrmw xchg %f32_ptr, %f32 monotonic : !llvm.ptr, f32
%5 = llvm.atomicrmw xchg %f32_ptr, %f32 monotonic : !llvm.ptr, f32
// CHECK: atomicrmw add ptr %{{.*}}, i32 %{{.*}} acquire
%5 = llvm.atomicrmw add %i32_ptr, %i32 acquire : !llvm.ptr, i32
%6 = llvm.atomicrmw add %i32_ptr, %i32 acquire : !llvm.ptr, i32
// CHECK: atomicrmw sub ptr %{{.*}}, i32 %{{.*}} release
%6 = llvm.atomicrmw sub %i32_ptr, %i32 release : !llvm.ptr, i32
%7 = llvm.atomicrmw sub %i32_ptr, %i32 release : !llvm.ptr, i32
// CHECK: atomicrmw and ptr %{{.*}}, i32 %{{.*}} acq_rel
%7 = llvm.atomicrmw _and %i32_ptr, %i32 acq_rel : !llvm.ptr, i32
%8 = llvm.atomicrmw _and %i32_ptr, %i32 acq_rel : !llvm.ptr, i32
// CHECK: atomicrmw nand ptr %{{.*}}, i32 %{{.*}} seq_cst
%8 = llvm.atomicrmw nand %i32_ptr, %i32 seq_cst : !llvm.ptr, i32
%9 = llvm.atomicrmw nand %i32_ptr, %i32 seq_cst : !llvm.ptr, i32
// CHECK: atomicrmw or ptr %{{.*}}, i32 %{{.*}} monotonic
%9 = llvm.atomicrmw _or %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%10 = llvm.atomicrmw _or %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw xor ptr %{{.*}}, i32 %{{.*}} monotonic
%10 = llvm.atomicrmw _xor %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%11 = llvm.atomicrmw _xor %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw max ptr %{{.*}}, i32 %{{.*}} monotonic
%11 = llvm.atomicrmw max %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%12 = llvm.atomicrmw max %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw min ptr %{{.*}}, i32 %{{.*}} monotonic
%12 = llvm.atomicrmw min %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%13 = llvm.atomicrmw min %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw umax ptr %{{.*}}, i32 %{{.*}} monotonic
%13 = llvm.atomicrmw umax %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%14 = llvm.atomicrmw umax %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw umin ptr %{{.*}}, i32 %{{.*}} monotonic
%14 = llvm.atomicrmw umin %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%15 = llvm.atomicrmw umin %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw uinc_wrap ptr %{{.*}}, i32 %{{.*}} monotonic
%15 = llvm.atomicrmw uinc_wrap %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%16 = llvm.atomicrmw uinc_wrap %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw udec_wrap ptr %{{.*}}, i32 %{{.*}} monotonic
%16 = llvm.atomicrmw udec_wrap %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%17 = llvm.atomicrmw udec_wrap %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw usub_cond ptr %{{.*}}, i32 %{{.*}} monotonic
%17 = llvm.atomicrmw usub_cond %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%18 = llvm.atomicrmw usub_cond %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw usub_sat ptr %{{.*}}, i32 %{{.*}} monotonic
%18 = llvm.atomicrmw usub_sat %i32_ptr, %i32 monotonic : !llvm.ptr, i32
%19 = llvm.atomicrmw usub_sat %i32_ptr, %i32 monotonic : !llvm.ptr, i32

// CHECK: atomicrmw volatile
// CHECK-SAME: syncscope("singlethread")
// CHECK-SAME: align 8
%19 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
%20 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
llvm.return
}

Expand Down

0 comments on commit 3728f65

Please sign in to comment.