diff --git a/depyf/explain/patched_lazy_format_graph_code.py b/depyf/explain/patched_lazy_format_graph_code.py index eb5e3759..28217d92 100644 --- a/depyf/explain/patched_lazy_format_graph_code.py +++ b/depyf/explain/patched_lazy_format_graph_code.py @@ -30,7 +30,6 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): src = "from __future__ import annotations\nimport torch\n" + \ gm.print_readable(print_output=False) src = src.replace("", "GraphModule") - print(src) try: compile(src, "noname", "exec") except Exception as e: @@ -50,7 +49,11 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): scope = fn.__globals__ exec(compile(src, filename=new_filepath, mode="exec"), scope) if use_gm: - fn.__code__ = getattr(scope["GraphModule"], fn.__name__).__code__ + import torch + classes = [v for v in scope.values() if isinstance(v, type) and issubclass(v, torch.nn.Module)] + assert len(classes) == 1 + module_class = classes[0] + fn.__code__ = getattr(module_class, fn.__name__).__code__ else: fn.__code__ = scope[fn.__name__].__code__ del scope[fn.__name__]