Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

we don't need thread locks for SDK tracing context #6551

Merged
merged 1 commit into from
Jul 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 10 additions & 28 deletions sdk/core/azure-core/azure/core/tracing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class ContextProtocol(Protocol):
Implements set and get variables in a thread safe way.
"""

def __init__(self, name, default, lock):
# type: (string, Any, threading.Lock) -> None
def __init__(self, name, default):
# type: (string, Any) -> None
pass

def clear(self):
Expand All @@ -54,11 +54,10 @@ class _AsyncContext(object):
Uses contextvars to set and get variables globally in a thread safe way.
"""

def __init__(self, name, default, lock):
def __init__(self, name, default):
self.name = name
self.contextvar = contextvars.ContextVar(name)
self.default = default if callable(default) else (lambda: default)
self.lock = lock

def clear(self):
# type: () -> None
Expand All @@ -78,8 +77,7 @@ def get(self):
def set(self, value):
# type: (Any) -> None
"""Set the value in the context."""
with self.lock:
self.contextvar.set(value)
self.contextvar.set(value)


class _ThreadLocalContext(object):
Expand All @@ -88,11 +86,10 @@ class _ThreadLocalContext(object):
"""
_thread_local = threading.local()

def __init__(self, name, default, lock):
# type: (str, Any, threading.Lock) -> None
def __init__(self, name, default):
# type: (str, Any) -> None
self.name = name
self.default = default if callable(default) else (lambda: default)
self.lock = lock

def clear(self):
# type: () -> None
Expand All @@ -112,16 +109,14 @@ def get(self):
def set(self, value):
# type: (Any) -> None
"""Set the value in the context."""
with self.lock:
setattr(self._thread_local, self.name, value)
setattr(self._thread_local, self.name, value)


class TracingContext:
_lock = threading.Lock()

class TracingContext(object):
def __init__(self):
# type: () -> None
self.current_span = TracingContext._get_context_class("current_span", None)
context_class = _AsyncContext if contextvars else _ThreadLocalContext
self.current_span = context_class("current_span", None)

def with_current_context(self, func):
# type: (Callable[[Any], Any]) -> Any
Expand All @@ -146,17 +141,4 @@ def call_with_current_context(*args, **kwargs):

return call_with_current_context

@classmethod
def _get_context_class(cls, name, default_val):
# type: (str, Any) -> ContextProtocol
"""
Returns an instance of the the context class that stores the variable.
:param name: The key to store the variable in the context class
:param default_val: The default value of the variable if unset
:return: An instance that implements the context protocol class
"""
context_class = _AsyncContext if contextvars else _ThreadLocalContext
return context_class(name, default_val, cls._lock)


tracing_context = TracingContext()
11 changes: 3 additions & 8 deletions sdk/core/azure-core/tests/test_tracing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class TestContext(unittest.TestCase):
def test_get_context_class(self):
with ContextHelper():
slot = tracing_context._get_context_class("temp", 1)
assert slot.get() == 1
slot.set(2)
assert slot.get() == 2

def test_current_span(self):
with ContextHelper():
assert tracing_context.current_span.get() is None
assert not tracing_context.current_span.get()
val = mock.Mock(spec=AbstractSpan)
tracing_context.current_span.set(val)
assert tracing_context.current_span.get() == val
tracing_context.current_span.clear()
assert not tracing_context.current_span.get()

def test_with_current_context(self):
with ContextHelper(tracer_to_use=mock.Mock(AbstractSpan)):
Expand Down