From 729aed6aa5009a6629cad31d3ff7efd6004cbd32 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 22 Sep 2021 08:59:54 -0700 Subject: [PATCH] [Meta Schedule][M3c] Argument Info (#9059) This PR is part of the meta schedule project (#8473) that adds metadata of each PrimFunc's argument. This feature is necessary for dynamic shape auto-tuning. Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/meta_schedule/arg_info.h | 111 ++++++++++++++++ include/tvm/runtime/container/map.h | 10 +- python/tvm/meta_schedule/__init__.py | 1 + python/tvm/meta_schedule/arg_info.py | 106 +++++++++++++++ python/tvm/meta_schedule/utils.py | 33 ++++- python/tvm/runtime/__init__.py | 2 +- src/meta_schedule/arg_info.cc | 122 ++++++++++++++++++ src/meta_schedule/utils.h | 3 + src/support/array.h | 83 +++++++++++- .../unittest/test_meta_schedule_arg_info.py | 71 ++++++++++ 10 files changed, 538 insertions(+), 4 deletions(-) create mode 100644 include/tvm/meta_schedule/arg_info.h create mode 100644 python/tvm/meta_schedule/arg_info.py create mode 100644 src/meta_schedule/arg_info.cc create mode 100644 tests/python/unittest/test_meta_schedule_arg_info.py diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h new file mode 100644 index 000000000000..08553a001374 --- /dev/null +++ b/include/tvm/meta_schedule/arg_info.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_ARG_INFO_H_ +#define TVM_META_SCHEDULE_ARG_INFO_H_ + +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief The argument information. */ +class ArgInfoNode : public runtime::Object { + public: + static constexpr const char* _type_key = "meta_schedule.ArgInfo"; + TVM_DECLARE_BASE_OBJECT_INFO(ArgInfoNode, runtime::Object); + + public: + /*! \brief Default destructor. */ + virtual ~ArgInfoNode() = default; + /*! \brief Converts the ArgInfo to its corresponding JSON representation. */ + virtual ObjectRef AsJSON() const = 0; +}; + +/*! + * \brief Managed reference to ArgInfoNode + * \sa ArgInfoNode + */ +class ArgInfo : public runtime::ObjectRef { + public: + /*! + * \brief Parse the argument information from a JSON object. + * \param json_obj The json object to parse. + * \return The argument information parsed. + */ + TVM_DLL static ArgInfo FromJSON(const ObjectRef& json_obj); + /*! + * \brief Extract a list of the argument information from PrimFunc. + * \param func The PrimFunc to get argument information from. + * \return An array of the argument information derived. + */ + TVM_DLL static Array FromPrimFunc(const tir::PrimFunc& func); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ArgInfo, runtime::ObjectRef, ArgInfoNode); + + protected: + ArgInfo() = default; +}; + +/*! \brief The tensor argument information. */ +class TensorInfoNode : public ArgInfoNode { + public: + /*! \brief The data type of the tensor. */ + runtime::DataType dtype; + /*! \brief The shape of the tensor. */ + runtime::ShapeTuple shape; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("shape", &shape); + } + + static constexpr const char* _type_key = "meta_schedule.TensorInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, ArgInfoNode); + + public: + ObjectRef AsJSON() const; +}; + +/*! + * \brief Managed reference to TensorInfoNode + * \sa TensorInfoNode + */ +class TensorInfo : public ArgInfo { + public: + /*! + * \brief Constructor of TensorInfo. + * \param dtype The data type of the tensor argument. + * \param shape The shape tuple of the tensor argument. + */ + TVM_DLL explicit TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape); + /*! + * \brief Parse the argument information from a JSON object. + * \param json_obj The json object to parse. + * \return The argument information parsed. + */ + TVM_DLL static TensorInfo FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorInfo, ArgInfo, TensorInfoNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_ARG_INFO_H_ diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h index 3fe4f697bb9e..977dbfbaaaa1 100644 --- a/include/tvm/runtime/container/map.h +++ b/include/tvm/runtime/container/map.h @@ -33,6 +33,7 @@ #include #include "./base.h" +#include "./optional.h" namespace tvm { namespace runtime { @@ -1344,7 +1345,14 @@ class Map : public ObjectRef { iterator end() const { return iterator(GetMapNode()->end()); } /*! \return find the key and returns the associated iterator */ iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } - + /*! \return The value associated with the key, NullOpt if not found */ + Optional Get(const K& key) const { + MapNode::iterator iter = GetMapNode()->find(key); + if (iter == GetMapNode()->end()) { + return NullOptType{}; + } + return DowncastNoCheck(iter->second); + } void erase(const K& key) { CopyOnWrite()->erase(key); } /*! diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 8e788e798e70..f0e8af223511 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -16,4 +16,5 @@ # under the License. """Package `tvm.meta_schedule`. The meta schedule infrastructure.""" from . import builder +from . import arg_info from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/arg_info.py b/python/tvm/meta_schedule/arg_info.py new file mode 100644 index 000000000000..a56ca86e8cb7 --- /dev/null +++ b/python/tvm/meta_schedule/arg_info.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The argument information""" +from typing import Any, List, Union + +from tvm._ffi import register_object +from tvm.runtime import DataType, Object, ShapeTuple +from tvm.tir import PrimFunc + +from . import _ffi_api +from .utils import _json_de_tvm + + +@register_object("meta_schedule.ArgInfo") +class ArgInfo(Object): + """Argument information""" + + def as_json(self) -> Any: + """Converts the ArgInfo to its corresponding JSON representation.""" + return _json_de_tvm(_ffi_api.ArgInfoAsJSON(self)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: Any) -> "ArgInfo": + """Parse the argument information from a JSON object. + + Parameters + ---------- + json_obj : Any + The json object to parse. + + Returns + ------- + parsed : ArgInfo + The argument information parsed. + """ + return _ffi_api.ArgInfoFromJSON(json_obj) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_prim_func(func: PrimFunc) -> List["ArgInfo"]: + """Extract a list of the argument information from PrimFunc. + + Parameters + ---------- + func : PrimFunc + The PrimFunc to get argument information from. + + Returns + ------- + extracted : List[ArgInfo] + An array of the argument information derived. + """ + return _ffi_api.ArgInfoFromPrimFunc(func) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.TensorInfo") +class TensorInfo(ArgInfo): + """Tensor argument information + + Parameters + ---------- + dtype : DataType + The data type of the tensor. + shape : ShapeTuple + The shape of the tensor. + """ + + dtype: DataType + shape: ShapeTuple + + def __init__( + self, + dtype: DataType, + shape: Union[ShapeTuple, List[int]], + ) -> None: + """Constructor + + Parameters + ---------- + dtype : DataType + The data type of the tensor. + shape : ShapeTuple + The shape of the tensor. + """ + if isinstance(shape, ShapeTuple): + shape_tuple = shape + else: + shape_tuple = ShapeTuple(shape) + self.__init_handle_by_constructor__( + _ffi_api.TensorInfo, # type: ignore # pylint: disable=no-member + dtype, + shape_tuple, + ) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 74f93e86f506..abde198cf6ec 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -17,12 +17,15 @@ """Utilities for meta schedule""" import os import shutil -from typing import Callable, Union +from typing import Any, Callable, Union import psutil from tvm._ffi import get_global_func, register_func from tvm.error import TVMError +from tvm.ir import Array, Map +from tvm.runtime import String +from tvm.tir import FloatImm, IntImm @register_func("meta_schedule.cpu_count") @@ -95,3 +98,31 @@ def get_global_func_with_default_on_worker( def remove_build_dir(artifact_path: str) -> None: """Clean up the build directory""" shutil.rmtree(os.path.dirname(artifact_path)) + + +def _json_de_tvm(obj: Any) -> Any: + """Unpack a TVM nested container to a JSON object in python. + + Parameters + ---------- + obj : Any + The TVM nested container to be unpacked. + + Returns + ------- + result : Any + The unpacked json object. + """ + if obj is None: + return None + if isinstance(obj, (int, float)): + return obj + if isinstance(obj, (IntImm, FloatImm)): + return obj.value + if isinstance(obj, (str, String)): + return str(obj) + if isinstance(obj, Array): + return [_json_de_tvm(i) for i in obj] + if isinstance(obj, Map): + return {_json_de_tvm(k): _json_de_tvm(v) for k, v in obj.items()} + raise TypeError("Not supported type: " + str(type(obj))) diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 71563b508290..b3504dbac506 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -29,5 +29,5 @@ from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib -from .container import String +from .container import String, ShapeTuple from .params import save_param_dict, load_param_dict diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc new file mode 100644 index 000000000000..104662b6aad0 --- /dev/null +++ b/src/meta_schedule/arg_info.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace meta_schedule { + +/******** ArgInfo ********/ + +ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { + // The JSON object is always an array whose first element is a tag. For example: + // `['TENSOR', 'float32', [1, 224, 224, 3]] + // Step 1. Extract the tag + String tag{runtime::ObjectPtr(nullptr)}; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() >= 1); + tag = Downcast(json_array->at(0)); + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + // Step 2. Dispatch the tag to corresponding subclass of ArgInfo + if (tag == "TENSOR") { + return TensorInfo::FromJSON(json_obj); + } + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj; + throw; +} + +Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { + using support::AsVector; + Array result; + result.reserve(func->params.size()); + for (const tir::Var& arg : func->params) { + if (Optional _buffer = func->buffer_map.Get(arg)) { + tir::Buffer buffer = _buffer.value(); + result.push_back(TensorInfo(/*dtype=*/buffer->dtype, + /*shape=*/AsVector(buffer->shape))); + } else { + LOG(FATAL) << "ValueError: Unsupported argument type: " << arg; + } + } + return result; +} + +/******** TensorInfo ********/ + +TensorInfo::TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape) { + ObjectPtr n = make_object(); + n->dtype = dtype; + n->shape = shape; + this->data_ = std::move(n); +} + +ObjectRef TensorInfoNode::AsJSON() const { + static String tag = "TENSOR"; + String dtype = DLDataType2String(this->dtype); + Array shape = support::AsArray(this->shape); + return Array{tag, dtype, shape}; +} + +TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { + DLDataType dtype; + Array shape; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 3); + // Load json[1] => dtype + { + String dtype_str = Downcast(json_array->at(1)); + dtype = runtime::String2DLDataType(dtype_str); + } + // Load json[2] => shape + shape = Downcast>(json_array->at(2)); + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return TensorInfo(DataType(dtype), ShapeTuple(shape.begin(), shape.end())); +} + +/******** Repr ********/ + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + p->stream << "TensorInfo(\"" << self->dtype << "\", " << self->shape << ")"; + }); + +/******** FFI ********/ + +TVM_REGISTER_OBJECT_TYPE(ArgInfoNode); +TVM_REGISTER_NODE_TYPE(TensorInfoNode); + +TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method(&ArgInfoNode::AsJSON); +TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc); +TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.TensorInfo") + .set_body_typed([](runtime::DataType dtype, runtime::ShapeTuple shape) -> TensorInfo { + return TensorInfo(dtype, shape); + }); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 47331203a25a..e6eae4d0d915 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -19,8 +19,11 @@ #ifndef TVM_META_SCHEDULE_UTILS_H_ #define TVM_META_SCHEDULE_UTILS_H_ +#include #include +#include "../src/support/array.h" + namespace tvm { namespace meta_schedule {} // namespace meta_schedule } // namespace tvm diff --git a/src/support/array.h b/src/support/array.h index 89e17433344b..95b4f58a2e22 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -75,9 +75,33 @@ inline bool ArrayWithSameContent(const std::vector& a, const std::vector * \return The result vector */ template -std::vector AsVector(const Array& vec); +inline std::vector AsVector(const Array& vec); + +/*! + * \brief Convert a std::vector to tvm::runtime::Array + * \tparam TSrc The type of elements in the source vector + * \tparam TDst The type of elements in the result Array + * \return The result vector + */ +template +inline Array AsArray(const std::vector& vec); + +/*! + * \brief Get the shape tuple as array + * \param shape The shape tuple + * \return An array of the shape tuple + */ +inline Array AsArray(const ShapeTuple& shape) { + Array result; + result.reserve(shape->size); + for (ShapeTuple::index_type i : shape) { + result.push_back(Integer(i)); + } + return result; +} /********** Implementation details of AsVector **********/ + namespace details { template @@ -130,11 +154,68 @@ struct AsVectorImpl { }; } // namespace details +/********** Implementation details of AsArray **********/ + +namespace details { + +template +struct AsArrayImpl {}; + +template +struct AsArrayImpl { + inline Array operator()(const std::vector& vec) const { + return Array(vec.begin(), vec.end()); + } +}; + +template +struct AsArrayImpl { + inline Array operator()(const std::vector& vec) const { + Array result; + result.reserve(vec.size()); + for (int x : vec) { + result.push_back(Integer(x)); + } + return result; + } +}; + +template +struct AsArrayImpl { + inline Array operator()(const std::vector& vec) const { + Array result; + result.reserve(vec.size()); + for (int64_t x : vec) { + result.push_back(Integer(x)); + } + return result; + } +}; + +template +struct AsArrayImpl { + inline Array operator()(const std::vector& vec) const { + Array result; + result.reserve(vec.size()); + for (double x : vec) { + result.push_back(FloatImm(tvm::DataType::Float(64), x)); + } + return result; + } +}; + +} // namespace details + template inline std::vector AsVector(const Array& vec) { return details::AsVectorImpl()(vec); } +template +inline Array AsArray(const std::vector& vec) { + return details::AsArrayImpl()(vec); +} + } // namespace support } // namespace tvm #endif // TVM_SUPPORT_ARRAY_H_ diff --git a/tests/python/unittest/test_meta_schedule_arg_info.py b/tests/python/unittest/test_meta_schedule_arg_info.py new file mode 100644 index 000000000000..51ec9ea87ed3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_arg_info.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule.arg_info import ArgInfo, TensorInfo +from tvm.script import ty + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.tir +def Matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (128, 256), "float32") + B = tir.match_buffer(b, (256, 512), "float32") + C = tir.match_buffer(c, (128, 512), "float32") + with tir.block([128, 256, tir.reduce_axis(0, 512)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_tensor_info_creation(): + info = TensorInfo("float32", [1, 224, 224, 3]) + info = str(info) + assert info == 'TensorInfo("float32", [1, 224, 224, 3])' + + +def test_meta_schedule_tensor_info_as_json(): + info = TensorInfo("float32", [1, 224, 224, 3]) + info = info.as_json() + assert info == ["TENSOR", "float32", [1, 224, 224, 3]] + + +def test_meta_schedule_tensor_info_from_json(): + info = ["TENSOR", "float32", [1, 224, 224, 3]] + info = TensorInfo.from_json(info) + assert str(info) == 'TensorInfo("float32", [1, 224, 224, 3])' + + +def test_meta_schedule_arg_info_from_prim_func(): + a_info, b_info, c_info = ArgInfo.from_prim_func(Matmul) + assert str(a_info) == 'TensorInfo("float32", [128, 256])' + assert str(b_info) == 'TensorInfo("float32", [256, 512])' + assert str(c_info) == 'TensorInfo("float32", [128, 512])' + + +if __name__ == "__main__": + test_meta_schedule_tensor_info_creation() + test_meta_schedule_tensor_info_as_json() + test_meta_schedule_tensor_info_from_json() + test_meta_schedule_arg_info_from_prim_func()