Skip to content

Commit

Permalink
Improved device region annotation
Browse files Browse the repository at this point in the history
Previously, if no device-specific attribute is found, assume that the
entire function should be executed on the device.  Now, identify
host-specific Call (e.g. `builtin::call_packed()`) and ensure these
remain on the host.
  • Loading branch information
Lunderberg committed Jun 16, 2023
1 parent d3ddd40 commit cc7d384
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 8 deletions.
92 changes: 84 additions & 8 deletions src/tir/transforms/annotate_device_regions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,32 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <algorithm>
#include <tuple>
#include <vector>

namespace tvm {
namespace tir {

class DeviceRegionAnnotater : public StmtMutator {
class DeviceRegionAnnotater : public StmtExprMutator {
using Parent = StmtExprMutator;

public:
static Stmt Apply(Target host_target, Target device_target, Stmt body) {
bool same_host_and_device = host_target->str() == device_target->str();
if (same_host_and_device) {
return body;
}

DeviceRegionAnnotater mutator(device_target);
body = mutator(body);

bool same_host_and_device = host_target->str() == device_target->str();

// If no region was found that must be on the device, but the
// device and host differ (e.g. `T.target('c', host='llvm')`),
// then the entire region should be annotated. This preserves the
// host-side handling of DLTensor arguments, while ensuring that
// any device targets are used for the codegen.
if (!mutator.found_target_region_ && !same_host_and_device) {
if (mutator.current_region_ == Region::Either && !same_host_and_device) {
body = AttrStmt(device_target, tvm::attr::kTarget, 0, body);
}

Expand All @@ -58,23 +67,90 @@ class DeviceRegionAnnotater : public StmtMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is.
found_target_region_ = true;
current_region_ = Region::Device;
return GetRef<Stmt>(op);
} else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
found_target_region_ = true;
current_region_ = Region::Device;
Stmt body = GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else {
// All other annotations are ignored
return StmtMutator::VisitStmt_(op);
return Parent::VisitStmt_(op);
}
}

Stmt VisitStmt_(const SeqStmtNode* op) final {
std::vector<Region> regions;
Array<Stmt> seq = op->seq.Map([&](Stmt stmt) {
current_region_ = Region::Either;
stmt = VisitStmt(stmt);
regions.push_back(current_region_);
return stmt;
});

bool has_host_function = std::any_of(regions.begin(), regions.end(),
[](const auto& reg) { return reg == Region::Host; });
if (has_host_function) {
current_region_ = Region::Host;

Array<Stmt> new_seq;
Array<Stmt> device_seq;
auto finish_device_seq = [&]() {
if (device_seq.size()) {
new_seq.push_back(
AttrStmt(device_target_, tvm::attr::kTarget, 0, SeqStmt::Flatten(device_seq)));
device_seq.clear();
}
};

for (size_t i = 0; i < seq.size(); i++) {
if (regions[i] == Region::Host) {
finish_device_seq();
new_seq.push_back(seq[i]);
} else {
device_seq.push_back(seq[i]);
}
}
finish_device_seq();

return SeqStmt::Flatten(new_seq);
} else if (seq.same_as(op->seq)) {
return GetRef<Stmt>(op);
} else {
return SeqStmt(seq);
}
}

PrimExpr VisitExpr_(const CallNode* op) final {
// TODO(Lunderberg): Make a new attribute in builtin.cc to label
// host-only operations.
bool is_host_only_op =
op->op.same_as(builtin::tvm_call_packed()) || op->op.same_as(builtin::tvm_call_cpacked()) ||
op->op.same_as(builtin::tvm_call_packed_lowered()) ||
op->op.same_as(builtin::tvm_call_cpacked_lowered()) ||
op->op.same_as(builtin::tvm_struct_get()) || op->op.same_as(builtin::tvm_struct_set()) ||
op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::call_pure_extern()) ||
op->op.same_as(builtin::tvm_throw_last_error()) ||
op->op.same_as(builtin::tvm_stack_alloca()) ||
op->op.same_as(builtin::tvm_stack_make_shape()) ||
op->op.same_as(builtin::tvm_stack_make_array());
if (is_host_only_op) {
current_region_ = Region::Host;
}
return Parent::VisitExpr_(op);
}

Target device_target_;
bool found_target_region_{false};

enum class Region {
Either,
Host,
Device,
};
Region current_region_{Region::Either};
};

namespace transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,76 @@ def expected(A: T.Buffer(1, "float32")):
A[0] = 0.0


class TestAnnotateEntireBody(BaseCompare):
"""Annotation inserted to wrap entire function
Function is assumed to belong on the device.
"""

def before(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
A[0] = 0.0

def expected(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(T.target("cuda"), "target", 0)
A[0] = 0.0


class TestNoAnnotationForSameHostDevice(BaseCompare):
"""No annotation is needed if host/device are the same"""

def before(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("llvm", host="llvm")})
A[0] = 0.0

expected = before


class TestAnnotationAvoidsHostConstructs(BaseCompare):
"""Device annotation does not contain host-only functions
Calls that must be on the host side (e.g. T.call_packed) remain on
the host.
"""

def before(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.call_packed("dummy_function", dtype="void")
A[0] = 0.0
T.call_packed("dummy_function", dtype="void")

def expected(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.call_packed("dummy_function", dtype="void")
with T.attr(T.target("cuda"), "target", 0):
A[0] = 0.0
T.call_packed("dummy_function", dtype="void")


class TestAnnotationNoRepetition(BaseCompare):
"""Device annotation does not contain host-only functions
When placing everything that isn't a host-specific function into
target block, sequential device statements should be in the same
block.
"""

def before(A: T.Buffer(2, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.call_packed("dummy_function", dtype="void")
A[0] = 0.0
A[1] = 1.0
T.call_packed("dummy_function", dtype="void")

def expected(A: T.Buffer(2, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.call_packed("dummy_function", dtype="void")
with T.attr(T.target("cuda"), "target", 0):
A[0] = 0.0
A[1] = 1.0
T.call_packed("dummy_function", dtype="void")


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit cc7d384

Please sign in to comment.