Skip to content

Commit

Permalink
[Comb] delete slow canonicalizer (#8014)
Browse files Browse the repository at this point in the history
This canonicalizer does not have a functional issue, but is causing bad
performance issues.  This change removes it until it can be fixed
properly.
  • Loading branch information
youngar authored Dec 23, 2024
1 parent 40c2014 commit 6f7cba6
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 192 deletions.
111 changes: 0 additions & 111 deletions lib/Dialect/Comb/CombFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1178,107 +1178,6 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
return constFoldAssociativeOp(inputs, hw::PEO::Or);
}

/// Simplify concat ops in an or op when a constant operand is present in either
/// concat.
///
/// This will invert an or(concat, concat) into concat(or, or, ...), which can
/// often be further simplified due to the smaller or ops being easier to fold.
///
/// For example:
///
/// or(..., concat(x, 0), concat(0, y))
/// ==> or(..., concat(x, 0, y)), when x and y don't overlap.
///
/// or(..., concat(x: i2, cst1: i4), concat(cst2: i5, y: i1))
/// ==> or(..., concat(or(x: i2, extract(cst2, 4..3)),
/// or(extract(cst1, 3..1), extract(cst2, 2..0)),
/// or(extract(cst1, 0..0), y: i1))
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
size_t concatIdx2,
PatternRewriter &rewriter) {
assert(concatIdx1 < concatIdx2 && "concatIdx1 must be < concatIdx2");

auto inputs = op.getInputs();
auto concat1 = inputs[concatIdx1].getDefiningOp<ConcatOp>();
auto concat2 = inputs[concatIdx2].getDefiningOp<ConcatOp>();

assert(concat1 && concat2 && "expected indexes to point to ConcatOps");

// We can simplify as long as a constant is present in either concat.
bool hasConstantOp1 =
llvm::any_of(concat1->getOperands(), [&](Value operand) -> bool {
return operand.getDefiningOp<hw::ConstantOp>();
});
if (!hasConstantOp1) {
bool hasConstantOp2 =
llvm::any_of(concat2->getOperands(), [&](Value operand) -> bool {
return operand.getDefiningOp<hw::ConstantOp>();
});
if (!hasConstantOp2)
return false;
}

SmallVector<Value> newConcatOperands;

// Simultaneously iterate over the operands of both concat ops, from MSB to
// LSB, pushing out or's of overlapping ranges of the operands. When operands
// span different bit ranges, we extract only the maximum overlap.
auto operands1 = concat1->getOperands();
auto operands2 = concat2->getOperands();
// Number of bits already consumed from operands 1 and 2, respectively.
unsigned consumedWidth1 = 0;
unsigned consumedWidth2 = 0;
for (auto it1 = operands1.begin(), end1 = operands1.end(),
it2 = operands2.begin(), end2 = operands2.end();
it1 != end1 && it2 != end2;) {
auto operand1 = *it1;
auto operand2 = *it2;

unsigned remainingWidth1 =
hw::getBitWidth(operand1.getType()) - consumedWidth1;
unsigned remainingWidth2 =
hw::getBitWidth(operand2.getType()) - consumedWidth2;
unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
auto narrowedType = rewriter.getIntegerType(widthToConsume);

auto extract1 = rewriter.createOrFold<ExtractOp>(
op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
auto extract2 = rewriter.createOrFold<ExtractOp>(
op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);

newConcatOperands.push_back(
rewriter.createOrFold<OrOp>(op.getLoc(), extract1, extract2, false));

consumedWidth1 += widthToConsume;
consumedWidth2 += widthToConsume;

if (widthToConsume == remainingWidth1) {
++it1;
consumedWidth1 = 0;
}
if (widthToConsume == remainingWidth2) {
++it2;
consumedWidth2 = 0;
}
}

ConcatOp newOp = rewriter.create<ConcatOp>(op.getLoc(), newConcatOperands);

// Copy the old operands except for concatIdx1 and concatIdx2, and append the
// new ConcatOp to the end.
SmallVector<Value> newOrOperands;
newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1);
newOrOperands.append(inputs.begin() + concatIdx1 + 1,
inputs.begin() + concatIdx2);
newOrOperands.append(inputs.begin() + concatIdx2 + 1,
inputs.begin() + inputs.size());
newOrOperands.push_back(newOp);

replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
newOrOperands);
return true;
}

LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
auto inputs = op.getInputs();
auto size = inputs.size();
Expand Down Expand Up @@ -1328,16 +1227,6 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
}
}

// or(..., concat(x, cst1), concat(cst2, y)
// ==> or(..., concat(x, cst3, y)), when x and y don't overlap.
for (size_t i = 0; i < size - 1; ++i) {
if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
for (size_t j = i + 1; j < size; ++j)
if (auto concat = inputs[j].getDefiningOp<ConcatOp>())
if (canonicalizeOrOfConcatsWithCstOperands(op, i, j, rewriter))
return success();
}

// extracts only of or(...) -> or(extract()...)
if (narrowOperationWidth(op, true, rewriter))
return success();
Expand Down
81 changes: 0 additions & 81 deletions test/Dialect/Comb/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -181,87 +181,6 @@ hw.module @dedupLong(in %arg0 : i7, in %arg1 : i7, in %arg2: i7, out resAnd: i7,
hw.output %0, %1 : i7, i7
}

// CHECK-LABEL: hw.module @orExclusiveConcats
hw.module @orExclusiveConcats(in %arg0 : i6, in %arg1 : i2, out o: i9) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %0 = comb.concat %arg1, %false, %arg0 : i2, i1, i6
// CHECK-NEXT: hw.output %0 : i9
%c0 = hw.constant 0 : i3
%0 = comb.concat %c0, %arg0 : i3, i6
%c1 = hw.constant 0 : i7
%1 = comb.concat %arg1, %c1 : i2, i7
%2 = comb.or %0, %1 : i9
hw.output %2 : i9
}

// When two concats are or'd together and have mutually-exclusive fields, they
// can be merged together into a single concat.
// concat0: 0aaa aaa0 0000 0bb0
// concat1: 0000 0000 ccdd d000
// merged: 0aaa aaa0 ccdd dbb0
// CHECK-LABEL: hw.module @orExclusiveConcats2
hw.module @orExclusiveConcats2(in %arg0 : i6, in %arg1 : i2, in %arg2: i2, in %arg3: i3, out o: i16) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %0 = comb.concat %false, %arg0, %false, %arg2, %arg3, %arg1, %false : i1, i6, i1, i2, i3, i2, i1
// CHECK-NEXT: hw.output %0 : i16
%c0 = hw.constant 0 : i1
%c1 = hw.constant 0 : i6
%c2 = hw.constant 0 : i1
%0 = comb.concat %c0, %arg0, %c1, %arg1, %c2: i1, i6, i6, i2, i1
%c3 = hw.constant 0 : i8
%c4 = hw.constant 0 : i3
%1 = comb.concat %c3, %arg2, %arg3, %c4 : i8, i2, i3, i3
%2 = comb.or %0, %1 : i16
hw.output %2 : i16
}

// When two concats are or'd together and have mutually-exclusive fields, they
// can be merged together into a single concat.
// concat0: aaaa 1111
// concat1: 1111 10bb
// merged: 1111 1111
// CHECK-LABEL: hw.module @orExclusiveConcats3
hw.module @orExclusiveConcats3(in %arg0 : i4, in %arg1 : i2, out o: i8) {
// CHECK-NEXT: [[RES:%[a-z0-9_-]+]] = hw.constant -1 : i8
// CHECK-NEXT: hw.output [[RES]] : i8
%c0 = hw.constant -1 : i4
%0 = comb.concat %arg0, %c0: i4, i4
%c1 = hw.constant -1 : i5
%c2 = hw.constant 0 : i1
%1 = comb.concat %c1, %c2, %arg1 : i5, i1, i2
%2 = comb.or %0, %1 : i8
hw.output %2 : i8
}

// CHECK-LABEL: hw.module @orMultipleExclusiveConcats
hw.module @orMultipleExclusiveConcats(in %arg0 : i2, in %arg1 : i2, in %arg2: i2, out o: i6) {
// CHECK-NEXT: %0 = comb.concat %arg0, %arg1, %arg2 : i2, i2, i2
// CHECK-NEXT: hw.output %0 : i6
%c2 = hw.constant 0 : i2
%c4 = hw.constant 0 : i4
%0 = comb.concat %arg0, %c4: i2, i4
%1 = comb.concat %c2, %arg1, %c2: i2, i2, i2
%2 = comb.concat %c4, %arg2: i4, i2
%out = comb.or %0, %1, %2 : i6
hw.output %out : i6
}

// CHECK-LABEL: hw.module @orConcatsWithMux
hw.module @orConcatsWithMux(in %bit: i1, in %cond: i1, out o: i6) {
// CHECK-NEXT: [[RES:%[a-z0-9_-]+]] = hw.constant 0 : i4
// CHECK-NEXT: %0 = comb.concat [[RES]], %cond, %bit : i4, i1, i1
// CHECK-NEXT: hw.output %0 : i6
%c0 = hw.constant 0 : i5
%0 = comb.concat %c0, %bit: i5, i1
%c1 = hw.constant 0 : i4
%c2 = hw.constant 2 : i2
%c3 = hw.constant 0 : i2
%1 = comb.mux %cond, %c2, %c3 : i2
%2 = comb.concat %c1, %1 : i4, i2
%3 = comb.or %0, %2 : i6
hw.output %3 : i6
}

// CHECK-LABEL: @extractNested
hw.module @extractNested(in %0: i5, out o1 : i1) {
// Multiple layers of nested extract is a weak evidence that the cannonicalization
Expand Down

0 comments on commit 6f7cba6

Please sign in to comment.