Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed May 3, 2021
1 parent a229041 commit a561718
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 31 deletions.
1 change: 0 additions & 1 deletion python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def compile_cuda(code, target="ptx", arch=None, options=None, path_target=None):

cmd += ["-o", file_target]
cmd += [temp_code]

cxx_compiler_path = tvm.support.libinfo().get("TVM_CXX_COMPILER_PATH")
if cxx_compiler_path != "":
# This tells nvcc where to find the c++ compiler just in case it is not in the path.
Expand Down
30 changes: 3 additions & 27 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
auto storage_and_device = it.second;
ICHECK_EQ(storage_and_device.size(), 2u);
auto device_type = storage_and_device[1];
std::cout << PrettyPrint(expr) << std::endl;
std::cout << device_type << std::endl;
tvm::Device dev;
dev.device_id = 0;
dev.device_type = static_cast<DLDeviceType>(device_type[0]->value);
Expand All @@ -226,6 +228,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
});

auto main_module = lowered_module.main_module;
std::cout << "MainModule: " << main_module << std::endl;
main_module = relay::transform::InferType()(main_module);
relay::Function main_func = Downcast<relay::Function>(main_module->Lookup("main"));

Expand Down Expand Up @@ -388,33 +391,6 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
if (auto global_node = call->op.as<GlobalVarNode>()) {
auto prim_fn_name = global_node->name_hint;

Target target;

ICHECK_GE(storage_device_map_.count(call), 0)
<< "Could not find a storage device for " << prim_fn_name
<< "The memory planning was either not performed for this precise node, or there is bug "
"in the memory planner.";

auto& device_type = storage_device_map_[call][1];
auto call_dev_type = device_type[0]->value;
// Normal Relay Function
if (targets_.size() == 1) {
// homogeneous execution.
const auto& it = targets_.begin();
target = (*it).second;
} else {
// heterogeneous execution.
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(call_dev_type);
}
if (targets_.count(call_dev_type) == 0) {
LOG(FATAL) << "No target is provided for device " << call_dev_name;
}
target = targets_[call_dev_type];
}

return GraphAddCallNode(call_node, _GetUniqueName(prim_fn_name), prim_fn_name);
} else {
Expand Down
21 changes: 19 additions & 2 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,11 @@ class LowerTensorExpr : public ExprMutator {
return Call(ext_func->prim_fn_var, args, {});
}

ICHECK_GE(device_context_map_.count(expr), 0);
ICHECK_GE(device_context_map_.count(expr), 0)
<< "Could not find an entry in the device context map for " << PrettyPrint(expr)
<< "The memory planning was either not performed for this precise node, or there is bug "
"in the memory planner.";

auto& device_context = this->device_context_map_[expr];
auto call_dev_type = device_context.device_type;

Expand All @@ -317,20 +321,33 @@ class LowerTensorExpr : public ExprMutator {
const auto& it = targets_.begin();
target = (*it).second;
} else {
std::cout << "DeviceType: " << call_dev_type << std::endl;
// The heterogeneous execution case we have multiple targets
// in this case.
//
// We need to identify the target and translate.
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
call_dev_type = kDLCPU;
} else {
call_dev_name = ::tvm::runtime::DeviceName(call_dev_type);
}

if (targets_.count(call_dev_type) == 0) {
LOG(FATAL) << "No target is provided for device " << call_dev_name;
std::stringstream msg;
msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n";
msg << call_dev_name << " mapped to device type (" << call_dev_type << ") which was not found in the target map.\n";
msg << "Availible targets: \n";
for (auto target : targets_) {
msg << " " << target.first << "-> " << target.second << "\n";
}
LOG(FATAL) << msg.str();
}

std::cout << "DeviceName: " << call_dev_name << std::endl;
target = targets_[call_dev_type];
std::cout << "Target: " << target << std::endl;
}

CCacheKey key = CCacheKey(func, target);
Expand Down
8 changes: 7 additions & 1 deletion src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
candidate_name = truncated_name.str();
}

auto prim_fn_name = renamer(candidate_name);
// NB(@jroesch): unfortunately the graph runtime deals with copy in
// a totally hacky way, we really need to rectify this but this will
// have to work for now.
std::string prim_fn_name = candidate_name;
if (prim_fn_name != "__copy") {
prim_fn_name = renamer(prim_fn_name);
}
auto prim_fn_var = GlobalVar(prim_fn_name);
prim_fn_var->checked_type_ = prim_func->checked_type();

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/graph_executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector<DLTensor>&
}
}

std::cout << "Executing: " << param.func_name << std::endl;
if (param.func_name == "__nop") {
return {[]() {}, arg_ptr};
} else if (param.func_name == "__copy") {
Expand All @@ -422,6 +423,7 @@ GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector<DLTensor>&
auto fexec = [arg_ptr]() {
DLTensor* from = static_cast<DLTensor*>(arg_ptr->arg_values[0].v_handle);
DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle);
std::cout << "from: " << from->device.device_type << "to: " << to->device.device_type << std::endl;
TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr));
};
return {fexec, arg_ptr};
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def check_graph_executor(
device_index = graph_json["attrs"]["device_index"][1]
assert device_index == expected_index
mod = graph_executor.create(graph, lib, contexts)
import pdb; pdb.set_trace()
mod.set_input(**new_params)
mod.run()
res = mod.get_output(0).asnumpy()
Expand Down

0 comments on commit a561718

Please sign in to comment.