Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler][flow] Move cast, reshape and bitcast after transfer op #18742

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:ViewLikeInterface",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ iree_cc_library(
MLIRSupport
MLIRTensorDialect
MLIRTransformUtils
MLIRTransforms
MLIRViewLikeInterface
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
Expand Down
121 changes: 121 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>

Expand All @@ -13,6 +14,7 @@
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Utils/Shape.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
Expand All @@ -30,6 +32,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/HomomorphismSimplification.h"

namespace mlir::iree_compiler::IREE::Flow {

Expand Down Expand Up @@ -1180,11 +1183,129 @@ struct ElideRedundantTransfer : public OpRewritePattern<TensorTransferOp> {
}
};

// Move an op of type Op after flow.tensor.transfer.
// E.g.
//
// %cast = tensor.cast %0 : tensor<2xf32> to tensor<?xf32>
// %1 = flow.tensor.transfer %cast : tensor<?xf32>{%c2} to
// #hal.device.affinity<@__device_0>
//
// is transformed into
//
// %1 = flow.tensor.transfer %0 : tensor<2xf32> to
// #hal.device.affinity<@__device_0> %cast = tensor.cast %1 : tensor<2xf32> to
// tensor<?xf32>
//
// Will only move the op if the resulting transfer op would operate on a
// tensor with less or equal number of dynamic dimensions.
// The goal is that after application we would have one more op
// (the transfer op) that has less dynamism.
//
// E.g.
//
// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor<?x2xf32>
// %1 = flow.tensor.transfer %cast : tensor<?x2xf32>{%c1} to
// #hal.device.affinity<@__device_0>
//
// will match, but this will not
//
// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor<1x2xf32>
// %1 = flow.tensor.transfer %cast : tensor<1x2xf32> to
// #hal.device.affinity<@__device_0>
//
// The weaker condition that the number of dynamic dimensions is not strictly
// less is the tie breaker between the symmetric pattern where we move the op
// before the transfer instead. We strive to reduce the dynamism of the
// transfer op. If there will be no strict dynamism improvement, we prefer the
// other op after the transfer.
// TODO: add the analogous move-befor-transfer pattern.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// TODO: add the analogous move-befor-transfer pattern.
// TODO: add the analogous move-before-transfer pattern.

(a spell checker can help in your IDE)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I have one but it ignored the hyphened word.

//
// Uses the homomorphism simplification pattern
// transfer(cast(x)) -> cast(transfer(x))
// where the cast op is the homomorphism
// and the transfer is the unar (mono-unary algebra) operation.
template <typename Op, unsigned homomorphismOpOperandIndex = 0,
unsigned homomorphismOpResultIndex = 0>
static void populateMoveOpAfterTransferPattern(RewritePatternSet &results) {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto getHomomorphismOpOperandFn = [](Operation *op) {
// op is Op.
return &op->getOpOperand(homomorphismOpOperandIndex);
};

auto getHomomorphismOpResultFn = [](Operation *op) {
// op is Op.
return op->getResult(homomorphismOpResultIndex);
};

auto getSourceAlgebraicOpOperandsFn = [](Operation *op,
SmallVector<OpOperand *> &operands) {
// Op is transfer.
operands.push_back(&op->getOpOperand(0));
};

auto getSourceAlgebraicOpResultFn = [](Operation *op) {
// Op is transfer.
return op->getResult(0);
};
auto getTargetAlgebraicOpResultFn = getSourceAlgebraicOpResultFn;

auto isHomomorphismOpFn = [](Operation *op,
std::optional<Operation *> referenceOp) {
auto operation = dyn_cast<Op>(op);
if (!operation) {
return false;
}
auto sourceType = op->getOperand(homomorphismOpOperandIndex).getType();
auto targetType = op->getResult(homomorphismOpOperandIndex).getType();
return shapeHasLessOrEqualDynamicDimensions(cast<ShapedType>(sourceType),
cast<ShapedType>(targetType));
};

auto isSourceAlgebraicOpFn = [](Operation *op) {
return static_cast<bool>(dyn_cast<TensorTransferOp>(op));
};

auto createTargetAlgebraicOpFn = [](Operation *originalOp,
IRMapping &operandsRemapping,
PatternRewriter &rewriter) {
// Create the transfer op.
auto originalTransferOp = cast<TensorTransferOp>(originalOp);
return rewriter.create<TensorTransferOp>(
originalTransferOp->getLoc(),
operandsRemapping.lookup(originalTransferOp.getOperand()),
originalTransferOp.getTargetAttr());
};

using CastTransferReorderPattern = HomomorphismSimplification<
std::decay_t<decltype(getHomomorphismOpOperandFn)>,
std::decay_t<decltype(getHomomorphismOpResultFn)>,
std::decay_t<decltype(getSourceAlgebraicOpOperandsFn)>,
std::decay_t<decltype(getSourceAlgebraicOpResultFn)>,
std::decay_t<decltype(getTargetAlgebraicOpResultFn)>,
std::decay_t<decltype(isHomomorphismOpFn)>,
std::decay_t<decltype(isSourceAlgebraicOpFn)>,
std::decay_t<decltype(createTargetAlgebraicOpFn)>>;
results.add(std::make_unique<CastTransferReorderPattern>(
std::move(getHomomorphismOpOperandFn),
std::move(getHomomorphismOpResultFn),
std::move(getSourceAlgebraicOpOperandsFn),
std::move(getSourceAlgebraicOpResultFn),
std::move(getTargetAlgebraicOpResultFn), std::move(isHomomorphismOpFn),
std::move(isSourceAlgebraicOpFn), std::move(createTargetAlgebraicOpFn),
TensorTransferOp::getOperationName(), 1, results.getContext()));
}

} // namespace

void TensorTransferOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<ElideRedundantTransfer>(context);
populateMoveOpAfterTransferPattern<tensor::BitcastOp>(results);
populateMoveOpAfterTransferPattern<TensorBitCastOp>(results);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
populateMoveOpAfterTransferPattern<TensorBitCastOp>(results);
populateMoveOpAfterTransferPattern<IREE::Flow::TensorBitCastOp>(results);

use namespaced names where there may be ambiguity to make it clearer to readers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

populateMoveOpAfterTransferPattern<tensor::CastOp>(results);
populateMoveOpAfterTransferPattern<tensor::ReshapeOp>(results);
populateMoveOpAfterTransferPattern<TensorReshapeOp>(results);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
populateMoveOpAfterTransferPattern<TensorReshapeOp>(results);
populateMoveOpAfterTransferPattern<IREE::Flow::TensorReshapeOp>(results);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,39 @@ util.func public @tensor_cast_to_reshape(%reshape_17 : tensor<?x?x?x?xf32>, %65
// CHECK: flow.tensor.reshape
// CHECK-SAME: tensor<?x12x?x64xf32>
// CHECK-SAME: -> tensor<?x?x?x?xf32>

// -----

// CHECK-LABEL: util.func public @move_reshape_after_transfer
util.func public @move_reshape_after_transfer(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>) -> tensor<?x2xf32> {
%arg0: tensor<1x?xf32>) -> tensor<?x2xf32> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
// CHECK: %[[TRANSFER:.*]] = flow.tensor.transfer %[[ARG0]] : tensor<1x?xf32>{%[[C2]]} to #hal.device.affinity<@__device_0>
// CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[TRANSFER]] : tensor<1x?xf32>{%[[C2]]} -> tensor<?x2xf32>{%[[C1]]}
%1 = flow.tensor.reshape %arg0 : tensor<1x?xf32>{%c2} -> tensor<?x2xf32>{%c1}
%2 = flow.tensor.transfer %1 : tensor<?x2xf32>{%c1} to
#hal.device.affinity<@__device_0>
// CHECK: util.return %[[RESHAPE]] : tensor<?x2xf32>
util.return %2 : tensor<?x2xf32>
}

// -----

// CHECK-LABEL: util.func public @do_not_move_reshape_after_transfer
util.func public @do_not_move_reshape_after_transfer(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>) -> tensor<1xf32> {
%arg0: tensor<?xf32>) -> tensor<1xf32> {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
// CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[ARG0]] : tensor<?xf32>{%[[C1]]} -> tensor<1xf32>
%1 = flow.tensor.reshape %arg0 : tensor<?xf32>{%c1} -> tensor<1xf32>
// CHECK: %[[TRANSFER:.*]] = flow.tensor.transfer %[[RESHAPE]] : tensor<1xf32> to #hal.device.affinity<@__device_0>
%2 = flow.tensor.transfer %1 : tensor<1xf32> to
#hal.device.affinity<@__device_0>
// CHECK: util.return %[[TRANSFER]] : tensor<1xf32>
util.return %2 : tensor<1xf32>
}
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ iree_compiler_cc_library(
"PassUtils.h",
"PatternUtils.h",
"Permutation.h",
"Shape.h",
"SmallVectorDenseMapInfo.h",
"StringUtils.h",
"ToolUtils.h",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_cc_library(
"PassUtils.h"
"PatternUtils.h"
"Permutation.h"
"Shape.h"
"SmallVectorDenseMapInfo.h"
"StringUtils.h"
"ToolUtils.h"
Expand Down
31 changes: 31 additions & 0 deletions compiler/src/iree/compiler/Utils/Shape.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_UTILS_SHAPE_H_
#define IREE_COMPILER_UTILS_SHAPE_H_

#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"

namespace mlir::iree_compiler {

// Unranked shapes are always considered to have more dynamic dimensions than
// ranked.
inline bool shapeHasLessOrEqualDynamicDimensions(ShapedType t1, ShapedType t2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YAGNI
on the 3rd or 4th time a 3 line block of code is used it's worth hoisting into a global shared util - but there's a bar for doing this - every shared util pollutes the project and adds a maintenance burden and has to be worth it. A single use in a single location is not worth it. Just inline this as static where this is used.

(this would also need to be static here - you can't put inline functions in header files)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it.

if (!t2.hasRank()) {
return true;
}
if (!t1.hasRank()) {
return false;
}

return llvm::count_if(t1.getShape(), &ShapedType::isDynamic) <=
llvm::count_if(t2.getShape(), &ShapedType::isDynamic);
}

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_UTILS_SHAPE_H_
Loading