Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Data tiling: transpose narrow-N into narrow-M" #17503

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,20 @@ materializeEncodingForTarget(RankedTensorType tensorType,
if (enumeratedTileMxNxK.empty()) {
return failure();
}
// Check if the encoding specifies static narrow sizes for the M/N dimensions.
// This can be used to choose a correspondingly narrow tile shape.
// With microkernels, we keep this logic in sync with the set of actual
// optimized microkernel tile functions to avoid a tile shape specialization
// causing a fallback to a slow generic tile function. At the moment,
// microkernel tile functions are only specialize for narrow M, not for narrow
// N. Accordingly, we leave matmulNarrowN as 0 (default) when microkernels are
// used. Generally it would be best to deal with narrow-N cases by transposing
// the whole matmul and swapping LHS<->RHS, reducing the narrow-N case to
// narrow-M.
int64_t matmulNarrowM = getIntOrZero(encoding.getMatmulNarrow_M());
int64_t matmulNarrowN = getIntOrZero(encoding.getMatmulNarrow_N());
int64_t matmulNarrowN = hasUkernel(targetAttr, "mmt4d")
? 0
: getIntOrZero(encoding.getMatmulNarrow_N());
// Choose a final matmul TileMxNxK from the above-enumarated tile shapes,
// taking narrow dimensions into account.
TileMxNxK chosenTileMxNxK =
Expand Down

Large diffs are not rendered by default.

107 changes: 11 additions & 96 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,107 +9,30 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"

#include <numeric>

namespace mlir::iree_compiler {

using IREE::Encoding::EncodingAttr;
using IREE::Encoding::EncodingRole;
using IREE::Encoding::getEncodingAttr;
using IREE::Encoding::getEncodingContractionDims;

// If tensorType has the encoding of a matmul RESULT with narrow N, returns
// the transposed type. Otherwise, just returns tensorType.
static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
auto encoding =
llvm::dyn_cast_or_null<EncodingAttr>(tensorType.getEncoding());
if (!encoding) {
return tensorType;
}
if (!isNarrowNResult(encoding)) {
return tensorType;
}
auto newRole = encoding.getRole().getValue();
TypeAttr originalTypeAttr = encoding.getOriginalType();
RankedTensorType originalType = tensorType;
if (originalTypeAttr) {
originalType =
llvm::dyn_cast<RankedTensorType>(originalTypeAttr.getValue());
}
SmallVector<int64_t> newOriginalShape(originalType.getShape());
auto userIndexingMaps = encoding.getUserIndexingMaps();
SmallVector<AffineMap> maps;
for (auto a : userIndexingMaps) {
maps.push_back(cast<AffineMapAttr>(a).getAffineMap());
}
auto cDims = linalg::inferContractionDims(maps);
SmallVector<int64_t> newShape(tensorType.getShape());
SmallVector<int64_t> permIndices(maps[0].getNumDims());
std::iota(std::begin(permIndices), std::end(permIndices), 0);
// Matrix case: there are both M and N dimensions. Transposing means swapping
// them.
if (cDims->m.size() == 1 && cDims->n.size() == 1) {
int m = cDims->m[0];
int n = cDims->n[0];
std::swap(permIndices[m], permIndices[n]);
int mDim = encoding.mapDimToRoleIndex(m);
int nDim = encoding.mapDimToRoleIndex(n);
std::swap(newShape[mDim], newShape[nDim]);
std::swap(newOriginalShape[mDim], newOriginalShape[nDim]);
}
// Vector case: there is no N dimension to swap the M dimension with. We
// swap the maps themselves.
if (cDims->n.empty()) {
std::swap(maps[0], maps[1]);
}

// auto newRoundDimsTo = encoding.getRoundDimsToArray();
SmallVector<int64_t> newRoundDimsTo(encoding.getRoundDimsToArray());
assert(newRoundDimsTo.size() == 0 || newRoundDimsTo.size() == 3);
if (newRoundDimsTo.size() != 0)
std::swap(newRoundDimsTo[0], newRoundDimsTo[1]);

auto context = tensorType.getContext();
AffineMap permutation = AffineMap::getPermutationMap(permIndices, context);
for (auto &map : maps) {
map = map.compose(permutation);
}
SmallVector<Attribute> newMaps;
for (auto map : maps) {
newMaps.push_back(AffineMapAttr::get(map));
}
ArrayAttr newIndexingMaps = ArrayAttr::get(context, newMaps);
auto elemType = tensorType.getElementType();
OpBuilder builder(context);

auto newEncoding = IREE::Encoding::EncodingAttr::get(
context, IREE::Encoding::EncodingRoleAttr::get(context, newRole),
encoding.getElementTypes(),
TypeAttr::get(RankedTensorType::get(newOriginalShape, elemType)),
encoding.getMatmulNarrow_N(), encoding.getMatmulNarrow_M(),
newIndexingMaps, DenseI64ArrayAttr::get(context, newRoundDimsTo));
return RankedTensorType::get(newShape, elemType, newEncoding);
}

/// For a given tensor type with an encoding, return the materialized
/// type to use for it. If no encoding is set, then return the tensor type
/// itself.
static RankedTensorType
getMaterializedType(RankedTensorType tensorType,
MaterializeEncodingFn materializeEncodingFn) {
RankedTensorType maybeTransposedTensorType =
transposeIfNarrowNResult(tensorType);
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(maybeTransposedTensorType);
materializeEncodingFn(tensorType);
if (failed(materializeEncodingInfo)) {
return dropEncoding(tensorType);
}
return cast<RankedTensorType>(tensor::PackOp::inferPackedType(
getOriginalTypeWithEncoding(maybeTransposedTensorType)
.clone(tensorType.getElementType()),
materializeEncodingInfo->innerTileSizes,
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm));
return cast<RankedTensorType>(
tensor::PackOp::inferPackedType(getOriginalTypeWithEncoding(tensorType)
.clone(tensorType.getElementType()),
materializeEncodingInfo->innerTileSizes,
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm));
}

MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
Expand All @@ -119,9 +42,10 @@ MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
addConversion([](IndexType indexType) { return indexType; });
addConversion([](FloatType floatType) { return floatType; });
addConversion([](MemRefType memrefType) { return memrefType; });
addConversion([=](RankedTensorType t) -> RankedTensorType {
return getMaterializedType(t, materializeEncodingFn);
});
addConversion(
[materializeEncodingFn](RankedTensorType t) -> RankedTensorType {
return getMaterializedType(t, materializeEncodingFn);
});
}

MaterializeEncodingConversionTarget::MaterializeEncodingConversionTarget(
Expand Down Expand Up @@ -203,13 +127,4 @@ MaterializeEncodingInfo getEncodingInfoForMatmul(EncodingAttr encoding,
return encodingInfo;
}

bool isNarrowNResult(EncodingAttr encoding) {
if (encoding.getRole().getValue() != EncodingRole::RESULT) {
return false;
}
IntegerAttr narrowM = encoding.getMatmulNarrow_M();
IntegerAttr narrowN = encoding.getMatmulNarrow_N();
return narrowN && (!narrowM || narrowM.getInt() > narrowN.getInt());
}

} // namespace mlir::iree_compiler
4 changes: 0 additions & 4 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ void populateMaterializeEncodingIntoPackUnPackPatterns(
void populateMaterializeUpperBoundTileSizePatterns(
RewritePatternSet &patterns, MaterializeEncodingFn materializeEncodingFn);

// Returns true if `encoding` represents a narrow-N matmul RESULT, e.g. the
// result of a matvec.
bool isNarrowNResult(IREE::Encoding::EncodingAttr encoding);

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_
Loading
Loading