diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py index 8e5807acec94..ade63f2da9e4 100644 --- a/python/tvm/micro/__init__.py +++ b/python/tvm/micro/__init__.py @@ -23,6 +23,7 @@ from .debugger import GdbRemoteDebugger from .micro_library import MicroLibrary from .micro_binary import MicroBinary +from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError from .session import ( create_local_graph_runtime, create_local_debug_runtime, diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py new file mode 100644 index 000000000000..4ce80be647c1 --- /dev/null +++ b/python/tvm/micro/model_library_format.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines functions for exporting to Model Library Format.""" + +import datetime +import json +import os +import re +import tarfile + +from ..contrib import utils +from ..relay.backend import graph_runtime_factory +from ..relay import param_dict + + +class UnsupportedInModelLibraryFormatError(Exception): + """Raised when export_model_library_format does not support the given Module tree.""" + + +def _populate_codegen_dir(mod, codegen_dir: str): + """Populate the codegen sub-directory as part of a Model Library Format export. + + Parameters + ---------- + mod : tvm.runtime.Module + Module which should be written to codegen_dir. + codegen_dir : str + Path to the codegen directory on disk. + """ + dso_modules = mod._collect_dso_modules() + dso_module_handles = [m.handle.value for m in dso_modules] + non_dso_modules = mod._collect_from_import_tree(lambda m: m not in dso_modules) + if non_dso_modules: + raise UnsupportedInModelLibraryFormatError( + f"Don't know how to export non-c or non-llvm modules; found: {non_dso_modules!r}" + ) + + mod_indices = {"lib": 0, "src": 0} + host_codegen_dir = os.path.join(codegen_dir, "host") + for dso_mod in dso_modules: + if dso_mod.type_key == "c": + index = mod_indices["src"] + mod_indices["src"] += 1 + parent_dir = os.path.join(host_codegen_dir, "src") + file_name = os.path.join(parent_dir, f"lib{index}.c") + elif dso_mod.type_key == "llvm": + index = mod_indices["lib"] + mod_indices["lib"] += 1 + parent_dir = os.path.join(host_codegen_dir, "lib") + file_name = os.path.join(parent_dir, f"lib{index}.o") + else: + assert ( + False + ), f"do not expect module with type_key={mod.type_key} from _collect_dso_modules" + + if not os.path.exists(parent_dir): + os.makedirs(parent_dir) + dso_mod.save(file_name) + + +def _build_memory_map(graph_json): + """Build a simpler memory map from graph JSON. + + Parameters + ---------- + graph_json : str + String representation of the graph_json created from tvm.relay.build(). + + Returns + ------- + list : + A list with one entry per storage id describing that memory. + """ + graph = json.loads(graph_json) + + seen_storage_ids = set() + memory_map = [] + for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]): + if storage_id in seen_storage_ids: + continue + + seen_storage_ids.add(storage_id) + num_elements = 1 + for dim in graph["attrs"]["shape"][1][storage_id]: + num_elements *= dim + + dltype = graph["attrs"]["dltype"][1][storage_id] + m = re.match(r"^[a-zA-Z]+([0-9]+)$", dltype) + assert m, f"Exported graph contains unknown dltype {dltype}" + + elem_bits = int(m.group(1)) + + map_entry = { + "storage_id": storage_id, + "size_bytes": (num_elements * elem_bits + 7) // 8, + } + if node_id in graph["arg_nodes"]: + map_entry["input_binding"] = graph["nodes"][node_id]["name"] + + memory_map.append(map_entry) + + return memory_map + + +def export_model_library_format(mod: graph_runtime_factory.GraphRuntimeFactoryModule, file_name): + """Export the build artifact in Model Library Format. + + This function creates a .tar archive containing the build artifacts in a standardized + layout. It's intended to allow downstream automation to build TVM artifacts against the C + runtime. + + Parameters + ---------- + mod : tvm.relay.backend.graph_runtime_factory.GraphRuntimeFactoryModule + The return value of tvm.relay.build, which will be exported into Model Library Format. + file_name : str + Path to the .tar archive to generate. + """ + tempdir = utils.tempdir() + metadata = { + "version": 1, + "model_name": mod.libmod_name, + "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), + "memory": _build_memory_map(mod.graph_json), + "target": {int(k): str(v) for k, v in mod.target.items()}, + "runtimes": ["graph"], + } + with open(tempdir.relpath("metadata.json"), "w") as json_f: + json.dump(metadata, json_f, indent=2, sort_keys=True) + + codegen_dir_path = tempdir.relpath("codegen") + os.mkdir(codegen_dir_path) + _populate_codegen_dir(mod.lib, codegen_dir_path) + + parameters_dir_path = tempdir.relpath("parameters") + os.mkdir(parameters_dir_path) + param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params") + with open(param_filename, "wb") as f: + f.write(param_dict.save_param_dict(mod.params)) + + with open(tempdir.relpath("relay.txt"), "w") as f: + f.write(str(mod.ir_mod)) + + graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph")) + os.makedirs(graph_config_dir_path) + with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f: + f.write(mod.graph_json) + + with tarfile.open(file_name, "w") as tar_f: + + def reset(tarinfo): + tarinfo.uid = tarinfo.gid = 0 + tarinfo.uname = tarinfo.gname = "root" + return tarinfo + + tar_f.add(tempdir.temp_dir, arcname=".", filter=reset) diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index 3427a62cd491..e92ae710ca0b 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -16,9 +16,9 @@ # under the License. """Graph runtime factory.""" import warnings -from tvm._ffi.base import string_types -from tvm._ffi.registry import get_global_func -from tvm.runtime import ndarray +from ..._ffi.base import string_types +from ..._ffi.registry import get_global_func +from ...runtime import ndarray class GraphRuntimeFactoryModule: @@ -31,6 +31,8 @@ class GraphRuntimeFactoryModule: The graph to be deployed in json format output by graph compiler. The graph can contain operator(tvm_op) that points to the name of PackedFunc in the libmod. + target : tvm.Target + The Target used to build this module. libmod : tvm.Module The module of the corresponding function libmod_name: str @@ -39,13 +41,15 @@ class GraphRuntimeFactoryModule: The parameters of module """ - def __init__(self, graph_json_str, libmod, libmod_name, params): + def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params): assert isinstance(graph_json_str, string_types) fcreate = get_global_func("tvm.graph_runtime_factory.create") args = [] for k, v in params.items(): args.append(k) args.append(ndarray.array(v)) + self.ir_mod = ir_mod + self.target = target self.module = fcreate(graph_json_str, libmod, libmod_name, *args) self.graph_json = graph_json_str self.lib = libmod diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 4c9a898f2374..8e69d288df12 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -208,14 +208,14 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo return build(mod, target, target_host, params, mod_name).module -def build(mod, target=None, target_host=None, params=None, mod_name="default"): +def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"): # fmt: off # pylint: disable=line-too-long """Helper function that builds a Relay function to run on TVM graph runtime. Parameters ---------- - mod : :py:class:`~tvm.IRModule` + ir_mod : :py:class:`~tvm.IRModule` The IR module to build. Using relay.Function is deprecated. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional @@ -251,13 +251,13 @@ def build(mod, target=None, target_host=None, params=None, mod_name="default"): """ # pylint: enable=line-too-long # fmt: on - if not isinstance(mod, (IRModule, _function.Function)): + if not isinstance(ir_mod, (IRModule, _function.Function)): raise ValueError("Type of input parameter mod must be tvm.IRModule") - if isinstance(mod, _function.Function): + if isinstance(ir_mod, _function.Function): if params: - mod = bind_params_by_name(mod, params) - mod = IRModule.from_expr(mod) + ir_mod = bind_params_by_name(ir_mod, params) + ir_mod = IRModule.from_expr(ir_mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " "instead of deprecated parameter mod (tvm.relay.function.Function)", @@ -280,9 +280,11 @@ def build(mod, target=None, target_host=None, params=None, mod_name="default"): with tophub_context: bld_mod = BuildModule() - graph_json, mod, params = bld_mod.build(mod, target, target_host, params) - mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params) - return mod + graph_json, runtime_mod, params = bld_mod.build(ir_mod, target, target_host, params) + runtime_mod = _graph_runtime_factory.GraphRuntimeFactoryModule( + ir_mod, target, graph_json, runtime_mod, mod_name, params + ) + return runtime_mod def optimize(mod, target=None, params=None): diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 63267969ab4e..53576a60f32f 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -105,6 +105,9 @@ def __getitem__(self, name): raise ValueError("Can only take string as function name") return self.get_function(name) + def __eq__(self, other): + return self.handle.value == other.handle.value + def __call__(self, *args): if self._entry: return self._entry(*args) @@ -233,15 +236,27 @@ def evaluator(*args): except NameError: raise NameError("time_evaluate is only supported when RPC is enabled") - def _collect_dso_modules(self): - """Helper function to collect dso modules, then return it.""" + def _collect_from_import_tree(self, filter_func): + """Helper function to collect modules from the tree matching a filter_func, then return it. + + Parameters + ---------- + filter_func : Callable[[Module], bool] + A function which is invoked for each Module discovered in the import tree (including + self). + + Returns + ------- + list[Module] : + A list of matching Module. + """ visited, stack, dso_modules = set(), [], [] # append root module visited.add(self) stack.append(self) while stack: module = stack.pop() - if module._dso_exportable(): + if filter_func(module): dso_modules.append(module) for m in module.imported_modules: if m not in visited: @@ -249,8 +264,9 @@ def _collect_dso_modules(self): stack.append(m) return dso_modules - def _dso_exportable(self): - return self.type_key == "llvm" or self.type_key == "c" + def _collect_dso_modules(self): + is_dso_exportable = lambda m: (m.type_key == "llvm" or m.type_key == "c") + return self._collect_from_import_tree(is_dso_exportable) def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs): """Export the module and its imported device code one library. diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 4d3993a9a36f..605d6b0ce892 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -156,7 +156,8 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args "graph_runtime_factory.create needs at least 3, " "but it has " << args.num_args; - // The argument order is graph_json, module, module_name, params. + // The argument order is graph_json, module, module_name, param0_name, param0_tensor, + // [param1_name, param1_tensor], ... ICHECK_EQ((args.size() - 3) % 2, 0); std::unordered_map params; for (size_t i = 3; i < static_cast(args.size()); i += 2) { diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py new file mode 100644 index 000000000000..c999091cc3cc --- /dev/null +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import json +import os +import sys +import tarfile + +import numpy +import pytest + +import tvm +import tvm.relay +from tvm.relay.backend import graph_runtime_factory +import tvm.runtime.module +import tvm.testing +from tvm.contrib import utils + + +def validate_graph_json(extract_dir, factory): + with open(os.path.join(extract_dir, "runtime-config", "graph", "graph.json")) as graph_f: + graph_json = graph_f.read() + assert graph_json == factory.graph_json + + # Just check it parses and looks roughly right. + graph = json.loads(graph_json) + assert "nodes" in graph + assert len(graph["nodes"]) == 4 + assert "attrs" in graph + + +@tvm.testing.requires_micro +def test_export_model_library_format_c(): + with utils.TempDirectory.set_keep_for_debug(True): + target = tvm.target.target.micro("host") + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + relay_mod = tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[(1, 2), float32]) { + %0 = cast(%a, dtype="float32") + %b * %c; + %0 + }""" + ) + factory = tvm.relay.build( + relay_mod, + target, + target_host=target, + mod_name="add", + params={"c": numpy.array([[2.0, 4.0]], dtype="float32")}, + ) + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + import tvm.micro as micro + + micro.export_model_library_format(factory, mlf_tar_path) + tf = tarfile.open(mlf_tar_path) + + extract_dir = temp_dir.relpath("extract") + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + with open(os.path.join(extract_dir, "metadata.json")) as json_f: + metadata = json.load(json_f) + assert metadata["version"] == 1 + assert metadata["model_name"] == "add" + export_datetime = datetime.datetime.strptime( + metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" + ) + assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) + assert metadata["target"] == {"1": str(target)} + assert metadata["memory"] == [ + {"storage_id": 0, "size_bytes": 2, "input_binding": "a"}, + {"storage_id": 1, "size_bytes": 8, "input_binding": "b"}, + {"storage_id": 2, "size_bytes": 8, "input_binding": "p0"}, + {"storage_id": 3, "size_bytes": 8}, + ] + + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib0.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib1.c")) + + validate_graph_json(extract_dir, factory) + + with open(os.path.join(extract_dir, "relay.txt")) as relay_f: + assert relay_f.read() == str(relay_mod) + + with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: + params = tvm.relay.load_param_dict(params_f.read()) + assert "p0" in params + + +@tvm.testing.requires_micro +def test_export_model_library_format_llvm(): + with utils.TempDirectory.set_keep_for_debug(True): + target = tvm.target.target.micro("host") + assert str(target)[:2] == "c " + target = tvm.target.Target("llvm " + str(target)[2:]) + with tvm.transform.PassContext(opt_level=3): + relay_mod = tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[(1, 2), float32]) { + %0 = cast(%a, dtype="float32") + %b * %c; + %0 + }""" + ) + factory = tvm.relay.build( + relay_mod, + target, + target_host=target, + mod_name="add", + params={"c": numpy.array([[2.0, 4.0]], dtype="float32")}, + ) + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + import tvm.micro as micro + + micro.export_model_library_format(factory, mlf_tar_path) + tf = tarfile.open(mlf_tar_path) + + extract_dir = temp_dir.relpath("extract") + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + with open(os.path.join(extract_dir, "metadata.json")) as json_f: + metadata = json.load(json_f) + assert metadata["version"] == 1 + assert metadata["model_name"] == "add" + export_datetime = datetime.datetime.strptime( + metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" + ) + assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) + assert metadata["target"] == {"1": str(target)} + assert metadata["memory"] == [ + {"storage_id": 0, "size_bytes": 2, "input_binding": "a"}, + {"storage_id": 1, "size_bytes": 8, "input_binding": "b"}, + {"storage_id": 2, "size_bytes": 8, "input_binding": "p0"}, + {"storage_id": 3, "size_bytes": 8}, + ] + + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "lib", "lib0.o")) + + validate_graph_json(extract_dir, factory) + + with open(os.path.join(extract_dir, "relay.txt")) as relay_f: + assert relay_f.read() == str(relay_mod) + + with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: + params = tvm.relay.load_param_dict(params_f.read()) + assert "p0" in params + + +@tvm.testing.requires_micro +def test_export_model(): + module = tvm.support.FrontendTestModule() + factory = graph_runtime_factory.GraphRuntimeFactoryModule( + None, tvm.target.target.micro("host"), '"graph_json"', module, "test_module", {} + ) + + temp_dir = utils.tempdir() + import tvm.micro as micro + import tvm.micro.model_library_format as model_library_format + + with pytest.raises(micro.UnsupportedInModelLibraryFormatError) as exc: + model_library_format._populate_codegen_dir(module, temp_dir.relpath("codegen")) + + assert str(exc.exception) == ( + "Don't know how to export non-c or non-llvm modules; found: ffi_testing" + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))