Skip to content

Commit

Permalink
Add LinalgExt canonicalization patterns to tile+distribute pass. (#11143
Browse files Browse the repository at this point in the history
)

The cast operations can be created during tiling. We have to fold it
into the tiled op if possible. The pattern for LinAlg dialect is
included, but the pattern for LinalgExt dialect is not included. This
commit adds the pattern to the pass.

Fixes #11038
  • Loading branch information
hanhanW authored Nov 11, 2022
1 parent a34e553 commit 4ffe00d
Showing 1 changed file with 14 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// utility method.
//
//===---------------------------------------------------------------------===//

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
Expand Down Expand Up @@ -291,9 +293,10 @@ struct TileAndDistributeToWorkgroupsPass
: public TileAndDistributeToWorkgroupsBase<
TileAndDistributeToWorkgroupsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, IREE::Flow::FlowDialect,
IREE::HAL::HALDialect, linalg::LinalgDialect,
scf::SCFDialect, tensor::TensorDialect>();
registry
.insert<AffineDialect, IREE::Flow::FlowDialect, IREE::HAL::HALDialect,
linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect,
scf::SCFDialect, tensor::TensorDialect>();
}

void runOnOperation() override;
Expand Down Expand Up @@ -409,13 +412,14 @@ void TileAndDistributeToWorkgroupsPass::runOnOperation() {
});

{
// Apply linalg tiling optimization patterns.
RewritePatternSet canonicalizationPatterns(context);
linalg::populateLinalgTilingCanonicalizationPatterns(
canonicalizationPatterns);
populateFoldAffineMinInDistributedLoopsPatterns(canonicalizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(canonicalizationPatterns)))) {
// Apply linalg tiling optimization patterns, which includes folding
// casting ops into tiled operations.
RewritePatternSet patterns(context);
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
populateFoldAffineMinInDistributedLoopsPatterns(patterns);
context->getOrLoadDialect<IREE::LinalgExt::IREELinalgExtDialect>()
->getCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
funcOp.emitOpError("tiling canonicalizations failed");
return signalPassFailure();
}
Expand Down

0 comments on commit 4ffe00d

Please sign in to comment.