diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h new file mode 100644 index 000000000000..b05dc3c11802 --- /dev/null +++ b/include/tvm/meta_schedule/cost_model.h @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_COST_MODEL_H_ +#define TVM_META_SCHEDULE_COST_MODEL_H_ + +#include + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Cost model. */ +class CostModelNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~CostModelNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Load the cost model from given file location. + * \param path The file path. + */ + virtual void Load(const String& path) = 0; + + /*! + * \brief Save the cost model to given file location. + * \param path The file path. + */ + virtual void Save(const String& path) = 0; + + /*! + * \brief Update the cost model given running results. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param results The running results of the measure candidates. + */ + virtual void Update(const TuneContext& tune_context, const Array& candidates, + const Array& results) = 0; + + /*! + * \brief Predict the normalized score (the larger the better) of given measure candidates. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \return The predicted normalized score. + */ + virtual std::vector Predict(const TuneContext& tune_context, + const Array& candidates) = 0; + + static constexpr const char* _type_key = "meta_schedule.CostModel"; + TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); +}; + +/*! \brief The cost model with customized methods on the python-side. */ +class PyCostModelNode : public CostModelNode { + public: + /*! + * \brief Load the cost model from given file location. + * \param path The file path. + */ + using FLoad = runtime::TypedPackedFunc; + /*! + * \brief Save the cost model to given file location. + * \param path The file path. + */ + using FSave = runtime::TypedPackedFunc; + /*! + * \brief Update the cost model given running results. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param results The running results of the measure candidates. + * \return Whether cost model was updated successfully. + */ + using FUpdate = runtime::TypedPackedFunc&, + const Array&)>; + /*! + * \brief Predict the running results of given measure candidates. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param p_addr The address to save the the estimated running results. + */ + using FPredict = runtime::TypedPackedFunc&, + void* p_addr)>; + /*! + * \brief Get the cost model as string with name. + * \return The string representation of the cost model. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Load` function. */ + FLoad f_load; + /*! \brief The packed function to the `Save` function. */ + FSave f_save; + /*! \brief The packed function to the `Update` function. */ + FUpdate f_update; + /*! \brief The packed function to the `Predict` function. */ + FPredict f_predict; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_load` is not visited + // `f_save` is not visited + // `f_update` is not visited + // `f_predict` is not visited + // `f_as_string` is not visited + } + + void Load(const String& path) { + ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; + f_load(path); + } + + void Save(const String& path) { + ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; + f_save(path); + } + void Update(const TuneContext& tune_context, const Array& candidates, + const Array& results) { + ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; + f_update(tune_context, candidates, results); + } + + std::vector Predict(const TuneContext& tune_context, + const Array& candidates) { + ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; + std::vector result(candidates.size(), 0.0); + f_predict(tune_context, candidates, result.data()); + return result; + } + + static constexpr const char* _type_key = "meta_schedule.PyCostModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode); +}; + +/*! + * \brief Managed reference to CostModelNode + * \sa CostModelNode + */ +class CostModel : public runtime::ObjectRef { + public: + /*! + * \brief Create a feature extractor with customized methods on the python-side. + * \param f_load The packed function of `Load`. + * \param f_save The packed function of `Save`. + * \param f_update The packed function of `Update`. + * \param f_predict The packed function of `Predict`. + * \param f_as_string The packed function of `AsString`. + * \return The feature extractor created. + */ + TVM_DLL static CostModel PyCostModel(PyCostModelNode::FLoad f_load, // + PyCostModelNode::FSave f_save, // + PyCostModelNode::FUpdate f_update, // + PyCostModelNode::FPredict f_predict, // + PyCostModelNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_COST_MODEL_H_ diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h new file mode 100644 index 000000000000..ee5d94c13c98 --- /dev/null +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ +#define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Extractor for features from measure candidates for use in cost model. */ +class FeatureExtractorNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~FeatureExtractorNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Extract features from the given measure candidate. + * \param tune_context The tuning context for feature extraction. + * \param candidates The measure candidates to extract features from. + * \return The feature ndarray extracted. + */ + virtual Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) = 0; + + static constexpr const char* _type_key = "meta_schedule.FeatureExtractor"; + TVM_DECLARE_BASE_OBJECT_INFO(FeatureExtractorNode, Object); +}; + +/*! \brief The feature extractor with customized methods on the python-side. */ +class PyFeatureExtractorNode : public FeatureExtractorNode { + public: + /*! + * \brief Extract features from the given measure candidate. + * \param tune_context The tuning context for feature extraction. + * \param candidates The measure candidates to extract features from. + * \return The feature ndarray extracted. + */ + using FExtractFrom = runtime::TypedPackedFunc( + const TuneContext& tune_context, const Array& candidates)>; + /*! + * \brief Get the feature extractor as string with name. + * \return The string of the feature extractor. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `ExtractFrom` function. */ + FExtractFrom f_extract_from; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_extract_from` is not visited + // `f_as_string` is not visited + } + + Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) { + ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; + return f_extract_from(tune_context, candidates); + } + + static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode); +}; + +/*! + * \brief Managed reference to FeatureExtractorNode + * \sa FeatureExtractorNode + */ +class FeatureExtractor : public runtime::ObjectRef { + public: + /*! + * \brief Create a feature extractor that extracts features from each BufferStore + * \param buffers_per_store The number of buffers in each BufferStore; Pad or truncate if + * necessary. + * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity + * curve. + * \param cache_line_bytes The number of bytes in a cache line. + * \return The feature extractor created. + */ + TVM_DLL static FeatureExtractor PerStoreFeature(int buffers_per_store = 5, + int arith_intensity_curve_num_samples = 10, + int cache_line_bytes = 64); + /*! + * \brief Create a feature extractor with customized methods on the python-side. + * \param f_extract_from The packed function of `ExtractFrom`. + * \param f_as_string The packed function of `AsString`. + * \return The feature extractor created. + */ + TVM_DLL static FeatureExtractor PyFeatureExtractor( + PyFeatureExtractorNode::FExtractFrom f_extract_from, + PyFeatureExtractorNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FeatureExtractor, ObjectRef, FeatureExtractorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py new file mode 100644 index 000000000000..3d4a81e1222f --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.cost_model package. +""" +from .cost_model import CostModel, PyCostModel +from .random_model import RandomModel diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py new file mode 100644 index 000000000000..f5bd60162ec5 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule CostModel.""" +import ctypes +from typing import List + +import numpy as np # type: ignore +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api +from ..runner import RunnerResult +from ..search_strategy import MeasureCandidate +from ..tune_context import TuneContext +from ..utils import _get_hex_address, check_override + + +@register_object("meta_schedule.CostModel") +class CostModel(Object): + """Cost model.""" + + def load(self, path: str) -> None: + """Load the cost model from given file location. + + Parameters + ---------- + path : str + The file path. + """ + _ffi_api.CostModelLoad(self, path) # type: ignore # pylint: disable=no-member + + def save(self, path: str) -> None: + """Save the cost model to given file location. + + Parameters + ---------- + path : str + The file path. + """ + _ffi_api.CostModelSave(self, path) # type: ignore # pylint: disable=no-member + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + """ + _ffi_api.CostModelUpdate(self, tune_context, candidates, results) # type: ignore # pylint: disable=no-member + + def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : np.ndarray + The predicted normalized score. + """ + n = len(candidates) + results = np.zeros(shape=(n,), dtype="float64") + _ffi_api.CostModelPredict( # type: ignore # pylint: disable=no-member + self, + tune_context, + candidates, + results.ctypes.data_as(ctypes.c_void_p), + ) + return results + + +@register_object("meta_schedule.PyCostModel") +class PyCostModel(CostModel): + """An abstract CostModel with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, CostModel) + def f_load(path: str) -> None: + self.load(path) + + @check_override(self.__class__, CostModel) + def f_save(path: str) -> None: + self.save(path) + + @check_override(self.__class__, CostModel) + def f_update( + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + self.update(tune_context, candidates, results) + + @check_override(self.__class__, CostModel) + def f_predict( + tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr + ) -> None: + n = len(candidates) + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) + array_wrapper[:] = self.predict(tune_context, candidates) + assert ( + array_wrapper.dtype == "float64" + ), "ValueError: Invalid data type returned from CostModel Predict!" + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.CostModelPyCostModel, # type: ignore # pylint: disable=no-member + f_load, + f_save, + f_update, + f_predict, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/cost_model/metric.py b/python/tvm/meta_schedule/cost_model/metric.py new file mode 100644 index 000000000000..efd8dc68ac0d --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/metric.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Cost model metrics for meta schedule""" +import numpy as np # type: ignore + + +def max_curve(trial_scores: np.ndarray) -> np.ndarray: + """f(n) = max([s[i] fo i < n]) + + Parameters + ---------- + trial_scores : List[float] + the score of i-th trial + + Returns + ------- + curve : np.ndarray + A vector, the max-curve function values + """ + ret = np.empty(len(trial_scores)) + keep = -1e9 + for i, score in enumerate(trial_scores): + keep = max(keep, score) + ret[i] = keep + return ret diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py new file mode 100644 index 000000000000..23238d25797c --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/random_model.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Random cost model +""" +from typing import List, Optional, Tuple, Union + +import numpy as np # type: ignore + +from ..cost_model import PyCostModel +from ..runner import RunnerResult +from ..search_strategy import MeasureCandidate +from ..tune_context import TuneContext + + +class RandomModel(PyCostModel): + """Random cost model + + Parameters + ---------- + random_state : Union[Tuple[str, np.ndarray, int, int, float], dict] + The random state of the random number generator. + path : Optional[str] + The path of the random cost model. + max_range : Optional[int] + The maximum range of random results, [0, max_range]. + + Reference + --------- + https://numpy.org/doc/stable/reference/random/generated/numpy.random.get_state.html + """ + + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + path: Optional[str] + + def __init__( + self, + *, + seed: Optional[int] = None, + path: Optional[str] = None, + max_range: Optional[int] = 100, + ): + super().__init__() + if path is not None: + self.load(path) + else: + np.random.seed(seed) + self.random_state = np.random.get_state() + self.max_range = max_range + + def load(self, path: str) -> None: + """Load the cost model from given file location. + + Parameters + ---------- + path : str + The file path. + """ + self.random_state = tuple(np.load(path, allow_pickle=True)) # type: ignore + + def save(self, path: str) -> None: + """Save the cost model to given file location. + + Parameters + ---------- + path : str + The file path. + """ + np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True) + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + """ + + def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : np.ndarray + The predicted running results. + """ + np.random.set_state(self.random_state) + # TODO(@zxybazh): Use numpy's RandState object: + # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState + result = np.random.rand(len(candidates)) * self.max_range + self.random_state = np.random.get_state() + return result diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py b/python/tvm/meta_schedule/feature_extractor/__init__.py new file mode 100644 index 000000000000..f29c44bd1efd --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.feature_extractor package. +Meta Schedule feature extractors that extracts features from +measure candidates for use in cost model. +""" +from .feature_extractor import FeatureExtractor, PyFeatureExtractor +from .random_feature_extractor import RandomFeatureExtractor diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py new file mode 100644 index 000000000000..bd7656e5bef1 --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule FeatureExtractor.""" +from typing import List + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.runtime.ndarray import NDArray + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate + + +@register_object("meta_schedule.FeatureExtractor") +class FeatureExtractor(Object): + """Extractor for features from measure candidates for use in cost model.""" + + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + """Extract features from the given measure candidate. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for feature extraction. + candidates : List[MeasureCandidate] + The measure candidates to extract features from. + + Returns + ------- + features : List[NDArray] + The feature numpy ndarray extracted. + """ + result = _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member + self, tune_context, candidates + ) + return result + + +@register_object("meta_schedule.PyFeatureExtractor") +class PyFeatureExtractor(FeatureExtractor): + """An abstract feature extractor with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, FeatureExtractor) + def f_extract_from( + tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + features = self.extract_from(tune_context, candidates) + return features + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.FeatureExtractorPyFeatureExtractor, # type: ignore # pylint: disable=no-member + f_extract_from, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py new file mode 100644 index 000000000000..7c72a25b2378 --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Random Feature Extractor.""" +from typing import List, Union, Tuple + +import numpy as np # type: ignore +from tvm.runtime.ndarray import NDArray, array + +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..feature_extractor import PyFeatureExtractor + + +class RandomFeatureExtractor(PyFeatureExtractor): + """Random Feature Extractor + + Parameters + ---------- + feature_size : int + The size of each block's feature vector. + max_block_num : int + The maximum number of blocks in each schedule. + random_state : Union[Tuple[str, np.ndarray, int, int, float], dict] + The current random state of the f + """ + + feature_size: int + max_block_num: int + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + + def __init__(self, *, feature_size: int = 30, max_block_num: int = 5, seed=0): + super().__init__() + assert max_block_num >= 1, "Max block number must be greater or equal to one!" + self.max_block_num = max_block_num + self.feature_size = feature_size + np.random.seed(seed) + self.random_state = np.random.get_state() + + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + np.random.set_state(self.random_state) + result = [ + np.random.rand(np.random.randint(1, self.max_block_num + 1), self.feature_size) + for candidate in candidates + ] + self.random_state = np.random.get_state() + return [array(x) for x in result] diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index 609baa267786..298cdae4283a 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -19,6 +19,5 @@ Meta Schedule search strategy utilizes the design spaces given to generate measure candidates. """ - -from .search_strategy import SearchStrategy, PySearchStrategy +from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy from .replay_trace import ReplayTrace diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index a9ef514543f8..aaaa956140ab 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Utilities for meta schedule""" +import ctypes import json import os import shutil @@ -24,7 +25,7 @@ import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError -from tvm.ir import Array, Map, IRModule +from tvm.ir import Array, IRModule, Map from tvm.rpc import RPCSession from tvm.runtime import PackedFunc, String from tvm.tir import FloatImm, IntImm @@ -245,3 +246,17 @@ def inner(func: Callable): return func return inner + + +def _get_hex_address(handle: ctypes.c_void_p) -> str: + """Get the hexadecimal address of a handle. + Parameters + ---------- + handle : ctypes.c_void_p + The handle to be converted. + Returns + ------- + result : str + The hexadecimal address of the handle. + """ + return hex(ctypes.cast(handle, ctypes.c_void_p).value) diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc new file mode 100644 index 000000000000..5cd32b097caa --- /dev/null +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, // + PyCostModelNode::FSave f_save, // + PyCostModelNode::FUpdate f_update, // + PyCostModelNode::FPredict f_predict, // + PyCostModelNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_load = std::move(f_load); + n->f_save = std::move(f_save); + n->f_update = std::move(f_update); + n->f_predict = std::move(f_predict); + n->f_as_string = std::move(f_as_string); + return CostModel(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyCostModelNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyCostModel's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(CostModelNode); +TVM_REGISTER_NODE_TYPE(PyCostModelNode); + +TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate") + .set_body_method(&CostModelNode::Update); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict") + .set_body_typed([](CostModel model, // + const TuneContext& tune_context, // + Array candidates, // + void* p_addr) -> void { + std::vector result = model->Predict(tune_context, candidates); + std::copy(result.begin(), result.end(), static_cast(p_addr)); + }); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc new file mode 100644 index 000000000000..84d22493aaa6 --- /dev/null +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +FeatureExtractor FeatureExtractor::PyFeatureExtractor( + PyFeatureExtractorNode::FExtractFrom f_extract_from, // + PyFeatureExtractorNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_extract_from = std::move(f_extract_from); + n->f_as_string = std::move(f_as_string); + return FeatureExtractor(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyFeatureExtractorNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyFeatureExtractor's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode); +TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom") + .set_body_method(&FeatureExtractorNode::ExtractFrom); +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor") + .set_body_typed(FeatureExtractor::PyFeatureExtractor); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 83e65a5ced44..9b0a37160a13 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -22,7 +22,9 @@ #include #include #include +#include #include +#include #include #include #include diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py new file mode 100644 index 000000000000..3f98d711ea61 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import os +import re +import shutil +import sys +import tempfile +from typing import List + +import numpy as np +import pytest +import tvm +from tvm.meta_schedule.cost_model import PyCostModel, RandomModel +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.tir.schedule.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,disable=unused-argument + + +def test_meta_schedule_cost_model(): + class FancyCostModel(PyCostModel): + def load(self, path: str) -> None: + pass + + def save(self, path: str) -> None: + pass + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + return np.random.rand(10) + + model = FancyCostModel() + model.save("fancy_test_location") + model.load("fancy_test_location") + model.update(TuneContext(), [], []) + results = model.predict(TuneContext, [MeasureCandidate(Schedule(mod=Matmul), [])]) + assert results.shape == (10,) + + +def test_meta_schedule_cost_model_as_string(): + class NotSoFancyCostModel(PyCostModel): + def load(self, path: str) -> None: + pass + + def save(self, path: str) -> None: + pass + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + return np.random.rand(10) + + cost_model = NotSoFancyCostModel() + pattern = re.compile(r"NotSoFancyCostModel\(0x[a-f|0-9]*\)") + assert pattern.match(str(cost_model)) + + +def test_meta_schedule_random_model(): + model = RandomModel() + model.update(TuneContext(), [], []) + res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(10)]) + assert len(res) == 10 + assert min(res) >= 0 and max(res) <= model.max_range + + +def test_meta_schedule_random_model_reseed(): + model = RandomModel(seed=100) + res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)]) + new_model = RandomModel(seed=100) + new_res = new_model.predict( + TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)] + ) + assert (res == new_res).all() + + +def test_meta_schedule_random_model_reload(): + model = RandomModel(seed=25973) + model.predict( + TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(30)] + ) # change state + path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_random_model.npy") + model.save(path) + res1 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) + model.load(path) + res2 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) + shutil.rmtree(os.path.dirname(path)) + assert (res1 == res2).all() + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py new file mode 100644 index 000000000000..143d446f48fd --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import re +from typing import List + +import numpy as np +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.feature_extractor import PyFeatureExtractor +from tvm.meta_schedule.search_strategy import MeasureCandidate + + +def test_meta_schedule_feature_extractor(): + class FancyFeatureExtractor(PyFeatureExtractor): + def extract_from( + self, + tune_context: TuneContext, # pylint: disable = unused-argument + candidates: List[MeasureCandidate], # pylint: disable = unused-argument + ) -> List[np.ndarray]: + return [np.random.rand(4, 5)] + + extractor = FancyFeatureExtractor() + features = extractor.extract_from(TuneContext(), []) + assert len(features) == 1 + assert features[0].shape == (4, 5) + + +def test_meta_schedule_feature_extractor_as_string(): + class NotSoFancyFeatureExtractor(PyFeatureExtractor): + def extract_from( + self, + tune_context: TuneContext, # pylint: disable = unused-argument + candidates: List[MeasureCandidate], # pylint: disable = unused-argument + ) -> List[np.ndarray]: + return [] + + feature_extractor = NotSoFancyFeatureExtractor() + pattern = re.compile(r"NotSoFancyFeatureExtractor\(0x[a-f|0-9]*\)") + assert pattern.match(str(feature_extractor)) + + +if __name__ == "__main__": + test_meta_schedule_feature_extractor() + test_meta_schedule_feature_extractor_as_string()