Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler][flow] Move cast, reshape and bitcast after transfer op #18742

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:ViewLikeInterface",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ iree_cc_library(
MLIRSupport
MLIRTensorDialect
MLIRTransformUtils
MLIRTransforms
MLIRViewLikeInterface
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
Expand Down
134 changes: 134 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>

Expand All @@ -22,6 +23,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
Expand All @@ -30,6 +32,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/HomomorphismSimplification.h"

namespace mlir::iree_compiler::IREE::Flow {

Expand Down Expand Up @@ -1180,11 +1183,142 @@ struct ElideRedundantTransfer : public OpRewritePattern<TensorTransferOp> {
}
};

// Unranked shapes are always considered to have more dynamic dimensions than
// ranked.
static bool shapeHasLessOrEqualDynamicDimensions(ShapedType t1, ShapedType t2) {
if (!t2.hasRank()) {
return true;
}
if (!t1.hasRank()) {
return false;
}

return llvm::count_if(t1.getShape(), &ShapedType::isDynamic) <=
llvm::count_if(t2.getShape(), &ShapedType::isDynamic);
}

// Move an op of type Op after flow.tensor.transfer.
// E.g.
//
// %cast = tensor.cast %0 : tensor<2xf32> to tensor<?xf32>
// %1 = flow.tensor.transfer %cast : tensor<?xf32>{%c2} to
// #hal.device.affinity<@__device_0>
//
// is transformed into
//
// %1 = flow.tensor.transfer %0 : tensor<2xf32> to
// #hal.device.affinity<@__device_0> %cast = tensor.cast %1 : tensor<2xf32> to
// tensor<?xf32>
//
// Will only move the op if the resulting transfer op would operate on a
// tensor with less or equal number of dynamic dimensions.
// The goal is that after application we would have one more op
// (the transfer op) that has less dynamism.
//
// E.g.
//
// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor<?x2xf32>
// %1 = flow.tensor.transfer %cast : tensor<?x2xf32>{%c1} to
// #hal.device.affinity<@__device_0>
//
// will match, but this will not
//
// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor<1x2xf32>
// %1 = flow.tensor.transfer %cast : tensor<1x2xf32> to
// #hal.device.affinity<@__device_0>
//
// The weaker condition that the number of dynamic dimensions is not strictly
// less is the tie breaker between the symmetric pattern where we move the op
// before the transfer instead. We strive to reduce the dynamism of the
// transfer op. If there will be no strict dynamism improvement, we prefer the
// other op after the transfer.
// TODO: add the analogous move-before-transfer pattern.
//
// Uses the homomorphism simplification pattern
// transfer(cast(x)) -> cast(transfer(x))
// where the cast op is the homomorphism
// and the transfer is the unar (mono-unary algebra) operation.
template <typename Op, unsigned homomorphismOpOperandIndex = 0,
unsigned homomorphismOpResultIndex = 0>
static void populateMoveOpAfterTransferPattern(RewritePatternSet &results) {
auto getHomomorphismOpOperandFn = [](Operation *op) {
// op is Op.
return &op->getOpOperand(homomorphismOpOperandIndex);
};

auto getHomomorphismOpResultFn = [](Operation *op) {
// op is Op.
return op->getResult(homomorphismOpResultIndex);
};

auto getSourceAlgebraicOpOperandsFn = [](Operation *op,
SmallVector<OpOperand *> &operands) {
// Op is transfer.
operands.push_back(&op->getOpOperand(0));
};

auto getSourceAlgebraicOpResultFn = [](Operation *op) {
// Op is transfer.
return op->getResult(0);
};
auto getTargetAlgebraicOpResultFn = getSourceAlgebraicOpResultFn;

auto isHomomorphismOpFn = [](Operation *op,
std::optional<Operation *> referenceOp) {
auto operation = dyn_cast<Op>(op);
if (!operation) {
return false;
}
auto sourceType = op->getOperand(homomorphismOpOperandIndex).getType();
auto targetType = op->getResult(homomorphismOpOperandIndex).getType();
return shapeHasLessOrEqualDynamicDimensions(cast<ShapedType>(sourceType),
cast<ShapedType>(targetType));
};

auto isSourceAlgebraicOpFn = [](Operation *op) {
return static_cast<bool>(dyn_cast<TensorTransferOp>(op));
};

auto createTargetAlgebraicOpFn = [](Operation *originalOp,
IRMapping &operandsRemapping,
PatternRewriter &rewriter) {
// Create the transfer op.
auto originalTransferOp = cast<TensorTransferOp>(originalOp);
return rewriter.create<TensorTransferOp>(
originalTransferOp->getLoc(),
operandsRemapping.lookup(originalTransferOp.getOperand()),
originalTransferOp.getTargetAttr());
};

using CastTransferReorderPattern = HomomorphismSimplification<
std::decay_t<decltype(getHomomorphismOpOperandFn)>,
std::decay_t<decltype(getHomomorphismOpResultFn)>,
std::decay_t<decltype(getSourceAlgebraicOpOperandsFn)>,
std::decay_t<decltype(getSourceAlgebraicOpResultFn)>,
std::decay_t<decltype(getTargetAlgebraicOpResultFn)>,
std::decay_t<decltype(isHomomorphismOpFn)>,
std::decay_t<decltype(isSourceAlgebraicOpFn)>,
std::decay_t<decltype(createTargetAlgebraicOpFn)>>;
results.add(std::make_unique<CastTransferReorderPattern>(
std::move(getHomomorphismOpOperandFn),
std::move(getHomomorphismOpResultFn),
std::move(getSourceAlgebraicOpOperandsFn),
std::move(getSourceAlgebraicOpResultFn),
std::move(getTargetAlgebraicOpResultFn), std::move(isHomomorphismOpFn),
std::move(isSourceAlgebraicOpFn), std::move(createTargetAlgebraicOpFn),
TensorTransferOp::getOperationName(), 1, results.getContext()));
}

} // namespace

void TensorTransferOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<ElideRedundantTransfer>(context);
populateMoveOpAfterTransferPattern<tensor::BitcastOp>(results);
populateMoveOpAfterTransferPattern<IREE::Flow::TensorBitCastOp>(results);
populateMoveOpAfterTransferPattern<tensor::CastOp>(results);
populateMoveOpAfterTransferPattern<tensor::ReshapeOp>(results);
populateMoveOpAfterTransferPattern<IREE::Flow::TensorReshapeOp>(results);
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,39 @@ util.func public @tensor_cast_to_reshape(%reshape_17 : tensor<?x?x?x?xf32>, %65
// CHECK: flow.tensor.reshape
// CHECK-SAME: tensor<?x12x?x64xf32>
// CHECK-SAME: -> tensor<?x?x?x?xf32>

// -----

// CHECK-LABEL: util.func public @move_reshape_after_transfer
util.func public @move_reshape_after_transfer(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>) -> tensor<?x2xf32> {
%arg0: tensor<1x?xf32>) -> tensor<?x2xf32> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
// CHECK: %[[TRANSFER:.*]] = flow.tensor.transfer %[[ARG0]] : tensor<1x?xf32>{%[[C2]]} to #hal.device.affinity<@__device_0>
// CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[TRANSFER]] : tensor<1x?xf32>{%[[C2]]} -> tensor<?x2xf32>{%[[C1]]}
%1 = flow.tensor.reshape %arg0 : tensor<1x?xf32>{%c2} -> tensor<?x2xf32>{%c1}
%2 = flow.tensor.transfer %1 : tensor<?x2xf32>{%c1} to
#hal.device.affinity<@__device_0>
// CHECK: util.return %[[RESHAPE]] : tensor<?x2xf32>
util.return %2 : tensor<?x2xf32>
}

// -----

// CHECK-LABEL: util.func public @do_not_move_reshape_after_transfer
util.func public @do_not_move_reshape_after_transfer(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>) -> tensor<1xf32> {
%arg0: tensor<?xf32>) -> tensor<1xf32> {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
// CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[ARG0]] : tensor<?xf32>{%[[C1]]} -> tensor<1xf32>
%1 = flow.tensor.reshape %arg0 : tensor<?xf32>{%c1} -> tensor<1xf32>
// CHECK: %[[TRANSFER:.*]] = flow.tensor.transfer %[[RESHAPE]] : tensor<1xf32> to #hal.device.affinity<@__device_0>
%2 = flow.tensor.transfer %1 : tensor<1xf32> to
#hal.device.affinity<@__device_0>
// CHECK: util.return %[[TRANSFER]] : tensor<1xf32>
util.return %2 : tensor<1xf32>
}
Loading