Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement const lowering rule? #4

Merged
merged 3 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions query/include/Dialect/Sandbox/IR/SandboxBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def Sandbox_CmpPredicateAttr : I64EnumAttr<
I64EnumAttrCase<"GT", 2, "gt">,
I64EnumAttrCase<"GTEQ", 3, "gteq">,
I64EnumAttrCase<"EQ", 4, "eq">,
I64EnumAttrCase<"NEQ", 5, "eq">,
]> {
}

Expand Down
30 changes: 15 additions & 15 deletions query/include/Dialect/Sandbox/IR/SandboxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@

include "Dialect/Sandbox/IR/SandboxBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

include "mlir/IR/OpBase.td"

class Sandbox_Op<string mnemonic, list<Trait> traits = []> :
Op<Sandbox_Dialect, mnemonic, traits>;

def Sandbox_ConstantOp : Sandbox_Op<"constant", [ConstantLike, NoSideEffect]> {
let arguments = (ins SI32Attr:$value);
let results = (outs Sandbox_Int:$output);
let assemblyFormat = "$value attr-dict `:` type($output)";
let arguments = (ins AnyAttr:$value);
let results = (outs AnyType:$output);
let assemblyFormat = "`(`$value`)` attr-dict `:` type($output)";

let builders = [
OpBuilder<(ins "::mlir::Attribute":$value, "::mlir::Type":$type),
[{ build($_builder, $_state, type, value); }]>,
];

let hasFolder = true;
let hasVerifier = true;
}


def Sandbox_AddIOp : Sandbox_Op<"addi", [NoSideEffect, Commutative]> {
let arguments = (ins Sandbox_Int:$lhs, Sandbox_Int:$rhs);
let results = (outs Sandbox_Int:$result);
Expand All @@ -25,14 +32,7 @@ def Sandbox_AddIOp : Sandbox_Op<"addi", [NoSideEffect, Commutative]> {
let hasCanonicalizer = true;
}

def BoolOp : Sandbox_Op<"bool", [ConstantLike, NoSideEffect]> {
let arguments = (ins BoolAttr:$value);
let results = (outs Sandbox_Bool:$output);
let assemblyFormat = "$value attr-dict `:` type($output)";
}


def ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> {
def Sandbox_ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> {
let arguments = (ins Sandbox_CmpPredicateAttr:$predicate, Sandbox_Int:$lhs, Sandbox_Int:$rhs);
let results = (outs Sandbox_Bool:$output);
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($output)";
Expand All @@ -41,13 +41,13 @@ def ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> {
static CmpPredicate getPredicateByName(llvm::StringRef name);
}];

let hasFolder = 0;
let hasFolder = 1;
let hasCanonicalizer = 0;
}

def Sandbox_PrintOp : Sandbox_Op<"print"> {
let arguments = (ins Sandbox_Int:$value);
let assemblyFormat = "$value attr-dict";
let arguments = (ins AnyType:$value);
let assemblyFormat = "$value attr-dict `:` type($value)";
}

#endif // DIALECT_SANDBOX_SANDBOXOPS
59 changes: 57 additions & 2 deletions query/lib/Conversion/LowerSandbox/LowerSandbox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,58 @@ struct ConstantOpConversion : public mlir::OpConversionPattern<ir::ConstantOp> {
}
};

struct ComparisonOpConversion : public mlir::OpConversionPattern<ir::ComparisonOp> {
ComparisonOpConversion(mlir::MLIRContext *context)
: OpConversionPattern<ir::ComparisonOp>(context, /*benefit=*/1) {}

mlir::LogicalResult
matchAndRewrite(ir::ComparisonOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto lhs = castToMLIR(rewriter, op->getLoc(), adaptor.lhs());
auto rhs = castToMLIR(rewriter, op->getLoc(), adaptor.rhs());

mlir::arith::CmpIPredicate pred;

switch(adaptor.predicate()) {
case CmpPredicate::LT:
pred = mlir::arith::CmpIPredicate::slt;
break;
case CmpPredicate::LTEQ:
pred = mlir::arith::CmpIPredicate::sle;
break;
case CmpPredicate::GT:
pred = mlir::arith::CmpIPredicate::sgt;
break;
case CmpPredicate::GTEQ:
pred = mlir::arith::CmpIPredicate::sge;
break;
case CmpPredicate::EQ:
pred = mlir::arith::CmpIPredicate::eq;
break;
case CmpPredicate::NEQ:
pred = mlir::arith::CmpIPredicate::ne;
break;
}

rewriter.replaceOpWithNewOp<mlir::arith::CmpIOp>(op,pred, lhs, rhs);
return mlir::success();
}
};

struct PrintOpConversion : public mlir::OpConversionPattern<ir::PrintOp> {
PrintOpConversion(mlir::MLIRContext *context)
: OpConversionPattern<ir::PrintOp>(context, /*benefit=*/1) {}

mlir::LogicalResult
matchAndRewrite(ir::PrintOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto value2 = castToMLIR(rewriter, op->getLoc(), adaptor.value());
rewriter.replaceOpWithNewOp<ir::PrintOp>(op, value2);
return mlir::success();
}
};


} // end namespace

void LowerSandboxPass::runOnOperation() {
Expand All @@ -57,7 +109,7 @@ void LowerSandboxPass::runOnOperation() {

// Configure conversion to lower out SCF operations.
mlir::ConversionTarget target(getContext());
target.addIllegalOp<ir::ConstantOp, ir::AddIOp>();
target.addIllegalOp<ir::ConstantOp, ir::AddIOp, ir::ComparisonOp>();
target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand All @@ -66,7 +118,10 @@ void LowerSandboxPass::runOnOperation() {

void populateLowerSandboxConversionPatterns(
mlir::RewritePatternSet &patterns) {
patterns.add<ConstantOpConversion, AddIOpConversion>(patterns.getContext());
patterns.add<
ConstantOpConversion,
AddIOpConversion,
ComparisonOpConversion>(patterns.getContext());
}

std::unique_ptr<mlir::Pass> createLowerSandboxPass() {
Expand Down
66 changes: 65 additions & 1 deletion query/lib/Dialect/Sandbox/IR/SandboxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APSInt.h"
#include <cstddef>

#include "Dialect/Sandbox/IR/SandboxOpsEnums.cpp.inc"
Expand All @@ -20,6 +22,23 @@ mlir::OpFoldResult ConstantOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
return valueAttr();
}

mlir::LogicalResult ConstantOp::verify() {
auto type = getType();
// The value's type must match the return type.
auto valueType = value().getType();

if (valueType.isa<mlir::IntegerType>() && type.isa<IntType>()) {
return mlir::success();
}

if (valueType == mlir::IntegerType::get(getContext(), 1) && type.isa<BooleanType>()) {
return mlir::success();
}

return emitOpError() << "bad constant: type=" << type << ", valueType=" << valueType;
}


mlir::OpFoldResult AddIOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
assert(operands.size() == 2 && "binary op takes two operands");
if (!operands[0] || !operands[1])
Expand All @@ -37,6 +56,50 @@ mlir::OpFoldResult AddIOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
return {};
}


mlir::OpFoldResult ComparisonOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
assert(operands.size() == 2 && "comparison op takes two operands");
if (!operands[0] || !operands[1])
return {};


if (operands[0].isa<mlir::IntegerAttr>() && operands[1].isa<mlir::IntegerAttr>()) {
auto pred = predicate();
auto lhs = operands[0].cast<mlir::IntegerAttr>();
auto rhs = operands[1].cast<mlir::IntegerAttr>();


bool x;
auto l = lhs.getValue();
auto r = rhs.getValue();
switch(pred) {
case CmpPredicate::LT:
x = l.slt(r);
break;
case CmpPredicate::LTEQ:
x = l.sle(r);
break;
case CmpPredicate::GT:
x = l.sgt(r);
break;
case CmpPredicate::GTEQ:
x = l.sge(r);
break;
case CmpPredicate::EQ:
x = l == r;
break;
case CmpPredicate::NEQ:
x = l != r;
break;
}

auto result = mlir::BoolAttr::get(getContext(), x);
return result;
}

return {};
}

struct SimplifyAddConstAddConst : public mlir::OpRewritePattern<AddIOp> {
SimplifyAddConstAddConst(mlir::MLIRContext *context)
: OpRewritePattern<AddIOp>(context, /*benefit=*/1) {}
Expand All @@ -51,7 +114,8 @@ struct SimplifyAddConstAddConst : public mlir::OpRewritePattern<AddIOp> {

auto sumConst = rewriter.create<ConstantOp>(
mlir::FusedLoc::get(op->getContext(), {lConst->getLoc(), rConst.getLoc()}, nullptr),
lConst.getType(), lConst.value() + rConst.value());
lConst.getType(),
mlir::IntegerAttr::get(lConst.getType(), lConst.value().cast<mlir::IntegerAttr>().getValue() + rConst.value().cast<mlir::IntegerAttr>().getValue()));
rewriter.replaceOpWithNewOp<AddIOp>(op, lhs.result().getType(), lhs.lhs(), sumConst);
return mlir::success();
}
Expand Down
17 changes: 12 additions & 5 deletions query/test/current.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
func.func @foo(%in: !sb.int) -> (!sb.bool) {
%i1 = sb.constant 5 : !sb.int
%i2 = sb.constant 7 : !sb.int
func.func @foo(%in: !sb.int) -> () {
%i1 = sb.constant (5) : !sb.int
%i2 = sb.constant (7) : !sb.int
%i3 = sb.addi %in %i1
%i4 = sb.addi %i3 %i2
%i5 = sb.constant 6 : !sb.int
%i5 = sb.constant (6) : !sb.int
%i6 = sb.compare eq, %i4, %i5 : !sb.bool
func.return %i6 : !sb.bool
sb.print %i6 : !sb.bool
func.return
}

func.func @bar() -> () {
%i1 = sb.constant (true) : !sb.bool
sb.print %i1 : !sb.bool
func.return
}