Skip to content

Commit

Permalink
[TIR] Move SplitHostDevice to before MakePackedAPI (apache#14986)
Browse files Browse the repository at this point in the history
* [TIR] Move SplitHostDevice to before MakePackedAPI

This simplifies the logic used in MakePackedAPI, that it the last user
of the host parameter in a function's target.  After MakePackedAPI,
every PrimFunc has a "target" attribute without a "host".

* Roofline plots, update location for SaveLoweredTIR

* Update ethos-u tests to include host prior to MakeUnpackedAPI
  • Loading branch information
Lunderberg authored and junrushao committed Jun 22, 2023
1 parent f35cc9e commit 8a3406f
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 41 deletions.
12 changes: 9 additions & 3 deletions python/tvm/utils/roofline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,16 @@ def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=Non

@pass_instrument
class SaveLoweredTIR:
"""Save TIR functions from right before final lowering. Right now this
means right before tir.MakePackedAPI."""
"""Save TIR functions for analysis.
def __init__(self, before_pass: str = "tir.MakePackedAPI"):
We need the TIR function in a form that can be handled by
`auto_scheduler.feature.named_features_from_primfunc`, but which
is the closest to the final lowered form as possible. Right now this
means right before tir.SplitHostDevice.
"""

def __init__(self, before_pass: str = "tir.SplitHostDevice"):
"""
Parameters
----------
Expand Down
5 changes: 3 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
}

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());

bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
.value_or(relay::Executor::Create("graph", {}))
->GetAttr<Bool>("unpacked-api")
Expand All @@ -590,8 +593,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
Expand Down
11 changes: 10 additions & 1 deletion src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}();
int target_device_type = target->GetTargetDeviceType();

// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}

auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
Expand Down Expand Up @@ -325,7 +333,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
name_hint + "." + kv.first->name_hint);
}

func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host}});

Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
Expand Down
10 changes: 9 additions & 1 deletion src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
}();
int target_device_type = target->GetTargetDeviceType();

// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}

auto* func_ptr = func.CopyOnWrite();

// Setup device context
Expand Down Expand Up @@ -145,7 +153,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
func_ptr->buffer_map = Map<Var, Buffer>();

// return the function.
return func;
return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}});
}

namespace transform {
Expand Down
8 changes: 0 additions & 8 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ class HostDeviceSplitter : public StmtMutator {
};

PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& gvar) {
auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt_target) << "SplitHostDevice: Require the target attribute";
Target target = opt_target.value();

auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto name_prefix = global_symbol.value_or(gvar->name_hint);

Expand All @@ -112,10 +108,6 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g
func.CopyOnWrite()->body = body;
}

if (auto target_host = target->GetHost()) {
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value());
}

return func;
}

Expand Down
4 changes: 3 additions & 1 deletion tests/python/contrib/test_ethosu/test_encode_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ def get_graph():
# nothing else was overrwritten.
# With Target Hooks the TIR module needs a target attached
# and lowered via make unpacked API.
tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
tir_mod["main"] = tir_mod["main"].with_attr(
"target", tvm.target.Target("ethos-u", host="ethos-u")
)
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
tir_to_cs_translator.translate(tir_mod, params)

Expand Down
8 changes: 6 additions & 2 deletions tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def test_buffer_info_extraction():
# With Target Hooks the TIR module needs a target attached
# and lowered via make unpacked API.
tir_mod = test_case["tir_module"]
tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
tir_mod["main"] = tir_mod["main"].with_attr(
"target", tvm.target.Target("ethos-u", host="ethos-u")
)
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"])
for buffer_var, info in buffer_info.items():
Expand Down Expand Up @@ -959,7 +961,9 @@ def check_buffer(address, region, length, buffer_var):

for test_case in test_cases:
tir_mod = test_case["tir_module"]
tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
tir_mod["main"] = tir_mod["main"].with_attr(
"target", tvm.target.Target("ethos-u", host="ethos-u")
)
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
candidate_regions_for_scratch = [5, 2, 1]
(
Expand Down
40 changes: 33 additions & 7 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_makeapi():
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr(
{
"target": tvm.target.Target("llvm"),
"target": tvm.target.Target("llvm", host="llvm"),
"global_symbol": "main",
}
)
Expand Down Expand Up @@ -90,7 +90,9 @@ def test_variable_passed_from_args():
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, not_device_context], stmt))
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
)(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]

Expand Down Expand Up @@ -132,7 +134,9 @@ def test_device_api_context_implicit_resource_handle():
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, device_context], stmt))
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod)
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
)(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]

Expand Down Expand Up @@ -161,7 +165,7 @@ def test_device_api_context_implicit_resource_handle():

@pytest.mark.parametrize("use_global_symbol", [True, False])
def test_no_op_when_global_symbol_is_absent(use_global_symbol):
func_attr = {"target": tvm.target.Target("llvm")}
func_attr = {"target": tvm.target.Target("llvm", host="llvm")}
if use_global_symbol:
func_attr["global_symbol"] = "main"

Expand All @@ -177,6 +181,28 @@ def before():
tvm.ir.assert_structural_equal(before, after)


def test_target_host_removed():
"""After MakePackedAPI, host-side target should be the host
MakePackedAPI is the last transform that requires both the device
and the host. After MakePackedAPI, the target attribute should
only contain the host-side target.
"""

host = tvm.target.Target("llvm")

@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)})
T.evaluate(0)

after = tvm.tir.transform.MakePackedAPI()(before)
target_attr = after["main"].attrs["target"]
assert str(host) == str(target_attr)


def test_internal_subroutine_call():
"""Internal subroutines should not use the PackedFunc API
Expand All @@ -190,7 +216,7 @@ def test_internal_subroutine_call():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
before.subroutine(A.data)

@T.prim_func
Expand Down Expand Up @@ -222,12 +248,12 @@ def test_subroutine_call_to_externally_visible_subroutine():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
before.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")})
T.evaluate(A_data)

after = tvm.tir.transform.MakePackedAPI()(before)
Expand Down
65 changes: 56 additions & 9 deletions tests/python/unittest/test_tir_transform_make_unpacked_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def mod(mod_without_attrs):


def test_noop_if_not_global_symbol(mod_without_attrs):
before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(
mod_without_attrs
)
target = tvm.target.Target("llvm", host="llvm")
before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_without_attrs)
after = tvm.tir.transform.MakeUnpackedAPI()(before)
tvm.ir.assert_structural_equal(before, after)

Expand All @@ -59,7 +58,8 @@ def test_fails_if_no_target(mod_without_attrs):

@tvm.testing.parametrize_targets("c", "llvm", "cuda")
def test_device_setup(mod, target, dev):
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod)
target = tvm.target.Target(target, host="llvm")
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 1
assert f.params[0].name == "A"
Expand Down Expand Up @@ -138,6 +138,49 @@ def test_body():
assert f.params[2].name == "A"


class TestTargetHostRemoved(tvm.testing.CompareBeforeAfter):
"""After MakeUnpackedAPI, host-side target should be the host
MakeUnpackedAPI is the last transform that requires both the device
and the host. After MakeUnpackedAPI, the target attribute should
only contain the host-side target.
"""

transform = tvm.tir.transform.MakeUnpackedAPI()

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("cuda")})
T.evaluate(A_data)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 2)
mod.subroutine(A_data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("cuda")})
T.evaluate(A_data)

return mod


class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter):
"""Internal subroutines do not require modification
Expand All @@ -153,7 +196,7 @@ def before(self):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
mod.subroutine(A.data)

@T.prim_func
Expand Down Expand Up @@ -195,12 +238,14 @@ def before(self):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr(
{"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}
)
T.evaluate(A_data)

return mod
Expand Down Expand Up @@ -240,7 +285,7 @@ def before(self):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
mod.subroutine(
T.tvm_stack_make_array(
A.data,
Expand All @@ -255,7 +300,9 @@ def main(A: T.Buffer(1, "float32")):

@T.prim_func
def subroutine(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.func_attr(
{"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}
)
T.evaluate(A.data)

return mod
Expand Down
14 changes: 7 additions & 7 deletions tests/python/unittest/test_tir_transform_split_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_split_host_device_func_attr():
[
tvm.tir.transform.AnnotateDeviceRegions(),
tvm.tir.transform.SplitHostDevice(),
tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.LowerDeviceKernelLaunch(),
]
)(mod)
Expand Down Expand Up @@ -111,7 +112,7 @@ def expected(self):
class mod:
@T.prim_func
def main(n: T.int32):
T.func_attr({"target": T.target("llvm -opt-level=0")})
T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")})
mod.main_kernel(n)

@T.prim_func
Expand Down Expand Up @@ -168,20 +169,19 @@ def main_kernel(n: T.int32):
return mod


class TestSplitHostDevice(BaseCompare):
class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare):
"""Like TestSplitHostDevice, but no device regions to extract
Even if there are no device regions, the host-side function should
still have its "target" attribute updated.
Because MakePackedAPI/MakeUnpackedAPI still require both the
device and host, SplitHostDevice does not modify the "target"
attribute.
"""

def before():
T.func_attr({"target": T.target("ext_dev", host="llvm")})
T.evaluate(0)

def expected():
T.func_attr({"target": T.target("llvm")})
T.evaluate(0)
expected = before


if __name__ == "__main__":
Expand Down

0 comments on commit 8a3406f

Please sign in to comment.