Skip to content

Commit

Permalink
Patch loops to copy context on task creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Sep 11, 2018
1 parent 278ad10 commit ab53215
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 0 deletions.
37 changes: 37 additions & 0 deletions contextvars/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import collections.abc
import threading
import types

import immutables

Expand Down Expand Up @@ -209,3 +210,39 @@ def _get_state():


_state = threading.local()


def create_task(loop, coro):
task = loop._orig_create_task(coro)
if task._source_traceback:
del task._source_traceback[-1]
task.context = copy_context()
return task


def _patch_loop(loop):
if loop and not hasattr(loop, '_orig_create_task'):
loop._orig_create_task = loop.create_task
loop.create_task = types.MethodType(create_task, loop)
return loop


def get_event_loop():
return _patch_loop(_get_event_loop())


def set_event_loop(loop):
return _set_event_loop(_patch_loop(loop))


def new_event_loop():
return _patch_loop(_new_event_loop())


_get_event_loop = asyncio.get_event_loop
_set_event_loop = asyncio.set_event_loop
_new_event_loop = asyncio.new_event_loop

asyncio.get_event_loop = asyncio.events.get_event_loop = get_event_loop
asyncio.set_event_loop = asyncio.events.set_event_loop = set_event_loop
asyncio.new_event_loop = asyncio.events.new_event_loop = new_event_loop
94 changes: 94 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copied from https://git.io/fAGgA with small updates

import asyncio
import contextvars
import random
import unittest


class TaskTests(unittest.TestCase):
def test_context_1(self):
cvar = contextvars.ContextVar('cvar')

async def sub():
await asyncio.sleep(0.01, loop=loop)
self.assertEqual(cvar.get(), 'nope')
cvar.set('something else')

async def main():
cvar.set('nope')
self.assertEqual(cvar.get(), 'nope')
subtask = loop.create_task(sub())
cvar.set('yes')
self.assertEqual(cvar.get(), 'yes')
await subtask
self.assertEqual(cvar.get(), 'yes')

loop = asyncio.new_event_loop()
try:
loop.run_until_complete(main())
finally:
loop.close()

def test_context_2(self):
cvar = contextvars.ContextVar('cvar', default='nope')

async def main():
def fut_on_done(fut):
# This change must not pollute the context
# of the "main()" task.
cvar.set('something else')

self.assertEqual(cvar.get(), 'nope')

for j in range(2):
fut = loop.create_future()
ctx = contextvars.copy_context()
fut.add_done_callback(lambda f: ctx.run(fut_on_done, f))
cvar.set('yes{}'.format(j))
loop.call_soon(fut.set_result, None)
await fut
self.assertEqual(cvar.get(), 'yes{}'.format(j))

for i in range(3):
# Test that task passed its context to add_done_callback:
cvar.set('yes{}-{}'.format(i, j))
await asyncio.sleep(0.001, loop=loop)
self.assertEqual(cvar.get(), 'yes{}-{}'.format(i, j))

loop = asyncio.new_event_loop()
try:
task = loop.create_task(main())
loop.run_until_complete(task)
finally:
loop.close()

self.assertEqual(cvar.get(), 'nope')

def test_context_3(self):
# Run 100 Tasks in parallel, each modifying cvar.

cvar = contextvars.ContextVar('cvar', default=-1)

async def sub(num):
for i in range(10):
cvar.set(num + i)
await asyncio.sleep(
random.uniform(0.001, 0.05), loop=loop)
self.assertEqual(cvar.get(), num + i)

async def main():
tasks = []
for i in range(100):
task = loop.create_task(sub(random.randint(0, 10)))
tasks.append(task)

await asyncio.gather(*tasks, loop=loop)

loop = asyncio.new_event_loop()
try:
loop.run_until_complete(main())
finally:
loop.close()

self.assertEqual(cvar.get(), -1)

0 comments on commit ab53215

Please sign in to comment.