Skip to content

Commit

Permalink
[Dy2static-Fallback] add set_eval_frame function in pybind. (#52006)
Browse files Browse the repository at this point in the history
* [Dy2static-Fallback] add set_eval_frame function in pybind.
1. add set_eval_frame function in pybind.

* add unittest for eval frame hooker.

* [support py38]

* fix-GeneratorExit error in eval frame hooker

* support python == 3.9

* support 3.10

* fix some comments
  • Loading branch information
2742195759 authored May 18, 2023
1 parent 2d0c694 commit 7b1695a
Show file tree
Hide file tree
Showing 4 changed files with 377 additions and 1 deletion.
227 changes: 227 additions & 0 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,234 @@ limitations under the License. */

#include "paddle/fluid/pybind/jit.h"

#include <Python.h>
#include <code.h>
#include <frameobject.h>
#include <object.h>
#include <pystate.h>

#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/platform/place.h"

#include "glog/logging.h"
#include "paddle/fluid/jit/function.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/serializer.h"
#include "paddle/utils/pybind.h"

namespace py = pybind11;

namespace paddle {
namespace pybind {

#define unlikely(x) __builtin_expect((x), 0)

// Use static variable to save customed eval hook.
static Py_tss_t eval_frame_callback_key = {0, 0};

inline static PyObject *eval_frame_callback_get(void) {
void *result = PyThread_tss_get(&eval_frame_callback_key);
if (unlikely(result == NULL)) {
Py_RETURN_NONE;
} else {
return reinterpret_cast<PyObject *>(result);
}
}

inline static void eval_frame_callback_set(PyObject *obj) {
PyThread_tss_set(&eval_frame_callback_key, obj);
}

// call python default eval frame to interpret current frame.
inline static PyObject *eval_frame_default(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag) {
#if PY_VERSION_HEX >= 0x03090000
if (tstate == NULL) {
tstate = PyThreadState_GET();
}
return _PyEval_EvalFrameDefault(tstate, frame, throw_flag);
#else
return _PyEval_EvalFrameDefault(frame, throw_flag);
#endif
}

// Start a new frame and run code in this frame.
// Execute a piece of code by default frame-hook.
inline static PyObject *eval_custom_code(PyThreadState *tstate,
PyFrameObject *frame,
PyCodeObject *code,
int throw_flag) {
Py_ssize_t ncells = 0;
Py_ssize_t nfrees = 0;
Py_ssize_t nlocals_new = code->co_nlocals;
Py_ssize_t nlocals_old = frame->f_code->co_nlocals;

if ((code->co_flags & CO_NOFREE) == 0) {
ncells = PyTuple_GET_SIZE(code->co_cellvars);
nfrees = PyTuple_GET_SIZE(code->co_freevars);
}

PyFrameObject *shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
if (shadow == NULL) {
return NULL;
}

PyObject **fastlocals_old = frame->f_localsplus;
PyObject **fastlocals_new = shadow->f_localsplus;

for (Py_ssize_t i = 0; i < nlocals_old; i++) {
Py_XINCREF(fastlocals_old[i]);
fastlocals_new[i] = fastlocals_old[i];
}

for (Py_ssize_t i = 0; i < ncells + nfrees; i++) {
Py_XINCREF(fastlocals_old[nlocals_old + i]);
fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i];
}

PyObject *result = eval_frame_default(tstate, shadow, throw_flag);
Py_DECREF(shadow);
return result;
}

static PyObject *_custom_eval_frame(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag,
PyObject *callback) {
// https://peps.python.org/pep-0558/#fast-locals-proxy-implementation-details
// https://devguide.python.org/internals/interpreter/#all-sorts-of-variables
if (PyFrame_FastToLocalsWithError(frame) < 0) {
return NULL;
}

// NOTE:(xiongkun): Handle GeneratorExit exception: (Spend a day)
// In Python, gen close is also a Python function call that will enter this
// function with GeneratorExit set, which will cause the PyObject_CallObject
// raise SystemError. So we disable the custom behavior for GeneratorExit. def
// func():
// iter = iter([1, 2, 3])
// for i in iter:
// return i # <--- Early return, cause a GeneratorExit thrown,
// # <--- which Cause the PyObject_CallObject raise
// SystemError.
if (PyErr_ExceptionMatches(PyExc_GeneratorExit)) {
return eval_frame_default(tstate, frame, throw_flag);
}

// We don't run the current custom_eval_frame behavior for guards.
// So we temporarily set the callback to Py_None to drive the correct behavior
// in the shim.
eval_frame_callback_set(Py_None);

PyObject *args = Py_BuildValue("(O)", frame);
PyObject *result = PyObject_CallObject(callback, args);
// result: GuardedCode
if (result == NULL) {
// internal exception
return NULL;
} else if (result != Py_None) {
// NOTE: Cache is not supported now
PyCodeObject *code = reinterpret_cast<PyCodeObject *>(
PyObject_GetAttrString(result, "code"));
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, code, throw_flag);
} else {
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_frame_default(tstate, frame, throw_flag);
}
}

static PyObject *_custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag) {
PyObject *callback = eval_frame_callback_get();

if (callback == Py_None) {
return eval_frame_default(tstate, frame, throw_flag);
}

return _custom_eval_frame(tstate, frame, throw_flag, callback);
}

#if PY_VERSION_HEX >= 0x03090000
static PyObject *custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag) {
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#else
static PyObject *custom_eval_frame_shim(PyFrameObject *frame, int throw_flag) {
PyThreadState *tstate = PyThreadState_GET();
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#endif

static PyObject *set_eval_frame(PyObject *new_callback, PyThreadState *tstate) {
// Change the eval frame callback and return the old one
// - None: disables: diable custom callback.
// - Python callable(): enables custom callback.
// NOTE: Cache is not supported now
PyObject *old_callback = eval_frame_callback_get();

#if PY_VERSION_HEX >= 0x03090000
auto *old_eval_frame = _PyInterpreterState_GetEvalFrameFunc(tstate->interp);
#else
// Function pointer.
_PyFrameEvalFunction old_eval_frame = tstate->interp->eval_frame;
#endif

// NOTE: multi-threading is not supported now
if (old_callback != Py_None && new_callback == Py_None) {
if (old_eval_frame != &_PyEval_EvalFrameDefault) {
VLOG(7) << "set _PyEval_EvalFrameDefault";
#if PY_VERSION_HEX >= 0x03090000
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
&_PyEval_EvalFrameDefault);
#else
tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
#endif
}
} else if (old_callback == Py_None && new_callback != Py_None) {
if (old_eval_frame != &custom_eval_frame_shim) {
VLOG(7) << "set custom_eval_frame_shim";
#if PY_VERSION_HEX >= 0x03090000
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
&custom_eval_frame_shim);
#else
tstate->interp->eval_frame = &custom_eval_frame_shim;
#endif
}
}

Py_INCREF(new_callback);
eval_frame_callback_set(new_callback);

return old_callback;
}

static PyObject *set_eval_frame_py(PyObject *callback) {
if (callback != Py_None && !PyCallable_Check(callback)) {
VLOG(7) << "callback is not a callable or none, invalid arguments.";
RETURN_PY_NONE
}
return set_eval_frame(callback, PyThreadState_GET());
}

PyMODINIT_FUNC PyInit__eval_frame(void) {
int result = PyThread_tss_create(&eval_frame_callback_key);
VLOG(7) << "Set PyThread_tss_create return: " << result;

Py_INCREF(Py_None);
eval_frame_callback_set(Py_None);

return NULL;
}

PyTypeObject *g_jit_function_pytype = nullptr;
using Variable = paddle::framework::Variable;

Expand Down Expand Up @@ -58,5 +272,18 @@ void BindJit(pybind11::module *m) {
});
}

void BindEvalFrame(pybind11::module *m) {
PyInit__eval_frame();
m->def(
"set_eval_frame",
[](const py::object &py_func) {
VLOG(5) << "start call set_eval_frame_py.";
auto ret = set_eval_frame_py(py_func.ptr());
auto obj = py::reinterpret_borrow<py::object>(ret);
return obj;
},
py::arg("callback"));
}

} // namespace pybind
} // namespace paddle
94 changes: 93 additions & 1 deletion paddle/fluid/pybind/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,102 @@ limitations under the License. */
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

// see https://bugs.python.org/issue35886
// If py_version==3.8.*, we need to redefine _PyEvalFrameFunc and the
// related functions and structs.

#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x3090000

typedef PyObject *(*_PyFrameEvalFunction)(struct _frame *, int);

struct _warnings_runtime_state {
/* Both 'filters' and 'onceregistry' can be set in warnings.py;
get_warnings_attr() will reset these variables accordingly. */
PyObject *filters; /* List */
PyObject *once_registry; /* Dict */
PyObject *default_action; /* String */
long filters_version; // NOLINT
};

struct _is {
struct _is *next;
struct _ts *tstate_head;

int64_t id;
int64_t id_refcount;
int requires_idref;
PyThread_type_lock id_mutex;

int finalizing;

PyObject *modules;
PyObject *modules_by_index;
PyObject *sysdict;
PyObject *builtins;
PyObject *importlib;

/* Used in Python/sysmodule.c. */
int check_interval;

/* Used in Modules/_threadmodule.c. */
long num_threads; // NOLINT
/* Support for runtime thread stack size tuning.
A value of 0 means using the platform's default stack size
or the size specified by the THREAD_STACK_SIZE macro. */
/* Used in Python/thread.c. */
size_t pythread_stacksize;

PyObject *codec_search_path;
PyObject *codec_search_cache;
PyObject *codec_error_registry;
int codecs_initialized;

/* fs_codec.encoding is initialized to NULL.
Later, it is set to a non-NULL string by _PyUnicode_InitEncodings(). */
struct {
char *encoding; /* Filesystem encoding (encoded to UTF-8) */
char *errors; /* Filesystem errors (encoded to UTF-8) */
_Py_error_handler error_handler;
} fs_codec;

PyConfig config;
#ifdef HAVE_DLOPEN
int dlopenflags;
#endif

PyObject *dict; /* Stores per-interpreter state */

PyObject *builtins_copy;
PyObject *import_func;
/* Initialized to PyEval_EvalFrameDefault(). */
_PyFrameEvalFunction eval_frame;

Py_ssize_t co_extra_user_count;
freefunc co_extra_freefuncs[MAX_CO_EXTRA_USERS];

#ifdef HAVE_FORK
PyObject *before_forkers;
PyObject *after_forkers_parent;
PyObject *after_forkers_child;
#endif
/* AtExit module */
void (*pyexitfunc)(PyObject *);
PyObject *pyexitmodule;

uint64_t tstate_next_unique_id;

struct _warnings_runtime_state warnings;

PyObject *audit_hooks;
};

#endif

namespace paddle {
namespace pybind {

void BindJit(pybind11::module* m);
void BindJit(pybind11::module *m);
void BindEvalFrame(pybind11::module *m);

} // namespace pybind
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ PYBIND11_MODULE(libpaddle, m) {
BindCudaStream(&m);
BindXpuStream(&m);
BindJit(&m);
BindEvalFrame(&m);
BindCustomDevicePy(&m);

// Not used, just make sure cpu_info.cc is linked.
Expand Down
Loading

0 comments on commit 7b1695a

Please sign in to comment.