Skip to content

Commit

Permalink
Refactor AIRSplitL2Memref to support transposed data movement (Xilinx…
Browse files Browse the repository at this point in the history
…#563)

* Refactor method to map dims between offsets and memref shape, taking into account transpose

* Test
  • Loading branch information
erwei-xilinx authored May 6, 2024
1 parent 4b0b5b3 commit aff2a49
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 58 deletions.
14 changes: 14 additions & 0 deletions mlir/include/air/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ void renumberChannelOps(Block *region, std::map<int, int> &reverse_map);

// Return op name as string
std::string to_string(Operation *op);
// Return type name as string
std::string to_string(mlir::Type t);

// Generate a new unique channel name
Expand Down Expand Up @@ -203,6 +204,19 @@ getUpdatedOffsetsAfterShrinkage(SmallVector<int> old_memref_shape,
SmallVector<int64_t> new_memref_shape,
SmallVector<Value> offsets);

// Given a dimension on wrap-and-stride list, infer the dimension on memref that
// this pattern spans completely on.
std::optional<int> getMemrefDimFromOffsetDim(int dimOnOffset,
SmallVector<Value> offsets,
SmallVector<Value> strides,
SmallVector<int> memrefShape);

// Given a dimension on memref shape, infer the dimension on wrap-and-stride
// list that spans on this memref dimension.
std::optional<int> getOffsetDimFromMemrefDim(int dimOnMemref,
SmallVector<Value> strides,
SmallVector<int> memrefShape);

} // namespace air
} // namespace xilinx

Expand Down
100 changes: 42 additions & 58 deletions mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,6 @@ class AIRSpecializeDmaBroadcast
auto loc = memcpyOp->getLoc();
bool opIsUpdated = false;
for (unsigned i = 0; i < current_shape_expr.size(); i++) {
// std::cout << herdDimToDmaOffsetDimMap[i] << " ";
if (!current_shape_expr[i])
continue;
if (!herdDimToDmaOffsetDimMap[i])
Expand Down Expand Up @@ -1064,8 +1063,9 @@ class AIRSplitL2MemrefForBufferConstraintPass
getTargetMemrefAllocs(func::FuncOp func,
std::map<memref::AllocOp, SmallVector<int>>
&targetMemrefsToColTilingFactors);
int getMemrefSplitDim(SmallVector<air::ChannelInterface> putgets,
int memrefRank);
std::optional<int>
getMemrefSplitDim(SmallVector<air::ChannelInterface> putgets,
SmallVector<int> memrefShape);
};

template <typename T> void push_back_if_unique(SmallVector<T> &vec, T entry) {
Expand Down Expand Up @@ -1171,34 +1171,6 @@ Value tileChannelOpByFactor(air::ChannelInterface originalChanOp, int factor,
return newWaitAll.getAsyncToken();
}

std::optional<int> getFirstConstantOffsetValue(SmallVector<Value> offsets,
int memrefRank,
int &initialDim) {
int offsetDim = (int)offsets.size() >= memrefRank
? offsets.size() - memrefRank + initialDim
: 0;
auto offset = getConstantIntValue(offsets[offsetDim]);
// Find the first constant offset to use as key for memref splitting.
while (!offset && offsetDim < (int)offsets.size()) {
offset = getConstantIntValue(offsets[++offsetDim]);
initialDim++;
}
return offset;
}

int getFirstConstantOffsetValueIndex(SmallVector<Value> offsets, int memrefRank,
int initialDim = 0) {
int offsetDim = (int)offsets.size() >= memrefRank
? offsets.size() - memrefRank + initialDim
: 0;
auto offset = getConstantIntValue(offsets[offsetDim]);
// Find the first constant offset to use as key for memref splitting.
while (!offset && offsetDim < (int)offsets.size()) {
offset = getConstantIntValue(offsets[++offsetDim]);
}
return offsetDim;
}

// Partition L2 memref.
void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref(
SmallVector<air::ChannelPutOp> &puts, SmallVector<air::ChannelGetOp> &gets,
Expand All @@ -1223,17 +1195,27 @@ void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref(
std::map<int, SmallVector<air::ChannelInterface>> chanOpPartitions;
SmallVector<int> keys;
for (auto op : puts) {
auto offset = getFirstConstantOffsetValue(
op.getOffsets(), air::getTensorShape(ty).size(), dim);
auto offsetDim = air::getOffsetDimFromMemrefDim(dim, op.getStrides(),
air::getTensorShape(ty));
if (!offsetDim)
continue;
auto offset = getConstantIntValue(op.getOffsets()[*offsetDim]);
if (!offset)
continue;
push_back_if_unique<int>(keys, *offset);
if (!chanOpPartitions.count(*offset))
chanOpPartitions[*offset] = SmallVector<air::ChannelInterface>{op};
else
chanOpPartitions[*offset].push_back(op);
}
for (auto op : gets) {
auto offset = getFirstConstantOffsetValue(
op.getOffsets(), air::getTensorShape(ty).size(), dim);
auto offsetDim = air::getOffsetDimFromMemrefDim(dim, op.getStrides(),
air::getTensorShape(ty));
if (!offsetDim)
continue;
auto offset = getConstantIntValue(op.getOffsets()[*offsetDim]);
if (!offset)
continue;
push_back_if_unique<int>(keys, *offset);
if (!chanOpPartitions.count(*offset))
chanOpPartitions[*offset] = SmallVector<air::ChannelInterface>{op};
Expand All @@ -1248,12 +1230,12 @@ void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref(
newMemrefShape.push_back(air::getTensorShape(ty)[i]);
}
for (auto op : chanOpPartitions[key]) {
int offsetDim =
op.getOffsets().size() >= air::getTensorShape(ty).size()
? op.getOffsets().size() - air::getTensorShape(ty).size() + dim
: 0;
auto offsetDim = air::getOffsetDimFromMemrefDim(dim, op.getStrides(),
air::getTensorShape(ty));
if (!offsetDim)
continue;
if (op.getSizes().size() == newMemrefShape.size()) {
newMemrefShape[dim] = *getConstantIntValue(op.getSizes()[offsetDim]);
newMemrefShape[dim] = *getConstantIntValue(op.getSizes()[*offsetDim]);
break;
}
}
Expand Down Expand Up @@ -1300,9 +1282,11 @@ void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref(
op.getIndices().size();
auto &memrefOpOper = op->getOpOperand(memrefOperandOffset);
memrefOpOper.assign(newMemref);
int offsetDim = getFirstConstantOffsetValueIndex(
op.getOffsets(), air::getTensorShape(ty).size(), dim);
int offsetOperandOffset = memrefOperandOffset + offsetDim + 1;
auto offsetDim = air::getOffsetDimFromMemrefDim(dim, op.getStrides(),
air::getTensorShape(ty));
if (!offsetDim)
continue;
int offsetOperandOffset = memrefOperandOffset + *offsetDim + 1;
auto &offsetOpOper = op->getOpOperand(offsetOperandOffset);
offsetOpOper.assign(builder.create<arith::ConstantIndexOp>(loc, 0));
// Update strides (contiguous, row-major) after memref tiling.
Expand Down Expand Up @@ -1342,9 +1326,9 @@ void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref(

// Infer the dimension to which the join / distribute pattern happens, as basis
// for memref splitting.
int AIRSplitL2MemrefForBufferConstraintPass::getMemrefSplitDim(
SmallVector<air::ChannelInterface> putgets, int memrefRank) {
int split_dim = 0;
std::optional<int> AIRSplitL2MemrefForBufferConstraintPass::getMemrefSplitDim(
SmallVector<air::ChannelInterface> putgets, SmallVector<int> memrefShape) {
std::optional<int> memrefDim = std::nullopt;
for (unsigned i = 0; i < putgets.size() - 1; i++) {
for (unsigned j = i + 1; j < putgets.size(); j++) {
if (putgets[i].getOffsets().size() != putgets[j].getOffsets().size())
Expand All @@ -1354,16 +1338,16 @@ int AIRSplitL2MemrefForBufferConstraintPass::getMemrefSplitDim(
getConstantIntValue(putgets[j].getOffsets()[k])) {
if (*getConstantIntValue(putgets[i].getOffsets()[k]) !=
*getConstantIntValue(putgets[j].getOffsets()[k]))
split_dim = k;
memrefDim = k;
}
}
}
}
// Match offset dims with memref shape.
if (split_dim)
split_dim = split_dim + memrefRank - putgets[0].getOffsets().size();
split_dim = std::max(split_dim, 0);
return split_dim;
// Match offset dims with memref dims.
if (!memrefDim)
return std::nullopt;
return air::getMemrefDimFromOffsetDim(*memrefDim, putgets[0].getOffsets(),
putgets[0].getStrides(), memrefShape);
}

SmallVector<memref::AllocOp>
Expand Down Expand Up @@ -1451,12 +1435,12 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs(
for (auto chanOp : MM2SChannels)
for (auto put : air::getChannelPutOpThroughSymbol(chanOp))
putgets.push_back(put);
int split_dim = getMemrefSplitDim(
putgets, air::getTensorShape(memref.getType()).size());
auto split_dim =
getMemrefSplitDim(putgets, air::getTensorShape(memref.getType()));
if (split_dim)
allocOp->setAttr(
"split_dim",
IntegerAttr::get(IntegerType::get(ctx, 32), split_dim));
IntegerAttr::get(IntegerType::get(ctx, 32), *split_dim));
}
if (S2MMChannels.size() > 1) {
targetMemrefsToColTilingFactors[allocOp].push_back(S2MMChannels.size());
Expand All @@ -1465,12 +1449,12 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs(
for (auto chanOp : MM2SChannels)
for (auto get : air::getChannelGetOpThroughSymbol(chanOp))
putgets.push_back(get);
int split_dim = getMemrefSplitDim(
putgets, air::getTensorShape(memref.getType()).size());
auto split_dim =
getMemrefSplitDim(putgets, air::getTensorShape(memref.getType()));
if (split_dim)
allocOp->setAttr(
"split_dim",
IntegerAttr::get(IntegerType::get(ctx, 32), split_dim));
IntegerAttr::get(IntegerType::get(ctx, 32), *split_dim));
}
}
return targetMemrefs;
Expand Down
66 changes: 66 additions & 0 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1239,3 +1239,69 @@ air::getUpdatedOffsetsAfterShrinkage(SmallVector<int> old_memref_shape,
}
return new_offsets;
}

// Given a dimension on wrap-and-stride list, infer the dimension on memref that
// this pattern spans completely on.
std::optional<int>
air::getMemrefDimFromOffsetDim(int dimOnOffset, SmallVector<Value> offsets,
SmallVector<Value> strides,
SmallVector<int> memrefShape) {
std::optional<int> output = std::nullopt;
if (memrefShape.empty())
return std::nullopt;
if (offsets.empty())
return std::nullopt;
assert(dimOnOffset < (int)offsets.size() && "Dimension exceeds offsets rank");

// Get stride value which corresponds to accessing each memref dimension,
// highest dimension first.
int memrefRank = memrefShape.size();
SmallVector<int> memrefDimStrides(memrefRank, 1);
int currentStride = 1;
for (int i = memrefRank - 1; i >= 0; i--) {
memrefDimStrides[i] = currentStride;
currentStride *= memrefShape[i];
}

// Find the dimension on memref shape with given stride value.
auto strideVal = getConstantIntValue(strides[dimOnOffset]);
assert(strideVal && "Non-static stride value in data access pattern, NYI.");
for (unsigned i = 0; i < memrefDimStrides.size(); i++)
if (*strideVal == memrefDimStrides[i]) {
output = i;
return output;
}
return std::nullopt;
}

// Given a dimension on memref shape, infer the dimension on wrap-and-stride
// list that spans on this memref dimension.
std::optional<int>
air::getOffsetDimFromMemrefDim(int dimOnMemref, SmallVector<Value> strides,
SmallVector<int> memrefShape) {
std::optional<int> output = std::nullopt;
if (memrefShape.empty())
return std::nullopt;
if (strides.empty())
return std::nullopt;
assert(dimOnMemref < (int)memrefShape.size() &&
"Dimension exceeds memref rank");

// Get stride value which corresponds to accessing the current memref
// dimension.
int memrefRank = memrefShape.size();
int memrefStride = 1;
for (int i = memrefRank - 1; i > dimOnMemref; i--)
memrefStride *= memrefShape[i];

// Find the dimension on wrap-and-stride list with given stride value.
for (unsigned i = 0; i < strides.size(); i++) {
auto strideVal = getConstantIntValue(strides[i]);
assert(strideVal && "Non-static stride value in data access pattern, NYI.");
if (*strideVal == memrefStride) {
output = i;
return output;
}
}
return std::nullopt;
}
Loading

0 comments on commit aff2a49

Please sign in to comment.