Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Add linalg.transpose constant folding #92589

Merged
merged 1 commit into from
May 28, 2024

Conversation

ryan-holt-1
Copy link
Contributor

There was existing support for constant folding a linalg.generic that was actually a transpose. This commit adds support for the named op, linalg.transpose, as well by making use of the LinalgOp interface.

@llvmbot
Copy link
Collaborator

llvmbot commented May 17, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Ryan Holt (ryan-holt-1)

Changes

There was existing support for constant folding a linalg.generic that was actually a transpose. This commit adds support for the named op, linalg.transpose, as well by making use of the LinalgOp interface.


Full diff: https://github.com/llvm/llvm-project/pull/92589.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp (+32-30)
  • (modified) mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir (+12)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 8fffabf11f3fd..2e6079e1402e1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -23,21 +23,21 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 namespace {
-/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
-/// and permutation indexing maps.
+/// Base class for constant folding linalg structured ops with N inputs, 1
+/// output, and permutation indexing maps.
 ///
 /// `ConcreteType` should provide methods with signatures
 ///
 /// ```c++
-///   bool matchIndexingMaps(GenericOp genericOp) const;
-///   RegionComputationFn getRegionComputeFn(GenericOp) const;
+///   bool matchIndexingMaps(LinalgOp linalgOp) const;
+///   RegionComputationFn getRegionComputeFn(LinalgOp) const;
 /// ```
 ///
 /// The latter inspects the region and returns the computation inside as a
 /// functor. The functor will be invoked with constant elements for all inputs
 /// and should return the corresponding computed constant element for output.
 template <typename ConcreteType>
-class FoldConstantBase : public OpRewritePattern<GenericOp> {
+class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
 public:
   struct APIntOrFloat {
     std::optional<APInt> apInt;
@@ -52,25 +52,26 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
   FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
                    PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+        controlFn(controlFn) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
     // Mixed and buffer sematics aren't supported.
-    if (!genericOp.hasPureTensorSemantics())
+    if (!linalgOp.hasPureTensorSemantics())
       return failure();
 
     // Only support ops generating one output for now.
-    if (genericOp.getNumDpsInits() != 1)
+    if (linalgOp.getNumDpsInits() != 1)
       return failure();
 
-    auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
+    auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
     // Require the output types to be static given that we are generating
     // constants.
     if (!outputType || !outputType.hasStaticShape())
       return failure();
 
-    if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
+    if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
           return isa<ShapedType>(input.getType());
         }))
       return failure();
@@ -80,7 +81,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
       return cast<ShapedType>(value.getType()).getElementType();
     };
     if (!llvm::all_equal(
-            llvm::map_range(genericOp->getOperands(), getOperandElementType)))
+            llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
       return failure();
 
     // We can only handle the case where we have int/float elements.
@@ -93,30 +94,30 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
     // entirely in the compiler, without needing to turn all indices into
     // Values, and then do affine apply on them, and then match back the
     // constant again.
-    if (!llvm::all_of(genericOp.getIndexingMapsArray(),
+    if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
                       [](AffineMap map) { return map.isPermutation(); }))
       return failure();
 
-    for (OpOperand &operand : genericOp.getDpsInitsMutable()) {
-      if (genericOp.payloadUsesValueFromOperand(&operand))
+    for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
+      if (linalgOp.payloadUsesValueFromOperand(&operand))
         return failure();
     }
 
     // Further check the indexing maps are okay for the ConcreteType.
-    if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
+    if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
       return failure();
 
     // Defer to the concrete type to check the region and discover the
     // computation inside.
     RegionComputationFn computeFn =
-        static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
+        static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
     if (!computeFn)
       return failure();
 
     // All inputs should be constants.
-    int numInputs = genericOp.getNumDpsInputs();
+    int numInputs = linalgOp.getNumDpsInputs();
     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
-    for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
       if (!matchPattern(en.value()->get(),
                         m_Constant(&inputValues[en.index()])))
         return failure();
@@ -124,12 +125,11 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
     // Identified this as a potential candidate for folding. Now check the
     // policy to see whether we are allowed to proceed.
-    for (OpOperand *operand : genericOp.getDpsInputOperands()) {
+    for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
       if (!controlFn(operand))
         return failure();
     }
 
-    auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
     int64_t numElements = outputType.getNumElements();
 
@@ -155,8 +155,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
     SmallVector<SmallVector<unsigned>> inputDims;
     for (int i = 0; i < numInputs; ++i)
-      inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
-    auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
+      inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
+    auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
     auto outputShape = outputType.getShape();
 
     // Allocate small vectors for index delinearization. Initial values do not
@@ -173,7 +173,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
     APIntOrFloatArray computeFnInputs;
 
     auto inputShapes = llvm::to_vector<4>(
-        llvm::map_range(genericOp.getInputs(), [](Value value) {
+        llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
           return cast<ShapedType>(value.getType()).getShape();
         }));
 
@@ -254,7 +254,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
                 : DenseElementsAttr::get(outputType, intOutputValues);
 
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
     return success();
   }
 
@@ -262,18 +262,20 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
   ControlFusionFn controlFn;
 };
 
-// Folds linalg.generic ops that are actually transposes on constant values.
+// Folds linalg.transpose (and linalg.generic ops that are actually transposes)
+// on constant values.
 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
+
   using FoldConstantBase::FoldConstantBase;
 
-  bool matchIndexingMaps(GenericOp genericOp) const {
+  bool matchIndexingMaps(LinalgOp linalgOp) const {
     // We should have one input and one output.
-    return genericOp.getIndexingMapsArray().size() == 2;
+    return linalgOp.getIndexingMapsArray().size() == 2;
   }
 
-  RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
+  RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
     // Make sure the region only contains a yield op.
-    Block &body = genericOp.getRegion().front();
+    Block &body = linalgOp->getRegion(0).front();
     if (!llvm::hasSingleElement(body))
       return nullptr;
     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 15a4f6cdd3bbe..70f43885712b7 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -910,6 +910,18 @@ func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tenso
 
 // -----
 
+// CHECK-LABEL: @named_transpose_fold_2d_fp32
+func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+  //               CHECK: %[[CST:.+]] = arith.constant
+  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+  %1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0]
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
 // Fusing the broadcast into a reduction would require to insert extra knowledge
 // about the size of the reduction dimension. As long, as this is not
 // implemented, we check that two linalg operations remain.

@ryan-holt-1
Copy link
Contributor Author

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change itself looks fine, but fair warning, these blow up compilation time to be completely unreasonable, and is probably not the best way to handle constant folding. FWIW, IREE dropped the use of such patterns.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually could you move this test (and similar tests) into a separate file cause they are here only due to legacy reason. This has nothing to do with fusion. Not blocking, but would much appreciate.

There was existing support for constant folding a `linalg.generic` that
was actually a transpose. This commit adds support for the named op,
`linalg.transpose`, as well by making use of the `LinalgOp` interface.
@ryan-holt-1
Copy link
Contributor Author

Thanks for the heads up. We have not run into compilation time problems with these yet but we will keep that in mind.

I moved these tests out of fusion-elementwise-ops.mlir. However, linalg-fuse-elementwise-ops is currently the only pass which uses this pattern. Perhaps this constant folding pattern should be removed from there as well and placed into a separate linalg-constant-fold pass?

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks ok to me.

@ryan-holt-1 ryan-holt-1 merged commit 74ed79f into llvm:main May 28, 2024
4 checks passed
vg0204 pushed a commit to vg0204/llvm-project that referenced this pull request May 29, 2024
There was existing support for constant folding a `linalg.generic` that
was actually a transpose. This commit adds support for the named op,
`linalg.transpose`, as well by making use of the `LinalgOp` interface.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants