Skip to content

Commit

Permalink
Annotate entire body as target region if no subregions found
Browse files Browse the repository at this point in the history
Otherwise, in cases of a custom codegen, the device specification may
be dropped entirely.
  • Loading branch information
Lunderberg committed Jun 13, 2023
1 parent f3160a9 commit c35a8b5
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions src/tir/transforms/annotate_device_regions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,37 @@ namespace tir {

class DeviceRegionAnnotater : public StmtMutator {
public:
static Stmt Apply(Target host_target, Target device_target, Stmt body) {
DeviceRegionAnnotater mutator(device_target);
body = mutator(body);

bool same_host_and_device = host_target->str() == device_target->str();

// If no region was found that must be on the device, but the
// device and host differ (e.g. `T.target('c', host='llvm')`),
// then the entire region should be annotated. This preserves the
// host-side handling of DLTensor arguments, while ensuring that
// any device targets are used for the codegen.
if (!mutator.found_target_region_ && !same_host_and_device) {
body = AttrStmt(device_target, tvm::attr::kTarget, 0, body);
}

return body;
}

private:
explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is.
found_target_region_ = true;
return GetRef<Stmt>(op);
} else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
found_target_region_ = true;
Stmt body = GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else {
Expand All @@ -52,8 +73,8 @@ class DeviceRegionAnnotater : public StmtMutator {
}
}

private:
Target device_target_;
bool found_target_region_{false};
};

namespace transform {
Expand All @@ -64,9 +85,12 @@ Pass AnnotateDeviceRegions() {
ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute";
Target target = opt_target.value();

if (target->GetHost()) {
DeviceRegionAnnotater mutator(target.WithoutHost());
func.CopyOnWrite()->body = mutator(func->body);
if (auto opt_host = target->GetHost()) {
auto new_body =
DeviceRegionAnnotater::Apply(opt_host.value(), target.WithoutHost(), func->body);
if (!new_body.same_as(func->body)) {
func.CopyOnWrite()->body = new_body;
}
}
return func;
};
Expand Down

0 comments on commit c35a8b5

Please sign in to comment.