diff --git a/CMakeLists.txt b/CMakeLists.txt index c1c068cffa68..e091dbdc4721 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -257,6 +257,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) # Source file lists file(GLOB_RECURSE COMPILER_SRCS src/auto_scheduler/*.cc + src/meta_schedule/*.cc src/node/*.cc src/ir/*.cc src/arith/*.cc diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h new file mode 100644 index 000000000000..d0985071b773 --- /dev/null +++ b/include/tvm/meta_schedule/builder.h @@ -0,0 +1,151 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_BUILDER_H_ +#define TVM_META_SCHEDULE_BUILDER_H_ + +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief The builder's input. */ +class BuilderInputNode : public runtime::Object { + public: + /*! \brief The IRModule to be built. */ + IRModule mod; + /*! \brief The target to be built for. */ + Target target; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("mod", &mod); + v->Visit("target", &target); + } + + static constexpr const char* _type_key = "meta_schedule.BuilderInput"; + TVM_DECLARE_FINAL_OBJECT_INFO(BuilderInputNode, runtime::Object); +}; + +/*! + * \brief Managed reference to BuilderInputNode + * \sa BuilderInputNode + */ +class BuilderInput : public runtime::ObjectRef { + public: + /*! + * \brief Constructor of BuilderInput. + * \param mod The IRModule to be built. + * \param target The target to be built for. + */ + TVM_DLL explicit BuilderInput(IRModule mod, Target target); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); +}; + +/*! \brief The builder's output. */ +class BuilderResultNode : public runtime::Object { + public: + /*! \brief The path to the built artifact. */ + Optional artifact_path; + /*! \brief The error message if any. */ + Optional error_msg; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("artifact_path", &artifact_path); + v->Visit("error_msg", &error_msg); + } + + static constexpr const char* _type_key = "meta_schedule.BuilderResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(BuilderResultNode, runtime::Object); +}; + +/*! + * \brief Managed reference to BuilderResultNode + * \sa BuilderResultNode + */ +class BuilderResult : public runtime::ObjectRef { + public: + /*! + * \brief Constructor of BuilderResult. + * \param artifact_path The path to the built artifact. + * \param error_msg The error message if any. + */ + TVM_DLL explicit BuilderResult(Optional artifact_path, Optional error_msg); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderResult, runtime::ObjectRef, BuilderResultNode); +}; + +/*! \brief The abstract builder interface. */ +class BuilderNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~BuilderNode() = default; + /*! + * \brief Generate the build results from build inputs. + * \param build_inputs The inputs to be built. + * \return The build results. + */ + virtual Array Build(const Array& build_inputs) = 0; + /*! + * \brief The function type of `Build` method. + * \param build_inputs The inputs to be built. + * \return The build results. + */ + using FBuild = runtime::TypedPackedFunc(const Array&)>; + + static constexpr const char* _type_key = "meta_schedule.Builder"; + TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, runtime::Object); +}; + +/*! + * \brief Managed reference to BuilderNode + * \sa BuilderNode + */ +class Builder : public runtime::ObjectRef { + public: + /*! + * \brief Create a builder with customized build method on the python-side. + * \param f_build The packed function to the `Build` function.. + * \return The Builder created. + */ + static Builder PyBuilder(BuilderNode::FBuild f_build); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Builder, runtime::ObjectRef, BuilderNode); +}; + +/*! \brief An abstract builder with customized build method on the python-side. */ +class PyBuilderNode : public BuilderNode { + public: + /*! \brief The packed function to the `Build` function. */ + FBuild f_build; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_build` is not visited + } + + Array Build(const Array& build_inputs) final { + return f_build(build_inputs); + } + + static constexpr const char* _type_key = "meta_schedule.PyBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyBuilderNode, BuilderNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_BUILDER_H_ diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py new file mode 100644 index 000000000000..b12194e7e009 --- /dev/null +++ b/python/tvm/meta_schedule/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Package `tvm.meta_schedule`. The meta schedule infrastructure.""" +from . import builder diff --git a/python/tvm/meta_schedule/_ffi_api.py b/python/tvm/meta_schedule/_ffi_api.py new file mode 100644 index 000000000000..24022191a8b4 --- /dev/null +++ b/python/tvm/meta_schedule/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.meta_schedule""" +from .._ffi import _init_api + +_init_api("meta_schedule", __name__) # pylint: disable=protected-access diff --git a/python/tvm/meta_schedule/builder/__init__.py b/python/tvm/meta_schedule/builder/__init__.py new file mode 100644 index 000000000000..859c74d75622 --- /dev/null +++ b/python/tvm/meta_schedule/builder/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.builder package. +Meta Schedule builders that translate IRModule to runtime.Module, +and then export +""" +from .builder import Builder, BuilderInput, BuilderResult, PyBuilder +from .local_builder import LocalBuilder diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py new file mode 100644 index 000000000000..ed81f4c0d3f9 --- /dev/null +++ b/python/tvm/meta_schedule/builder/builder.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule builders that translate IRModule to runtime.Module, and then export""" +from typing import List, Optional + +from tvm._ffi import register_object +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm.target import Target + +from .. import _ffi_api + + +@register_object("meta_schedule.BuilderInput") +class BuilderInput(Object): + """The builder's input. + + Parameters + ---------- + mod : IRModule + The IRModule to be built. + target : Target + The target to be built for. + """ + + mod: IRModule + target: Target + + def __init__(self, mod: IRModule, target: Target) -> None: + """Constructor. + + Parameters + ---------- + mod : IRModule + The IRModule to be built. + target : Target + The target to be built for. + """ + self.__init_handle_by_constructor__( + _ffi_api.BuilderInput, # type: ignore # pylint: disable=no-member + mod, + target, + ) + + +@register_object("meta_schedule.BuilderResult") +class BuilderResult(Object): + """The builder's result. + + Parameters + ---------- + artifact_path : Optional[str] + The path to the artifact. + error_msg : Optional[str] + The error message. + """ + + artifact_path: Optional[str] + error_msg: Optional[str] + + def __init__( + self, + artifact_path: Optional[str], + error_msg: Optional[str], + ) -> None: + """Constructor. + + Parameters + ---------- + artifact_path : Optional[str] + The path to the artifact. + error_msg : Optional[str] + The error message. + """ + self.__init_handle_by_constructor__( + _ffi_api.BuilderResult, # type: ignore # pylint: disable=no-member + artifact_path, + error_msg, + ) + + +@register_object("meta_schedule.Builder") +class Builder(Object): + """The abstract builder interface.""" + + def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: + """Build the given inputs. + + Parameters + ---------- + build_inputs : List[BuilderInput] + The inputs to be built. + Returns + ------- + build_results : List[BuilderResult] + The results of building the given inputs. + """ + return _ffi_api.BuilderBuild(self, build_inputs) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyBuilder") +class PyBuilder(Builder): + """An abstract builder with customized build method on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]: + return self.build(build_inputs) + + self.__init_handle_by_constructor__( + _ffi_api.BuilderPyBuilder, # type: ignore # pylint: disable=no-member + f_build, + ) + + def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: + raise NotImplementedError diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py new file mode 100644 index 000000000000..cefe5ec50cad --- /dev/null +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Local builder that compile on the local host""" +import os +import tempfile +from typing import Callable, List, Optional, Union + +from tvm._ffi import register_func +from tvm.ir import IRModule +from tvm.runtime import Module +from tvm.target import Target + +from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind +from ..utils import cpu_count, get_global_func_with_default_on_worker +from .builder import BuilderInput, BuilderResult, PyBuilder + + +class LocalBuilder(PyBuilder): + """A builder that builds the given input on local host. + + Parameters + ---------- + pool : PopenPoolExecutor + The process pool to run the build. + timeout_sec : float + The timeout in seconds for the build. + f_build : Union[None, str, LocalBuilder.T_BUILD] + Name of the build function to be used. + Defaults to `meta_schedule.builder.default_build`. + f_export : Union[None, str, LocalBuilder.T_EXPORT] + Name of the export function to be used. + Defaults to `meta_schedule.builder.default_export`. + + Attributes + ---------- + T_BUILD : typing._GenericAlias + The signature of the build function `f_build`, which is + `Callable[[IRModule, Target], Module]` + T_EXPORT : typing._GenericAlias + The signature of the build function `f_export`, which is + `Callable[[Module], str]` + + Note + ---- + The build function and export function should be registered in the worker process. + The worker process is only aware of functions registered in TVM package, + if there are extra functions to be registered, + please send the registration logic via initializer. + """ + + T_BUILD = Callable[[IRModule, Target], Module] + T_EXPORT = Callable[[Module], str] + + pool: PopenPoolExecutor + timeout_sec: float + f_build: Union[None, str, T_BUILD] + f_export: Union[None, str, T_EXPORT] + + def __init__( + self, + *, + max_workers: Optional[int] = None, + timeout_sec: float = 30.0, + f_build: Union[None, str, T_BUILD] = None, + f_export: Union[None, str, T_EXPORT] = None, + initializer: Optional[Callable[[], None]] = None, + ) -> None: + """Constructor. + + Parameters + ---------- + max_workers : Optional[int] + The maximum number of worker processes to be used. + Defaults to number of CPUs. + timeout_sec : float + The timeout in seconds for the build. + f_build : LocalBuilder.T_BUILD + Name of the build function to be used. + Defaults to `meta_schedule.builder.default_build`. + f_export : LocalBuilder.T_EXPORT + Name of the export function to be used. + Defaults to `meta_schedule.builder.default_export`. + initializer : Optional[Callable[[], None]] + The initializer to be used for the worker processes. + """ + super().__init__() + + if max_workers is None: + max_workers = cpu_count() + + self.pool = PopenPoolExecutor( + max_workers=max_workers, + timeout=timeout_sec, + initializer=initializer, + ) + self.timeout_sec = timeout_sec + self.f_build = f_build + self.f_export = f_export + self._sanity_check() + + def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: + results: List[BuilderResult] = [] + map_result: MapResult + + # Dispatch the build inputs to the worker processes. + for map_result in self.pool.map_with_error_catching( + lambda x: LocalBuilder._worker_func(*x), + [ + ( + self.f_build, + self.f_export, + build_input.mod, + build_input.target, + ) + for build_input in build_inputs + ], + ): + if map_result.status == StatusKind.COMPLETE: + results.append(BuilderResult(map_result.value, None)) + elif map_result.status == StatusKind.TIMEOUT: + results.append( + BuilderResult( + None, + f"LocalBuilder: Timeout, killed after {self.timeout_sec} seconds", + ) + ) + elif map_result.status == StatusKind.EXCEPTION: + results.append( + BuilderResult( + None, + "LocalBuilder: An exception occurred\n" + str(map_result.value), + ) + ) + else: + raise ValueError("Unreachable: unexpected result: {map_result}") + return results + + def _sanity_check(self) -> None: + def _check(f_build, f_export) -> None: + get_global_func_with_default_on_worker(name=f_build, default=None) + get_global_func_with_default_on_worker(name=f_export, default=None) + + value = self.pool.submit(_check, self.f_build, self.f_export) + value.result() + + @staticmethod + def _worker_func( + _f_build: Union[None, str, T_BUILD], + _f_export: Union[None, str, T_EXPORT], + mod: IRModule, + target: Target, + ) -> str: + # Step 0. Get the registered functions + f_build: LocalBuilder.T_BUILD = get_global_func_with_default_on_worker( + _f_build, + default_build, + ) + f_export: LocalBuilder.T_EXPORT = get_global_func_with_default_on_worker( + _f_export, + default_export, + ) + # Step 1. Build the IRModule + rt_mod: Module = f_build(mod, target) + # Step 2. Export the Module + artifact_path: str = f_export(rt_mod) + return artifact_path + + +@register_func("meta_schedule.builder.default_build") +def default_build(mod: IRModule, target: Target) -> Module: + """Default build function. + + Parameters + ---------- + mod : IRModule + The IRModule to be built. + target : Target + The target to be built. + + Returns + ------- + rt_mod : Module + The built Module. + """ + # pylint: disable=import-outside-toplevel + from tvm.autotvm.measure.measure_methods import set_cuda_target_arch + from tvm.driver import build as tvm_build + + # pylint: enable=import-outside-toplevel + + if target.kind.name == "cuda": + set_cuda_target_arch(target.attrs["arch"]) + + return tvm_build(mod, target=target) + + +@register_func("meta_schedule.builder.default_export") +def default_export(mod: Module) -> str: + """Default export function. + + Parameters + ---------- + mod : Module + The Module to be exported. + + Returns + ------- + artifact_path : str + The path to the exported Module. + """ + from tvm.contrib.tar import tar # pylint: disable=import-outside-toplevel + + artifact_path = os.path.join(tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format) + mod.export_library(artifact_path, tar) + return artifact_path diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py new file mode 100644 index 000000000000..74f93e86f506 --- /dev/null +++ b/python/tvm/meta_schedule/utils.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities for meta schedule""" +import os +import shutil +from typing import Callable, Union + +import psutil + +from tvm._ffi import get_global_func, register_func +from tvm.error import TVMError + + +@register_func("meta_schedule.cpu_count") +def cpu_count(logical: bool = True) -> int: + """Return the number of logical or physical CPUs in the system + + Parameters + ---------- + logical : bool = True + If True, return the number of logical CPUs, otherwise return the number of physical CPUs + + Returns + ------- + cpu_count : int + The number of logical or physical CPUs in the system + + Note + ---- + The meta schedule search infra intentionally does not adopt the following convention in TVM: + - C++ API `tvm::runtime::threading::MaxConcurrency()` + - Environment variable `TVM_NUM_THREADS` or + - Environment variable `OMP_NUM_THREADS` + + This is because these variables are dedicated to controlling + the runtime behavior of generated kernels, instead of the host-side search. + Setting these variables may interfere the host-side search with profiling of generated kernels + when measuring locally. + """ + return psutil.cpu_count(logical=logical) or 1 + + +def get_global_func_with_default_on_worker( + name: Union[None, str, Callable], + default: Callable, +) -> Callable: + """Get the registered global function on the worker process. + + Parameters + ---------- + name : Union[None, str, Callable] + If given a string, retrieve the function in TVM's global registry; + If given a python function, return it as it is; + Otherwise, return `default`. + + default : Callable + The function to be returned if `name` is None. + + Returns + ------- + result : Callable + The retrieved global function or `default` if `name` is None + """ + if name is None: + return default + if callable(name): + return name + try: + return get_global_func(name) + except TVMError as error: + raise ValueError( + "Function '{name}' is not registered on the worker process. " + "The build function and export function should be registered in the worker process. " + "Note that the worker process is only aware of functions registered in TVM package, " + "if there are extra functions to be registered, " + "please send the registration logic via initializer." + ) from error + + +@register_func("meta_schedule.remove_build_dir") +def remove_build_dir(artifact_path: str) -> None: + """Clean up the build directory""" + shutil.rmtree(os.path.dirname(artifact_path)) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 9a7857a01fe6..688422284c0f 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1204,7 +1204,7 @@ def FakeQuantizationToInteger(): x w | | dq dq - \ / + \\ / op1 | op2 diff --git a/src/meta_schedule/builder/builder.cc b/src/meta_schedule/builder/builder.cc new file mode 100644 index 000000000000..fb63b7e65332 --- /dev/null +++ b/src/meta_schedule/builder/builder.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/******** Constructors ********/ + +BuilderInput::BuilderInput(IRModule mod, Target target) { + ObjectPtr n = make_object(); + n->mod = std::move(mod); + n->target = std::move(target); + data_ = std::move(n); +} + +BuilderResult::BuilderResult(Optional artifact_path, Optional error_msg) { + ObjectPtr n = make_object(); + n->artifact_path = std::move(artifact_path); + n->error_msg = std::move(error_msg); + data_ = std::move(n); +} + +Builder Builder::PyBuilder(BuilderNode::FBuild f_build) { + ObjectPtr n = make_object(); + n->f_build = std::move(f_build); + return Builder(std::move(n)); +} + +/******** FFI ********/ + +TVM_REGISTER_NODE_TYPE(BuilderInputNode); +TVM_REGISTER_NODE_TYPE(BuilderResultNode); +TVM_REGISTER_OBJECT_TYPE(BuilderNode); +TVM_REGISTER_NODE_TYPE(PyBuilderNode); + +TVM_REGISTER_GLOBAL("meta_schedule.BuilderInput") + .set_body_typed([](IRModule mod, Target target) -> BuilderInput { + return BuilderInput(mod, target); + }); + +TVM_REGISTER_GLOBAL("meta_schedule.BuilderResult") + .set_body_typed([](Optional artifact_path, + Optional error_msg) -> BuilderResult { + return BuilderResult(artifact_path, error_msg); + }); + +TVM_REGISTER_GLOBAL("meta_schedule.BuilderBuild").set_body_method(&BuilderNode::Build); + +TVM_REGISTER_GLOBAL("meta_schedule.BuilderPyBuilder").set_body_typed(Builder::PyBuilder); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h new file mode 100644 index 000000000000..47331203a25a --- /dev/null +++ b/src/meta_schedule/utils.h @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_UTILS_H_ +#define TVM_META_SCHEDULE_UTILS_H_ + +#include + +namespace tvm { +namespace meta_schedule {} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_UTILS_H_ diff --git a/tests/python/unittest/test_meta_schedule_builder.py b/tests/python/unittest/test_meta_schedule_builder.py new file mode 100644 index 000000000000..f97ede881330 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_builder.py @@ -0,0 +1,219 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Builder """ + +import os +import sys +import time +from typing import List + +import pytest + +from tvm import tir, script +from tvm._ffi import register_func +from tvm.meta_schedule.builder import ( + BuilderInput, + BuilderResult, + LocalBuilder, + PyBuilder, +) +from tvm.runtime import Module +from tvm.script import ty +from tvm.target import Target + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +@script.tir +class MatmulModule: + def matmul( # pylint: disable=no-self-argument + a: ty.handle, b: ty.handle, c: ty.handle + ) -> None: + tir.func_attr({"global_symbol": "matmul", "tir.noalias": True}) + A = tir.match_buffer(a, (1024, 1024), "float32") + B = tir.match_buffer(b, (1024, 1024), "float32") + C = tir.match_buffer(c, (1024, 1024), "float32") + with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@script.tir +class MatmulReluModule: + def matmul_relu( # pylint: disable=no-self-argument + a: ty.handle, b: ty.handle, d: ty.handle + ) -> None: + tir.func_attr({"global_symbol": "matmul_relu", "tir.noalias": True}) + A = tir.match_buffer(a, (1024, 1024), "float32") + B = tir.match_buffer(b, (1024, 1024), "float32") + D = tir.match_buffer(d, (1024, 1024), "float32") + C = tir.alloc_buffer((1024, 1024), "float32") + with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + with tir.block([1024, 1024], "relu") as [vi, vj]: + D[vi, vj] = tir.max(C[vi, vj], 0.0) + + +@script.tir +class BatchMatmulModule: + def batch_matmul( # pylint: disable=no-self-argument + a: ty.handle, b: ty.handle, c: ty.handle + ) -> None: + tir.func_attr({"global_symbol": "batch_matmul", "tir.noalias": True}) + A = tir.match_buffer(a, [16, 128, 128]) + B = tir.match_buffer(b, [16, 128, 128]) + C = tir.match_buffer(c, [16, 128, 128]) + with tir.block([16, 128, 128, tir.reduce_axis(0, 128)], "update") as [vn, vi, vj, vk]: + with tir.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +def _check_build_results(builder_results: List[BuilderResult]): + """Simple check whether the build is successful""" + for result in builder_results: + artifact_path = result.artifact_path + error_msg = result.error_msg + assert artifact_path is not None + assert error_msg is None + os.remove(artifact_path) + os.rmdir(os.path.dirname(artifact_path)) + + +def test_meta_schedule_single_build(): + """Test meta schedule builder for a single build""" + mod = MatmulModule() + builder = LocalBuilder() + builder_inputs = [BuilderInput(mod, Target("llvm"))] + builder_results = builder.build(builder_inputs) + assert len(builder_results) == len(builder_inputs) + _check_build_results(builder_results) + + +def test_meta_schedule_multiple_build(): + """Test meta schedule builder for multiple builds""" + builder = LocalBuilder() + builder_inputs = [ + BuilderInput(MatmulModule(), Target("llvm")), + BuilderInput(MatmulReluModule(), Target("llvm")), + BuilderInput(BatchMatmulModule(), Target("llvm")), + ] + builder_results = builder.build(builder_inputs) + assert len(builder_results) == len(builder_inputs) + _check_build_results(builder_results) + + +def test_meta_schedule_error_handle_test_builder(): + """Test the error handing during building""" + + class TestBuilder(PyBuilder): + def build( # pylint: disable=no-self-use + self, + build_inputs: List[BuilderInput], + ) -> List[BuilderResult]: + return [BuilderResult(None, "error") for w in build_inputs] + + builder = TestBuilder() + builder_inputs = [ + BuilderInput(MatmulModule(), Target("llvm")), + BuilderInput(MatmulReluModule(), Target("llvm")), + BuilderInput(BatchMatmulModule(), Target("llvm")), + ] + builder_results = builder.build(builder_inputs) + assert len(builder_results) == len(builder_inputs) + for result in builder_results: + artifact_path = result.artifact_path + error_msg = result.error_msg + assert artifact_path is None + assert error_msg == "error" + + +def test_meta_schedule_error_handle_build_func(): + """Test the error handing during building""" + + def initializer(): + @register_func("meta_schedule.builder.test_build") + def test_build(mod: Module, target: Target) -> None: # pylint: disable=unused-variable + raise ValueError("Builder intended Test Error (build func).") + + builder = LocalBuilder(f_build="meta_schedule.builder.test_build", initializer=initializer) + builder_inputs = [BuilderInput(MatmulModule(), Target("llvm"))] + builder_results = builder.build(builder_inputs) + assert len(builder_results) == len(builder_inputs) + for result in builder_results: + artifact_path = result.artifact_path + error_msg = result.error_msg + assert artifact_path is None + assert error_msg.startswith("LocalBuilder: An exception occurred") + + +def test_meta_schedule_error_handle_export_func(): + """Test the error handing during building""" + + def initializer(): + @register_func("meta_schedule.builder.test_export") + def test_build(mod: Module) -> str: # pylint: disable=unused-variable + raise ValueError("Builder intended Test Error (export func).") + + builder = LocalBuilder(f_export="meta_schedule.builder.test_export", initializer=initializer) + builder_inputs = [BuilderInput(MatmulModule(), Target("llvm"))] + builder_results = builder.build(builder_inputs) + assert len(builder_results) == len(builder_inputs) + for result in builder_results: + artifact_path = result.artifact_path + error_msg = result.error_msg + assert artifact_path is None + assert error_msg.startswith("LocalBuilder: An exception occurred") + + +def test_meta_schedule_error_handle_time_out(): + """Test the error handing time out during building""" + + def initializer(): + @register_func("meta_schedule.builder.test_time_out") + def timeout_build(mod, target): # pylint: disable=unused-argument, unused-variable + time.sleep(2) + + builder = LocalBuilder( + timeout_sec=1, + f_build="meta_schedule.builder.test_time_out", + initializer=initializer, + ) + builder_inputs = [BuilderInput(MatmulModule(), Target("llvm"))] + builder_results = builder.build(builder_inputs) + assert len(builder_results) == len(builder_inputs) + for result in builder_results: + artifact_path = result.artifact_path + error_msg = result.error_msg + assert artifact_path is None + assert error_msg.startswith("LocalBuilder: Timeout") + + +def test_meta_schedule_missing_build_func(): + with pytest.raises(ValueError): + LocalBuilder(f_build="wrong-name") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index 8507f311e9da..05d1c238b64f 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -17,13 +17,16 @@ # under the License. set -o pipefail -echo "Checking MyPy Type defs in the schedule package." +echo "Checking MyPy Type defs in the TensorIR schedule package." mypy --check-untyped-defs python/tvm/tir/schedule +echo "Checking MyPy Type defs in the meta schedule package." +mypy --check-untyped-defs python/tvm/meta_schedule + echo "Checking MyPy Type defs in the analysis package." mypy --check-untyped-defs python/tvm/tir/analysis/ -echo "Checking MyPy Type defs in the transofrm package." +echo "Checking MyPy Type defs in the transform package." mypy --check-untyped-defs python/tvm/tir/transform/ echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."