Skip to content

Commit

Permalink
[MLIR][OpenMP] Support lowering of host_eval to LLVM IR
Browse files Browse the repository at this point in the history
This patch updates the MLIR to LLVM IR lowering of `omp.target` to support
passing `num_teams`, `num_threads`, `thread_limit` and SPMD loop bounds through
the `host_eval` argument of `omp.target`.

This replaces the previous implementation where this information was directly
attached to the `omp.target` operation rather than captured to be used by the
corresponding nested operation.
  • Loading branch information
skatrak committed Oct 10, 2024
1 parent 44b6230 commit e2ee789
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 72 deletions.
62 changes: 18 additions & 44 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1756,55 +1756,29 @@ LogicalResult TargetOp::verify() {
Operation *TargetOp::getInnermostCapturedOmpOp() {
Dialect *ompDialect = (*this)->getDialect();
Operation *capturedOp = nullptr;
Region *capturedParentRegion = nullptr;

walk<WalkOrder::PostOrder>([&](Operation *op) {
walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op == *this)
return;

// Reset captured op if crossing through an omp.loop_nest, so that the top
// level one will be the one captured.
if (llvm::isa<LoopNestOp>(op)) {
capturedOp = nullptr;
capturedParentRegion = nullptr;
}
return WalkResult::advance();

// Ignore operations of other dialects or omp operations with no regions,
// because these will only be checked if they are siblings of an omp
// operation that can potentially be captured.
bool isOmpDialect = op->getDialect() == ompDialect;
bool hasRegions = op->getNumRegions() > 0;

if (capturedOp) {
bool isImmediateParent = false;
for (Region &region : op->getRegions()) {
if (&region == capturedParentRegion) {
isImmediateParent = true;
capturedParentRegion = op->getParentRegion();
break;
}
}

// Make sure the captured op is part of a (possibly multi-level) nest of
// OpenMP-only operations containing no unsupported siblings at any level.
if ((hasRegions && isOmpDialect != isImmediateParent) ||
(!isImmediateParent && !siblingAllowedInCapture(op))) {
capturedOp = nullptr;
capturedParentRegion = nullptr;
}
} else {
// The first OpenMP dialect op containing a region found while visiting
// in post-order should be the innermost captured OpenMP operation.
if (isOmpDialect && hasRegions) {
capturedOp = op;
capturedParentRegion = op->getParentRegion();

// Don't capture this op if it has a not-allowed sibling.
for (Operation &sibling : op->getParentRegion()->getOps()) {
if (&sibling != op && !siblingAllowedInCapture(&sibling)) {
capturedOp = nullptr;
capturedParentRegion = nullptr;
}
}
}
}
if (!isOmpDialect || !hasRegions)
return WalkResult::skip();

// Don't capture this op if it has a not-allowed sibling, and stop recursing
// into nested operations.
for (Operation &sibling : op->getParentRegion()->getOps())
if (&sibling != op && !siblingAllowedInCapture(&sibling))
return WalkResult::interrupt();

// Don't continue capturing nested operations if we reach an omp.loop_nest.
capturedOp = op;
return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
: WalkResult::advance();
});

return capturedOp;
Expand Down
Loading

0 comments on commit e2ee789

Please sign in to comment.