From 2e7aa930c5e99112d69e90c6cafaf7659e693389 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Wed, 6 Mar 2024 15:03:04 -0500 Subject: [PATCH 1/6] [BugFix] : Move DimOp canonicalization from memref to tensor. --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 - mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 1 - mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 33 -------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 27 ++++++- mlir/test/Dialect/MemRef/canonicalize.mlir | 42 ---------- mlir/test/Dialect/Tensor/canonicalize.mlir | 80 +++++++++++++++++++ 6 files changed, 106 insertions(+), 78 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index c71517666b609c..2333c92fd7b12c 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [ Speculation::Speculatability getSpeculatability(); }]; - let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index b0a4de2da1e869..e1cb5b477debbc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) { MLIRContext *context = enclosingOp->getContext(); RewritePatternSet patterns(context); patterns.add>(context); - memref::DimOp::getCanonicalizationPatterns(patterns, context); tensor::DimOp::getCanonicalizationPatterns(patterns, context); affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 248193481acfc6..00b7fa122a6c96 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { return {}; } -namespace { -/// Fold dim of a memref reshape operation to a load into the reshape's shape -/// operand. -struct DimOfMemRefReshape : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dim, - PatternRewriter &rewriter) const override { - auto reshape = dim.getSource().getDefiningOp(); - - if (!reshape) - return failure(); - - // Place the load directly after the reshape to ensure that the shape memref - // was not mutated. - rewriter.setInsertionPointAfter(reshape); - Location loc = dim.getLoc(); - Value load = - rewriter.create(loc, reshape.getShape(), dim.getIndex()); - if (load.getType() != dim.getType()) - load = rewriter.create(loc, dim.getType(), load); - rewriter.replaceOp(dim, load); - return success(); - } -}; - -} // namespace - -void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fe2f250e6b9290..ce9792f813cbb3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -824,11 +824,36 @@ struct DimOfDestStyleOp : public OpRewritePattern { return success(); } }; + +/// Fold dim of a tensor reshape operation to a extract into the reshape's shape +/// operand. +struct DimOfReshapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dim, + PatternRewriter &rewriter) const override { + auto reshape = dim.getSource().getDefiningOp(); + + if (!reshape) + return failure(); + + // Since tensors are immutable we don't need to worry about where to place + // the load call + rewriter.setInsertionPointAfter(dim); + Location loc = dim.getLoc(); + Value load = + rewriter.create(loc, reshape.getShape(), dim.getIndex()); + if (load.getType() != dim.getType()) + load = rewriter.create(loc, dim.getType(), load); + rewriter.replaceOp(dim, load); + return success(); + } +}; } // namespace void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index a772a25da57382..0054a8ac785a89 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { // ----- -// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] -// CHECK-LABEL: func @dim_of_memref_reshape( -// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, -// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref -// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 -// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] -// CHECK-NEXT: memref.store -// CHECK-NOT: memref.dim -// CHECK: return %[[DIM]] : index -func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref) - -> index { - %c3 = arith.constant 3 : index - %0 = memref.reshape %arg0(%arg1) - : (memref<*xf32>, memref) -> memref<*xf32> - // Update the shape to test that he load ends up in the right place. - memref.store %c3, %arg1[%c3] : memref - %1 = memref.dim %0, %c3 : memref<*xf32> - return %1 : index -} - -// ----- - -// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] -// CHECK-LABEL: func @dim_of_memref_reshape_i32( -// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, -// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref -// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 -// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] -// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]] -// CHECK-NOT: memref.dim -// CHECK: return %[[CAST]] : index -func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref) - -> index { - %c3 = arith.constant 3 : index - %0 = memref.reshape %arg0(%arg1) - : (memref<*xf32>, memref) -> memref<*xf32> - %1 = memref.dim %0, %c3 : memref<*xf32> - return %1 : index -} - -// ----- - // CHECK-LABEL: func @alloc_const_fold func.func @alloc_const_fold() -> memref { // CHECK-NEXT: memref.alloc() : memref<4xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index d17c23adfb14d8..45d37c553a0025 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles // CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] // CHECK: return %[[SRC]] + +// ----- + +// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> memref.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor +// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 +// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]] +// CHECK-NOT: tensor.store +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +// CHECK: return %[[DIM]] : index +func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor) + -> index { + %c3 = arith.constant 3 : index + %0 = tensor.reshape %arg0(%arg1) + : (tensor<*xf32>, tensor) -> tensor<*xf32> + // Update the shape to test that the load ends up in the right place. + tensor.insert %c3 into %arg1[%c3] : tensor + %1 = tensor.dim %0, %c3 : tensor<*xf32> + return %1 : index +} + +// ----- + +// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape_i32( +// CHECK: tensor.extract +// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +// CHECK: return %[[CAST]] : index +func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor) + -> index { + %c3 = arith.constant 3 : index + %0 = tensor.reshape %arg0(%arg1) + : (tensor<*xf32>, tensor) -> tensor<*xf32> + %1 = tensor.dim %0, %c3 : tensor<*xf32> + return %1 : index +} + +// ----- + +// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape_for( +// CHECK: scf.for +// CHECK-NEXT: tensor.extract +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + + %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + + %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) { + %2 = tensor.dim %0, %arg2 : tensor<*xf32> + %3 = arith.muli %arg3, %2 : index + scf.yield %3 : index + } + return %1 : index +} + +// ----- + +// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape_undominated( +// CHECK: arith.muli +// CHECK-NEXT: tensor.extract +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: index) -> index { + %c4 = arith.constant 4 : index + %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %0 = arith.muli %arg2, %c4 : index + %dim = tensor.dim %reshape, %0 : tensor<*xf32> + return %dim : index + } From 0de5e5997181a2fefb5a86dda1193b1c37a18382 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Wed, 6 Mar 2024 15:03:04 -0500 Subject: [PATCH 2/6] [BugFix] : Move DimOp canonicalization from memref to tensor. --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 - mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 1 - mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 33 -------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 28 ++++++- mlir/test/Dialect/MemRef/canonicalize.mlir | 42 ---------- mlir/test/Dialect/Tensor/canonicalize.mlir | 80 +++++++++++++++++++ 6 files changed, 107 insertions(+), 78 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index c71517666b609c..2333c92fd7b12c 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [ Speculation::Speculatability getSpeculatability(); }]; - let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index b0a4de2da1e869..e1cb5b477debbc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) { MLIRContext *context = enclosingOp->getContext(); RewritePatternSet patterns(context); patterns.add>(context); - memref::DimOp::getCanonicalizationPatterns(patterns, context); tensor::DimOp::getCanonicalizationPatterns(patterns, context); affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 248193481acfc6..00b7fa122a6c96 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { return {}; } -namespace { -/// Fold dim of a memref reshape operation to a load into the reshape's shape -/// operand. -struct DimOfMemRefReshape : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dim, - PatternRewriter &rewriter) const override { - auto reshape = dim.getSource().getDefiningOp(); - - if (!reshape) - return failure(); - - // Place the load directly after the reshape to ensure that the shape memref - // was not mutated. - rewriter.setInsertionPointAfter(reshape); - Location loc = dim.getLoc(); - Value load = - rewriter.create(loc, reshape.getShape(), dim.getIndex()); - if (load.getType() != dim.getType()) - load = rewriter.create(loc, dim.getType(), load); - rewriter.replaceOp(dim, load); - return success(); - } -}; - -} // namespace - -void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fe2f250e6b9290..038b8c3122b527 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern { return success(); } }; + +/// Fold dim of a tensor reshape operation to a extract into the reshape's shape +/// operand. +struct DimOfReshapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dim, + PatternRewriter &rewriter) const override { + auto reshape = dim.getSource().getDefiningOp(); + + if (!reshape) + return failure(); + + // Since tensors are immutable we don't need to worry about where to place + // the extract call + rewriter.setInsertionPointAfter(dim); + Location loc = dim.getLoc(); + Value extract = + rewriter.create(loc, reshape.getShape(), dim.getIndex()); + if (extract.getType() != dim.getType()) + extract = + rewriter.create(loc, dim.getType(), extract); + rewriter.replaceOp(dim, extract); + return success(); + } +}; } // namespace void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index a772a25da57382..0054a8ac785a89 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { // ----- -// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] -// CHECK-LABEL: func @dim_of_memref_reshape( -// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, -// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref -// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 -// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] -// CHECK-NEXT: memref.store -// CHECK-NOT: memref.dim -// CHECK: return %[[DIM]] : index -func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref) - -> index { - %c3 = arith.constant 3 : index - %0 = memref.reshape %arg0(%arg1) - : (memref<*xf32>, memref) -> memref<*xf32> - // Update the shape to test that he load ends up in the right place. - memref.store %c3, %arg1[%c3] : memref - %1 = memref.dim %0, %c3 : memref<*xf32> - return %1 : index -} - -// ----- - -// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] -// CHECK-LABEL: func @dim_of_memref_reshape_i32( -// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, -// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref -// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 -// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] -// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]] -// CHECK-NOT: memref.dim -// CHECK: return %[[CAST]] : index -func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref) - -> index { - %c3 = arith.constant 3 : index - %0 = memref.reshape %arg0(%arg1) - : (memref<*xf32>, memref) -> memref<*xf32> - %1 = memref.dim %0, %c3 : memref<*xf32> - return %1 : index -} - -// ----- - // CHECK-LABEL: func @alloc_const_fold func.func @alloc_const_fold() -> memref { // CHECK-NEXT: memref.alloc() : memref<4xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index d17c23adfb14d8..1ecf076b6ca2bc 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles // CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] // CHECK: return %[[SRC]] + +// ----- + +// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor +// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 +// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]] +// CHECK-NOT: tensor.store +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +// CHECK: return %[[DIM]] : index +func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor) + -> index { + %c3 = arith.constant 3 : index + %0 = tensor.reshape %arg0(%arg1) + : (tensor<*xf32>, tensor) -> tensor<*xf32> + // Update the shape to test that the load ends up in the right place. + tensor.insert %c3 into %arg1[%c3] : tensor + %1 = tensor.dim %0, %c3 : tensor<*xf32> + return %1 : index +} + +// ----- + +// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape_i32( +// CHECK: tensor.extract +// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +// CHECK: return %[[CAST]] : index +func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor) + -> index { + %c3 = arith.constant 3 : index + %0 = tensor.reshape %arg0(%arg1) + : (tensor<*xf32>, tensor) -> tensor<*xf32> + %1 = tensor.dim %0, %c3 : tensor<*xf32> + return %1 : index +} + +// ----- + +// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape_for( +// CHECK: scf.for +// CHECK-NEXT: tensor.extract +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + + %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + + %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) { + %2 = tensor.dim %0, %arg2 : tensor<*xf32> + %3 = arith.muli %arg3, %2 : index + scf.yield %3 : index + } + return %1 : index +} + +// ----- + +// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape_undominated( +// CHECK: arith.muli +// CHECK-NEXT: tensor.extract +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: index) -> index { + %c4 = arith.constant 4 : index + %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %0 = arith.muli %arg2, %c4 : index + %dim = tensor.dim %reshape, %0 : tensor<*xf32> + return %dim : index + } From 1e9c352b427c369f88b74bf148d6f6ab59434b1b Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 8 Mar 2024 08:33:47 -0500 Subject: [PATCH 3/6] [Task] : Add back memref.dim canonicalization with dominator fix. --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 + mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 46 +++++++++ mlir/test/Dialect/MemRef/canonicalize.mlir | 95 +++++++++++++++++++ mlir/test/Dialect/Tensor/canonicalize.mlir | 4 +- 4 files changed, 144 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 2333c92fd7b12c..c71517666b609c 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -629,6 +629,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [ Speculation::Speculatability getSpeculatability(); }]; + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 00b7fa122a6c96..f69a10334050b9 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1069,6 +1069,52 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { return {}; } +namespace { +/// Fold dim of a memref reshape operation to a load into the reshape's shape +/// operand. +struct DimOfMemRefReshape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dim, + PatternRewriter &rewriter) const override { + auto reshape = dim.getSource().getDefiningOp(); + + if (!reshape) + return rewriter.notifyMatchFailure( + dim, "Dim op is not defined by a reshape op."); + + if (dim.getIndex().getParentBlock() == reshape->getBlock()) { + if (auto *definingOp = dim.getIndex().getDefiningOp()) { + if (reshape->isBeforeInBlock(definingOp)) + return rewriter.notifyMatchFailure( + dim, + "dim.getIndex is not defined before reshape in the same block."); + } // else dim.getIndex is a block argument to reshape->getBlock + } else if (!dim.getIndex().getParentRegion()->isProperAncestor( + reshape->getParentRegion())) + return rewriter.notifyMatchFailure( + dim, "dim.getIndex does not dominate reshape."); + + // Place the load directly after the reshape to ensure that the shape memref + // was not mutated. + rewriter.setInsertionPointAfter(reshape); + Location loc = dim.getLoc(); + Value load = + rewriter.create(loc, reshape.getShape(), dim.getIndex()); + if (load.getType() != dim.getType()) + load = rewriter.create(loc, dim.getType(), load); + rewriter.replaceOp(dim, load); + return success(); + } +}; + +} // namespace + +void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 0054a8ac785a89..0c6157445fbfae 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -242,6 +242,101 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { // ----- +// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref +// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 +// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] +// CHECK-NEXT: memref.store +// CHECK-NOT: memref.dim +// CHECK: return %[[DIM]] : index +func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref) + -> index { + %c3 = arith.constant 3 : index + %0 = memref.reshape %arg0(%arg1) + : (memref<*xf32>, memref) -> memref<*xf32> + // Update the shape to test that he load ends up in the right place. + memref.store %c3, %arg1[%c3] : memref + %1 = memref.dim %0, %c3 : memref<*xf32> + return %1 : index +} + +// ----- + +// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_i32( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref +// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 +// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] +// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]] +// CHECK-NOT: memref.dim +// CHECK: return %[[CAST]] : index +func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref) + -> index { + %c3 = arith.constant 3 : index + %0 = memref.reshape %arg0(%arg1) + : (memref<*xf32>, memref) -> memref<*xf32> + %1 = memref.dim %0, %c3 : memref<*xf32> + return %1 : index +} + +// ----- + +// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref, +// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index +// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] +// CHECK-NOT: memref.dim +// CHECK: return %[[DIM]] : index +func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref, %arg2: index) -> index { + %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref) -> memref<*xf32> + %dim = memref.dim %reshape, %arg2 : memref<*xf32> + return %dim : index +} + +// ----- + +// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_for( +// CHECK: memref.reshape +// CHECK: memref.dim +// CHECK-NOT: memref.load +func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + + %0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref) -> memref<*xf32> + + %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) { + %2 = memref.dim %0, %arg2 : memref<*xf32> + %3 = arith.muli %arg3, %2 : index + scf.yield %3 : index + } + return %1 : index +} + +// ----- + +// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_undominated( +// CHECK: memref.reshape +// CHECK: memref.dim +// CHECK-NOT: memref.load +func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref, %arg2: index) -> index { + %c4 = arith.constant 4 : index + %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref) -> memref<*xf32> + %0 = arith.muli %arg2, %c4 : index + %dim = memref.dim %reshape, %0 : memref<*xf32> + return %dim : index + } + +// ----- + // CHECK-LABEL: func @alloc_const_fold func.func @alloc_const_fold() -> memref { // CHECK-NEXT: memref.alloc() : memref<4xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 1ecf076b6ca2bc..7a4cf0dd2fc50a 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2294,7 +2294,7 @@ func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor) // ----- -// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx] +// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx] // CHECK-LABEL: func @dim_of_reshape_for( // CHECK: scf.for // CHECK-NEXT: tensor.extract @@ -2317,7 +2317,7 @@ func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor) -> // ----- -// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx] +// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx] // CHECK-LABEL: func @dim_of_reshape_undominated( // CHECK: arith.muli // CHECK-NEXT: tensor.extract From e50232affc5a111db92b63df30031031b6ce7bc4 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 8 Mar 2024 09:40:32 -0500 Subject: [PATCH 4/6] [Task] : Add back memref.dim canonicalization to Loops.cpp. --- mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index e1cb5b477debbc..b0a4de2da1e869 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -317,6 +317,7 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) { MLIRContext *context = enclosingOp->getContext(); RewritePatternSet patterns(context); patterns.add>(context); + memref::DimOp::getCanonicalizationPatterns(patterns, context); tensor::DimOp::getCanonicalizationPatterns(patterns, context); affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); From d7f3bd23f7e4c18d0c5598deb950d76ed002b73f Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 8 Mar 2024 15:53:57 -0500 Subject: [PATCH 5/6] [Task] : Add comments + enhance check for index in parent block of reshape. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 25 ++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index f69a10334050b9..3594b9669e3c6d 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1083,17 +1083,34 @@ struct DimOfMemRefReshape : public OpRewritePattern { return rewriter.notifyMatchFailure( dim, "Dim op is not defined by a reshape op."); + // dim of a memref reshape can be folded if dim.getIndex() dominates the + // reshape. Instead of using `DominanceInfo` (which is usually costly) we + // cheaply check that either of the following conditions hold: + // 1. dim.getIndex() is defined in the same block as reshape but before + // reshape. + // 2. dim.getIndex() is defined in a parent block of + // reshape. + + // Check condition 1 if (dim.getIndex().getParentBlock() == reshape->getBlock()) { if (auto *definingOp = dim.getIndex().getDefiningOp()) { - if (reshape->isBeforeInBlock(definingOp)) + if (reshape->isBeforeInBlock(definingOp)) { return rewriter.notifyMatchFailure( dim, "dim.getIndex is not defined before reshape in the same block."); - } // else dim.getIndex is a block argument to reshape->getBlock - } else if (!dim.getIndex().getParentRegion()->isProperAncestor( - reshape->getParentRegion())) + } + } // else dim.getIndex is a block argument to reshape->getBlock and + // dominates reshape + } // Check condition 2 + else if (dim->getBlock() != reshape->getBlock() && + !dim.getIndex().getParentRegion()->isProperAncestor( + reshape->getParentRegion())) { + // If dim and reshape are in the same block but dim.getIndex() isn't, we + // already know dim.getIndex() dominates reshape without calling + // `isProperAncestor` return rewriter.notifyMatchFailure( dim, "dim.getIndex does not dominate reshape."); + } // Place the load directly after the reshape to ensure that the shape memref // was not mutated. From 64b7c159b5b23aecbc71f09df5a9d46339375add Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 8 Mar 2024 21:24:58 -0500 Subject: [PATCH 6/6] [Task] : Run clang-format. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 3594b9669e3c6d..cb5599def1efb4 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1101,7 +1101,7 @@ struct DimOfMemRefReshape : public OpRewritePattern { } } // else dim.getIndex is a block argument to reshape->getBlock and // dominates reshape - } // Check condition 2 + } // Check condition 2 else if (dim->getBlock() != reshape->getBlock() && !dim.getIndex().getParentRegion()->isProperAncestor( reshape->getParentRegion())) {