Skip to content

Commit

Permalink
[compiler][flow] Move cast, reshape and bitcast after transfer op
Browse files Browse the repository at this point in the history
We got incoming IR of the form
```mlir
  %cast = tensor.cast %0 : tensor<2xf32> to tensor<?xf32>
  %2 = flow.tensor.transfer %cast : tensor<?xf32>{%c2} to
    #hal.device.affinity<@__device_0>
  %cast_2 = tensor.cast %2 : tensor<?xf32> to tensor<2xf32>
```

We would like to allow for the 2 casts to get folded.
We would also like to reduce the dynamism of the transfer op. To
operate on a tensor with fewer dynamic dimensions.
  • Loading branch information
sogartar committed Oct 10, 2024
1 parent ce4f098 commit f4c155a
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 0 deletions.
125 changes: 125 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>

Expand All @@ -13,6 +14,7 @@
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Utils/Shape.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
Expand All @@ -30,6 +32,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/HomomorphismSimplification.h"

namespace mlir::iree_compiler::IREE::Flow {

Expand Down Expand Up @@ -1180,11 +1183,133 @@ struct ElideRedundantTransfer : public OpRewritePattern<TensorTransferOp> {
}
};

// Move an op of type Op after flow.tensor.transfer.
// E.g.
//
// %cast = tensor.cast %0 : tensor<2xf32> to tensor<?xf32>
// %1 = flow.tensor.transfer %cast : tensor<?xf32>{%c2} to
// #hal.device.affinity<@__device_0>
//
// is transformed into
//
// %1 = flow.tensor.transfer %0 : tensor<2xf32> to
// #hal.device.affinity<@__device_0> %cast = tensor.cast %1 : tensor<2xf32> to
// tensor<?xf32>
//
// Will only move the op if the resulting transfer op would operate on a
// tensor with less or equal number of dynamic dimensions.
// The goal is that after application we would have one more op
// (the transfer op) that has less dynamism.
//
// E.g.
//
// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor<?x2xf32>
// %1 = flow.tensor.transfer %cast : tensor<?x2xf32>{%c1} to
// #hal.device.affinity<@__device_0>
//
// will match, but this will not
//
// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor<1x2xf32>
// %1 = flow.tensor.transfer %cast : tensor<1x2xf32> to
// #hal.device.affinity<@__device_0>
//
// The weaker condition that the number of dynamic dimensions is not strictly
// less is the tie breaker between the symmetric pattern where we move the op
// before the transfer instead. We strive to reduce the dynamism of the
// transfer op. If there will be no strict dynamism improvement, we prefer the
// other op after the transfer.
// TODO: add the analogous move-befor-transfer pattern.
//
// Uses the homomorphism simplification pattern
// transfer(cast(x)) -> cast(transfer(x))
// where the cast op is the homomorphism
// and the transfer is the unar (mono-unary algebra) operation.
template <typename Op, unsigned homomorphismOpOperandIndex = 0,
unsigned homomorphismOpResultIndex = 0>
static void populateMoveOpAfterTransferPattern(RewritePatternSet &results) {

auto getHomomorphismOpOperandFn = [](Operation *op) {
// op is Op.
return &op->getOpOperand(homomorphismOpOperandIndex);
};

auto getHomomorphismOpResultFn = [](Operation *op) {
// op is Op.
return op->getResult(homomorphismOpResultIndex);
};

auto getSourceAlgebraicOpOperandsFn = [](Operation *op,
SmallVector<OpOperand *> &operands) {
// Op is transfer.
operands.push_back(&op->getOpOperand(0));
};

auto getSourceAlgebraicOpResultFn = [](Operation *op) {
// Op is transfer.
return op->getResult(0);
};

auto getTargetAlgebraicOpResultFn = [](Operation *op) {
// Op is transfer.
return op->getResult(0);
};

auto isHomomorphismOpFn = [](Operation *op,
std::optional<Operation *> referenceOp) {
auto operation = dyn_cast<Op>(op);
if (!operation) {
return false;
}
auto sourceType = op->getOperand(homomorphismOpOperandIndex).getType();
auto targetType = op->getResult(homomorphismOpOperandIndex).getType();
return shapeHasLessOrEqualDynamicDimensions(cast<ShapedType>(sourceType),
cast<ShapedType>(targetType));
};

auto isSourceAlgebraicOpFn = [](Operation *op) {
return static_cast<bool>(dyn_cast<TensorTransferOp>(op));
};

auto createTargetAlgebraicOpFn = [](Operation *originalOp,
IRMapping &operandsRemapping,
PatternRewriter &rewriter) {
// Create the transfer op.
auto originalTransferOp = cast<TensorTransferOp>(originalOp);
return rewriter.create<TensorTransferOp>(
originalTransferOp->getLoc(),
operandsRemapping.lookup(originalTransferOp.getOperand()),
originalTransferOp.getTargetAttr());
};

using CastTransferReorderPattern = HomomorphismSimplification<
std::decay_t<decltype(getHomomorphismOpOperandFn)>,
std::decay_t<decltype(getHomomorphismOpResultFn)>,
std::decay_t<decltype(getSourceAlgebraicOpOperandsFn)>,
std::decay_t<decltype(getSourceAlgebraicOpResultFn)>,
std::decay_t<decltype(getTargetAlgebraicOpResultFn)>,
std::decay_t<decltype(isHomomorphismOpFn)>,
std::decay_t<decltype(isSourceAlgebraicOpFn)>,
std::decay_t<decltype(createTargetAlgebraicOpFn)>>;
results.add(std::make_unique<CastTransferReorderPattern>(
std::move(getHomomorphismOpOperandFn),
std::move(getHomomorphismOpResultFn),
std::move(getSourceAlgebraicOpOperandsFn),
std::move(getSourceAlgebraicOpResultFn),
std::move(getTargetAlgebraicOpResultFn), std::move(isHomomorphismOpFn),
std::move(isSourceAlgebraicOpFn), std::move(createTargetAlgebraicOpFn),
TensorTransferOp::getOperationName(), 1, results.getContext()));
}

} // namespace

void TensorTransferOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<ElideRedundantTransfer>(context);
populateMoveOpAfterTransferPattern<tensor::BitcastOp>(results);
populateMoveOpAfterTransferPattern<TensorBitCastOp>(results);
populateMoveOpAfterTransferPattern<tensor::CastOp>(results);
populateMoveOpAfterTransferPattern<tensor::ReshapeOp>(results);
populateMoveOpAfterTransferPattern<TensorReshapeOp>(results);
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,39 @@ util.func public @tensor_cast_to_reshape(%reshape_17 : tensor<?x?x?x?xf32>, %65
// CHECK: flow.tensor.reshape
// CHECK-SAME: tensor<?x12x?x64xf32>
// CHECK-SAME: -> tensor<?x?x?x?xf32>

// -----

// CHECK-LABEL: util.func public @move_reshape_after_transfer
util.func public @move_reshape_after_transfer(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>) -> tensor<?x2xf32> {
%arg0: tensor<1x?xf32>) -> tensor<?x2xf32> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
// CHECK: %[[TRANSFER:.*]] = flow.tensor.transfer %[[ARG0]] : tensor<1x?xf32>{%[[C2]]} to #hal.device.affinity<@__device_0>
// CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[TRANSFER]] : tensor<1x?xf32>{%[[C2]]} -> tensor<?x2xf32>{%[[C1]]}
%1 = flow.tensor.reshape %arg0 : tensor<1x?xf32>{%c2} -> tensor<?x2xf32>{%c1}
%2 = flow.tensor.transfer %1 : tensor<?x2xf32>{%c1} to
#hal.device.affinity<@__device_0>
// CHECK: util.return %[[RESHAPE]] : tensor<?x2xf32>
util.return %2 : tensor<?x2xf32>
}

// -----

// CHECK-LABEL: util.func public @do_not_move_reshape_after_transfer
util.func public @do_not_move_reshape_after_transfer(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>) -> tensor<1xf32> {
%arg0: tensor<?xf32>) -> tensor<1xf32> {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
// CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[ARG0]] : tensor<?xf32>{%[[C1]]} -> tensor<1xf32>
%1 = flow.tensor.reshape %arg0 : tensor<?xf32>{%c1} -> tensor<1xf32>
// CHECK: %[[TRANSFER:.*]] = flow.tensor.transfer %[[RESHAPE]] : tensor<1xf32> to #hal.device.affinity<@__device_0>
%2 = flow.tensor.transfer %1 : tensor<1xf32> to
#hal.device.affinity<@__device_0>
// CHECK: util.return %[[TRANSFER]] : tensor<1xf32>
util.return %2 : tensor<1xf32>
}
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ iree_compiler_cc_library(
"PassUtils.h",
"PatternUtils.h",
"Permutation.h",
"Shape.h",
"SmallVectorDenseMapInfo.h",
"StringUtils.h",
"ToolUtils.h",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_cc_library(
"PassUtils.h"
"PatternUtils.h"
"Permutation.h"
"Shape.h"
"SmallVectorDenseMapInfo.h"
"StringUtils.h"
"ToolUtils.h"
Expand Down
31 changes: 31 additions & 0 deletions compiler/src/iree/compiler/Utils/Shape.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_UTILS_SHAPE_H_
#define IREE_COMPILER_UTILS_SHAPE_H_

#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"

namespace mlir::iree_compiler {

// Unranked shapes are always considered to have more dynamic dimensions than
// ranked.
inline bool shapeHasLessOrEqualDynamicDimensions(ShapedType t1, ShapedType t2) {
if (!t2.hasRank()) {
return true;
}
if (!t1.hasRank()) {
return false;
}

return llvm::count_if(t1.getShape(), &ShapedType::isDynamic) <=
llvm::count_if(t2.getShape(), &ShapedType::isDynamic);
}

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_UTILS_SHAPE_H_

0 comments on commit f4c155a

Please sign in to comment.