Skip to content

Commit

Permalink
[mlir][tensor] Add a tensor.concat operation (llvm#72779)
Browse files Browse the repository at this point in the history
This adds an operation for concatenating ranked tensors along a static
dimension, as well as a decomposition mirroring the existing lowering
from TOSA to Tensor. This offers a convergence point for "input" like
dialects that include various lowerings for concatenation operations,
easing later analysis. In the future, this op can implement the
necessary interfaces for tiling, as well as potentially add conversions
to some kind of linalg and/or memref counterpart.

This patch adds the op, the decomposition, and some basic
folding/canonicalization. Replacing lowerings with the op (such as the
TOSA lowering) will come as a follow up.

See
https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858
  • Loading branch information
qedawkins authored Dec 1, 2023
1 parent 4c44dcf commit f310a5d
Show file tree
Hide file tree
Showing 13 changed files with 554 additions and 58 deletions.
64 changes: 64 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,70 @@ def Tensor_CastOp : Tensor_Op<"cast", [
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//

def Tensor_ConcatOp : Tensor_Op<"concat",
[Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "tensor concatenation operation";
let description = [{
The "concat" operation constructs a tensor out of a variadic list of input
tensors, concatenated along a static dimension number. All inputs and the
result type must share the same rank.

`dim` specifies the dimension along which to concatenate. The size of the
concatenated dimension in the result must be equal to the sum of the sizes
of the inputs along that dimension. All other dimensions in both the inputs
and result must be the same size.

Example:

```mlir
%0 = tensor.concat dim(0) %0, %1, %2 :
(tensor<3x6xf32>, tensor<3x6xf32>, tensor<1x6xf32) -> tensor<7x6xf32>

// Dynamic + dynamic -> static
%0 = tensor.concat dim(1) %0, %1, %2 :
(tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32>
```
}];
let arguments = (ins I64Attr:$dim,
Variadic<AnyRankedTensor>:$inputs);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
`dim` `(` $dim `)` $inputs attr-dict
`:` functional-type(operands, results)
}];

let builders = [
// Builder with an inferred result type.
OpBuilder<(ins "int64_t":$dim, "ValueRange":$inputs)>,
];

let extraClassDeclaration = [{
// Helper to infer the concatenated result type for the given list of input
// types, being concatenated along `dim`. Because concatenation can specify
// more static information than can automatically be inferred,
// InferTypeOpInterface is not used.
static RankedTensorType inferResultType(int64_t dim, TypeRange inputTypes);

RankedTensorType getResultType() {
return ::llvm::cast<RankedTensorType>(getResult().getType());
}

int64_t getRank() {
return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
}
}];

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def ApplyDecomposeTensorConcatPatternsOp : Op<Transform_Dialect,
"apply_patterns.tensor.decompose_concat",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that tensor.concat ops should be decomposed into a chain of
tensor.insert_slice operations inserting into a materialized destination.
}];

let assemblyFormat = "attr-dict";
}


def ApplyDropRedundantInsertSliceRankExpansionPatternsOp : Op<Transform_Dialect,
"apply_patterns.tensor.drop_redundant_insert_slice_rank_expansion",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
bool foldSingleUseOnly = false);

/// Populates `patterns` with patterns that decompose `tensor.concat` into
/// `tensor.empty` of a tensor of the concatenated size, followed by a chain
/// of `tensor.insert_slice` operations on the inputs. This is intended to be
/// used as a fallback tensor -> tensor lowering that decomposes concat such
/// that it can be bufferized into a sequence of copies.
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that fold operations like `tensor.pad`
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
/// respectively.
Expand Down
33 changes: 33 additions & 0 deletions mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,39 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
OpFoldResult step);

/// Idiomatic saturated operations on values like offsets, sizes, and strides.
struct SaturatedInteger {
static SaturatedInteger wrap(int64_t v) {
return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
: SaturatedInteger{false, v};
}
int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
if (saturated && !other.saturated)
return other;
if (!saturated && !other.saturated && v != other.v)
return failure();
return *this;
}
bool operator==(SaturatedInteger other) {
return (saturated && other.saturated) ||
(!saturated && !other.saturated && v == other.v);
}
bool operator!=(SaturatedInteger other) { return !(*this == other); }
SaturatedInteger operator+(SaturatedInteger other) {
if (saturated || other.saturated)
return SaturatedInteger{true, 0};
return SaturatedInteger{false, other.v + v};
}
SaturatedInteger operator*(SaturatedInteger other) {
if (saturated || other.saturated)
return SaturatedInteger{true, 0};
return SaturatedInteger{false, other.v * v};
}
bool saturated = true;
int64_t v = 0;
};

} // namespace mlir

#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
76 changes: 18 additions & 58 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,43 +26,6 @@
using namespace mlir;
using namespace mlir::memref;

namespace {
/// Idiomatic saturated operations on offsets, sizes and strides.
namespace saturated_arith {
struct Wrapper {
static Wrapper stride(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
static Wrapper offset(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
static Wrapper size(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
bool operator==(Wrapper other) {
return (saturated && other.saturated) ||
(!saturated && !other.saturated && v == other.v);
}
bool operator!=(Wrapper other) { return !(*this == other); }
Wrapper operator+(Wrapper other) {
if (saturated || other.saturated)
return Wrapper{true, 0};
return Wrapper{false, other.v + v};
}
Wrapper operator*(Wrapper other) {
if (saturated || other.saturated)
return Wrapper{true, 0};
return Wrapper{false, other.v * v};
}
bool saturated;
int64_t v;
};
} // namespace saturated_arith
} // namespace

/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
Expand Down Expand Up @@ -2208,11 +2171,11 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
ReassociationIndices reassoc = std::get<0>(it);
int64_t currentStrideToExpand = std::get<1>(it);
for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
using saturated_arith::Wrapper;
reverseResultStrides.push_back(currentStrideToExpand);
currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
Wrapper::size(resultShape[shapeIndex--]))
.asStride();
currentStrideToExpand =
(SaturatedInteger::wrap(currentStrideToExpand) *
SaturatedInteger::wrap(resultShape[shapeIndex--]))
.asInteger();
}
}
auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
Expand Down Expand Up @@ -2332,10 +2295,9 @@ computeCollapsedLayoutMap(MemRefType srcType,
unsigned resultStrideIndex = resultStrides.size() - 1;
for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
using saturated_arith::Wrapper;
auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
for (int64_t idx : llvm::reverse(trailingReassocs)) {
stride = stride * Wrapper::size(srcShape[idx]);
stride = stride * SaturatedInteger::wrap(srcShape[idx]);

// Both source and result stride must have the same static value. In that
// case, we can be sure, that the dimensions are collapsible (because they
Expand All @@ -2345,7 +2307,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
// ops where obviously non-contiguous dims are collapsed, but accept ops
// where we cannot be sure statically. Such ops may fail at runtime. See
// the op documentation for details.
auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
if (strict && (stride.saturated || srcStride.saturated))
return failure();

Expand All @@ -2371,11 +2333,11 @@ MemRefType CollapseShapeOp::computeCollapsedType(
SmallVector<int64_t> resultShape;
resultShape.reserve(reassociation.size());
for (const ReassociationIndices &group : reassociation) {
using saturated_arith::Wrapper;
auto groupSize = Wrapper::size(1);
auto groupSize = SaturatedInteger::wrap(1);
for (int64_t srcDim : group)
groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
resultShape.push_back(groupSize.asSize());
groupSize =
groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
resultShape.push_back(groupSize.asInteger());
}

if (srcType.getLayout().isIdentity()) {
Expand Down Expand Up @@ -2586,11 +2548,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
int64_t targetOffset = sourceOffset;
for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
using saturated_arith::Wrapper;
targetOffset =
(Wrapper::offset(targetOffset) +
Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
.asOffset();
targetOffset = (SaturatedInteger::wrap(targetOffset) +
SaturatedInteger::wrap(staticOffset) *
SaturatedInteger::wrap(targetStride))
.asInteger();
}

// Compute target stride whose value is:
Expand All @@ -2599,10 +2560,9 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
targetStrides.reserve(staticOffsets.size());
for (auto it : llvm::zip(sourceStrides, staticStrides)) {
auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
using saturated_arith::Wrapper;
targetStrides.push_back(
(Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
.asStride());
targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
SaturatedInteger::wrap(staticStride))
.asInteger());
}

// The type is now known.
Expand Down
Loading

0 comments on commit f310a5d

Please sign in to comment.