diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 2f177a242835c..47ffde4327c42 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -64,6 +64,18 @@ MAX_FLOAT = 1e10 +class BuildFunc: + """store build_func name and callable to class variable. + name: str = "default" + The name of registered build function. + build_func: callable = tar.tar + The callable of registered build function. + """ + + name = "default" + build_func = tar.tar + + @tvm._ffi.register_object("auto_scheduler.MeasureCallback") class MeasureCallback(Object): """ The base class of measurement callback functions. """ @@ -303,12 +315,28 @@ class LocalBuilder(ProgramBuilder): This is used in a wrapper of the multiprocessing.Process.join(). n_parallel : int = multiprocessing.cpu_count() Number of threads used to build in parallel. - build_func : str = 'default' - The name of registered build function. + build_func: callable or str = "default" + If is 'default', use default build function + If is 'ndk', use function for android ndk + If is callable, use it as custom build function, expect lib_format field. """ def __init__(self, timeout=15, n_parallel=multiprocessing.cpu_count(), build_func="default"): - self.__init_handle_by_constructor__(_ffi_api.LocalBuilder, timeout, n_parallel, build_func) + if build_func == "default": + BuildFunc.name = "default" + BuildFunc.build_func = tar.tar + elif build_func == "ndk": + BuildFunc.name = "ndk" + BuildFunc.build_func = ndk.create_shared + elif callable(build_func): + BuildFunc.name = "custom" + BuildFunc.build_func = build_func + else: + raise ValueError("Invalid build_func" + build_func) + + self.__init_handle_by_constructor__( + _ffi_api.LocalBuilder, timeout, n_parallel, BuildFunc.name + ) @tvm._ffi.register_object("auto_scheduler.LocalRunner") @@ -624,12 +652,10 @@ def local_build_worker(args): The build result of this Builder thread. """ inp, build_func, timeout, verbose = args - if build_func == "default": - build_func = tar.tar - elif build_func == "ndk": - build_func = ndk.create_shared - else: - raise ValueError("Invalid build_func" + build_func) + assert build_func == BuildFunc.name, ( + "BuildFunc.name: " + BuildFunc.name + ", but args is: " + build_func + ) + build_func = BuildFunc.build_func res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose)) if isinstance(res, TimeoutError):