From 7ffbe3121113dadccdfb2a0b953150750cf7ac3a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Dec 2024 16:12:46 -0800 Subject: [PATCH] use print_readable for submodules (#74) Signed-off-by: youkaichao --- .github/workflows/test_decompile.yml | 2 +- .../explain/patched_lazy_format_graph_code.py | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test_decompile.yml b/.github/workflows/test_decompile.yml index ea975953..b3a526f3 100644 --- a/.github/workflows/test_decompile.yml +++ b/.github/workflows/test_decompile.yml @@ -12,7 +12,7 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 # ubuntu-latest does not support python 3.7 strategy: fail-fast: false matrix: diff --git a/depyf/explain/patched_lazy_format_graph_code.py b/depyf/explain/patched_lazy_format_graph_code.py index 2ab26ad0..28217d92 100644 --- a/depyf/explain/patched_lazy_format_graph_code.py +++ b/depyf/explain/patched_lazy_format_graph_code.py @@ -24,8 +24,12 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): # update file path filepath = inspect.getsourcefile(fn) # try to use verbose code with type and shape annotations - src = "from __future__ import annotations\n" + \ - gm._graph.python_code(root_module="self", verbose=True).src + use_gm = True + + # use `print_readable` because it can include submodules + src = "from __future__ import annotations\nimport torch\n" + \ + gm.print_readable(print_output=False) + src = src.replace("", "GraphModule") try: compile(src, "noname", "exec") except Exception as e: @@ -38,13 +42,21 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): commented_src += "".join(["# " + line + "\n" for line in src.splitlines()]) src = simple_code + commented_src + use_gm = False if filepath is not None: new_filepath = write_code_to_file_template( src, os.path.dirname(filepath) + "/" + file_name + "." + "%s" + ".py") scope = fn.__globals__ exec(compile(src, filename=new_filepath, mode="exec"), scope) - fn.__code__ = scope[fn.__name__].__code__ - del scope[fn.__name__] + if use_gm: + import torch + classes = [v for v in scope.values() if isinstance(v, type) and issubclass(v, torch.nn.Module)] + assert len(classes) == 1 + module_class = classes[0] + fn.__code__ = getattr(module_class, fn.__name__).__code__ + else: + fn.__code__ = scope[fn.__name__].__code__ + del scope[fn.__name__] # ========================================= # original code of `lazy_format_graph_code`