From b02ff47a6f0048efa35d8a0083ed4e1f7dc1ff9b Mon Sep 17 00:00:00 2001 From: Jeremy Voss Date: Mon, 21 Nov 2022 07:43:27 -0800 Subject: [PATCH] Custom sampler fix (#3026) * Fixed circular dependency that can happen when injecting custom samplers * lint * Deleted duplicate tests * lint * lint * lint * lint * lint * Update opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py Co-authored-by: Leighton Chen * Update opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py Co-authored-by: Leighton Chen * typing * Update opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py Co-authored-by: Leighton Chen * Update opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py Co-authored-by: Srikanth Chekuri * Retry tests * Fixed circular dependency that can happen when injecting custom samplers * lint * Deleted duplicate tests * lint * lint * lint * lint * lint * Update opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py Co-authored-by: Leighton Chen * Update opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py Co-authored-by: Leighton Chen * typing * Retry tests * Update opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py Co-authored-by: Leighton Chen * Update opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py Co-authored-by: Srikanth Chekuri * Updated contrib sha Co-authored-by: Srikanth Chekuri Co-authored-by: Leighton Chen Co-authored-by: Diego Hurtado --- .github/workflows/test.yml | 2 +- CHANGELOG.md | 2 + .../sdk/_configuration/__init__.py | 78 ++++- .../src/opentelemetry/sdk/trace/__init__.py | 37 ++- .../src/opentelemetry/sdk/trace/sampling.py | 64 ++-- .../src/opentelemetry/sdk/util/__init__.py | 30 +- opentelemetry-sdk/tests/test_configurator.py | 288 +++++++++++++++++- opentelemetry-sdk/tests/trace/test_trace.py | 240 +-------------- 8 files changed, 412 insertions(+), 329 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aedcce541ae..a75f3576232 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ env: # Otherwise, set variable to the commit of your branch on # opentelemetry-python-contrib which is compatible with these Core repo # changes. - CONTRIB_REPO_SHA: 66edf69811e142c397d8500cafe6eddeb5565d6e + CONTRIB_REPO_SHA: c6134843900e2eeb1b8b3383a897b38cc0905c38 # This is needed because we do not clone the core repo in contrib builds anymore. # When running contrib builds as part of core builds, we use actions/checkout@v2 which # does not set an environment variable (simply just runs tox), which is different when diff --git a/CHANGELOG.md b/CHANGELOG.md index 339cdb62460..030a8a37452 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Fixed circular dependency issue with custom samplers + ([#3026](https://github.com/open-telemetry/opentelemetry-python/pull/3026)) - Add missing entry points for OTLP/HTTP exporter ([#3027](https://github.com/open-telemetry/opentelemetry-python/pull/3027)) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py index c2280e1b27e..27f3a334c7a 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/_configuration/__init__.py @@ -21,8 +21,9 @@ import os from abc import ABC, abstractmethod from os import environ -from typing import Dict, Optional, Sequence, Tuple, Type +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type +from pkg_resources import iter_entry_points from typing_extensions import Literal from opentelemetry.environment_variables import ( @@ -44,6 +45,8 @@ OTEL_EXPORTER_OTLP_METRICS_PROTOCOL, OTEL_EXPORTER_OTLP_PROTOCOL, OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, + OTEL_TRACES_SAMPLER, + OTEL_TRACES_SAMPLER_ARG, ) from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import ( @@ -54,7 +57,7 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter from opentelemetry.sdk.trace.id_generator import IdGenerator -from opentelemetry.sdk.util import _import_config_components +from opentelemetry.sdk.trace.sampling import Sampler from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.trace import set_tracer_provider @@ -82,9 +85,35 @@ _RANDOM_ID_GENERATOR = "random" _DEFAULT_ID_GENERATOR = _RANDOM_ID_GENERATOR +_OTEL_SAMPLER_ENTRY_POINT_GROUP = "opentelemetry_traces_sampler" + _logger = logging.getLogger(__name__) +def _import_config_components( + selected_components: List[str], entry_point_name: str +) -> Sequence[Tuple[str, object]]: + component_entry_points = { + ep.name: ep for ep in iter_entry_points(entry_point_name) + } + component_impls = [] + for selected_component in selected_components: + entry_point = component_entry_points.get(selected_component, None) + if not entry_point: + raise RuntimeError( + f"Requested component '{selected_component}' not found in entry points for '{entry_point_name}'" + ) + + component_impl = entry_point.load() + component_impls.append((selected_component, component_impl)) + + return component_impls + + +def _get_sampler() -> Optional[str]: + return environ.get(OTEL_TRACES_SAMPLER, None) + + def _get_id_generator() -> str: return environ.get(OTEL_PYTHON_ID_GENERATOR, _DEFAULT_ID_GENERATOR) @@ -149,7 +178,8 @@ def _get_exporter_names( def _init_tracing( exporters: Dict[str, Type[SpanExporter]], - id_generator: IdGenerator, + id_generator: IdGenerator = None, + sampler: Sampler = None, auto_instrumentation_version: Optional[str] = None, ): # if env var OTEL_RESOURCE_ATTRIBUTES is given, it will read the service_name @@ -161,7 +191,8 @@ def _init_tracing( ResourceAttributes.TELEMETRY_AUTO_VERSION ] = auto_instrumentation_version provider = TracerProvider( - id_generator=id_generator(), + id_generator=id_generator, + sampler=sampler, resource=Resource.create(auto_resource), ) set_tracer_provider(provider) @@ -266,13 +297,41 @@ def _import_exporters( return trace_exporters, metric_exporters, log_exporters +def _import_sampler_factory(sampler_name: str) -> Callable[[str], Sampler]: + _, sampler_impl = _import_config_components( + [sampler_name.strip()], _OTEL_SAMPLER_ENTRY_POINT_GROUP + )[0] + return sampler_impl + + +def _import_sampler(sampler_name: str) -> Optional[Sampler]: + if not sampler_name: + return None + try: + sampler_factory = _import_sampler_factory(sampler_name) + sampler_arg = os.getenv(OTEL_TRACES_SAMPLER_ARG, "") + sampler = sampler_factory(sampler_arg) + if not isinstance(sampler, Sampler): + message = f"Sampler factory, {sampler_factory}, produced output, {sampler}, which is not a Sampler." + _logger.warning(message) + raise ValueError(message) + return sampler + except Exception as exc: # pylint: disable=broad-except + _logger.warning( + "Using default sampler. Failed to initialize custom sampler, %s: %s", + sampler_name, + exc, + ) + return None + + def _import_id_generator(id_generator_name: str) -> IdGenerator: id_generator_name, id_generator_impl = _import_config_components( [id_generator_name.strip()], "opentelemetry_id_generator" )[0] if issubclass(id_generator_impl, IdGenerator): - return id_generator_impl + return id_generator_impl() raise RuntimeError(f"{id_generator_name} is not an IdGenerator") @@ -283,9 +342,16 @@ def _initialize_components(auto_instrumentation_version): _get_exporter_names("metrics"), _get_exporter_names("logs"), ) + sampler_name = _get_sampler() + sampler = _import_sampler(sampler_name) id_generator_name = _get_id_generator() id_generator = _import_id_generator(id_generator_name) - _init_tracing(trace_exporters, id_generator, auto_instrumentation_version) + _init_tracing( + exporters=trace_exporters, + id_generator=id_generator, + sampler=sampler, + auto_instrumentation_version=auto_instrumentation_version, + ) _init_metrics(metric_exporters, auto_instrumentation_version) logging_enabled = os.getenv( _OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED, "false" diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py index 51ede5e1211..0ddf531a4d5 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py @@ -78,9 +78,6 @@ _ENV_VALUE_UNSET = "" -# pylint: disable=protected-access -_TRACE_SAMPLER = sampling._get_from_env_or_default() - class SpanProcessor: """Interface which allows hooks for SDK's `Span` start and end method @@ -334,7 +331,7 @@ def _check_span_ended(func): def wrapper(self, *args, **kwargs): already_ended = False with self._lock: # pylint: disable=protected-access - if self._end_time is None: + if self._end_time is None: # pylint: disable=protected-access func(self, *args, **kwargs) else: already_ended = True @@ -519,7 +516,11 @@ def _format_events(events): f_event = OrderedDict() f_event["name"] = event.name f_event["timestamp"] = util.ns_to_iso_str(event.timestamp) - f_event["attributes"] = Span._format_attributes(event.attributes) + f_event[ + "attributes" + ] = Span._format_attributes( # pylint: disable=protected-access + event.attributes + ) f_events.append(f_event) return f_events @@ -528,8 +529,16 @@ def _format_links(links): f_links = [] for link in links: f_link = OrderedDict() - f_link["context"] = Span._format_context(link.context) - f_link["attributes"] = Span._format_attributes(link.attributes) + f_link[ + "context" + ] = Span._format_context( # pylint: disable=protected-access + link.context + ) + f_link[ + "attributes" + ] = Span._format_attributes( # pylint: disable=protected-access + link.attributes + ) f_links.append(f_link) return f_links @@ -691,10 +700,12 @@ def _from_env_if_absent( ) # not removed for backward compat. please use SpanLimits instead. -SPAN_ATTRIBUTE_COUNT_LIMIT = SpanLimits._from_env_if_absent( - None, - OTEL_SPAN_ATTRIBUTE_COUNT_LIMIT, - _DEFAULT_OTEL_SPAN_ATTRIBUTE_COUNT_LIMIT, +SPAN_ATTRIBUTE_COUNT_LIMIT = ( + SpanLimits._from_env_if_absent( # pylint: disable=protected-access + None, + OTEL_SPAN_ATTRIBUTE_COUNT_LIMIT, + _DEFAULT_OTEL_SPAN_ATTRIBUTE_COUNT_LIMIT, + ) ) @@ -1115,7 +1126,7 @@ class TracerProvider(trace_api.TracerProvider): def __init__( self, - sampler: sampling.Sampler = _TRACE_SAMPLER, + sampler: sampling.Sampler = None, resource: Resource = Resource.create({}), shutdown_on_exit: bool = True, active_span_processor: Union[ @@ -1132,6 +1143,8 @@ def __init__( else: self.id_generator = id_generator self._resource = resource + if not sampler: + sampler = sampling._get_from_env_or_default() self.sampler = sampler self._span_limits = span_limits or SpanLimits() self._atexit_handler = None diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py index 38a3338b02f..8af41f3d66e 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py @@ -73,7 +73,8 @@ * parentbased_always_off - Sampler that respects its parent span's sampling decision, but otherwise never samples. * parentbased_traceidratio - Sampler that respects its parent span's sampling decision, but otherwise samples probabalistically based on rate. -Sampling probability can be set with ``OTEL_TRACES_SAMPLER_ARG`` if the sampler is traceidratio or parentbased_traceidratio. Rate must be in the range [0.0,1.0]. When not provided rate will be set to 1.0 (maximum rate possible). +Sampling probability can be set with ``OTEL_TRACES_SAMPLER_ARG`` if the sampler is traceidratio or parentbased_traceidratio. Rate must be in the range [0.0,1.0]. When not provided rate will be set to +1.0 (maximum rate possible). Prev example but with environment variables. Please make sure to set the env ``OTEL_TRACES_SAMPLER=traceidratio`` and ``OTEL_TRACES_SAMPLER_ARG=0.001``. @@ -97,9 +98,10 @@ with trace.get_tracer(__name__).start_as_current_span("Test Span"): ... -In order to create a configurable custom sampler, create an entry point for the custom sampler factory method under the entry point group, ``opentelemetry_traces_sampler``. The custom sampler factory -method must be of type ``Callable[[str], Sampler]``, taking a single string argument and returning a Sampler object. The single input will come from the string value of the -``OTEL_TRACES_SAMPLER_ARG`` environment variable. If ``OTEL_TRACES_SAMPLER_ARG`` is not configured, the input will be an empty string. For example: +When utilizing a configurator, you can configure a custom sampler. In order to create a configurable custom sampler, create an entry point for the custom sampler +factory method or function under the entry point group, ``opentelemetry_traces_sampler``. The custom sampler factory method must be of type ``Callable[[str], Sampler]``, taking a single string argument and +returning a Sampler object. The single input will come from the string value of the ``OTEL_TRACES_SAMPLER_ARG`` environment variable. If ``OTEL_TRACES_SAMPLER_ARG`` is not configured, the input will +be an empty string. For example: .. code:: python @@ -134,7 +136,7 @@ class CustomSamplerFactory: import os from logging import getLogger from types import MappingProxyType -from typing import Callable, Optional, Sequence +from typing import Optional, Sequence # pylint: disable=unused-import from opentelemetry.context import Context @@ -142,7 +144,6 @@ class CustomSamplerFactory: OTEL_TRACES_SAMPLER, OTEL_TRACES_SAMPLER_ARG, ) -from opentelemetry.sdk.util import _import_config_components from opentelemetry.trace import Link, SpanKind, get_current_span from opentelemetry.trace.span import TraceState from opentelemetry.util.types import Attributes @@ -193,9 +194,6 @@ def __init__( self.trace_state = trace_state -_OTEL_SAMPLER_ENTRY_POINT_GROUP = "opentelemetry_traces_sampler" - - class Sampler(abc.ABC): @abc.abstractmethod def should_sample( @@ -407,37 +405,22 @@ def __init__(self, rate: float): def _get_from_env_or_default() -> Sampler: - traces_sampler_name = os.getenv( + trace_sampler = os.getenv( OTEL_TRACES_SAMPLER, "parentbased_always_on" ).lower() + if trace_sampler not in _KNOWN_SAMPLERS: + _logger.warning("Couldn't recognize sampler %s.", trace_sampler) + trace_sampler = "parentbased_always_on" - if traces_sampler_name in _KNOWN_SAMPLERS: - if traces_sampler_name in ("traceidratio", "parentbased_traceidratio"): - try: - rate = float(os.getenv(OTEL_TRACES_SAMPLER_ARG)) - except ValueError: - _logger.warning( - "Could not convert TRACES_SAMPLER_ARG to float." - ) - rate = 1.0 - return _KNOWN_SAMPLERS[traces_sampler_name](rate) - return _KNOWN_SAMPLERS[traces_sampler_name] - try: - traces_sampler_factory = _import_sampler_factory(traces_sampler_name) - sampler_arg = os.getenv(OTEL_TRACES_SAMPLER_ARG, "") - traces_sampler = traces_sampler_factory(sampler_arg) - if not isinstance(traces_sampler, Sampler): - message = f"Traces sampler factory, {traces_sampler_factory}, produced output, {traces_sampler}, which is not a Sampler object." - _logger.warning(message) - raise ValueError(message) - return traces_sampler - except Exception as exc: # pylint: disable=broad-except - _logger.warning( - "Using default sampler. Failed to initialize custom sampler, %s: %s", - traces_sampler_name, - exc, - ) - return _KNOWN_SAMPLERS["parentbased_always_on"] + if trace_sampler in ("traceidratio", "parentbased_traceidratio"): + try: + rate = float(os.getenv(OTEL_TRACES_SAMPLER_ARG)) + except ValueError: + _logger.warning("Could not convert TRACES_SAMPLER_ARG to float.") + rate = 1.0 + return _KNOWN_SAMPLERS[trace_sampler](rate) + + return _KNOWN_SAMPLERS[trace_sampler] def _get_parent_trace_state(parent_context) -> Optional["TraceState"]: @@ -445,10 +428,3 @@ def _get_parent_trace_state(parent_context) -> Optional["TraceState"]: if parent_span_context is None or not parent_span_context.is_valid: return None return parent_span_context.trace_state - - -def _import_sampler_factory(sampler_name: str) -> Callable[[str], Sampler]: - _, sampler_impl = _import_config_components( - [sampler_name.strip()], _OTEL_SAMPLER_ENTRY_POINT_GROUP - )[0] - return sampler_impl diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/util/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/util/__init__.py index 52104243532..e1857d8e62d 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/util/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/util/__init__.py @@ -14,11 +14,11 @@ import datetime import threading -from collections import OrderedDict, abc, deque -from typing import List, Optional, Sequence, Tuple +from collections import OrderedDict, deque +from collections.abc import MutableMapping, Sequence +from typing import Optional from deprecated import deprecated -from pkg_resources import iter_entry_points def ns_to_iso_str(nanoseconds): @@ -41,27 +41,7 @@ def get_dict_as_key(labels): ) -def _import_config_components( - selected_components: List[str], entry_point_name: str -) -> Sequence[Tuple[str, object]]: - component_entry_points = { - ep.name: ep for ep in iter_entry_points(entry_point_name) - } - component_impls = [] - for selected_component in selected_components: - entry_point = component_entry_points.get(selected_component, None) - if not entry_point: - raise RuntimeError( - f"Requested component '{selected_component}' not found in entry points for '{entry_point_name}'" - ) - - component_impl = entry_point.load() - component_impls.append((selected_component, component_impl)) - - return component_impls - - -class BoundedList(abc.Sequence): +class BoundedList(Sequence): """An append only list with a fixed max size. Calls to `append` and `extend` will drop the oldest elements if there is @@ -112,7 +92,7 @@ def from_seq(cls, maxlen, seq): @deprecated(version="1.4.0") # type: ignore -class BoundedDict(abc.MutableMapping): +class BoundedDict(MutableMapping): """An ordered dict with a fixed max capacity. Oldest elements are dropped when the dict is full and a new element is diff --git a/opentelemetry-sdk/tests/test_configurator.py b/opentelemetry-sdk/tests/test_configurator.py index 947ae623bc8..a27c7a49a1d 100644 --- a/opentelemetry-sdk/tests/test_configurator.py +++ b/opentelemetry-sdk/tests/test_configurator.py @@ -16,11 +16,12 @@ import logging from os import environ -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Optional, Sequence from unittest import TestCase from unittest.mock import patch from opentelemetry import trace +from opentelemetry.context import Context from opentelemetry.environment_variables import OTEL_PYTHON_ID_GENERATOR from opentelemetry.sdk._configuration import ( _EXPORTER_OTLP, @@ -28,8 +29,10 @@ _EXPORTER_OTLP_PROTO_HTTP, _get_exporter_names, _get_id_generator, + _get_sampler, _import_exporters, _import_id_generator, + _import_sampler, _init_logging, _init_metrics, _init_tracing, @@ -37,6 +40,10 @@ ) from opentelemetry.sdk._logs import LoggingHandler from opentelemetry.sdk._logs.export import ConsoleLogExporter +from opentelemetry.sdk.environment_variables import ( + OTEL_TRACES_SAMPLER, + OTEL_TRACES_SAMPLER_ARG, +) from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import ( AggregationTemporality, @@ -49,10 +56,22 @@ from opentelemetry.sdk.resources import SERVICE_NAME, Resource from opentelemetry.sdk.trace.export import ConsoleSpanExporter from opentelemetry.sdk.trace.id_generator import IdGenerator, RandomIdGenerator +from opentelemetry.sdk.trace.sampling import ( + ALWAYS_ON, + Decision, + ParentBased, + Sampler, + SamplingResult, + TraceIdRatioBased, +) +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes class Provider: - def __init__(self, resource=None, id_generator=None): + def __init__(self, resource=None, sampler=None, id_generator=None): + self.sampler = sampler self.id_generator = id_generator self.processor = None self.resource = resource or Resource.create({}) @@ -175,6 +194,73 @@ def shutdown(self): pass +class CustomSampler(Sampler): + def __init__(self) -> None: + pass + + def get_description(self) -> str: + return "CustomSampler" + + def should_sample( + self, + parent_context: Optional["Context"], + trace_id: int, + name: str, + kind: SpanKind = None, + attributes: Attributes = None, + links: Sequence[Link] = None, + trace_state: TraceState = None, + ) -> "SamplingResult": + return SamplingResult( + Decision.RECORD_AND_SAMPLE, + None, + None, + ) + + +class CustomRatioSampler(TraceIdRatioBased): + def __init__(self, ratio): + if not isinstance(ratio, float): + raise ValueError( + "CustomRatioSampler ratio argument is not a float." + ) + self.ratio = ratio + super().__init__(ratio) + + def get_description(self) -> str: + return "CustomSampler" + + def should_sample( + self, + parent_context: Optional["Context"], + trace_id: int, + name: str, + kind: SpanKind = None, + attributes: Attributes = None, + links: Sequence[Link] = None, + trace_state: TraceState = None, + ) -> "SamplingResult": + return SamplingResult( + Decision.RECORD_AND_SAMPLE, + None, + None, + ) + + +class CustomSamplerFactory: + @staticmethod + def get_custom_sampler(unused_sampler_arg): + return CustomSampler() + + @staticmethod + def get_custom_ratio_sampler(sampler_arg): + return CustomRatioSampler(float(sampler_arg)) + + @staticmethod + def empty_get_custom_sampler(sampler_arg): + return + + class CustomIdGenerator(IdGenerator): def generate_span_id(self): pass @@ -220,7 +306,11 @@ def tearDown(self): environ, {"OTEL_RESOURCE_ATTRIBUTES": "service.name=my-test-service"} ) def test_trace_init_default(self): - _init_tracing({"zipkin": Exporter}, RandomIdGenerator, "test-version") + _init_tracing( + {"zipkin": Exporter}, + id_generator=RandomIdGenerator(), + auto_instrumentation_version="test-version", + ) self.assertEqual(self.set_provider_mock.call_count, 1) provider = self.set_provider_mock.call_args[0][0] @@ -241,7 +331,9 @@ def test_trace_init_default(self): {"OTEL_RESOURCE_ATTRIBUTES": "service.name=my-otlp-test-service"}, ) def test_trace_init_otlp(self): - _init_tracing({"otlp": OTLPSpanExporter}, RandomIdGenerator) + _init_tracing( + {"otlp": OTLPSpanExporter}, id_generator=RandomIdGenerator() + ) self.assertEqual(self.set_provider_mock.call_count, 1) provider = self.set_provider_mock.call_args[0][0] @@ -257,7 +349,7 @@ def test_trace_init_otlp(self): @patch.dict(environ, {OTEL_PYTHON_ID_GENERATOR: "custom_id_generator"}) @patch("opentelemetry.sdk._configuration.IdGenerator", new=IdGenerator) - @patch("opentelemetry.sdk.util.iter_entry_points") + @patch("opentelemetry.sdk._configuration.iter_entry_points") def test_trace_init_custom_id_generator(self, mock_iter_entry_points): mock_iter_entry_points.configure_mock( return_value=[ @@ -266,10 +358,194 @@ def test_trace_init_custom_id_generator(self, mock_iter_entry_points): ) id_generator_name = _get_id_generator() id_generator = _import_id_generator(id_generator_name) - _init_tracing({}, id_generator) + _init_tracing({}, id_generator=id_generator) provider = self.set_provider_mock.call_args[0][0] self.assertIsInstance(provider.id_generator, CustomIdGenerator) + @patch.dict( + "os.environ", {OTEL_TRACES_SAMPLER: "non_existent_entry_point"} + ) + def test_trace_init_custom_sampler_with_env_non_existent_entry_point(self): + sampler_name = _get_sampler() + sampler = _import_sampler(sampler_name) + _init_tracing({}, sampler=sampler) + provider = self.set_provider_mock.call_args[0][0] + self.assertIsNone(provider.sampler) + + @patch("opentelemetry.sdk._configuration.iter_entry_points") + @patch.dict("os.environ", {OTEL_TRACES_SAMPLER: "custom_sampler_factory"}) + def test_trace_init_custom_sampler_with_env(self, mock_iter_entry_points): + mock_iter_entry_points.configure_mock( + return_value=[ + IterEntryPoint( + "custom_sampler_factory", + CustomSamplerFactory.get_custom_sampler, + ) + ] + ) + sampler_name = _get_sampler() + sampler = _import_sampler(sampler_name) + _init_tracing({}, sampler=sampler) + provider = self.set_provider_mock.call_args[0][0] + self.assertIsInstance(provider.sampler, CustomSampler) + + @patch("opentelemetry.sdk._configuration.iter_entry_points") + @patch.dict("os.environ", {OTEL_TRACES_SAMPLER: "custom_sampler_factory"}) + def test_trace_init_custom_sampler_with_env_bad_factory( + self, mock_iter_entry_points + ): + mock_iter_entry_points.configure_mock( + return_value=[ + IterEntryPoint( + "custom_sampler_factory", + CustomSamplerFactory.empty_get_custom_sampler, + ) + ] + ) + sampler_name = _get_sampler() + sampler = _import_sampler(sampler_name) + _init_tracing({}, sampler=sampler) + provider = self.set_provider_mock.call_args[0][0] + self.assertIsNone(provider.sampler) + + @patch("opentelemetry.sdk._configuration.iter_entry_points") + @patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_sampler_factory", + OTEL_TRACES_SAMPLER_ARG: "0.5", + }, + ) + def test_trace_init_custom_sampler_with_env_unused_arg( + self, mock_iter_entry_points + ): + mock_iter_entry_points.configure_mock( + return_value=[ + IterEntryPoint( + "custom_sampler_factory", + CustomSamplerFactory.get_custom_sampler, + ) + ] + ) + sampler_name = _get_sampler() + sampler = _import_sampler(sampler_name) + _init_tracing({}, sampler=sampler) + provider = self.set_provider_mock.call_args[0][0] + self.assertIsInstance(provider.sampler, CustomSampler) + + @patch("opentelemetry.sdk._configuration.iter_entry_points") + @patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory", + OTEL_TRACES_SAMPLER_ARG: "0.5", + }, + ) + def test_trace_init_custom_ratio_sampler_with_env( + self, mock_iter_entry_points + ): + mock_iter_entry_points.configure_mock( + return_value=[ + IterEntryPoint( + "custom_ratio_sampler_factory", + CustomSamplerFactory.get_custom_ratio_sampler, + ) + ] + ) + sampler_name = _get_sampler() + sampler = _import_sampler(sampler_name) + _init_tracing({}, sampler=sampler) + provider = self.set_provider_mock.call_args[0][0] + self.assertIsInstance(provider.sampler, CustomRatioSampler) + self.assertEqual(provider.sampler.ratio, 0.5) + + @patch("opentelemetry.sdk._configuration.iter_entry_points") + @patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory", + OTEL_TRACES_SAMPLER_ARG: "foobar", + }, + ) + def test_trace_init_custom_ratio_sampler_with_env_bad_arg( + self, mock_iter_entry_points + ): + mock_iter_entry_points.configure_mock( + return_value=[ + IterEntryPoint( + "custom_ratio_sampler_factory", + CustomSamplerFactory.get_custom_ratio_sampler, + ) + ] + ) + sampler_name = _get_sampler() + sampler = _import_sampler(sampler_name) + _init_tracing({}, sampler=sampler) + provider = self.set_provider_mock.call_args[0][0] + self.assertIsNone(provider.sampler) + + @patch("opentelemetry.sdk._configuration.iter_entry_points") + @patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory", + }, + ) + def test_trace_init_custom_ratio_sampler_with_env_missing_arg( + self, mock_iter_entry_points + ): + mock_iter_entry_points.configure_mock( + return_value=[ + IterEntryPoint( + "custom_ratio_sampler_factory", + CustomSamplerFactory.get_custom_ratio_sampler, + ) + ] + ) + sampler_name = _get_sampler() + sampler = _import_sampler(sampler_name) + _init_tracing({}, sampler=sampler) + provider = self.set_provider_mock.call_args[0][0] + self.assertIsNone(provider.sampler) + + @patch("opentelemetry.sdk._configuration.iter_entry_points") + @patch.dict( + "os.environ", + { + OTEL_TRACES_SAMPLER: "custom_sampler_factory", + OTEL_TRACES_SAMPLER_ARG: "0.5", + }, + ) + def test_trace_init_custom_ratio_sampler_with_env_multiple_entry_points( + self, mock_iter_entry_points + ): + mock_iter_entry_points.configure_mock( + return_value=[ + IterEntryPoint( + "custom_ratio_sampler_factory", + CustomSamplerFactory.get_custom_ratio_sampler, + ), + IterEntryPoint( + "custom_sampler_factory", + CustomSamplerFactory.get_custom_sampler, + ), + IterEntryPoint( + "custom_z_sampler_factory", + CustomSamplerFactory.empty_get_custom_sampler, + ), + ] + ) + sampler_name = _get_sampler() + sampler = _import_sampler(sampler_name) + _init_tracing({}, sampler=sampler) + provider = self.set_provider_mock.call_args[0][0] + self.assertIsInstance(provider.sampler, CustomSampler) + + def verify_default_sampler(self, tracer_provider): + self.assertIsInstance(tracer_provider.sampler, ParentBased) + # pylint: disable=protected-access + self.assertEqual(tracer_provider.sampler._root, ALWAYS_ON) + class TestLoggingInit(TestCase): def setUp(self): diff --git a/opentelemetry-sdk/tests/trace/test_trace.py b/opentelemetry-sdk/tests/trace/test_trace.py index 3f4d0d0da1c..5bb8d87c658 100644 --- a/opentelemetry-sdk/tests/trace/test_trace.py +++ b/opentelemetry-sdk/tests/trace/test_trace.py @@ -21,7 +21,7 @@ from logging import ERROR, WARNING from random import randint from time import time_ns -from typing import Optional, Sequence +from typing import Optional from unittest import mock from opentelemetry import trace as trace_api @@ -46,10 +46,7 @@ ALWAYS_ON, Decision, ParentBased, - Sampler, - SamplingResult, StaticSampler, - TraceIdRatioBased, ) from opentelemetry.sdk.util import ns_to_iso_str from opentelemetry.sdk.util.instrumentation import InstrumentationInfo @@ -57,9 +54,7 @@ get_span_with_dropped_attributes_events_links, new_tracer, ) -from opentelemetry.trace import Link, SpanKind, Status, StatusCode -from opentelemetry.trace.span import TraceState -from opentelemetry.util.types import Attributes +from opentelemetry.trace import Status, StatusCode class TestTracer(unittest.TestCase): @@ -151,78 +146,6 @@ def test_tracer_provider_accepts_concurrent_multi_span_processor(self): ) -class CustomSampler(Sampler): - def __init__(self) -> None: - pass - - def get_description(self) -> str: - return "CustomSampler" - - def should_sample( - self, - parent_context: Optional["Context"], - trace_id: int, - name: str, - kind: SpanKind = None, - attributes: Attributes = None, - links: Sequence[Link] = None, - trace_state: TraceState = None, - ) -> "SamplingResult": - return SamplingResult( - Decision.RECORD_AND_SAMPLE, - None, - None, - ) - - -class CustomRatioSampler(TraceIdRatioBased): - def __init__(self, ratio): - self.ratio = ratio - super().__init__(ratio) - - def get_description(self) -> str: - return "CustomSampler" - - def should_sample( - self, - parent_context: Optional["Context"], - trace_id: int, - name: str, - kind: SpanKind = None, - attributes: Attributes = None, - links: Sequence[Link] = None, - trace_state: TraceState = None, - ) -> "SamplingResult": - return SamplingResult( - Decision.RECORD_AND_SAMPLE, - None, - None, - ) - - -class CustomSamplerFactory: - @staticmethod - def get_custom_sampler(unused_sampler_arg): - return CustomSampler() - - @staticmethod - def get_custom_ratio_sampler(sampler_arg): - return CustomRatioSampler(float(sampler_arg)) - - @staticmethod - def empty_get_custom_sampler(sampler_arg): - return - - -class IterEntryPoint: - def __init__(self, name, class_type): - self.name = name - self.class_type = class_type - - def load(self): - return self.class_type - - class TestTracerSampling(unittest.TestCase): def tearDown(self): reload(trace) @@ -251,7 +174,8 @@ def test_default_sampler_type(self): tracer_provider = trace.TracerProvider() self.verify_default_sampler(tracer_provider) - def test_sampler_no_sampling(self): + @mock.patch("opentelemetry.sdk.trace.sampling._get_from_env_or_default") + def test_sampler_no_sampling(self, _get_from_env_or_default): tracer_provider = trace.TracerProvider(ALWAYS_OFF) tracer = tracer_provider.get_tracer(__name__) @@ -270,6 +194,7 @@ def test_sampler_no_sampling(self): child_span.get_span_context().trace_flags, trace_api.TraceFlags.DEFAULT, ) + self.assertFalse(_get_from_env_or_default.called) @mock.patch.dict("os.environ", {OTEL_TRACES_SAMPLER: "always_off"}) def test_sampler_with_env(self): @@ -299,161 +224,6 @@ def test_ratio_sampler_with_env(self): self.assertIsInstance(tracer_provider.sampler, ParentBased) self.assertEqual(tracer_provider.sampler._root.rate, 0.25) - @mock.patch.dict( - "os.environ", {OTEL_TRACES_SAMPLER: "non_existent_entry_point"} - ) - def test_sampler_with_env_non_existent_entry_point(self): - # pylint: disable=protected-access - reload(trace) - tracer_provider = trace.TracerProvider() - self.verify_default_sampler(tracer_provider) - - @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") - @mock.patch.dict( - "os.environ", {OTEL_TRACES_SAMPLER: "custom_sampler_factory"} - ) - def test_custom_sampler_with_env(self, mock_iter_entry_points): - mock_iter_entry_points.return_value = [ - IterEntryPoint( - "custom_sampler_factory", - CustomSamplerFactory.get_custom_sampler, - ) - ] - # pylint: disable=protected-access - reload(trace) - tracer_provider = trace.TracerProvider() - self.assertIsInstance(tracer_provider.sampler, CustomSampler) - - @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") - @mock.patch.dict( - "os.environ", {OTEL_TRACES_SAMPLER: "custom_sampler_factory"} - ) - def test_custom_sampler_with_env_bad_factory(self, mock_iter_entry_points): - mock_iter_entry_points.return_value = [ - IterEntryPoint( - "custom_sampler_factory", - CustomSamplerFactory.empty_get_custom_sampler, - ) - ] - # pylint: disable=protected-access - reload(trace) - tracer_provider = trace.TracerProvider() - self.verify_default_sampler(tracer_provider) - - @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") - @mock.patch.dict( - "os.environ", - { - OTEL_TRACES_SAMPLER: "custom_sampler_factory", - OTEL_TRACES_SAMPLER_ARG: "0.5", - }, - ) - def test_custom_sampler_with_env_unused_arg(self, mock_iter_entry_points): - mock_iter_entry_points.return_value = [ - IterEntryPoint( - "custom_sampler_factory", - CustomSamplerFactory.get_custom_sampler, - ) - ] - # pylint: disable=protected-access - reload(trace) - tracer_provider = trace.TracerProvider() - self.assertIsInstance(tracer_provider.sampler, CustomSampler) - - @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") - @mock.patch.dict( - "os.environ", - { - OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory", - OTEL_TRACES_SAMPLER_ARG: "0.5", - }, - ) - def test_custom_ratio_sampler_with_env(self, mock_iter_entry_points): - mock_iter_entry_points.return_value = [ - IterEntryPoint( - "custom_ratio_sampler_factory", - CustomSamplerFactory.get_custom_ratio_sampler, - ) - ] - # pylint: disable=protected-access - reload(trace) - tracer_provider = trace.TracerProvider() - self.assertIsInstance(tracer_provider.sampler, CustomRatioSampler) - self.assertEqual(tracer_provider.sampler.ratio, 0.5) - - @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") - @mock.patch.dict( - "os.environ", - { - OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory", - OTEL_TRACES_SAMPLER_ARG: "foobar", - }, - ) - def test_custom_ratio_sampler_with_env_bad_arg( - self, mock_iter_entry_points - ): - mock_iter_entry_points.return_value = [ - IterEntryPoint( - "custom_ratio_sampler_factory", - CustomSamplerFactory.get_custom_ratio_sampler, - ) - ] - # pylint: disable=protected-access - reload(trace) - tracer_provider = trace.TracerProvider() - self.verify_default_sampler(tracer_provider) - - @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") - @mock.patch.dict( - "os.environ", - { - OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory", - }, - ) - def test_custom_ratio_sampler_with_env_no_arg( - self, mock_iter_entry_points - ): - mock_iter_entry_points.return_value = [ - IterEntryPoint( - "custom_ratio_sampler_factory", - CustomSamplerFactory.get_custom_ratio_sampler, - ) - ] - # pylint: disable=protected-access - reload(trace) - tracer_provider = trace.TracerProvider() - self.verify_default_sampler(tracer_provider) - - @mock.patch("opentelemetry.sdk.trace.util.iter_entry_points") - @mock.patch.dict( - "os.environ", - { - OTEL_TRACES_SAMPLER: "custom_sampler_factory", - OTEL_TRACES_SAMPLER_ARG: "0.5", - }, - ) - def test_custom_ratio_sampler_with_env_multiple_entry_points( - self, mock_iter_entry_points - ): - mock_iter_entry_points.return_value = [ - IterEntryPoint( - "custom_ratio_sampler_factory", - CustomSamplerFactory.get_custom_ratio_sampler, - ), - IterEntryPoint( - "custom_sampler_factory", - CustomSamplerFactory.get_custom_sampler, - ), - IterEntryPoint( - "custom_z_sampler_factory", - CustomSamplerFactory.empty_get_custom_sampler, - ), - ] - # pylint: disable=protected-access - reload(trace) - tracer_provider = trace.TracerProvider() - self.assertIsInstance(tracer_provider.sampler, CustomSampler) - def verify_default_sampler(self, tracer_provider): self.assertIsInstance(tracer_provider.sampler, ParentBased) # pylint: disable=protected-access