Skip to content

Commit

Permalink
AIRSpecializeChannelWrapAndStride: More flexible wrap-and-stride offs…
Browse files Browse the repository at this point in the history
…et canonicalization (Xilinx#791)

* Enable more flexible offset canonicalization, rather than only considering folding to the next dimension

* Add a new board test for pack-peel gemm in i32, with 4x4 herd

* Fixup comparison types

* Remove debug prints

* Fixup modulo result 0 cannot be converted to bool false

* Wrap-and-stride for loop folding taking into account complex affine maps with both gradients and offset
  • Loading branch information
erwei-xilinx authored Nov 25, 2024
1 parent 44f0c0b commit f3884b6
Show file tree
Hide file tree
Showing 5 changed files with 583 additions and 17 deletions.
55 changes: 42 additions & 13 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,24 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
builder.getUnknownLoc(), (*const_size) * (*const_size_next));
return true;
};
// For a given offset[i], find the first offset[j] such that stride[j] is
// divisible by stride[i], so that offset[i] can be composed onto offset[j].
auto findFirstComposableOffsetIdx = [](int i, SmallVector<Value> offsets,
SmallVector<Value> strides) {
auto constStrideI = getConstantIntValue(strides[i]);
std::optional<int> output = std::nullopt;
for (int j = i + 1; j < (int)strides.size(); j++) {
if (!getConstantIntValue(offsets[j]))
continue; // Currently unable to compose offset[i] expr onto another
// offset[j] expr.
auto constStrideJ = getConstantIntValue(strides[j]);
if ((*constStrideI) % (*constStrideJ) == 0) {
output = j;
return output;
}
}
return output;
};
for (auto i : erase_dims) {
auto const_offset = getConstantIntValue(offsets[i]);
if (const_offset && *const_offset == 0) {
Expand All @@ -855,13 +873,14 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
continue;
auto const_stride = getConstantIntValue(strides[i]);
assert(const_stride && "non-static stride, NYI.");
auto const_offset_next = getConstantIntValue(offsets[i + 1]);
if (!const_offset_next)
auto j = findFirstComposableOffsetIdx(i, offsets, strides);
if (!j)
continue;
auto const_stride_next = getConstantIntValue(strides[i + 1]);
assert(const_stride_next && "non-static stride, NYI.");
auto const_offset_next = getConstantIntValue(offsets[*j]);
auto const_stride_next = getConstantIntValue(strides[*j]);
// Attempting to compose i-th offset onto another offset.
if (const_offset) {
offsets[i + 1] = builder.create<arith::ConstantIndexOp>(
offsets[*j] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(),
(*const_stride) * (*const_offset) / (*const_stride_next) +
(*const_offset_next));
Expand Down Expand Up @@ -912,7 +931,7 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
auto next_offset_map = AffineMap::get(0, 1, offset_expr);
affine_apply.setMap(next_offset_map);
offsets[i] = affine_apply;
offsets[i + 1] = offsets[i];
offsets[*j] = offsets[i];
}
erased |= multiplyAdjWraps(builder, i, sizes);
offsets.erase(offsets.begin() + i);
Expand Down Expand Up @@ -1029,6 +1048,12 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides(
}

std::map<Operation *, int> op_to_count;
// Evaluate offset from affine map.
auto evalOffsetFromAffineMap = [&](MLIRContext *ctx, AffineMap map) {
return air::evaluateConstantsInMap(
map, SmallVector<std::optional<int64_t>>{std::optional<int64_t>{0}},
ctx);
};
for (auto o : for_loops) {
int64_t stepSize = -1;
int loop_lower_bound = 0;
Expand Down Expand Up @@ -1067,14 +1092,18 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides(
if (iv_is_symbol) {
auto map = affop.getAffineMap();
ind_var_factor = *getConstantIntValue(strides[i]);
ind_var_factor *= air::evaluateConstantsInMap(
map,
SmallVector<std::optional<int64_t>>{
std::optional<int64_t>{stepSize}},
for_op->getContext())
.value();
int64_t map_offset =
evalOffsetFromAffineMap(for_op->getContext(), map).value();
int64_t map_gradient = air::evaluateConstantsInMap(
map,
SmallVector<std::optional<int64_t>>{
std::optional<int64_t>{stepSize}},
for_op->getContext())
.value() -
map_offset;
ind_var_factor *= map_gradient;
offsets[i] = builder.template create<arith::ConstantIndexOp>(
loc, loop_lower_bound);
loc, loop_lower_bound + map_offset);
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ module {
// Offset propagation with wrap-and-stride canonicalization.
// CHECK-LABEL: test9
// CHECK: %[[VAL0:.*]] = affine.apply #map()[%arg1]
// CHECK: put @channel_21[] (%arg0[%c0, %c0, %[[VAL0]], %c0] [%c8, %c2, %c32, %c32] [%c32, %c8192, %c256, %c1]) : (memref<128x256xi32>)
// CHECK: put @channel_21[] (%arg0[%c0, %c0, %[[VAL0]]] [%c8, %c64, %c32] [%c32, %c256, %c1]) : (memref<128x256xi32>)
// CHECK: air.channel.put @channel_22[] (%arg2[%c256, %c0, %c0] [%c8, %c32, %c4] [%c4, %c32, %c1]) : (memref<1x2x32x32xi32, 1 : i32>)
// CHECK: air.channel.put @channel_23[] (%arg3[%c128, %c0, %c0] [%c4, %c32, %c8] [%c8, %c32, %c1]) : (memref<2x1x32x32xi32, 1 : i32>)
// CHECK: %[[VAL1:.*]] = affine.apply
Expand Down Expand Up @@ -386,25 +386,32 @@ module {
// Affine.apply with map joining two for loops in a loop nest.
// CHECK-LABEL: test11

// CHECK: air.channel.put async [%{{.*}}] @channel_26[%c0, %c0] (%{{.*}}[%c0, %c0, %c0] [%c4_0, %c18, %c4_0] [%c96, %c16, %c1]) : (memref<1x6x6x16xbf16, 1>)
// CHECK: air.channel.put async {{.*}}@channel_26[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}] [%c4{{.*}}, %c18{{.*}}, %c4{{.*}}] [%c96{{.*}}, %c16{{.*}}, %c1{{.*}}]) : (memref<1x6x6x16xbf16, 1>)
// CHECK: air.channel.put async {{.*}}@channel_26[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}, %c12{{.*}}] [%c3{{.*}}, %c3{{.*}}, %c4{{.*}}, %c4{{.*}}] [%c96{{.*}}, %c16{{.*}}, %c16{{.*}}, %c1{{.*}}]) : (memref<1x3x6x16xi32, 1>)

func.func @test11() {
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c3, %arg7=%c3, %arg8=%c4) {
%1 = air.segment @segment_0 async {
%c576 = arith.constant 576 : index
%c288 = arith.constant 288 : index
%c96 = arith.constant 96 : index
%c3_0 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c12 = arith.constant 12 : index
%c6 = arith.constant 6 : index
%c0 = arith.constant 0 : index
%c4_1 = arith.constant 4 : index
%async_token, %results = air.execute -> (memref<1x6x6x16xbf16, 1>) {
%alloc = memref.alloc() : memref<1x6x6x16xbf16, 1>
air.execute_terminator %alloc : memref<1x6x6x16xbf16, 1>
}
%async_token_23, %results_25 = air.execute -> (memref<1x3x6x16xi32, 1>) {
%alloc = memref.alloc() : memref<1x3x6x16xi32, 1>
air.execute_terminator %alloc : memref<1x3x6x16xi32, 1>
}
%4 = scf.for %arg9 = %c0 to %c4_1 step %c1 iter_args(%arg13 = %async_token) -> (!air.async.token) {
%2 = scf.for %arg10 = %c0 to %c3_0 step %c1 iter_args(%arg11 = %arg13) -> (!air.async.token) {
%async_token_2, %results_3 = air.execute [%arg11] -> (index) {
Expand All @@ -416,6 +423,15 @@ module {
}
scf.yield %2 : !air.async.token
}
scf.for %arg9 = %c0 to %c3_0 step %c1 {
%60 = scf.for %arg10 = %c0 to %c3_0 step %c1 iter_args(%arg13 = %async_token) -> (!air.async.token) {
%async_token_54, %results_55 = air.execute [%arg13] -> (index) {
air.execute_terminator %arg9 : index
}
%61 = air.channel.put async [%async_token_54] @channel_26[%c0, %c0] (%results_25[%c0, %results_55, %arg10, %c12] [%c1, %c1, %c4_1, %c4_1] [%c288, %c96, %c16, %c1]) : (memref<1x3x6x16xi32, 1>)
scf.yield %61 : !air.async.token
}
}
}
}
return
Expand Down Expand Up @@ -460,10 +476,11 @@ module {
// CHECK-LABEL: test13

// CHECK: air.channel.put async [%{{.*}}] @channel_14[] (%{{.*}}[%c0, %1, %results, %c0] [%c8, %c2_0, %c32, %c32] [%c32, %c8192, %c256, %c1]) : (memref<2x128x256xi32>)
// CHECK: air.channel.put async [%{{.*}}] @channel_15[%c0, %c0] (%{{.*}}[%c0, %results, %c32768] [%c8, %c32, %c32] [%c32, %c256, %c1]) : (memref<512x512xi32>)

func.func @test13(%arg0: memref<2x128x256xi32>, %arg1: memref<2x256x128xi32>) {
func.func @test13(%arg0: memref<2x128x256xi32>, %arg1: memref<512x512xi32>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c2, %arg7=%c2, %arg8=%c2) args(%arg10=%arg0, %arg11=%arg1) : memref<2x128x256xi32>, memref<2x256x128xi32> {
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c2, %arg7=%c2, %arg8=%c2) args(%arg10=%arg0, %arg11=%arg1) : memref<2x128x256xi32>, memref<512x512xi32> {
%c4096 = arith.constant 4096 : index
%c8 = arith.constant 8 : index
%c16384 = arith.constant 16384 : index
Expand All @@ -484,6 +501,10 @@ module {
%7 = air.channel.put async [%arg13, %async_token] @channel_14[] (%arg10[%arg3, %c0, %c0, %results, %arg12] [%c1, %c2_0, %c1, %c32, %c32] [%c32768, %c8192, %c32, %c256, %c1]) : (memref<2x128x256xi32>)
scf.yield %7 : !air.async.token
}
%3 = scf.for %arg12 = %c0 to %c256 step %c32 iter_args(%arg13 = %async_token) -> (!air.async.token) {
%7 = air.channel.put async [%arg13, %async_token] @channel_15[%c0, %c0] (%arg11[%c2_0, %c0, %results, %arg12] [%c1, %c1, %c32, %c32] [%c16384, %c32, %c256, %c1]) {id = 1 : i32} : (memref<512x512xi32>)
scf.yield %7 : !air.async.token
}
}
return
}
Expand Down
Loading

0 comments on commit f3884b6

Please sign in to comment.