diff --git a/frontend/bytecode_writter.py b/frontend/bytecode_writter.py index 1f669944ec6c..3669cae08bb2 100644 --- a/frontend/bytecode_writter.py +++ b/frontend/bytecode_writter.py @@ -441,7 +441,7 @@ def rewrite_bytecode(code: types.CodeType, frame_id: int, for original_inst, inst in zip(original_instructions, instructions): inst.original_inst = original_inst instructions[0].is_start = True - print(format_insts(instructions)) + # print(format_insts(instructions)) frame_cache = get_frame_cache(frame_id) # list of (start_pc, traced_instructions) run_traced_insts: list[tuple[int, list[Instruction]]] = [] @@ -506,8 +506,8 @@ def rewrite_bytecode(code: types.CodeType, frame_id: int, list(new_names_all["names"])) strip_extended_args(instructions) fix_instructions_for_assemble(instructions, code_options) - print("guarded code") - print(format_insts(instructions)) + # print("guarded code") + # print(format_insts(instructions)) code_map = generate_code_map(original_instructions, instructions, in_trace_insts, next_original_pc) new_code = assemble_instructions(instructions, code_options)[1] diff --git a/frontend/cache.py b/frontend/cache.py index 5984e3664b96..a15efb166ec3 100644 --- a/frontend/cache.py +++ b/frontend/cache.py @@ -1,5 +1,5 @@ from types import CodeType -from typing import Callable, Any, Optional +from typing import Callable, Any, Optional, Tuple from dataclasses import dataclass from frontend.code import ProcessedCode @@ -30,9 +30,9 @@ class FrameCache: list[CachedGraph]] # start_pc -> list of cached graph callsite_id: dict[int, int] # start_pc -> callsite_id pre_cache_size: int - new_code: Optional[CodeType] - code_map: Optional[ProcessedCode] updated: bool + # 0 for root, 1 for callee + code: list[Optional[Tuple[CodeType, ProcessedCode]]] def __init__(self, frame_id: int) -> None: self.frame_id = frame_id @@ -40,6 +40,7 @@ def __init__(self, frame_id: int) -> None: self.callsite_id = {0: 0} self.new_code = None self.code_map = None + self.code = [None, None] self.updated = True # rewrite bytecode for the first time def add(self, traced_code: CachedGraph) -> None: @@ -58,9 +59,27 @@ def add(self, traced_code: CachedGraph) -> None: TOTAL_SIZE += 1 self.updated = True - def set_new_code(self, new_code: CodeType, code_map: ProcessedCode) -> None: - self.new_code = new_code - self.code_map = code_map + def set_new_code(self, new_code: CodeType, code_map: ProcessedCode, + is_callee: bool) -> None: + self.code[is_callee] = (new_code, code_map) + + def get_new_code(self, is_callee: bool) -> Tuple[CodeType, ProcessedCode]: + code = self.code[is_callee] + assert code is not None + return code + + def is_valid(self, is_callee: bool) -> bool: + return not self.updated and self.code[is_callee] is not None + + def update_code(self, f_code: CodeType, frame_id: int, + is_callee: bool) -> None: + if not self.is_valid(is_callee): + from .bytecode_writter import rewrite_bytecode + for i in (False, True): + if i == is_callee or self.code[i] is not None: + print("new_code for is_callee =", i) + new_code, code_map = rewrite_bytecode(f_code, frame_id, i) + self.set_new_code(new_code, code_map, i) self.updated = False diff --git a/frontend/tracer.py b/frontend/tracer.py index 34aac2661811..e9ecac95572b 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -8,7 +8,6 @@ from .cache import enable_cache, check_cache_updated, get_frame_cache from .fx_graph import set_frame_root from .c_api import set_eval_frame, mark_need_postprocess -from .bytecode_writter import rewrite_bytecode from .code import ProcessedCode from .instruction import format_insts from .config import get_config @@ -85,25 +84,12 @@ def preprocess_frame( hex(id(frame)), frame.f_code.co_name) enable_cache(frame_id) set_frame_root(frame_id, f) - if check_cache_updated(frame_id): - print("new bytecode: \n") - new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, - is_callee) - get_frame_cache(frame_id).set_new_code(new_code, code_map) - trace_func = get_trace_func(frame_id) - - else: - old_frame = get_frame_cache(frame_id) - assert old_frame.code_map is not None, "Code map doesn't exist for frame id {}".format( - frame_id) - assert old_frame.new_code is not None, "New code doesn't exist for frame id {}".format( - frame_id) - if is_debug: - print("old bytecode: \n") - print(format_insts(old_frame.code_map.guard_insts)) - new_code = old_frame.new_code - code_map = old_frame.code_map - trace_func = get_trace_func(frame_id) + frame_cache = get_frame_cache(frame_id) + frame_cache.update_code(frame.f_code, frame_id, is_callee) + new_code, code_map = frame_cache.get_new_code(is_callee) + print("bytecode to run:") + print(format_insts(code_map.guard_insts)) + trace_func = get_trace_func(frame_id) except Exception as e: print("exception in preprocess:", e, type(e)) @@ -117,13 +103,9 @@ def postprocess_frame(frame: FrameType, frame_id: int) -> None: if SHOULD_NOT_CALL_REWRITE: raise ValueError("should not call postprocess") print(f"postprocess frame {frame.f_code.co_filename}") - if check_cache_updated(frame_id): - print("new bytecode: \n") - set_frame_root(frame_id, f) - new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, - is_callee) - get_frame_cache(frame_id).set_new_code(new_code, code_map) - + set_frame_root(frame_id, f) + frame_cache = get_frame_cache(frame_id) + frame_cache.update_code(frame.f_code, frame_id, is_callee) except Exception as e: print("exception in postprocess:", e, type(e)) print(traceback.format_exc()) diff --git a/test/test_call_udf.py b/test/test_call_udf.py index 5cc94ed2fd6a..e7c8164b2a9f 100644 --- a/test/test_call_udf.py +++ b/test/test_call_udf.py @@ -320,4 +320,23 @@ def test_call_run_udf(caplog): result = call_run_udf(x) compiled_model = compile(call_run_udf) run_and_check(compiled_model, [MISS, MISS], 1, caplog, result, x) - run_and_check(compiled_model, [HIT], 1, caplog, result, x) \ No newline at end of file + run_and_check(compiled_model, [HIT], 1, caplog, result, x) + + +def b_aba(a): + return a_aba(a + 2.0, 0) + + +def a_aba(a, b): + if b == 0: + return a + 3.0 + else: + return b_aba(a + 1.0) + + +def test_call_aba(caplog): + reset() + compiled = compile(a_aba) + expect = a_aba(1.0, 1) + run_and_check(compiled, [MISS, MISS, MISS], 1, caplog, expect, 1.0, 1) + run_and_check(compiled, [HIT], 1, caplog, expect, 1.0, 1)