Skip to content

Commit

Permalink
Propagate Botocore context to child threads
Browse files Browse the repository at this point in the history
  • Loading branch information
hssyoo committed Feb 13, 2025
1 parent bc23c92 commit 927fb5b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 26 deletions.
6 changes: 5 additions & 1 deletion s3transfer/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from collections import namedtuple
from concurrent import futures

from botocore.context import get_context

from s3transfer.compat import MAXINT
from s3transfer.exceptions import CancelledError, TransferNotDoneError
from s3transfer.utils import FunctionContainer, TaskSemaphore
Expand Down Expand Up @@ -467,7 +469,9 @@ def submit(self, task, tag=None, block=True):
semaphore.release, task.transfer_id, acquire_token
)
# Submit the task to the underlying executor.
future = ExecutorFuture(self._executor.submit(task))
# Pass the current context to ensure child threads persist the
# parent thread's context.
future = ExecutorFuture(self._executor.submit(task, get_context()))
# Add the Semaphore.release() callback to the future such that
# it is invoked once the future completes.
future.add_done_callback(release_callback)
Expand Down
53 changes: 28 additions & 25 deletions s3transfer/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import copy
import logging

from botocore.context import start_as_current_context

from s3transfer.utils import get_callbacks

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -118,32 +120,33 @@ def _get_kwargs_with_params_to_exclude(self, kwargs, exclude):
filtered_kwargs[param] = value
return filtered_kwargs

def __call__(self):
def __call__(self, ctx=None):
"""The callable to use when submitting a Task to an executor"""
try:
# Wait for all of futures this task depends on.
self._wait_on_dependent_futures()
# Gather up all of the main keyword arguments for main().
# This includes the immediately provided main_kwargs and
# the values for pending_main_kwargs that source from the return
# values from the task's dependent futures.
kwargs = self._get_all_main_kwargs()
# If the task is not done (really only if some other related
# task to the TransferFuture had failed) then execute the task's
# main() method.
if not self._transfer_coordinator.done():
return self._execute_main(kwargs)
except Exception as e:
self._log_and_set_exception(e)
finally:
# Run any done callbacks associated to the task no matter what.
for done_callback in self._done_callbacks:
done_callback()

if self._is_final:
# If this is the final task announce that it is done if results
# are waiting on its completion.
self._transfer_coordinator.announce_done()
with start_as_current_context(ctx):
try:
# Wait for all of futures this task depends on.
self._wait_on_dependent_futures()
# Gather up all of the main keyword arguments for main().
# This includes the immediately provided main_kwargs and
# the values for pending_main_kwargs that source from the return
# values from the task's dependent futures.
kwargs = self._get_all_main_kwargs()
# If the task is not done (really only if some other related
# task to the TransferFuture had failed) then execute the task's
# main() method.
if not self._transfer_coordinator.done():
return self._execute_main(kwargs)
except Exception as e:
self._log_and_set_exception(e)
finally:
# Run any done callbacks associated to the task no matter what.
for done_callback in self._done_callbacks:
done_callback()

if self._is_final:
# If this is the final task announce that it is done if results
# are waiting on its completion.
self._transfer_coordinator.announce_done()

def _execute_main(self, kwargs):
# Do not display keyword args that should not be printed, especially
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from functools import partial
from threading import Event

from botocore.context import ClientContext, get_context

from s3transfer.futures import BoundedExecutor, TransferCoordinator
from s3transfer.subscribers import BaseSubscriber
from s3transfer.tasks import (
Expand Down Expand Up @@ -69,6 +71,11 @@ def _submit(self, transfer_future, **kwargs):
pass


class ReturnContextTask(Task):
def _main(self):
return get_context()


class ExceptionSubmissionTask(SubmissionTask):
def _submit(
self,
Expand Down Expand Up @@ -723,6 +730,15 @@ def test_single_failed_pending_future_in_list(self):
with self.assertRaises(TaskFailureException):
self.transfer_coordinator.result()

def test_passing_context_to_task_call(self):
ctx = ClientContext()
ctx.features.add('FOO')
task = ReturnContextTask(self.transfer_coordinator)
self.assertEqual(task(ctx).features, {'FOO'})
# `task(ctx)` returned, so the current context should be reset to None.
current_ctx = get_context()
self.assertEqual(current_ctx, None)


class BaseMultipartTaskTest(BaseTaskTest):
def setUp(self):
Expand Down

0 comments on commit 927fb5b

Please sign in to comment.