Skip to content

Commit

Permalink
Merge pull request #4 from tpoterba/tp-mlir-tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoterba authored Nov 8, 2022
2 parents ebe1395 + f934a20 commit 38f6f1d
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 23 deletions.
1 change: 1 addition & 0 deletions query/include/Dialect/Sandbox/IR/SandboxBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def Sandbox_CmpPredicateAttr : I64EnumAttr<
I64EnumAttrCase<"GT", 2, "gt">,
I64EnumAttrCase<"GTEQ", 3, "gteq">,
I64EnumAttrCase<"EQ", 4, "eq">,
I64EnumAttrCase<"NEQ", 5, "eq">,
]> {
}

Expand Down
30 changes: 15 additions & 15 deletions query/include/Dialect/Sandbox/IR/SandboxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@

include "Dialect/Sandbox/IR/SandboxBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

include "mlir/IR/OpBase.td"

class Sandbox_Op<string mnemonic, list<Trait> traits = []> :
Op<Sandbox_Dialect, mnemonic, traits>;

def Sandbox_ConstantOp : Sandbox_Op<"constant", [ConstantLike, NoSideEffect]> {
let arguments = (ins SI32Attr:$value);
let results = (outs Sandbox_Int:$output);
let assemblyFormat = "$value attr-dict `:` type($output)";
let arguments = (ins AnyAttr:$value);
let results = (outs AnyType:$output);
let assemblyFormat = "`(`$value`)` attr-dict `:` type($output)";

let builders = [
OpBuilder<(ins "::mlir::Attribute":$value, "::mlir::Type":$type),
[{ build($_builder, $_state, type, value); }]>,
];

let hasFolder = true;
let hasVerifier = true;
}


def Sandbox_AddIOp : Sandbox_Op<"addi", [NoSideEffect, Commutative]> {
let arguments = (ins Sandbox_Int:$lhs, Sandbox_Int:$rhs);
let results = (outs Sandbox_Int:$result);
Expand All @@ -25,14 +32,7 @@ def Sandbox_AddIOp : Sandbox_Op<"addi", [NoSideEffect, Commutative]> {
let hasCanonicalizer = true;
}

def BoolOp : Sandbox_Op<"bool", [ConstantLike, NoSideEffect]> {
let arguments = (ins BoolAttr:$value);
let results = (outs Sandbox_Bool:$output);
let assemblyFormat = "$value attr-dict `:` type($output)";
}


def ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> {
def Sandbox_ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> {
let arguments = (ins Sandbox_CmpPredicateAttr:$predicate, Sandbox_Int:$lhs, Sandbox_Int:$rhs);
let results = (outs Sandbox_Bool:$output);
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($output)";
Expand All @@ -41,13 +41,13 @@ def ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> {
static CmpPredicate getPredicateByName(llvm::StringRef name);
}];

let hasFolder = 0;
let hasFolder = 1;
let hasCanonicalizer = 0;
}

def Sandbox_PrintOp : Sandbox_Op<"print"> {
let arguments = (ins Sandbox_Int:$value);
let assemblyFormat = "$value attr-dict";
let arguments = (ins AnyType:$value);
let assemblyFormat = "$value attr-dict `:` type($value)";
}

#endif // DIALECT_SANDBOX_SANDBOXOPS
59 changes: 57 additions & 2 deletions query/lib/Conversion/LowerSandbox/LowerSandbox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,58 @@ struct ConstantOpConversion : public mlir::OpConversionPattern<ir::ConstantOp> {
}
};

struct ComparisonOpConversion : public mlir::OpConversionPattern<ir::ComparisonOp> {
ComparisonOpConversion(mlir::MLIRContext *context)
: OpConversionPattern<ir::ComparisonOp>(context, /*benefit=*/1) {}

mlir::LogicalResult
matchAndRewrite(ir::ComparisonOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto lhs = castToMLIR(rewriter, op->getLoc(), adaptor.lhs());
auto rhs = castToMLIR(rewriter, op->getLoc(), adaptor.rhs());

mlir::arith::CmpIPredicate pred;

switch(adaptor.predicate()) {
case CmpPredicate::LT:
pred = mlir::arith::CmpIPredicate::slt;
break;
case CmpPredicate::LTEQ:
pred = mlir::arith::CmpIPredicate::sle;
break;
case CmpPredicate::GT:
pred = mlir::arith::CmpIPredicate::sgt;
break;
case CmpPredicate::GTEQ:
pred = mlir::arith::CmpIPredicate::sge;
break;
case CmpPredicate::EQ:
pred = mlir::arith::CmpIPredicate::eq;
break;
case CmpPredicate::NEQ:
pred = mlir::arith::CmpIPredicate::ne;
break;
}

rewriter.replaceOpWithNewOp<mlir::arith::CmpIOp>(op,pred, lhs, rhs);
return mlir::success();
}
};

struct PrintOpConversion : public mlir::OpConversionPattern<ir::PrintOp> {
PrintOpConversion(mlir::MLIRContext *context)
: OpConversionPattern<ir::PrintOp>(context, /*benefit=*/1) {}

mlir::LogicalResult
matchAndRewrite(ir::PrintOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto value2 = castToMLIR(rewriter, op->getLoc(), adaptor.value());
rewriter.replaceOpWithNewOp<ir::PrintOp>(op, value2);
return mlir::success();
}
};


} // end namespace

void LowerSandboxPass::runOnOperation() {
Expand All @@ -57,7 +109,7 @@ void LowerSandboxPass::runOnOperation() {

// Configure conversion to lower out SCF operations.
mlir::ConversionTarget target(getContext());
target.addIllegalOp<ir::ConstantOp, ir::AddIOp>();
target.addIllegalOp<ir::ConstantOp, ir::AddIOp, ir::ComparisonOp>();
target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand All @@ -66,7 +118,10 @@ void LowerSandboxPass::runOnOperation() {

void populateLowerSandboxConversionPatterns(
mlir::RewritePatternSet &patterns) {
patterns.add<ConstantOpConversion, AddIOpConversion>(patterns.getContext());
patterns.add<
ConstantOpConversion,
AddIOpConversion,
ComparisonOpConversion>(patterns.getContext());
}

std::unique_ptr<mlir::Pass> createLowerSandboxPass() {
Expand Down
66 changes: 65 additions & 1 deletion query/lib/Dialect/Sandbox/IR/SandboxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APSInt.h"
#include <cstddef>

#include "Dialect/Sandbox/IR/SandboxOpsEnums.cpp.inc"
Expand All @@ -20,6 +22,23 @@ mlir::OpFoldResult ConstantOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
return valueAttr();
}

mlir::LogicalResult ConstantOp::verify() {
auto type = getType();
// The value's type must match the return type.
auto valueType = value().getType();

if (valueType.isa<mlir::IntegerType>() && type.isa<IntType>()) {
return mlir::success();
}

if (valueType == mlir::IntegerType::get(getContext(), 1) && type.isa<BooleanType>()) {
return mlir::success();
}

return emitOpError() << "bad constant: type=" << type << ", valueType=" << valueType;
}


mlir::OpFoldResult AddIOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
assert(operands.size() == 2 && "binary op takes two operands");
if (!operands[0] || !operands[1])
Expand All @@ -37,6 +56,50 @@ mlir::OpFoldResult AddIOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
return {};
}


mlir::OpFoldResult ComparisonOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
assert(operands.size() == 2 && "comparison op takes two operands");
if (!operands[0] || !operands[1])
return {};


if (operands[0].isa<mlir::IntegerAttr>() && operands[1].isa<mlir::IntegerAttr>()) {
auto pred = predicate();
auto lhs = operands[0].cast<mlir::IntegerAttr>();
auto rhs = operands[1].cast<mlir::IntegerAttr>();


bool x;
auto l = lhs.getValue();
auto r = rhs.getValue();
switch(pred) {
case CmpPredicate::LT:
x = l.slt(r);
break;
case CmpPredicate::LTEQ:
x = l.sle(r);
break;
case CmpPredicate::GT:
x = l.sgt(r);
break;
case CmpPredicate::GTEQ:
x = l.sge(r);
break;
case CmpPredicate::EQ:
x = l == r;
break;
case CmpPredicate::NEQ:
x = l != r;
break;
}

auto result = mlir::BoolAttr::get(getContext(), x);
return result;
}

return {};
}

struct SimplifyAddConstAddConst : public mlir::OpRewritePattern<AddIOp> {
SimplifyAddConstAddConst(mlir::MLIRContext *context)
: OpRewritePattern<AddIOp>(context, /*benefit=*/1) {}
Expand All @@ -51,7 +114,8 @@ struct SimplifyAddConstAddConst : public mlir::OpRewritePattern<AddIOp> {

auto sumConst = rewriter.create<ConstantOp>(
mlir::FusedLoc::get(op->getContext(), {lConst->getLoc(), rConst.getLoc()}, nullptr),
lConst.getType(), lConst.value() + rConst.value());
lConst.getType(),
mlir::IntegerAttr::get(lConst.getType(), lConst.value().cast<mlir::IntegerAttr>().getValue() + rConst.value().cast<mlir::IntegerAttr>().getValue()));
rewriter.replaceOpWithNewOp<AddIOp>(op, lhs.result().getType(), lhs.lhs(), sumConst);
return mlir::success();
}
Expand Down
17 changes: 12 additions & 5 deletions query/test/current.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
func.func @foo(%in: !sb.int) -> (!sb.bool) {
%i1 = sb.constant 5 : !sb.int
%i2 = sb.constant 7 : !sb.int
func.func @foo(%in: !sb.int) -> () {
%i1 = sb.constant (5) : !sb.int
%i2 = sb.constant (7) : !sb.int
%i3 = sb.addi %in %i1
%i4 = sb.addi %i3 %i2
%i5 = sb.constant 6 : !sb.int
%i5 = sb.constant (6) : !sb.int
%i6 = sb.compare eq, %i4, %i5 : !sb.bool
func.return %i6 : !sb.bool
sb.print %i6 : !sb.bool
func.return
}

func.func @bar() -> () {
%i1 = sb.constant (true) : !sb.bool
sb.print %i1 : !sb.bool
func.return
}

0 comments on commit 38f6f1d

Please sign in to comment.