From 904977bb8bf2f4fad66f4644e154329920cc5f21 Mon Sep 17 00:00:00 2001 From: rongchaodong <16302010007@fudan.edu.cn> Date: Thu, 18 Apr 2024 15:23:07 +0800 Subject: [PATCH] enable fallback from exception during compiling process --- frontend/c_api.pyi | 6 ++++++ frontend/compile.py | 2 ++ frontend/config.py | 1 + frontend/csrc/frame_evaluation.cpp | 16 +++++++++++++--- frontend/guard_tracker.py | 13 +++++-------- frontend/tracer.py | 29 +++++++++++++++++++++++++---- 6 files changed, 52 insertions(+), 15 deletions(-) diff --git a/frontend/c_api.pyi b/frontend/c_api.pyi index f23ab8c63ab1..e2c928c7847c 100644 --- a/frontend/c_api.pyi +++ b/frontend/c_api.pyi @@ -11,6 +11,12 @@ def set_eval_frame( pass +def set_fallback( + new_callback: Optional[Tuple[Callable[..., Any], Callable[..., Any]]] +) -> Optional[Tuple[Callable[..., Any], Callable[..., Any]]]: + pass + + def set_skip_files(skip_file: set[str], end_file: set[str]) -> None: pass diff --git a/frontend/compile.py b/frontend/compile.py index 6b7e38a6f365..3ce024e75a29 100644 --- a/frontend/compile.py +++ b/frontend/compile.py @@ -93,3 +93,5 @@ def reset() -> None: fx_graph.reset() from . import dynamic dynamic.reset() + from . import tracer + tracer.reset() diff --git a/frontend/config.py b/frontend/config.py index 390ed060de8b..fb259a03ed9d 100644 --- a/frontend/config.py +++ b/frontend/config.py @@ -5,6 +5,7 @@ "debug": True, "miss_threshold": 3, "dynshape": False, + "enable_fallback": False, } diff --git a/frontend/csrc/frame_evaluation.cpp b/frontend/csrc/frame_evaluation.cpp index 9d3d7e93e3e5..268bdd1c8d4e 100644 --- a/frontend/csrc/frame_evaluation.cpp +++ b/frontend/csrc/frame_evaluation.cpp @@ -242,9 +242,11 @@ inline static void enable_eval_frame_shim(PyThreadState *tstate) { inline static void enable_eval_frame_default(PyThreadState *tstate) { if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) != previous_eval_frame) { - _PyInterpreterState_SetEvalFrameFunc(tstate->interp, - previous_eval_frame); - previous_eval_frame = NULL; + if (previous_eval_frame != NULL) { + _PyInterpreterState_SetEvalFrameFunc(tstate->interp, + previous_eval_frame); + previous_eval_frame = NULL; + } } } @@ -290,6 +292,13 @@ static PyObject *set_eval_frame(PyObject *self, PyObject *args) { return old_callback; } +static PyObject *set_fallback(PyObject *self, PyObject *args) { + PyThreadState *tstate = PyThreadState_GET(); + fprintf(stderr, "Falling back\n"); + decrese_working_threads(tstate); + Py_RETURN_NONE; +} + // TODO: in a more elegant way static PyObject *set_skip_files(PyObject *self, PyObject *args) { if (skip_files != Py_None) { @@ -659,6 +668,7 @@ static PyObject *mark_need_postprocess(PyObject *self, PyObject *args) { static PyMethodDef _methods[] = { {"set_eval_frame", set_eval_frame, METH_VARARGS, NULL}, + {"set_fallback", set_fallback, METH_VARARGS, NULL}, {"set_skip_files", set_skip_files, METH_VARARGS, NULL}, {"set_null_object", set_null_object, METH_VARARGS, NULL}, {"set_miss_threshold", set_miss_threshold, METH_VARARGS, NULL}, diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index b7c5b4666ead..85a21caff799 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -31,6 +31,7 @@ from .variables.const import ClsByNamedTupleVar from .variables.base import Variable from .control_flow import ControlFlowInfo, LoopModule, ForLoopInfo, LoopPosMap, if_stmt, IfStmtInfo +from .config import get_config MAKE_VAR_FN_TYPE = Callable[[ Any, bool, vs.HelperFunctions, Optional[FxGraph], Optional[list[StorePos]] @@ -1551,7 +1552,7 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool: return func in (dict, tuple, set, list, hasattr, slice, range, len, type, all, str.join, reversed, zip, iter, id, next, collections.OrderedDict, str.format, any, str, - str.split) + str.split, sorted) def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool: print(dir(func)) @@ -2415,11 +2416,6 @@ def UNPACK_SEQUENCE(self, inst: Instruction) -> None: # ] # }) # pass - print("check data", seq, type(seq)) - if self.state.objects.contains(seq): - print("jjjjjj") - for i in seq: - print(i) raise NotImplementedError def UNPACK_EX(self, inst: Instruction) -> None: @@ -2693,8 +2689,9 @@ def pop_tracker(frame_id: int) -> None: print("before pop_tracker", [t.frame_id for t in trackers], "frame_id", frame_id) to_pop = trackers.pop() - assert to_pop.frame_id == frame_id - assert to_pop.state.is_empty + if not get_config("enable_fallback"): + assert to_pop.frame_id == frame_id + assert to_pop.state.is_empty def record(frame: FrameType, frame_id: int) -> None: diff --git a/frontend/tracer.py b/frontend/tracer.py index 4ad43340db74..e552f6d13551 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -4,18 +4,24 @@ from types import FrameType, CodeType from typing import Any, Callable, Tuple import inspect -from .guard_tracker import push_tracker, pop_tracker, record +from .guard_tracker import push_tracker, pop_tracker, record, trackers from .cache import enable_cache, check_cache_updated, get_frame_cache from .fx_graph import set_frame_root -from .c_api import set_eval_frame, mark_need_postprocess +from .c_api import set_eval_frame, mark_need_postprocess, set_fallback from .code import ProcessedCode from .instruction import format_insts from .config import get_config +run_trace_func: bool = True +fall_back_frames: list[int] = [] + def get_trace_func(frame_id: int) -> Callable[[FrameType, str, Any], None]: def trace_func(frame: FrameType, event: str, arg: Any) -> None: + global run_trace_func + if not run_trace_func and frame_id in fall_back_frames: + return None try: if event == "opcode": opcode = frame.f_code.co_code[frame.f_lasti] @@ -33,7 +39,17 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> None: except Exception as e: print("exception in trace_func:", e, type(e)) print(traceback.format_exc()) - raise e + if get_config("enable_fallback"): + run_trace_func = False + for i in trackers: + fall_back_frames.append(i.frame_id) + # if len(trackers) > 1: + # disable_trace(frame_id) + print("fallback frames", fall_back_frames) + set_fallback(None) + return None + else: + raise e return None return trace_func @@ -115,4 +131,9 @@ def postprocess_frame(frame: FrameType, frame_id: int) -> None: raise e return - return (preprocess_frame, postprocess_frame) \ No newline at end of file + return (preprocess_frame, postprocess_frame) + + +def reset() -> None: + run_trace_func = True + fall_back_frames.clear()