Skip to content

Commit

Permalink
[SOT] clean code in eval frame and fix dy2st uts (#57662)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Sep 23, 2023
1 parent 9269321 commit 832309b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 98 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/eval_frame.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ limitations under the License. */
#include <Python.h>
#include <frameobject.h>

#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x3090000
#define Py_BUILD_CORE // internal/pycore_pymem.h need this macro
#include <internal/pycore_pystate.h>
#undef Py_BUILD_CORE
#endif
#if PY_VERSION_HEX < 0x030b0000
#include <code.h>
#endif
Expand Down
91 changes: 0 additions & 91 deletions paddle/fluid/pybind/eval_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,97 +19,6 @@ extern "C" {

#include <Python.h>

// see https://bugs.python.org/issue35886
// If py_version==3.8.*, we need to redefine _PyEvalFrameFunc and the
// related functions and structs.

#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x3090000

typedef PyObject *(*_PyFrameEvalFunction)(struct _frame *, int);

struct _warnings_runtime_state {
/* Both 'filters' and 'onceregistry' can be set in warnings.py;
get_warnings_attr() will reset these variables accordingly. */
PyObject *filters; /* List */
PyObject *once_registry; /* Dict */
PyObject *default_action; /* String */
long filters_version; // NOLINT
};

struct _is {
struct _is *next;
struct _ts *tstate_head;

int64_t id;
int64_t id_refcount;
int requires_idref;
PyThread_type_lock id_mutex;

int finalizing;

PyObject *modules;
PyObject *modules_by_index;
PyObject *sysdict;
PyObject *builtins;
PyObject *importlib;

/* Used in Python/sysmodule.c. */
int check_interval;

/* Used in Modules/_threadmodule.c. */
long num_threads; // NOLINT
/* Support for runtime thread stack size tuning.
A value of 0 means using the platform's default stack size
or the size specified by the THREAD_STACK_SIZE macro. */
/* Used in Python/thread.c. */
size_t pythread_stacksize;

PyObject *codec_search_path;
PyObject *codec_search_cache;
PyObject *codec_error_registry;
int codecs_initialized;

/* fs_codec.encoding is initialized to NULL.
Later, it is set to a non-NULL string by _PyUnicode_InitEncodings(). */
struct {
char *encoding; /* Filesystem encoding (encoded to UTF-8) */
char *errors; /* Filesystem errors (encoded to UTF-8) */
_Py_error_handler error_handler;
} fs_codec;

PyConfig config;
#ifdef HAVE_DLOPEN
int dlopenflags;
#endif

PyObject *dict; /* Stores per-interpreter state */

PyObject *builtins_copy;
PyObject *import_func;
/* Initialized to PyEval_EvalFrameDefault(). */
_PyFrameEvalFunction eval_frame;

Py_ssize_t co_extra_user_count;
freefunc co_extra_freefuncs[MAX_CO_EXTRA_USERS];

#ifdef HAVE_FORK
PyObject *before_forkers;
PyObject *after_forkers_parent;
PyObject *after_forkers_child;
#endif
/* AtExit module */
void (*pyexitfunc)(PyObject *);
PyObject *pyexitmodule;

uint64_t tstate_next_unique_id;

struct _warnings_runtime_state warnings;

PyObject *audit_hooks;
};

#endif

PyObject *set_eval_frame_py(PyObject *callback);
PyMODINIT_FUNC PyInit__eval_frame();

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cc_library(
test_dialect
SRCS test_dialect.cc test_op.cc test_trait.cc test_interface.cc
DEPS pir)
DEPS pir gtest)
13 changes: 7 additions & 6 deletions test/dygraph_to_static/test_new_ir_selectedrows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import random
import unittest

from dygraph_to_static_util import test_and_compare_with_new_ir
from dygraph_to_static_util import (
enable_fallback_guard,
test_and_compare_with_new_ir,
)

import paddle
from paddle.jit.api import to_static
Expand Down Expand Up @@ -53,7 +56,6 @@ def forward(self, x):
return x


@to_static
def train(net, adam, x):
loss_data = []
for i in range(10):
Expand All @@ -75,7 +77,6 @@ def train_dygraph():
parameters=net.parameters(), learning_rate=0.01, grad_clip=clip
)

paddle.jit.enable_to_static(False)
return train(net, adam, x)


Expand All @@ -89,8 +90,7 @@ def train_static():
parameters=net.parameters(), learning_rate=0.01, grad_clip=clip
)

paddle.jit.enable_to_static(True)
return train(net, adam, x)
return to_static(train)(net, adam, x)


class TestSimnet(unittest.TestCase):
Expand All @@ -104,4 +104,5 @@ def test_dygraph_static_same_loss(self):


if __name__ == '__main__':
unittest.main()
with enable_fallback_guard("False"):
unittest.main()

0 comments on commit 832309b

Please sign in to comment.