diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index a81af7d7805bf..56835e2b73d4a 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -34,16 +34,37 @@ namespace tir { class DeviceRegionAnnotater : public StmtMutator { public: + static Stmt Apply(Target host_target, Target device_target, Stmt body) { + DeviceRegionAnnotater mutator(device_target); + body = mutator(body); + + bool same_host_and_device = host_target->str() == device_target->str(); + + // If no region was found that must be on the device, but the + // device and host differ (e.g. `T.target('c', host='llvm')`), + // then the entire region should be annotated. This preserves the + // host-side handling of DLTensor arguments, while ensuring that + // any device targets are used for the codegen. + if (!mutator.found_target_region_ && !same_host_and_device) { + body = AttrStmt(device_target, tvm::attr::kTarget, 0, body); + } + + return body; + } + + private: explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tvm::attr::kTarget) { // If a target attribute already exists, use it as-is. + found_target_region_ = true; return GetRef(op); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { // These attributes are only allowed in device-side code, so // they should be annotated with the function's default target. + found_target_region_ = true; Stmt body = GetRef(op); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); } else { @@ -52,8 +73,8 @@ class DeviceRegionAnnotater : public StmtMutator { } } - private: Target device_target_; + bool found_target_region_{false}; }; namespace transform { @@ -64,9 +85,12 @@ Pass AnnotateDeviceRegions() { ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; Target target = opt_target.value(); - if (target->GetHost()) { - DeviceRegionAnnotater mutator(target.WithoutHost()); - func.CopyOnWrite()->body = mutator(func->body); + if (auto opt_host = target->GetHost()) { + auto new_body = + DeviceRegionAnnotater::Apply(opt_host.value(), target.WithoutHost(), func->body); + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } } return func; };