Skip to content
Merged
28 changes: 19 additions & 9 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1559,15 +1559,24 @@ def ComplexBinOp : CIR_Op<"complex.binop",
//===----------------------------------------------------------------------===//

class CIR_BitOp<string mnemonic, TypeConstraint inputTy>
: CIR_Op<mnemonic, [Pure]> {
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins inputTy:$input);
let results = (outs SInt32:$result);
let results = (outs inputTy:$result);

let assemblyFormat = [{
`(` $input `:` type($input) `)` `:` type($result) attr-dict
}];
}

class CIR_CountZerosBitOp<string mnemonic, TypeConstraint inputTy>
: CIR_BitOp<mnemonic, inputTy> {
let arguments = (ins inputTy:$input, UnitAttr:$is_zero_poison);
let assemblyFormat = [{
`(` $input `:` type($input) `)` (`zero_poison` $is_zero_poison^)?
`:` type($result) attr-dict
}];
}

def BitClrsbOp : CIR_BitOp<"bit.clrsb", AnyTypeOf<[SInt32, SInt64]>> {
let summary = "Get the number of leading redundant sign bits in the input";
let description = [{
Expand Down Expand Up @@ -1599,7 +1608,7 @@ def BitClrsbOp : CIR_BitOp<"bit.clrsb", AnyTypeOf<[SInt32, SInt64]>> {
}];
}

def BitClzOp : CIR_BitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
def BitClzOp : CIR_CountZerosBitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of leading 0-bits in the input";
let description = [{
Compute the number of leading 0-bits in the input.
Expand All @@ -1608,23 +1617,23 @@ def BitClzOp : CIR_BitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
returns the number of consecutive 0-bits at the most significant bit
position in the input.

This operation invokes undefined behavior if the input value is 0.
Zero_poison attribute means this operation invokes undefined behavior if the
input value is 0.

Example:

```mlir
!s32i = !cir.int<s, 32>
!u32i = !cir.int<u, 32>

// %0 = 0b0000_0000_0000_0000_0000_0000_0000_1000
%0 = cir.const #cir.int<8> : !u32i
// %1 will be 28
%1 = cir.bit.clz(%0 : !u32i) : !s32i
%1 = cir.bit.clz(%0 : !u32i) zero_poison : !u32i
```
}];
}

def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
def BitCtzOp : CIR_CountZerosBitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of trailing 0-bits in the input";
let description = [{
Compute the number of trailing 0-bits in the input.
Expand All @@ -1633,7 +1642,8 @@ def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
returns the number of consecutive 0-bits at the least significant bit
position in the input.

This operation invokes undefined behavior if the input value is 0.
Zero_poison attribute means this operation invokes undefined behavior if the
input value is 0.

Example:

Expand All @@ -1644,7 +1654,7 @@ def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
// %0 = 0b1000
%0 = cir.const #cir.int<8> : !u32i
// %1 will be 3
%1 = cir.bit.ctz(%0 : !u32i) : !s32i
%1 = cir.bit.ctz(%0 : !u32i) : !u32i
```
}];
}
Expand Down
61 changes: 40 additions & 21 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,30 @@ static mlir::Value emitBinaryMaybeConstrainedFPBuiltin(CIRGenFunction &CGF,
}

template <typename Op>
static RValue
emitBuiltinBitOp(CIRGenFunction &CGF, const CallExpr *E,
std::optional<CIRGenFunction::BuiltinCheckKind> CK) {
static RValue emitBuiltinBitOp(
CIRGenFunction &CGF, const CallExpr *E,
std::optional<CIRGenFunction::BuiltinCheckKind> CK = std::nullopt,
bool isZeroPoison = false, bool convertToInt = true) {
mlir::Value arg;
if (CK.has_value())
arg = CGF.emitCheckedArgForBuiltin(E->getArg(0), *CK);
else
arg = CGF.emitScalarExpr(E->getArg(0));

auto resultTy = CGF.convertType(E->getType());
auto op =
CGF.getBuilder().create<Op>(CGF.getLoc(E->getExprLoc()), resultTy, arg);
return RValue::get(op);
Op op;
if constexpr (std::is_same_v<Op, cir::BitClzOp> ||
std::is_same_v<Op, cir::BitCtzOp>) {
op = CGF.getBuilder().create<Op>(CGF.getLoc(E->getExprLoc()), arg,
isZeroPoison);
} else {
op = CGF.getBuilder().create<Op>(CGF.getLoc(E->getExprLoc()), arg);
}
const mlir::Value bitResult = op.getResult();
if (const auto si32Ty = CGF.getBuilder().getSInt32Ty();
convertToInt && arg.getType() != si32Ty) {
return RValue::get(CGF.getBuilder().createIntCast(bitResult, si32Ty));
}
return RValue::get(bitResult);
}

// Initialize the alloca with the given size and alignment according to the lang
Expand Down Expand Up @@ -1052,46 +1063,54 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,

case Builtin::BI__builtin_clrsb:
case Builtin::BI__builtin_clrsbl:
case Builtin::BI__builtin_clrsbll:
return emitBuiltinBitOp<cir::BitClrsbOp>(*this, E, std::nullopt);
case Builtin::BI__builtin_clrsbll: {
return emitBuiltinBitOp<cir::BitClrsbOp>(*this, E);
}

case Builtin::BI__builtin_ctzs:
case Builtin::BI__builtin_ctz:
case Builtin::BI__builtin_ctzl:
case Builtin::BI__builtin_ctzll:
case Builtin::BI__builtin_ctzg:
return emitBuiltinBitOp<cir::BitCtzOp>(*this, E, BCK_CTZPassedZero);
case Builtin::BI__builtin_ctzg: {
return emitBuiltinBitOp<cir::BitCtzOp>(*this, E, BCK_CTZPassedZero, true);
}

case Builtin::BI__builtin_clzs:
case Builtin::BI__builtin_clz:
case Builtin::BI__builtin_clzl:
case Builtin::BI__builtin_clzll:
case Builtin::BI__builtin_clzg:
return emitBuiltinBitOp<cir::BitClzOp>(*this, E, BCK_CLZPassedZero);
case Builtin::BI__builtin_clzg: {
return emitBuiltinBitOp<cir::BitClzOp>(*this, E, BCK_CLZPassedZero, true);
}

case Builtin::BI__builtin_ffs:
case Builtin::BI__builtin_ffsl:
case Builtin::BI__builtin_ffsll:
return emitBuiltinBitOp<cir::BitFfsOp>(*this, E, std::nullopt);
case Builtin::BI__builtin_ffsll: {
return emitBuiltinBitOp<cir::BitFfsOp>(*this, E);
}

case Builtin::BI__builtin_parity:
case Builtin::BI__builtin_parityl:
case Builtin::BI__builtin_parityll:
return emitBuiltinBitOp<cir::BitParityOp>(*this, E, std::nullopt);
case Builtin::BI__builtin_parityll: {
return emitBuiltinBitOp<cir::BitParityOp>(*this, E);
}

case Builtin::BI__lzcnt16:
case Builtin::BI__lzcnt:
case Builtin::BI__lzcnt64:
llvm_unreachable("BI__lzcnt16 like NYI");
case Builtin::BI__lzcnt64: {
return emitBuiltinBitOp<cir::BitClzOp>(*this, E, BCK_CLZPassedZero, false,
false);
}

case Builtin::BI__popcnt16:
case Builtin::BI__popcnt:
case Builtin::BI__popcnt64:
case Builtin::BI__builtin_popcount:
case Builtin::BI__builtin_popcountl:
case Builtin::BI__builtin_popcountll:
case Builtin::BI__builtin_popcountg:
return emitBuiltinBitOp<cir::BitPopcountOp>(*this, E, std::nullopt);
case Builtin::BI__builtin_popcountg: {
return emitBuiltinBitOp<cir::BitPopcountOp>(*this, E);
}

case Builtin::BI__builtin_unpredictable: {
if (CGM.getCodeGenOpts().OptimizationLevel != 0)
Expand Down
61 changes: 12 additions & 49 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3059,38 +3059,6 @@ mlir::LogicalResult CIRToLLVMAssumeSepStorageOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::Value createLLVMBitOp(mlir::Location loc,
const llvm::Twine &llvmIntrinBaseName,
mlir::Type resultTy, mlir::Value operand,
std::optional<bool> poisonZeroInputFlag,
mlir::ConversionPatternRewriter &rewriter) {
auto operandIntTy = mlir::cast<mlir::IntegerType>(operand.getType());
auto resultIntTy = mlir::cast<mlir::IntegerType>(resultTy);

std::string llvmIntrinName =
llvmIntrinBaseName.concat(".i")
.concat(std::to_string(operandIntTy.getWidth()))
.str();

// Note that LLVM intrinsic calls to bit intrinsics have the same type as the
// operand.
mlir::LLVM::CallIntrinsicOp op;
if (poisonZeroInputFlag.has_value()) {
auto poisonZeroInputValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, rewriter.getI1Type(), static_cast<int64_t>(*poisonZeroInputFlag));
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
operand.getType(),
{operand, poisonZeroInputValue});
} else {
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
operand.getType(), operand);
}

return getLLVMIntCast(
rewriter, op->getResult(0), mlir::cast<mlir::IntegerType>(resultTy),
/*isUnsigned=*/true, operandIntTy.getWidth(), resultIntTy.getWidth());
}

mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite(
cir::BitClrsbOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand All @@ -3111,8 +3079,8 @@ mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite(
op.getLoc(), isNeg, flipped, adaptor.getInput());

auto resTy = getTypeConverter()->convertType(op.getType());
auto clz = createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, select,
/*poisonZeroInputFlag=*/false, rewriter);
auto clz = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
op.getLoc(), resTy, select, false);

auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
auto res = rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), clz, one);
Expand Down Expand Up @@ -3147,9 +3115,8 @@ mlir::LogicalResult CIRToLLVMBitClzOpLowering::matchAndRewrite(
cir::BitClzOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto llvmOp =
createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/true, rewriter);
auto llvmOp = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
op.getLoc(), resTy, adaptor.getInput(), op.getIsZeroPoison());
rewriter.replaceOp(op, llvmOp);
return mlir::LogicalResult::success();
}
Expand All @@ -3158,9 +3125,8 @@ mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite(
cir::BitCtzOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto llvmOp =
createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/true, rewriter);
auto llvmOp = rewriter.create<mlir::LLVM::CountTrailingZerosOp>(
op.getLoc(), resTy, adaptor.getInput(), op.getIsZeroPoison());
rewriter.replaceOp(op, llvmOp);
return mlir::LogicalResult::success();
}
Expand All @@ -3169,9 +3135,8 @@ mlir::LogicalResult CIRToLLVMBitFfsOpLowering::matchAndRewrite(
cir::BitFfsOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto ctz =
createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/false, rewriter);
auto ctz = rewriter.create<mlir::LLVM::CountTrailingZerosOp>(
op.getLoc(), resTy, adaptor.getInput(), false);

auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
auto ctzAddOne = rewriter.create<mlir::LLVM::AddOp>(op.getLoc(), ctz, one);
Expand All @@ -3196,9 +3161,8 @@ mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite(
cir::BitParityOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto popcnt =
createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/std::nullopt, rewriter);
auto popcnt = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
adaptor.getInput());

auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
auto popcntMod2 =
Expand All @@ -3212,9 +3176,8 @@ mlir::LogicalResult CIRToLLVMBitPopcountOpLowering::matchAndRewrite(
cir::BitPopcountOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto llvmOp =
createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/std::nullopt, rewriter);
auto llvmOp = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
adaptor.getInput());
rewriter.replaceOp(op, llvmOp);
return mlir::LogicalResult::success();
}
Expand Down
6 changes: 0 additions & 6 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ mlir::LLVM::CallIntrinsicOp replaceOpWithCallLLVMIntrinsicOp(
const llvm::Twine &intrinsicName, mlir::Type resultTy,
mlir::ValueRange operands);

mlir::Value createLLVMBitOp(mlir::Location loc,
const llvm::Twine &llvmIntrinBaseName,
mlir::Type resultTy, mlir::Value operand,
std::optional<bool> poisonZeroInputFlag,
mlir::ConversionPatternRewriter &rewriter);

class CIRToLLVMCopyOpLowering : public mlir::OpConversionPattern<cir::CopyOp> {
public:
using mlir::OpConversionPattern<cir::CopyOp>::OpConversionPattern;
Expand Down
Loading