From 6cb4a43a4cdbf0f1a84e9f8279dc33185d5d8850 Mon Sep 17 00:00:00 2001 From: Ilya Veselov Date: Wed, 25 Sep 2024 19:38:29 +0200 Subject: [PATCH] Allow 16 bit floating point operand for LLVM_AtomicRMWOp As far as AMDGPU target supports vectorization for atomic_rmw operation, allow construction of LLVM_AtomicRMWOp with 16 bit floating point values. See also: #94845, #95393, #95394 Signed-off-by: Ilya Veselov --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 3 ++- mlir/include/mlir/IR/BuiltinTypes.h | 2 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 16 +++++++++++++--- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 2 +- mlir/test/Dialect/LLVMIR/invalid.mlir | 8 ++++++++ 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 030160821bd823..615c0a39f3acd0 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1737,7 +1737,8 @@ def LLVM_ConstantOp // Atomic operations. // -def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>; +def LLVM_AtomicRMWType + : AnyTypeOf<[LLVM_AnyPointer, AnySignlessInteger, LLVM_ScalarOrVectorOf]>; def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [ TypesMatchWith<"result #0 and operand #1 have the same type", diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 91e68b4066dd67..fc84401d5d6a4a 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -327,8 +327,8 @@ class VectorType::Builder { Builder &setShape(ArrayRef newShape, ArrayRef newIsScalableDim = {}) { - shape = newShape; scalableDims = newIsScalableDim; + shape = newShape; return *this; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 0561c364c7d591..99b3dc79fda664 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -3008,9 +3008,19 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, LogicalResult AtomicRMWOp::verify() { auto valType = getVal().getType(); - if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub || - getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) { - if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) + if (getBinOp() == AtomicBinOp::fadd && isCompatibleVectorType(valType)) { + // Currently, only fadd operation supports fixed vector operands. + if (isScalableVectorType(valType)) + return emitOpError("expected LLVM IR fixed vector type"); + Type elemType = getVectorElementType(valType); + if (!(isCompatibleFloatingPointType(elemType) && + elemType.getIntOrFloatBitWidth() == 16)) + return emitOpError("unexpected LLVM IR type for vector element"); + } else if (getBinOp() == AtomicBinOp::fadd || + getBinOp() == AtomicBinOp::fsub || + getBinOp() == AtomicBinOp::fmin || + getBinOp() == AtomicBinOp::fmax) { + if (!isCompatibleFloatingPointType(valType)) return emitOpError("expected LLVM IR floating point type"); } else if (getBinOp() == AtomicBinOp::xchg) { DataLayout dataLayout = DataLayout::closest(*this); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 7f10a15ff31ff9..baac02d82a25d7 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -663,7 +663,7 @@ verifyVectorConstructionInvariants(function_ref emitError, if (numElements == 0) return emitError() << "the number of vector elements must be positive"; - if (!VecTy::isValidElementType(elementType)) + if (!VecTy::isValidElementType(elementType) ^ VectorType::isValidElementType(elementType)) return emitError() << "invalid vector element type"; return success(); diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 9388d7ef24936e..51769bc04437b0 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -643,6 +643,14 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) { // ----- +func.func @atomicrmw_unexpected_vector_element(%i32_ptr : !llvm.ptr, %i16_fvec : vector<[3]xi16>) { + // expected-error@+1 {{unexpected LLVM IR type for vector element}} + %0 = llvm.atomicrmw fadd %i32_ptr, %i16_fvec unordered : !llvm.ptr, i32 + llvm.return +} + +// ----- + func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) { // expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}} %0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1