Skip to content

Commit

Permalink
Introduce Model Library Format export format.
Browse files Browse the repository at this point in the history
 * This function produces a stable on-disk representation of TVM's
   compiler output.
 * It's intended just for use with the C runtime for microTVM right
   now. It could be expanded for other use cases.
 * This PR implements the Model Library Format RFC, which ultimately
   is intended to support the Project Generator API (RFC
   forthcoming).
 * There may be some changes to the format without revving the version
   number until downstream consumers are known. The Project Generator
   API is the first such known downstream consumer.
 * There are no plans currently to support generating old Model
   Library Format from TVM. The version number is intended as a
   compatibility check between the generator and downstream consumers.
  • Loading branch information
areusch committed Feb 26, 2021
1 parent 43b15a8 commit 72df6e4
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 13 deletions.
98 changes: 94 additions & 4 deletions python/tvm/relay/backend/graph_runtime_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
# specific language governing permissions and limitations
# under the License.
"""Graph runtime factory."""
import datetime
import os
import json
import re
import tarfile
import warnings
from tvm._ffi.base import string_types
from tvm._ffi.registry import get_global_func
from tvm.runtime import ndarray
from ...contrib import utils
from ..._ffi.base import string_types
from ..._ffi.registry import get_global_func
from ...runtime import ndarray
from .. import param_dict


class GraphRuntimeFactoryModule(object):
Expand All @@ -31,6 +38,8 @@ class GraphRuntimeFactoryModule(object):
The graph to be deployed in json format output by graph compiler.
The graph can contain operator(tvm_op) that points to the name of
PackedFunc in the libmod.
target : tvm.Target
The Target used to build this module.
libmod : tvm.Module
The module of the corresponding function
libmod_name: str
Expand All @@ -39,13 +48,15 @@ class GraphRuntimeFactoryModule(object):
The parameters of module
"""

def __init__(self, graph_json_str, libmod, libmod_name, params):
def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params):
assert isinstance(graph_json_str, string_types)
fcreate = get_global_func("tvm.graph_runtime_factory.create")
args = []
for k, v in params.items():
args.append(k)
args.append(ndarray.array(v))
self.ir_mod = ir_mod
self.target = target
self.module = fcreate(graph_json_str, libmod, libmod_name, *args)
self.graph_json = graph_json_str
self.lib = libmod
Expand All @@ -56,6 +67,85 @@ def __init__(self, graph_json_str, libmod, libmod_name, params):
def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
return self.module.export_library(file_name, fcompile, addons, **kwargs)

def _build_memory_map(self):
graph = json.loads(self.graph_json)

seen_storage_ids = set()
memory_map = []
for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]):
if storage_id in seen_storage_ids:
continue

seen_storage_ids.add(storage_id)
num_elements = 1
for dim in graph["attrs"]["shape"][1][storage_id]:
num_elements *= dim

dltype = graph["attrs"]["dltype"][1][storage_id]
m = re.match(r"^[a-zA-Z]+([0-9]+)$", dltype)
assert m, f"Exported graph contains unknown dltype {dltype}"

elem_bits = int(m.group(1))

map_entry = {
"storage_id": storage_id,
"size_bytes": (num_elements * elem_bits + 7) // 8,
}
if node_id in graph["arg_nodes"]:
map_entry["input_binding"] = graph["nodes"][node_id]["name"]

memory_map.append(map_entry)

return memory_map

def export_model_library_format(self, file_name):
"""Export the build artifact in Model Library Format.
This function creates a .tar archive containing the build artifacts in a standardized
layout. It's intended to allow downstream automation to build TVM artifacts against the C
runtime.
Parameters
----------
file_name : str
Path to the .tar archive to generate.
"""
tempdir = utils.tempdir()
metadata = {
"version": 1,
"model_name": self.libmod_name,
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
"memory": self._build_memory_map(),
"target": {int(k): str(v) for k, v in self.target.items()},
"runtimes": ["graph"],
}
with open(tempdir.relpath("metadata.json"), "w") as json_f:
json.dump(metadata, json_f, indent=2, sort_keys=True)

codegen_dir_path = tempdir.relpath("codegen")
print("codegen_dir", codegen_dir_path)
os.mkdir(codegen_dir_path)
self.lib.export_model_library_format(codegen_dir_path)
parameters_dir_path = tempdir.relpath("parameters")
os.mkdir(parameters_dir_path)
param_filename = os.path.join(parameters_dir_path, f"{self.libmod_name}.params")
with open(param_filename, "wb") as f:
f.write(param_dict.save_param_dict(self.params))
with open(tempdir.relpath("relay.txt"), "w") as f:
f.write(str(self.ir_mod))
graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph"))
os.makedirs(graph_config_dir_path)
with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f:
f.write(self.graph_json)
with tarfile.open(file_name, "w") as tar_f:

def reset(tarinfo):
tarinfo.uid = tarinfo.gid = 0
tarinfo.uname = tarinfo.gname = "root"
return tarinfo

tar_f.add(tempdir.temp_dir, arcname=".", filter=reset)

# Sometimes we want to get params explicitly.
# For example, we want to save its params value to
# an independent file.
Expand Down
20 changes: 11 additions & 9 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ def get_params(self):
return ret


def build(mod, target=None, target_host=None, params=None, mod_name="default"):
def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"):
# fmt: off
# pylint: disable=line-too-long
"""Helper function that builds a Relay function to run on TVM graph runtime.
Parameters
----------
mod : :py:class:`~tvm.IRModule`
ir_mod : :py:class:`~tvm.IRModule`
The IR module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional
Expand Down Expand Up @@ -237,13 +237,13 @@ def build(mod, target=None, target_host=None, params=None, mod_name="default"):
"""
# pylint: enable=line-too-long
# fmt: on
if not isinstance(mod, (IRModule, _function.Function)):
if not isinstance(ir_mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

if isinstance(mod, _function.Function):
if isinstance(ir_mod, _function.Function):
if params:
mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod)
ir_mod = bind_params_by_name(mod, params)
ir_mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod (tvm.relay.function.Function)",
Expand All @@ -266,9 +266,11 @@ def build(mod, target=None, target_host=None, params=None, mod_name="default"):

with tophub_context:
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params)
return mod
graph_json, runtime_mod, params = bld_mod.build(ir_mod, target, target_host, params)
runtime_mod = _graph_runtime_factory.GraphRuntimeFactoryModule(
ir_mod, target, graph_json, runtime_mod, mod_name, params
)
return runtime_mod


def optimize(mod, target=None, params=None):
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,33 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No

return fcompile(file_name, files, **kwargs)

def export_model_library_format(self, codegen_dir: str):
"""Populate the codegen sub-directory as part of a Model Library Format export.
Parameters
----------
codegen_dir : str
Path to the codegen directory on disk.
"""
dso_modules = self._collect_dso_modules()
mod_indices = {"lib": 0, "src": 0}
host_codegen_dir = os.path.join(codegen_dir, "host")
for mod in dso_modules:
if mod.type_key == "c":
index = mod_indices["src"]
mod_indices["src"] += 1
parent_dir = os.path.join(host_codegen_dir, "src")
file_name = os.path.join(parent_dir, f"lib{index}.c")
else:
index = mod_indices["lib"]
mod_indices["lib"] += 1
parent_dir = os.path.join(host_codegen_dir, "lib")
file_name = os.path.join(parent_dir, f"lib{index}.o")

if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
mod.save(file_name)


def system_lib():
"""Get system-wide library module singleton.
Expand Down
163 changes: 163 additions & 0 deletions tests/python/unittest/test_graph_runtime_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import datetime
import json
import os
import sys
import tarfile

import numpy
import pytest

import tvm
import tvm.relay
from tvm.contrib import utils


def validate_graph_json(extract_dir, factory):
with open(os.path.join(extract_dir, "runtime-config", "graph", "graph.json")) as graph_f:
graph_json = graph_f.read()
assert graph_json == factory.graph_json

# Just check it parses and looks roughly right.
graph = json.loads(graph_json)
assert "nodes" in graph
assert len(graph["nodes"]) == 4
assert "attrs" in graph


def test_export_model_library_format_c():
with utils.TempDirectory.set_keep_for_debug(True):
target = tvm.target.target.micro("host")
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
relay_mod = tvm.parser.fromtext(
"""
#[version = "0.0.5"]
def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[(1, 2), float32]) {
%0 = cast(%a, dtype="float32") + %b * %c;
%0
}"""
)
factory = tvm.relay.build(
relay_mod,
target,
target_host=target,
mod_name="add",
params={"c": numpy.array([[2.0, 4.0]], dtype="float32")},
)

temp_dir = utils.tempdir()
mlf_tar_path = temp_dir.relpath("lib.tar")
print("fac", factory)
factory.export_model_library_format(mlf_tar_path)
tf = tarfile.open(mlf_tar_path)

extract_dir = temp_dir.relpath("extract")
os.mkdir(extract_dir)
tf.extractall(extract_dir)

with open(os.path.join(extract_dir, "metadata.json")) as json_f:
metadata = json.load(json_f)
assert metadata["version"] == 1
assert metadata["model_name"] == "add"
export_datetime = datetime.datetime.strptime(
metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ"
)
assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5)
assert metadata["target"] == {"1": str(target)}
assert metadata["memory"] == [
{"storage_id": 0, "size_bytes": 2, "input_binding": "a"},
{"storage_id": 1, "size_bytes": 8, "input_binding": "b"},
{"storage_id": 2, "size_bytes": 8, "input_binding": "p0"},
{"storage_id": 3, "size_bytes": 8},
]

assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib0.c"))
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib1.c"))

validate_graph_json(extract_dir, factory)

with open(os.path.join(extract_dir, "relay.txt")) as relay_f:
assert relay_f.read() == str(relay_mod)

with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f:
params = tvm.relay.load_param_dict(params_f.read())
assert "p0" in params


def test_export_model_library_format_llvm():
with utils.TempDirectory.set_keep_for_debug(True):
target = tvm.target.target.micro("host")
assert str(target)[:2] == "c "
target = tvm.target.Target("llvm " + str(target)[2:])
with tvm.transform.PassContext(opt_level=3):
relay_mod = tvm.parser.fromtext(
"""
#[version = "0.0.5"]
def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[(1, 2), float32]) {
%0 = cast(%a, dtype="float32") + %b * %c;
%0
}"""
)
factory = tvm.relay.build(
relay_mod,
target,
target_host=target,
mod_name="add",
params={"c": numpy.array([[2.0, 4.0]], dtype="float32")},
)

temp_dir = utils.tempdir()
mlf_tar_path = temp_dir.relpath("lib.tar")
factory.export_model_library_format(mlf_tar_path)
tf = tarfile.open(mlf_tar_path)

extract_dir = temp_dir.relpath("extract")
os.mkdir(extract_dir)
tf.extractall(extract_dir)

with open(os.path.join(extract_dir, "metadata.json")) as json_f:
metadata = json.load(json_f)
assert metadata["version"] == 1
assert metadata["model_name"] == "add"
export_datetime = datetime.datetime.strptime(
metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ"
)
assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5)
assert metadata["target"] == {"1": str(target)}
assert metadata["memory"] == [
{"storage_id": 0, "size_bytes": 2, "input_binding": "a"},
{"storage_id": 1, "size_bytes": 8, "input_binding": "b"},
{"storage_id": 2, "size_bytes": 8, "input_binding": "p0"},
{"storage_id": 3, "size_bytes": 8},
]

assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "lib", "lib0.o"))

validate_graph_json(extract_dir, factory)

with open(os.path.join(extract_dir, "relay.txt")) as relay_f:
assert relay_f.read() == str(relay_mod)

with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f:
params = tvm.relay.load_param_dict(params_f.read())
assert "p0" in params


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 72df6e4

Please sign in to comment.