diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index ecf7c1978d07..dc7a2b218568 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -46,6 +46,13 @@ enum class CallingConv : int { * - Implementation: specified by the native target. */ kDefault = 0, + /*! + * \brief PackedFunc that exposes a CPackedFunc signature. + * + * - Calling by PackedFunc calling convention. + * - Implementation: Expose a function with the CPackedFunc signature. + */ + kCPackedFunc = 1, /*! * \brief Device kernel launch * @@ -53,13 +60,6 @@ enum class CallingConv : int { * - Implementation: defined by device runtime(e.g. runtime/cuda) */ kDeviceKernelLaunch = 2, - /*! - * \brief PackedFunc that exposes a CPackedFunc signature. - * - * - Calling by PackedFunc calling convention. - * - Implementation: Expose a function with the CPackedFunc signature. - */ - kCPackedFunc = 3, }; /*! diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index f6ea9185567a..f63bf96ef2ab 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -324,6 +324,8 @@ class IRModule : public ObjectRef { /*! \brief Declare the container type. */ using ContainerType = IRModuleNode; + // allow copy on write. + TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode); }; /*! diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 6bab44e2355b..fe74a96ae118 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -49,6 +49,16 @@ struct ExprDeepEqual { public: TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const; }; + + +/*! + * \brief Find undefined vars in the statment. + * \param stmt The function to be checked. + * \param defs The vars that is defined. + * \return Array of undefined vars. + */ +Array UndefinedVars(const Stmt& stmt, const Array& defs); + } // namespace tir } // namespace tvm #endif // TVM_TIR_ANALYSIS_H_ diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 6a1a1788c312..8ba008bf024d 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -406,56 +406,6 @@ LoweredFunc MakeAPI(Stmt body, int num_unpacked_args, bool is_restricted); -/*! - * \brief Bind the device type of host function to be device_type. - * \param func The function to be binded. - * \param device_type The device type to be binded. - * \return The binded function. - */ -LoweredFunc BindDeviceType(LoweredFunc func, - int device_type); -/*! - * \brief Find undefined vars in the statment. - * \param stmt The function to be checked. - * \param defs The vars that is defined. - * \return Array of undefined vars. - */ -Array UndefinedVars(const Stmt& stmt, const Array& defs); - -/*! - * \brief Split the function into a host function and device functions. - * \param func The function to be splitted. - * - * \return Array of functions, the first one is host function, - * the others are device functions. - */ -Array SplitHostDevice(LoweredFunc func); - -/*! - * \brief Insert sync between parallel read/write of shared buffers. - * - * \param stmt The stmt to be trasnformed. - * \param storage_scope The storage scope considered. - */ -LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope); - -/*! - * \brief Lower cross thread alleduce in the stmt. - * \param f The device function to be lowered. - * \param warp_size the size of warp where no sync is needed. - * \return Transformed function. - */ -LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size); - -/*! - * \brief Lower warp memory in stmt. - * \param f The device function to be lowered. - * \param warp_size the size of warp where no sync is needed. - * this function will only take in effect if warp_size is bigger than one. - * \return Transformed function. - */ -LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size); - /*! * \brief Remap the thread axis * @@ -470,26 +420,6 @@ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size); */ LoweredFunc RemapThreadAxis(LoweredFunc f, Map axis_map); -/*! - * \brief Lower packed function call. - * \param f The function to be lowered. - * \return Transformed function. - */ -LoweredFunc LowerTVMBuiltin(LoweredFunc f); - - -/*! - * \brief Rewrite the pointer content type of arguments, - * as well as Alloc internal to the function to use - * the most frequently accessed type for load/store - * to avoid pointer casting in backend when possible. - * - * \note implemeneted in storage_rewrite.cc - * \param f The function to be trasnformed - * \return Transformed function. - */ -LoweredFunc PointerValueTypeRewrite(LoweredFunc f); - /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use @@ -513,14 +443,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f); */ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); -/*! - * \brief Infer the TensorCore fragment infomation using tensor intrinsics - * - * \param f The device function to be lowered. - * \return Transformed function. - */ -LoweredFunc InferFragment(LoweredFunc f); - /*! * \brief Verify if memory accesses are legal for a specific target device type. * diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index d809e07ad6db..211e344fa1d8 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -58,6 +58,21 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< const std::string& name, const tvm::Array& required); +/*! + * \brief Bind the device type ofthe function to be + * the device_type specified in the target attribute. + * + * \return The pass. + */ +TVM_DLL Pass BindDeviceType(); + +/*! + * \brief Split the function into a host function and device functions. + * + * \return The pass. + */ +TVM_DLL Pass SplitHostDevice(); + /*! * \brief skip assert stmt. * diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 7eda40de7215..e4bd2009841f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +# pylint: disable=invalid-name """The build utils in python. This module provides the functions to transform schedule to @@ -25,6 +27,7 @@ from tvm.runtime import ndarray from tvm.ir import container +from tvm.ir import CallingConv from tvm.target import codegen, BuildConfig from tvm.tir import ir_pass from tvm.tir.stmt import LoweredFunc @@ -222,75 +225,59 @@ def _build_for_device(flist, target, target_host): mdev : tvm.module A module that contains device code. """ - @tvm.tir.transform.prim_func_pass(opt_level=0) - class BindTarget: - def __init__(self, target): - self.target = target - - # pylint: disable=unused-argument - def transform_function(self, func, mod, ctx): - return func.with_attr("target", self.target) - target = _target.create(target) + target_host = _target.create(target_host) device_type = ndarray.context(target.target_name, 0).device_type - fhost = [] - fdevice = [] + for func in flist: if not ir_pass.VerifyMemory(func, device_type): raise ValueError( "Direct host side access to device memory is detected in %s. " "Did you forget to bind?" % func.name) - if func.func_type == LoweredFunc.MixedFunc: - if BuildConfig.current().detect_global_barrier: - func = ir_pass.ThreadSync(func, "global") - func = ir_pass.ThreadSync(func, "shared") - func = ir_pass.ThreadSync(func, "warp") - func = ir_pass.InferFragment(func) - warp_size = target.thread_warp_size - func = ir_pass.LowerThreadAllreduce(func, warp_size) - fsplits = list(ir_pass.SplitHostDevice(func)) - fhost.append(fsplits[0]) - for x in fsplits[1:]: - fdevice.append(x) - elif func.func_type == LoweredFunc.HostFunc: - fhost.append(func) - elif func.func_type == LoweredFunc.DeviceFunc: - fdevice.append(func) - else: - raise ValueError("unknown function type %d" % func.func_type) - - if "gpu" in target.keys and not fdevice: - warnings.warn( - "Specified target %s, but cannot find device code, did you do " - "bind?" % target) - fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] + mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist) + opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))] + if BuildConfig.current().detect_global_barrier: + opt_mixed += [tvm.tir.transform.ThreadSync("global")] + opt_mixed += [tvm.tir.transform.ThreadSync("shared"), + tvm.tir.transform.ThreadSync("warp"), + tvm.tir.transform.InferFragment(), + tvm.tir.transform.LowerThreadAllreduce(), + tvm.tir.transform.BindDeviceType(), + tvm.tir.transform.SplitHostDevice()] + mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed) - if device_type == ndarray.cpu(0).device_type and target_host == target: - assert not fdevice - - target_host = _target.create(target_host) # device optimizations - mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice) opt_device = tvm.ir.transform.Sequential( - [BindTarget(target), + [tvm.tir.transform.Filter( + lambda f: "calling_conv" in f.attrs and + f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH), tvm.tir.transform.LowerWarpMemory(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin()]) - mod_dev = opt_device(mod_dev) + mod_dev = opt_device(mod_mixed) # host optimizations - mod_host = tvm.testing.LoweredFuncsToIRModule(fhost) opt_host = tvm.ir.transform.Sequential( - [BindTarget(target_host), + [tvm.tir.transform.Filter( + lambda f: "calling_conv" not in f.attrs or + f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH), + tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)), tvm.tir.transform.LowerTVMBuiltin(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin(), tvm.tir.transform.CombineContextCall()]) - mod_host = opt_host(mod_host) + mod_host = opt_host(mod_mixed) + + if device_type == ndarray.cpu(0).device_type and target_host == target: + assert len(mod_dev.functions) == 0 + if "gpu" in target.keys and len(mod_dev.functions) == 0: + warnings.warn( + "Specified target %s, but cannot find device code, did you do " + "bind?" % target) - rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None + rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None return mod_host, rt_mod_dev diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index b3efd6bb259d..1aabf3e5bca7 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -23,7 +23,7 @@ from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range -from .function import BaseFunc +from .function import CallingConv, BaseFunc from .adt import Constructor, TypeData from .module import IRModule from .attrs import Attrs, DictAttrs, make_node diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index 70eb51a093d3..afc8c1066b1c 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -15,10 +15,18 @@ # specific language governing permissions and limitations # under the License. """Function defintiions.""" +from enum import IntEnum from .expr import RelayExpr from . import _ffi_api +class CallingConv(IntEnum): + """Possible kinds of calling conventions.""" + DEFAULT = 0 + C_PACKED_FUNC = 1 + DEVICE_KERNEL_LAUNCH = 2 + + class BaseFunc(RelayExpr): """Base class of all functions.""" @property diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 24f5211909ac..8d75d8e8ee21 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -60,7 +60,6 @@ def __init__(self, functions=None, type_definitions=None): type_definitions = mapped_type_defs self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) - def __setitem__(self, var, val): """Add a mapping to the module. diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index 93bb996084f4..a19cc2fa08c1 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -16,6 +16,7 @@ # under the License. """TIR specific function pass support.""" import inspect +import types import functools import tvm._ffi @@ -142,7 +143,7 @@ def create_function_pass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): raise TypeError("pass_func must be a callable for Module pass") - return _ffi_api.MakeFunctionPass(pass_arg, info) + return _ffi_api.CreatePrimFuncPass(pass_arg, info) if pass_func: return create_function_pass(pass_func) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 6be4a38fec03..c823c1af5baa 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -17,6 +17,70 @@ """Wrapping existing transformations.""" # pylint: disable=invalid-name from . import _ffi_api +from . import function_pass as _fpass + + +def Apply(ftransform): + """Apply ftransform to each function in the Module. + + This function is a thin wrapper around tvm.tir.transform.prim_func_pass + + Parameters + ---------- + ftransform: tvm.tir.PrimFunc -> tvm.tir.PrimFunc + The transformation pass. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + # pylint: disable=unused-argument + def _transform(func, mod, ctx): + return ftransform(func) + return _fpass.prim_func_pass(_transform, opt_level=0) + + +def Filter(fcond): + """Filter functions by the calling convention attribute. + + Parameters + ---------- + fcond : tvm.tir.PrimFunc -> bool + The condition of the filtering. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + # pylint: disable=unused-argument + def _transform(func, mod, ctx): + return func if fcond(func) else None + return _fpass.prim_func_pass(_transform, opt_level=0) + + +def BindDeviceType(): + """Bind the device type of the function to be + the device_type specified in the target attribute. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.BindDeviceType() + + +def SplitHostDevice(): + """Split the function into a host function and device functions. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.SplitHostDevice() def SkipAssert(): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f59e7646a2ac..d54d6f8773ce 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -185,75 +185,50 @@ transform::Pass BindTarget(Target target) { } +template +transform::Pass FilterBy(FCond fcond) { + auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + if (fcond(f)) { + return f; + } else { + return tir::PrimFunc(nullptr); + } + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {}); +} + + std::pair split_dev_host_funcs(const Array& funcs, const Target& target, const Target& target_host, const BuildConfig& config) { - std::unordered_set all_names; - for (const auto& x : funcs) { - CHECK(all_names.count(x->name) == 0) - << "Duplicate function name " << x->name; - all_names.insert(x->name); - } - - Array fhost; - Array fdevice; - for (const auto& x : funcs) { CHECK(tir::VerifyMemory(x, target->device_type)) << "Direct host side access to device memory is detected in " << x->func_name() << ". Did you forget to bind?"; - - if (x->func_type == tir::kMixedFunc) { - auto func = x; - if (config->detect_global_barrier) { - func = tir::ThreadSync(func, "global"); - } - - func = tir::ThreadSync(func, "shared"); - func = tir::ThreadSync(func, "warp"); - func = tir::InferFragment(func); - func = tir::LowerThreadAllreduce(func, target->thread_warp_size); - auto fsplits = tir::SplitHostDevice(func); - fhost.push_back(fsplits[0]); - for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) { - fdevice.push_back(*f); - } - } else if (x->func_type == tir::kHostFunc) { - fhost.push_back(x); - } else if (x->func_type == tir::kDeviceFunc) { - fdevice.push_back(x); - } else { - LOG(FATAL) << "unknown function type " << x->func_type; - } - } - - auto keys = target->keys(); - bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); - if (target_is_gpu && fdevice.size() == 0) { - LOG(WARNING) << "Specified target " - << target->str() - << " but cannot find device code. Did you forget to bind?"; } + IRModule mod_mixed = codegen::ToIRModule(funcs); - if (target->device_type == target::llvm()->device_type && - target_host == target) { - CHECK(fdevice.empty()) << "No device code should be generated when target " - << "and host_target are both llvm target." - << "\n"; - } - - for (size_t i = 0; i < fhost.size(); ++i) { - auto func = fhost[i]; - func = tir::BindDeviceType(func, target->device_type); - fhost.Set(i, func); + Array mixed_pass_list = {BindTarget(target)}; + if (config->detect_global_barrier) { + mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } + mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); + mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); + mixed_pass_list.push_back(tir::transform::InferFragment()); + mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); + mixed_pass_list.push_back(tir::transform::BindDeviceType()); + mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + auto opt_mixed = transform::Sequential(mixed_pass_list); + mod_mixed = opt_mixed(std::move(mod_mixed)); - // host pipeline - auto mhost = codegen::ToIRModule(fhost); auto host_pass_list = { + FilterBy([](const tir::PrimFunc& f) { + int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; + return value != static_cast(CallingConv::kDeviceKernelLaunch); + }), BindTarget(target_host), tir::transform::LowerTVMBuiltin(), tir::transform::LowerIntrin(), @@ -261,18 +236,38 @@ split_dev_host_funcs(const Array& funcs, tir::transform::CombineContextCall(), }; auto opt_host = transform::Sequential(host_pass_list); - mhost = opt_host(mhost); + auto mhost = opt_host(mod_mixed); // device pipeline - auto mdevice = codegen::ToIRModule(fdevice); auto device_pass_list = { + FilterBy([](const tir::PrimFunc& f) { + int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; + return value == static_cast(CallingConv::kDeviceKernelLaunch); + }), BindTarget(target), tir::transform::LowerWarpMemory(), tir::transform::LowerIntrin(), tir::transform::LowerDeviceStorageAccessInfo(), }; auto opt_device = transform::Sequential(device_pass_list); - mdevice = opt_device(mdevice); + auto mdevice = opt_device(mod_mixed); + + // some final misc checks. + auto keys = target->keys(); + bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); + if (target_is_gpu && mdevice->functions.size() == 0) { + LOG(WARNING) << "Specified target " + << target->str() + << " but cannot find device code. Did you forget to bind?"; + } + + if (target->device_type == target::llvm()->device_type && + target_host == target) { + CHECK(mdevice->functions.empty()) + << "No device code should be generated when target " + << "and host_target are both llvm target." + << "\n"; + } return {mhost, mdevice}; } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 56e77b72ed8e..bda997a59d4d 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -34,6 +34,7 @@ */ #include #include +#include #include #include #include "doc.h" @@ -434,6 +435,10 @@ class RelayTextPrinter : Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) { if (auto* n = base_func.as()) { return PrintFunc(prefix, GetRef(n)); + } else if (auto* n = base_func.as()) { + std::ostringstream os; + os << GetRef(n); + return Doc::RawText(os.str()); } else { // def @xyz = meta['ExternalFunc'][id] Doc doc; @@ -455,8 +460,9 @@ class RelayTextPrinter : } // functions for (const auto& kv : mod->functions) { - dg_ = DependencyGraph::Create(&arena_, kv.second); - + if (kv.second.as()) { + dg_ = DependencyGraph::Create(&arena_, kv.second); + } if (counter++ != 0) { doc << Doc::NewLine(); } diff --git a/src/target/codegen.cc b/src/target/codegen.cc index a977d35b2198..703328f8761f 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -50,9 +50,10 @@ tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) { Map remap_vars; for (auto var : from->args) { - if (from->handle_data_type.count(var)) { + auto it = from->handle_data_type.find(var); + if (it != from->handle_data_type.end()) { tir::Var new_var(var->name_hint, - PointerType(PrimType(var->dtype))); + PointerType(PrimType((*it).second->dtype))); args.push_back(new_var); remap_vars.Set(var, new_var); } else { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 70bcfe88c30e..33a3e17c939a 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include "codegen_cpu.h" diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index f991e908ca02..773c67d79269 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -108,8 +108,13 @@ IRModule PrimFuncPassNode::operator()(const IRModule& mod, updates.push_back({it.first, updated_func}); } } + // automatic removal of None for (const auto& pair : updates) { - updated_mod->Add(pair.first, pair.second, true); + if (pair.second.defined()) { + updated_mod->Add(pair.first, pair.second, true); + } else { + updated_mod->Remove(pair.first); + } } pass_ctx.Trace(updated_mod, pass_info, false); return updated_mod; diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index ff821fe48517..83db1a900fc6 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -128,10 +128,7 @@ REGISTER_PASS(VectorizeLoop); REGISTER_PASS(SkipVectorize); REGISTER_PASS(UnrollLoop); REGISTER_PASS(InjectCopyIntrin); -REGISTER_PASS(ThreadSync); REGISTER_PASS(MakeAPI); -REGISTER_PASS(BindDeviceType); -REGISTER_PASS(SplitHostDevice); REGISTER_PASS(StorageRewrite); REGISTER_PASS(CoProcSync); REGISTER_PASS(LowerStorageAccessInfo); @@ -141,7 +138,6 @@ REGISTER_PASS(InjectDoubleBuffer); REGISTER_PASS(LoopPartition); REGISTER_PASS(RemoveNoOp); REGISTER_PASS(LiftAttrScope); -REGISTER_PASS(LowerThreadAllreduce); REGISTER_PASS(RemapThreadAxis); REGISTER_PASS(LowerCustomDatatypes); REGISTER_PASS(VerifyMemory); @@ -150,7 +146,6 @@ REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); -REGISTER_PASS(InferFragment) REGISTER_PASS(NarrowDataType); } // namespace tir } // namespace tvm diff --git a/src/tir/pass/make_api.cc b/src/tir/pass/make_api.cc index f8eae645a044..861cd43e5376 100644 --- a/src/tir/pass/make_api.cc +++ b/src/tir/pass/make_api.cc @@ -218,69 +218,6 @@ LoweredFunc MakeAPI(Stmt body, return f; } -class DeviceTypeBinder: public StmtExprMutator { - public: - explicit DeviceTypeBinder(int device_type) - : device_type_(device_type) {} - - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::device_context_type) { - if (const VarNode* var = op->value.as()) { - var_ = var; - PrimExpr value = make_const(op->value.dtype(), device_type_); - Stmt body = StmtExprMutator::VisitStmt_(op); - var_ = nullptr; - std::ostringstream os; - os << "device_type need to be " << device_type_; - return AssertStmtNode::make(op->value == value, os.str(), body); - } - } - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const IfThenElseNode* op) final { - // eager simplify if guard. - Stmt res = StmtExprMutator::VisitStmt_(op); - op = res.as(); - if (is_zero(op->condition)) { - if (op->else_case.defined()) return op->else_case; - return EvaluateNode::make(0); - } - if (is_one(op->condition)) { - return op->then_case; - } - return res; - } - - PrimExpr VisitExpr_(const NENode* op) final { - // eager check NE for device check - PrimExpr res = StmtExprMutator::VisitExpr_(op); - op = res.as(); - if (tir::ExprDeepEqual()(op->a, op->b)) { - return make_const(op->dtype, false); - } - return res; - } - - PrimExpr VisitExpr_(const VarNode* op) final { - if (op == var_) { - return make_const(op->dtype, device_type_); - } else { - return GetRef(op); - } - } - - public: - const VarNode* var_{nullptr}; - int device_type_; -}; - -LoweredFunc BindDeviceType(LoweredFunc f, - int device_type) { - auto n = make_object(*f.operator->()); - n->body = DeviceTypeBinder(device_type)(n->body); - return LoweredFunc(n); -} } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc new file mode 100644 index 000000000000..486f21c907f9 --- /dev/null +++ b/src/tir/transforms/bind_device_type.cc @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file bind_device_type.cc + * \brief Bind the device type according to the target field. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class DeviceTypeBinder: public StmtExprMutator { + public: + explicit DeviceTypeBinder(int device_type) + : device_type_(device_type) {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::device_context_type) { + if (const VarNode* var = op->value.as()) { + var_ = var; + PrimExpr value = make_const(op->value.dtype(), device_type_); + Stmt body = StmtExprMutator::VisitStmt_(op); + var_ = nullptr; + std::ostringstream os; + os << "device_type need to be " << device_type_; + return AssertStmtNode::make(op->value == value, os.str(), body); + } + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const IfThenElseNode* op) final { + // eager simplify if guard. + Stmt res = StmtExprMutator::VisitStmt_(op); + op = res.as(); + if (is_zero(op->condition)) { + if (op->else_case.defined()) return op->else_case; + return EvaluateNode::make(0); + } + if (is_one(op->condition)) { + return op->then_case; + } + return res; + } + + PrimExpr VisitExpr_(const NENode* op) final { + // eager check NE for device check + PrimExpr res = StmtExprMutator::VisitExpr_(op); + op = res.as(); + if (tir::ExprDeepEqual()(op->a, op->b)) { + return make_const(op->dtype, false); + } + return res; + } + + PrimExpr VisitExpr_(const VarNode* op) final { + if (op == var_) { + return make_const(op->dtype, device_type_); + } else { + return GetRef(op); + } + } + + public: + const VarNode* var_{nullptr}; + int device_type_; +}; + +namespace transform { + +Pass BindDeviceType() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto target = f->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) + << "BindDeviceType: Require the target attribute"; + n->body = DeviceTypeBinder(target->device_type)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.BindDeviceType") +.set_body_typed(BindDeviceType); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index e7e89f899d4f..c4df2dcdb868 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -340,14 +340,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::unordered_map alloc_remap_; }; -LoweredFunc -LowerThreadAllreduce(LoweredFunc f, int warp_size) { - CHECK_NE(f->func_type, kHostFunc); - auto n = make_object(*f.operator->()); - n->body = ThreadAllreduceBuilder(warp_size)(n->body); - return LoweredFunc(n); -} - namespace transform { Pass LowerThreadAllreduce() { @@ -356,10 +348,6 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) - << "LowerThreadAllreeduce: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body); return f; }; diff --git a/src/tir/pass/split_host_device.cc b/src/tir/transforms/split_host_device.cc similarity index 61% rename from src/tir/pass/split_host_device.cc rename to src/tir/transforms/split_host_device.cc index 519101fe49ac..838ad82d974f 100644 --- a/src/tir/pass/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -21,18 +21,22 @@ * \file split_host_device.cc * \brief Split device function from host. */ +#include #include -#include +#include #include #include -#include +#include +#include +#include + #include namespace tvm { namespace tir { // use/def analysis, also delete unreferenced lets -class IRUseDefAnalysis : public StmtExprMutator { +class VarUseDefAnalysis : public StmtExprMutator { public: Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -156,8 +160,27 @@ class IRUseDefAnalysis : public StmtExprMutator { std::unordered_map def_count_; }; + +Array UndefinedVars(const Stmt& stmt, const Array& args) { + VarUseDefAnalysis m; + for (Var arg : args) { + m.use_count_[arg.get()] = 0; + } + m(stmt); + return m.undefined_; +} + + class HostDeviceSplitter : public StmtMutator { public: + explicit HostDeviceSplitter(IRModuleNode* device_mod, + Target device_target, + std::string name_prefix) + : device_mod_(device_mod), + device_target_(device_target), + name_prefix_(name_prefix) { + } + Stmt VisitStmt_(const AllocateNode* op) final { handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); return StmtMutator::VisitStmt_(op); @@ -172,86 +195,128 @@ class HostDeviceSplitter : public StmtMutator { return StmtMutator::VisitStmt_(op); } - Array Split(LoweredFunc f) { - CHECK_EQ(f->func_type, kMixedFunc); - for (auto kv : f->handle_data_type) { - handle_data_type_[kv.first.get()] = kv.second; - } - name_ = f->name; - ObjectPtr n = - make_object(*f.operator->()); - n->body = operator()(f->body); - n->func_type = kHostFunc; - Array ret{LoweredFunc(n)}; - for (LoweredFunc x : device_funcs_) { - ret.push_back(x); - } - return ret; - } - private: Stmt SplitDeviceFunc(Stmt body) { std::ostringstream os; - os << name_ << "_kernel" << device_funcs_.size(); - ObjectPtr n = make_object(); + os << name_prefix_ << "_kernel" << device_func_counter_++; + std::string kernel_symbol = os.str(); // isolate the device function. - IRUseDefAnalysis m; + VarUseDefAnalysis m; m.visit_thread_extent_ = false; - n->body = m(std::move(body)); - n->name = os.str(); - n->func_type = kDeviceFunc; - n->thread_axis = m.thread_axis_; + body = m(std::move(body)); + + Array params; + Array arguments; + Map remap_vars; + // Strictly order the arguments: Var pointers, positional arguments. - for (Var v : m.undefined_) { - if (v.dtype().is_handle()) { - n->args.push_back(v); - // mark handle data type. - auto it = handle_data_type_.find(v.get()); + for (Var var : m.undefined_) { + if (var.dtype().is_handle()) { + // Create a new version of v. + auto it = handle_data_type_.find(var.get()); if (it != handle_data_type_.end()) { - n->handle_data_type.Set(v, it->second); + tir::Var new_var(var->name_hint, + PointerType(PrimType((*it).second->dtype))); + params.push_back(new_var); + remap_vars.Set(var, new_var); + } else { + params.push_back(var); } + arguments.push_back(var); } } - for (Var v : m.undefined_) { - if (!v.dtype().is_handle()) { - n->args.push_back(v); + // positional arguments + for (Var var : m.undefined_) { + if (!var.dtype().is_handle()) { + params.push_back(var); + arguments.push_back(var); } } - LoweredFunc f_device(n); + PrimFunc device_func(params, Substitute(body, remap_vars)); + device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_); + device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, + Integer(CallingConv::kDeviceKernelLaunch)); + device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, + runtime::String(kernel_symbol)); + device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); + device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); + device_mod_->Add(GlobalVar(kernel_symbol), device_func); + + // generate calls to the device function Array call_args; - call_args.push_back(StringImmNode::make(f_device->name)); - for (Var arg : n->args) { + call_args.push_back(StringImmNode::make(kernel_symbol)); + for (PrimExpr arg : arguments) { call_args.push_back(arg); } for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } - device_funcs_.emplace_back(f_device); return EvaluateNode::make(CallNode::make( DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic)); } - // function name - std::string name_; - // the device functions + // target ir module + IRModuleNode* device_mod_; + // Device target + Target device_target_; + // function name hint + std::string name_prefix_; + // Number of device functions. + int device_func_counter_{0}; std::vector device_funcs_; std::unordered_map handle_data_type_; }; -Array UndefinedVars(const Stmt& stmt, const Array& args) { - IRUseDefAnalysis m; - for (Var arg : args) { - m.use_count_[arg.get()] = 0; - } - m(stmt); - return m.undefined_; +PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) { + auto target = func->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) + << "SplitHostDevice: Require the target attribute"; + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; + + HostDeviceSplitter splitter( + device_mod, target, static_cast(global_symbol)); + + auto* n = func.CopyOnWrite(); + n->body = splitter(std::move(n->body)); + // set the host target to None. + func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr)); + return std::move(func); } -Array SplitHostDevice(LoweredFunc func) { - return HostDeviceSplitter().Split(func); + + +namespace transform { + +Pass SplitHostDevice() { + auto pass_func = [](IRModule m, PassContext ctx) { + IRModuleNode* mptr = m.CopyOnWrite(); + std::vector > updates; + + for (const auto& kv : mptr->functions) { + if (auto* n = kv.second.as()) { + PrimFunc func = GetRef(n); + auto updated_func = SplitHostDevice(std::move(func), mptr); + updates.push_back({kv.first, updated_func}); + } + } + + for (const auto& pair : updates) { + mptr->Add(pair.first, pair.second, true); + } + return m; + }; + + return tvm::transform::CreateModulePass( + pass_func, 0, "tir.SplitHostDevice", {}); } +TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice") +.set_body_typed(SplitHostDevice); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index fad423392937..1ece078e6c3c 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -218,26 +218,19 @@ Stmt InferFragment(Stmt stmt) { return stmt; } -LoweredFunc InferFragment(LoweredFunc f) { - CHECK_NE(f->func_type, kHostFunc); - auto n = make_object(*f.operator->()); - n->body = InferFragment(f->body); - return LoweredFunc(n); -} - namespace transform { -Pass InferFragement() { +Pass InferFragment() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = InferFragment(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.InferFragement", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InferFragement") -.set_body_typed(InferFragement); +TVM_REGISTER_GLOBAL("tir.transform.InferFragment") +.set_body_typed(InferFragment); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index b631a6200d47..f464af655a15 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -374,13 +374,6 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); } -LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { - CHECK_NE(f->func_type, kHostFunc); - auto n = make_object(*f.operator->()); - n->body = ThreadSync(f->body, storage_scope); - return LoweredFunc(n); -} - namespace transform { Pass ThreadSync(std::string storage_scope) { diff --git a/tests/python/unittest/test_tir_pass_split_host_device.py b/tests/python/unittest/test_tir_analysis_usedef.py similarity index 98% rename from tests/python/unittest/test_tir_pass_split_host_device.py rename to tests/python/unittest/test_tir_analysis_usedef.py index 09f7740df9c9..449a4626926d 100644 --- a/tests/python/unittest/test_tir_pass_split_host_device.py +++ b/tests/python/unittest/test_tir_analysis_usedef.py @@ -28,7 +28,7 @@ def test_loop_dependent_allocate(): s[AA].compute_at(s[C], s[C].op.axis[0]) # this line should fail due to IRUseDefAnalysis sees an allocate statement # referencing undefined variable - tvm.lower(s, [A,C]) + tvm.lower(s, [A, C]) if __name__ == "__main__": test_loop_dependent_allocate() diff --git a/tests/python/unittest/test_tir_pass_inject_double_buffer.py b/tests/python/unittest/test_tir_pass_inject_double_buffer.py index 0fe3f614796b..94e29c68d930 100644 --- a/tests/python/unittest/test_tir_pass_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_pass_inject_double_buffer.py @@ -41,7 +41,9 @@ def test_double_buffer(): assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) - f = tvm.tir.ir_pass.ThreadSync(f, "shared") + mod = tvm.testing.LoweredFuncsToIRModule([f]) + f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] + count = [0] def count_sync(op): if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": diff --git a/tests/python/unittest/test_tir_pass_storage_flatten.py b/tests/python/unittest/test_tir_pass_storage_flatten.py index e8a78cbc5209..dbfcd20f0843 100644 --- a/tests/python/unittest/test_tir_pass_storage_flatten.py +++ b/tests/python/unittest/test_tir_pass_storage_flatten.py @@ -93,7 +93,10 @@ def test_flatten_double_buffer(): assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) - f = tvm.tir.ir_pass.ThreadSync(f, "shared") + f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) + mod = tvm.testing.LoweredFuncsToIRModule([f]) + f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] + count = [0] def count_sync(op): if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 66d3cfb1875b..167899a46838 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -33,16 +33,15 @@ def test_lower_warp_mem(): xo, xi = s[AA].split(s[AA].op.axis[0], 32) s[AA].bind(xi, tx) - f = tvm.lower(s, [A, B]) - fhost, fdevice = tvm.tir.ir_pass.SplitHostDevice(f) - - # temp adapter to convert loweredFunc to IRModule - # to test passes in the new style. - fname = fdevice.name - mod = tvm.testing.LoweredFuncsToIRModule([fdevice]) cuda_target = tvm.target.create("cuda") assert cuda_target.thread_warp_size == 32 - mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target)) + f = tvm.lower(s, [A, B], name="f") + + + mod = tvm.testing.LoweredFuncsToIRModule([f]) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) + fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] + mod = tvm.IRModule.from_expr(fdevice) fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"] assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.body.extents[0].value == 2) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index e692e23b0878..6c9e7f9b76b7 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -38,13 +38,13 @@ def test_thread_storage_sync(): A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True) - flist = tvm.tir.ir_pass.SplitHostDevice(f) - f = flist[1] - fname = f.name - mod = tvm.testing.LoweredFuncsToIRModule([f]) + cuda_target = tvm.target.create("cuda") + mod = tvm.testing.LoweredFuncsToIRModule([f]) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) + fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] + mod = tvm.IRModule.from_expr(fdevice) cuda_target = tvm.target.create("cuda") - mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target)) f = tvm.tir.transform.ThreadSync("shared")(mod)["main"] body_list = tvm.tir.stmt_list(f.body.body.body.body) assert(body_list[1].value.name == "tvm_storage_sync")