Skip to content

Commit

Permalink
remove createPass wrapper
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 committed Sep 19, 2024
1 parent b1085ae commit 352dcb0
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-decompose-pack-unpack-ops"
Expand All @@ -31,6 +30,8 @@ namespace mlir::iree_compiler {

namespace {

using PackUnPackControlFn = std::function<LogicalResult(Operation *)>;

/// A wrapper pattern that calls linalg::lowerPack on tensor::PackOp. It lowers
/// a tensor.pack op to tensor.pad + tensor.expand_shape + linalg.transpose ops.
struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
Expand Down
11 changes: 0 additions & 11 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,6 @@ using ConfigFn =
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvolutionToIGEMMPass(ConfigFn configFn);

using PackUnPackControlFn = std::function<LogicalResult(Operation *)>;
// Pass to decompose pack and unpack ops into pad/extract_slice and reshape ops.
// If specified, `controlFn` controls which ops get decomposed. The `controlFn`
// should be used with `useOnlyReshapes` set to true.
// TODO(Max191): Add a controlFn upstream for `GeneralizeOuterUnitDim*` patterns
// and remove the need to have `useOnlyReshapes = true` when using `controlFn`.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createDecomposePackUnPackOpsPass(
bool tileOuterToOne, bool useOnlyReshapes = false,
std::optional<PackUnPackControlFn> controlFn = std::nullopt);

std::unique_ptr<Pass> createDecomposeSoftmaxPass(bool useFusion);

/// Pass to perform linalg on tensor bufferization. The function passed into
Expand Down
10 changes: 9 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,15 @@ def DecomposePackUnPackOpsPass :
Option<"tileOuterToOne", "tile-outer-to-one", "bool", "false",
"Always apply tiling to make outer dimension be ones">,
Option<"useOnlyReshapes", "use-only-reshapes", "bool", "false",
"Use decomposition into reshape ops, even when packing unit dimensions.">
"Use decomposition into reshape ops, even when packing unit dimensions.">,
// If specified, `controlFn` controls which ops get decomposed.
// TODO(Max191): The `controlFn` should be used with `useOnlyReshapes` set
// to true. We should add a controlFn upstream for `GeneralizeOuterUnitDim*`
// patterns and remove the need to have `useOnlyReshapes = true` when using
// `controlFn`.
Option<"controlFn", "control-fn",
"std::optional<std::function<LogicalResult(Operation *)>>", "std::nullopt",
"Controls which pack and unpack ops get decomposed.">
];
}

Expand Down
15 changes: 10 additions & 5 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,11 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
}

// Step 3. Decompose pack and unpack ops and propagate the resulting reshapes.
funcPassManager.addPass(
createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/false,
/*useOnlyReshapes=*/true));
{
DecomposePackUnPackOpsPassOptions options;
options.useOnlyReshapes = true;
funcPassManager.addPass(createDecomposePackUnPackOpsPass(options));
}

// Step 3.5. Expand the inner dimensions of MultiMma ops in preparation for
// distribution to lanes.
Expand Down Expand Up @@ -945,8 +947,11 @@ void addGPUPackUnPackPasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(
createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/true));
{
DecomposePackUnPackOpsPassOptions options;
options.tileOuterToOne = true;
funcPassManager.addPass(createDecomposePackUnPackOpsPass(options));
}
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
addGPUVectorizationPasses(funcPassManager);
Expand Down

0 comments on commit 352dcb0

Please sign in to comment.