Skip to content

Commit

Permalink
Remove extra call of PlanDevice in OptimizeImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
elvin-n committed Jul 22, 2022
1 parent 1b03e8d commit 9a46ad2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
15 changes: 2 additions & 13 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ class RelayBuildModule : public runtime::ModuleNode {

// Fuse the operations if it is needed.
pass_seqs.push_back(transform::FuseOps());
pass_seqs.push_back(transform::PlanDevices(config_));

// Create a sequential pass and perform optimizations.
transform::Pass seq = transform::Sequential(pass_seqs);
Expand Down Expand Up @@ -396,20 +397,8 @@ class RelayBuildModule : public runtime::ModuleNode {
relay_module = transform::Inline()(relay_module);
relay_module = transform::InferType()(relay_module);
relay_module = transform::LabelOps()(relay_module);

relay_module = transform::AnnotateMemoryScope(config_)(relay_module);
pass_seqs = GetPassPrefix(
/*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false);
pass_seqs.push_back(transform::PlanDevices(config_));
// Create a sequential pass and perform optimizations.
seq = transform::Sequential(pass_seqs);
if (config_->optional_homogeneous_target.defined()) {
With<Target> tctx(config_->optional_homogeneous_target);
relay_module = seq(relay_module);
} else {
relay_module = seq(relay_module);
}
relay_module = transform::InferType()(relay_module);

ICHECK(relay_module.defined());

return relay_module;
Expand Down
6 changes: 5 additions & 1 deletion src/relay/transforms/annotate_texture_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,11 @@ Map<Expr, Array<String>> CollectStorageInfo(const Expr& expr) {

Expr AnnotateMemoryScopeExpr(const Expr& expr, const IRModule& mod, CompilationConfig config) {
auto storage_scope = CollectStorageInfo(expr);
return VDRewriter(storage_scope).Rewrite(expr);
if (storage_scope.size()) {
return VDRewriter(storage_scope).Rewrite(expr);
} else {
return expr;
}
}

namespace transform {
Expand Down

0 comments on commit 9a46ad2

Please sign in to comment.