Skip to content

Commit

Permalink
Fix auto scheduler code
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Jul 6, 2021
1 parent 2fd991e commit 7e68f87
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -726,19 +726,23 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic

auto updated_module = pass(module);

const auto* te_compiler_update_weights =
runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights");
// A temporary solution until we can rewrite the auto-scheduler task extraction code to work
// in a more reasonable way.
if (backend::IsAutoSchedulerEnabled()) {
const auto* te_compiler_update_weights =
runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights");

ICHECK(te_compiler_update_weights != nullptr)
<< "auto_scheduler.relay_integration.te_compiler_update_weights";
ICHECK(te_compiler_update_weights != nullptr)
<< "auto_scheduler.relay_integration.te_compiler_update_weights";

Map<String, tvm::Integer> weight_map;
Map<String, tvm::Integer> weight_map;

for (auto pair : compiler->GetOpWeights()) {
weight_map.Set(pair.first, pair.second);
}
for (auto pair : compiler->GetOpWeights()) {
weight_map.Set(pair.first, pair.second);
}

(*te_compiler_update_weights)(weight_map);
(*te_compiler_update_weights)(weight_map);
}

LoweredModule lowered_module;
lowered_module.main_module = updated_module;
Expand Down

0 comments on commit 7e68f87

Please sign in to comment.