diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp index fe49c44c71ff..c9cd79d8affa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp @@ -13,6 +13,8 @@ // utility method. // //===---------------------------------------------------------------------===// + +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h" #include "iree/compiler/Codegen/Common/Transforms.h" @@ -291,9 +293,10 @@ struct TileAndDistributeToWorkgroupsPass : public TileAndDistributeToWorkgroupsBase< TileAndDistributeToWorkgroupsPass> { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override; @@ -409,13 +412,14 @@ void TileAndDistributeToWorkgroupsPass::runOnOperation() { }); { - // Apply linalg tiling optimization patterns. - RewritePatternSet canonicalizationPatterns(context); - linalg::populateLinalgTilingCanonicalizationPatterns( - canonicalizationPatterns); - populateFoldAffineMinInDistributedLoopsPatterns(canonicalizationPatterns); - if (failed(applyPatternsAndFoldGreedily( - funcOp, std::move(canonicalizationPatterns)))) { + // Apply linalg tiling optimization patterns, which includes folding + // casting ops into tiled operations. + RewritePatternSet patterns(context); + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + populateFoldAffineMinInDistributedLoopsPatterns(patterns); + context->getOrLoadDialect() + ->getCanonicalizationPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { funcOp.emitOpError("tiling canonicalizations failed"); return signalPassFailure(); }