Skip to content

Commit

Permalink
Add LinalgFusionInterface to support fusion for linalg_ext ops (add…
Browse files Browse the repository at this point in the history
…ed `scatter` and `reverse`) (iree-org#17428)

`LinalgFusionOpInterface` allows for fusion of both `Linalg` and
`LinalgExt` operations. The new interface provides access to methods
essential for performing fusion, allowing existing fusion logic to be
used with `LinalgExt` operations.

As noted in iree-org#17392, it probably makes sense to move this into the
`TilingInterface` + probably make it a bit more abstracted


#### Changes
- **`LinalgFusionOpInterface`**: Interface for fusion operations for
both `Linalg` and `LinalgExt` ops.
  - Implements methods to access indexing maps (or null
- **Implementation for Linalg Ops**: The interface is implemented for
standard Linalg operations by forwarding to preexisting methods (e.g
`getIndexingMaps()`). No changes to the ops themselves.
- **Implementation for LinalgExt Ops**: The interface currently only
implemented for `iree_linalg_ext.scatter/reverse`.

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored and gglangg committed Jun 4, 2024
1 parent b8c4e35 commit 6fa5336
Show file tree
Hide file tree
Showing 14 changed files with 359 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithUtils",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:DestinationStyleOpInterface",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ iree_cc_library(
MLIRArithDialect
MLIRArithUtils
MLIRComplexDialect
MLIRDestinationStyleOpInterface
MLIRFuncDialect
MLIRFunctionInterfaces
MLIRIR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand All @@ -30,6 +34,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -390,15 +395,21 @@ static bool hasCompatibleOuterParallelLoops(
// relationship through `operand` have compatible outer-parallel loops.
static bool hasCompatibleOuterParallelLoops(
OpOperand &operand, const llvm::SmallBitVector &rootOuterParallelLoops) {
auto producer = operand.get().getDefiningOp<linalg::LinalgOp>();
auto consumer = dyn_cast<linalg::LinalgOp>(operand.getOwner());
auto producer =
operand.get().getDefiningOp<LinalgExt::LinalgFusionOpInterface>();
auto consumer =
dyn_cast<LinalgExt::LinalgFusionOpInterface>(operand.getOwner());
if (!producer || !consumer)
return false;

auto producerIndexingMap = producer.getIndexingMapMatchingResult(
llvm::cast<OpResult>(operand.get()));
auto consumerIndexingMap = consumer.getMatchingIndexingMap(&operand);

if (!producerIndexingMap || !consumerIndexingMap) {
return false;
}

return hasCompatibleOuterParallelLoops(
cast<TilingInterface>(producer.getOperation()),
producerIndexingMap, rootOuterParallelLoops) &&
Expand Down Expand Up @@ -605,14 +616,16 @@ isFusableWithConsumer(OpOperand &fusedOperand,
return false;
}

auto producerLinalgOp = dyn_cast<linalg::LinalgOp>(producer);
auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer);
if (!producerLinalgOp || !consumerLinalgOp)
auto producerFusionOp =
dyn_cast<LinalgExt::LinalgFusionOpInterface>(producer);
auto consumerFusionOp =
dyn_cast<LinalgExt::LinalgFusionOpInterface>(consumer);
if (!producerFusionOp || !consumerFusionOp)
return false;

// Check that the consumer is all parallel.
if (consumerLinalgOp.getNumLoops() !=
consumerLinalgOp.getNumParallelLoops()) {
if (consumerFusionOp.getNumLoops() !=
consumerFusionOp.getNumParallelLoops()) {
return false;
}

Expand All @@ -623,8 +636,8 @@ isFusableWithConsumer(OpOperand &fusedOperand,
// Check if the iteration spaces of the producer and consumer are same.
// TODO(#12664): This is unnecessary requirement, but we need a better config
// to tile the consumer with a larger iteration space.
auto producerIterationSpace = producerLinalgOp.getStaticLoopRanges();
auto consumerIterationSpace = consumerLinalgOp.getStaticLoopRanges();
auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
if (producerIterationSpace.size() < consumerIterationSpace.size()) {
return false;
}
Expand All @@ -640,12 +653,18 @@ isFusableWithConsumer(OpOperand &fusedOperand,
// While fusing with consumer, the result of the root might not be the final
// result of the dispatch. To avoid a stack allocation we have to ensure that
// all operations can bufferize without needing additional memory.
for (OpOperand *inputOperand : consumerLinalgOp.getDpsInputOperands()) {
auto consumerDstOp =
dyn_cast<DestinationStyleOpInterface>(consumerFusionOp.getOperation());
if (!consumerDstOp) {
return true;
}

for (OpOperand *inputOperand : consumerDstOp.getDpsInputOperands()) {
if (inputOperand->get().getDefiningOp() != producer)
continue;
if (isa<linalg::ConvolutionOpInterface>(producer) &&
!llvm::any_of(
consumerLinalgOp.getDpsInitsMutable(), [&](OpOperand &initOperand) {
consumerDstOp.getDpsInitsMutable(), [&](OpOperand &initOperand) {
return canUseInOperandAsInitOperand(inputOperand, &initOperand);
})) {
return false;
Expand Down Expand Up @@ -744,13 +763,14 @@ isFusableWithProducer(OpOperand &operand,
.Default([](Operation *) { return false; });
}

if (!isa<linalg::LinalgOp>(consumer) || !isa<linalg::LinalgOp>(producer)) {
if (!isa<LinalgExt::LinalgFusionOpInterface>(consumer) ||
!isa<LinalgExt::LinalgFusionOpInterface>(producer)) {
return false;
}

if (!options.aggressiveFusion) {
auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
if (!consumerLinalgOp.isDpsInit(&operand)) {
auto consumerFusionOp = dyn_cast<DestinationStyleOpInterface>(consumer);
if (consumerFusionOp && !consumerFusionOp.isDpsInit(&operand)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def FormDispatchRegionsPass :
"mlir::scf::SCFDialect",
"mlir::tensor::TensorDialect",
"IREE::Flow::FlowDialect",
"IREE::LinalgExt::IREELinalgExtDialect",
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ iree_lit_test_suite(
"form_dispatch_regions.mlir",
"form_dispatch_workgroups.mlir",
"form_scalar_dispatches.mlir",
"dispatch_linalg_ext_fusion.mlir",
"fusion_of_tensor_ops.mlir",
"fusion_preprocessing.mlir",
"initialize_empty_tensors.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_lit_test_suite(
"collapse_reduction.mlir"
"convert_region_to_workgroups.mlir"
"deduplicate_executables.mlir"
"dispatch_linalg_ext_fusion.mlir"
"dispatch_linalg_on_tensors.mlir"
"dispatch_linalg_on_tensors_default.mlir"
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-form-dispatch-workgroups), cse, canonicalize, cse)" %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> {
%0 = tensor.empty() : tensor<4x1xi32>
%1 = tensor.empty() : tensor<4x1xi64>
%2 = tensor.empty() : tensor<4x1x16x8x128xf32>
%3 = tensor.empty() : tensor<4x1x16x8x128xf32>
%4 = tensor.empty() : tensor<8192x16x8x128xf32>
%5 = tensor.empty() : tensor<8192x16x8x128xf32>
%6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%1 : tensor<4x1xi64>) outs(%0 : tensor<4x1xi32>) {
^bb0(%in: i64, %out: i32):
%10 = arith.trunci %in : i64 to i32
linalg.yield %10 : i32
} -> tensor<4x1xi32>

%7 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<4x1x16x8x128xf32>) outs(%3 : tensor<4x1x16x8x128xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<4x1x16x8x128xf32>

%8 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%7, %6 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) outs(%4 : tensor<8192x16x8x128xf32>) {
^bb0(%arg0: f32, %arg1: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<8192x16x8x128xf32>

// Dont fuse with scatter's consumer
%9 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<8192x16x8x128xf32>) outs(%5 : tensor<8192x16x8x128xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<8192x16x8x128xf32>
util.return %9 : tensor<8192x16x8x128xf32>
}

// CHECK: util.func public @linalgext_scatter_fusion
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
// CHECK: %[[INDICES:.+]] = linalg.generic
// CHECK: %[[UPDATE:.+]] = linalg.generic
// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
// CHECK: flow.dispatch.workgroups
// CHECK: %[[GEN2:.+]] = linalg.generic
// CHECK-SAME: ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>)



// -----


#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> {
%0 = tensor.empty() : tensor<10x10xi64>
%1 = tensor.empty() : tensor<10x10xi32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x10xi64>) outs(%1 : tensor<10x10xi32>) {
^bb0(%in: i64, %out: i32):
%7 = arith.trunci %in : i64 to i32
linalg.yield %7 : i32
} -> tensor<10x10xi32>
%3 = tensor.empty() : tensor<10x10xi32>
%4 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%2 : tensor<10x10xi32>) outs(%3 : tensor<10x10xi32>) : tensor<10x10xi32>

// dont fuse with with reverse's consumer
%5 = tensor.empty() : tensor<10x10xi32>
%6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<10x10xi32>) outs(%5 : tensor<10x10xi32>) {
^bb0(%in: i32, %out: i32):
%7 = arith.addi %in, %out : i32
linalg.yield %7 : i32
} -> tensor<10x10xi32>
util.return %6 : tensor<10x10xi32>
}

// CHECK: util.func public @linalgext_reverse_fusion
// CHECK: flow.dispatch.workgroups
// CHECK: %[[SHRUNK:.+]] = linalg.generic
// CHECK: %[[REVERSED:.+]] = iree_linalg_ext.reverse
// CHECK: ins(%[[SHRUNK]] : tensor<10x10xi32>)
// CHECK: flow.dispatch.workgroups
// CHECK: %[[GEN:.+]] = linalg.generic
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:InliningUtils",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgStructuredOpsIncGen",
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ iree_cc_library(
MLIRIR
MLIRInferTypeOpInterface
MLIRLinalgDialect
MLIRLinalgStructuredOpsIncGenLib
MLIRLinalgUtils
MLIRMathDialect
MLIRMemRefDialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"

using namespace mlir;
Expand Down Expand Up @@ -46,7 +51,83 @@ struct IREELinalgExtInlinerInterface : public DialectInlinerInterface {
}
};

// Used to register the LinalgFusionOpInterface with the linalg ops.
namespace {
template <typename ConcreteType>
struct LinalgFusionOpInterfaceAdapter
: public LinalgFusionOpInterface::ExternalModel<
LinalgFusionOpInterfaceAdapter<ConcreteType>, ConcreteType> {
public:
SmallVector<AffineMap> getIndexingMapsForOperands(mlir::Operation *op) const {
auto maps = llvm::cast<ConcreteType>(op)
.getIndexingMaps()
.template getAsValueRange<AffineMapAttr>();
return {maps.begin(),
maps.end() - llvm::cast<ConcreteType>(op).getNumResults()};
}

SmallVector<AffineMap> getIndexingMapsForResults(mlir::Operation *op) const {
auto maps = llvm::cast<ConcreteType>(op)
.getIndexingMaps()
.template getAsValueRange<AffineMapAttr>();
return {maps.end() - llvm::cast<ConcreteType>(op).getNumResults(),
maps.end()};
}

// Forward all the interface methods to the corresponding linalg op.
unsigned getNumParallelLoops(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getNumParallelLoops());
}

unsigned getNumLoops(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getNumLoops());
}

SmallVector<int64_t, 4> getStaticLoopRanges(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getStaticLoopRanges());
}

AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
OpResult result) const {
return (llvm::cast<ConcreteType>(op).getIndexingMapMatchingResult(result));
}

AffineMap getMatchingIndexingMap(mlir::Operation *op,
OpOperand *operand) const {
return (llvm::cast<ConcreteType>(op).getMatchingIndexingMap(operand));
}

SmallVector<AffineMap> getIndexingMaps(mlir::Operation *op) const {
// Note: this is different from linalg's implementation
// of `getIndexingMaps`. Call interface methods to get
// the vector of indexing maps for operands and results.
auto inputMaps = getIndexingMapsForOperands(op);
llvm::append_range(inputMaps, getIndexingMapsForResults(op));
return inputMaps;
}
};
} // namespace

template <typename... Args>
static void registerOpsWithLinalgExtOpInterface(mlir::MLIRContext *context) {
(Args::template attachInterface<LinalgFusionOpInterfaceAdapter<Args>>(
*context),
...);
}

void IREELinalgExtDialect::initialize() {
mlir::MLIRContext *context = getContext();
context->loadDialect<mlir::linalg::LinalgDialect>();

#define GET_OP_LIST
declarePromisedInterfaces<LinalgFusionOpInterface,
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>();

#define GET_OP_LIST
registerOpsWithLinalgExtOpInterface<
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(context);
addInterfaces<IREELinalgExtInlinerInterface>();

addAttributes<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_

#include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"

#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"

Expand Down
Loading

0 comments on commit 6fa5336

Please sign in to comment.