Skip to content

Commit

Permalink
Add ThreadPoolTaskExecutor for non running loop context
Browse files Browse the repository at this point in the history
  • Loading branch information
ynkdir committed Dec 27, 2024
1 parent f1c1b90 commit 7406324
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 28 deletions.
61 changes: 61 additions & 0 deletions tests/test_asyncui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import asyncio
import sys
import threading
import unittest

from win32more.asyncui import async_callback


class TestAsyncui(unittest.TestCase):
@unittest.skipIf(sys.version_info < (3, 12), "eager task is not supported")
def test_task_will_start_in_current_running_loop(self):
@async_callback
async def f():
trace.append(2)
await asyncio.sleep(0)
trace.append(4)

async def main():
trace.append(1)
f()
trace.append(3)
await asyncio.sleep(0)
trace.append(5)

trace = []
asyncio.run(main())

self.assertEqual(trace, [1, 2, 3, 4, 5])

def test_task_will_start_eagerly_and_continue_in_background_thread(self):
@async_callback
async def f():
asyncio.current_task().add_done_callback(lambda _: event.set())
trace.append(threading.get_native_id()) # <- caller's thread
await asyncio.sleep(0)
trace.append(threading.get_native_id()) # <- background thread

trace = []
event = threading.Event()
f()
self.assertTrue(event.wait(5))
self.assertEqual(trace[0], threading.get_native_id())
self.assertNotEqual(trace[1], threading.get_native_id())

def test_task_will_start_eagerly_and_dont_start_thread_when_task_was_done(self):
@async_callback
async def f():
asyncio.current_task().add_done_callback(lambda _: event.set())
asyncio.current_task().add_done_callback(lambda _: trace.append((2, threading.get_native_id())))
trace.append((1, threading.get_native_id()))

trace = []
event = threading.Event()
f()
self.assertTrue(event.wait(5))
self.assertEqual(trace[0], (1, threading.get_native_id()))
self.assertEqual(trace[1], (2, threading.get_native_id()))


if __name__ == "__main__":
unittest.main()
90 changes: 62 additions & 28 deletions win32more/asyncui.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,36 @@
import asyncio
import sys
import threading
from concurrent.futures import Future, ThreadPoolExecutor

from win32more.Windows.Win32.System.Com import IUnknown
from win32more.Windows.Win32.UI.WindowsAndMessaging import SetTimer

running_loop = None


# Asyncio runner for Windows message loop.
def async_start_runner(delay_ms=100):
global running_loop

def timer_proc(*args):
running_loop._run_once()

running_loop = asyncio.new_event_loop()
running_loop.stop()
running_loop._run_forever_setup()
loop._run_once()

loop = asyncio.new_event_loop()
loop.stop()
loop._run_forever_setup()
SetTimer(0, 0, delay_ms, timer_proc)


_tasks_keep = set()


def async_callback(coroutine_function):
def wrapper(*args):
_addref(args)
if sys.version_info < (3, 12):
task = loop.create_task(coroutine_function(*args))
else:
# Start task eagerly.
# Some method can not be called after returned.
# (e.g. CoreWebView2NewWindowRequestedEventArgs.GetDeferral())
task = asyncio.eager_task_factory(loop, coroutine_function(*args))
_tasks_keep.add(task)
task.add_done_callback(_tasks_keep.remove)
task.add_done_callback(lambda _: _release(args))

loop = _get_running_loop()
try:
executor = RunningLoopTaskExecutor()
except RuntimeError:
executor = ThreadPoolTaskExecutor()
future = executor.submit(coroutine_function(*args))
future.add_done_callback(lambda _: _release(args))

return wrapper


def _get_running_loop():
return running_loop or asyncio.get_running_loop()


def _addref(args):
for obj in args:
if isinstance(obj, IUnknown) and obj.value:
Expand All @@ -57,3 +41,53 @@ def _release(args):
for obj in args:
if isinstance(obj, IUnknown) and obj.value:
obj.Release()


class RunningLoopTaskExecutor:
def __init__(self):
self._loop = asyncio.get_running_loop()

def submit(self, coro):
task = self._create_task(self._loop, coro)
return asyncio.run_coroutine_threadsafe(self._await_task(task), self._loop)

async def _await_task(self, task):
return await task

def _create_task(self, loop, coro):
if sys.version_info < (3, 12):
return loop.create_task(coro)
# Start task eagerly.
# Some method can not be called after returned.
# (e.g. CoreWebView2NewWindowRequestedEventArgs.GetDeferral())
return asyncio.eager_task_factory(loop, coro)


class ThreadPoolTaskExecutor:
_thread_pool = None
_lock = threading.Lock()

def __init__(self):
self._init_thread_pool()

@classmethod
def _init_thread_pool(cls):
if cls._thread_pool is None:
with cls._lock:
if cls._thread_pool is None:
cls._thread_pool = ThreadPoolExecutor()

def submit(self, coro):
loop = asyncio.new_event_loop()
# start task eagerly in current thread context.
# cannot use eager_task_factory because it requires running loop.
task = loop.create_task(coro)
loop.stop()
loop.run_forever()
if task.done():
loop.run_until_complete(task) # ensure calling done callback
future = Future()
future.set_result(task.result())
return future
# continue in background thread
return self._thread_pool.submit(loop.run_until_complete, task)

0 comments on commit 7406324

Please sign in to comment.