Skip to content

Commit

Permalink
enable fallback from exception during compiling process (apache#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 authored Apr 23, 2024
2 parents db17f8f + 904977b commit 7cd47d6
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 15 deletions.
6 changes: 6 additions & 0 deletions frontend/c_api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ def set_eval_frame(
pass


def set_fallback(
new_callback: Optional[Tuple[Callable[..., Any], Callable[..., Any]]]
) -> Optional[Tuple[Callable[..., Any], Callable[..., Any]]]:
pass


def set_skip_files(skip_file: set[str], end_file: set[str]) -> None:
pass

Expand Down
2 changes: 2 additions & 0 deletions frontend/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,5 @@ def reset() -> None:
fx_graph.reset()
from . import dynamic
dynamic.reset()
from . import tracer
tracer.reset()
1 change: 1 addition & 0 deletions frontend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"debug": True,
"miss_threshold": 3,
"dynshape": False,
"enable_fallback": False,
}


Expand Down
16 changes: 13 additions & 3 deletions frontend/csrc/frame_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,11 @@ inline static void enable_eval_frame_shim(PyThreadState *tstate) {
inline static void enable_eval_frame_default(PyThreadState *tstate) {
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
previous_eval_frame) {
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
previous_eval_frame);
previous_eval_frame = NULL;
if (previous_eval_frame != NULL) {
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
previous_eval_frame);
previous_eval_frame = NULL;
}
}
}

Expand Down Expand Up @@ -290,6 +292,13 @@ static PyObject *set_eval_frame(PyObject *self, PyObject *args) {
return old_callback;
}

static PyObject *set_fallback(PyObject *self, PyObject *args) {
PyThreadState *tstate = PyThreadState_GET();
fprintf(stderr, "Falling back\n");
decrese_working_threads(tstate);
Py_RETURN_NONE;
}

// TODO: in a more elegant way
static PyObject *set_skip_files(PyObject *self, PyObject *args) {
if (skip_files != Py_None) {
Expand Down Expand Up @@ -659,6 +668,7 @@ static PyObject *mark_need_postprocess(PyObject *self, PyObject *args) {

static PyMethodDef _methods[] = {
{"set_eval_frame", set_eval_frame, METH_VARARGS, NULL},
{"set_fallback", set_fallback, METH_VARARGS, NULL},
{"set_skip_files", set_skip_files, METH_VARARGS, NULL},
{"set_null_object", set_null_object, METH_VARARGS, NULL},
{"set_miss_threshold", set_miss_threshold, METH_VARARGS, NULL},
Expand Down
13 changes: 5 additions & 8 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .variables.const import ClsByNamedTupleVar
from .variables.base import Variable
from .control_flow import ControlFlowInfo, LoopModule, ForLoopInfo, LoopPosMap, if_stmt, IfStmtInfo
from .config import get_config

MAKE_VAR_FN_TYPE = Callable[[
Any, bool, vs.HelperFunctions, Optional[FxGraph], Optional[list[StorePos]]
Expand Down Expand Up @@ -1551,7 +1552,7 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool:
return func in (dict, tuple, set, list, hasattr, slice, range, len,
type, all, str.join, reversed, zip, iter, id, next,
collections.OrderedDict, str.format, any, str,
str.split)
str.split, sorted)

def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
print(dir(func))
Expand Down Expand Up @@ -2427,11 +2428,6 @@ def UNPACK_SEQUENCE(self, inst: Instruction) -> None:
# ]
# })
# pass
print("check data", seq, type(seq))
if self.state.objects.contains(seq):
print("jjjjjj")
for i in seq:
print(i)
raise NotImplementedError

def UNPACK_EX(self, inst: Instruction) -> None:
Expand Down Expand Up @@ -2705,8 +2701,9 @@ def pop_tracker(frame_id: int) -> None:
print("before pop_tracker", [t.frame_id for t in trackers], "frame_id",
frame_id)
to_pop = trackers.pop()
assert to_pop.frame_id == frame_id
assert to_pop.state.is_empty
if not get_config("enable_fallback"):
assert to_pop.frame_id == frame_id
assert to_pop.state.is_empty


def record(frame: FrameType, frame_id: int) -> None:
Expand Down
29 changes: 25 additions & 4 deletions frontend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
from types import FrameType, CodeType
from typing import Any, Callable, Tuple
import inspect
from .guard_tracker import push_tracker, pop_tracker, record
from .guard_tracker import push_tracker, pop_tracker, record, trackers
from .cache import enable_cache, check_cache_updated, get_frame_cache
from .fx_graph import set_frame_root
from .c_api import set_eval_frame, mark_need_postprocess
from .c_api import set_eval_frame, mark_need_postprocess, set_fallback
from .code import ProcessedCode
from .instruction import format_insts
from .config import get_config

run_trace_func: bool = True
fall_back_frames: list[int] = []


def get_trace_func(frame_id: int) -> Callable[[FrameType, str, Any], None]:

def trace_func(frame: FrameType, event: str, arg: Any) -> None:
global run_trace_func
if not run_trace_func and frame_id in fall_back_frames:
return None
try:
if event == "opcode":
opcode = frame.f_code.co_code[frame.f_lasti]
Expand All @@ -33,7 +39,17 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> None:
except Exception as e:
print("exception in trace_func:", e, type(e))
print(traceback.format_exc())
raise e
if get_config("enable_fallback"):
run_trace_func = False
for i in trackers:
fall_back_frames.append(i.frame_id)
# if len(trackers) > 1:
# disable_trace(frame_id)
print("fallback frames", fall_back_frames)
set_fallback(None)
return None
else:
raise e
return None

return trace_func
Expand Down Expand Up @@ -115,4 +131,9 @@ def postprocess_frame(frame: FrameType, frame_id: int) -> None:
raise e
return

return (preprocess_frame, postprocess_frame)
return (preprocess_frame, postprocess_frame)


def reset() -> None:
run_trace_func = True
fall_back_frames.clear()

0 comments on commit 7cd47d6

Please sign in to comment.