Skip to content

Commit

Permalink
[mlir][linalg] Add linalg.transpose constant folding (llvm#92589)
Browse files Browse the repository at this point in the history
There was existing support for constant folding a `linalg.generic` that
was actually a transpose. This commit adds support for the named op,
`linalg.transpose`, as well by making use of the `LinalgOp` interface.
  • Loading branch information
ryan-holt-1 authored May 28, 2024
1 parent d2a103e commit 74ed79f
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 163 deletions.
62 changes: 32 additions & 30 deletions mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ using namespace mlir;
using namespace mlir::linalg;

namespace {
/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
/// and permutation indexing maps.
/// Base class for constant folding linalg structured ops with N inputs, 1
/// output, and permutation indexing maps.
///
/// `ConcreteType` should provide methods with signatures
///
/// ```c++
/// bool matchIndexingMaps(GenericOp genericOp) const;
/// RegionComputationFn getRegionComputeFn(GenericOp) const;
/// bool matchIndexingMaps(LinalgOp linalgOp) const;
/// RegionComputationFn getRegionComputeFn(LinalgOp) const;
/// ```
///
/// The latter inspects the region and returns the computation inside as a
/// functor. The functor will be invoked with constant elements for all inputs
/// and should return the corresponding computed constant element for output.
template <typename ConcreteType>
class FoldConstantBase : public OpRewritePattern<GenericOp> {
class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
public:
struct APIntOrFloat {
std::optional<APInt> apInt;
Expand All @@ -52,25 +52,26 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {

FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFn(controlFn) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
// Mixed and buffer sematics aren't supported.
if (!genericOp.hasPureTensorSemantics())
if (!linalgOp.hasPureTensorSemantics())
return failure();

// Only support ops generating one output for now.
if (genericOp.getNumDpsInits() != 1)
if (linalgOp.getNumDpsInits() != 1)
return failure();

auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
// Require the output types to be static given that we are generating
// constants.
if (!outputType || !outputType.hasStaticShape())
return failure();

if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
return isa<ShapedType>(input.getType());
}))
return failure();
Expand All @@ -80,7 +81,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
return cast<ShapedType>(value.getType()).getElementType();
};
if (!llvm::all_equal(
llvm::map_range(genericOp->getOperands(), getOperandElementType)))
llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
return failure();

// We can only handle the case where we have int/float elements.
Expand All @@ -93,43 +94,42 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
// entirely in the compiler, without needing to turn all indices into
// Values, and then do affine apply on them, and then match back the
// constant again.
if (!llvm::all_of(genericOp.getIndexingMapsArray(),
if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
[](AffineMap map) { return map.isPermutation(); }))
return failure();

for (OpOperand &operand : genericOp.getDpsInitsMutable()) {
if (genericOp.payloadUsesValueFromOperand(&operand))
for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
if (linalgOp.payloadUsesValueFromOperand(&operand))
return failure();
}

// Further check the indexing maps are okay for the ConcreteType.
if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
return failure();

// Defer to the concrete type to check the region and discover the
// computation inside.
RegionComputationFn computeFn =
static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
if (!computeFn)
return failure();

// All inputs should be constants.
int numInputs = genericOp.getNumDpsInputs();
int numInputs = linalgOp.getNumDpsInputs();
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
if (!matchPattern(en.value()->get(),
m_Constant(&inputValues[en.index()])))
return failure();
}

// Identified this as a potential candidate for folding. Now check the
// policy to see whether we are allowed to proceed.
for (OpOperand *operand : genericOp.getDpsInputOperands()) {
for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
if (!controlFn(operand))
return failure();
}

auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
int64_t numElements = outputType.getNumElements();

Expand All @@ -155,8 +155,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {

SmallVector<SmallVector<unsigned>> inputDims;
for (int i = 0; i < numInputs; ++i)
inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
auto outputShape = outputType.getShape();

// Allocate small vectors for index delinearization. Initial values do not
Expand All @@ -173,7 +173,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
APIntOrFloatArray computeFnInputs;

auto inputShapes = llvm::to_vector<4>(
llvm::map_range(genericOp.getInputs(), [](Value value) {
llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
return cast<ShapedType>(value.getType()).getShape();
}));

Expand Down Expand Up @@ -254,26 +254,28 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
: DenseElementsAttr::get(outputType, intOutputValues);

rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
return success();
}

private:
ControlFusionFn controlFn;
};

// Folds linalg.generic ops that are actually transposes on constant values.
// Folds linalg.transpose (and linalg.generic ops that are actually transposes)
// on constant values.
struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {

using FoldConstantBase::FoldConstantBase;

bool matchIndexingMaps(GenericOp genericOp) const {
bool matchIndexingMaps(LinalgOp linalgOp) const {
// We should have one input and one output.
return genericOp.getIndexingMapsArray().size() == 2;
return linalgOp.getIndexingMapsArray().size() == 2;
}

RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
// Make sure the region only contains a yield op.
Block &body = genericOp.getRegion().front();
Block &body = linalgOp->getRegion(0).front();
if (!llvm::hasSingleElement(body))
return nullptr;
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
Expand Down
148 changes: 148 additions & 0 deletions mlir/test/Dialect/Linalg/constant-fold.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// RUN: mlir-opt %s -linalg-fuse-elementwise-ops -split-input-file | FileCheck %s

// CHECK-LABEL: @transpose_fold_2d_fp32
func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
// CHECK: %[[CST:.+]] = arith.constant
// CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}

// -----

// CHECK-LABEL: @transpose_fold_2d_fp64
func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
// CHECK: %[[CST:.+]] = arith.constant
// CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
^bb0(%arg1: f64, %arg2: f64):
linalg.yield %arg1 : f64
} -> tensor<3x2xf64>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf64>
}

// -----

// CHECK-LABEL: @transpose_fold_4d_i32
func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
%input = arith.constant dense<[[
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi32>
// CHECK: %[[CST:.+]] = arith.constant dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
^bb0(%arg1: i32, %arg2: i32):
linalg.yield %arg1 : i32
} -> tensor<3x1x4x2xi32>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi32>
}

// -----

// CHECK-LABEL: @transpose_fold_4d_i16
func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
%input = arith.constant dense<[[
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi16>
// CHECK: %[[CST:.+]] = arith.constant dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
^bb0(%arg1: i16, %arg2: i16):
linalg.yield %arg1 : i16
} -> tensor<3x1x4x2xi16>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi16>
}

// -----

// CHECK-LABEL: @transpose_nofold_non_cst_input
func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
// CHECK: linalg.generic
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}

// -----

// CHECK-LABEL: @transpose_nofold_yield_const
func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
%cst = arith.constant 8.0 : f32
// CHECK: linalg.generic
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %cst : f32
} -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}

// -----

// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
// CHECK: linalg.generic
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%add = arith.addf %arg1, %arg1 : f32
linalg.yield %add : f32
} -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}

// -----

// CHECK-LABEL: @named_transpose_fold_2d_fp32
func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
// CHECK: %[[CST:.+]] = arith.constant
// CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
%1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0]
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}

// -----


Loading

0 comments on commit 74ed79f

Please sign in to comment.