Skip to content

Commit

Permalink
[Flow] Add pass statistics to ConvertDispatchRegionsToWorkgroups. (#…
Browse files Browse the repository at this point in the history
…17900)

The pass statistics is used to track the number of dispatches created.
The `iree-scheduling-dump-statistics-file` emits statistics for the
whole program, this includes all the dispatches formed for const-expr
hoisted ops. These statistics are more aimed towards developers for
tracking pass-level statistics.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored Jul 16, 2024
1 parent 6b87a9f commit 058432d
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 8 deletions.
1 change: 1 addition & 0 deletions build_tools/cmake/iree_llvm.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ macro(iree_llvm_set_bundled_cmake_options)
set(LLVM_ENABLE_TERMINFO OFF CACHE BOOL "Default disable")
set(LLVM_ENABLE_ZLIB OFF CACHE BOOL "Default disable")
set(LLVM_ENABLE_ZSTD OFF CACHE BOOL "Default disable")
set(LLVM_FORCE_ENABLE_STATS ON CACHE BOOL "Default enable")

# LLVM defaults to building all targets. We always enable targets that we need
# as we need them, so default to none. The user can override this as needed,
Expand Down
1 change: 1 addition & 0 deletions build_tools/llvm/llvm_config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ set(LLVM_ENABLE_TERMINFO OFF CACHE BOOL "")
set(LLVM_ENABLE_Z3_SOLVER OFF CACHE BOOL "")
set(LLVM_INCLUDE_DOCS OFF CACHE BOOL "")
set(LLVM_INCLUDE_GO_TESTS OFF CACHE BOOL "")
set(LLVM_FORCE_ENABLE_STATS ON CACHE BOOL "")

# Do not store debug information by default.
set(CMAKE_BUILD_TYPE Release CACHE STRING "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ void ConvertDispatchRegionsToWorkgroupsPass::runOnOperation() {
SmallVector<IREE::Flow::DispatchRegionOp> regionOps;
funcOp.walk([&](Flow::DispatchRegionOp op) { regionOps.push_back(op); });

numDispatches += regionOps.size();

// Clone additional producers and rewrite to DispatchWorkgroupsOp.
for (auto regionOp : regionOps) {
auto maybeWorkgroupOp =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ wrapInWorkgroupsOp(mlir::TensorDimTrackingRewriter &rewriter,
}

/// Rewrite top-level InsertSliceOps to FlowUpdateOps or wrap them in a
/// dispatch region.
static LogicalResult convertInsertSliceOps(
/// dispatch region. Returns the number of dispatches for non-contiguous insert
/// slices created.
static FailureOr<int> convertInsertSliceOps(
mlir::TensorDimTrackingRewriter &rewriter, mlir::FunctionOpInterface funcOp,
SmallVector<IREE::Flow::DispatchWorkgroupsOp> &workgroupsOps) {
// Find eligible InsertSliceOps.
Expand All @@ -94,6 +95,8 @@ static LogicalResult convertInsertSliceOps(
remainingInsertSliceOps.push_back(insertSliceOp);
}
}
int64_t numRemainingInsertSliceOps =
static_cast<int64_t>(remainingInsertSliceOps.size());

// Create a DispatchWorkgroupsOp for every remaining InsertSliceOp.
FailureOr<SmallVector<IREE::Flow::DispatchWorkgroupsOp>> newWorkgroupsOps =
Expand All @@ -102,12 +105,13 @@ static LogicalResult convertInsertSliceOps(
return failure();
workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end());

return success();
return numRemainingInsertSliceOps;
}

/// Rewrite top-level ExtractSliceOps to FlowSliceOps or wrap them in a
/// dispatch region.
static LogicalResult convertExtractSliceOps(
/// dispatch region. Returns the number of dispatches for non-contiguous extract
/// slices created.
static FailureOr<size_t> convertExtractSliceOps(
mlir::TensorDimTrackingRewriter &rewriter, mlir::FunctionOpInterface funcOp,
SmallVector<IREE::Flow::DispatchWorkgroupsOp> &workgroupsOps) {
// Find eligible ExtractSliceOps.
Expand All @@ -125,14 +129,17 @@ static LogicalResult convertExtractSliceOps(
}
}

int64_t numRemainingExtractSliceOps =
static_cast<int64_t>(remainingExtractSliceOps.size());

// Create a DispatchWorkgroupsOp for every remaining ExtractSliceOp.
FailureOr<SmallVector<IREE::Flow::DispatchWorkgroupsOp>> newWorkgroupsOps =
wrapInWorkgroupsOp(rewriter, remainingExtractSliceOps);
if (failed(newWorkgroupsOps))
return failure();
workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end());

return success();
return numRemainingExtractSliceOps;
}

namespace {
Expand All @@ -156,18 +163,24 @@ void ConvertTensorToFlowPass::runOnOperation() {
});

// Rewrite InsertSliceOps to FlowUpdateOps.
if (failed(convertInsertSliceOps(rewriter, funcOp, workgroupsOps))) {
FailureOr<size_t> numSlowInsertSliceDispatches =
convertInsertSliceOps(rewriter, funcOp, workgroupsOps);
if (failed(numSlowInsertSliceDispatches)) {
funcOp->emitOpError(
"failed to create dispatch region for `tensor.insert_slice`");
return signalPassFailure();
}
numSlowCopyDispatches += numSlowInsertSliceDispatches.value();

// Rewrite ExtractSliceOps to FlowUpdateOps.
if (failed(convertExtractSliceOps(rewriter, funcOp, workgroupsOps))) {
FailureOr<size_t> numSlowExtractSliceDispatches =
convertExtractSliceOps(rewriter, funcOp, workgroupsOps);
if (failed(numSlowExtractSliceDispatches)) {
funcOp->emitOpError(
"failed to create dispatch region for `tensor.extract_slice`");
return signalPassFailure();
}
numSlowCopyDispatches += numSlowExtractSliceDispatches.value();

// Canonicalize to flow.tensor ops.
RewritePatternSet convertToFlowPatterns(context);
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def ConvertDispatchRegionsToWorkgroupsPass :
"mlir::tensor::TensorDialect",
"IREE::Flow::FlowDialect",
];
let statistics = [
Statistic<"numDispatches", "num-dispatches", "Number of dispatches created">
];
}

def ConvertTensorToFlowPass :
Expand All @@ -182,6 +185,10 @@ def ConvertTensorToFlowPass :
"mlir::tensor::TensorDialect",
"IREE::Flow::FlowDialect",
];
let statistics = [
Statistic<"numSlowCopyDispatches", "num-slow-copy-dispatches",
"Number of slow copy dispatches (for handling slices) created">
];
}

def DispatchWithTransformDialectPass : Pass<"iree-flow-dispatch-with-transform-dialect"> {
Expand Down

0 comments on commit 058432d

Please sign in to comment.