Skip to content

Commit

Permalink
[Bugfix][TIR][VTA] Update host-side target, even without device func
Browse files Browse the repository at this point in the history
This resolves an issue introduced by the combination of
apache#14918 and
apache#14945.  The bug occurred for
targets that do not require device-side codegen, but do require a
`device_type` other than `kDLCPU`.  It wasn't caught by CI, as the
issue only occurred with the combination of both PRs.

1. apache#14918 updated `SplitHostDevice` to only modify the `"target"`
   attribute when a device-side function has been extracted.

2. For VTA, there is no device-side function, as everything is done
   through host-side API calls.

3. From (1) and (2), the VTA examples kept the target
   `T.target("ext_dev", host="llvm")` after the `SplitHostDevice`
   pass, instead of being updated to `T.target("llvm")`.

4. apache#14945 restricted CombineContextCall to only apply to host-side
   passes.

5. From (4) and (5), the `CombineContextCall` pass was no longer
   applied to the VTA context calls.

This PR fixes `SplitHostDevice`, updating the target from
`T.target("ext_dev", host="llvm")` to `T.target("llvm")`, even if no
device sections have been extracted from the function.
  • Loading branch information
Lunderberg committed May 30, 2023
1 parent 4267fbf commit 4fdf1d1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g

HostDeviceSplitter splitter(device_mod, name_prefix);

auto body = splitter(func->body);

if (!body.same_as(func->body)) {
if (auto body = splitter(func->body); !body.same_as(func->body)) {
func.CopyOnWrite()->body = body;
auto target_host = target->GetHost().value_or(Target("llvm"));
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host);
}

if (auto target_host = target->GetHost()) {
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value());
}

return func;
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_tir_transform_split_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,21 @@ def main_kernel(n: T.int32):
return mod


class TestSplitHostDevice(BaseCompare):
"""Like TestSplitHostDevice, but no device regions to extract
Even if there are no device regions, the host-side function should
still have its "target" attribute updated.
"""

def before():
T.func_attr({"target": T.target("ext_dev", host="llvm")})
T.evaluate(0)

def expected():
T.func_attr({"target": T.target("llvm")})
T.evaluate(0)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 4fdf1d1

Please sign in to comment.