-
Notifications
You must be signed in to change notification settings - Fork 598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[compiler][flow] Move cast, reshape and bitcast after transfer op #18742
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,6 +5,7 @@ | |||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||||
|
||||||
#include <algorithm> | ||||||
#include <memory> | ||||||
#include <numeric> | ||||||
#include <optional> | ||||||
|
||||||
|
@@ -13,6 +14,7 @@ | |||||
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h" | ||||||
#include "iree/compiler/Dialect/Util/IR/UtilOps.h" | ||||||
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" | ||||||
#include "iree/compiler/Utils/Shape.h" | ||||||
#include "llvm/ADT/MapVector.h" | ||||||
#include "llvm/ADT/STLExtras.h" | ||||||
#include "llvm/ADT/SmallPtrSet.h" | ||||||
|
@@ -30,6 +32,7 @@ | |||||
#include "mlir/IR/TypeUtilities.h" | ||||||
#include "mlir/IR/Value.h" | ||||||
#include "mlir/Support/LogicalResult.h" | ||||||
#include "mlir/Transforms/HomomorphismSimplification.h" | ||||||
|
||||||
namespace mlir::iree_compiler::IREE::Flow { | ||||||
|
||||||
|
@@ -1180,11 +1183,129 @@ struct ElideRedundantTransfer : public OpRewritePattern<TensorTransferOp> { | |||||
} | ||||||
}; | ||||||
|
||||||
// Move an op of type Op after flow.tensor.transfer. | ||||||
// E.g. | ||||||
// | ||||||
// %cast = tensor.cast %0 : tensor<2xf32> to tensor<?xf32> | ||||||
// %1 = flow.tensor.transfer %cast : tensor<?xf32>{%c2} to | ||||||
// #hal.device.affinity<@__device_0> | ||||||
// | ||||||
// is transformed into | ||||||
// | ||||||
// %1 = flow.tensor.transfer %0 : tensor<2xf32> to | ||||||
// #hal.device.affinity<@__device_0> %cast = tensor.cast %1 : tensor<2xf32> to | ||||||
// tensor<?xf32> | ||||||
// | ||||||
// Will only move the op if the resulting transfer op would operate on a | ||||||
// tensor with less or equal number of dynamic dimensions. | ||||||
// The goal is that after application we would have one more op | ||||||
// (the transfer op) that has less dynamism. | ||||||
// | ||||||
// E.g. | ||||||
// | ||||||
// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor<?x2xf32> | ||||||
// %1 = flow.tensor.transfer %cast : tensor<?x2xf32>{%c1} to | ||||||
// #hal.device.affinity<@__device_0> | ||||||
// | ||||||
// will match, but this will not | ||||||
// | ||||||
// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor<1x2xf32> | ||||||
// %1 = flow.tensor.transfer %cast : tensor<1x2xf32> to | ||||||
// #hal.device.affinity<@__device_0> | ||||||
// | ||||||
// The weaker condition that the number of dynamic dimensions is not strictly | ||||||
// less is the tie breaker between the symmetric pattern where we move the op | ||||||
// before the transfer instead. We strive to reduce the dynamism of the | ||||||
// transfer op. If there will be no strict dynamism improvement, we prefer the | ||||||
// other op after the transfer. | ||||||
// TODO: add the analogous move-befor-transfer pattern. | ||||||
// | ||||||
// Uses the homomorphism simplification pattern | ||||||
// transfer(cast(x)) -> cast(transfer(x)) | ||||||
// where the cast op is the homomorphism | ||||||
// and the transfer is the unar (mono-unary algebra) operation. | ||||||
template <typename Op, unsigned homomorphismOpOperandIndex = 0, | ||||||
unsigned homomorphismOpResultIndex = 0> | ||||||
static void populateMoveOpAfterTransferPattern(RewritePatternSet &results) { | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
auto getHomomorphismOpOperandFn = [](Operation *op) { | ||||||
// op is Op. | ||||||
return &op->getOpOperand(homomorphismOpOperandIndex); | ||||||
}; | ||||||
|
||||||
auto getHomomorphismOpResultFn = [](Operation *op) { | ||||||
// op is Op. | ||||||
return op->getResult(homomorphismOpResultIndex); | ||||||
}; | ||||||
|
||||||
auto getSourceAlgebraicOpOperandsFn = [](Operation *op, | ||||||
SmallVector<OpOperand *> &operands) { | ||||||
// Op is transfer. | ||||||
operands.push_back(&op->getOpOperand(0)); | ||||||
}; | ||||||
|
||||||
auto getSourceAlgebraicOpResultFn = [](Operation *op) { | ||||||
// Op is transfer. | ||||||
return op->getResult(0); | ||||||
}; | ||||||
auto getTargetAlgebraicOpResultFn = getSourceAlgebraicOpResultFn; | ||||||
|
||||||
auto isHomomorphismOpFn = [](Operation *op, | ||||||
std::optional<Operation *> referenceOp) { | ||||||
auto operation = dyn_cast<Op>(op); | ||||||
if (!operation) { | ||||||
return false; | ||||||
} | ||||||
auto sourceType = op->getOperand(homomorphismOpOperandIndex).getType(); | ||||||
auto targetType = op->getResult(homomorphismOpOperandIndex).getType(); | ||||||
return shapeHasLessOrEqualDynamicDimensions(cast<ShapedType>(sourceType), | ||||||
cast<ShapedType>(targetType)); | ||||||
}; | ||||||
|
||||||
auto isSourceAlgebraicOpFn = [](Operation *op) { | ||||||
return static_cast<bool>(dyn_cast<TensorTransferOp>(op)); | ||||||
}; | ||||||
|
||||||
auto createTargetAlgebraicOpFn = [](Operation *originalOp, | ||||||
IRMapping &operandsRemapping, | ||||||
PatternRewriter &rewriter) { | ||||||
// Create the transfer op. | ||||||
auto originalTransferOp = cast<TensorTransferOp>(originalOp); | ||||||
return rewriter.create<TensorTransferOp>( | ||||||
originalTransferOp->getLoc(), | ||||||
operandsRemapping.lookup(originalTransferOp.getOperand()), | ||||||
originalTransferOp.getTargetAttr()); | ||||||
}; | ||||||
|
||||||
using CastTransferReorderPattern = HomomorphismSimplification< | ||||||
std::decay_t<decltype(getHomomorphismOpOperandFn)>, | ||||||
std::decay_t<decltype(getHomomorphismOpResultFn)>, | ||||||
std::decay_t<decltype(getSourceAlgebraicOpOperandsFn)>, | ||||||
std::decay_t<decltype(getSourceAlgebraicOpResultFn)>, | ||||||
std::decay_t<decltype(getTargetAlgebraicOpResultFn)>, | ||||||
std::decay_t<decltype(isHomomorphismOpFn)>, | ||||||
std::decay_t<decltype(isSourceAlgebraicOpFn)>, | ||||||
std::decay_t<decltype(createTargetAlgebraicOpFn)>>; | ||||||
results.add(std::make_unique<CastTransferReorderPattern>( | ||||||
std::move(getHomomorphismOpOperandFn), | ||||||
std::move(getHomomorphismOpResultFn), | ||||||
std::move(getSourceAlgebraicOpOperandsFn), | ||||||
std::move(getSourceAlgebraicOpResultFn), | ||||||
std::move(getTargetAlgebraicOpResultFn), std::move(isHomomorphismOpFn), | ||||||
std::move(isSourceAlgebraicOpFn), std::move(createTargetAlgebraicOpFn), | ||||||
TensorTransferOp::getOperationName(), 1, results.getContext())); | ||||||
} | ||||||
|
||||||
} // namespace | ||||||
|
||||||
void TensorTransferOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||||||
MLIRContext *context) { | ||||||
results.insert<ElideRedundantTransfer>(context); | ||||||
populateMoveOpAfterTransferPattern<tensor::BitcastOp>(results); | ||||||
populateMoveOpAfterTransferPattern<TensorBitCastOp>(results); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
use namespaced names where there may be ambiguity to make it clearer to readers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
populateMoveOpAfterTransferPattern<tensor::CastOp>(results); | ||||||
populateMoveOpAfterTransferPattern<tensor::ReshapeOp>(results); | ||||||
populateMoveOpAfterTransferPattern<TensorReshapeOp>(results); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
} | ||||||
|
||||||
//===----------------------------------------------------------------------===// | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#ifndef IREE_COMPILER_UTILS_SHAPE_H_ | ||
#define IREE_COMPILER_UTILS_SHAPE_H_ | ||
|
||
#include "llvm/ADT/STLExtras.h" | ||
#include "mlir/IR/BuiltinTypeInterfaces.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
// Unranked shapes are always considered to have more dynamic dimensions than | ||
// ranked. | ||
inline bool shapeHasLessOrEqualDynamicDimensions(ShapedType t1, ShapedType t2) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. YAGNI (this would also need to be static here - you can't put inline functions in header files) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved it. |
||
if (!t2.hasRank()) { | ||
return true; | ||
} | ||
if (!t1.hasRank()) { | ||
return false; | ||
} | ||
|
||
return llvm::count_if(t1.getShape(), &ShapedType::isDynamic) <= | ||
llvm::count_if(t2.getShape(), &ShapedType::isDynamic); | ||
} | ||
|
||
} // namespace mlir::iree_compiler | ||
|
||
#endif // IREE_COMPILER_UTILS_SHAPE_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(a spell checker can help in your IDE)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I have one but it ignored the hyphened word.