Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2static-Fallback] add set_eval_frame function in pybind. #52006

Merged
merged 10 commits into from
May 18, 2023
227 changes: 227 additions & 0 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,234 @@ limitations under the License. */

#include "paddle/fluid/pybind/jit.h"

#include <Python.h>
#include <code.h>
#include <frameobject.h>
#include <object.h>
#include <pystate.h>

#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/platform/place.h"

#include "glog/logging.h"
#include "paddle/fluid/jit/function.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/serializer.h"
#include "paddle/utils/pybind.h"

namespace py = pybind11;

namespace paddle {
namespace pybind {

#define unlikely(x) __builtin_expect((x), 0)

// Use static variable to save customed eval hook.
static Py_tss_t eval_frame_callback_key = {0, 0};

inline static PyObject *eval_frame_callback_get(void) {
void *result = PyThread_tss_get(&eval_frame_callback_key);
if (unlikely(result == NULL)) {
Py_RETURN_NONE;
} else {
return reinterpret_cast<PyObject *>(result);
}
}

inline static void eval_frame_callback_set(PyObject *obj) {
PyThread_tss_set(&eval_frame_callback_key, obj);
}

// call python default eval frame to interpret current frame.
inline static PyObject *eval_frame_default(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag) {
#if PY_VERSION_HEX >= 0x03090000
if (tstate == NULL) {
tstate = PyThreadState_GET();
}
return _PyEval_EvalFrameDefault(tstate, frame, throw_flag);
#else
return _PyEval_EvalFrameDefault(frame, throw_flag);
#endif
}

// Start a new frame and run code in this frame.
// Execute a piece of code by default frame-hook.
inline static PyObject *eval_custom_code(PyThreadState *tstate,
PyFrameObject *frame,
PyCodeObject *code,
int throw_flag) {
Py_ssize_t ncells = 0;
Py_ssize_t nfrees = 0;
Py_ssize_t nlocals_new = code->co_nlocals;
Py_ssize_t nlocals_old = frame->f_code->co_nlocals;

if ((code->co_flags & CO_NOFREE) == 0) {
ncells = PyTuple_GET_SIZE(code->co_cellvars);
nfrees = PyTuple_GET_SIZE(code->co_freevars);
}

PyFrameObject *shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
if (shadow == NULL) {
return NULL;
}

PyObject **fastlocals_old = frame->f_localsplus;
PyObject **fastlocals_new = shadow->f_localsplus;

for (Py_ssize_t i = 0; i < nlocals_old; i++) {
Py_XINCREF(fastlocals_old[i]);
fastlocals_new[i] = fastlocals_old[i];
}

for (Py_ssize_t i = 0; i < ncells + nfrees; i++) {
Py_XINCREF(fastlocals_old[nlocals_old + i]);
fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i];
}

PyObject *result = eval_frame_default(tstate, shadow, throw_flag);
Py_DECREF(shadow);
return result;
}

static PyObject *_custom_eval_frame(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag,
PyObject *callback) {
// https://peps.python.org/pep-0558/#fast-locals-proxy-implementation-details
// https://devguide.python.org/internals/interpreter/#all-sorts-of-variables
if (PyFrame_FastToLocalsWithError(frame) < 0) {
return NULL;
}

// NOTE:(xiongkun): Handle GeneratorExit exception: (Spend a day)
// In Python, gen close is also a Python function call that will enter this
// function with GeneratorExit set, which will cause the PyObject_CallObject
// raise SystemError. So we disable the custom behavior for GeneratorExit. def
// func():
// iter = iter([1, 2, 3])
// for i in iter:
// return i # <--- Early return, cause a GeneratorExit thrown,
// # <--- which Cause the PyObject_CallObject raise
// SystemError.
if (PyErr_ExceptionMatches(PyExc_GeneratorExit)) {
return eval_frame_default(tstate, frame, throw_flag);
}

// We don't run the current custom_eval_frame behavior for guards.
// So we temporarily set the callback to Py_None to drive the correct behavior
// in the shim.
eval_frame_callback_set(Py_None);

PyObject *args = Py_BuildValue("(O)", frame);
PyObject *result = PyObject_CallObject(callback, args);
// result: GuardedCode
if (result == NULL) {
// internal exception
return NULL;
} else if (result != Py_None) {
// NOTE: Cache is not supported now
PyCodeObject *code = reinterpret_cast<PyCodeObject *>(
PyObject_GetAttrString(result, "code"));
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, code, throw_flag);
} else {
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_frame_default(tstate, frame, throw_flag);
}
}

static PyObject *_custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag) {
PyObject *callback = eval_frame_callback_get();

if (callback == Py_None) {
return eval_frame_default(tstate, frame, throw_flag);
}

return _custom_eval_frame(tstate, frame, throw_flag, callback);
}

#if PY_VERSION_HEX >= 0x03090000
static PyObject *custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag) {
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#else
static PyObject *custom_eval_frame_shim(PyFrameObject *frame, int throw_flag) {
PyThreadState *tstate = PyThreadState_GET();
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#endif

static PyObject *set_eval_frame(PyObject *new_callback, PyThreadState *tstate) {
// Change the eval frame callback and return the old one
// - None: disables: diable custom callback.
// - Python callable(): enables custom callback.
// NOTE: Cache is not supported now
PyObject *old_callback = eval_frame_callback_get();

#if PY_VERSION_HEX >= 0x03090000
auto *old_eval_frame = _PyInterpreterState_GetEvalFrameFunc(tstate->interp);
#else
// Function pointer.
_PyFrameEvalFunction old_eval_frame = tstate->interp->eval_frame;
#endif

// NOTE: multi-threading is not supported now
if (old_callback != Py_None && new_callback == Py_None) {
if (old_eval_frame != &_PyEval_EvalFrameDefault) {
VLOG(7) << "set _PyEval_EvalFrameDefault";
#if PY_VERSION_HEX >= 0x03090000
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
&_PyEval_EvalFrameDefault);
#else
tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
#endif
}
} else if (old_callback == Py_None && new_callback != Py_None) {
if (old_eval_frame != &custom_eval_frame_shim) {
VLOG(7) << "set custom_eval_frame_shim";
#if PY_VERSION_HEX >= 0x03090000
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
&custom_eval_frame_shim);
#else
tstate->interp->eval_frame = &custom_eval_frame_shim;
#endif
}
}

Py_INCREF(new_callback);
eval_frame_callback_set(new_callback);

return old_callback;
}

static PyObject *set_eval_frame_py(PyObject *callback) {
if (callback != Py_None && !PyCallable_Check(callback)) {
VLOG(7) << "callback is not a callable or none, invalid arguments.";
RETURN_PY_NONE
}
return set_eval_frame(callback, PyThreadState_GET());
}

PyMODINIT_FUNC PyInit__eval_frame(void) {
int result = PyThread_tss_create(&eval_frame_callback_key);
VLOG(7) << "Set PyThread_tss_create return: " << result;

Py_INCREF(Py_None);
eval_frame_callback_set(Py_None);

return NULL;
}

PyTypeObject *g_jit_function_pytype = nullptr;
using Variable = paddle::framework::Variable;

Expand Down Expand Up @@ -58,5 +272,18 @@ void BindJit(pybind11::module *m) {
});
}

void BindEvalFrame(pybind11::module *m) {
PyInit__eval_frame();
m->def(
"set_eval_frame",
[](const py::object &py_func) {
VLOG(5) << "start call set_eval_frame_py.";
auto ret = set_eval_frame_py(py_func.ptr());
auto obj = py::reinterpret_borrow<py::object>(ret);
return obj;
},
py::arg("callback"));
}

} // namespace pybind
} // namespace paddle
94 changes: 93 additions & 1 deletion paddle/fluid/pybind/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,102 @@ limitations under the License. */
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

// see https://bugs.python.org/issue35886
// If py_version==3.8.*, we need to redefine _PyEvalFrameFunc and the
// related functions and structs.

#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x3090000

typedef PyObject *(*_PyFrameEvalFunction)(struct _frame *, int);

struct _warnings_runtime_state {
/* Both 'filters' and 'onceregistry' can be set in warnings.py;
get_warnings_attr() will reset these variables accordingly. */
PyObject *filters; /* List */
PyObject *once_registry; /* Dict */
PyObject *default_action; /* String */
long filters_version; // NOLINT
};

struct _is {
struct _is *next;
struct _ts *tstate_head;

int64_t id;
int64_t id_refcount;
int requires_idref;
PyThread_type_lock id_mutex;

int finalizing;

PyObject *modules;
PyObject *modules_by_index;
PyObject *sysdict;
PyObject *builtins;
PyObject *importlib;

/* Used in Python/sysmodule.c. */
int check_interval;

/* Used in Modules/_threadmodule.c. */
long num_threads; // NOLINT
/* Support for runtime thread stack size tuning.
A value of 0 means using the platform's default stack size
or the size specified by the THREAD_STACK_SIZE macro. */
/* Used in Python/thread.c. */
size_t pythread_stacksize;

PyObject *codec_search_path;
PyObject *codec_search_cache;
PyObject *codec_error_registry;
int codecs_initialized;

/* fs_codec.encoding is initialized to NULL.
Later, it is set to a non-NULL string by _PyUnicode_InitEncodings(). */
struct {
char *encoding; /* Filesystem encoding (encoded to UTF-8) */
char *errors; /* Filesystem errors (encoded to UTF-8) */
_Py_error_handler error_handler;
} fs_codec;

PyConfig config;
#ifdef HAVE_DLOPEN
int dlopenflags;
#endif

PyObject *dict; /* Stores per-interpreter state */

PyObject *builtins_copy;
PyObject *import_func;
/* Initialized to PyEval_EvalFrameDefault(). */
_PyFrameEvalFunction eval_frame;

Py_ssize_t co_extra_user_count;
freefunc co_extra_freefuncs[MAX_CO_EXTRA_USERS];

#ifdef HAVE_FORK
PyObject *before_forkers;
PyObject *after_forkers_parent;
PyObject *after_forkers_child;
#endif
/* AtExit module */
void (*pyexitfunc)(PyObject *);
PyObject *pyexitmodule;

uint64_t tstate_next_unique_id;

struct _warnings_runtime_state warnings;

PyObject *audit_hooks;
};

#endif

namespace paddle {
namespace pybind {

void BindJit(pybind11::module* m);
void BindJit(pybind11::module *m);
void BindEvalFrame(pybind11::module *m);

} // namespace pybind
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ PYBIND11_MODULE(libpaddle, m) {
BindCudaStream(&m);
BindXpuStream(&m);
BindJit(&m);
BindEvalFrame(&m);
BindCustomDevicePy(&m);

// Not used, just make sure cpu_info.cc is linked.
Expand Down
Loading