diff --git a/Include/regions.h b/Include/regions.h index 52a7ec3e9d15a5..1cfc49ce0984b6 100644 --- a/Include/regions.h +++ b/Include/regions.h @@ -15,6 +15,8 @@ PyAPI_FUNC(int) _Py_IsLocal(PyObject *op); PyAPI_FUNC(int) _Py_IsCown(PyObject *op); #define Py_IsCown(op) _Py_IsCown(_PyObject_CAST(op)) +int Py_is_invariant_enabled(void); + #ifdef __cplusplus } #endif diff --git a/Lib/test/test_using.py b/Lib/test/test_using.py index ef005b34c4a2a6..8a1fd9a74a7b5b 100644 --- a/Lib/test/test_using.py +++ b/Lib/test/test_using.py @@ -166,3 +166,25 @@ def _(): result = c.get().value.value if result != 200: self.fail() + + def test_thread_creation(self): + from using import PyronaThread as T + + class Mutable: pass + self.assertRaises(RuntimeError, lambda x: T(target=print, args=(Mutable(),)), None) + self.assertRaises(RuntimeError, lambda x: T(target=print, kwargs={'a' : Mutable()}), None) + self.assertRaises(RuntimeError, lambda x: T(target=print, args=(Mutable(),), kwargs={'a' : Mutable()}), None) + self.assertRaises(RuntimeError, lambda x: T(target=print, args=(Mutable(), 42)), None) + self.assertRaises(RuntimeError, lambda x: T(target=print, args=(Mutable(), Cown())), None) + self.assertRaises(RuntimeError, lambda x: T(target=print, args=(Mutable(), Region())), None) + + T(target=print, kwargs={'imm' : 42, 'cown' : Cown(), 'region' : Region()}) + T(target=print, kwargs={'a': 42}) + T(target=print, kwargs={'a': Cown()}) + T(target=print, kwargs={'a': Region()}) + + T(target=print, args=(42, Cown(), Region())) + T(target=print, args=(42,)) + T(target=print, args=(Cown(),)) + T(target=print, args=(Region(),)) + self.assertTrue(True) # To make sure we got here correctly diff --git a/Lib/using.py b/Lib/using.py index 30ae93a3bbb1cf..aafc4e30b1e343 100644 --- a/Lib/using.py +++ b/Lib/using.py @@ -45,3 +45,58 @@ def decorator(func): with CS(cowns, *args): return func() return decorator + +# TODO: this creates a normal Python thread and ensures that all its +# arguments are moved to the new thread. Eventually we should revisit +# this behaviour as we go multiple interpreters / multicore. +# TODO: require RC to be one less when move is upstreamed +def PyronaThread(group=None, target=None, name=None, + args=(), kwargs=None, *, daemon=None): + # Only check when a program uses pyrona + from sys import getrefcount as rc + from threading import Thread + # TODO: improve this check for final version of phase 3 + # - Revisit the rc checks + # - Consider throwing a different kind of error (e.g. RegionError) + # - Improve error messages + def ok_share(o): + if isimmutable(o): + return True + if isinstance(o, Cown): + return True + return False + def ok_move(o): + if isinstance(o, Region): + if rc(o) != 5: + # rc = 4 because: + # 1. ref to o in rc + # 2. ref to o on this frame (ok_move) + # 3. ref to o on the calling frame (check) + # 4. ref to o from iteration over kwargs dictionary or args tuple/list + # 5. ref to o from kwargs dictionary or args tuple/list + raise RuntimeError("Region passed to thread was not moved into thread") + if o.is_open(): + raise RuntimeError("Region passed to thread was open") + return True + return False + + def check(a, args): + # rc(args) == 4 because we need to know that the args list is moved into the thread too + # rc = 4 because: + # 1. ref to args in rc + # 2. ref to args on this frame + # 3. ref to args on the calling framedef check(a, args): + # 4. ref from frame calling PyronaThread -- FIXME: not valid; revisit after #45 + if not (ok_share(a) or (ok_move(a) and rc(args) == 4)): + raise RuntimeError("Thread was passed an object which was neither immutable, a cown, or a unique region") + + if kwargs is None: + for a in args: + check(a, args) + return Thread(group, target, name, args, daemon) + else: + for k in kwargs: + # Important to get matching RCs in both paths + v = kwargs[k] + check(v, kwargs) + return Thread(group, target, name, kwargs, daemon) diff --git a/Objects/regions.c b/Objects/regions.c index 7f138973965bd3..6cf0ad06a6814e 100644 --- a/Objects/regions.c +++ b/Objects/regions.c @@ -50,6 +50,12 @@ static void _PyErr_Region(PyObject *src, PyObject *tgt, const char *msg); * Global status for performing the region check. */ bool invariant_do_region_check = false; +/** + * TODO: revisit the definition of this builting function + */ +int Py_is_invariant_enabled(void) { + return invariant_do_region_check; +} // The src object for an edge that invalidated the invariant. PyObject* invariant_error_src = Py_None;