diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index f77a13bcd2ed..7f65f2e88dac 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -19,7 +19,7 @@ from ... import ir_pass, build, build_config, nd, TVMError, register_func, \ rpc as _rpc, target as _target -from ...contrib import nvcc, ndk +from ...contrib import nvcc, ndk, tar from ..util import get_const_tuple from ..env import AutotvmGlobalScope @@ -58,20 +58,20 @@ class LocalBuilder(Builder): build_func: callable or str If is 'default', use default build function If is 'ndk', use function for android ndk - If is callable, use it as custom build function + If is callable, use it as custom build function, expect lib_format field. """ def __init__(self, timeout=10, n_parallel=None, build_func='default'): super(LocalBuilder, self).__init__(timeout, n_parallel) if isinstance(build_func, str): if build_func == 'default': - build_func = default_build_func + build_func = tar.tar elif build_func == 'ndk': - build_func = android_ndk_build_func + build_func = ndk.create_shared else: raise ValueError("Invalid build_func" + build_func) - self.build_func = build_func + self.build_func = _wrap_build_func(build_func) self.executor = LocalExecutor(timeout=timeout) self.tmp_dir = tempfile.mkdtemp() @@ -349,46 +349,47 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) -def default_build_func(measure_input, tmp_dir, **kwargs): +def _wrap_build_func(build_func): """ - Default build func. This can work for cuda, opencl, llvm backend + Wrap build_func to a function that can be used in measure. Parameters ---------- - measure_input: MeasureInput - The input of measurement - tmp_dir: str - The path of temporary directory to export generated library - """ - tic = time.time() - try: - filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64)) - func, arg_info = _build_func_common(measure_input, **kwargs) - func.export_library(filename) - except Exception as e: # pylint: disable=broad-except - return BuildResult(None, None, e, time.time() - tic) - return BuildResult(filename, arg_info, None, time.time() - tic) - - -def android_ndk_build_func(measure_input, tmp_dir, **kwargs): - """ - Build function for android device using ndk. + build_func : The compilation function + We expect fcompile to contain an attr "output_format" - Parameters - ---------- - measure_input: MeasureInput - The input of measurement - tmp_dir: str - The path of temporary directory to export generated library + Returns + ------- + wrapped_build_func : function + The wrapped build function """ - tic = time.time() - try: - filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64)) - func, arg_info = _build_func_common(measure_input, **kwargs) - func.export_library(filename, ndk.create_shared) - except Exception as e: # pylint: disable=broad-except - return BuildResult(None, None, e, time.time() - tic) - return BuildResult(filename, arg_info, None, time.time() - tic) + if not hasattr(build_func, "output_format"): + raise AttributeError("Expect build_func to have the attribute output_format.") + output_format = build_func.output_format + + def _wrapped(measure_input, tmp_dir, **kwargs): + """ + Wrapped build func. + + Parameters + ---------- + measure_input: MeasureInput + The input of measurement + + tmp_dir: str + The path of temporary directory to export generated library + """ + tic = time.time() + try: + filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % ( + getrandbits(64), output_format)) + # TODO(tvm-team) consider linline _build_func_common + func, arg_info = _build_func_common(measure_input, **kwargs) + func.export_library(filename, build_func) + except Exception as e: # pylint: disable=broad-except + return BuildResult(None, None, e, time.time() - tic) + return BuildResult(filename, arg_info, None, time.time() - tic) + return _wrapped def run_through_rpc(measure_input, build_result, diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index ee84da820902..09822e594b75 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -29,7 +29,7 @@ def create_shared(output, cc : str, optional The compile string. """ - if sys.platform == "darwin" or sys.platform.startswith('linux'): + if sys.platform == "darwin" or sys.platform.startswith("linux"): _linux_shared(output, objects, options, cc) elif sys.platform == "win32": _windows_shared(output, objects, options) @@ -37,6 +37,38 @@ def create_shared(output, raise ValueError("Unsupported platform") +# assign so as default output format +create_shared.output_format = "so" if sys.platform != "win32" else "dll" + + +def cross_compiler(cc, options=None, output_format="so"): + """Create a cross compiler function. + + Parameters + ---------- + cc : str + The cross compiler name. + + options : list, optional + List of additional optional string. + + output_format : str, optional + Library output format. + + Returns + ------- + fcompile : function + A compilation function that can be passed to export_library. + """ + def _fcompile(outputs, objects, opts=None): + opts = opts if opts else [] + if options: + opts += options + _linux_shared(outputs, objects, opts, cc=cc) + _fcompile.output_format = output_format + return _fcompile + + def _linux_shared(output, objects, options, cc="g++"): cmd = [cc] cmd += ["-shared", "-fPIC"] diff --git a/python/tvm/contrib/tar.py b/python/tvm/contrib/tar.py index 7e075d7a5697..741a9140d741 100644 --- a/python/tvm/contrib/tar.py +++ b/python/tvm/contrib/tar.py @@ -42,6 +42,9 @@ def tar(output, files): msg += py_str(out) raise RuntimeError(msg) +# assign output format +tar.output_format = "tar" + def untar(tar_file, directory): """Unpack all tar files into the directory diff --git a/python/tvm/contrib/xcode.py b/python/tvm/contrib/xcode.py index a43dc9ae2bfe..99f593863522 100644 --- a/python/tvm/contrib/xcode.py +++ b/python/tvm/contrib/xcode.py @@ -98,6 +98,9 @@ def create_dylib(output, objects, arch, sdk="macosx"): raise RuntimeError(msg) +# assign so as default output format +create_dylib.output_format = "dylib" + def compile_metal(code, path_target=None, sdk="macosx"): """Compile metal with CLI tool from env.