Skip to content

Commit

Permalink
Revert "Data tiling: transpose narrow-N into narrow-M" (iree-org#17503)
Browse files Browse the repository at this point in the history
Reverts iree-org#17446

Reason: postsubmit failures on arm64,
iree-org#17446 (comment)
  • Loading branch information
bjacob authored and gglangg committed Jun 4, 2024
1 parent 7e41631 commit f5e3020
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 411 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,20 @@ materializeEncodingForTarget(RankedTensorType tensorType,
if (enumeratedTileMxNxK.empty()) {
return failure();
}
// Check if the encoding specifies static narrow sizes for the M/N dimensions.
// This can be used to choose a correspondingly narrow tile shape.
// With microkernels, we keep this logic in sync with the set of actual
// optimized microkernel tile functions to avoid a tile shape specialization
// causing a fallback to a slow generic tile function. At the moment,
// microkernel tile functions are only specialize for narrow M, not for narrow
// N. Accordingly, we leave matmulNarrowN as 0 (default) when microkernels are
// used. Generally it would be best to deal with narrow-N cases by transposing
// the whole matmul and swapping LHS<->RHS, reducing the narrow-N case to
// narrow-M.
int64_t matmulNarrowM = getIntOrZero(encoding.getMatmulNarrow_M());
int64_t matmulNarrowN = getIntOrZero(encoding.getMatmulNarrow_N());
int64_t matmulNarrowN = hasUkernel(targetAttr, "mmt4d")
? 0
: getIntOrZero(encoding.getMatmulNarrow_N());
// Choose a final matmul TileMxNxK from the above-enumarated tile shapes,
// taking narrow dimensions into account.
TileMxNxK chosenTileMxNxK =
Expand Down

Large diffs are not rendered by default.

107 changes: 11 additions & 96 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,107 +9,30 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"

#include <numeric>

namespace mlir::iree_compiler {

using IREE::Encoding::EncodingAttr;
using IREE::Encoding::EncodingRole;
using IREE::Encoding::getEncodingAttr;
using IREE::Encoding::getEncodingContractionDims;

// If tensorType has the encoding of a matmul RESULT with narrow N, returns
// the transposed type. Otherwise, just returns tensorType.
static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
auto encoding =
llvm::dyn_cast_or_null<EncodingAttr>(tensorType.getEncoding());
if (!encoding) {
return tensorType;
}
if (!isNarrowNResult(encoding)) {
return tensorType;
}
auto newRole = encoding.getRole().getValue();
TypeAttr originalTypeAttr = encoding.getOriginalType();
RankedTensorType originalType = tensorType;
if (originalTypeAttr) {
originalType =
llvm::dyn_cast<RankedTensorType>(originalTypeAttr.getValue());
}
SmallVector<int64_t> newOriginalShape(originalType.getShape());
auto userIndexingMaps = encoding.getUserIndexingMaps();
SmallVector<AffineMap> maps;
for (auto a : userIndexingMaps) {
maps.push_back(cast<AffineMapAttr>(a).getAffineMap());
}
auto cDims = linalg::inferContractionDims(maps);
SmallVector<int64_t> newShape(tensorType.getShape());
SmallVector<int64_t> permIndices(maps[0].getNumDims());
std::iota(std::begin(permIndices), std::end(permIndices), 0);
// Matrix case: there are both M and N dimensions. Transposing means swapping
// them.
if (cDims->m.size() == 1 && cDims->n.size() == 1) {
int m = cDims->m[0];
int n = cDims->n[0];
std::swap(permIndices[m], permIndices[n]);
int mDim = encoding.mapDimToRoleIndex(m);
int nDim = encoding.mapDimToRoleIndex(n);
std::swap(newShape[mDim], newShape[nDim]);
std::swap(newOriginalShape[mDim], newOriginalShape[nDim]);
}
// Vector case: there is no N dimension to swap the M dimension with. We
// swap the maps themselves.
if (cDims->n.empty()) {
std::swap(maps[0], maps[1]);
}

// auto newRoundDimsTo = encoding.getRoundDimsToArray();
SmallVector<int64_t> newRoundDimsTo(encoding.getRoundDimsToArray());
assert(newRoundDimsTo.size() == 0 || newRoundDimsTo.size() == 3);
if (newRoundDimsTo.size() != 0)
std::swap(newRoundDimsTo[0], newRoundDimsTo[1]);

auto context = tensorType.getContext();
AffineMap permutation = AffineMap::getPermutationMap(permIndices, context);
for (auto &map : maps) {
map = map.compose(permutation);
}
SmallVector<Attribute> newMaps;
for (auto map : maps) {
newMaps.push_back(AffineMapAttr::get(map));
}
ArrayAttr newIndexingMaps = ArrayAttr::get(context, newMaps);
auto elemType = tensorType.getElementType();
OpBuilder builder(context);

auto newEncoding = IREE::Encoding::EncodingAttr::get(
context, IREE::Encoding::EncodingRoleAttr::get(context, newRole),
encoding.getElementTypes(),
TypeAttr::get(RankedTensorType::get(newOriginalShape, elemType)),
encoding.getMatmulNarrow_N(), encoding.getMatmulNarrow_M(),
newIndexingMaps, DenseI64ArrayAttr::get(context, newRoundDimsTo));
return RankedTensorType::get(newShape, elemType, newEncoding);
}

/// For a given tensor type with an encoding, return the materialized
/// type to use for it. If no encoding is set, then return the tensor type
/// itself.
static RankedTensorType
getMaterializedType(RankedTensorType tensorType,
MaterializeEncodingFn materializeEncodingFn) {
RankedTensorType maybeTransposedTensorType =
transposeIfNarrowNResult(tensorType);
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(maybeTransposedTensorType);
materializeEncodingFn(tensorType);
if (failed(materializeEncodingInfo)) {
return dropEncoding(tensorType);
}
return cast<RankedTensorType>(tensor::PackOp::inferPackedType(
getOriginalTypeWithEncoding(maybeTransposedTensorType)
.clone(tensorType.getElementType()),
materializeEncodingInfo->innerTileSizes,
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm));
return cast<RankedTensorType>(
tensor::PackOp::inferPackedType(getOriginalTypeWithEncoding(tensorType)
.clone(tensorType.getElementType()),
materializeEncodingInfo->innerTileSizes,
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm));
}

MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
Expand All @@ -119,9 +42,10 @@ MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
addConversion([](IndexType indexType) { return indexType; });
addConversion([](FloatType floatType) { return floatType; });
addConversion([](MemRefType memrefType) { return memrefType; });
addConversion([=](RankedTensorType t) -> RankedTensorType {
return getMaterializedType(t, materializeEncodingFn);
});
addConversion(
[materializeEncodingFn](RankedTensorType t) -> RankedTensorType {
return getMaterializedType(t, materializeEncodingFn);
});
}

MaterializeEncodingConversionTarget::MaterializeEncodingConversionTarget(
Expand Down Expand Up @@ -203,13 +127,4 @@ MaterializeEncodingInfo getEncodingInfoForMatmul(EncodingAttr encoding,
return encodingInfo;
}

bool isNarrowNResult(EncodingAttr encoding) {
if (encoding.getRole().getValue() != EncodingRole::RESULT) {
return false;
}
IntegerAttr narrowM = encoding.getMatmulNarrow_M();
IntegerAttr narrowN = encoding.getMatmulNarrow_N();
return narrowN && (!narrowM || narrowM.getInt() > narrowN.getInt());
}

} // namespace mlir::iree_compiler
4 changes: 0 additions & 4 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ void populateMaterializeEncodingIntoPackUnPackPatterns(
void populateMaterializeUpperBoundTileSizePatterns(
RewritePatternSet &patterns, MaterializeEncodingFn materializeEncodingFn);

// Returns true if `encoding` represents a narrow-N matmul RESULT, e.g. the
// result of a matvec.
bool isNarrowNResult(IREE::Encoding::EncodingAttr encoding);

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_
Loading

0 comments on commit f5e3020

Please sign in to comment.