Skip to content

Commit

Permalink
[CombToAIG] Add a pattern for mul (#8015)
Browse files Browse the repository at this point in the history
This commit adds a pattern to lower mul op. The pattern lowers mul op into chains of comb.add + comb.mux. There must be more efficient implementation but for now this naive pattern should work fine. LEC is verified. 

```
{a_{n}, a_{n-1}, ..., a_0} * b
= sum_{i=0}^{n} a_i * 2^i * b
= sum_{i=0}^{n} (a_i ? b : 0) << i
```
  • Loading branch information
uenoku authored Dec 27, 2024
1 parent 550a9a7 commit 5b128a1
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
7 changes: 7 additions & 0 deletions integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,10 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
%0 = comb.sub %lhs, %rhs : i4
hw.output %0 : i4
}

// RUN: circt-lec %t.mlir %s -c1=mul -c2=mul --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MUL
// COMB_MUL: c1 == c2
hw.module @mul(in %arg0: i3, in %arg1: i3, in %arg2: i3, out add: i3) {
%0 = comb.mul %arg0, %arg1, %arg2 : i3
hw.output %0 : i3
}
51 changes: 48 additions & 3 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,51 @@ struct CombSubOpConversion : OpConversionPattern<SubOp> {
}
};

struct CombMulOpConversion : OpConversionPattern<MulOp> {
using OpConversionPattern<MulOp>::OpConversionPattern;
using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
LogicalResult
matchAndRewrite(MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getInputs().size() != 2)
return failure();

// FIXME: Currently it's lowered to a really naive implementation that
// chains add operations.

// a_{n}a_{n-1}...a_0 * b
// = sum_{i=0}^{n} a_i * 2^i * b
// = sum_{i=0}^{n} (a_i ? b : 0) << i
int64_t width = op.getType().getIntOrFloatBitWidth();
auto aBits = extractBits(rewriter, adaptor.getInputs()[0]);
SmallVector<Value> results;
auto rhs = op.getInputs()[1];
auto zero = rewriter.create<hw::ConstantOp>(op.getLoc(),
llvm::APInt::getZero(width));
for (int64_t i = 0; i < width; ++i) {
auto aBit = aBits[i];
auto andBit =
rewriter.createOrFold<comb::MuxOp>(op.getLoc(), aBit, rhs, zero);
auto upperBits = rewriter.createOrFold<comb::ExtractOp>(
op.getLoc(), andBit, 0, width - i);
if (i == 0) {
results.push_back(upperBits);
continue;
}

auto lowerBits =
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(i));

auto shifted = rewriter.createOrFold<comb::ConcatOp>(
op.getLoc(), op.getType(), ValueRange{upperBits, lowerBits});
results.push_back(shifted);
}

rewriter.replaceOpWithNewOp<comb::AddOp>(op, results, true);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -304,10 +349,10 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
CombMuxOpConversion,
// Arithmetic Ops
CombAddOpConversion, CombSubOpConversion,
CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
// Variadic ops that must be lowered to binary operations
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>>(
patterns.getContext());
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
CombLowerVariadicOp<MulOp>>(patterns.getContext());
}

void ConvertCombToAIGPass::runOnOperation() {
Expand Down
19 changes: 19 additions & 0 deletions test/Conversion/CombToAIG/comb-to-aig-arith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,22 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
%0 = comb.sub %lhs, %rhs : i4
hw.output %0 : i4
}


// CHECK-LABEL: @mul
// ALLOW_ADD-LABEL: @mul
// ALLOW_ADD-NEXT: %[[EXT_0:.+]] = comb.extract %lhs from 0 : (i2) -> i1
// ALLOW_ADD-NEXT: %[[EXT_1:.+]] = comb.extract %lhs from 1 : (i2) -> i1
// ALLOW_ADD-NEXT: %c0_i2 = hw.constant 0 : i2
// ALLOW_ADD-NEXT: %[[MUX_0:.+]] = comb.mux %0, %rhs, %c0_i2 : i2
// ALLOW_ADD-NEXT: %[[MUX_1:.+]] = comb.mux %1, %rhs, %c0_i2 : i2
// ALLOW_ADD-NEXT: %[[EXT_MUX_1:.+]] = comb.extract %3 from 0 : (i2) -> i1
// ALLOW_ADD-NEXT: %false = hw.constant false
// ALLOW_ADD-NEXT: %[[SHIFT:.+]] = comb.concat %4, %false : i1, i1
// ALLOW_ADD-NEXT: %[[ADD:.+]] = comb.add bin %[[MUX_0]], %[[SHIFT]] : i2
// ALLOW_ADD-NEXT: hw.output %[[ADD]] : i2
// ALLOW_ADD-NEXT: }
hw.module @mul(in %lhs: i2, in %rhs: i2, out out: i2) {
%0 = comb.mul %lhs, %rhs : i2
hw.output %0 : i2
}

0 comments on commit 5b128a1

Please sign in to comment.