From beed3b278c0b28eb9310a8e64e2396c9228c7bd1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Dec 2024 15:31:36 -0800 Subject: [PATCH 1/6] use submodules Signed-off-by: youkaichao --- depyf/explain/patched_lazy_format_graph_code.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/depyf/explain/patched_lazy_format_graph_code.py b/depyf/explain/patched_lazy_format_graph_code.py index 2ab26ad0..e7b27ed4 100644 --- a/depyf/explain/patched_lazy_format_graph_code.py +++ b/depyf/explain/patched_lazy_format_graph_code.py @@ -24,8 +24,12 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): # update file path filepath = inspect.getsourcefile(fn) # try to use verbose code with type and shape annotations + use_gm = True + + # use `print_readable` because it can include submodules src = "from __future__ import annotations\n" + \ - gm._graph.python_code(root_module="self", verbose=True).src + gm.print_readable(print_output=False) + src = src.replace("", "GraphModule") try: compile(src, "noname", "exec") except Exception as e: @@ -38,13 +42,17 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): commented_src += "".join(["# " + line + "\n" for line in src.splitlines()]) src = simple_code + commented_src + use_gm = False if filepath is not None: new_filepath = write_code_to_file_template( src, os.path.dirname(filepath) + "/" + file_name + "." + "%s" + ".py") scope = fn.__globals__ exec(compile(src, filename=new_filepath, mode="exec"), scope) - fn.__code__ = scope[fn.__name__].__code__ - del scope[fn.__name__] + if use_gm: + fn.__code__ = getattr(scope["GraphModule"], fn.__name__).__code__ + else: + fn.__code__ = scope[fn.__name__].__code__ + del scope[fn.__name__] # ========================================= # original code of `lazy_format_graph_code` From 87f5c38bfaf9ef84b7c9284ad3ec6d1d7cbf5f3b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Dec 2024 15:37:46 -0800 Subject: [PATCH 2/6] fix python 3.7 Signed-off-by: youkaichao --- .github/workflows/test_decompile.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_decompile.yml b/.github/workflows/test_decompile.yml index ea975953..697d1a7f 100644 --- a/.github/workflows/test_decompile.yml +++ b/.github/workflows/test_decompile.yml @@ -12,7 +12,7 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: From 810a8fb599471ac1a9c6ed2b5d4b558def3bd1ce Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Dec 2024 15:39:36 -0800 Subject: [PATCH 3/6] add import torch Signed-off-by: youkaichao --- depyf/explain/patched_lazy_format_graph_code.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/depyf/explain/patched_lazy_format_graph_code.py b/depyf/explain/patched_lazy_format_graph_code.py index e7b27ed4..af238c9c 100644 --- a/depyf/explain/patched_lazy_format_graph_code.py +++ b/depyf/explain/patched_lazy_format_graph_code.py @@ -27,7 +27,7 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): use_gm = True # use `print_readable` because it can include submodules - src = "from __future__ import annotations\n" + \ + src = "from __future__ import annotations\nimport torch\n" + \ gm.print_readable(print_output=False) src = src.replace("", "GraphModule") try: From d1c71d69e7b9bb22cb350827cc4d43a639543255 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Dec 2024 15:42:46 -0800 Subject: [PATCH 4/6] live debugging Signed-off-by: youkaichao --- depyf/explain/patched_lazy_format_graph_code.py | 1 + 1 file changed, 1 insertion(+) diff --git a/depyf/explain/patched_lazy_format_graph_code.py b/depyf/explain/patched_lazy_format_graph_code.py index af238c9c..eb5e3759 100644 --- a/depyf/explain/patched_lazy_format_graph_code.py +++ b/depyf/explain/patched_lazy_format_graph_code.py @@ -30,6 +30,7 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): src = "from __future__ import annotations\nimport torch\n" + \ gm.print_readable(print_output=False) src = src.replace("", "GraphModule") + print(src) try: compile(src, "noname", "exec") except Exception as e: From f03fd0e226123a3013a9134f8f4ea119e38c1f16 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Dec 2024 15:47:56 -0800 Subject: [PATCH 5/6] fix export Signed-off-by: youkaichao --- depyf/explain/patched_lazy_format_graph_code.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/depyf/explain/patched_lazy_format_graph_code.py b/depyf/explain/patched_lazy_format_graph_code.py index eb5e3759..28217d92 100644 --- a/depyf/explain/patched_lazy_format_graph_code.py +++ b/depyf/explain/patched_lazy_format_graph_code.py @@ -30,7 +30,6 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): src = "from __future__ import annotations\nimport torch\n" + \ gm.print_readable(print_output=False) src = src.replace("", "GraphModule") - print(src) try: compile(src, "noname", "exec") except Exception as e: @@ -50,7 +49,11 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): scope = fn.__globals__ exec(compile(src, filename=new_filepath, mode="exec"), scope) if use_gm: - fn.__code__ = getattr(scope["GraphModule"], fn.__name__).__code__ + import torch + classes = [v for v in scope.values() if isinstance(v, type) and issubclass(v, torch.nn.Module)] + assert len(classes) == 1 + module_class = classes[0] + fn.__code__ = getattr(module_class, fn.__name__).__code__ else: fn.__code__ = scope[fn.__name__].__code__ del scope[fn.__name__] From 614f4e5149db705a3fb819bd7725de08fcf0109d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Dec 2024 16:06:55 -0800 Subject: [PATCH 6/6] comment Signed-off-by: youkaichao --- .github/workflows/test_decompile.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_decompile.yml b/.github/workflows/test_decompile.yml index 697d1a7f..b3a526f3 100644 --- a/.github/workflows/test_decompile.yml +++ b/.github/workflows/test_decompile.yml @@ -12,7 +12,7 @@ on: jobs: build: - runs-on: ubuntu-22.04 + runs-on: ubuntu-22.04 # ubuntu-latest does not support python 3.7 strategy: fail-fast: false matrix: