From c4110c60482f3f6bfa06a16c62dcc5df17a9652c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 21 Jan 2021 18:21:57 -0800 Subject: [PATCH] Refactor the compile engine into a cleaner interface. Duplicate the CompileEngine interface. Refactor the graph_runtime_codegen to invoke the new LowerTE pass More changes Things appear to be working Some tracing to get Relay code to flow through too. Disable some assertions as exp. Tweak printing for now Fix a few bugs: (#13) 1. Don't add relay main function to list of lowered TIR functions 2. Don't skip visiting call to relay function in graph runtime codegen Remove debug prints. Start refactoring Split out shared data structures Fix implicit duplicate decl of IsDynamic Clean up handling of name + global prim fn Clean up the code and debug issue introduced by previous hack Clean up the debugging Do C++ lint clean up --- src/driver/driver_api.cc | 9 +- src/relay/backend/compile_engine.cc | 652 +---------------- src/relay/backend/compile_engine.h | 201 +----- src/relay/backend/graph_executor_codegen.cc | 247 +++---- src/relay/backend/graph_plan_memory.cc | 28 +- src/relay/backend/interpreter.cc | 3 +- src/relay/backend/te_compiler.cc | 372 ++++++++++ src/relay/backend/te_compiler.h | 134 ++++ src/relay/backend/te_compiler_cache.cc | 681 ++++++++++++++++++ src/relay/backend/te_compiler_cache.h | 249 +++++++ src/relay/backend/vm/compiler.cc | 7 +- src/relay/ir/function.cc | 9 +- .../auto_scheduler_layout_rewrite.cc | 2 +- src/relay/transforms/memory_alloc.cc | 1 + src/relay/transforms/type_infer.cc | 9 +- src/runtime/graph_executor/graph_executor.cc | 2 + src/target/llvm/llvm_module.cc | 10 +- .../relay/test_backend_graph_executor.py | 18 +- 18 files changed, 1632 insertions(+), 1002 deletions(-) create mode 100644 src/relay/backend/te_compiler.cc create mode 100644 src/relay/backend/te_compiler.h create mode 100644 src/relay/backend/te_compiler_cache.cc create mode 100644 src/relay/backend/te_compiler_cache.h diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f30cecbf7f05..1d5364f658db 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -246,14 +246,17 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target } if (target->kind->device_type == kDLCPU && target_host == target) { - ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " - << "and host_target are both llvm target." - << "\n"; + // We need to relax this check for just TIR functions. + // ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " + // << "and host_target are both llvm target." + // << "\n"; } return {mhost, mdevice}; } +// Can we make this take one annotated IRModule? +// // Build for heterogeneous execution. runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 5e3b66b3ae15..7a09d8182f23 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -46,569 +46,14 @@ #include #include "../transforms/pass_utils.h" +#include "te_compiler_cache.h" #include "utils.h" namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(LoweredOutputNode); -TVM_REGISTER_NODE_TYPE(CachedFuncNode); -TVM_REGISTER_NODE_TYPE(CCacheKeyNode); -TVM_REGISTER_NODE_TYPE(CCacheValueNode); TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); -LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { - auto n = make_object(); - n->outputs = std::move(outputs); - n->implementation = std::move(impl); - data_ = std::move(n); -} - -CCacheKey::CCacheKey(Function source_func, Target target) { - auto n = make_object(); - n->source_func = std::move(source_func); - n->target = std::move(target); - data_ = std::move(n); -} - -Array GetShape(const Array& shape) { - // for now, we always use int32 shape when possible - // even if the result of shape inference becomes int64. - Array res; - for (IndexExpr val : shape) { - const int64_t* pval = tir::as_const_int(val); - if (pval != nullptr) { -#ifndef TVM_INDEX_DEFAULT_I64 - ICHECK_LE(pval[0], std::numeric_limits::max()); - ICHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(IntImm(DataType::Int(32), *pval)); -#else - res.push_back(val); -#endif // TVM_INDEX_DEFAULT_I64 - } else if (val->IsInstance()) { - res.push_back(val.as()->ToVar()); - } else { - res.push_back(val); - } - } - return res; -} - -// The getter to get schedule from compile engine. -// Get schedule from functor. -class ScheduleGetter : public backend::MemoizedExprTranslator> { - public: - explicit ScheduleGetter(Target target) - : target_(target), device_copy_op_(Op::Get("device_copy")) { - // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - } - - CachedFunc Create(const Function& prim_func) { - auto cache_node = make_object(); - cache_node->target = target_; - for (Var param : prim_func->params) { - Array inputs; - if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } - } - memo_[param] = inputs; - } - readable_name_stream_ << "fused"; - cache_node->outputs = this->VisitExpr(prim_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - cache_node->func_name = candidate_name; - ICHECK(anchor_op_.defined()); - // Fusion over tupled results may leave identity relationships - // between inputs and outputs, and those should not be scheduled. - // Hence schedule only non PlaceholderOp outputs. - tvm::Array tensor_outs; - for (const auto& tensor : cache_node->outputs) { - if (!tensor->op.as()) { - tensor_outs.push_back(tensor); - } - } - - te::Schedule schedule; - // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr) { - if (use_auto_scheduler_) { - const auto* fauto_schedule = - runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); - ICHECK(fauto_schedule != nullptr) - << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; - ObjectRef obj = (*fauto_schedule)(String(cache_node->func_name), tensor_outs); - if (obj.defined()) { - schedule = Downcast(obj); - } - } - - // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined()) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); - } - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); - } - } - } - cache_node->schedule = std::move(schedule); - return CachedFunc(cache_node); - } - - Array VisitExpr_(const VarNode* op) final { - LOG(FATAL) << "Free variable " << op->name_hint(); - return {}; - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(op->is_scalar()); - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "compile_engine_const", topi::kBroadcast); - scalars_.push_back(value->op); - return {value}; - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = Op::GetAttrMap("TOpPattern"); - static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); - ICHECK(flower_call) << "relay.backend.lower_call is not registered."; - - Array inputs; - int count_tuple = 0; - for (Expr arg : call_node->args) { - if (arg->checked_type().as()) { - ++count_tuple; - } - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - - Array outputs; - OpImplementation impl; - // Skip fcompute for device copy operators as it is not registered. - if (op == device_copy_op_) { - const auto* copy_input = inputs[0].operator->(); - outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); - } else { - LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - outputs = lowered_out->outputs; - impl = lowered_out->implementation; - } - - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern > anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; - } - if (outputs.size() != 1) { - const auto* tuple_type = call_node->checked_type().as(); - ICHECK(tuple_type) << "Expect output to be a tuple type"; - ICHECK_EQ(tuple_type->fields.size(), outputs.size()); - } - // Set the name to `__copy`. It will be detected in graph executor to perform - // data copy across devices. - if (op == device_copy_op_) { - readable_name_stream_.str(std::string()); - readable_name_stream_ << "__copy"; - } else { - readable_name_stream_ << '_' << op->name; - } - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - const auto* tuple_type = op->tuple->type_as(); - Array tuple = VisitExpr(op->tuple); - ICHECK_EQ(tuple_type->fields.size(), tuple.size()); - ICHECK_GE(op->index, 0); - ICHECK_LT(static_cast(op->index), tuple.size()); - return {tuple[op->index]}; - } - - private: - tvm::Target target_; - Op anchor_op_; - Attrs anchor_attrs_; - int anchor_op_pattern_{-1}; - OpImplementation anchor_implementation_; - std::ostringstream readable_name_stream_; - Array scalars_; - bool use_auto_scheduler_; - // Cache device copy op for equivalence checking to reduce registry lookup - // overhead for each invocation of call node when retrieving schedules. - const Op& device_copy_op_; -}; - -/*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ -CachedFunc CreateSchedule(const Function& source_func, const Target& target) { - return ScheduleGetter(target).Create(source_func); -} - -// Creates shape function from functor. -class MakeShapeFunc : public backend::MemoizedExprTranslator> { - public: - MakeShapeFunc() {} - - std::pair Create(const Function& prim_func) { - for (auto param : prim_func->params) { - param_states_[param] = kNoNeed; - Array data_inputs; - Array shape_inputs; - - auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { - // Add data placeholder - Shape shape = GetShape(ttype->shape); - tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); - data_inputs.push_back(data_tensor); - // Add shape placeholder - int64_t ndim = shape.size(); - Shape sshape; - if (ndim > 0) { - sshape.push_back(tvm::Integer(ndim)); - } - tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); - shape_inputs.push_back(shape_tensor); - }; - - if (const auto* ttype = param->checked_type().as()) { - add_placeholder(ttype); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - // TODO(@icemelon): Support recursive tuple - ICHECK(tuple_type); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - ICHECK(ttype); - add_placeholder(ttype); - } - } - param_data_[param] = data_inputs; - param_shapes_[param] = shape_inputs; - } - readable_name_stream_ << "shape_func"; - auto cache_node = make_object(); - cache_node->outputs = VisitExpr(prim_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - cache_node->func_name = candidate_name; - - // set inputs - for (auto param : prim_func->params) { - int state = param_states_[param]; - cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); - if (state & kNeedInputData) { - for (auto t : param_data_[param]) { - cache_node->inputs.push_back(t); - } - } - if (state & kNeedInputShape) { - for (auto t : param_shapes_[param]) { - cache_node->inputs.push_back(t); - } - } - } - - CachedFunc cfunc(cache_node); - // generate schedule for shape func - Array out_ops; - for (auto t : cache_node->outputs) { - out_ops.push_back(t->op); - } - auto schedule = te::create_schedule(out_ops); - tvm::te::AutoInlineInjective(schedule); - for (const auto& scalar : scalars_) { - auto scalar_op = scalar->op; - if (schedule->Contain(scalar_op)) { - schedule[scalar_op].compute_inline(); - } - } - return std::make_pair(schedule, cfunc); - } - - Array VisitExpr(const Expr& expr) final { - if (expr.as()) { - // Do not memoize vars because shape functions could use either the data - // or the shape of a var each time. - return ExprFunctor::VisitExpr(expr); - } - // For other case, do memoized visit - return backend::MemoizedExprTranslator>::VisitExpr(expr); - } - - Array VisitExpr_(const VarNode* var_node) final { - auto var = GetRef(var_node); - auto it = param_states_.find(var); - if (it == param_states_.end()) { - LOG(FATAL) << "Free variable " << var->name_hint(); - return {}; - } else { - ICHECK(data_dependents_per_input_.size()); - auto data_dependent = data_dependents_per_input_.back(); - if (data_dependent) { - param_states_[var] |= kNeedInputData; - return param_data_[var]; - } else { - param_states_[var] |= kNeedInputShape; - return param_shapes_[var]; - } - } - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(data_dependents_per_input_.size()); - bool data_dependent = data_dependents_per_input_.back(); - if (!op->is_scalar()) { - // This is a constant weight, extract the shape of the weight tensor. - // This can not be data dependent. - CHECK(!data_dependent); - auto ttype = op->checked_type().as(); - int ndim = static_cast(ttype->shape.size()); - Array out_shape{ndim}; - te::Tensor value = tvm::te::compute( - out_shape, - [&](const Array& indices) { - auto idx = indices[0]; - PrimExpr ret = make_const(DataType::Int(64), 0); - for (int i = 0; i < ndim; i++) { - ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); - } - return ret; - }, - "shape_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } - if (data_dependent) { - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = tvm::te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "data_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } else { - auto value = tvm::te::compute( - {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, - "shape_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto fshape_func = Op::GetAttrMap("FShapeFunc"); - static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) - << "Error in op fusion: output of the shape func is fed to a " - << "data-dependent shape func"; - ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; - ICHECK_GT(tshape_data_dependent.count(op), 0) - << "Internal error, cannot find TShapeDataDependent for " << op->name; - - Array dep_spec = tshape_data_dependent[op]; - if (dep_spec.size() == 1) { - // This is for cases when data dependence is specified per op - // Replicate 0 or 1 flag to all arguments - for (size_t i = 1; i < call_node->args.size(); ++i) { - dep_spec.push_back(dep_spec[0]); - } - } - - // Visit all inputs - Array inputs; - int count_tuple = 0; - for (size_t i = 0; i < call_node->args.size(); ++i) { - Expr arg = call_node->args[i]; - if (arg->checked_type().as()) { - ++count_tuple; - } - data_dependents_per_input_.push_back(dep_spec[i]->value != 0); - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - data_dependents_per_input_.pop_back(); - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - // Get output ndims - auto ret_type = call_node->checked_type(); - Array out_ndims; - if (const auto* ttype = ret_type.as()) { - out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); - } else { - auto rtype = ret_type.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(rtype); - for (size_t i = 0; i < rtype->fields.size(); ++i) { - auto ttype = rtype->fields[i].as(); - ICHECK(ttype); - out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); - } - } - // Call shape function - auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); - readable_name_stream_ << "_" << op->name; - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - Array input_shapes = VisitExpr(op->tuple); - Array out; - out.push_back(input_shapes[op->index]); - return out; - } - - private: - /*! \brief String stream for function name */ - std::ostringstream readable_name_stream_; - /*! \brief Map from parameter to its shape function usage state */ - std::unordered_map param_states_; - /*! \brief Map from parameter to list of data placeholder */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; - /*! \brief Map from parameter to list of shape placeholder */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; - /*! \brief Stack of data dependencies for shape function, specified per each op input */ - std::vector data_dependents_per_input_; - /*! \brief Scalars used in the shape function */ - Array scalars_; -}; - class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. @@ -618,14 +63,8 @@ class CompileEngineImpl : public CompileEngineNode { PackedFunc JIT(const CCacheKey& key) final { CCacheValue value = LowerInternal(key); if (value->packed_func != nullptr) return value->packed_func; - // build the function. - tvm::runtime::Module m; - if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - m = (*f)(value->cached_func->funcs, key->target); - } else { - m = build(value->cached_func->funcs, key->target, Target(nullptr)); - } - value->packed_func = m.GetFunction(value->cached_func->func_name); + auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); + value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); return value->packed_func; } @@ -731,50 +170,50 @@ class CompileEngineImpl : public CompileEngineNode { // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto cache_node = make_object(); + auto ir_module = IRModule(); const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; - cache_node->func_name = std::string(name_node.value()); - cache_node->target = Target("ext_dev"); - cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func); - value->cached_func = CachedFunc(cache_node); + auto func_name = std::string(name_node.value()); + auto target = Target("ext_dev"); + auto global_var = GlobalVar(func_name); + global_var->checked_type_ = key->source_func->checked_type(); + ir_module->Add(global_var, key->source_func); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); return value; } + // Enforce use the target. With target_scope(key->target); ICHECK(!value->cached_func.defined()); - auto cfunc = CreateSchedule(key->source_func, key->target); - auto cache_node = make_object(*(cfunc.operator->())); + auto cfunc = PrimFuncFor(key->source_func, key->target, + [&](std::string name) { return GetUniqueName(name, name_map_); }); // Skip lowering for device copy node. const Expr body = (key->source_func)->body; if (const CallNode* call_node = body.as()) { if (call_node->attrs.as()) { - value->cached_func = CachedFunc(cache_node); + value->cached_func = cfunc; return value; } } - cache_node->func_name = GetUniqueName(cache_node->func_name); // NOTE: array will copy on write. - Array all_args = cache_node->inputs; - for (te::Tensor arg : cache_node->outputs) { + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { all_args.push_back(arg); } - // lower the function - if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); - } else { - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); - std::unordered_map binds; - cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds); - } - value->cached_func = CachedFunc(cache_node); + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + + std::unordered_map binds; + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name, binds)); + value->cached_func = cfunc; return value; } + // implement lowered shape func CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { std::lock_guard lock(mutex_); @@ -793,47 +232,14 @@ class CompileEngineImpl : public CompileEngineNode { With target_scope(key->target); ICHECK(!value->cached_func.defined()); - auto spair = MakeShapeFunc().Create(key->source_func); - auto cache_node = make_object(*(spair.second.operator->())); - cache_node->func_name = GetUniqueName(cache_node->func_name); - cache_node->target = key->target; - - Array all_args = cache_node->inputs; - for (te::Tensor arg : cache_node->outputs) { - all_args.push_back(arg); - } - - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); + auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(name, name_map_); + }); - std::unordered_map binds; - cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds); - value->cached_func = CachedFunc(cache_node); + value->cached_func = cached_func; return value; } - /*! - * \brief Get unique name from name. - * \param name The orginal name. - * \return Updated name which is unique. - */ - std::string GetUniqueName(std::string name) { - for (size_t i = 0; i < name.length(); ++i) { - if (name[i] == '.') name[i] = '_'; - } - while (true) { - auto it = name_map_.find(name); - if (it == name_map_.end()) { - name_map_[name] = 1; - return name; - } else { - std::ostringstream os; - os << name << "_" << it->second; - ++(it->second); - name = os.str(); - } - } - return name; - } + /*! \brief compiler cache lock*/ std::mutex mutex_; /*! \brief internal name map to get an unique name */ diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index d7628e7a5bdf..cb3097b69a97 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -19,8 +19,12 @@ /*! * \file relay/backend/compile_engine.h - * \brief Internal compialtion engine handle function cache. - * and interface to low level code generation. + * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * + * This layer represents the older design of the Relay compilation flow and is being deprecated + * in favor of te_compiler.h which is a migration step towards a standard pass based lowering of + * Relay functions. + * */ #ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ @@ -36,157 +40,12 @@ #include #include +#include "te_compiler_cache.h" + namespace tvm { namespace relay { -/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ -enum ShapeFuncParamState { - kNoNeed = 0, - kNeedInputData = 1, - kNeedInputShape = 2, - kNeedBoth = 3, -}; - -struct LoweredOutputNode : public Object { - /*! \brief The outputs to the function */ - tvm::Array outputs; - /*! \brief The implementation used to compute the output */ - OpImplementation implementation; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("outputs", &outputs); - v->Visit("implementation", &implementation); - } - - static constexpr const char* _type_key = "relay.LoweredOutput"; - TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); -}; - -class LoweredOutput : public ObjectRef { - public: - TVM_DLL LoweredOutput(tvm::Array outputs, OpImplementation impl); - - TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); -}; - -/*! \brief Node container to represent a cached function. */ -struct CachedFuncNode : public Object { - /* \brief compiled target */ - tvm::Target target; - /*! \brief Function name */ - std::string func_name; - /* \brief The inputs to the function */ - tvm::Array inputs; - /* \brief The outputs to the function */ - tvm::Array outputs; - /*! \brief The schedule to the function */ - te::Schedule schedule; - /*! \brief The lowered functions to support the function. */ - IRModule funcs = IRModule(Map({})); - - /*! \brief Parameter usage states in the shape function. */ - tvm::Array shape_func_param_states; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("target", &target); - v->Visit("func_name", &func_name); - v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); - v->Visit("schedule", &schedule); - v->Visit("funcs", &funcs); - v->Visit("shape_func_param_states", &shape_func_param_states); - } - - static constexpr const char* _type_key = "relay.CachedFunc"; - TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); -}; - -class CachedFunc : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); -}; - -class CCacheKey; -/*! \brief Compile cache key */ -class CCacheKeyNode : public Object { - public: - /*! \brief The source function to be lowered. */ - Function source_func; - /*! \brief The hardware target.*/ - Target target; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("source_func", &source_func); - v->Visit("target", &target); - } - /*! \return The hash value of CCacheKey. */ - inline size_t Hash() const; - /*! - * \brief check content equality - * \param other The other value. - * \return The result of equality check. - */ - inline bool Equal(const CCacheKeyNode* other) const; - - static constexpr const char* _type_key = "relay.CCacheKey"; - TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); - - private: - /*! - * \brief internal cached hash value. - */ - mutable size_t hash_{0}; -}; - -/*! \brief cache entry used in compile engine */ -class CCacheKey : public ObjectRef { - public: - CCacheKey() {} - explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief The constructor - * \param source_func The source function. - * \param target The target device. - */ - TVM_DLL CCacheKey(Function source_func, Target target); - - const CCacheKeyNode* operator->() const { return static_cast(get()); } - // comparator - inline bool operator==(const CCacheKey& other) const { - ICHECK(defined() && other.defined()); - return (*this)->Equal(other.operator->()); - } - using ContainerType = CCacheKeyNode; -}; - -/*! \brief Node container for compile cache. */ -class CCacheValueNode : public Object { - public: - /*! \brief The corresponding function */ - CachedFunc cached_func; - /*! \brief Result of Packed function generated by JIT */ - PackedFunc packed_func; - /*! \brief usage statistics */ - int use_count{0}; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("cached_func", &cached_func); - v->Visit("use_count", &use_count); - } - static constexpr const char* _type_key = "relay.CCacheValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); -}; - -/*! \brief cache entry used in compile engine */ -class CCacheValue : public ObjectRef { - public: - CCacheValue() {} - explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} - CCacheValueNode* operator->() { return static_cast(get_mutable()); } - const CCacheValueNode* operator->() const { return static_cast(get()); } - using ContainerType = CCacheValueNode; -}; +using namespace tvm::relay::tec; /*! * \brief Backend compilation engine for @@ -241,49 +100,7 @@ class CompileEngine : public ObjectRef { TVM_DLL static CompileEngine& Global(); }; -/*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ -CachedFunc CreateSchedule(const Function& source_func, const Target& target); - -/*! - * \brief Check if the type is dynamic. - * \param ty The type to be checked. - * \return The result. - */ -bool IsDynamic(const Type& ty); - -// implementations -inline size_t CCacheKeyNode::Hash() const { - if (hash_ != 0) return hash_; - // do structral hash, avoid 0. - hash_ = tvm::StructuralHash()(this->source_func); - hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); - if (hash_ == 0) hash_ = 1; - return hash_; -} - -inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { - if (Hash() != other->Hash()) return false; - return this->target->str() == other->target->str() && - tvm::StructuralEqual()(this->source_func, other->source_func); -} - } // namespace relay } // namespace tvm -namespace std { -// overload hash -template <> -struct hash<::tvm::relay::CCacheKey> { - size_t operator()(const ::tvm::relay::CCacheKey& key) const { - ICHECK(key.defined()); - return key->Hash(); - } -}; -} // namespace std #endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index d92d4d2077f7..84b5cba23de6 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -36,10 +36,15 @@ #include #include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { namespace relay { + +/// TODO(@jroesch, @chris): declare directly elsewhere +Map>> GraphPlanMemory(const Function& func); + namespace backend { class GraphNode; @@ -176,11 +181,12 @@ class GraphOpNode : public GraphNode { const std::string op_type_name_{"tvm_op"}; }; -/*! \brief Code generator for graph executor */ +/*! \brief Code generator for the graph executor, produces a module containing the graph JSON, + * module, and parameters. + */ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> { public: GraphExecutorCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) { - compile_engine_ = CompileEngine::Global(); targets_ = targets; } @@ -271,16 +277,52 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(device_type[0]->value); + device_context_map.insert({expr, ctx}); + } + + // todo map targets down + auto lowered_module = tec::LowerTE(mod, targets_, device_context_map); + + auto main_module = lowered_module.main_module; + main_module = relay::transform::InferType()(main_module); + relay::Function main_func = Downcast(main_module->Lookup("main")); + + // Now that we have lowered all operators to TIR code, we can proceed with compilation. + storage_device_map_ = GraphPlanMemory(main_func); + // First we convert all the parameters into input nodes. - for (auto param : func->params) { + for (auto param : main_func->params) { auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); var_map_[param.get()] = AddNode(node_ptr, param); } - heads_ = VisitExpr(func->body); + + heads_ = VisitExpr(main_func->body); std::ostringstream os; + dmlc::JSONWriter writer(&os); GetJSON(&writer); LoweredOutput ret; @@ -292,16 +334,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(param_storage_ids_[param.first]), param.second))); } - for (auto& kv : lowered_funcs_) { - if (ret.lowered_funcs.count(kv.first) == 0) { - ret.lowered_funcs.Set(kv.first, IRModule(Map({}))); - } - auto& mod = ret.lowered_funcs[kv.first]; - mod->Update(kv.second); - ret.lowered_funcs.Set(kv.first, mod); - } - ret.external_mods = compile_engine_->LowerExternalFunctions(); ret.function_metadata = std::move(function_metadata_); + ret.lowered_funcs = lowered_module.per_target_module; + ret.external_mods = lowered_module.external_mods; return ret; } @@ -432,157 +467,47 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(op)); } - bool ShareSameStorage(const Expr& lhs, const Expr& rhs) { - auto lit = storage_device_map_.find(lhs); - auto rit = storage_device_map_.find(rhs); - ICHECK(lit != storage_device_map_.end()); - ICHECK(rit != storage_device_map_.end()); - int64_t lhs_storage_id = ((*lit).second)[0][0]->value; - int64_t rhs_storage_id = ((*rit).second)[0][0]->value; - return lhs_storage_id == rhs_storage_id; - } + std::vector VisitExpr_(const CallNode* call_node) override { + relay::Call call = GetRef(call_node); + if (auto global_node = call->op.as()) { - /*! - * \brief Obtain the Target from the device type. - * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. - * - * \param dev_type - * \return Target - */ - Target GetTargetFromInteger(int64_t dev_type) { - if (targets_.size() == 1) { - // homogeneous execution. - const auto& it = targets_.begin(); - return (*it).second; - } else { - // heterogeneous execution. - std::string call_dev_name; - if (dev_type == 0) { - call_dev_name = "llvm"; - } else { - call_dev_name = runtime::DeviceName(dev_type); - } - if (targets_.count(dev_type) == 0) { - LOG(FATAL) << "No target is provided for device " << call_dev_name; - } - return targets_[dev_type]; - } - } + auto prim_fn_name = global_node->name_hint; - /*! - * \brief Update the function metadata for a given cached function and its relay - * primitive function. - * - * \param cfunc The cached function as provided the by the compile engine - * \param relay_func The source relay primitive function - * \param relay_target The target associated with relay primitive function - */ - void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func, - const Target& relay_target) { - auto fi_node = make_object(); - for (const auto& kv : cfunc->funcs->functions) { - auto primfunc = Downcast(kv.second); - auto workspace_byte_alignment = relay_target->GetAttr("workspace-byte-alignment") - .value_or(tvm::runtime::kDefaultWorkspaceAlignment); - Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); - Target primfunc_target = relay_target; - if (primfunc->attrs->dict.count("target")) { - primfunc_target = Downcast(primfunc->attrs->dict["target"]); - } - fi_node->workspace_sizes.Set(primfunc_target, workspace_size); - // Calculating size for I/O - for (auto const& param : primfunc->params) { - auto p_shape = primfunc->buffer_map[param]->shape; - int num_of_elements = 1; - for (const auto& dim_index_expr : p_shape) { - if (dim_index_expr->IsInstance()) { - num_of_elements *= dim_index_expr.as()->value; - } else { - // If shape is dynamic, we cannot calculate workspace in compile time. - num_of_elements = 0; - } - } - int element_size = primfunc->buffer_map[param]->dtype.bytes(); - fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements); - } - fi_node->constant_sizes.Set(primfunc_target, 0); - fi_node->tir_primfuncs.Set(primfunc_target, primfunc); - fi_node->relay_primfuncs.Set(primfunc_target, relay_func); - } - function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node)); - } + Target target; - std::vector VisitExpr_(const CallNode* op) override { - Expr expr = GetRef(op); - Function func; - if (op->op.as()) { - LOG(FATAL) << "Operators should be transformed away; try applying" - << "the fuse_ops transformation to the expression."; - } else if (op->op.as()) { - LOG(FATAL) << "Not implemented"; - } else if (op->op.as()) { - func = GetRef(op->op.as()); - } else { - LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); - } - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - LOG(FATAL) << "TVM only support calls to primitive functions " - << "(i.e functions composed of fusable operator invocations)"; - } + // // Handle external function + // if (func->GetAttr(attr::kCompiler).defined()) { + // UpdateConstants(func, ¶ms_); + // return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name); + // } - // Copy attrs from function into the graph node - // For now we only handle strings - GraphAttrs attrs; - for (auto p : func->attrs->dict) { - if (p.second.as()) { - attrs[p.first] = std::string(Downcast(p.second)); + ICHECK_GE(storage_device_map_.count(call), 0); + auto& device_type = storage_device_map_[call][1]; + auto call_dev_type = device_type[0]->value; + // Normal Relay Function + if (targets_.size() == 1) { + // homogeneous execution. + const auto& it = targets_.begin(); + target = (*it).second; + } else { + // heterogeneous execution. + std::string call_dev_name; + if (call_dev_type == 0) { + call_dev_name = "llvm"; + } else { + call_dev_name = runtime::DeviceName(call_dev_type); + } + if (targets_.count(call_dev_type) == 0) { + LOG(FATAL) << "No target is provided for device " << call_dev_name; + } + target = targets_[call_dev_type]; } - } - auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); - auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); - Target target; - // Handle external function - if (func->GetAttr(attr::kCompiler).defined()) { - target = Target("ext_dev"); - CCacheKey key = (*pf0)(func, target); - CachedFunc ext_func = (*pf1)(compile_engine_, key); - ICHECK(ext_func.defined()) << "External function is not defined."; - UpdateConstants(func, ¶ms_); - return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name, attrs); - } - - // In the current flat memory allocation scenario - // the flat memory allocator can always allocate input - // and output of the reshape to the same memory, we can turn reshape only - // function to a nop. - // - // NOTE that for non-flat memory this is not necessarily true. - // - // TODO(tvm-team) Update checks of flat memory enablement when we support - // opaque-nd memory planning to skip this path. - if (func->HasNonzeroAttr(attr::kReshapeOnly) && ShareSameStorage(expr, op->args[0])) { - return GraphAddCallNode(op, "reshape_nop", "__nop", attrs); - } - - ICHECK_GE(storage_device_map_.count(expr), 0); - auto& device_type = storage_device_map_[expr][1]; - auto call_dev_type = device_type[0]->value; - target = GetTargetFromInteger(call_dev_type); - // Normal Relay Function - - CCacheKey key = (*pf0)(func, target); - CachedFunc lowered_func = (*pf1)(compile_engine_, key); - if (!lowered_funcs_.count(target->str())) { - lowered_funcs_[target->str()] = IRModule(Map({})); + return GraphAddCallNode(call_node, _GetUniqueName(prim_fn_name), prim_fn_name); + } else { + LOG(FATAL) << "BadCase: " << PrettyPrint(call) << std::endl; + return {}; } - lowered_funcs_[target->str()]->Update(lowered_func->funcs); - - // Update function metadata via looking at all primfuncs - UpdateFunctionMetadata(lowered_func, func, target); - return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name, - attrs); } std::vector VisitExpr_(const LetNode* op) override { @@ -730,8 +655,6 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator function_metadata_; /*! \brief name map */ std::unordered_map name_map_; - /*! \brief compile engine */ - CompileEngine compile_engine_; }; class GraphExecutorCodegenModule : public runtime::ModuleNode { diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 351469d6e1ca..b5a81e7ba05a 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -114,7 +114,8 @@ class StorageAllocaBaseVisitor : public ExprVisitor { const std::vector& GetToken(const Expr& expr) { this->VisitExpr(expr); auto it = token_map_.find(expr.operator->()); - ICHECK(it != token_map_.end()); + ICHECK(it != token_map_.end()) + << "Expression: `" << PrettyPrint(expr) << "` not found in storage map."; return it->second; } /*! @@ -166,10 +167,22 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { } void VisitExpr_(const CallNode* op) final { + // temporary hack for change of style. + bool skip_first = false; + if (auto op_node = op->op.as()) { + if (op_node->name == "prim_fn_call") { + skip_first = true; + } + } + // create token for the call node. CreateToken(op, true); // for each input, visit argument token. for (Expr arg : op->args) { + if (skip_first) { + skip_first = false; + continue; + } for (StorageToken* tok : GetToken(arg)) { tok->ref_counter += 1; } @@ -272,9 +285,22 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // The call map void VisitExpr_(const CallNode* op) final { + // temporary hack for change of style. + bool skip_first = false; + if (auto op_node = op->op.as()) { + if (op_node->name == "prim_fn_call") { + skip_first = true; + } + } + std::vector args; // for each input, visit argument token. for (Expr arg : op->args) { + if (skip_first) { + skip_first = false; + continue; + } + for (StorageToken* tok : GetToken(arg)) { args.push_back(tok); } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index eeba010dc164..53985c78a33c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -32,6 +32,7 @@ #include #include +#include "../transforms/pass_utils.h" #include "compile_engine.h" namespace tvm { @@ -381,7 +382,7 @@ class Interpreter : public ExprFunctor, } else { m = build(cfunc->funcs, cfunc->target, Target(nullptr)); } - shape_func = m.GetFunction(cfunc->func_name); + shape_func = m.GetFunction(cfunc->prim_fn_var->name_hint); shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); // Get output shapes diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc new file mode 100644 index 000000000000..bcf32f0112a3 --- /dev/null +++ b/src/relay/backend/te_compiler.cc @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "te_compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../transforms/pass_utils.h" +#include "te_compiler_cache.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +using namespace tvm::relay::transform; + +TVM_REGISTER_OBJECT_TYPE(TECompilerNode); + +class TECompilerImpl : public TECompilerNode { + public: + // Lower the function. + CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; } + + // For now, build one module per function. + PackedFunc JIT(const CCacheKey& key) final { + CCacheValue value = LowerInternal(key); + if (value->packed_func != nullptr) { + return value->packed_func; + } + auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); + value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); + return value->packed_func; + } + + CachedFunc LowerShapeFunc(const CCacheKey& key) final { + return LowerShapeFuncInternal(key)->cached_func; + } + + Map GetLoweredFunctions() { + Map lowered_functions; + for (const auto& it : cache_) { + auto source_func = it.first; + auto lowered_func = it.second; + auto target = source_func->target; + + if (!lowered_functions.count(target->str())) { + lowered_functions.Set(target->str(), IRModule(Map({}))); + } + + lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + } + return lowered_functions; + } + + Array LowerExternalFunctions() { + Array ret; + std::unordered_map cached_symbol; + std::vector cached_ext_funcs; + for (const auto& it : cache_) { + auto src_func = it.first->source_func; + ICHECK(src_func.defined()); + if (src_func->GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->GetAttr(attr::kCompiler); + ICHECK(code_gen.defined()) << "No external codegen is set"; + std::string code_gen_name = code_gen.value(); + cached_ext_funcs.push_back(it.first); + + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" + << AsText(src_func, false); + + std::string sn = symbol_name.value(); + if (cached_symbol.count(sn)) { + cached_symbol[sn] = code_gen_name; + } else { + ICHECK_NE(sn, code_gen_name) + << "Found duplicated symbol: " << sn << " for: " << code_gen_name; + } + + std::string ext_name = "relay.ext." + code_gen_name; + auto pf = tvm::runtime::Registry::Get(ext_name); + ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; + // No need to keep compiler attribute at this point, functions have been + // extracted for specific codegen. + src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); + runtime::Module ext_mod = (*pf)(src_func); + + ICHECK(ext_mod.defined()) << "No external runtime is generated."; + ret.push_back(ext_mod); + } + } + + // No need to cache external functions as we collected them all to create + // external runtime modules. + for (const auto& it : cached_ext_funcs) { + cache_.erase(it); + } + return ret; + } + + void Clear() final { cache_.clear(); } + + // List all items in the cache. + Array ListItems() { + std::lock_guard lock(mutex_); + Array items; + for (auto& kv : cache_) { + items.push_back(kv.first); + items.push_back(kv.second); + } + return items; + } + + /*! + * \brief Get the cache key of the function that is being lowered currently + * \return the cache key + */ + CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } + + private: + // implement lowered func + CCacheValue LowerInternal(const CCacheKey& key) { + std::lock_guard lock(mutex_); + CCacheValue value; + auto it = cache_.find(key); + if (it != cache_.end()) { + it->second->use_count += 1; + if (it->second->cached_func.defined()) return it->second; + value = it->second; + } else { + value = CCacheValue(make_object()); + value->use_count = 0; + if (!backend::IsCompileEngineCacheDisabled()) { + cache_[key] = value; + } + } + cur_ccache_key_ = key; + + // No need to lower external functions for now. We will invoke the external + // codegen tool once and lower all functions together. + if (key->source_func->GetAttr(attr::kCompiler).defined()) { + auto ir_module = IRModule(); + const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(name_node.defined()) << "External function has not been attached a name yet."; + auto func_name = std::string(name_node.value()); + auto target = Target("ext_dev"); + auto global_var = GlobalVar(func_name); + global_var->checked_type_ = key->source_func->checked_type(); + ir_module->Add(global_var, key->source_func); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); + return value; + } + // Enforce use the target. + With target_scope(key->target); + + ICHECK(!value->cached_func.defined()); + auto cfunc = PrimFuncFor(key->source_func, key->target, + [&](std::string name) { return GetUniqueName(name, name_map_); }); + + // Skip lowering for device copy node. + const Expr body = (key->source_func)->body; + if (const CallNode* call_node = body.as()) { + if (call_node->attrs.as()) { + value->cached_func = cfunc; + return value; + } + } + + std::cout << "Input Size: " << cfunc->inputs.size() << std::endl; + std::cout << "Output Size: " << cfunc->outputs.size() << std::endl; + // NOTE: array will copy on write. + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { + all_args.push_back(arg); + } + + std::cout << "Allargs Size: " << all_args.size() << std::endl; + + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + + std::unordered_map binds; + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name, binds)); + value->cached_func = cfunc; + return value; + } + + // implement lowered shape func + CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { + std::lock_guard lock(mutex_); + CCacheValue value; + auto it = shape_func_cache_.find(key); + if (it != shape_func_cache_.end()) { + it->second->use_count += 1; + if (it->second->cached_func.defined()) return it->second; + value = it->second; + } else { + value = CCacheValue(make_object()); + value->use_count = 0; + shape_func_cache_[key] = value; + } + // Enforce use the target. + With target_scope(key->target); + + ICHECK(!value->cached_func.defined()); + auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(name, name_map_); + }); + + value->cached_func = cached_func; + return value; + } + + /*! \brief compiler cache lock*/ + std::mutex mutex_; + /*! \brief internal name map to get an unique name */ + std::unordered_map name_map_; + /*! \brief internal compiler cache */ + std::unordered_map cache_; + /*! \brief internal compiler cache for shape funcs */ + std::unordered_map shape_func_cache_; + /*! \brief the cache key of the function that is being lowered currently*/ + CCacheKey cur_ccache_key_; +}; + +TECompiler::TECompiler() { + auto object = make_object(); + data_ = object; +} + +class LowerTensorExpr : public ExprMutator { + public: + LowerTensorExpr(const IRModule& module, const TargetsMap& targets, + const DeviceContextMap& device_ctx_map, TECompiler compiler) + : module_(module), + targets_(targets), + device_context_map_(device_ctx_map), + compiler_(compiler) {} + + Expr VisitExpr_(const CallNode* call) override { + Call expr = GetRef(call); + Function func; + + if (call->op.as()) { + func = GetRef(call->op.as()); + } else { + return ExprMutator::VisitExpr_(call); + } + + if (!func->HasNonzeroAttr(attr::kPrimitive)) { + // LOG(FATAL) << "TVM only support calls to primitive functions " + // << "(i.e functions composed of fusable operator invocations)"; + return ExprMutator::VisitExpr_(call); + } + + // Process inputs. + Array args; + for (size_t i = 0; i < expr->args.size(); i++) { + args.push_back(VisitExpr(expr->args[i])); + } + + Target target; + + if (func->GetAttr(attr::kCompiler).defined()) { + target = Target("ext_dev"); + CCacheKey key = CCacheKey(func, target); + CachedFunc ext_func = compiler_->Lower(key); + ICHECK(ext_func.defined()) << "External function is not defined."; + return Call(ext_func->prim_fn_var, args, {}); + } + + ICHECK_GE(device_context_map_.count(expr), 0); + auto& device_context = this->device_context_map_[expr]; + auto call_dev_type = device_context.device_type; + + // Non-External Relay Function + if (targets_.size() == 1) { + // The homogeneous execution case, we should only have one target + // so we just grab it. + const auto& it = targets_.begin(); + target = (*it).second; + } else { + // The heterogeneous execution case we have multiple targets + // in this case. + // + // We need to identify the target and translate. + std::string call_dev_name; + if (call_dev_type == 0) { + call_dev_name = "llvm"; + } else { + call_dev_name = ::tvm::runtime::DeviceName(call_dev_type); + } + if (targets_.count(call_dev_type) == 0) { + LOG(FATAL) << "No target is provided for device " << call_dev_name; + } + target = targets_[call_dev_type]; + } + + CCacheKey key = CCacheKey(func, target); + CachedFunc lowered_func = compiler_->Lower(key); + + return Call(lowered_func->prim_fn_var, args, Attrs()); + } + + IRModule module_; + TargetsMap targets_; + DeviceContextMap device_context_map_; + TECompiler compiler_; +}; + +LoweredModule LowerTE(const IRModule& module, TargetsMap targets, + DeviceContextMap device_context_map) { + TECompiler compiler; + + auto pass = CreateFunctionPass( + [=](Function func, IRModule module, PassContext ctx) { + LowerTensorExpr lower_te(module, targets, device_context_map, compiler); + return Downcast(lower_te.VisitExpr(func)); + }, + 0, "LowerTensorExpr", {}); + + auto updated_module = pass(module); + + LoweredModule lowered_module; + lowered_module.main_module = updated_module; + lowered_module.per_target_module = compiler->GetLoweredFunctions(); + lowered_module.external_mods = compiler->LowerExternalFunctions(); + return lowered_module; +} + +} // namespace tec +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h new file mode 100644 index 000000000000..bf140187ae7e --- /dev/null +++ b/src/relay/backend/te_compiler.h @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/tir_compiler.h + * * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * + * + * This represents the new design of the Relay compilation flow and will replace the interface + * contained in compile_engine.h as we migrate towards a standard pass based lowering of + * Relay functions. + * + * This files provides an internal API which lowers Relay programs to components which + * can be combined with TVM produced kernels to compile an entire program. + * + * The result of lowering contains a combination of `runtime::Module`s produced by external + * compilers and a set of lowered PrimFns which can be code generated for targets. + */ +#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_ +#define TVM_RELAY_BACKEND_TE_COMPILER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../transforms/infer_layout_utils.h" +#include "../transforms/pass_utils.h" +#include "./te_compiler_cache.h" + +namespace tvm { +namespace relay { +namespace tec { + +// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake +// we should a version of context which works in Map +using TargetsMap = std::unordered_map; +using DeviceContextMap = + std::unordered_map; + +/*! + * \brief A compiler which lowers primitive Relay functions to tensor expressions + * and schdules them into TIR functions. + */ +class TECompilerNode : public Object { + public: + /*! \brief destructor */ + virtual ~TECompilerNode() {} + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key) = 0; + + virtual Map GetLoweredFunctions() = 0; + /*! + * \brief Just in time compile to get a PackedFunc. + * \param key The key to the cached function. + * \return The result. + */ + virtual PackedFunc JIT(const CCacheKey& key) = 0; + /*! + * \brief Lower the shape function. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; + /*! + * \brief Lower the external function using external codegen tools. + * \return The runtime moduels for each needed external codegen tool. + */ + virtual tvm::Array LowerExternalFunctions() = 0; + + /*! \brief clear the cache. */ + virtual void Clear() = 0; + + // VisitAttrs + void VisitAttrs(AttrVisitor*) {} + + static constexpr const char* _type_key = "relay.TECompiler"; + TVM_DECLARE_FINAL_OBJECT_INFO(TECompilerNode, Object); +}; + +/*! \brief cache entry used in compile engine */ +class TECompiler : public ObjectRef { + public: + TECompiler(); + explicit TECompiler(ObjectPtr n) : ObjectRef(n) {} + TECompilerNode* operator->() { return static_cast(get_mutable()); } + using ContainerType = TECompilerNode; + /*! \brief The global compile engine. */ + TVM_DLL static TECompiler& Global(); +}; + +struct LoweredModule { + IRModule main_module; + Map per_target_module; + Array external_mods; +}; + +LoweredModule LowerTE(const IRModule& module, TargetsMap targets, + DeviceContextMap device_context_map); + +} // namespace tec +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_TE_COMPILER_H_ diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc new file mode 100644 index 000000000000..dc5c1acee2be --- /dev/null +++ b/src/relay/backend/te_compiler_cache.cc @@ -0,0 +1,681 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./te_compiler_cache.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../transforms/pass_utils.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +TVM_REGISTER_NODE_TYPE(LoweredOutputNode); +TVM_REGISTER_NODE_TYPE(CachedFuncNode); +TVM_REGISTER_NODE_TYPE(CCacheKeyNode); +TVM_REGISTER_NODE_TYPE(CCacheValueNode); + +LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { + auto n = make_object(); + n->outputs = std::move(outputs); + n->implementation = std::move(impl); + data_ = std::move(n); +} + +CCacheKey::CCacheKey(Function source_func, Target target) { + auto n = make_object(); + n->source_func = std::move(source_func); + n->target = std::move(target); + data_ = std::move(n); +} + +CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array inputs, + tvm::Array outputs, te::Schedule schedule, + tvm::Array shape_func_param_states, IRModule funcs) { + auto n = make_object(); + n->target = target; + n->prim_fn_var = prim_fn_var; + n->inputs = inputs; + n->outputs = outputs; + n->schedule = schedule; + n->shape_func_param_states = shape_func_param_states; + n->funcs = funcs; + data_ = std::move(n); +} + +Array GetShape(const Array& shape) { + // for now, we always use int32 shape when possible + // even if the result of shape inference becomes int64. + Array res; + for (IndexExpr val : shape) { + const int64_t* pval = tir::as_const_int(val); + if (pval != nullptr) { +#ifndef TVM_INDEX_DEFAULT_I64 + ICHECK_LE(pval[0], std::numeric_limits::max()); + ICHECK_GE(pval[0], std::numeric_limits::min()); + res.push_back(IntImm(DataType::Int(32), *pval)); +#else + res.push_back(val); +#endif // TVM_INDEX_DEFAULT_I64 + } else if (val->IsInstance()) { + res.push_back(val.as()->ToVar()); + } else { + res.push_back(val); + } + } + return res; +} + +// The getter to get schedule from compile engine. +// Get schedule from functor. +class ScheduleGetter : public backend::MemoizedExprTranslator> { + public: + explicit ScheduleGetter(Target target) + : target_(target), device_copy_op_(Op::Get("device_copy")) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + } + + CachedFunc Create(const Function& prim_func, std::function renamer) { + Array fn_inputs; + for (Var param : prim_func->params) { + Array inputs; + if (const auto* ttype = param->checked_type().as()) { + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + fn_inputs.push_back(tensor); + inputs.push_back(tensor); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + // TODO(@icemelon): Allow recursive tuple + ICHECK(ttype != nullptr); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + fn_inputs.push_back(tensor); + inputs.push_back(tensor); + } + } + memo_[param] = inputs; + } + readable_name_stream_ << "fused"; + auto outputs = this->VisitExpr(prim_func->body); + auto candidate_name = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + ICHECK(anchor_op_.defined()); + // Fusion over tupled results may leave identity relationships + // between inputs and outputs, and those should not be scheduled. + // Hence schedule only non PlaceholderOp outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + + te::Schedule schedule; + // No need to register schedule for device copy op. + if (anchor_attrs_.as() == nullptr) { + if (use_auto_scheduler_) { + const auto* fauto_schedule = + runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); + ICHECK(fauto_schedule != nullptr) + << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; + ObjectRef obj = (*fauto_schedule)(tensor_outs); + if (obj.defined()) { + schedule = Downcast(obj); + } + } + + // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. + if (!schedule.defined()) { + ICHECK(anchor_implementation_.defined()); + schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + } + for (const auto& scalar : scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } + } + } + + auto prim_fn_var = GlobalVar(candidate_name); + prim_fn_var->checked_type_ = prim_func->checked_type(); + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {}); + } + + Array VisitExpr_(const VarNode* op) final { + LOG(FATAL) << "Free variable " << op->name_hint(); + return {}; + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(op->is_scalar()); + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); + scalars_.push_back(value->op); + return {value}; + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto fpattern = Op::GetAttrMap("TOpPattern"); + static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); + ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + + Array inputs; + int count_tuple = 0; + for (Expr arg : call_node->args) { + if (arg->checked_type().as()) { + ++count_tuple; + } + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + } + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; + } + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + Array outputs; + OpImplementation impl; + // Skip fcompute for device copy operators as it is not registered. + if (op == device_copy_op_) { + const auto* copy_input = inputs[0].operator->(); + outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); + } else { + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); + outputs = lowered_out->outputs; + impl = lowered_out->implementation; + } + + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + anchor_implementation_ = impl; + } + if (outputs.size() != 1) { + const auto* tuple_type = call_node->checked_type().as(); + ICHECK(tuple_type) << "Expect output to be a tuple type"; + ICHECK_EQ(tuple_type->fields.size(), outputs.size()); + } + // Set the name to `__copy`. It will be detected in graph runtime to perform + // data copy across devices. + if (op == device_copy_op_) { + readable_name_stream_.str(std::string()); + readable_name_stream_ << "__copy"; + } else { + readable_name_stream_ << '_' << op->name; + } + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Do not support sub function"; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + const auto* tuple_type = op->tuple->type_as(); + Array tuple = VisitExpr(op->tuple); + ICHECK_EQ(tuple_type->fields.size(), tuple.size()); + ICHECK_GE(op->index, 0); + ICHECK_LT(static_cast(op->index), tuple.size()); + return {tuple[op->index]}; + } + + private: + tvm::Target target_; + Op anchor_op_; + Attrs anchor_attrs_; + int anchor_op_pattern_{0}; + OpImplementation anchor_implementation_; + std::ostringstream readable_name_stream_; + Array scalars_; + bool use_auto_scheduler_; + // Cache device copy op for equivalence checking to reduce registry lookup + // overhead for each invocation of call node when retrieving schedules. + const Op& device_copy_op_; +}; + +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc PrimFuncFor(const Function& source_func, const Target& target, + std::function renamer) { + return ScheduleGetter(target).Create(source_func, renamer); +} + +// Creates shape function from functor. +class MakeShapeFunc : public backend::MemoizedExprTranslator> { + public: + MakeShapeFunc() {} + + CachedFunc Create(const Function& prim_func, const Target& target, + std::function renamer) { + Array inputs; + TShapeDataDependent shape_func_param_states; + + for (auto param : prim_func->params) { + param_states_[param] = kNoNeed; + Array data_inputs; + Array shape_inputs; + + auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { + // Add data placeholder + Shape shape = GetShape(ttype->shape); + tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); + data_inputs.push_back(data_tensor); + // Add shape placeholder + int64_t ndim = shape.size(); + Shape sshape; + if (ndim > 0) { + sshape.push_back(tvm::Integer(ndim)); + } + tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); + shape_inputs.push_back(shape_tensor); + }; + + if (const auto* ttype = param->checked_type().as()) { + add_placeholder(ttype); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + // TODO(@icemelon): Support recursive tuple + ICHECK(tuple_type); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + ICHECK(ttype); + add_placeholder(ttype); + } + } + param_data_[param] = data_inputs; + param_shapes_[param] = shape_inputs; + } + + // Setup the name; + readable_name_stream_ << "shape_func"; + + // Create the `te::Tensor`s which represent the output. + auto outputs = VisitExpr(prim_func->body); + + // Generate a name. + auto candidate_name = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + auto func_name = renamer(candidate_name); + + // Set all the inputs correctly. + for (auto param : prim_func->params) { + int state = param_states_[param]; + shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); + if (state & kNeedInputData) { + for (auto t : param_data_[param]) { + inputs.push_back(t); + } + } + if (state & kNeedInputShape) { + for (auto t : param_shapes_[param]) { + inputs.push_back(t); + } + } + } + + auto prim_fn_gvar = GlobalVar(func_name); + prim_fn_gvar->checked_type_ = prim_func->checked_type(); + + // generate schedule for shape func + Array out_ops; + for (auto t : outputs) { + out_ops.push_back(t->op); + } + auto schedule = te::create_schedule(out_ops); + tvm::te::AutoInlineInjective(schedule); + for (const auto& scalar : scalars_) { + auto scalar_op = scalar->op; + if (schedule->Contain(scalar_op)) { + schedule[scalar_op].compute_inline(); + } + } + + Array all_args = Array(inputs); + for (te::Tensor arg : outputs) { + all_args.push_back(arg); + } + + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + + std::unordered_map binds; + auto ir_module = tvm::lower(schedule, all_args, func_name, binds); + + return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states, + ir_module); + } + + Array VisitExpr(const Expr& expr) final { + if (expr.as()) { + // Do not memoize vars because shape functions could use either the data + // or the shape of a var each time. + return ExprFunctor::VisitExpr(expr); + } + // For other case, do memoized visit + return backend::MemoizedExprTranslator>::VisitExpr(expr); + } + + Array VisitExpr_(const VarNode* var_node) final { + auto var = GetRef(var_node); + auto it = param_states_.find(var); + if (it == param_states_.end()) { + LOG(FATAL) << "Free variable " << var->name_hint(); + return {}; + } else { + ICHECK(data_dependents_per_input_.size()); + auto data_dependent = data_dependents_per_input_.back(); + if (data_dependent) { + param_states_[var] |= kNeedInputData; + return param_data_[var]; + } else { + param_states_[var] |= kNeedInputShape; + return param_shapes_[var]; + } + } + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(data_dependents_per_input_.size()); + bool data_dependent = data_dependents_per_input_.back(); + if (!op->is_scalar()) { + // This is a constant weight, extract the shape of the weight tensor. + // This can not be data dependent. + CHECK(!data_dependent); + auto ttype = op->checked_type().as(); + int ndim = static_cast(ttype->shape.size()); + Array out_shape{ndim}; + te::Tensor value = tvm::te::compute( + out_shape, + [&](const Array& indices) { + auto idx = indices[0]; + PrimExpr ret = make_const(DataType::Int(64), 0); + for (int i = 0; i < ndim; i++) { + ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); + } + return ret; + }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } + if (data_dependent) { + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = tvm::te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "data_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } else { + auto value = tvm::te::compute( + {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto fshape_func = Op::GetAttrMap("FShapeFunc"); + static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) + << "Error in op fusion: output of the shape func is fed to a " + << "data-dependent shape func"; + ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; + ICHECK_GT(tshape_data_dependent.count(op), 0) + << "Internal error, cannot find TShapeDataDependent for " << op->name; + + Array dep_spec = tshape_data_dependent[op]; + if (dep_spec.size() == 1) { + // This is for cases when data dependence is specified per op + // Replicate 0 or 1 flag to all arguments + for (size_t i = 1; i < call_node->args.size(); ++i) { + dep_spec.push_back(dep_spec[0]); + } + } + + // Visit all inputs + Array inputs; + int count_tuple = 0; + for (size_t i = 0; i < call_node->args.size(); ++i) { + Expr arg = call_node->args[i]; + if (arg->checked_type().as()) { + ++count_tuple; + } + data_dependents_per_input_.push_back(dep_spec[i]->value != 0); + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + data_dependents_per_input_.pop_back(); + } + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; + } + // Get output ndims + auto ret_type = call_node->checked_type(); + Array out_ndims; + if (const auto* ttype = ret_type.as()) { + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); + } else { + auto rtype = ret_type.as(); + // TODO(@icemelon): Allow recursive tuple + ICHECK(rtype); + for (size_t i = 0; i < rtype->fields.size(); ++i) { + auto ttype = rtype->fields[i].as(); + ICHECK(ttype); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); + } + } + // Call shape function + auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); + readable_name_stream_ << "_" << op->name; + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Do not support sub function"; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + Array input_shapes = VisitExpr(op->tuple); + Array out; + out.push_back(input_shapes[op->index]); + return out; + } + + private: + /*! \brief String stream for function name */ + std::ostringstream readable_name_stream_; + /*! \brief Map from parameter to its shape function usage state */ + std::unordered_map param_states_; + /*! \brief Map from parameter to list of data placeholder */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; + /*! \brief Map from parameter to list of shape placeholder */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; + /*! \brief Stack of data dependencies for shape function, specified per each op input */ + std::vector data_dependents_per_input_; + /*! \brief Scalars used in the shape function */ + Array scalars_; +}; + +CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, + std::function renamer) { + return MakeShapeFunc().Create(prim_func, target, renamer); +} + +/*! + * \brief Get unique name from name. + * \param name The orginal name. + * \return Updated name which is unique. + */ +std::string GetUniqueName(std::string name, std::unordered_map name_map_) { + for (size_t i = 0; i < name.length(); ++i) { + if (name[i] == '.') name[i] = '_'; + } + while (true) { + auto it = name_map_.find(name); + if (it == name_map_.end()) { + name_map_[name] = 1; + return name; + } else { + std::ostringstream os; + os << name << "_" << it->second; + ++(it->second); + name = os.str(); + } + } + return name; +} + +} // namespace tec +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h new file mode 100644 index 000000000000..4f73fddd57c4 --- /dev/null +++ b/src/relay/backend/te_compiler_cache.h @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/tec_compiler_utils.h + * \brief Utilities for compiling tensor expressions inside of the Relay compiler. + */ +#ifndef TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ +#define TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../transforms/infer_layout_utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ +enum ShapeFuncParamState { + kNoNeed = 0, + kNeedInputData = 1, + kNeedInputShape = 2, + kNeedBoth = 3, +}; + +struct LoweredOutputNode : public Object { + /*! \brief The outputs to the function */ + tvm::Array outputs; + /*! \brief The implementation used to compute the output */ + OpImplementation implementation; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("outputs", &outputs); + v->Visit("implementation", &implementation); + } + + static constexpr const char* _type_key = "relay.LoweredOutput"; + TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); +}; + +class LoweredOutput : public ObjectRef { + public: + TVM_DLL LoweredOutput(tvm::Array outputs, OpImplementation impl); + + TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); +}; + +class CCacheKey; +/*! \brief Compile cache key */ +class CCacheKeyNode : public Object { + public: + /*! \brief The source function to be lowered. */ + Function source_func; + /*! \brief The hardware target.*/ + Target target; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("source_func", &source_func); + v->Visit("target", &target); + } + /*! \return The hash value of CCacheKey. */ + inline size_t Hash() const; + /*! + * \brief check content equality + * \param other The other value. + * \return The result of equality check. + */ + inline bool Equal(const CCacheKeyNode* other) const; + + static constexpr const char* _type_key = "relay.CCacheKey"; + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); + + private: + /*! + * \brief internal cached hash value. + */ + mutable size_t hash_{0}; +}; + +/*! \brief cache entry used in compile engine */ +class CCacheKey : public ObjectRef { + public: + CCacheKey() {} + explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief The constructor + * \param source_func The source function. + * \param target The target device. + */ + TVM_DLL CCacheKey(Function source_func, Target target); + + const CCacheKeyNode* operator->() const { return static_cast(get()); } + // comparator + inline bool operator==(const CCacheKey& other) const { + ICHECK(defined() && other.defined()); + return (*this)->Equal(other.operator->()); + } + using ContainerType = CCacheKeyNode; +}; + +/*! \brief Node container to represent a cached function. */ +struct CachedFuncNode : public Object { + /* \brief compiled target */ + tvm::Target target; + /*! \brief Primitive Function Name */ + GlobalVar prim_fn_var; + /* \brief The inputs to the function */ + tvm::Array inputs; + /* \brief The outputs to the function */ + tvm::Array outputs; + /*! \brief The schedule to the function */ + te::Schedule schedule; + /*! \brief Parameter usage states in the shape function. */ + tvm::Array shape_func_param_states; + /*! \brief The lowered functions to support the function. */ + IRModule funcs = IRModule(Map({})); + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("target", &target); + v->Visit("prim_fn_var", &prim_fn_var); + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + v->Visit("schedule", &schedule); + v->Visit("funcs", &funcs); + v->Visit("shape_func_param_states", &shape_func_param_states); + } + + static constexpr const char* _type_key = "relay.CachedFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); +}; + +class CachedFunc : public ObjectRef { + public: + CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array inputs, + tvm::Array outputs, te::Schedule schedule, + tvm::Array shape_func_param_states, + IRModule funcs = IRModule(Map({}))); + + public: + TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); +}; + +/*! \brief Node container for compile cache. */ +class CCacheValueNode : public Object { + public: + /*! \brief The corresponding function */ + CachedFunc cached_func; + /*! \brief Result of Packed function generated by JIT */ + PackedFunc packed_func; + /*! \brief usage statistics */ + int use_count{0}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cached_func", &cached_func); + v->Visit("use_count", &use_count); + } + static constexpr const char* _type_key = "relay.CCacheValue"; + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); +}; + +/*! \brief cache entry used in compile engine */ +class CCacheValue : public ObjectRef { + public: + CCacheValue() {} + explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} + CCacheValueNode* operator->() { return static_cast(get_mutable()); } + const CCacheValueNode* operator->() const { return static_cast(get()); } + using ContainerType = CCacheValueNode; +}; + +Array GetShape(const Array& shape); + +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc PrimFuncFor(const Function& source_func, const Target& target, + std::function renamer); + +CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, + std::function renamer); + +std::string GetUniqueName(std::string name, std::unordered_map name_map); + +// implementations +inline size_t CCacheKeyNode::Hash() const { + if (hash_ != 0) return hash_; + // do structral hash, avoid 0. + hash_ = tvm::StructuralHash()(this->source_func); + hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); + if (hash_ == 0) hash_ = 1; + return hash_; +} + +inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { + if (Hash() != other->Hash()) return false; + return this->target->str() == other->target->str() && + tvm::StructuralEqual()(this->source_func, other->source_func); +} + +} // namespace tec +} // namespace relay +} // namespace tvm + +namespace std { +// overload hash +template <> +struct hash<::tvm::relay::tec::CCacheKey> { + size_t operator()(const ::tvm::relay::tec::CCacheKey& key) const { + ICHECK(key.defined()); + return key->Hash(); + } +}; +} // namespace std + +#endif // TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index ad23e132c486..b00656001813 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -972,7 +972,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe // update primitive function map size_t primitive_index = 0; for (const auto& cfunc : context_.cached_funcs) { - exec_->primitive_map.insert({cfunc->func_name, primitive_index++}); + exec_->primitive_map.insert({cfunc->prim_fn_var->name_hint, primitive_index++}); } } @@ -1167,8 +1167,9 @@ void VMCompiler::Codegen() { if (target_str == "ext_dev") { // Collect metadata in functions that are handled by external codegen. - ICHECK(mod->ContainGlobalVar(cfunc->func_name)); - Function func = Downcast(mod->Lookup(cfunc->func_name)); + auto name = cfunc->prim_fn_var->name_hint; + ICHECK(mod->ContainGlobalVar(name)); + Function func = Downcast(mod->Lookup(name)); backend::UpdateConstants(func, ¶ms_); continue; } else if (funcs.count(target_str) == 0) { diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index c9920a621b56..d988ec847e75 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -62,9 +62,12 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body - << ", " << node->type_params << ", " << node->attrs << ")"; + // auto* node = static_cast(ref.get()); + // p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << + // node->body + // << ", " << node->type_params << ", " << node->attrs << ")"; + // we should standardize behavior + p->stream << PrettyPrint(ref); }); } // namespace relay diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index edc4119ce859..d5c03b113dc3 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -124,7 +124,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - CreateSchedule(GetRef(func), Target::Current()); + PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 03473b7d7455..a4d26c2b7a4f 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -43,6 +43,7 @@ #include "../backend/compile_engine.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" +#include "./pass_utils.h" #include "let_list.h" #include "pattern_utils.h" diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 4c6013792426..f29087dcc049 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -205,8 +205,13 @@ class TypeInferencer : private ExprFunctor, this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables " << "without a module"); } - relay::Function e = Downcast(mod_->Lookup(var)); - return e->checked_type(); + + if (mod_->ContainGlobalVar(var->name_hint)) { + relay::Function e = Downcast(mod_->Lookup(var)); + return e->checked_type(); + } else { + return op->checked_type_; + } } Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); } diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 584aafe3410b..2d6184f6d510 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -372,6 +372,7 @@ void GraphExecutor::SetupOpExecs() { uint32_t eid = this->entry_id(e); args.push_back(*(data_entry_[eid].operator->())); } + for (uint32_t index = 0; index < inode.param.num_outputs; ++index) { uint32_t eid = this->entry_id(nid, index); args.push_back(*(data_entry_[eid].operator->())); @@ -433,6 +434,7 @@ GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector& ICHECK(pf != nullptr) << "no such function in module: " << param.func_name; auto fexec = [arg_ptr, pf]() { + std::cout << "Number of args: " << static_cast(arg_ptr->arg_values.size()); TVMRetValue rv; TVMArgs targs(arg_ptr->arg_values.data(), arg_ptr->arg_tcodes.data(), static_cast(arg_ptr->arg_values.size())); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 24fb3dc95819..9ece234b4444 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -223,8 +223,12 @@ class LLVMModuleNode final : public runtime::ModuleNode { found_linked_params = true; continue; } - ICHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey(); + if (!kv.second->IsInstance()) { + // (@jroesch): we relax constraints here, Relay functions will just be ignored. + DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " + << kv.second->GetTypeKey(); + continue; + } auto f = Downcast(kv.second); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()); @@ -234,7 +238,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { } funcs.push_back(f); } - ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); + // ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index 4ec1c21467fc..ee202dec1f7e 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -127,6 +127,7 @@ def test_plan_memory(): func = relay.Function([x, y], z) mod = tvm.IRModule.from_expr(func) mod = relay.transform.InferType()(mod) + print(mod) mod = relay.transform.FuseOps(0)(mod) func = mod["main"] mod = relay.transform.InferType()(mod) @@ -145,7 +146,7 @@ def test_plan_memory(): # Current rule requires vars have unique storage id # because we don't do inplace, we will need another # two alternating temporary space. - assert len(storage_ids) == 4 + assert len(storage_ids) == 4, f"found storaged_ids: {storage_ids}" assert len(device_types) == 1 assert len(storage_sizes) == 4 @@ -288,11 +289,12 @@ def test_graph_executor_nested_tuples(): if __name__ == "__main__": - test_reshape_nop() - test_plan_memory() - test_with_params() - test_add_op_scalar() test_add_op_tensor() - test_add_op_broadcast() - test_gru_like() - test_compile_nested_tuples() + # test_plan_memory() + # test_with_params() + # test_add_op_scalar() + # test_add_op_tensor() + # test_add_op_broadcast() + # test_gru_like() + # test_compile_nested_tuples() + # test_add_op_tensor()