diff --git a/query/include/Dialect/Sandbox/IR/SandboxBase.td b/query/include/Dialect/Sandbox/IR/SandboxBase.td index cb64a2b8205..3cd79d266ef 100644 --- a/query/include/Dialect/Sandbox/IR/SandboxBase.td +++ b/query/include/Dialect/Sandbox/IR/SandboxBase.td @@ -49,6 +49,7 @@ def Sandbox_CmpPredicateAttr : I64EnumAttr< I64EnumAttrCase<"GT", 2, "gt">, I64EnumAttrCase<"GTEQ", 3, "gteq">, I64EnumAttrCase<"EQ", 4, "eq">, + I64EnumAttrCase<"NEQ", 5, "eq">, ]> { } diff --git a/query/include/Dialect/Sandbox/IR/SandboxOps.td b/query/include/Dialect/Sandbox/IR/SandboxOps.td index 8fe51ec723d..5fd43c9850e 100644 --- a/query/include/Dialect/Sandbox/IR/SandboxOps.td +++ b/query/include/Dialect/Sandbox/IR/SandboxOps.td @@ -3,19 +3,26 @@ include "Dialect/Sandbox/IR/SandboxBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" - +include "mlir/IR/OpBase.td" class Sandbox_Op traits = []> : Op; def Sandbox_ConstantOp : Sandbox_Op<"constant", [ConstantLike, NoSideEffect]> { - let arguments = (ins SI32Attr:$value); - let results = (outs Sandbox_Int:$output); - let assemblyFormat = "$value attr-dict `:` type($output)"; + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType:$output); + let assemblyFormat = "`(`$value`)` attr-dict `:` type($output)"; + + let builders = [ + OpBuilder<(ins "::mlir::Attribute":$value, "::mlir::Type":$type), + [{ build($_builder, $_state, type, value); }]>, + ]; let hasFolder = true; + let hasVerifier = true; } + def Sandbox_AddIOp : Sandbox_Op<"addi", [NoSideEffect, Commutative]> { let arguments = (ins Sandbox_Int:$lhs, Sandbox_Int:$rhs); let results = (outs Sandbox_Int:$result); @@ -25,14 +32,7 @@ def Sandbox_AddIOp : Sandbox_Op<"addi", [NoSideEffect, Commutative]> { let hasCanonicalizer = true; } -def BoolOp : Sandbox_Op<"bool", [ConstantLike, NoSideEffect]> { - let arguments = (ins BoolAttr:$value); - let results = (outs Sandbox_Bool:$output); - let assemblyFormat = "$value attr-dict `:` type($output)"; -} - - -def ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> { +def Sandbox_ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> { let arguments = (ins Sandbox_CmpPredicateAttr:$predicate, Sandbox_Int:$lhs, Sandbox_Int:$rhs); let results = (outs Sandbox_Bool:$output); let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($output)"; @@ -41,13 +41,13 @@ def ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> { static CmpPredicate getPredicateByName(llvm::StringRef name); }]; - let hasFolder = 0; + let hasFolder = 1; let hasCanonicalizer = 0; } def Sandbox_PrintOp : Sandbox_Op<"print"> { - let arguments = (ins Sandbox_Int:$value); - let assemblyFormat = "$value attr-dict"; + let arguments = (ins AnyType:$value); + let assemblyFormat = "$value attr-dict `:` type($value)"; } #endif // DIALECT_SANDBOX_SANDBOXOPS diff --git a/query/lib/Conversion/LowerSandbox/LowerSandbox.cpp b/query/lib/Conversion/LowerSandbox/LowerSandbox.cpp index 3f77e80dceb..85213bf4ae5 100644 --- a/query/lib/Conversion/LowerSandbox/LowerSandbox.cpp +++ b/query/lib/Conversion/LowerSandbox/LowerSandbox.cpp @@ -49,6 +49,58 @@ struct ConstantOpConversion : public mlir::OpConversionPattern { } }; +struct ComparisonOpConversion : public mlir::OpConversionPattern { + ComparisonOpConversion(mlir::MLIRContext *context) + : OpConversionPattern(context, /*benefit=*/1) {} + + mlir::LogicalResult + matchAndRewrite(ir::ComparisonOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto lhs = castToMLIR(rewriter, op->getLoc(), adaptor.lhs()); + auto rhs = castToMLIR(rewriter, op->getLoc(), adaptor.rhs()); + + mlir::arith::CmpIPredicate pred; + + switch(adaptor.predicate()) { + case CmpPredicate::LT: + pred = mlir::arith::CmpIPredicate::slt; + break; + case CmpPredicate::LTEQ: + pred = mlir::arith::CmpIPredicate::sle; + break; + case CmpPredicate::GT: + pred = mlir::arith::CmpIPredicate::sgt; + break; + case CmpPredicate::GTEQ: + pred = mlir::arith::CmpIPredicate::sge; + break; + case CmpPredicate::EQ: + pred = mlir::arith::CmpIPredicate::eq; + break; + case CmpPredicate::NEQ: + pred = mlir::arith::CmpIPredicate::ne; + break; + } + + rewriter.replaceOpWithNewOp(op,pred, lhs, rhs); + return mlir::success(); + } +}; + +struct PrintOpConversion : public mlir::OpConversionPattern { + PrintOpConversion(mlir::MLIRContext *context) + : OpConversionPattern(context, /*benefit=*/1) {} + + mlir::LogicalResult + matchAndRewrite(ir::PrintOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto value2 = castToMLIR(rewriter, op->getLoc(), adaptor.value()); + rewriter.replaceOpWithNewOp(op, value2); + return mlir::success(); + } +}; + + } // end namespace void LowerSandboxPass::runOnOperation() { @@ -57,7 +109,7 @@ void LowerSandboxPass::runOnOperation() { // Configure conversion to lower out SCF operations. mlir::ConversionTarget target(getContext()); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -66,7 +118,10 @@ void LowerSandboxPass::runOnOperation() { void populateLowerSandboxConversionPatterns( mlir::RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add< + ConstantOpConversion, + AddIOpConversion, + ComparisonOpConversion>(patterns.getContext()); } std::unique_ptr createLowerSandboxPass() { diff --git a/query/lib/Dialect/Sandbox/IR/SandboxOps.cpp b/query/lib/Dialect/Sandbox/IR/SandboxOps.cpp index 7ff136c01ae..37666cb0cbc 100644 --- a/query/lib/Dialect/Sandbox/IR/SandboxOps.cpp +++ b/query/lib/Dialect/Sandbox/IR/SandboxOps.cpp @@ -5,6 +5,8 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Location.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/APSInt.h" #include #include "Dialect/Sandbox/IR/SandboxOpsEnums.cpp.inc" @@ -20,6 +22,23 @@ mlir::OpFoldResult ConstantOp::fold(llvm::ArrayRef operands) { return valueAttr(); } +mlir::LogicalResult ConstantOp::verify() { + auto type = getType(); + // The value's type must match the return type. + auto valueType = value().getType(); + + if (valueType.isa() && type.isa()) { + return mlir::success(); + } + + if (valueType == mlir::IntegerType::get(getContext(), 1) && type.isa()) { + return mlir::success(); + } + + return emitOpError() << "bad constant: type=" << type << ", valueType=" << valueType; +} + + mlir::OpFoldResult AddIOp::fold(llvm::ArrayRef operands) { assert(operands.size() == 2 && "binary op takes two operands"); if (!operands[0] || !operands[1]) @@ -37,6 +56,50 @@ mlir::OpFoldResult AddIOp::fold(llvm::ArrayRef operands) { return {}; } + +mlir::OpFoldResult ComparisonOp::fold(llvm::ArrayRef operands) { + assert(operands.size() == 2 && "comparison op takes two operands"); + if (!operands[0] || !operands[1]) + return {}; + + + if (operands[0].isa() && operands[1].isa()) { + auto pred = predicate(); + auto lhs = operands[0].cast(); + auto rhs = operands[1].cast(); + + + bool x; + auto l = lhs.getValue(); + auto r = rhs.getValue(); + switch(pred) { + case CmpPredicate::LT: + x = l.slt(r); + break; + case CmpPredicate::LTEQ: + x = l.sle(r); + break; + case CmpPredicate::GT: + x = l.sgt(r); + break; + case CmpPredicate::GTEQ: + x = l.sge(r); + break; + case CmpPredicate::EQ: + x = l == r; + break; + case CmpPredicate::NEQ: + x = l != r; + break; + } + + auto result = mlir::BoolAttr::get(getContext(), x); + return result; + } + + return {}; +} + struct SimplifyAddConstAddConst : public mlir::OpRewritePattern { SimplifyAddConstAddConst(mlir::MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) {} @@ -51,7 +114,8 @@ struct SimplifyAddConstAddConst : public mlir::OpRewritePattern { auto sumConst = rewriter.create( mlir::FusedLoc::get(op->getContext(), {lConst->getLoc(), rConst.getLoc()}, nullptr), - lConst.getType(), lConst.value() + rConst.value()); + lConst.getType(), + mlir::IntegerAttr::get(lConst.getType(), lConst.value().cast().getValue() + rConst.value().cast().getValue())); rewriter.replaceOpWithNewOp(op, lhs.result().getType(), lhs.lhs(), sumConst); return mlir::success(); } diff --git a/query/test/current.mlir b/query/test/current.mlir index 46bfe998494..0306ac97af2 100644 --- a/query/test/current.mlir +++ b/query/test/current.mlir @@ -1,9 +1,16 @@ -func.func @foo(%in: !sb.int) -> (!sb.bool) { - %i1 = sb.constant 5 : !sb.int - %i2 = sb.constant 7 : !sb.int +func.func @foo(%in: !sb.int) -> () { + %i1 = sb.constant (5) : !sb.int + %i2 = sb.constant (7) : !sb.int %i3 = sb.addi %in %i1 %i4 = sb.addi %i3 %i2 - %i5 = sb.constant 6 : !sb.int + %i5 = sb.constant (6) : !sb.int %i6 = sb.compare eq, %i4, %i5 : !sb.bool - func.return %i6 : !sb.bool + sb.print %i6 : !sb.bool + func.return +} + +func.func @bar() -> () { + %i1 = sb.constant (true) : !sb.bool + sb.print %i1 : !sb.bool + func.return } \ No newline at end of file