Skip to content

Commit

Permalink
lower print
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoterba committed Nov 8, 2022
1 parent 53475a3 commit f934a20
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
4 changes: 2 additions & 2 deletions query/include/Dialect/Sandbox/IR/SandboxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def Sandbox_ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> {
}

def Sandbox_PrintOp : Sandbox_Op<"print"> {
let arguments = (ins Sandbox_Int:$value);
let assemblyFormat = "$value attr-dict";
let arguments = (ins AnyType:$value);
let assemblyFormat = "$value attr-dict `:` type($value)";
}

#endif // DIALECT_SANDBOX_SANDBOXOPS
59 changes: 57 additions & 2 deletions query/lib/Conversion/LowerSandbox/LowerSandbox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,58 @@ struct ConstantOpConversion : public mlir::OpConversionPattern<ir::ConstantOp> {
}
};

struct ComparisonOpConversion : public mlir::OpConversionPattern<ir::ComparisonOp> {
ComparisonOpConversion(mlir::MLIRContext *context)
: OpConversionPattern<ir::ComparisonOp>(context, /*benefit=*/1) {}

mlir::LogicalResult
matchAndRewrite(ir::ComparisonOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto lhs = castToMLIR(rewriter, op->getLoc(), adaptor.lhs());
auto rhs = castToMLIR(rewriter, op->getLoc(), adaptor.rhs());

mlir::arith::CmpIPredicate pred;

switch(adaptor.predicate()) {
case CmpPredicate::LT:
pred = mlir::arith::CmpIPredicate::slt;
break;
case CmpPredicate::LTEQ:
pred = mlir::arith::CmpIPredicate::sle;
break;
case CmpPredicate::GT:
pred = mlir::arith::CmpIPredicate::sgt;
break;
case CmpPredicate::GTEQ:
pred = mlir::arith::CmpIPredicate::sge;
break;
case CmpPredicate::EQ:
pred = mlir::arith::CmpIPredicate::eq;
break;
case CmpPredicate::NEQ:
pred = mlir::arith::CmpIPredicate::ne;
break;
}

rewriter.replaceOpWithNewOp<mlir::arith::CmpIOp>(op,pred, lhs, rhs);
return mlir::success();
}
};

struct PrintOpConversion : public mlir::OpConversionPattern<ir::PrintOp> {
PrintOpConversion(mlir::MLIRContext *context)
: OpConversionPattern<ir::PrintOp>(context, /*benefit=*/1) {}

mlir::LogicalResult
matchAndRewrite(ir::PrintOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto value2 = castToMLIR(rewriter, op->getLoc(), adaptor.value());
rewriter.replaceOpWithNewOp<ir::PrintOp>(op, value2);
return mlir::success();
}
};


} // end namespace

void LowerSandboxPass::runOnOperation() {
Expand All @@ -57,7 +109,7 @@ void LowerSandboxPass::runOnOperation() {

// Configure conversion to lower out SCF operations.
mlir::ConversionTarget target(getContext());
target.addIllegalOp<ir::ConstantOp, ir::AddIOp>();
target.addIllegalOp<ir::ConstantOp, ir::AddIOp, ir::ComparisonOp>();
target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand All @@ -66,7 +118,10 @@ void LowerSandboxPass::runOnOperation() {

void populateLowerSandboxConversionPatterns(
mlir::RewritePatternSet &patterns) {
patterns.add<ConstantOpConversion, AddIOpConversion>(patterns.getContext());
patterns.add<
ConstantOpConversion,
AddIOpConversion,
ComparisonOpConversion>(patterns.getContext());
}

std::unique_ptr<mlir::Pass> createLowerSandboxPass() {
Expand Down
10 changes: 6 additions & 4 deletions query/test/current.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
func.func @foo(%in: !sb.int) -> (!sb.bool) {
func.func @foo(%in: !sb.int) -> () {
%i1 = sb.constant (5) : !sb.int
%i2 = sb.constant (7) : !sb.int
%i3 = sb.addi %in %i1
%i4 = sb.addi %i3 %i2
%i5 = sb.constant (6) : !sb.int
%i6 = sb.compare eq, %i4, %i5 : !sb.bool
func.return %i6 : !sb.bool
sb.print %i6 : !sb.bool
func.return
}

func.func @bar() -> (!sb.bool) {
func.func @bar() -> () {
%i1 = sb.constant (true) : !sb.bool
func.return %i1 : !sb.bool
sb.print %i1 : !sb.bool
func.return
}

0 comments on commit f934a20

Please sign in to comment.