diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index 1a60d1cc284e..c14995c701e0 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include +#include #include #include @@ -13,6 +14,7 @@ #include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "iree/compiler/Utils/Shape.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -30,6 +32,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/HomomorphismSimplification.h" namespace mlir::iree_compiler::IREE::Flow { @@ -1180,11 +1183,133 @@ struct ElideRedundantTransfer : public OpRewritePattern { } }; +// Move an op of type Op after flow.tensor.transfer. +// E.g. +// +// %cast = tensor.cast %0 : tensor<2xf32> to tensor +// %1 = flow.tensor.transfer %cast : tensor{%c2} to +// #hal.device.affinity<@__device_0> +// +// is transformed into +// +// %1 = flow.tensor.transfer %0 : tensor<2xf32> to +// #hal.device.affinity<@__device_0> %cast = tensor.cast %1 : tensor<2xf32> to +// tensor +// +// Will only move the op if the resulting transfer op would operate on a +// tensor with less or equal number of dynamic dimensions. +// The goal is that after application we would have one more op +// (the transfer op) that has less dynamism. +// +// E.g. +// +// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor +// %1 = flow.tensor.transfer %cast : tensor{%c1} to +// #hal.device.affinity<@__device_0> +// +// will match, but this will not +// +// %cast = tensor.cast %0 : tensor<1x?xf32> to tensor<1x2xf32> +// %1 = flow.tensor.transfer %cast : tensor<1x2xf32> to +// #hal.device.affinity<@__device_0> +// +// The weaker condition that the number of dynamic dimensions is not strictly +// less is the tie breaker between the symmetric pattern where we move the op +// before the transfer instead. We strive to reduce the dynamism of the +// transfer op. If there will be no strict dynamism improvement, we prefer the +// other op after the transfer. +// TODO: add the analogous move-befor-transfer pattern. +// +// Uses the homomorphism simplification pattern +// transfer(cast(x)) -> cast(transfer(x)) +// where the cast op is the homomorphism +// and the transfer is the unar (mono-unary algebra) operation. +template +static void populateMoveOpAfterTransferPattern(RewritePatternSet &results) { + + auto getHomomorphismOpOperandFn = [](Operation *op) { + // op is Op. + return &op->getOpOperand(homomorphismOpOperandIndex); + }; + + auto getHomomorphismOpResultFn = [](Operation *op) { + // op is Op. + return op->getResult(homomorphismOpResultIndex); + }; + + auto getSourceAlgebraicOpOperandsFn = [](Operation *op, + SmallVector &operands) { + // Op is transfer. + operands.push_back(&op->getOpOperand(0)); + }; + + auto getSourceAlgebraicOpResultFn = [](Operation *op) { + // Op is transfer. + return op->getResult(0); + }; + + auto getTargetAlgebraicOpResultFn = [](Operation *op) { + // Op is transfer. + return op->getResult(0); + }; + + auto isHomomorphismOpFn = [](Operation *op, + std::optional referenceOp) { + auto operation = dyn_cast(op); + if (!operation) { + return false; + } + auto sourceType = op->getOperand(homomorphismOpOperandIndex).getType(); + auto targetType = op->getResult(homomorphismOpOperandIndex).getType(); + return shapeHasLessOrEqualDynamicDimensions(cast(sourceType), + cast(targetType)); + }; + + auto isSourceAlgebraicOpFn = [](Operation *op) { + return static_cast(dyn_cast(op)); + }; + + auto createTargetAlgebraicOpFn = [](Operation *originalOp, + IRMapping &operandsRemapping, + PatternRewriter &rewriter) { + // Create the transfer op. + auto originalTransferOp = cast(originalOp); + return rewriter.create( + originalTransferOp->getLoc(), + operandsRemapping.lookup(originalTransferOp.getOperand()), + originalTransferOp.getTargetAttr()); + }; + + using CastTransferReorderPattern = HomomorphismSimplification< + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t>; + results.add(std::make_unique( + std::move(getHomomorphismOpOperandFn), + std::move(getHomomorphismOpResultFn), + std::move(getSourceAlgebraicOpOperandsFn), + std::move(getSourceAlgebraicOpResultFn), + std::move(getTargetAlgebraicOpResultFn), std::move(isHomomorphismOpFn), + std::move(isSourceAlgebraicOpFn), std::move(createTargetAlgebraicOpFn), + TensorTransferOp::getOperationName(), 1, results.getContext())); +} + } // namespace void TensorTransferOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.insert(context); + populateMoveOpAfterTransferPattern(results); + populateMoveOpAfterTransferPattern(results); + populateMoveOpAfterTransferPattern(results); + populateMoveOpAfterTransferPattern(results); + populateMoveOpAfterTransferPattern(results); } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir index effadc47da5d..9f40f0a16803 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir @@ -106,3 +106,39 @@ util.func public @tensor_cast_to_reshape(%reshape_17 : tensor, %65 // CHECK: flow.tensor.reshape // CHECK-SAME: tensor // CHECK-SAME: -> tensor + +// ----- + +// CHECK-LABEL: util.func public @move_reshape_after_transfer +util.func public @move_reshape_after_transfer( +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>) -> tensor { + %arg0: tensor<1x?xf32>) -> tensor { + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + // CHECK: %[[TRANSFER:.*]] = flow.tensor.transfer %[[ARG0]] : tensor<1x?xf32>{%[[C2]]} to #hal.device.affinity<@__device_0> + // CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[TRANSFER]] : tensor<1x?xf32>{%[[C2]]} -> tensor{%[[C1]]} + %1 = flow.tensor.reshape %arg0 : tensor<1x?xf32>{%c2} -> tensor{%c1} + %2 = flow.tensor.transfer %1 : tensor{%c1} to + #hal.device.affinity<@__device_0> + // CHECK: util.return %[[RESHAPE]] : tensor + util.return %2 : tensor +} + +// ----- + +// CHECK-LABEL: util.func public @do_not_move_reshape_after_transfer +util.func public @do_not_move_reshape_after_transfer( +// CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor<1xf32> { + %arg0: tensor) -> tensor<1xf32> { + // CHECK: %[[C1:.*]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[ARG0]] : tensor{%[[C1]]} -> tensor<1xf32> + %1 = flow.tensor.reshape %arg0 : tensor{%c1} -> tensor<1xf32> + // CHECK: %[[TRANSFER:.*]] = flow.tensor.transfer %[[RESHAPE]] : tensor<1xf32> to #hal.device.affinity<@__device_0> + %2 = flow.tensor.transfer %1 : tensor<1xf32> to + #hal.device.affinity<@__device_0> + // CHECK: util.return %[[TRANSFER]] : tensor<1xf32> + util.return %2 : tensor<1xf32> +} diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel index dbcdc0156dae..e8619564064b 100644 --- a/compiler/src/iree/compiler/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Utils/BUILD.bazel @@ -42,6 +42,7 @@ iree_compiler_cc_library( "PassUtils.h", "PatternUtils.h", "Permutation.h", + "Shape.h", "SmallVectorDenseMapInfo.h", "StringUtils.h", "ToolUtils.h", diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt index c4f20b2ac74f..405c9180c549 100644 --- a/compiler/src/iree/compiler/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt @@ -27,6 +27,7 @@ iree_cc_library( "PassUtils.h" "PatternUtils.h" "Permutation.h" + "Shape.h" "SmallVectorDenseMapInfo.h" "StringUtils.h" "ToolUtils.h" diff --git a/compiler/src/iree/compiler/Utils/Shape.h b/compiler/src/iree/compiler/Utils/Shape.h new file mode 100644 index 000000000000..94fe9b8f4e93 --- /dev/null +++ b/compiler/src/iree/compiler/Utils/Shape.h @@ -0,0 +1,31 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_UTILS_SHAPE_H_ +#define IREE_COMPILER_UTILS_SHAPE_H_ + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" + +namespace mlir::iree_compiler { + +// Unranked shapes are always considered to have more dynamic dimensions than +// ranked. +inline bool shapeHasLessOrEqualDynamicDimensions(ShapedType t1, ShapedType t2) { + if (!t2.hasRank()) { + return true; + } + if (!t1.hasRank()) { + return false; + } + + return llvm::count_if(t1.getShape(), &ShapedType::isDynamic) <= + llvm::count_if(t2.getShape(), &ShapedType::isDynamic); +} + +} // namespace mlir::iree_compiler + +#endif // IREE_COMPILER_UTILS_SHAPE_H_