diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 506a4934fe..4be3e403aa 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -37,7 +37,9 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { torch::jit::EliminateCommonSubexpression(g); } torch::jit::EliminateDeadCode(g); - passes::MarkNodesForFallback(g, true); + if (lower_info.forced_fallback_modules.size() > 0) { + passes::MarkNodesForFallback(g, true); + } passes::UnpackHardSwish(g); passes::EliminateExceptionOrPassPattern(g); passes::ReduceToOperation(g); @@ -60,12 +62,13 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { LOG_GRAPH(*g); } -torch::jit::Module LowerModule( - const torch::jit::Module& mod, - std::string method_name, - std::unordered_set forced_fallback_modules) { - passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules); - LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph()); +torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) { + std::unordered_set forced_fallback_modules( + lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end()); + if (forced_fallback_modules.size() > 0) { + passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules); + LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph()); + } auto mod_ = torch::jit::freeze_module(mod); LOG_GRAPH("After freeze: " << *mod_.get_method(method_name).graph()); return mod_; @@ -77,9 +80,7 @@ std::pair, std::vector> L const LowerInfo& lower_info) { LOG_DEBUG(lower_info); LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph()); - std::unordered_set forced_fallback_modules( - lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end()); - auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, forced_fallback_modules); + auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, lower_info); auto g = lowered_mod.get_method(method_name).graph(); LOG_GRAPH("LibTorch Lowering"); diff --git a/core/lowering/passes/module_fallback.cpp b/core/lowering/passes/module_fallback.cpp index 9061130f4e..be7f7497b5 100644 --- a/core/lowering/passes/module_fallback.cpp +++ b/core/lowering/passes/module_fallback.cpp @@ -39,7 +39,7 @@ void NotateModuleForFallback( if (n->kind() == torch::jit::prim::GetAttr) { auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type())); if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) { - LOG_DEBUG( + LOG_GRAPH( "Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name << " (" << cls_name << ")]"); auto uses = n->output(0)->uses(); @@ -58,11 +58,32 @@ void NotateModuleForFallback( } if (changed_mod) { - LOG_DEBUG("Notated graph: " << *g); + LOG_GRAPH("Notated graph: " << *g); } - for (const auto sub_mod : mod.named_children()) { - NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules); + if (mod.named_children().size() > 0) { + for (const auto n : nodes) { + std::string sub_method_name = ""; + if (n->kind() == torch::jit::prim::CallMethod) { + sub_method_name = n->s(c10::Symbol::attr("name")); + auto sub_mod_val = n->input(0); + auto sub_mod_src_n = sub_mod_val->node(); + if (!sub_mod_src_n->hasAttributeS("name")) { + LOG_GRAPH("Node: " << util::node_info(sub_mod_src_n) << " manages a module with no name, skipping"); + break; + } + auto sub_mod_name = sub_mod_src_n->s(c10::Symbol::attr("name")); + for (const auto sub_mod : mod.named_children()) { + // Theres probably a way to directly access the module we care about + if (sub_mod.name == sub_mod_name) { + LOG_GRAPH( + "Looking at .() next: " << sub_mod_name << "." << sub_method_name + << "() (lowering.passes.NotateModuleForFallback)"); + NotateModuleForFallback(sub_mod.value, sub_mod.name, sub_method_name, forced_fallback_modules); + } + } + } + } } } @@ -74,7 +95,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del auto n = *it; if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) { if (n->s(c10::Symbol::attr("compilation_edge")) == "start") { - LOG_DEBUG("Starting to mark new segmented block targeted for torch"); + LOG_GRAPH("Starting to mark new segmented block targeted for torch"); mark.push(true); if (delete_delims) { it.destroyCurrent(); @@ -82,7 +103,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del } } else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) { if (n->s(c10::Symbol::attr("compilation_edge")) == "start") { - LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block"); + LOG_GRAPH("Found the start of another segmented block targeted for torch while actively marking a block"); mark.push(true); if (delete_delims) { it.destroyCurrent(); @@ -90,7 +111,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del } } else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) { if (n->s(c10::Symbol::attr("compilation_edge")) == "end") { - LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block"); + LOG_GRAPH("Found the end of segmented block targeted for torch while actively marking a block"); mark.pop(); if (delete_delims) { it.destroyCurrent(); @@ -106,7 +127,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del } } - LOG_DEBUG("After marking operations for torch fallback: " << *g); + LOG_GRAPH("After marking operations for torch fallback: " << *g); } } // namespace passes