Skip to content

Commit

Permalink
GH-91079: Decouple C stack overflow checks from Python recursion chec…
Browse files Browse the repository at this point in the history
…ks. (GH-96510)
  • Loading branch information
markshannon committed Oct 5, 2022
1 parent 0ff8fd6 commit 7644935
Show file tree
Hide file tree
Showing 22 changed files with 165 additions and 99 deletions.
16 changes: 14 additions & 2 deletions Include/cpython/pystate.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ struct _ts {
/* Was this thread state statically allocated? */
int _static;

int recursion_remaining;
int recursion_limit;
int py_recursion_remaining;
int py_recursion_limit;

int c_recursion_remaining;
int recursion_headroom; /* Allow 50 more calls to handle any errors. */

/* 'tracing' keeps track of the execution depth when tracing/profiling.
Expand Down Expand Up @@ -202,6 +204,16 @@ struct _ts {
_PyCFrame root_cframe;
};

/* WASI has limited call stack. Python's recursion limit depends on code
layout, optimization, and WASI runtime. Wasmtime can handle about 700
recursions, sometimes less. 500 is a more conservative limit. */
#ifndef C_RECURSION_LIMIT
# ifdef __wasi__
# define C_RECURSION_LIMIT 500
# else
# define C_RECURSION_LIMIT 800
# endif
#endif

/* other API */

Expand Down
21 changes: 9 additions & 12 deletions Include/internal/pycore_ceval.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,8 @@ extern "C" {
struct pyruntimestate;
struct _ceval_runtime_state;

/* WASI has limited call stack. Python's recursion limit depends on code
layout, optimization, and WASI runtime. Wasmtime can handle about 700-750
recursions, sometimes less. 600 is a more conservative limit. */
#ifndef Py_DEFAULT_RECURSION_LIMIT
# ifdef __wasi__
# define Py_DEFAULT_RECURSION_LIMIT 600
# else
# define Py_DEFAULT_RECURSION_LIMIT 1000
# endif
# define Py_DEFAULT_RECURSION_LIMIT 1000
#endif

#include "pycore_interp.h" // PyInterpreterState.eval_frame
Expand Down Expand Up @@ -118,19 +111,22 @@ extern void _PyEval_DeactivateOpCache(void);
/* With USE_STACKCHECK macro defined, trigger stack checks in
_Py_CheckRecursiveCall() on every 64th call to _Py_EnterRecursiveCall. */
static inline int _Py_MakeRecCheck(PyThreadState *tstate) {
return (tstate->recursion_remaining-- <= 0
|| (tstate->recursion_remaining & 63) == 0);
return (tstate->c_recursion_remaining-- <= 0
|| (tstate->c_recursion_remaining & 63) == 0);
}
#else
static inline int _Py_MakeRecCheck(PyThreadState *tstate) {
return tstate->recursion_remaining-- <= 0;
return tstate->c_recursion_remaining-- <= 0;
}
#endif

PyAPI_FUNC(int) _Py_CheckRecursiveCall(
PyThreadState *tstate,
const char *where);

int _Py_CheckRecursiveCallPy(
PyThreadState *tstate);

static inline int _Py_EnterRecursiveCallTstate(PyThreadState *tstate,
const char *where) {
return (_Py_MakeRecCheck(tstate) && _Py_CheckRecursiveCall(tstate, where));
Expand All @@ -142,7 +138,7 @@ static inline int _Py_EnterRecursiveCall(const char *where) {
}

static inline void _Py_LeaveRecursiveCallTstate(PyThreadState *tstate) {
tstate->recursion_remaining++;
tstate->c_recursion_remaining++;
}

static inline void _Py_LeaveRecursiveCall(void) {
Expand All @@ -157,6 +153,7 @@ extern PyObject* _Py_MakeCoro(PyFunctionObject *func);
extern int _Py_HandlePending(PyThreadState *tstate);



#ifdef __cplusplus
}
#endif
Expand Down
2 changes: 1 addition & 1 deletion Include/internal/pycore_runtime_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ extern "C" {
#define _PyThreadState_INIT \
{ \
._static = 1, \
.recursion_limit = Py_DEFAULT_RECURSION_LIMIT, \
.py_recursion_limit = Py_DEFAULT_RECURSION_LIMIT, \
.context_ver = 1, \
}

Expand Down
5 changes: 4 additions & 1 deletion Lib/test/support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"run_with_tz", "PGO", "missing_compiler_executable",
"ALWAYS_EQ", "NEVER_EQ", "LARGEST", "SMALLEST",
"LOOPBACK_TIMEOUT", "INTERNET_TIMEOUT", "SHORT_TIMEOUT", "LONG_TIMEOUT",
"Py_DEBUG",
"Py_DEBUG", "EXCEEDS_RECURSION_LIMIT",
]


Expand Down Expand Up @@ -2352,3 +2352,6 @@ def adjust_int_max_str_digits(max_digits):
yield
finally:
sys.set_int_max_str_digits(current)

#For recursion tests, easily exceeds default recursion limit
EXCEEDS_RECURSION_LIMIT = 5000
6 changes: 3 additions & 3 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,9 +825,9 @@ def next(self):

@support.cpython_only
def test_ast_recursion_limit(self):
fail_depth = sys.getrecursionlimit() * 3
crash_depth = sys.getrecursionlimit() * 300
success_depth = int(fail_depth * 0.75)
fail_depth = support.EXCEEDS_RECURSION_LIMIT
crash_depth = 100_000
success_depth = 1200

def check_limit(prefix, repeated):
expect_ok = prefix + repeated * success_depth
Expand Down
38 changes: 38 additions & 0 deletions Lib/test/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,44 @@ def test_multiple_values(self):
with self.check_raises_type_error(msg):
A().method_two_args("x", "y", x="oops")

@cpython_only
class TestRecursion(unittest.TestCase):

def test_super_deep(self):

def recurse(n):
if n:
recurse(n-1)

def py_recurse(n, m):
if n:
py_recurse(n-1, m)
else:
c_py_recurse(m-1)

def c_recurse(n):
if n:
_testcapi.pyobject_fastcall(c_recurse, (n-1,))

def c_py_recurse(m):
if m:
_testcapi.pyobject_fastcall(py_recurse, (1000, m))

depth = sys.getrecursionlimit()
sys.setrecursionlimit(100_000)
try:
recurse(90_000)
with self.assertRaises(RecursionError):
recurse(101_000)
c_recurse(100)
with self.assertRaises(RecursionError):
c_recurse(90_000)
c_py_recurse(90)
with self.assertRaises(RecursionError):
c_py_recurse(100_000)
finally:
sys.setrecursionlimit(depth)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion Lib/test/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def test_odd_sizes(self):
self.assertEqual(Dot(1)._replace(d=999), (999,))
self.assertEqual(Dot(1)._fields, ('d',))

n = 5000
n = support.EXCEEDS_RECURSION_LIMIT
names = list(set(''.join([choice(string.ascii_letters)
for j in range(10)]) for i in range(n)))
n = len(names)
Expand Down
3 changes: 1 addition & 2 deletions Lib/test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def __getitem__(self, key):

@unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI")
def test_extended_arg(self):
# default: 1000 * 2.5 = 2500 repetitions
repeat = int(sys.getrecursionlimit() * 2.5)
repeat = 2000
longexpr = 'x = x or ' + '-x' * repeat
g = {}
code = '''
Expand Down
8 changes: 4 additions & 4 deletions Lib/test/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ class MyGlobals(dict):
def __missing__(self, key):
return int(key.removeprefix("_number_"))

# 1,000 on most systems
limit = sys.getrecursionlimit()
code = "lambda: " + "+".join(f"_number_{i}" for i in range(limit))
# Need more than 256 variables to use EXTENDED_ARGS
variables = 400
code = "lambda: " + "+".join(f"_number_{i}" for i in range(variables))
sum_func = eval(code, MyGlobals())
expected = sum(range(limit))
expected = sum(range(variables))
# Warm up the the function for quickening (PEP 659)
for _ in range(30):
self.assertEqual(sum_func(), expected)
Expand Down
8 changes: 2 additions & 6 deletions Lib/test/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,7 @@ def test_recursion_normalizing_exception(self):
code = """if 1:
import sys
from _testinternalcapi import get_recursion_depth
from test import support
class MyException(Exception): pass
Expand Down Expand Up @@ -1399,13 +1400,8 @@ def gen():
generator = gen()
next(generator)
recursionlimit = sys.getrecursionlimit()
depth = get_recursion_depth()
try:
# Upon the last recursive invocation of recurse(),
# tstate->recursion_depth is equal to (recursion_limit - 1)
# and is equal to recursion_limit when _gen_throw() calls
# PyErr_NormalizeException().
recurse(setrecursionlimit(depth + 2) - depth)
recurse(support.EXCEEDS_RECURSION_LIMIT)
finally:
sys.setrecursionlimit(recursionlimit)
print('Done.')
Expand Down
12 changes: 6 additions & 6 deletions Lib/test/test_isinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from test import support




class TestIsInstanceExceptions(unittest.TestCase):
# Test to make sure that an AttributeError when accessing the instance's
# class's bases is masked. This was actually a bug in Python 2.2 and
Expand Down Expand Up @@ -97,7 +97,7 @@ def getclass(self):
class D: pass
self.assertRaises(RuntimeError, isinstance, c, D)



# These tests are similar to above, but tickle certain code paths in
# issubclass() instead of isinstance() -- really PyObject_IsSubclass()
# vs. PyObject_IsInstance().
Expand Down Expand Up @@ -147,7 +147,7 @@ def getbases(self):
self.assertRaises(TypeError, issubclass, B, C())




# meta classes for creating abstract classes and instances
class AbstractClass(object):
def __init__(self, bases):
Expand Down Expand Up @@ -179,7 +179,7 @@ class Super:

class Child(Super):
pass


class TestIsInstanceIsSubclass(unittest.TestCase):
# Tests to ensure that isinstance and issubclass work on abstract
# classes and instances. Before the 2.2 release, TypeErrors were
Expand Down Expand Up @@ -353,10 +353,10 @@ def blowstack(fxn, arg, compare_to):
# Make sure that calling isinstance with a deeply nested tuple for its
# argument will raise RecursionError eventually.
tuple_arg = (compare_to,)
for cnt in range(sys.getrecursionlimit()+5):
for cnt in range(support.EXCEEDS_RECURSION_LIMIT):
tuple_arg = (tuple_arg,)
fxn(arg, tuple_arg)



if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion Lib/test/test_marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def test_code(self):

def test_many_codeobjects(self):
# Issue2957: bad recursion count on code objects
count = 5000 # more than MAX_MARSHAL_STACK_DEPTH
# more than MAX_MARSHAL_STACK_DEPTH
count = support.EXCEEDS_RECURSION_LIMIT
codes = (ExceptionTestCase.test_exceptions.__code__,) * count
marshal.loads(marshal.dumps(codes))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Separate Python recursion checking from C recursion checking which reduces
the chance of C stack overflow and allows the recursion limit to be
increased safely.
4 changes: 1 addition & 3 deletions Modules/_testinternalcapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ get_recursion_depth(PyObject *self, PyObject *Py_UNUSED(args))
{
PyThreadState *tstate = _PyThreadState_GET();

/* subtract one to ignore the frame of the get_recursion_depth() call */

return PyLong_FromLong(tstate->recursion_limit - tstate->recursion_remaining - 1);
return PyLong_FromLong(tstate->py_recursion_limit - tstate->py_recursion_remaining);
}


Expand Down
9 changes: 3 additions & 6 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,19 +1380,16 @@ class PartingShots(StaticVisitor):
return NULL;
}
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;
/* Be careful here to prevent overflow. */
int COMPILER_STACK_FRAME_SCALE = 3;
PyThreadState *tstate = _PyThreadState_GET();
if (!tstate) {
return 0;
}
state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;
PyObject *result = ast2obj_mod(state, t);
Expand Down
9 changes: 3 additions & 6 deletions Python/Python-ast.c

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 3 additions & 6 deletions Python/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,6 @@ _PyAST_Validate(mod_ty mod)
int res = -1;
struct validator state;
PyThreadState *tstate;
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;

/* Setup recursion depth check counters */
Expand All @@ -984,12 +983,10 @@ _PyAST_Validate(mod_ty mod)
return 0;
}
/* Be careful here to prevent overflow. */
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth< INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state.recursion_depth = starting_recursion_depth;
state.recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
state.recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;

switch (mod->kind) {
case Module_kind:
Expand Down
9 changes: 3 additions & 6 deletions Python/ast_opt.c
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,6 @@ int
_PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
{
PyThreadState *tstate;
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;

/* Setup recursion depth check counters */
Expand All @@ -1089,12 +1088,10 @@ _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
return 0;
}
/* Be careful here to prevent overflow. */
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;
state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;

int ret = astfold_mod(mod, arena, state);
assert(ret || PyErr_Occurred());
Expand Down
Loading

0 comments on commit 7644935

Please sign in to comment.