Skip to content

Commit

Permalink
[mlir][llvm] Improve LoadOp and StoreOp import.
Browse files Browse the repository at this point in the history
The revision supports importing the volatile keyword and nontemporal
metadata for the LoadOp and StoreOp. Additionally, it updates the
builders and uses an assembly format for printing and parsing.

The operation type still requires custom parse and print methods
due to the current handling of typed and opaque pointers.

Reviewed By: Dinistro

Differential Revision: https://reviews.llvm.org/D143714
  • Loading branch information
gysit committed Feb 13, 2023
1 parent 0fad18c commit 240c6f2
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 112 deletions.
33 changes: 23 additions & 10 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,10 @@ class MemoryOpBase {
}];
code setNonTemporalMetadataCode = [{
if ($nontemporal) {
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::MDNode *metadata = llvm::MDNode::get(
inst->getContext(), llvm::ConstantAsMetadata::get(
builder.getInt32(1)));
inst->setMetadata(module->getMDKindID("nontemporal"), metadata);
inst->setMetadata(llvm::LLVMContext::MD_nontemporal, metadata);
}
}];
code setAccessGroupsMetadataCode = [{
Expand Down Expand Up @@ -355,6 +354,10 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
UnitAttr:$nontemporal);
let results = (outs LLVM_LoadableType:$res);
string llvmInstName = "Load";
let assemblyFormat = [{
(`volatile` $volatile_^)? $addr attr-dict `:`
custom<LoadType>(type($addr), type($res))
}];
string llvmBuilder = [{
auto *inst = builder.CreateLoad($_resultType, $addr, $volatile_);
}] # setAlignmentCode
Expand All @@ -365,9 +368,12 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
# [{
$res = inst;
}];
// FIXME: Import attributes.
string mlirBuilder = [{
$res = $_builder.create<LLVM::LoadOp>($_location, $_resultType, $addr);
auto *loadInst = cast<llvm::LoadInst>(inst);
unsigned alignment = loadInst->getAlign().value();
$res = $_builder.create<LLVM::LoadOp>($_location, $_resultType, $addr,
alignment, loadInst->isVolatile(),
loadInst->hasMetadata(llvm::LLVMContext::MD_nontemporal));
}];
let builders = [
OpBuilder<(ins "Value":$addr, CArg<"unsigned", "0">:$alignment,
Expand All @@ -378,9 +384,10 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
"when the pointer type is opaque");
build($_builder, $_state, type, addr, alignment, isVolatile, isNonTemporal);
}]>,
OpBuilder<(ins "Type":$t, "Value":$addr,
OpBuilder<(ins "Type":$type, "Value":$addr,
CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
CArg<"bool", "false">:$isNonTemporal)>,];
CArg<"bool", "false">:$isNonTemporal)>
];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
Expand All @@ -395,23 +402,29 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpBase {
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
UnitAttr:$nontemporal);
string llvmInstName = "Store";
let assemblyFormat = [{
(`volatile` $volatile_^)? $value `,` $addr attr-dict `:`
custom<StoreType>(type($value), type($addr))
}];
string llvmBuilder = [{
auto *inst = builder.CreateStore($value, $addr, $volatile_);
}] # setAlignmentCode
# setNonTemporalMetadataCode
# setAccessGroupsMetadataCode
# setAliasScopeMetadataCode
# setTBAAMetadataCode;
// FIXME: Import attributes.
string mlirBuilder = [{
$_op = $_builder.create<LLVM::StoreOp>($_location, $value, $addr);
auto *storeInst = cast<llvm::StoreInst>(inst);
unsigned alignment = storeInst->getAlign().value();
$_op = $_builder.create<LLVM::StoreOp>($_location, $value, $addr,
alignment, storeInst->isVolatile(),
storeInst->hasMetadata(llvm::LLVMContext::MD_nontemporal));
}];
let builders = [
OpBuilder<(ins "Value":$value, "Value":$addr,
CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
CArg<"bool", "false">:$isNonTemporal)>
];
let hasCustomAssemblyFormat = 1;
];
let hasVerifier = 1;
}

Expand Down
150 changes: 56 additions & 94 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage;

#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"

static constexpr const char kVolatileAttrName[] = "volatile_";
static constexpr const char kNonTemporalAttrName[] = "nontemporal";
static constexpr const char kElemTypeAttrName[] = "elem_type";

#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
Expand Down Expand Up @@ -665,7 +663,7 @@ Type LLVM::GEPOp::getSourceElementType() {
}

//===----------------------------------------------------------------------===//
// Builder, printer and parser for for LLVM::LoadOp.
// LoadOp
//===----------------------------------------------------------------------===//

/// Verifies the given array attribute contains symbol references and checks the
Expand Down Expand Up @@ -759,29 +757,13 @@ LogicalResult verifyMemOpMetadata(OpTy memOp) {

LogicalResult LoadOp::verify() { return verifyMemOpMetadata(*this); }

void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal) {
result.addOperands(addr);
result.addTypes(t);
if (isVolatile)
result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
if (isNonTemporal)
result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
if (alignment != 0)
result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
}

void LoadOp::print(OpAsmPrinter &p) {
p << ' ';
if (getVolatile_())
p << "volatile ";
p << getAddr();
p.printOptionalAttrDict((*this)->getAttrs(),
{kVolatileAttrName, kElemTypeAttrName});
p << " : " << getAddr().getType();
if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
p << " -> " << getType();
build(builder, state, type, addr, /*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
isNonTemporal);
}

// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
Expand All @@ -797,105 +779,85 @@ getLoadStoreElementType(OpAsmParser &parser, Type type, SMLoc trailingTypeLoc) {
return llvmTy.getElementType();
}

// <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
// (`->` type)?
ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand addr;
Type type;
/// Parses the LoadOp type either using the typed or opaque pointer format.
// TODO: Drop once the typed pointer assembly format is not needed anymore.
static ParseResult parseLoadType(OpAsmParser &parser, Type &type,
Type &elementType) {
SMLoc trailingTypeLoc;

if (succeeded(parser.parseOptionalKeyword("volatile")))
result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());

if (parser.parseOperand(addr) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
parser.resolveOperand(addr, type, result.operands))
if (parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
return failure();

std::optional<Type> elemTy =
std::optional<Type> pointerElementType =
getLoadStoreElementType(parser, type, trailingTypeLoc);
if (!elemTy)
if (!pointerElementType)
return failure();
if (*elemTy) {
result.addTypes(*elemTy);
if (*pointerElementType) {
elementType = *pointerElementType;
return success();
}

Type trailingType;
if (parser.parseArrow() || parser.parseType(trailingType))
if (parser.parseArrow() || parser.parseType(elementType))
return failure();
result.addTypes(trailingType);
return success();
}

/// Prints the LoadOp type either using the typed or opaque pointer format.
// TODO: Drop once the typed pointer assembly format is not needed anymore.
static void printLoadType(OpAsmPrinter &printer, Operation *op, Type type,
Type elementType) {
printer << type;
auto pointerType = cast<LLVMPointerType>(type);
if (pointerType.isOpaque())
printer << " -> " << elementType;
}

//===----------------------------------------------------------------------===//
// Builder, printer and parser for LLVM::StoreOp.
// StoreOp
//===----------------------------------------------------------------------===//

LogicalResult StoreOp::verify() { return verifyMemOpMetadata(*this); }

void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal) {
result.addOperands({value, addr});
result.addTypes({});
if (isVolatile)
result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
if (isNonTemporal)
result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
if (alignment != 0)
result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
build(builder, state, value, addr, /*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
isNonTemporal);
}

void StoreOp::print(OpAsmPrinter &p) {
p << ' ';
if (getVolatile_())
p << "volatile ";
p << getValue() << ", " << getAddr();
p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName});
p << " : ";
if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
p << getValue().getType() << ", ";
p << getAddr().getType();
}

// <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
// attribute-dict? `:` type (`,` type)?
ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand addr, value;
Type type;
/// Parses the StoreOp type either using the typed or opaque pointer format.
// TODO: Drop once the typed pointer assembly format is not needed anymore.
static ParseResult parseStoreType(OpAsmParser &parser, Type &elementType,
Type &type) {
SMLoc trailingTypeLoc;

if (succeeded(parser.parseOptionalKeyword("volatile")))
result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());

if (parser.parseOperand(value) || parser.parseComma() ||
parser.parseOperand(addr) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
if (parser.getCurrentLocation(&trailingTypeLoc) ||
parser.parseType(elementType))
return failure();

Type operandType;
if (succeeded(parser.parseOptionalComma())) {
operandType = type;
if (parser.parseType(type))
return failure();
} else {
std::optional<Type> maybeOperandType =
getLoadStoreElementType(parser, type, trailingTypeLoc);
if (!maybeOperandType)
return failure();
operandType = *maybeOperandType;
}
if (succeeded(parser.parseOptionalComma()))
return parser.parseType(type);

if (parser.resolveOperand(value, operandType, result.operands) ||
parser.resolveOperand(addr, type, result.operands))
// Extract the element type from the pointer type.
type = elementType;
std::optional<Type> pointerElementType =
getLoadStoreElementType(parser, type, trailingTypeLoc);
if (!pointerElementType)
return failure();

elementType = *pointerElementType;
return success();
}

/// Prints the StoreOp type either using the typed or opaque pointer format.
// TODO: Drop once the typed pointer assembly format is not needed anymore.
static void printStoreType(OpAsmPrinter &printer, Operation *op,
Type elementType, Type type) {
auto pointerType = cast<LLVMPointerType>(type);
if (pointerType.isOpaque())
printer << elementType << ", ";
printer << type;
}

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 11 additions & 4 deletions mlir/test/Target/LLVMIR/Import/instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ define void @integer_arith(i32 %arg1, i32 %arg2, i64 %arg3, i64 %arg4) {
; CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
; CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
define half @extract_element(ptr %vec, i32 %idx) {
; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr -> vector<4xf16>
; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] {{.*}} : !llvm.ptr -> vector<4xf16>
; CHECK: %[[V2:.+]] = llvm.extractelement %[[V1]][%[[IDX]] : i32] : vector<4xf16>
; CHECK: llvm.return %[[V2]]
%1 = load <4 x half>, ptr %vec
Expand All @@ -266,7 +266,7 @@ define half @extract_element(ptr %vec, i32 %idx) {
; CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
; CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
define <4 x half> @insert_element(ptr %vec, half %val, i32 %idx) {
; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr -> vector<4xf16>
; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] {{.*}} : !llvm.ptr -> vector<4xf16>
; CHECK: %[[V2:.+]] = llvm.insertelement %[[VAL]], %[[V1]][%[[IDX]] : i32] : vector<4xf16>
; CHECK: llvm.return %[[V2]]
%1 = load <4 x half>, ptr %vec
Expand Down Expand Up @@ -352,13 +352,20 @@ define ptr @alloca(i64 %size) {
; CHECK-LABEL: @load_store
; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
define void @load_store(ptr %ptr) {
; CHECK: %[[V1:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64
; CHECK: llvm.store %[[V1]], %[[PTR]] : f64, !llvm.ptr
; CHECK: %[[V1:[0-9]+]] = llvm.load %[[PTR]] {alignment = 8 : i64} : !llvm.ptr -> f64
; CHECK: %[[V2:[0-9]+]] = llvm.load volatile %[[PTR]] {alignment = 16 : i64, nontemporal} : !llvm.ptr -> f64
%1 = load double, ptr %ptr
%2 = load volatile double, ptr %ptr, align 16, !nontemporal !0

; CHECK: llvm.store %[[V1]], %[[PTR]] {alignment = 8 : i64} : f64, !llvm.ptr
; CHECK: llvm.store volatile %[[V2]], %[[PTR]] {alignment = 16 : i64, nontemporal} : f64, !llvm.ptr
store double %1, ptr %ptr
store volatile double %2, ptr %ptr, align 16, !nontemporal !0
ret void
}

!0 = !{i32 1}

; // -----

; CHECK-LABEL: @atomic_rmw
Expand Down
7 changes: 3 additions & 4 deletions mlir/test/Target/LLVMIR/Import/metadata-loop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
; CHECK: }

; CHECK-LABEL: llvm.func @access_group
; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
define void @access_group(ptr %arg1) {
; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]]}
; CHECK: access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]]
%1 = load i32, ptr %arg1, !llvm.access.group !0
; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]]}
; CHECK: access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]]
%2 = load i32, ptr %arg1, !llvm.access.group !1
; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP3]]]}
; CHECK: access_groups = [@__llvm_global_metadata::@[[$GROUP3]]]
%3 = load i32, ptr %arg1, !llvm.access.group !2
ret void
}
Expand Down

0 comments on commit 240c6f2

Please sign in to comment.