Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race in set_tracer_provider() #2182

Merged
merged 8 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.5.0-0.24b0...HEAD)
- Fix race in `set_tracer_provider()`
([#2182](https://github.com/open-telemetry/opentelemetry-python/pull/2182))
- Automatically load OTEL environment variables as options for `opentelemetry-instrument`
([#1969](https://github.com/open-telemetry/opentelemetry-python/pull/1969))
- `opentelemetry-semantic-conventions` Update to semantic conventions v1.6.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import unittest
from unittest import mock
from unittest.mock import patch

# pylint:disable=no-name-in-module
# pylint:disable=import-error
Expand All @@ -38,6 +37,7 @@
from opentelemetry.sdk.resources import SERVICE_NAME
from opentelemetry.sdk.trace import Resource, TracerProvider
from opentelemetry.sdk.util.instrumentation import InstrumentationInfo
from opentelemetry.test.globals_test import TraceGlobalsTest
from opentelemetry.test.spantestutil import (
get_span_with_dropped_attributes_events_links,
)
Expand All @@ -53,7 +53,7 @@ def _translate_spans_with_dropped_attributes():
return translate._translate(ThriftTranslator(max_tag_value_length=5))


class TestJaegerExporter(unittest.TestCase):
class TestJaegerExporter(TraceGlobalsTest, unittest.TestCase):
def setUp(self):
# create and save span to be used in tests
self.context = trace_api.SpanContext(
Expand All @@ -73,7 +73,6 @@ def setUp(self):
self._test_span.end(end_time=3)
# pylint: disable=protected-access

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_constructor_default(self):
# pylint: disable=protected-access
"""Test the default values assigned by constructor."""
Expand All @@ -98,7 +97,6 @@ def test_constructor_default(self):
self.assertTrue(exporter._agent_client is not None)
self.assertIsNone(exporter._max_tag_value_length)

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_constructor_explicit(self):
# pylint: disable=protected-access
"""Test the constructor passing all the options."""
Expand Down Expand Up @@ -143,7 +141,6 @@ def test_constructor_explicit(self):
self.assertTrue(exporter._collector_http_client.auth is None)
self.assertEqual(exporter._max_tag_value_length, 42)

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_constructor_by_environment_variables(self):
# pylint: disable=protected-access
"""Test the constructor using Environment Variables."""
Expand Down Expand Up @@ -198,7 +195,6 @@ def test_constructor_by_environment_variables(self):
self.assertTrue(exporter._collector_http_client.auth is None)
environ_patcher.stop()

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_constructor_with_no_traceprovider_resource(self):

"""Test the constructor when there is no resource attached to trace_provider"""
Expand Down Expand Up @@ -480,7 +476,6 @@ def test_translate_to_jaeger(self):

self.assertEqual(spans, expected_spans)

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_export(self):

"""Test that agent and/or collector are invoked"""
Expand Down Expand Up @@ -511,9 +506,7 @@ def test_export(self):
exporter.export((self._test_span,))
self.assertEqual(agent_client_mock.emit.call_count, 1)
self.assertEqual(collector_mock.submit.call_count, 1)
# trace_api._TRACER_PROVIDER = None

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_export_span_service_name(self):
trace_api.set_tracer_provider(
TracerProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,12 @@
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SpanExportResult
from opentelemetry.test.globals_test import TraceGlobalsTest
from opentelemetry.trace import TraceFlags


# pylint: disable=no-member
class TestCollectorSpanExporter(unittest.TestCase):
@mock.patch(
"opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER",
None,
)
class TestCollectorSpanExporter(TraceGlobalsTest, unittest.TestCase):
def test_constructor(self):
mock_get_node = mock.Mock()
patch = mock.patch(
Expand Down Expand Up @@ -329,10 +326,6 @@ def test_export(self):
getattr(output_identifier, "host_name"), "testHostName"
)

@mock.patch(
"opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER",
None,
)
def test_export_service_name(self):
trace_api.set_tracer_provider(
TracerProvider(
Expand Down
40 changes: 21 additions & 19 deletions opentelemetry-api/src/opentelemetry/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
)
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util import types
from opentelemetry.util._once import Once
from opentelemetry.util._providers import _load_provider

logger = getLogger(__name__)
Expand Down Expand Up @@ -452,8 +453,9 @@ def start_as_current_span(
yield INVALID_SPAN


_TRACER_PROVIDER = None
_PROXY_TRACER_PROVIDER = None
_TRACER_PROVIDER_SET_ONCE = Once()
_TRACER_PROVIDER: Optional[TracerProvider] = None
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()


def get_tracer(
Expand All @@ -476,40 +478,40 @@ def get_tracer(
)


def _set_tracer_provider(tracer_provider: TracerProvider, log: bool) -> None:
aabmass marked this conversation as resolved.
Show resolved Hide resolved
def set_tp() -> None:
global _TRACER_PROVIDER # pylint: disable=global-statement
_TRACER_PROVIDER = tracer_provider

did_set = _TRACER_PROVIDER_SET_ONCE.do_once(set_tp)

if log and not did_set:
logger.warning("Overriding of current TracerProvider is not allowed")


def set_tracer_provider(tracer_provider: TracerProvider) -> None:
"""Sets the current global :class:`~.TracerProvider` object.

This can only be done once, a warning will be logged if any furter attempt
is made.
"""
global _TRACER_PROVIDER # pylint: disable=global-statement

if _TRACER_PROVIDER is not None:
logger.warning("Overriding of current TracerProvider is not allowed")
return

_TRACER_PROVIDER = tracer_provider
_set_tracer_provider(tracer_provider, log=True)


def get_tracer_provider() -> TracerProvider:
"""Gets the current global :class:`~.TracerProvider` object."""
# pylint: disable=global-statement
global _TRACER_PROVIDER
global _PROXY_TRACER_PROVIDER

if _TRACER_PROVIDER is None:
# if a global tracer provider has not been set either via code or env
# vars, return a proxy tracer provider
if OTEL_PYTHON_TRACER_PROVIDER not in os.environ:
if not _PROXY_TRACER_PROVIDER:
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
return _PROXY_TRACER_PROVIDER

_TRACER_PROVIDER = cast( # type: ignore
"TracerProvider",
_load_provider(OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"),
tracer_provider: TracerProvider = _load_provider(
OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"
aabmass marked this conversation as resolved.
Show resolved Hide resolved
)
return _TRACER_PROVIDER
_set_tracer_provider(tracer_provider, log=False)
# _TRACER_PROVIDER will have been set by one thread
return cast("TracerProvider", _TRACER_PROVIDER)


@contextmanager # type: ignore
Expand Down
47 changes: 47 additions & 0 deletions opentelemetry-api/src/opentelemetry/util/_once.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from threading import Lock
from typing import Callable


class Once:
owais marked this conversation as resolved.
Show resolved Hide resolved
owais marked this conversation as resolved.
Show resolved Hide resolved
"""Execute a function exactly once and block all callers until the function returns

Same as golang's `sync.Once <https://pkg.go.dev/sync#Once>`_
"""

def __init__(self) -> None:
self._lock = Lock()
self._done = False

def do_once(self, func: Callable[[], None]) -> bool:
"""Execute ``func`` if it hasn't been executed or return.

Will block until ``func`` has been called by one thread.

Returns:
Whether or not ``func`` was executed in this call
"""

# fast path, try to avoid locking
if self._done:
return False

with self._lock:
if not self._done:
func()
self._done = True
return True
return False
65 changes: 51 additions & 14 deletions opentelemetry-api/tests/trace/test_globals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
from unittest.mock import patch
from unittest.mock import Mock, patch

from opentelemetry import context, trace
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
from opentelemetry.test.globals_test import TraceGlobalsTest
from opentelemetry.trace.status import Status, StatusCode


Expand All @@ -25,25 +27,60 @@ def record_exception(
self.recorded_exception = exception


class TestGlobals(unittest.TestCase):
def setUp(self):
self._patcher = patch("opentelemetry.trace._TRACER_PROVIDER")
self._mock_tracer_provider = self._patcher.start()

def tearDown(self) -> None:
self._patcher.stop()

def test_get_tracer(self):
class TestGlobals(TraceGlobalsTest, unittest.TestCase):
@staticmethod
@patch("opentelemetry.trace._TRACER_PROVIDER")
def test_get_tracer(mock_tracer_provider): # type: ignore
"""trace.get_tracer should proxy to the global tracer provider."""
trace.get_tracer("foo", "var")
self._mock_tracer_provider.get_tracer.assert_called_with(
"foo", "var", None
)
mock_provider = unittest.mock.Mock()
mock_tracer_provider.get_tracer.assert_called_with("foo", "var", None)
mock_provider = Mock()
trace.get_tracer("foo", "var", mock_provider)
mock_provider.get_tracer.assert_called_with("foo", "var", None)


class TestGlobalsConcurrency(TraceGlobalsTest, ConcurrencyTestBase):
@patch("opentelemetry.trace.logger")
def test_set_tracer_provider_many_threads(self, mock_logger) -> None: # type: ignore
mock_logger.warning = MockFunc()

def do_concurrently() -> Mock:
# first get a proxy tracer
proxy_tracer = trace.ProxyTracerProvider().get_tracer("foo")

# try to set the global tracer provider
mock_tracer_provider = Mock(get_tracer=MockFunc())
trace.set_tracer_provider(mock_tracer_provider)

# start a span through the proxy which will call through to the mock provider
proxy_tracer.start_span("foo")

return mock_tracer_provider

num_threads = 100
mock_tracer_providers = self.run_with_many_threads(
do_concurrently,
num_threads=num_threads,
)

# despite trying to set tracer provider many times, only one of the
aabmass marked this conversation as resolved.
Show resolved Hide resolved
# mock_tracer_providers should have stuck and been called from
# proxy_tracer.start_span()
mock_tps_with_any_call = [
mock
for mock in mock_tracer_providers
if mock.get_tracer.call_count > 0
]

self.assertEqual(len(mock_tps_with_any_call), 1)
self.assertEqual(
mock_tps_with_any_call[0].get_tracer.call_count, num_threads
)

# should have warned everytime except for the successful set
self.assertEqual(mock_logger.warning.call_count, num_threads - 1)


class TestTracer(unittest.TestCase):
def setUp(self):
# pylint: disable=protected-access
Expand Down
10 changes: 5 additions & 5 deletions opentelemetry-api/tests/trace/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

from opentelemetry import trace
from opentelemetry.test.globals_test import TraceGlobalsTest
from opentelemetry.trace.span import INVALID_SPAN_CONTEXT, NonRecordingSpan


Expand All @@ -39,10 +40,8 @@ class TestSpan(NonRecordingSpan):
pass


class TestProxy(unittest.TestCase):
class TestProxy(TraceGlobalsTest, unittest.TestCase):
def test_proxy_tracer(self):
original_provider = trace._TRACER_PROVIDER

provider = trace.get_tracer_provider()
# proxy provider
self.assertIsInstance(provider, trace.ProxyTracerProvider)
Expand All @@ -60,6 +59,9 @@ def test_proxy_tracer(self):
# set a real provider
trace.set_tracer_provider(TestProvider())

# get_tracer_provider() now returns the real provider
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)

# tracer provider now returns real instance
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)

Expand All @@ -71,5 +73,3 @@ def test_proxy_tracer(self):
# creates real spans
with tracer.start_span("") as span:
self.assertIsInstance(span, TestSpan)

trace._TRACER_PROVIDER = original_provider
Loading