diff --git a/.coveragerc b/.coveragerc index 5a845fb4..ed687661 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,5 +1,7 @@ [run] source = depyf +# omit patched file from pytorch +omit = depyf/explain/patched* [report] include = depyf/* diff --git a/.github/workflows/test_decompile.yml b/.github/workflows/test_decompile.yml index 4edfb167..ea975953 100644 --- a/.github/workflows/test_decompile.yml +++ b/.github/workflows/test_decompile.yml @@ -45,6 +45,8 @@ jobs: run: | pytest --cov=depyf tests/test.py coverage run --append python_coverage.py + coverage run --append tests/test_code_owner.py + coverage run --append tests/test_ensure.py python tests/assert.py - name: Upload results to Codecov diff --git a/depyf/code_transform.py b/depyf/code_transform.py index 6b0b9602..20985aec 100644 --- a/depyf/code_transform.py +++ b/depyf/code_transform.py @@ -420,17 +420,6 @@ def visit_FunctionDef(self, node): # return self.generic_visit(node) -def structure_hash(source_code: str) -> str: - """Compute the hash of code structure, ignore the function name difference. - This is because PyTorch dynamically generates function names. - """ - tree = ast.parse(source_code) - tree = IdentifierReplacer().visit(tree) - modified_code = astor.to_source(tree) - hash_value = hashlib.md5(modified_code.encode()).hexdigest() - return hash_value - - def fix_irregular_code( old_bytecode: CodeType, src_code: str, diff --git a/depyf/decompiler.py b/depyf/decompiler.py index 2d550cd7..cdb0b78a 100644 --- a/depyf/decompiler.py +++ b/depyf/decompiler.py @@ -1133,7 +1133,8 @@ def cleanup_instructions(code, instructions: List[Instruction]): def __init__(self, code: Union[CodeType, Callable]): if callable(code): - code = code.__code__ + from depyf.utils import get_code_owner + code = get_code_owner(code).__code__ self.code = code instructions = list(convert_instruction(_) for _ in dis.get_instructions(code)) diff --git a/depyf/explain/__init__.py b/depyf/explain/__init__.py index e74676ca..31feb23a 100644 --- a/depyf/explain/__init__.py +++ b/depyf/explain/__init__.py @@ -10,16 +10,6 @@ def _extract_artifacts(original_code: CodeType, module): result = DynamoOptimizationResult(original_code, None, module) return result - -def _collect_compiled_subgraphs(result: DynamoOptimizationResult): - compiled_subgraphs = { - entry.compiled_subgraph_proxy.name: entry.compiled_subgraph for entry in result.compiled_code_entries} - for entry in result.compiled_code_entries: - for func in entry.referenced_global_functions.values(): - ans = _collect_compiled_subgraphs(func) - compiled_subgraphs.update(ans) - return compiled_subgraphs - def dump_src(original_code: CodeType, module): from depyf.explain.global_variables import data assert data["is_inside_prepare_debug"], "`dump_src` must be used inside `depyf.prepare_debug`." diff --git a/depyf/explain/enable_debugging.py b/depyf/explain/enable_debugging.py index 25db3a6a..ef98baa6 100644 --- a/depyf/explain/enable_debugging.py +++ b/depyf/explain/enable_debugging.py @@ -54,18 +54,14 @@ def __call__(self, code, new_code): import dill # code object, especially `new_code` constructed by Dynamo, may not be able to be dumped using `marshal`. # see https://github.com/pytorch/pytorch/issues/116013 for more details. - try: + with contextlib.suppress(Exception): dill.dump(code, open(filename + ".original_bytecode", "wb")) - except: - pass - try: + + with contextlib.suppress(Exception): dill.dump(new_code, open(filename + ".transformed_bytecode", "wb")) - except: - pass - try: + + with contextlib.suppress(Exception): dill.dump(decompiled_and_compiled_back_code, open(filename + ".decompiled_and_compiled_back_bytecode", "wb")) - except: - pass # this fix is used for PyTorch prior to PR https://github.com/pytorch/pytorch/pull/114487 from torch._dynamo.utils import orig_code_map diff --git a/depyf/explain/patched___call__.py b/depyf/explain/patched___call__.py index d90e5356..a5d033c7 100644 --- a/depyf/explain/patched___call__.py +++ b/depyf/explain/patched___call__.py @@ -1,6 +1,6 @@ def patched___call__(self, code, check_fn): from depyf.explain.global_variables import data - from depyf.explain.utils import get_code_owner + from depyf.utils import get_code_owner import torch unpatched___call__ = data["unpatched___call__"] optimized_functions = data["optimized_functions"] diff --git a/depyf/explain/patched_lazy_format_graph_code.py b/depyf/explain/patched_lazy_format_graph_code.py index 15d80df0..8cfa35da 100644 --- a/depyf/explain/patched_lazy_format_graph_code.py +++ b/depyf/explain/patched_lazy_format_graph_code.py @@ -1,5 +1,6 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): - from depyf.explain.utils import get_current_compiled_fn_name, get_code_owner, write_code_to_file_template + from depyf.explain.utils import get_current_compiled_fn_name, write_code_to_file_template + from depyf.utils import get_code_owner func_name = get_current_compiled_fn_name() file_name = name if name != func_name else "Captured Graph" file_name = func_name + " " + file_name diff --git a/depyf/explain/utils.py b/depyf/explain/utils.py index 2c98ce0f..1399455c 100644 --- a/depyf/explain/utils.py +++ b/depyf/explain/utils.py @@ -9,21 +9,6 @@ from dataclasses import dataclass import contextlib -import depyf -from depyf.decompiler import DecompilationError -from depyf.utils import get_function_signature - - -def decompile_ensure(fn, overwite_fn_name=None): - try: - decompiled_source_code = depyf.Decompiler( - fn).decompile(overwite_fn_name=overwite_fn_name) - except DecompilationError as e: - header = get_function_signature(fn, overwite_fn_name=overwite_fn_name) - decompiled_source_code = header + " 'Failed to decompile.'\n" - return decompiled_source_code - - class CodeProxy: instances: Dict[str, "CodeProxy"] = {} used_instances: Set[str] = set() @@ -49,6 +34,7 @@ def consume_new_name(name: str): @staticmethod def decompile_with_name(code: CodeType, name: str, skip_decompile=False): + from depyf.utils import decompile_ensure if hasattr(code, "__code__"): code = code.__code__ if code.co_name.startswith("transformed_code_") or code.co_name.startswith("__transformed_code_"): @@ -320,37 +306,6 @@ def write_code_to_file_template(src, path_template): return new_filepath -def get_code_owner(fn): - """A callable object `fn` might have a __code__ attribute, which is a code object. - However, `fn` might not be the owner of the code object. Only the code owner can change the code object. - This function returns the owner of the code object. - An example: - class A: - def func(self): - return 1 - a = A() - `a.func.__code__` is read-only. `A.func.__code__` is writable. - We can change the code object via `a.func.__func__.__code__`. - """ - import functools - while True: - if hasattr(fn, "__func__"): - # deal with bounded function - fn = fn.__func__ - elif hasattr(fn, "__wrapped__"): - # deal with lru_cache or other decorators - fn = fn.__wrapped__ - elif isinstance(fn, functools.partial): - # deal with partial function - fn = fn.func - elif hasattr(fn, "__call__") and hasattr(fn.__call__, "__func__"): - # deal with callable object - fn = fn.__call__.__func__ - else: - break - return fn - - def get_current_compiled_fn_name(): import torch from torch._dynamo.bytecode_transformation import _unique_id_counter diff --git a/depyf/utils.py b/depyf/utils.py index 88d0c7a1..92a224f5 100644 --- a/depyf/utils.py +++ b/depyf/utils.py @@ -43,3 +43,48 @@ def safe_create_directory(path): except OSError as e: if not os.path.isdir(path): raise + + + +def get_code_owner(fn): + """A callable object `fn` might have a __code__ attribute, which is a code object. + However, `fn` might not be the owner of the code object. Only the code owner can change the code object. + This function returns the owner of the code object. + An example: + class A: + def func(self): + return 1 + a = A() + `a.func.__code__` is read-only. `A.func.__code__` is writable. + We can change the code object via `a.func.__func__.__code__`. + """ + import functools + while True: + if hasattr(fn, "__func__"): + # deal with bounded function + fn = fn.__func__ + elif hasattr(fn, "__wrapped__"): + # deal with lru_cache or other decorators + fn = fn.__wrapped__ + elif isinstance(fn, functools.partial): + # deal with partial function + fn = fn.func + elif hasattr(fn, "__call__") and hasattr(fn.__call__, "__func__"): + # deal with callable object + fn = fn.__call__.__func__ + else: + break + return fn + + + +def decompile_ensure(fn: CodeType, overwite_fn_name=None): + import depyf + from depyf.decompiler import DecompilationError + try: + decompiled_source_code = depyf.Decompiler( + fn).decompile(overwite_fn_name=overwite_fn_name) + except DecompilationError as e: + header = get_function_signature(fn, overwite_fn_name=overwite_fn_name) + decompiled_source_code = header + " 'Failed to decompile.'\n" + return decompiled_source_code diff --git a/tests/test_code_owner.py b/tests/test_code_owner.py new file mode 100644 index 00000000..354a08ca --- /dev/null +++ b/tests/test_code_owner.py @@ -0,0 +1,16 @@ +from functools import partial, lru_cache + +def f(a, b): + return a + b + +class A: + def __call__(self, a, b): + return a + b + +import depyf + +print(depyf.decompile(partial(f, 1))) + +print(depyf.decompile(lru_cache(None)(f))) + +print(depyf.decompile(A())) diff --git a/tests/test_ensure.py b/tests/test_ensure.py new file mode 100644 index 00000000..a7bbb76b --- /dev/null +++ b/tests/test_ensure.py @@ -0,0 +1,11 @@ +from depyf.utils import decompile_ensure + +import asyncio + +def f(a, b): + try: + return a + b + finally: + return a - b + +print(decompile_ensure(f.__code__)) diff --git a/tests/test_pytorch/test_simple_graph.py b/tests/test_pytorch/test_simple_graph.py index c0d46f1c..675d9806 100644 --- a/tests/test_pytorch/test_simple_graph.py +++ b/tests/test_pytorch/test_simple_graph.py @@ -9,5 +9,5 @@ def fn(): return x.grad import depyf -with depyf.prepare_debug("./simple_output"): +with depyf.prepare_debug("./simple_output", log_bytecode=True, clean_wild_fx_code=False): fn()