Skip to content

Commit

Permalink
[AUTOTVM] Refactor measure build func (apache#2927)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and wweic committed Apr 10, 2019
1 parent a38aa40 commit 4260e42
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 40 deletions.
79 changes: 40 additions & 39 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
rpc as _rpc, target as _target
from ...contrib import nvcc, ndk
from ...contrib import nvcc, ndk, tar

from ..util import get_const_tuple
from ..env import AutotvmGlobalScope
Expand Down Expand Up @@ -58,20 +58,20 @@ class LocalBuilder(Builder):
build_func: callable or str
If is 'default', use default build function
If is 'ndk', use function for android ndk
If is callable, use it as custom build function
If is callable, use it as custom build function, expect lib_format field.
"""
def __init__(self, timeout=10, n_parallel=None, build_func='default'):
super(LocalBuilder, self).__init__(timeout, n_parallel)

if isinstance(build_func, str):
if build_func == 'default':
build_func = default_build_func
build_func = tar.tar
elif build_func == 'ndk':
build_func = android_ndk_build_func
build_func = ndk.create_shared
else:
raise ValueError("Invalid build_func" + build_func)

self.build_func = build_func
self.build_func = _wrap_build_func(build_func)
self.executor = LocalExecutor(timeout=timeout)
self.tmp_dir = tempfile.mkdtemp()

Expand Down Expand Up @@ -349,46 +349,47 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)


def default_build_func(measure_input, tmp_dir, **kwargs):
def _wrap_build_func(build_func):
"""
Default build func. This can work for cuda, opencl, llvm backend
Wrap build_func to a function that can be used in measure.
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)


def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
"""
Build function for android device using ndk.
build_func : The compilation function
We expect fcompile to contain an attr "output_format"
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
Returns
-------
wrapped_build_func : function
The wrapped build function
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, ndk.create_shared)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
if not hasattr(build_func, "output_format"):
raise AttributeError("Expect build_func to have the attribute output_format.")
output_format = build_func.output_format

def _wrapped(measure_input, tmp_dir, **kwargs):
"""
Wrapped build func.
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
getrandbits(64), output_format))
# TODO(tvm-team) consider linline _build_func_common
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, build_func)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
return _wrapped


def run_through_rpc(measure_input, build_result,
Expand Down
34 changes: 33 additions & 1 deletion python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,46 @@ def create_shared(output,
cc : str, optional
The compile string.
"""
if sys.platform == "darwin" or sys.platform.startswith('linux'):
if sys.platform == "darwin" or sys.platform.startswith("linux"):
_linux_shared(output, objects, options, cc)
elif sys.platform == "win32":
_windows_shared(output, objects, options)
else:
raise ValueError("Unsupported platform")


# assign so as default output format
create_shared.output_format = "so" if sys.platform != "win32" else "dll"


def cross_compiler(cc, options=None, output_format="so"):
"""Create a cross compiler function.
Parameters
----------
cc : str
The cross compiler name.
options : list, optional
List of additional optional string.
output_format : str, optional
Library output format.
Returns
-------
fcompile : function
A compilation function that can be passed to export_library.
"""
def _fcompile(outputs, objects, opts=None):
opts = opts if opts else []
if options:
opts += options
_linux_shared(outputs, objects, opts, cc=cc)
_fcompile.output_format = output_format
return _fcompile


def _linux_shared(output, objects, options, cc="g++"):
cmd = [cc]
cmd += ["-shared", "-fPIC"]
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def tar(output, files):
msg += py_str(out)
raise RuntimeError(msg)

# assign output format
tar.output_format = "tar"


def untar(tar_file, directory):
"""Unpack all tar files into the directory
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/xcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def create_dylib(output, objects, arch, sdk="macosx"):
raise RuntimeError(msg)


# assign so as default output format
create_dylib.output_format = "dylib"

def compile_metal(code, path_target=None, sdk="macosx"):
"""Compile metal with CLI tool from env.
Expand Down

0 comments on commit 4260e42

Please sign in to comment.