Skip to content

Commit 0c41eea

Browse files
author
Jeff Niu
authored
[mlir][llvm] Port overflowFlags to a native operation property (#89312)
This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on operations as native properties.
1 parent 2132ebf commit 0c41eea

File tree

11 files changed

+165
-109
lines changed

11 files changed

+165
-109
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

+11-11
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,6 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
3131
LLVM::IntegerOverflowFlags
3232
convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
3333

34-
/// Creates an LLVM overflow attribute from a given arithmetic overflow
35-
/// attribute.
36-
LLVM::IntegerOverflowFlagsAttr
37-
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
38-
3934
/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
4035
/// mode enum value.
4136
LLVM::RoundingMode
@@ -72,6 +67,9 @@ class AttrConvertFastMathToLLVM {
7267
}
7368

7469
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
70+
LLVM::IntegerOverflowFlags getOverflowFlags() const {
71+
return LLVM::IntegerOverflowFlags::none;
72+
}
7573

7674
private:
7775
NamedAttrList convertedAttr;
@@ -89,19 +87,18 @@ class AttrConvertOverflowToLLVM {
8987
// Get the name of the arith overflow attribute.
9088
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
9189
// Remove the source overflow attribute.
92-
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
93-
convertedAttr.erase(arithAttrName));
94-
if (arithAttr) {
95-
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
96-
convertedAttr.set(targetAttrName,
97-
convertArithOverflowAttrToLLVM(arithAttr));
90+
if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
91+
convertedAttr.erase(arithAttrName))) {
92+
overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
9893
}
9994
}
10095

10196
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
97+
LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
10298

10399
private:
104100
NamedAttrList convertedAttr;
101+
LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
105102
};
106103

107104
template <typename SourceOp, typename TargetOp>
@@ -132,6 +129,9 @@ class AttrConverterConstrainedFPToLLVM {
132129
}
133130

134131
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
132+
LLVM::IntegerOverflowFlags getOverflowFlags() const {
133+
return LLVM::IntegerOverflowFlags::none;
134+
}
135135

136136
private:
137137
NamedAttrList convertedAttr;

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

+9-5
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,24 @@
1111

1212
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
1313
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14+
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1415
#include "mlir/Transforms/DialectConversion.h"
1516

1617
namespace mlir {
1718
class CallOpInterface;
1819

1920
namespace LLVM {
2021
namespace detail {
22+
/// Handle generically setting flags as native properties on LLVM operations.
23+
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
24+
2125
/// Replaces the given operation "op" with a new operation of type "targetOp"
2226
/// and given operands.
23-
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
24-
ValueRange operands,
25-
ArrayRef<NamedAttribute> targetAttrs,
26-
const LLVMTypeConverter &typeConverter,
27-
ConversionPatternRewriter &rewriter);
27+
LogicalResult oneToOneRewrite(
28+
Operation *op, StringRef targetOp, ValueRange operands,
29+
ArrayRef<NamedAttribute> targetAttrs,
30+
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
31+
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
2832

2933
} // namespace detail
3034
} // namespace LLVM

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

+10-6
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors(
5454
std::function<Value(Type, ValueRange)> createOperand,
5555
ConversionPatternRewriter &rewriter);
5656

57-
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
58-
ValueRange operands,
59-
ArrayRef<NamedAttribute> targetAttrs,
60-
const LLVMTypeConverter &typeConverter,
61-
ConversionPatternRewriter &rewriter);
57+
LogicalResult vectorOneToOneRewrite(
58+
Operation *op, StringRef targetOp, ValueRange operands,
59+
ArrayRef<NamedAttribute> targetAttrs,
60+
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
61+
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
6262
} // namespace detail
6363
} // namespace LLVM
6464

@@ -70,6 +70,9 @@ class AttrConvertPassThrough {
7070
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
7171

7272
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
73+
LLVM::IntegerOverflowFlags getOverflowFlags() const {
74+
return LLVM::IntegerOverflowFlags::none;
75+
}
7376

7477
private:
7578
ArrayRef<NamedAttribute> srcAttrs;
@@ -100,7 +103,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
100103

101104
return LLVM::detail::vectorOneToOneRewrite(
102105
op, TargetOp::getOperationName(), adaptor.getOperands(),
103-
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
106+
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
107+
attrConvert.getOverflowFlags());
104108
}
105109
};
106110
} // namespace mlir

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

+29-47
Original file line numberDiff line numberDiff line change
@@ -50,58 +50,40 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
5050

5151
def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
5252
let description = [{
53-
Access to op integer overflow flags.
53+
This interface defines an LLVM operation with integer overflow flags and
54+
provides a uniform API for accessing them.
5455
}];
5556

5657
let cppNamespace = "::mlir::LLVM";
5758

5859
let methods = [
59-
InterfaceMethod<
60-
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
61-
/*returnType=*/ "IntegerOverflowFlagsAttr",
62-
/*methodName=*/ "getOverflowAttr",
63-
/*args=*/ (ins),
64-
/*methodBody=*/ [{}],
65-
/*defaultImpl=*/ [{
66-
auto op = cast<ConcreteOp>(this->getOperation());
67-
return op.getOverflowFlagsAttr();
68-
}]
69-
>,
70-
InterfaceMethod<
71-
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
72-
/*returnType=*/ "bool",
73-
/*methodName=*/ "hasNoUnsignedWrap",
74-
/*args=*/ (ins),
75-
/*methodBody=*/ [{}],
76-
/*defaultImpl=*/ [{
77-
auto op = cast<ConcreteOp>(this->getOperation());
78-
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
79-
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
80-
}]
81-
>,
82-
InterfaceMethod<
83-
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
84-
/*returnType=*/ "bool",
85-
/*methodName=*/ "hasNoSignedWrap",
86-
/*args=*/ (ins),
87-
/*methodBody=*/ [{}],
88-
/*defaultImpl=*/ [{
89-
auto op = cast<ConcreteOp>(this->getOperation());
90-
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
91-
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
92-
}]
93-
>,
94-
StaticInterfaceMethod<
95-
/*desc=*/ [{Returns the name of the IntegerOverflowFlagsAttr attribute
96-
for the operation}],
97-
/*returnType=*/ "StringRef",
98-
/*methodName=*/ "getIntegerOverflowAttrName",
99-
/*args=*/ (ins),
100-
/*methodBody=*/ [{}],
101-
/*defaultImpl=*/ [{
102-
return "overflowFlags";
103-
}]
104-
>
60+
InterfaceMethod<[{
61+
Get the integer overflow flags for the operation.
62+
}], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{
63+
return $_op.getProperties().overflowFlags;
64+
}]>,
65+
InterfaceMethod<[{
66+
Set the integer overflow flags for the operation.
67+
}], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{
68+
$_op.getProperties().overflowFlags = flags;
69+
}]>,
70+
InterfaceMethod<[{
71+
Returns whether the operation has the No Unsigned Wrap keyword.
72+
}], "bool", "hasNoUnsignedWrap", (ins), [{}], [{
73+
return bitEnumContainsAll($_op.getOverflowFlags(),
74+
IntegerOverflowFlags::nuw);
75+
}]>,
76+
InterfaceMethod<[{
77+
Returns whether the operation has the No Signed Wrap keyword.
78+
}], "bool", "hasNoSignedWrap", (ins), [{}], [{
79+
return bitEnumContainsAll($_op.getOverflowFlags(),
80+
IntegerOverflowFlags::nsw);
81+
}]>,
82+
StaticInterfaceMethod<[{
83+
Get the attribute name of the overflow flags property.
84+
}], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{
85+
return "overflowFlags";
86+
}]>,
10587
];
10688
}
10789

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,16 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
6060
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
6161
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
6262
dag iofArg = (
63-
ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
63+
ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags);
6464
let arguments = !con(commonArgs, iofArg);
6565
string mlirBuilder = [{
6666
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
67-
moduleImport.setIntegerOverflowFlagsAttr(inst, op);
67+
moduleImport.setIntegerOverflowFlags(inst, op);
6868
$res = op;
6969
}];
7070
let assemblyFormat = [{
71-
$lhs `,` $rhs (`overflow` `` $overflowFlags^)?
72-
custom<LLVMOpAttrs>(attr-dict) `:` type($res)
71+
$lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags)
72+
`` custom<LLVMOpAttrs>(attr-dict) `:` type($res)
7373
}];
7474
string llvmBuilder =
7575
"$res = builder.Create" # instName #

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ class ModuleImport {
183183
/// Sets the integer overflow flags (nsw/nuw) attribute for the imported
184184
/// operation `op` given the original instruction `inst`. Asserts if the
185185
/// operation does not implement the integer overflow flag interface.
186-
void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
187-
Operation *op) const;
186+
void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
188187

189188
/// Sets the fastmath flags attribute for the imported operation `op` given
190189
/// the original instruction `inst`. Asserts if the operation does not

mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp

-7
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
4949
return llvmFlags;
5050
}
5151

52-
LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
53-
arith::IntegerOverflowFlagsAttr flagsAttr) {
54-
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
55-
return LLVM::IntegerOverflowFlagsAttr::get(
56-
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
57-
}
58-
5952
LLVM::RoundingMode
6053
mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
6154
switch (roundingMode) {

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

+13-6
Original file line numberDiff line numberDiff line change
@@ -329,14 +329,19 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
329329
// Detail methods
330330
//===----------------------------------------------------------------------===//
331331

332+
void LLVM::detail::setNativeProperties(Operation *op,
333+
IntegerOverflowFlags overflowFlags) {
334+
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
335+
iface.setOverflowFlags(overflowFlags);
336+
}
337+
332338
/// Replaces the given operation "op" with a new operation of type "targetOp"
333339
/// and given operands.
334-
LogicalResult
335-
LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
336-
ValueRange operands,
337-
ArrayRef<NamedAttribute> targetAttrs,
338-
const LLVMTypeConverter &typeConverter,
339-
ConversionPatternRewriter &rewriter) {
340+
LogicalResult LLVM::detail::oneToOneRewrite(
341+
Operation *op, StringRef targetOp, ValueRange operands,
342+
ArrayRef<NamedAttribute> targetAttrs,
343+
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
344+
IntegerOverflowFlags overflowFlags) {
340345
unsigned numResults = op->getNumResults();
341346

342347
SmallVector<Type> resultTypes;
@@ -352,6 +357,8 @@ LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
352357
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
353358
resultTypes, targetAttrs);
354359

360+
setNativeProperties(newOp, overflowFlags);
361+
355362
// If the operation produced 0 or 1 result, return them immediately.
356363
if (numResults == 0)
357364
return rewriter.eraseOp(op), success();

mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp

+13-13
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
103103
return success();
104104
}
105105

106-
LogicalResult
107-
LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
108-
ValueRange operands,
109-
ArrayRef<NamedAttribute> targetAttrs,
110-
const LLVMTypeConverter &typeConverter,
111-
ConversionPatternRewriter &rewriter) {
106+
LogicalResult LLVM::detail::vectorOneToOneRewrite(
107+
Operation *op, StringRef targetOp, ValueRange operands,
108+
ArrayRef<NamedAttribute> targetAttrs,
109+
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
110+
IntegerOverflowFlags overflowFlags) {
112111
assert(!operands.empty());
113112

114113
// Cannot convert ops if their operands are not of LLVM type.
@@ -118,14 +117,15 @@ LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
118117
auto llvmNDVectorTy = operands[0].getType();
119118
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
120119
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
121-
rewriter);
120+
rewriter, overflowFlags);
122121

123-
auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
124-
ValueRange operands) {
125-
return rewriter
126-
.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
127-
llvm1DVectorTy, targetAttrs)
128-
->getResult(0);
122+
auto callback = [op, targetOp, targetAttrs, overflowFlags,
123+
&rewriter](Type llvm1DVectorTy, ValueRange operands) {
124+
Operation *newOp =
125+
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
126+
operands, llvm1DVectorTy, targetAttrs);
127+
LLVM::detail::setNativeProperties(newOp, overflowFlags);
128+
return newOp->getResult(0);
129129
};
130130

131131
return handleMultidimensionalVectors(op, operands, typeConverter, callback,

0 commit comments

Comments
 (0)