Skip to content

Commit

Permalink
[tvmc] Add a --config option to tvmc compile (apache#8253)
Browse files Browse the repository at this point in the history
* Allow to send some configurations to the PassContext via command line

 * Add various validations to the new option with appropriate error messages

 * Add unit testing
  • Loading branch information
leandron authored and trevor-m committed Jun 17, 2021
1 parent b72b71c commit b15d74c
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 1 deletion.
100 changes: 100 additions & 0 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,103 @@ def parse_shape_string(inputs_string):
shape_dict[name] = shape

return shape_dict


def get_pass_config_value(name, value, config_type):
"""Get a PassContext configuration value, based on its config data type.
Parameters
----------
name: str
config identifier name.
value: str
value assigned to the config, provided via command line.
config_type: str
data type defined to the config, as string.
Returns
-------
parsed_value: bool, int or str
a representation of the input value, converted to the type
specified by config_type.
"""

if config_type == "IntImm":
# "Bool" configurations in the PassContext are recognized as
# IntImm, so deal with this case here
mapping_values = {
"false": False,
"true": True,
}

if value.isdigit():
parsed_value = int(value)
else:
# if not an int, accept only values on the mapping table, case insensitive
parsed_value = mapping_values.get(value.lower(), None)

if parsed_value is None:
raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ")

if config_type == "runtime.String":
parsed_value = value

return parsed_value


def parse_configs(input_configs):
"""Parse configuration values set via command line.
Parameters
----------
input_configs: list of str
list of configurations provided via command line.
Returns
-------
pass_context_configs: dict
a dict containing key-value configs to be used in the PassContext.
"""
if not input_configs:
return {}

all_configs = tvm.ir.transform.PassContext.list_configs()
supported_config_types = ("IntImm", "runtime.String")
supported_configs = [
name for name in all_configs.keys() if all_configs[name]["type"] in supported_config_types
]

pass_context_configs = {}

for config in input_configs:
if not config:
raise TVMCException(
f"Invalid format for configuration '{config}', use <config>=<value>"
)

# Each config is expected to be provided as "name=value"
try:
name, value = config.split("=")
name = name.strip()
value = value.strip()
except ValueError:
raise TVMCException(
f"Invalid format for configuration '{config}', use <config>=<value>"
)

if name not in all_configs:
raise TVMCException(
f"Configuration '{name}' is not defined in TVM. "
f"These are the existing configurations: {', '.join(all_configs)}"
)

if name not in supported_configs:
raise TVMCException(
f"Configuration '{name}' uses a data type not supported by TVMC. "
f"The following configurations are supported: {', '.join(supported_configs)}"
)

parsed_value = get_pass_config_value(name, value, all_configs[name]["type"])
pass_context_configs[name] = parsed_value

return pass_context_configs
15 changes: 14 additions & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def add_compile_parser(subparsers):
help="output format. Use 'so' for shared object or 'mlf' for Model Library Format "
"(only for µTVM targets). Defaults to 'so'.",
)
parser.add_argument(
"--pass-config",
action="append",
metavar=("name=value"),
help="configurations to be used at compile time. This option can be provided multiple "
"times, each one to set one configuration value, "
"e.g. '--pass-config relay.backend.use_auto_scheduler=0'.",
)
parser.add_argument(
"--target",
help="compilation targets as comma separated string, inline JSON or path to a JSON file.",
Expand Down Expand Up @@ -145,6 +153,7 @@ def drive_compile(args):
target_host=None,
desired_layout=args.desired_layout,
disabled_pass=args.disabled_pass,
pass_context_configs=args.pass_config,
)

return 0
Expand All @@ -162,6 +171,7 @@ def compile_model(
target_host: Optional[str] = None,
desired_layout: Optional[str] = None,
disabled_pass: Optional[str] = None,
pass_context_configs: Optional[str] = None,
):
"""Compile a model from a supported framework into a TVM module.
Expand Down Expand Up @@ -202,6 +212,9 @@ def compile_model(
disabled_pass: str, optional
Comma-separated list of passes which needs to be disabled
during compilation
pass_context_configs: str, optional
String containing a set of configurations to be passed to the
PassContext.
Returns
Expand All @@ -212,7 +225,7 @@ def compile_model(
"""
mod, params = tvmc_model.mod, tvmc_model.params

config = {}
config = common.parse_configs(pass_context_configs)

if desired_layout:
mod = common.convert_graph_layout(mod, desired_layout)
Expand Down
51 changes: 51 additions & 0 deletions tests/python/driver/tvmc/test_tvmc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

import tvm
from tvm.contrib.target.vitis_ai import vitis_ai_available
from tvm.driver import tvmc

from tvm.driver.tvmc.common import TVMCException
Expand Down Expand Up @@ -306,3 +307,53 @@ def test_parse_quotes_and_separators_on_options():

assert len(targets_double_quote) == 1
assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"]


def test_config_invalid_format():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"])


def test_config_missing_from_tvm():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"])


def test_config_unsupported_tvmc_config():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["tir.LoopPartition=value"])


def test_config_empty():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs([""])


def test_config_valid_config_bool():
configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"])

assert len(configs) == 1
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"] == True


@pytest.mark.skipif(
not vitis_ai_available(),
reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'",
)
def test_config_valid_multiple_configs():
configs = tvmc.common.parse_configs(
[
"relay.backend.use_auto_scheduler=false",
"tir.detect_global_barrier=10",
"relay.ext.vitis_ai.options.build_dir=mystring",
]
)

assert len(configs) == 3
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"] == False
assert "tir.detect_global_barrier" in configs.keys()
assert configs["tir.detect_global_barrier"] == 10
assert "relay.ext.vitis_ai.options.build_dir" in configs.keys()
assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"

0 comments on commit b15d74c

Please sign in to comment.