diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h index 995fcd027919..c36923a41e35 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h @@ -122,6 +122,10 @@ class CIRCXXABI { mlir::Value loweredRhs, mlir::OpBuilder &builder) const = 0; + virtual mlir::Value lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const = 0; + virtual mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy, mlir::Value loweredSrc, diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp index 992cf88efaea..2819adb25b8c 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp @@ -103,6 +103,10 @@ class ItaniumCXXABI : public CIRCXXABI { mlir::Value loweredRhs, mlir::OpBuilder &builder) const override; + mlir::Value lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const override; + mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy, mlir::Value loweredSrc, mlir::OpBuilder &builder) const override; @@ -478,6 +482,61 @@ mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op, loweredRhs); } +mlir::Value ItaniumCXXABI::lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const { + assert(op.getKind() == cir::CmpOpKind::eq || + op.getKind() == cir::CmpOpKind::ne); + + cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(LM); + mlir::Value ptrdiffZero = builder.create( + op.getLoc(), ptrdiffCIRTy, cir::IntAttr::get(ptrdiffCIRTy, 0)); + + mlir::Value lhsPtrField = builder.create( + op.getLoc(), ptrdiffCIRTy, loweredLhs, 0); + mlir::Value rhsPtrField = builder.create( + op.getLoc(), ptrdiffCIRTy, loweredRhs, 0); + mlir::Value ptrCmp = builder.create(op.getLoc(), op.getKind(), + lhsPtrField, rhsPtrField); + mlir::Value ptrCmpToNull = builder.create( + op.getLoc(), op.getKind(), lhsPtrField, ptrdiffZero); + + mlir::Value lhsAdjField = builder.create( + op.getLoc(), ptrdiffCIRTy, loweredLhs, 1); + mlir::Value rhsAdjField = builder.create( + op.getLoc(), ptrdiffCIRTy, loweredRhs, 1); + mlir::Value adjCmp = builder.create(op.getLoc(), op.getKind(), + lhsAdjField, rhsAdjField); + + // We use cir.select to represent "||" and "&&" operations below: + // - cir.select if %a then %b else false => %a && %b + // - cir.select if %a then true else %b => %a || %b + // TODO: Do we need to invent dedicated "cir.logical_or" and "cir.logical_and" + // operations for this? + auto boolTy = cir::BoolType::get(op.getContext()); + mlir::Value trueValue = builder.create( + op.getLoc(), boolTy, cir::BoolAttr::get(op.getContext(), boolTy, true)); + mlir::Value falseValue = builder.create( + op.getLoc(), boolTy, cir::BoolAttr::get(op.getContext(), boolTy, false)); + auto create_and = [&](mlir::Value lhs, mlir::Value rhs) { + return builder.create(op.getLoc(), lhs, rhs, falseValue); + }; + auto create_or = [&](mlir::Value lhs, mlir::Value rhs) { + return builder.create(op.getLoc(), lhs, trueValue, rhs); + }; + + mlir::Value result; + if (op.getKind() == cir::CmpOpKind::eq) { + // (lhs.ptr == null || lhs.adj == rhs.adj) && lhs.ptr == rhs.ptr + result = create_and(create_or(ptrCmpToNull, adjCmp), ptrCmp); + } else { + // (lhs.ptr != null && lhs.adj != rhs.adj) || lhs.ptr != rhs.ptr + result = create_or(create_and(ptrCmpToNull, adjCmp), ptrCmp); + } + + return result; +} + mlir::Value ItaniumCXXABI::lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy, mlir::Value loweredSrc, diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index e1a2d49fc5d2..12ce2fb6eda6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2893,10 +2893,17 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( mlir::ConversionPatternRewriter &rewriter) const { auto type = cmpOp.getLhs().getType(); - if (mlir::isa(type)) { + if (mlir::isa(type)) { assert(lowerMod && "lowering module is not available"); - mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp( - cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter); + + mlir::Value loweredResult; + if (mlir::isa(type)) + loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp( + cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter); + else + loweredResult = lowerMod->getCXXABI().lowerMethodCmp( + cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter); + rewriter.replaceOp(cmpOp, loweredResult); return mlir::success(); } diff --git a/clang/test/CIR/CodeGen/pointer-to-member-func.cpp b/clang/test/CIR/CodeGen/pointer-to-member-func.cpp index a1a42f4d494c..5baf9c9bd23a 100644 --- a/clang/test/CIR/CodeGen/pointer-to-member-func.cpp +++ b/clang/test/CIR/CodeGen/pointer-to-member-func.cpp @@ -78,3 +78,43 @@ void call(Foo *obj, void (Foo::*func)(int), int arg) { // LLVM-NEXT: %[[#arg:]] = load i32, ptr %{{.+}} // LLVM-NEXT: call void %[[#callee_ptr]](ptr %[[#adjusted_this]], i32 %[[#arg]]) // LLVM: } + +bool cmp_eq(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) { + return lhs == rhs; +} + +// CHECK-LABEL: @_Z6cmp_eqM3FooFviES1_ +// CHECK: %{{.+}} = cir.cmp(eq, %{{.+}}, %{{.+}}) : !cir.method in !ty_Foo>, !cir.bool + +// LLVM-LABEL: @_Z6cmp_eqM3FooFviES1_ +// LLVM: %[[#lhs:]] = load { i64, i64 }, ptr %{{.+}} +// LLVM-NEXT: %[[#rhs:]] = load { i64, i64 }, ptr %{{.+}} +// LLVM-NEXT: %[[#lhs_ptr:]] = extractvalue { i64, i64 } %[[#lhs]], 0 +// LLVM-NEXT: %[[#rhs_ptr:]] = extractvalue { i64, i64 } %[[#rhs]], 0 +// LLVM-NEXT: %[[#ptr_cmp:]] = icmp eq i64 %[[#lhs_ptr]], %[[#rhs_ptr]] +// LLVM-NEXT: %[[#ptr_null:]] = icmp eq i64 %[[#lhs_ptr]], 0 +// LLVM-NEXT: %[[#lhs_adj:]] = extractvalue { i64, i64 } %[[#lhs]], 1 +// LLVM-NEXT: %[[#rhs_adj:]] = extractvalue { i64, i64 } %[[#rhs]], 1 +// LLVM-NEXT: %[[#adj_cmp:]] = icmp eq i64 %[[#lhs_adj]], %[[#rhs_adj]] +// LLVM-NEXT: %[[#tmp:]] = or i1 %[[#ptr_null]], %[[#adj_cmp]] +// LLVM-NEXT: %{{.+}} = and i1 %[[#tmp]], %[[#ptr_cmp]] + +bool cmp_ne(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) { + return lhs != rhs; +} + +// CHECK-LABEL: @_Z6cmp_neM3FooFviES1_ +// CHECK: %{{.+}} = cir.cmp(ne, %{{.+}}, %{{.+}}) : !cir.method in !ty_Foo>, !cir.bool + +// LLVM-LABEL: @_Z6cmp_neM3FooFviES1_ +// LLVM: %[[#lhs:]] = load { i64, i64 }, ptr %{{.+}} +// LLVM-NEXT: %[[#rhs:]] = load { i64, i64 }, ptr %{{.+}} +// LLVM-NEXT: %[[#lhs_ptr:]] = extractvalue { i64, i64 } %[[#lhs]], 0 +// LLVM-NEXT: %[[#rhs_ptr:]] = extractvalue { i64, i64 } %[[#rhs]], 0 +// LLVM-NEXT: %[[#ptr_cmp:]] = icmp ne i64 %[[#lhs_ptr]], %[[#rhs_ptr]] +// LLVM-NEXT: %[[#ptr_null:]] = icmp ne i64 %[[#lhs_ptr]], 0 +// LLVM-NEXT: %[[#lhs_adj:]] = extractvalue { i64, i64 } %[[#lhs]], 1 +// LLVM-NEXT: %[[#rhs_adj:]] = extractvalue { i64, i64 } %[[#rhs]], 1 +// LLVM-NEXT: %[[#adj_cmp:]] = icmp ne i64 %[[#lhs_adj]], %[[#rhs_adj]] +// LLVM-NEXT: %[[#tmp:]] = and i1 %[[#ptr_null]], %[[#adj_cmp]] +// LLVM-NEXT: %{{.+}} = or i1 %[[#tmp]], %[[#ptr_cmp]]