Skip to content

Commit

Permalink
pythongh-76785: Minor Improvements to "interpreters" Module (pythongh…
Browse files Browse the repository at this point in the history
…-116328)

This includes adding pickle support to various classes, and small changes to improve the maintainability of the low-level _xxinterpqueues module.
  • Loading branch information
ericsnowcurrently authored Mar 5, 2024
1 parent bdba8ef commit 4402b3c
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 88 deletions.
8 changes: 8 additions & 0 deletions Lib/test/support/interpreters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ def __hash__(self):
def __del__(self):
self._decref()

# for pickling:
def __getnewargs__(self):
return (self._id,)

# for pickling:
def __getstate__(self):
return None

def _decref(self):
if not self._ownsref:
return
Expand Down
12 changes: 11 additions & 1 deletion Lib/test/support/interpreters/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ class _ChannelEnd:

_end = None

def __init__(self, cid):
def __new__(cls, cid):
self = super().__new__(cls)
if self._end == 'send':
cid = _channels._channel_id(cid, send=True, force=True)
elif self._end == 'recv':
cid = _channels._channel_id(cid, recv=True, force=True)
else:
raise NotImplementedError(self._end)
self._id = cid
return self

def __repr__(self):
return f'{type(self).__name__}(id={int(self._id)})'
Expand All @@ -61,6 +63,14 @@ def __eq__(self, other):
return NotImplemented
return other._id == self._id

# for pickling:
def __getnewargs__(self):
return (int(self._id),)

# for pickling:
def __getstate__(self):
return None

@property
def id(self):
return self._id
Expand Down
31 changes: 16 additions & 15 deletions Lib/test/support/interpreters/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
]


class QueueEmpty(_queues.QueueEmpty, queue.Empty):
class QueueEmpty(QueueError, queue.Empty):
"""Raised from get_nowait() when the queue is empty.
It is also raised from get() if it times out.
"""


class QueueFull(_queues.QueueFull, queue.Full):
class QueueFull(QueueError, queue.Full):
"""Raised from put_nowait() when the queue is full.
It is also raised from put() if it times out.
Expand Down Expand Up @@ -66,7 +66,7 @@ def __new__(cls, id, /, *, _fmt=None):
else:
raise TypeError(f'id must be an int, got {id!r}')
if _fmt is None:
_fmt = _queues.get_default_fmt(id)
_fmt, = _queues.get_queue_defaults(id)
try:
self = _known_queues[id]
except KeyError:
Expand All @@ -93,6 +93,14 @@ def __repr__(self):
def __hash__(self):
return hash(self._id)

# for pickling:
def __getnewargs__(self):
return (self._id,)

# for pickling:
def __getstate__(self):
return None

@property
def id(self):
return self._id
Expand Down Expand Up @@ -159,9 +167,8 @@ def put(self, obj, timeout=None, *,
while True:
try:
_queues.put(self._id, obj, fmt)
except _queues.QueueFull as exc:
except QueueFull as exc:
if timeout is not None and time.time() >= end:
exc.__class__ = QueueFull
raise # re-raise
time.sleep(_delay)
else:
Expand All @@ -174,11 +181,7 @@ def put_nowait(self, obj, *, syncobj=None):
fmt = _SHARED_ONLY if syncobj else _PICKLED
if fmt is _PICKLED:
obj = pickle.dumps(obj)
try:
_queues.put(self._id, obj, fmt)
except _queues.QueueFull as exc:
exc.__class__ = QueueFull
raise # re-raise
_queues.put(self._id, obj, fmt)

def get(self, timeout=None, *,
_delay=10 / 1000, # 10 milliseconds
Expand All @@ -195,9 +198,8 @@ def get(self, timeout=None, *,
while True:
try:
obj, fmt = _queues.get(self._id)
except _queues.QueueEmpty as exc:
except QueueEmpty as exc:
if timeout is not None and time.time() >= end:
exc.__class__ = QueueEmpty
raise # re-raise
time.sleep(_delay)
else:
Expand All @@ -216,8 +218,7 @@ def get_nowait(self):
"""
try:
obj, fmt = _queues.get(self._id)
except _queues.QueueEmpty as exc:
exc.__class__ = QueueEmpty
except QueueEmpty as exc:
raise # re-raise
if fmt == _PICKLED:
obj = pickle.loads(obj)
Expand All @@ -226,4 +227,4 @@ def get_nowait(self):
return obj


_queues._register_queue_type(Queue)
_queues._register_heap_types(Queue, QueueEmpty, QueueFull)
7 changes: 7 additions & 0 deletions Lib/test/test_interpreters/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle
import threading
from textwrap import dedent
import unittest
Expand Down Expand Up @@ -261,6 +262,12 @@ def test_equality(self):
self.assertEqual(interp1, interp1)
self.assertNotEqual(interp1, interp2)

def test_pickle(self):
interp = interpreters.create()
data = pickle.dumps(interp)
unpickled = pickle.loads(data)
self.assertEqual(unpickled, interp)


class TestInterpreterIsRunning(TestBase):

Expand Down
13 changes: 13 additions & 0 deletions Lib/test/test_interpreters/test_channels.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import pickle
import threading
from textwrap import dedent
import unittest
Expand Down Expand Up @@ -100,6 +101,12 @@ def test_equality(self):
self.assertEqual(ch1, ch1)
self.assertNotEqual(ch1, ch2)

def test_pickle(self):
ch, _ = channels.create()
data = pickle.dumps(ch)
unpickled = pickle.loads(data)
self.assertEqual(unpickled, ch)


class TestSendChannelAttrs(TestBase):

Expand All @@ -125,6 +132,12 @@ def test_equality(self):
self.assertEqual(ch1, ch1)
self.assertNotEqual(ch1, ch2)

def test_pickle(self):
_, ch = channels.create()
data = pickle.dumps(ch)
unpickled = pickle.loads(data)
self.assertEqual(unpickled, ch)


class TestSendRecv(TestBase):

Expand Down
71 changes: 67 additions & 4 deletions Lib/test/test_interpreters/test_queues.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import importlib
import pickle
import threading
from textwrap import dedent
import unittest
import time

from test.support import import_helper
from test.support import import_helper, Py_DEBUG
# Raise SkipTest if subinterpreters not supported.
_queues = import_helper.import_module('_xxinterpqueues')
from test.support import interpreters
from test.support.interpreters import queues
from .utils import _run_output, TestBase
from .utils import _run_output, TestBase as _TestBase


class TestBase(TestBase):
def get_num_queues():
return len(_queues.list_all())


class TestBase(_TestBase):
def tearDown(self):
for qid in _queues.list_all():
for qid, _ in _queues.list_all():
try:
_queues.destroy(qid)
except Exception:
Expand All @@ -34,6 +39,58 @@ def test_highlevel_reloaded(self):
# See gh-115490 (https://github.com/python/cpython/issues/115490).
importlib.reload(queues)

def test_create_destroy(self):
qid = _queues.create(2, 0)
_queues.destroy(qid)
self.assertEqual(get_num_queues(), 0)
with self.assertRaises(queues.QueueNotFoundError):
_queues.get(qid)
with self.assertRaises(queues.QueueNotFoundError):
_queues.destroy(qid)

def test_not_destroyed(self):
# It should have cleaned up any remaining queues.
stdout, stderr = self.assert_python_ok(
'-c',
dedent(f"""
import {_queues.__name__} as _queues
_queues.create(2, 0)
"""),
)
self.assertEqual(stdout, '')
if Py_DEBUG:
self.assertNotEqual(stderr, '')
else:
self.assertEqual(stderr, '')

def test_bind_release(self):
with self.subTest('typical'):
qid = _queues.create(2, 0)
_queues.bind(qid)
_queues.release(qid)
self.assertEqual(get_num_queues(), 0)

with self.subTest('bind too much'):
qid = _queues.create(2, 0)
_queues.bind(qid)
_queues.bind(qid)
_queues.release(qid)
_queues.destroy(qid)
self.assertEqual(get_num_queues(), 0)

with self.subTest('nested'):
qid = _queues.create(2, 0)
_queues.bind(qid)
_queues.bind(qid)
_queues.release(qid)
_queues.release(qid)
self.assertEqual(get_num_queues(), 0)

with self.subTest('release without binding'):
qid = _queues.create(2, 0)
with self.assertRaises(queues.QueueError):
_queues.release(qid)


class QueueTests(TestBase):

Expand Down Expand Up @@ -127,6 +184,12 @@ def test_equality(self):
self.assertEqual(queue1, queue1)
self.assertNotEqual(queue1, queue2)

def test_pickle(self):
queue = queues.create()
data = pickle.dumps(queue)
unpickled = pickle.loads(data)
self.assertEqual(unpickled, queue)


class TestQueueOps(TestBase):

Expand Down
8 changes: 8 additions & 0 deletions Modules/_interpreters_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,11 @@ ensure_xid_class(PyTypeObject *cls, crossinterpdatafunc getdata)
//assert(cls->tp_flags & Py_TPFLAGS_HEAPTYPE);
return _PyCrossInterpreterData_RegisterClass(cls, getdata);
}

#ifdef REGISTERS_HEAP_TYPES
static int
clear_xid_class(PyTypeObject *cls)
{
return _PyCrossInterpreterData_UnregisterClass(cls);
}
#endif
14 changes: 8 additions & 6 deletions Modules/_xxinterpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include <sched.h> // sched_yield()
#endif

#define REGISTERS_HEAP_TYPES
#include "_interpreters_common.h"
#undef REGISTERS_HEAP_TYPES


/*
Expand Down Expand Up @@ -281,17 +283,17 @@ clear_xid_types(module_state *state)
{
/* external types */
if (state->send_channel_type != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type);
(void)clear_xid_class(state->send_channel_type);
Py_CLEAR(state->send_channel_type);
}
if (state->recv_channel_type != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->recv_channel_type);
(void)clear_xid_class(state->recv_channel_type);
Py_CLEAR(state->recv_channel_type);
}

/* heap types */
if (state->ChannelIDType != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
(void)clear_xid_class(state->ChannelIDType);
Py_CLEAR(state->ChannelIDType);
}
}
Expand Down Expand Up @@ -2677,11 +2679,11 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)

// Clear the old values if the .py module was reloaded.
if (state->send_channel_type != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type);
(void)clear_xid_class(state->send_channel_type);
Py_CLEAR(state->send_channel_type);
}
if (state->recv_channel_type != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->recv_channel_type);
(void)clear_xid_class(state->recv_channel_type);
Py_CLEAR(state->recv_channel_type);
}

Expand All @@ -2694,7 +2696,7 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
return -1;
}
if (ensure_xid_class(recv, _channelend_shared) < 0) {
(void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type);
(void)clear_xid_class(state->send_channel_type);
Py_CLEAR(state->send_channel_type);
Py_CLEAR(state->recv_channel_type);
return -1;
Expand Down
Loading

0 comments on commit 4402b3c

Please sign in to comment.