Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIRSpecializeChannelWrapAndStride: More flexible wrap-and-stride offset canonicalization #791

Merged
merged 7 commits into from
Nov 25, 2024
55 changes: 42 additions & 13 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,24 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
builder.getUnknownLoc(), (*const_size) * (*const_size_next));
return true;
};
// For a given offset[i], find the first offset[j] such that stride[j] is
// divisible by stride[i], so that offset[i] can be composed onto offset[j].
auto findFirstComposableOffsetIdx = [](int i, SmallVector<Value> offsets,
SmallVector<Value> strides) {
auto constStrideI = getConstantIntValue(strides[i]);
std::optional<int> output = std::nullopt;
for (int j = i + 1; j < (int)strides.size(); j++) {
if (!getConstantIntValue(offsets[j]))
continue; // Currently unable to compose offset[i] expr onto another
// offset[j] expr.
auto constStrideJ = getConstantIntValue(strides[j]);
if ((*constStrideI) % (*constStrideJ) == 0) {
output = j;
return output;
}
}
return output;
};
for (auto i : erase_dims) {
auto const_offset = getConstantIntValue(offsets[i]);
if (const_offset && *const_offset == 0) {
Expand All @@ -855,13 +873,14 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
continue;
auto const_stride = getConstantIntValue(strides[i]);
assert(const_stride && "non-static stride, NYI.");
auto const_offset_next = getConstantIntValue(offsets[i + 1]);
if (!const_offset_next)
auto j = findFirstComposableOffsetIdx(i, offsets, strides);
if (!j)
continue;
auto const_stride_next = getConstantIntValue(strides[i + 1]);
assert(const_stride_next && "non-static stride, NYI.");
auto const_offset_next = getConstantIntValue(offsets[*j]);
auto const_stride_next = getConstantIntValue(strides[*j]);
// Attempting to compose i-th offset onto another offset.
if (const_offset) {
offsets[i + 1] = builder.create<arith::ConstantIndexOp>(
offsets[*j] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(),
(*const_stride) * (*const_offset) / (*const_stride_next) +
(*const_offset_next));
Expand Down Expand Up @@ -912,7 +931,7 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
auto next_offset_map = AffineMap::get(0, 1, offset_expr);
affine_apply.setMap(next_offset_map);
offsets[i] = affine_apply;
offsets[i + 1] = offsets[i];
offsets[*j] = offsets[i];
}
erased |= multiplyAdjWraps(builder, i, sizes);
offsets.erase(offsets.begin() + i);
Expand Down Expand Up @@ -1029,6 +1048,12 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides(
}

std::map<Operation *, int> op_to_count;
// Evaluate offset from affine map.
auto evalOffsetFromAffineMap = [&](MLIRContext *ctx, AffineMap map) {
return air::evaluateConstantsInMap(
map, SmallVector<std::optional<int64_t>>{std::optional<int64_t>{0}},
ctx);
};
for (auto o : for_loops) {
int64_t stepSize = -1;
int loop_lower_bound = 0;
Expand Down Expand Up @@ -1067,14 +1092,18 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides(
if (iv_is_symbol) {
auto map = affop.getAffineMap();
ind_var_factor = *getConstantIntValue(strides[i]);
ind_var_factor *= air::evaluateConstantsInMap(
map,
SmallVector<std::optional<int64_t>>{
std::optional<int64_t>{stepSize}},
for_op->getContext())
.value();
int64_t map_offset =
evalOffsetFromAffineMap(for_op->getContext(), map).value();
int64_t map_gradient = air::evaluateConstantsInMap(
map,
SmallVector<std::optional<int64_t>>{
std::optional<int64_t>{stepSize}},
for_op->getContext())
.value() -
map_offset;
ind_var_factor *= map_gradient;
offsets[i] = builder.template create<arith::ConstantIndexOp>(
loc, loop_lower_bound);
loc, loop_lower_bound + map_offset);
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ module {
// Offset propagation with wrap-and-stride canonicalization.
// CHECK-LABEL: test9
// CHECK: %[[VAL0:.*]] = affine.apply #map()[%arg1]
// CHECK: put @channel_21[] (%arg0[%c0, %c0, %[[VAL0]], %c0] [%c8, %c2, %c32, %c32] [%c32, %c8192, %c256, %c1]) : (memref<128x256xi32>)
// CHECK: put @channel_21[] (%arg0[%c0, %c0, %[[VAL0]]] [%c8, %c64, %c32] [%c32, %c256, %c1]) : (memref<128x256xi32>)
// CHECK: air.channel.put @channel_22[] (%arg2[%c256, %c0, %c0] [%c8, %c32, %c4] [%c4, %c32, %c1]) : (memref<1x2x32x32xi32, 1 : i32>)
// CHECK: air.channel.put @channel_23[] (%arg3[%c128, %c0, %c0] [%c4, %c32, %c8] [%c8, %c32, %c1]) : (memref<2x1x32x32xi32, 1 : i32>)
// CHECK: %[[VAL1:.*]] = affine.apply
Expand Down Expand Up @@ -386,25 +386,32 @@ module {
// Affine.apply with map joining two for loops in a loop nest.
// CHECK-LABEL: test11

// CHECK: air.channel.put async [%{{.*}}] @channel_26[%c0, %c0] (%{{.*}}[%c0, %c0, %c0] [%c4_0, %c18, %c4_0] [%c96, %c16, %c1]) : (memref<1x6x6x16xbf16, 1>)
// CHECK: air.channel.put async {{.*}}@channel_26[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}] [%c4{{.*}}, %c18{{.*}}, %c4{{.*}}] [%c96{{.*}}, %c16{{.*}}, %c1{{.*}}]) : (memref<1x6x6x16xbf16, 1>)
// CHECK: air.channel.put async {{.*}}@channel_26[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}, %c12{{.*}}] [%c3{{.*}}, %c3{{.*}}, %c4{{.*}}, %c4{{.*}}] [%c96{{.*}}, %c16{{.*}}, %c16{{.*}}, %c1{{.*}}]) : (memref<1x3x6x16xi32, 1>)

func.func @test11() {
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c3, %arg7=%c3, %arg8=%c4) {
%1 = air.segment @segment_0 async {
%c576 = arith.constant 576 : index
%c288 = arith.constant 288 : index
%c96 = arith.constant 96 : index
%c3_0 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c12 = arith.constant 12 : index
%c6 = arith.constant 6 : index
%c0 = arith.constant 0 : index
%c4_1 = arith.constant 4 : index
%async_token, %results = air.execute -> (memref<1x6x6x16xbf16, 1>) {
%alloc = memref.alloc() : memref<1x6x6x16xbf16, 1>
air.execute_terminator %alloc : memref<1x6x6x16xbf16, 1>
}
%async_token_23, %results_25 = air.execute -> (memref<1x3x6x16xi32, 1>) {
%alloc = memref.alloc() : memref<1x3x6x16xi32, 1>
air.execute_terminator %alloc : memref<1x3x6x16xi32, 1>
}
%4 = scf.for %arg9 = %c0 to %c4_1 step %c1 iter_args(%arg13 = %async_token) -> (!air.async.token) {
%2 = scf.for %arg10 = %c0 to %c3_0 step %c1 iter_args(%arg11 = %arg13) -> (!air.async.token) {
%async_token_2, %results_3 = air.execute [%arg11] -> (index) {
Expand All @@ -416,6 +423,15 @@ module {
}
scf.yield %2 : !air.async.token
}
scf.for %arg9 = %c0 to %c3_0 step %c1 {
%60 = scf.for %arg10 = %c0 to %c3_0 step %c1 iter_args(%arg13 = %async_token) -> (!air.async.token) {
%async_token_54, %results_55 = air.execute [%arg13] -> (index) {
air.execute_terminator %arg9 : index
}
%61 = air.channel.put async [%async_token_54] @channel_26[%c0, %c0] (%results_25[%c0, %results_55, %arg10, %c12] [%c1, %c1, %c4_1, %c4_1] [%c288, %c96, %c16, %c1]) : (memref<1x3x6x16xi32, 1>)
scf.yield %61 : !air.async.token
}
}
}
}
return
Expand Down Expand Up @@ -460,10 +476,11 @@ module {
// CHECK-LABEL: test13

// CHECK: air.channel.put async [%{{.*}}] @channel_14[] (%{{.*}}[%c0, %1, %results, %c0] [%c8, %c2_0, %c32, %c32] [%c32, %c8192, %c256, %c1]) : (memref<2x128x256xi32>)
// CHECK: air.channel.put async [%{{.*}}] @channel_15[%c0, %c0] (%{{.*}}[%c0, %results, %c32768] [%c8, %c32, %c32] [%c32, %c256, %c1]) : (memref<512x512xi32>)

func.func @test13(%arg0: memref<2x128x256xi32>, %arg1: memref<2x256x128xi32>) {
func.func @test13(%arg0: memref<2x128x256xi32>, %arg1: memref<512x512xi32>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c2, %arg7=%c2, %arg8=%c2) args(%arg10=%arg0, %arg11=%arg1) : memref<2x128x256xi32>, memref<2x256x128xi32> {
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c2, %arg7=%c2, %arg8=%c2) args(%arg10=%arg0, %arg11=%arg1) : memref<2x128x256xi32>, memref<512x512xi32> {
%c4096 = arith.constant 4096 : index
%c8 = arith.constant 8 : index
%c16384 = arith.constant 16384 : index
Expand All @@ -484,6 +501,10 @@ module {
%7 = air.channel.put async [%arg13, %async_token] @channel_14[] (%arg10[%arg3, %c0, %c0, %results, %arg12] [%c1, %c2_0, %c1, %c32, %c32] [%c32768, %c8192, %c32, %c256, %c1]) : (memref<2x128x256xi32>)
scf.yield %7 : !air.async.token
}
%3 = scf.for %arg12 = %c0 to %c256 step %c32 iter_args(%arg13 = %async_token) -> (!air.async.token) {
%7 = air.channel.put async [%arg13, %async_token] @channel_15[%c0, %c0] (%arg11[%c2_0, %c0, %results, %arg12] [%c1, %c1, %c32, %c32] [%c16384, %c32, %c256, %c1]) {id = 1 : i32} : (memref<512x512xi32>)
scf.yield %7 : !air.async.token
}
}
return
}
Expand Down
Loading
Loading