Skip to content

Commit

Permalink
Remove canonicalizeFalseDependencies; move canonicalizeFalseDependenc…
Browse files Browse the repository at this point in the history
…ies to CanonicalizeAsyncOpDeps (Xilinx#844)
  • Loading branch information
erwei-xilinx authored Jan 6, 2025
1 parent 71c2c9b commit 383a75d
Showing 1 changed file with 70 additions and 97 deletions.
167 changes: 70 additions & 97 deletions mlir/lib/Dialect/AIR/IR/AIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,33 +196,84 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
template <class OpT>
static LogicalResult CanonicalizeAsyncOpDeps(OpT op,
PatternRewriter &rewriter) {

SmallVector<Value> depsOfDeps;
for (auto v : op.getAsyncDependencies()) {
if (auto asyncOperand =
dyn_cast_if_present<AsyncOpInterface>(v.getDefiningOp())) {
auto deps = asyncOperand.getAsyncDependencies();
depsOfDeps.append(deps.begin(), deps.end());
auto getMemrefsFromVec = [](SmallVector<Value> vec) {
SmallVector<Value> memrefs;
for (auto v : vec)
if (isa<MemRefType>(v.getType()))
memrefs.push_back(v);
return memrefs;
};
auto getAllMemrefsTouchedbyOp = [getMemrefsFromVec](Operation *o) {
llvm::SetVector<Value> memrefs;
SmallVector<Value> vals = o->getOperands();
vals.insert(vals.end(), o->getResults().begin(), o->getResults().end());
SmallVector<Region *> regions;
for (auto &region : o->getRegions())
regions.push_back(&region);
// If air.wait_all, then we analyze the dependency by collecting all
// operations that depend on it.
auto waitAllOp = dyn_cast_if_present<air::WaitAllOp>(o);
if (waitAllOp && waitAllOp.getAsyncToken()) {
for (auto user : waitAllOp.getAsyncToken().getUsers()) {
vals.insert(vals.end(), user->getOperands().begin(),
user->getOperands().end());
vals.insert(vals.end(), user->getResults().begin(),
user->getResults().end());
for (auto &region : user->getRegions())
regions.push_back(&region);
}
}
}
auto memrefvals = getMemrefsFromVec(vals);
memrefs.insert(memrefvals.begin(), memrefvals.end());
for (auto region : regions) {
llvm::SetVector<Value> usedVals;
getUsedValuesDefinedAbove(*region, usedVals);
auto usedMemrefs = getMemrefsFromVec(usedVals.takeVector());
memrefs.insert(usedMemrefs.begin(), usedMemrefs.end());
}
return memrefs;
};
auto memrefsTouchedByOp = getAllMemrefsTouchedbyOp(op.getOperation());
// make a list of new async token operands
SmallVector<Value> newAsyncDeps;
llvm::SetVector<Value> newAsyncDeps; // don't include duplicates
for (auto v : op.getAsyncDependencies()) {
// don't include duplicates
if (std::find(std::begin(newAsyncDeps), std::end(newAsyncDeps), v) !=
std::end(newAsyncDeps))
continue;
// don't include wait_all ops with no operands
if (auto wa = dyn_cast_if_present<WaitAllOp>(v.getDefiningOp()))
if (wa.getAsyncDependencies().size() == 0)
continue;
// don't include a dependency of another dependency
if (std::find(std::begin(depsOfDeps), std::end(depsOfDeps), v) !=
std::end(depsOfDeps))
continue;
newAsyncDeps.push_back(v);
// don't include any wrong dependencies
if (v.getDefiningOp()) {
auto memrefsTouchedByDefOp = getAllMemrefsTouchedbyOp(v.getDefiningOp());
if (!memrefsTouchedByDefOp.empty() && !memrefsTouchedByOp.empty() &&
llvm::none_of(memrefsTouchedByDefOp, [&memrefsTouchedByOp](Value v) {
return llvm::is_contained(memrefsTouchedByOp, v);
})) {
continue;
}
}
newAsyncDeps.insert(v);
}

// don't include a dependency of another dependency
auto getDepsOfDeps = [](llvm::SetVector<Value> deps) {
llvm::SetVector<Value> depsOfDeps;
for (auto v : deps) {
if (auto asyncOperand =
dyn_cast_if_present<AsyncOpInterface>(v.getDefiningOp())) {
auto deps = asyncOperand.getAsyncDependencies();
depsOfDeps.insert(deps.begin(), deps.end());
}
}
return depsOfDeps;
};
llvm::SetVector<Value> erased;
for (auto v : newAsyncDeps) {
if (llvm::is_contained(getDepsOfDeps(newAsyncDeps), v))
erased.insert(v);
}
for (auto e : erased)
newAsyncDeps.remove(e);

// if the operands won't change, return
if (newAsyncDeps.size() == op.getAsyncDependencies().size())
return failure();
Expand Down Expand Up @@ -301,77 +352,6 @@ CanonicalizeAsyncLoopCarriedDepsInRegion(OpT op, PatternRewriter &rewriter) {
return success();
}

// Break any wrong async dependencies.
template <class T>
static LogicalResult canonicalizeFalseDependencies(T op,
PatternRewriter &rewriter) {
auto asyncOp = dyn_cast_if_present<air::AsyncOpInterface>(op.getOperation());
if (!asyncOp)
return failure();
if (asyncOp.getAsyncDependencies().empty())
return failure();

auto getMemrefsFromVec = [](SmallVector<Value> vec) {
SmallVector<Value> memrefs;
for (auto v : vec)
if (isa<MemRefType>(v.getType()))
memrefs.push_back(v);
return memrefs;
};
auto getAllMemrefsTouchedbyOp = [getMemrefsFromVec](Operation *o) {
llvm::SetVector<Value> memrefs;
SmallVector<Value> vals = o->getOperands();
vals.insert(vals.end(), o->getResults().begin(), o->getResults().end());
SmallVector<Region *> regions;
for (auto &region : o->getRegions())
regions.push_back(&region);
// If air.wait_all, then we analyze the dependency by collecting all
// operations that depend on it.
auto waitAllOp = dyn_cast_if_present<air::WaitAllOp>(o);
if (waitAllOp && waitAllOp.getAsyncToken()) {
for (auto user : waitAllOp.getAsyncToken().getUsers()) {
vals.insert(vals.end(), user->getOperands().begin(),
user->getOperands().end());
vals.insert(vals.end(), user->getResults().begin(),
user->getResults().end());
for (auto &region : user->getRegions())
regions.push_back(&region);
}
}
auto memrefvals = getMemrefsFromVec(vals);
memrefs.insert(memrefvals.begin(), memrefvals.end());
for (auto region : regions) {
llvm::SetVector<Value> usedVals;
getUsedValuesDefinedAbove(*region, usedVals);
auto usedMemrefs = getMemrefsFromVec(usedVals.takeVector());
memrefs.insert(usedMemrefs.begin(), usedMemrefs.end());
}
return memrefs;
};

auto memrefsTouchedByOp = getAllMemrefsTouchedbyOp(op.getOperation());
if (memrefsTouchedByOp.empty())
return failure();
SmallVector<Value> depList = asyncOp.getAsyncDependencies();
for (int i = depList.size() - 1; i >= 0; i--) {
auto tokDefOp = depList[i].getDefiningOp();
if (!tokDefOp)
continue;
auto memrefsTouchedByDefOp = getAllMemrefsTouchedbyOp(tokDefOp);
if (memrefsTouchedByDefOp.empty())
continue;
if (llvm::none_of(memrefsTouchedByDefOp, [&memrefsTouchedByOp](Value v) {
return llvm::is_contained(memrefsTouchedByOp, v);
})) {
auto newOp = rewriter.clone(*op);
dyn_cast<air::AsyncOpInterface>(newOp).eraseAsyncDependency(i);
rewriter.replaceOp(op, newOp);
return success();
}
}
return failure();
}

//
// LaunchOp
//
Expand Down Expand Up @@ -587,7 +567,6 @@ void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add(canonicalizeHierarchyOpArgs<LaunchOp>);
patterns.add(CanonicalizeAsyncOpDeps<LaunchOp>);
patterns.add(CanonicalizeAsyncLoopCarriedDepsInRegion<LaunchOp>);
patterns.add(canonicalizeFalseDependencies<LaunchOp>);
}

ArrayRef<BlockArgument> LaunchOp::getIds() {
Expand Down Expand Up @@ -850,7 +829,6 @@ void SegmentOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add(canonicalizeHierarchyOpArgs<SegmentOp>);
patterns.add(CanonicalizeAsyncOpDeps<SegmentOp>);
patterns.add(CanonicalizeAsyncLoopCarriedDepsInRegion<SegmentOp>);
patterns.add(canonicalizeFalseDependencies<SegmentOp>);
}

ArrayRef<BlockArgument> SegmentOp::getIds() {
Expand Down Expand Up @@ -1112,7 +1090,6 @@ void HerdOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add(canonicalizeHierarchyOpArgs<HerdOp>);
patterns.add(CanonicalizeAsyncOpDeps<HerdOp>);
patterns.add(CanonicalizeAsyncLoopCarriedDepsInRegion<HerdOp>);
patterns.add(canonicalizeFalseDependencies<HerdOp>);
}

ArrayRef<BlockArgument> HerdOp::getIds() {
Expand Down Expand Up @@ -1234,7 +1211,6 @@ void ExecuteOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add(FoldExecute);
patterns.add(CanonicalizeAsyncOpDeps<ExecuteOp>);
patterns.add(CanonicalizeAsyncLoopCarriedDepsInRegion<ExecuteOp>);
patterns.add(canonicalizeFalseDependencies<ExecuteOp>);
}

//
Expand Down Expand Up @@ -1286,7 +1262,6 @@ void WaitAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(FoldWaitAll);
patterns.add(CanonicalizeAsyncOpDeps<WaitAllOp>);
patterns.add(canonicalizeFalseDependencies<WaitAllOp>);
}

// Get strides from MemRefType.
Expand Down Expand Up @@ -1549,7 +1524,7 @@ void DmaMemcpyNdOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(ComposeMemrefOpOnDmaMemcpyNdSrc);
patterns.add(ComposeMemrefOpOnDmaMemcpyNdDst);
patterns.add(canonicalizeFalseDependencies<DmaMemcpyNdOp>);
patterns.add(CanonicalizeAsyncOpDeps<DmaMemcpyNdOp>);
}

//
Expand Down Expand Up @@ -1593,7 +1568,6 @@ void ChannelPutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(ComposeMemrefOpOnChannelOp<ChannelPutOp>);
patterns.add(CanonicalizeAsyncOpDeps<ChannelPutOp>);
patterns.add(canonicalizeFalseDependencies<ChannelPutOp>);
}

//
Expand All @@ -1604,7 +1578,6 @@ void ChannelGetOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(ComposeMemrefOpOnChannelOp<ChannelGetOp>);
patterns.add(CanonicalizeAsyncOpDeps<ChannelGetOp>);
patterns.add(canonicalizeFalseDependencies<ChannelGetOp>);
}

//
Expand Down

0 comments on commit 383a75d

Please sign in to comment.