Skip to content

Commit

Permalink
support call frame0 (apache#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Nov 27, 2023
2 parents da71124 + 4389f52 commit b8e6464
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 37 deletions.
6 changes: 3 additions & 3 deletions frontend/bytecode_writter.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def rewrite_bytecode(code: types.CodeType, frame_id: int,
for original_inst, inst in zip(original_instructions, instructions):
inst.original_inst = original_inst
instructions[0].is_start = True
print(format_insts(instructions))
# print(format_insts(instructions))
frame_cache = get_frame_cache(frame_id)
# list of (start_pc, traced_instructions)
run_traced_insts: list[tuple[int, list[Instruction]]] = []
Expand Down Expand Up @@ -506,8 +506,8 @@ def rewrite_bytecode(code: types.CodeType, frame_id: int,
list(new_names_all["names"]))
strip_extended_args(instructions)
fix_instructions_for_assemble(instructions, code_options)
print("guarded code")
print(format_insts(instructions))
# print("guarded code")
# print(format_insts(instructions))
code_map = generate_code_map(original_instructions, instructions,
in_trace_insts, next_original_pc)
new_code = assemble_instructions(instructions, code_options)[1]
Expand Down
31 changes: 25 additions & 6 deletions frontend/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import CodeType
from typing import Callable, Any, Optional
from typing import Callable, Any, Optional, Tuple
from dataclasses import dataclass

from frontend.code import ProcessedCode
Expand Down Expand Up @@ -30,16 +30,17 @@ class FrameCache:
list[CachedGraph]] # start_pc -> list of cached graph
callsite_id: dict[int, int] # start_pc -> callsite_id
pre_cache_size: int
new_code: Optional[CodeType]
code_map: Optional[ProcessedCode]
updated: bool
# 0 for root, 1 for callee
code: list[Optional[Tuple[CodeType, ProcessedCode]]]

def __init__(self, frame_id: int) -> None:
self.frame_id = frame_id
self.cached_graphs = {0: []}
self.callsite_id = {0: 0}
self.new_code = None
self.code_map = None
self.code = [None, None]
self.updated = True # rewrite bytecode for the first time

def add(self, traced_code: CachedGraph) -> None:
Expand All @@ -58,9 +59,27 @@ def add(self, traced_code: CachedGraph) -> None:
TOTAL_SIZE += 1
self.updated = True

def set_new_code(self, new_code: CodeType, code_map: ProcessedCode) -> None:
self.new_code = new_code
self.code_map = code_map
def set_new_code(self, new_code: CodeType, code_map: ProcessedCode,
is_callee: bool) -> None:
self.code[is_callee] = (new_code, code_map)

def get_new_code(self, is_callee: bool) -> Tuple[CodeType, ProcessedCode]:
code = self.code[is_callee]
assert code is not None
return code

def is_valid(self, is_callee: bool) -> bool:
return not self.updated and self.code[is_callee] is not None

def update_code(self, f_code: CodeType, frame_id: int,
is_callee: bool) -> None:
if not self.is_valid(is_callee):
from .bytecode_writter import rewrite_bytecode
for i in (False, True):
if i == is_callee or self.code[i] is not None:
print("new_code for is_callee =", i)
new_code, code_map = rewrite_bytecode(f_code, frame_id, i)
self.set_new_code(new_code, code_map, i)
self.updated = False


Expand Down
36 changes: 9 additions & 27 deletions frontend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .cache import enable_cache, check_cache_updated, get_frame_cache
from .fx_graph import set_frame_root
from .c_api import set_eval_frame, mark_need_postprocess
from .bytecode_writter import rewrite_bytecode
from .code import ProcessedCode
from .instruction import format_insts
from .config import get_config
Expand Down Expand Up @@ -85,25 +84,12 @@ def preprocess_frame(
hex(id(frame)), frame.f_code.co_name)
enable_cache(frame_id)
set_frame_root(frame_id, f)
if check_cache_updated(frame_id):
print("new bytecode: \n")
new_code, code_map = rewrite_bytecode(frame.f_code, frame_id,
is_callee)
get_frame_cache(frame_id).set_new_code(new_code, code_map)
trace_func = get_trace_func(frame_id)

else:
old_frame = get_frame_cache(frame_id)
assert old_frame.code_map is not None, "Code map doesn't exist for frame id {}".format(
frame_id)
assert old_frame.new_code is not None, "New code doesn't exist for frame id {}".format(
frame_id)
if is_debug:
print("old bytecode: \n")
print(format_insts(old_frame.code_map.guard_insts))
new_code = old_frame.new_code
code_map = old_frame.code_map
trace_func = get_trace_func(frame_id)
frame_cache = get_frame_cache(frame_id)
frame_cache.update_code(frame.f_code, frame_id, is_callee)
new_code, code_map = frame_cache.get_new_code(is_callee)
print("bytecode to run:")
print(format_insts(code_map.guard_insts))
trace_func = get_trace_func(frame_id)

except Exception as e:
print("exception in preprocess:", e, type(e))
Expand All @@ -117,13 +103,9 @@ def postprocess_frame(frame: FrameType, frame_id: int) -> None:
if SHOULD_NOT_CALL_REWRITE:
raise ValueError("should not call postprocess")
print(f"postprocess frame {frame.f_code.co_filename}")
if check_cache_updated(frame_id):
print("new bytecode: \n")
set_frame_root(frame_id, f)
new_code, code_map = rewrite_bytecode(frame.f_code, frame_id,
is_callee)
get_frame_cache(frame_id).set_new_code(new_code, code_map)

set_frame_root(frame_id, f)
frame_cache = get_frame_cache(frame_id)
frame_cache.update_code(frame.f_code, frame_id, is_callee)
except Exception as e:
print("exception in postprocess:", e, type(e))
print(traceback.format_exc())
Expand Down
21 changes: 20 additions & 1 deletion test/test_call_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,23 @@ def test_call_run_udf(caplog):
result = call_run_udf(x)
compiled_model = compile(call_run_udf)
run_and_check(compiled_model, [MISS, MISS], 1, caplog, result, x)
run_and_check(compiled_model, [HIT], 1, caplog, result, x)
run_and_check(compiled_model, [HIT], 1, caplog, result, x)


def b_aba(a):
return a_aba(a + 2.0, 0)


def a_aba(a, b):
if b == 0:
return a + 3.0
else:
return b_aba(a + 1.0)


def test_call_aba(caplog):
reset()
compiled = compile(a_aba)
expect = a_aba(1.0, 1)
run_and_check(compiled, [MISS, MISS, MISS], 1, caplog, expect, 1.0, 1)
run_and_check(compiled, [HIT], 1, caplog, expect, 1.0, 1)

0 comments on commit b8e6464

Please sign in to comment.