From 50856f1e9616ebc3dac194f13b174e5e53591b87 Mon Sep 17 00:00:00 2001 From: Diego Hurtado Date: Thu, 11 Mar 2021 15:58:17 -0600 Subject: [PATCH] Remove setters and getters Fixes #1644 --- .../baggage/propagation/__init__.py | 92 +++----- .../src/opentelemetry/propagate/__init__.py | 56 ++--- .../opentelemetry/propagators/composite.py | 49 ++-- .../src/opentelemetry/propagators/textmap.py | 127 ++-------- .../trace/propagation/tracecontext.py | 132 ++++++----- .../tests/baggage/test_baggage_propagation.py | 150 ++++++------ .../tests/propagators/test_composite.py | 7 +- .../test_tracecontexthttptextformat.py | 72 +++--- .../opentelemetry/propagators/b3/__init__.py | 62 ++--- .../tests/test_b3_format.py | 216 +++++++++--------- .../propagators/jaeger/__init__.py | 142 +++++------- .../src/opentelemetry/test/mock_textmap.py | 60 ++--- 12 files changed, 492 insertions(+), 673 deletions(-) diff --git a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py index e6d1c4207bc..d66d2f68c88 100644 --- a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py @@ -11,29 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import typing -import urllib.parse + +from typing import Optional, Set +from urllib.parse import quote_plus, unquote from opentelemetry import baggage from opentelemetry.context import get_current from opentelemetry.context.context import Context -from opentelemetry.propagators import textmap +from opentelemetry.propagators.textmap import ( + TextMapPropagator, + TextMapPropagatorT, +) -class W3CBaggagePropagator(textmap.TextMapPropagator): +class W3CBaggagePropagator(TextMapPropagator): """Extracts and injects Baggage which is used to annotate telemetry.""" - _MAX_HEADER_LENGTH = 8192 - _MAX_PAIR_LENGTH = 4096 - _MAX_PAIRS = 180 - _BAGGAGE_HEADER_NAME = "baggage" + _baggage_header_name = "baggage" def extract( - self, - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: """Extract Baggage from the carrier. @@ -44,38 +41,39 @@ def extract( if context is None: context = get_current() - header = _extract_first_element( - getter.get(carrier, self._BAGGAGE_HEADER_NAME) - ) + value = carrier.get(self._baggage_header_name) + + if value is None: + header = None + else: + header = next(iter(value), None) - if not header or len(header) > self._MAX_HEADER_LENGTH: + # 8192 is the maximum header length + if header is None or len(header) > 8192: return context - baggage_entries = header.split(",") - total_baggage_entries = self._MAX_PAIRS - for entry in baggage_entries: + # 180 is the maximum amount of pairs + total_baggage_entries = 180 + + for entry in header.split(","): if total_baggage_entries <= 0: return context total_baggage_entries -= 1 - if len(entry) > self._MAX_PAIR_LENGTH: + # 4096 is the maximum pair length + if len(entry) > 4096: continue - try: + if "=" in entry: name, value = entry.split("=", 1) - except Exception: # pylint: disable=broad-except - continue - context = baggage.set_baggage( - urllib.parse.unquote(name).strip(), - urllib.parse.unquote(value).strip(), - context=context, - ) + context = baggage.set_baggage( + unquote(name).strip(), + unquote(value).strip(), + context=context, + ) return context def inject( - self, - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: """Injects Baggage into the carrier. @@ -83,28 +81,14 @@ def inject( `opentelemetry.propagators.textmap.TextMapPropagator.inject` """ baggage_entries = baggage.get_all(context=context) - if not baggage_entries: - return - baggage_string = _format_baggage(baggage_entries) - set_in_carrier(carrier, self._BAGGAGE_HEADER_NAME, baggage_string) + if baggage_entries: + carrier[self._baggage_header_name] = ",".join( + key + "=" + quote_plus(str(value)) + for key, value in baggage_entries.items() + ) @property - def fields(self) -> typing.Set[str]: + def fields(self) -> Set[str]: """Returns a set with the fields set in `inject`.""" - return {self._BAGGAGE_HEADER_NAME} - - -def _format_baggage(baggage_entries: typing.Mapping[str, object]) -> str: - return ",".join( - key + "=" + urllib.parse.quote_plus(str(value)) - for key, value in baggage_entries.items() - ) - - -def _extract_first_element( - items: typing.Optional[typing.Iterable[textmap.TextMapPropagatorT]], -) -> typing.Optional[textmap.TextMapPropagatorT]: - if items is None: - return None - return next(iter(items), None) + return {self._baggage_header_name} diff --git a/opentelemetry-api/src/opentelemetry/propagate/__init__.py b/opentelemetry-api/src/opentelemetry/propagate/__init__.py index 44f9897a532..d23a0fcd239 100644 --- a/opentelemetry-api/src/opentelemetry/propagate/__init__.py +++ b/opentelemetry-api/src/opentelemetry/propagate/__init__.py @@ -40,23 +40,12 @@ PROPAGATOR = propagators.get_global_textmap() - def get_header_from_flask_request(request, key): - return request.headers.get_all(key) - - def set_header_into_requests_request(request: requests.Request, - key: str, value: str): - request.headers[key] = value - def example_route(): - context = PROPAGATOR.extract( - get_header_from_flask_request, - flask.request - ) + context = PROPAGATOR.extract(flask.request) request_to_downstream = requests.Request( "GET", "http://httpbin.org/get" ) PROPAGATOR.inject( - set_header_into_requests_request, request_to_downstream, context=context ) @@ -68,23 +57,25 @@ def example_route(): https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/context/api-propagators.md """ -import typing from logging import getLogger from os import environ +from typing import Optional from pkg_resources import iter_entry_points from opentelemetry.context.context import Context from opentelemetry.environment_variables import OTEL_PROPAGATORS -from opentelemetry.propagators import composite, textmap +from opentelemetry.propagators import composite +from opentelemetry.propagators.textmap import ( + TextMapPropagator, + TextMapPropagatorT, +) -logger = getLogger(__name__) +_logger = getLogger(__name__) def extract( - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: """Uses the configured propagator to extract a Context from the carrier. @@ -99,26 +90,21 @@ def extract( context: an optional Context to use. Defaults to current context if not set. """ - return get_global_textmap().extract(getter, carrier, context) + return get_global_textmap().extract(carrier, context) def inject( - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: """Uses the configured propagator to inject a Context into the carrier. Args: - set_in_carrier: A setter function that can set values - on the carrier. - carrier: An object that contains a representation of HTTP - headers. Should be paired with set_in_carrier, which - should know how to set header values on the carrier. + carrier: A dict-like object that contains a representation of HTTP + headers. context: an optional Context to use. Defaults to current context if not set. """ - get_global_textmap().inject(set_in_carrier, carrier, context) + get_global_textmap().inject(carrier, context) try: @@ -138,16 +124,16 @@ def inject( ) except Exception: # pylint: disable=broad-except - logger.exception("Failed to load configured propagators") + _logger.error("Failed to load configured propagators") raise -_HTTP_TEXT_FORMAT = composite.CompositeHTTPPropagator(propagators) # type: ignore +_textmap_propagator = composite.CompositeHTTPPropagator(propagators) # type: ignore -def get_global_textmap() -> textmap.TextMapPropagator: - return _HTTP_TEXT_FORMAT +def get_global_textmap() -> TextMapPropagator: + return _textmap_propagator -def set_global_textmap(http_text_format: textmap.TextMapPropagator,) -> None: - global _HTTP_TEXT_FORMAT # pylint:disable=global-statement - _HTTP_TEXT_FORMAT = http_text_format # type: ignore +def set_global_textmap(http_text_format: TextMapPropagator,) -> None: + global _textmap_propagator # pylint:disable=global-statement + _textmap_propagator = http_text_format # type: ignore diff --git a/opentelemetry-api/src/opentelemetry/propagators/composite.py b/opentelemetry-api/src/opentelemetry/propagators/composite.py index 92dc6b8a380..811934b5164 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/composite.py +++ b/opentelemetry-api/src/opentelemetry/propagators/composite.py @@ -11,16 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import typing + +from logging import getLogger +from typing import Optional, Sequence, Set from opentelemetry.context.context import Context -from opentelemetry.propagators import textmap +from opentelemetry.propagators.textmap import ( + TextMapPropagator, + TextMapPropagatorT, +) -logger = logging.getLogger(__name__) +_logger = getLogger(__name__) -class CompositeHTTPPropagator(textmap.TextMapPropagator): +class CompositeHTTPPropagator(TextMapPropagator): """CompositeHTTPPropagator provides a mechanism for combining multiple propagators into a single one. @@ -28,46 +32,39 @@ class CompositeHTTPPropagator(textmap.TextMapPropagator): propagators: the list of propagators to use """ - def __init__( - self, propagators: typing.Sequence[textmap.TextMapPropagator] - ) -> None: + def __init__(self, propagators: Sequence[TextMapPropagator]) -> None: self._propagators = propagators def extract( - self, - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: - """Run each of the configured propagators with the given context and carrier. + """Run each of the configured propagators with the given context and + carrier. Propagators are run in the order they are configured, if multiple - propagators write the same context key, the propagator later in the list - will override previous propagators. + propagators write the same context key, the last propagator that writes + the context key will override previous propagators. See `opentelemetry.propagators.textmap.TextMapPropagator.extract` """ for propagator in self._propagators: - context = propagator.extract(getter, carrier, context) + context = propagator.extract(carrier, context) return context # type: ignore def inject( - self, - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: - """Run each of the configured propagators with the given context and carrier. - Propagators are run in the order they are configured, if multiple - propagators write the same carrier key, the propagator later in the list - will override previous propagators. + """Run each of the configured propagators with the given context and + carrier. Propagators are run in the order they are configured, if + multiple propagators write the same carrier key, the last propagator + that writes the carrier key will override previous propagators. See `opentelemetry.propagators.textmap.TextMapPropagator.inject` """ for propagator in self._propagators: - propagator.inject(set_in_carrier, carrier, context) + propagator.inject(carrier, context) @property - def fields(self) -> typing.Set[str]: + def fields(self) -> Set[str]: """Returns a set with the fields set in `inject`. See diff --git a/opentelemetry-api/src/opentelemetry/propagators/textmap.py b/opentelemetry-api/src/opentelemetry/propagators/textmap.py index cf93d1d6319..af6d3a49595 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/textmap.py +++ b/opentelemetry-api/src/opentelemetry/propagators/textmap.py @@ -12,139 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -import typing +from abc import ABC, abstractmethod +from typing import Optional, Set, TypeVar from opentelemetry.context.context import Context -TextMapPropagatorT = typing.TypeVar("TextMapPropagatorT") -CarrierValT = typing.Union[typing.List[str], str] +TextMapPropagatorT = TypeVar("TextMapPropagatorT") -Setter = typing.Callable[[TextMapPropagatorT, str, str], None] - -class Getter(typing.Generic[TextMapPropagatorT]): - """This class implements a Getter that enables extracting propagated - fields from a carrier. - """ - - def get( - self, carrier: TextMapPropagatorT, key: str - ) -> typing.Optional[typing.List[str]]: - """Function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, returns None. - - Args: - carrier: An object which contains values that are used to - construct a Context. - key: key of a field in carrier. - Returns: first value of the propagation key or None if the key doesn't - exist. - """ - raise NotImplementedError() - - def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]: - """Function that can retrieve all the keys in a carrier object. - - Args: - carrier: An object which contains values that are - used to construct a Context. - Returns: - list of keys from the carrier. - """ - raise NotImplementedError() - - -class DictGetter(Getter[typing.Dict[str, CarrierValT]]): - def get( - self, carrier: typing.Dict[str, CarrierValT], key: str - ) -> typing.Optional[typing.List[str]]: - """Getter implementation to retrieve a value from a dictionary. - - Args: - carrier: dictionary in which header - key: the key used to get the value - Returns: - A list with a single string with the value if it exists, else None. - """ - val = carrier.get(key, None) - if val is None: - return None - if isinstance(val, typing.Iterable) and not isinstance(val, str): - return list(val) - return [val] - - def keys(self, carrier: typing.Dict[str, CarrierValT]) -> typing.List[str]: - """Keys implementation that returns all keys from a dictionary.""" - return list(carrier.keys()) - - -class TextMapPropagator(abc.ABC): +class TextMapPropagator(ABC): """This class provides an interface that enables extracting and injecting - context into headers of HTTP requests. HTTP frameworks and clients - can integrate with TextMapPropagator by providing the object containing the - headers, and a getter and setter function for the extraction and - injection of values, respectively. - + context into headers of HTTP requests. HTTP frameworks and clients can + integrate with TextMapPropagator by providing the object containing the + headers. """ - @abc.abstractmethod + @abstractmethod def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: """Create a Context from values in the carrier. - The extract function should retrieve values from the carrier - object using getter, and use values to populate a - Context value and return it. + Retrieves values from the carrier object and uses them to populate a + context and returns it afterwards. Args: - getter: a function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, return an empty list. - carrier: and object which contains values that are - used to construct a Context. This object - must be paired with an appropriate getter - which understands how to extract a value from it. - context: an optional Context to use. Defaults to current - context if not set. + carrier: and object which contains values that are used to + construct a Context. + context: an optional Context to use. Defaults to current context if + not set. Returns: - A Context with configuration found in the carrier. - + A Context with the configuration found in the carrier. """ - @abc.abstractmethod + @abstractmethod def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: """Inject values from a Context into a carrier. - inject enables the propagation of values into HTTP clients or - other objects which perform an HTTP request. Implementations - should use the set_in_carrier method to set values on the - carrier. + Enables the propagation of values into HTTP clients or other objects + which perform an HTTP request. Args: - set_in_carrier: A setter function that can set values - on the carrier. - carrier: An object that a place to define HTTP headers. - Should be paired with set_in_carrier, which should - know how to set header values on the carrier. + carrier: An dict-like object where to store HTTP headers. context: an optional Context to use. Defaults to current context if not set. """ @property - @abc.abstractmethod - def fields(self) -> typing.Set[str]: + @abstractmethod + def fields(self) -> Set[str]: """ Gets the fields set in the carrier by the `inject` method. diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py index 480e716bf78..3e12a1797fd 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py @@ -11,111 +11,119 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import re -import typing -import opentelemetry.trace as trace +from re import compile as compile_ +from re import search +from typing import Optional, Set + from opentelemetry.context.context import Context -from opentelemetry.propagators import textmap -from opentelemetry.trace import format_span_id, format_trace_id +from opentelemetry.propagators.textmap import ( + TextMapPropagator, + TextMapPropagatorT, +) +from opentelemetry.trace import ( + INVALID_SPAN, + INVALID_SPAN_CONTEXT, + NonRecordingSpan, + SpanContext, + TraceFlags, + format_span_id, + format_trace_id, + get_current_span, + set_span_in_context, +) from opentelemetry.trace.span import TraceState -class TraceContextTextMapPropagator(textmap.TextMapPropagator): +class TraceContextTextMapPropagator(TextMapPropagator): """Extracts and injects using w3c TraceContext's headers.""" - _TRACEPARENT_HEADER_NAME = "traceparent" - _TRACESTATE_HEADER_NAME = "tracestate" - _TRACEPARENT_HEADER_FORMAT = ( - "^[ \t]*([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})" - + "(-.*)?[ \t]*$" + _traceparent_header_name = "traceparent" + _tracestate_header_name = "tracestate" + _traceparent_header_format_re = compile_( + r"^\s*(?P[0-9a-f]{2})-" + r"(?P[0-9a-f]{32})-" + r"(?P[0-9a-f]{16})-" + r"(?P[0-9a-f]{2})" + r"(?P.*)?\s*$" ) - _TRACEPARENT_HEADER_FORMAT_RE = re.compile(_TRACEPARENT_HEADER_FORMAT) def extract( - self, - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: """Extracts SpanContext from the carrier. See `opentelemetry.propagators.textmap.TextMapPropagator.extract` """ - header = getter.get(carrier, self._TRACEPARENT_HEADER_NAME) + header = carrier.get(self._traceparent_header_name) + + if header is None: + return set_span_in_context(INVALID_SPAN, context) - if not header: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + match = search(self._traceparent_header_format_re, header[0]) + if match is None: + return set_span_in_context(INVALID_SPAN, context) - match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header[0]) - if not match: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + version = match.group("version") + trace_id = match.group("trace_id") + span_id = match.group("span_id") - version = match.group(1) - trace_id = match.group(2) - span_id = match.group(3) - trace_flags = match.group(4) + if ( + (version == "ff") + or (version == "00" and match.group("remainder") is not None) + or (trace_id == "0" * 32 or span_id == "0" * 16) + ): - if trace_id == "0" * 32 or span_id == "0" * 16: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return set_span_in_context(INVALID_SPAN, context) - if version == "00": - if match.group(5): - return trace.set_span_in_context(trace.INVALID_SPAN, context) - if version == "ff": - return trace.set_span_in_context(trace.INVALID_SPAN, context) + tracestate_headers = carrier.get(self._tracestate_header_name) - tracestate_headers = getter.get(carrier, self._TRACESTATE_HEADER_NAME) if tracestate_headers is None: tracestate = None else: tracestate = TraceState.from_header(tracestate_headers) - span_context = trace.SpanContext( - trace_id=int(trace_id, 16), - span_id=int(span_id, 16), - is_remote=True, - trace_flags=trace.TraceFlags(trace_flags), - trace_state=tracestate, - ) - return trace.set_span_in_context( - trace.NonRecordingSpan(span_context), context + return set_span_in_context( + NonRecordingSpan( + SpanContext( + trace_id=int(trace_id, 16), + span_id=int(span_id, 16), + is_remote=True, + trace_flags=TraceFlags(match.group("trace_flags")), + trace_state=tracestate, + ) + ), + context, ) def inject( - self, - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: """Injects SpanContext into the carrier. See `opentelemetry.propagators.textmap.TextMapPropagator.inject` """ - span = trace.get_current_span(context) - span_context = span.get_span_context() - if span_context == trace.INVALID_SPAN_CONTEXT: + span_context = get_current_span(context).get_span_context() + if span_context == INVALID_SPAN_CONTEXT: return - traceparent_string = "00-{trace_id}-{span_id}-{:02x}".format( + carrier[ + self._traceparent_header_name + ] = "00-{trace_id}-{span_id}-{:02x}".format( span_context.trace_flags, trace_id=format_trace_id(span_context.trace_id), span_id=format_span_id(span_context.span_id), ) - set_in_carrier( - carrier, self._TRACEPARENT_HEADER_NAME, traceparent_string - ) - if span_context.trace_state: - tracestate_string = span_context.trace_state.to_header() - set_in_carrier( - carrier, self._TRACESTATE_HEADER_NAME, tracestate_string - ) + + if span_context.trace_state is not None: + carrier[ + self._tracestate_header_name + ] = span_context.trace_state.to_header() @property - def fields(self) -> typing.Set[str]: + def fields(self) -> Set[str]: """Returns a set with the fields set in `inject`. See `opentelemetry.propagators.textmap.TextMapPropagator.fields` """ - return {self._TRACEPARENT_HEADER_NAME, self._TRACESTATE_HEADER_NAME} + return {self._traceparent_header_name, self._tracestate_header_name} diff --git a/opentelemetry-api/tests/baggage/test_baggage_propagation.py b/opentelemetry-api/tests/baggage/test_baggage_propagation.py index a928a2fc8cb..846298b6479 100644 --- a/opentelemetry-api/tests/baggage/test_baggage_propagation.py +++ b/opentelemetry-api/tests/baggage/test_baggage_propagation.py @@ -11,26 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import unittest -from unittest.mock import Mock, patch + +from unittest import TestCase +from unittest.mock import patch from opentelemetry import baggage from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.context import get_current -from opentelemetry.propagators.textmap import DictGetter - -carrier_getter = DictGetter() -class TestBaggagePropagation(unittest.TestCase): +class TestBaggagePropagation(TestCase): def setUp(self): self.propagator = W3CBaggagePropagator() def _extract(self, header_value): """Test helper""" - header = {"baggage": [header_value]} - return baggage.get_all(self.propagator.extract(carrier_getter, header)) + return baggage.get_all( + self.propagator.extract({"baggage": [header_value]}) + ) def _inject(self, values): """Test helper""" @@ -38,122 +36,114 @@ def _inject(self, values): for k, v in values.items(): ctx = baggage.set_baggage(k, v, context=ctx) output = {} - self.propagator.inject(dict.__setitem__, output, context=ctx) + self.propagator.inject(output, context=ctx) return output.get("baggage") def test_no_context_header(self): - baggage_entries = baggage.get_all( - self.propagator.extract(carrier_getter, {}) - ) - self.assertEqual(baggage_entries, {}) + self.assertEqual(baggage.get_all(self.propagator.extract({})), {}) def test_empty_context_header(self): - header = "" - self.assertEqual(self._extract(header), {}) + self.assertEqual(self._extract(""), {}) def test_valid_header(self): - header = "key1=val1,key2=val2" - expected = {"key1": "val1", "key2": "val2"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1=val1,key2=val2"), + {"key1": "val1", "key2": "val2"}, + ) def test_valid_header_with_space(self): - header = "key1 = val1, key2 =val2 " - expected = {"key1": "val1", "key2": "val2"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1 = val1, key2 =val2 "), + {"key1": "val1", "key2": "val2"}, + ) def test_valid_header_with_properties(self): - header = "key1=val1,key2=val2;prop=1" - expected = {"key1": "val1", "key2": "val2;prop=1"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1=val1,key2=val2;prop=1"), + {"key1": "val1", "key2": "val2;prop=1"}, + ) def test_valid_header_with_url_escaped_comma(self): - header = "key%2C1=val1,key2=val2%2Cval3" - expected = {"key,1": "val1", "key2": "val2,val3"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key%2C1=val1,key2=val2%2Cval3"), + {"key,1": "val1", "key2": "val2,val3"}, + ) def test_valid_header_with_invalid_value(self): - header = "key1=val1,key2=val2,a,val3" - expected = {"key1": "val1", "key2": "val2"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1=val1,key2=val2,a,val3"), + {"key1": "val1", "key2": "val2"}, + ) def test_valid_header_with_empty_value(self): - header = "key1=,key2=val2" - expected = {"key1": "", "key2": "val2"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1=,key2=val2"), {"key1": "", "key2": "val2"} + ) def test_invalid_header(self): - header = "header1" - expected = {} - self.assertEqual(self._extract(header), expected) + self.assertEqual(self._extract("header1"), {}) def test_header_too_long(self): - long_value = "s" * (W3CBaggagePropagator._MAX_HEADER_LENGTH + 1) - header = "key1={}".format(long_value) - expected = {} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract( + "key1={}".format( + "s" * (W3CBaggagePropagator._MAX_HEADER_LENGTH + 1) + ) + ), + {}, + ) def test_header_contains_too_many_entries(self): - header = ",".join( - [ - "key{}=val".format(k) - for k in range(W3CBaggagePropagator._MAX_PAIRS + 1) - ] - ) self.assertEqual( - len(self._extract(header)), W3CBaggagePropagator._MAX_PAIRS + len( + self._extract( + ",".join( + "key{}=val".format(k) + for k in range(W3CBaggagePropagator._MAX_PAIRS + 1) + ) + ) + ), + W3CBaggagePropagator._MAX_PAIRS, ) def test_header_contains_pair_too_long(self): - long_value = "s" * (W3CBaggagePropagator._MAX_PAIR_LENGTH + 1) - header = "key1=value1,key2={},key3=value3".format(long_value) - expected = {"key1": "value1", "key3": "value3"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract( + "key1=value1,key2={},key3=value3".format( + "s" * (W3CBaggagePropagator._MAX_PAIR_LENGTH + 1) + ) + ), + {"key1": "value1", "key3": "value3"}, + ) def test_inject_no_baggage_entries(self): - values = {} - output = self._inject(values) - self.assertEqual(None, output) + self.assertEqual(None, self._inject({})) def test_inject(self): - values = { - "key1": "val1", - "key2": "val2", - } - output = self._inject(values) + output = self._inject({"key1": "val1", "key2": "val2"}) self.assertIn("key1=val1", output) self.assertIn("key2=val2", output) def test_inject_escaped_values(self): - values = { - "key1": "val1,val2", - "key2": "val3=4", - } - output = self._inject(values) + output = self._inject({"key1": "val1,val2", "key2": "val3=4"}) self.assertIn("key1=val1%2Cval2", output) self.assertIn("key2=val3%3D4", output) def test_inject_non_string_values(self): - values = { - "key1": True, - "key2": 123, - "key3": 123.567, - } - output = self._inject(values) + output = self._inject({"key1": True, "key2": 123, "key3": 123.567}) self.assertIn("key1=True", output) self.assertIn("key2=123", output) self.assertIn("key3=123.567", output) @patch("opentelemetry.baggage.propagation.baggage") - @patch("opentelemetry.baggage.propagation._format_baggage") - def test_fields(self, mock_format_baggage, mock_baggage): - - mock_set_in_carrier = Mock() + def test_fields(self, mock_baggage): - self.propagator.inject(mock_set_in_carrier, {}) + mock_baggage.configure_mock( + **{"get_all.return_value": {"a": "b", "c": "d"}} + ) - inject_fields = set() + carrier = {} - for mock_call in mock_set_in_carrier.mock_calls: - inject_fields.add(mock_call[1][1]) + self.propagator.inject(carrier) - self.assertEqual(inject_fields, self.propagator.fields) + self.assertEqual(carrier.keys(), self.propagator.fields) diff --git a/opentelemetry-api/tests/propagators/test_composite.py b/opentelemetry-api/tests/propagators/test_composite.py index 232e177d3d0..9975926a84b 100644 --- a/opentelemetry-api/tests/propagators/test_composite.py +++ b/opentelemetry-api/tests/propagators/test_composite.py @@ -126,13 +126,8 @@ def test_fields(self): ] ) - mock_set_in_carrier = Mock() - - propagator.inject(mock_set_in_carrier, {}) + propagator.inject({}) inject_fields = set() - for mock_call in mock_set_in_carrier.mock_calls: - inject_fields.add(mock_call[1][1]) - self.assertEqual(inject_fields, propagator.fields) diff --git a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py index cff30b7c9b8..66923326a5b 100644 --- a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py +++ b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py @@ -13,11 +13,18 @@ # limitations under the License. import typing -import unittest +from unittest import TestCase from unittest.mock import Mock, patch -from opentelemetry import trace from opentelemetry.propagators.textmap import DictGetter +from opentelemetry.trace import ( + INVALID_SPAN, + INVALID_SPAN_CONTEXT, + NonRecordingSpan, + SpanContext, + get_current_span, + set_span_in_context, +) from opentelemetry.trace.propagation import tracecontext from opentelemetry.trace.span import TraceState @@ -26,9 +33,9 @@ carrier_getter = DictGetter() -class TestTraceContextFormat(unittest.TestCase): - TRACE_ID = int("12345678901234567890123456789012", 16) # type:int - SPAN_ID = int("1234567890123456", 16) # type:int +class TestTraceContextFormat(TestCase): + trace_id = int("12345678901234567890123456789012", 16) # type:int + span_id = int("1234567890123456", 16) # type:int def test_no_traceparent_header(self): """When tracecontext headers are not present, a new SpanContext @@ -40,19 +47,19 @@ def test_no_traceparent_header(self): trace-id and parent-id that represents the current request. """ output = {} # type:typing.Dict[str, typing.List[str]] - span = trace.get_current_span(FORMAT.extract(carrier_getter, output)) - self.assertIsInstance(span.get_span_context(), trace.SpanContext) + span = get_current_span(FORMAT.extract(carrier_getter, output)) + self.assertIsInstance(span.get_span_context(), SpanContext) def test_headers_with_tracestate(self): """When there is a traceparent and tracestate header, data from both should be addded to the SpanContext. """ traceparent_value = "00-{trace_id}-{span_id}-00".format( - trace_id=format(self.TRACE_ID, "032x"), - span_id=format(self.SPAN_ID, "016x"), + trace_id=format(self.trace_id, "032x"), + span_id=format(self.span_id, "016x"), ) tracestate_value = "foo=1,bar=2,baz=3" - span_context = trace.get_current_span( + span_context = get_current_span( FORMAT.extract( carrier_getter, { @@ -61,16 +68,16 @@ def test_headers_with_tracestate(self): }, ) ).get_span_context() - self.assertEqual(span_context.trace_id, self.TRACE_ID) - self.assertEqual(span_context.span_id, self.SPAN_ID) + self.assertEqual(span_context.trace_id, self.trace_id) + self.assertEqual(span_context.span_id, self.span_id) self.assertEqual( span_context.trace_state, {"foo": "1", "bar": "2", "baz": "3"} ) self.assertTrue(span_context.is_remote) output = {} # type:typing.Dict[str, str] - span = trace.NonRecordingSpan(span_context) + span = NonRecordingSpan(span_context) - ctx = trace.set_span_in_context(span) + ctx = set_span_in_context(span) FORMAT.inject(dict.__setitem__, output, ctx) self.assertEqual(output["traceparent"], traceparent_value) for pair in ["foo=1", "bar=2", "baz=3"]: @@ -96,7 +103,7 @@ def test_invalid_trace_id(self): Note that the opposite is not true: failure to parse tracestate MUST NOT affect the parsing of traceparent. """ - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( carrier_getter, { @@ -107,7 +114,7 @@ def test_invalid_trace_id(self): }, ) ) - self.assertEqual(span.get_span_context(), trace.INVALID_SPAN_CONTEXT) + self.assertEqual(span.get_span_context(), INVALID_SPAN_CONTEXT) def test_invalid_parent_id(self): """If the parent id is invalid, we must ignore the full traceparent @@ -127,7 +134,7 @@ def test_invalid_parent_id(self): Note that the opposite is not true: failure to parse tracestate MUST NOT affect the parsing of traceparent. """ - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( carrier_getter, { @@ -138,7 +145,7 @@ def test_invalid_parent_id(self): }, ) ) - self.assertEqual(span.get_span_context(), trace.INVALID_SPAN_CONTEXT) + self.assertEqual(span.get_span_context(), INVALID_SPAN_CONTEXT) def test_no_send_empty_tracestate(self): """If the tracestate is empty, do not set the header. @@ -149,10 +156,10 @@ def test_no_send_empty_tracestate(self): empty tracestate headers but SHOULD avoid sending them. """ output = {} # type:typing.Dict[str, str] - span = trace.NonRecordingSpan( - trace.SpanContext(self.TRACE_ID, self.SPAN_ID, is_remote=False) + span = NonRecordingSpan( + SpanContext(self.trace_id, self.span_id, is_remote=False) ) - ctx = trace.set_span_in_context(span) + ctx = set_span_in_context(span) FORMAT.inject(dict.__setitem__, output, ctx) self.assertTrue("traceparent" in output) self.assertFalse("tracestate" in output) @@ -165,7 +172,7 @@ def test_format_not_supported(self): If the version cannot be parsed, return an invalid trace header. """ - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( carrier_getter, { @@ -177,18 +184,18 @@ def test_format_not_supported(self): }, ) ) - self.assertEqual(span.get_span_context(), trace.INVALID_SPAN_CONTEXT) + self.assertEqual(span.get_span_context(), INVALID_SPAN_CONTEXT) def test_propagate_invalid_context(self): """Do not propagate invalid trace context.""" output = {} # type:typing.Dict[str, str] - ctx = trace.set_span_in_context(trace.INVALID_SPAN) + ctx = set_span_in_context(INVALID_SPAN) FORMAT.inject(dict.__setitem__, output, context=ctx) self.assertFalse("traceparent" in output) def test_tracestate_empty_header(self): """Test tracestate with an additional empty header (should be ignored)""" - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( carrier_getter, { @@ -203,7 +210,7 @@ def test_tracestate_empty_header(self): def test_tracestate_header_with_trailing_comma(self): """Do not propagate invalid trace context.""" - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( carrier_getter, { @@ -226,7 +233,7 @@ def test_tracestate_keys(self): "foo-_*/bar=bar4", ] ) - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( carrier_getter, { @@ -268,13 +275,8 @@ def test_fields(self, mock_get_current_span, mock_invalid_span_context): ) ) - mock_set_in_carrier = Mock() + carrier = {} - FORMAT.inject(mock_set_in_carrier, {}) + FORMAT.inject(carrier) - inject_fields = set() - - for mock_call in mock_set_in_carrier.mock_calls: - inject_fields.add(mock_call[1][1]) - - self.assertEqual(inject_fields, FORMAT.fields) + self.assertEqual(carrier.keys(), {"traceparent", "tracestate"}) diff --git a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py index 01abcc7c879..7906b049243 100644 --- a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py +++ b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typing -from re import compile as re_compile +from re import compile as compile_ +from typing import Optional, Set import opentelemetry.trace as trace from opentelemetry.context import Context from opentelemetry.propagators.textmap import ( - Getter, - Setter, TextMapPropagator, TextMapPropagatorT, ) @@ -39,23 +37,19 @@ class B3Format(TextMapPropagator): SAMPLED_KEY = "x-b3-sampled" FLAGS_KEY = "x-b3-flags" _SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"]) - _trace_id_regex = re_compile(r"[\da-fA-F]{16}|[\da-fA-F]{32}") - _span_id_regex = re_compile(r"[\da-fA-F]{16}") + _trace_id_regex = compile_(r"[\da-fA-F]{16}|[\da-fA-F]{32}") + _span_id_regex = compile_(r"[\da-fA-F]{16}") def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: trace_id = format_trace_id(trace.INVALID_TRACE_ID) span_id = format_span_id(trace.INVALID_SPAN_ID) sampled = "0" flags = None - single_header = _extract_first_element( - getter.get(carrier, self.SINGLE_HEADER_KEY) - ) + single_header = next(iter(carrier.get(self.TRACE_ID_KEY, [])), False) + if single_header: # The b3 spec calls for the sampling state to be # "deferred", which is unspecified. This concept does not @@ -75,21 +69,16 @@ def extract( return trace.set_span_in_context(trace.INVALID_SPAN) else: trace_id = ( - _extract_first_element(getter.get(carrier, self.TRACE_ID_KEY)) + next(iter(carrier.get(self.TRACE_ID_KEY, [])), False) or trace_id ) span_id = ( - _extract_first_element(getter.get(carrier, self.SPAN_ID_KEY)) - or span_id + next(iter(carrier.get(self.SPAN_ID_KEY, [])), False) or span_id ) sampled = ( - _extract_first_element(getter.get(carrier, self.SAMPLED_KEY)) - or sampled - ) - flags = ( - _extract_first_element(getter.get(carrier, self.FLAGS_KEY)) - or flags + next(iter(carrier.get(self.SAMPLED_KEY, [])), False) or sampled ) + flags = next(iter(carrier.get(self.FLAGS_KEY, [])), False) or flags if ( self._trace_id_regex.fullmatch(trace_id) is None @@ -126,10 +115,7 @@ def extract( ) def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: span = trace.get_current_span(context=context) @@ -138,34 +124,20 @@ def inject( return sampled = (trace.TraceFlags.SAMPLED & span_context.trace_flags) != 0 - set_in_carrier( - carrier, self.TRACE_ID_KEY, format_trace_id(span_context.trace_id), - ) - set_in_carrier( - carrier, self.SPAN_ID_KEY, format_span_id(span_context.span_id) - ) + carrier[self.TRACE_ID_KEY] = format_trace_id(span_context.trace_id) + carrier[self.SPAN_ID_KEY] = format_span_id(span_context.span_id) span_parent = getattr(span, "parent", None) if span_parent is not None: - set_in_carrier( - carrier, - self.PARENT_SPAN_ID_KEY, + carrier[self.PARENT_SPAN_ID_KEY] = ( format_span_id(span_parent.span_id), ) - set_in_carrier(carrier, self.SAMPLED_KEY, "1" if sampled else "0") + carrier[self.SAMPLED_KEY] = "1" if sampled else "0" @property - def fields(self) -> typing.Set[str]: + def fields(self) -> Set[str]: return { self.TRACE_ID_KEY, self.SPAN_ID_KEY, self.PARENT_SPAN_ID_KEY, self.SAMPLED_KEY, } - - -def _extract_first_element( - items: typing.Iterable[TextMapPropagatorT], -) -> typing.Optional[TextMapPropagatorT]: - if items is None: - return None - return next(iter(items), None) diff --git a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py index f9d3bce1adb..55726110a03 100644 --- a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py +++ b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py @@ -12,17 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from unittest import TestCase from unittest.mock import Mock, patch -import opentelemetry.propagators.b3 as b3_format # pylint: disable=no-name-in-module,import-error -import opentelemetry.sdk.trace as trace -import opentelemetry.sdk.trace.id_generator as id_generator -import opentelemetry.trace as trace_api +from opentelemetry import trace from opentelemetry.context import get_current +from opentelemetry.propagators.b3 import ( + B3Format, + format_span_id, + format_trace_id, +) from opentelemetry.propagators.textmap import DictGetter +from opentelemetry.sdk.trace import TracerProvider, _Span, id_generator +from opentelemetry.trace import ( + INVALID_SPAN_ID, + INVALID_TRACE_ID, + SpanContext, + get_current_span, + set_span_in_context, +) -FORMAT = b3_format.B3Format() +format_ = B3Format() carrier_getter = DictGetter() @@ -30,13 +40,13 @@ def get_child_parent_new_carrier(old_carrier): - ctx = FORMAT.extract(carrier_getter, old_carrier) - parent_span_context = trace_api.get_current_span(ctx).get_span_context() + ctx = format_.extract(carrier_getter, old_carrier) + parent_span_context = get_current_span(ctx).get_span_context() - parent = trace._Span("parent", parent_span_context) - child = trace._Span( + parent = _Span("parent", parent_span_context) + child = _Span( "child", - trace_api.SpanContext( + SpanContext( parent_span_context.trace_id, id_generator.RandomIdGenerator().generate_span_id(), is_remote=False, @@ -47,30 +57,26 @@ def get_child_parent_new_carrier(old_carrier): ) new_carrier = {} - ctx = trace_api.set_span_in_context(child) - FORMAT.inject(dict.__setitem__, new_carrier, context=ctx) + ctx = set_span_in_context(child) + format_.inject(dict.__setitem__, new_carrier, context=ctx) return child, parent, new_carrier -class TestB3Format(unittest.TestCase): +class TestB3Format(TestCase): @classmethod def setUpClass(cls): generator = id_generator.RandomIdGenerator() - cls.serialized_trace_id = b3_format.format_trace_id( + cls.serialized_trace_id = format_trace_id( generator.generate_trace_id() ) - cls.serialized_span_id = b3_format.format_span_id( - generator.generate_span_id() - ) - cls.serialized_parent_id = b3_format.format_span_id( - generator.generate_span_id() - ) + cls.serialized_span_id = format_span_id(generator.generate_span_id()) + cls.serialized_parent_id = format_span_id(generator.generate_span_id()) def setUp(self) -> None: - tracer_provider = trace.TracerProvider() - patcher = unittest.mock.patch.object( - trace_api, "get_tracer_provider", return_value=tracer_provider + tracer_provider = TracerProvider() + patcher = patch.object( + trace, "get_tracer_provider", return_value=tracer_provider ) patcher.start() self.addCleanup(patcher.stop) @@ -79,52 +85,52 @@ def test_extract_multi_header(self): """Test the extraction of B3 headers.""" child, parent, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.PARENT_SPAN_ID_KEY: self.serialized_parent_id, - FORMAT.SAMPLED_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.PARENT_SPAN_ID_KEY: self.serialized_parent_id, + format_.SAMPLED_KEY: "1", } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], - b3_format.format_trace_id(child.context.trace_id), + new_carrier[format_.TRACE_ID_KEY], + format_trace_id(child.context.trace_id), ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], - b3_format.format_span_id(child.context.span_id), + new_carrier[format_.SPAN_ID_KEY], + format_span_id(child.context.span_id), ) self.assertEqual( - new_carrier[FORMAT.PARENT_SPAN_ID_KEY], - b3_format.format_span_id(parent.context.span_id), + new_carrier[format_.PARENT_SPAN_ID_KEY], + format_span_id(parent.context.span_id), ) self.assertTrue(parent.context.is_remote) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_extract_single_header(self): """Test the extraction from a single b3 header.""" child, parent, new_carrier = get_child_parent_new_carrier( { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + format_.SINGLE_HEADER_KEY: "{}-{}".format( self.serialized_trace_id, self.serialized_span_id ) } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], - b3_format.format_trace_id(child.context.trace_id), + new_carrier[format_.TRACE_ID_KEY], + format_trace_id(child.context.trace_id), ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], - b3_format.format_span_id(child.context.span_id), + new_carrier[format_.SPAN_ID_KEY], + format_span_id(child.context.span_id), ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") self.assertTrue(parent.context.is_remote) child, parent, new_carrier = get_child_parent_new_carrier( { - FORMAT.SINGLE_HEADER_KEY: "{}-{}-1-{}".format( + format_.SINGLE_HEADER_KEY: "{}-{}-1-{}".format( self.serialized_trace_id, self.serialized_span_id, self.serialized_parent_id, @@ -133,19 +139,19 @@ def test_extract_single_header(self): ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], - b3_format.format_trace_id(child.context.trace_id), + new_carrier[format_.TRACE_ID_KEY], + format_trace_id(child.context.trace_id), ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], - b3_format.format_span_id(child.context.span_id), + new_carrier[format_.SPAN_ID_KEY], + format_span_id(child.context.span_id), ) self.assertEqual( - new_carrier[FORMAT.PARENT_SPAN_ID_KEY], - b3_format.format_span_id(parent.context.span_id), + new_carrier[format_.PARENT_SPAN_ID_KEY], + format_span_id(parent.context.span_id), ) self.assertTrue(parent.context.is_remote) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_extract_header_precedence(self): """A single b3 header should take precedence over multiple @@ -155,17 +161,17 @@ def test_extract_header_precedence(self): _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + format_.SINGLE_HEADER_KEY: "{}-{}".format( single_header_trace_id, self.serialized_span_id ), - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.SAMPLED_KEY: "1", } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id + new_carrier[format_.TRACE_ID_KEY], single_header_trace_id ) def test_enabled_sampling(self): @@ -173,50 +179,50 @@ def test_enabled_sampling(self): for variant in ["1", "True", "true", "d"]: _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: variant, + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.SAMPLED_KEY: variant, } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_disabled_sampling(self): """Test b3 sample key variants that turn off sampling.""" for variant in ["0", "False", "false", None]: _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: variant, + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.SAMPLED_KEY: variant, } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "0") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "0") def test_flags(self): """x-b3-flags set to "1" should result in propagation.""" _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_flags_and_sampling(self): """Propagate if b3 flags and sampling are set.""" _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_64bit_trace_id(self): """64 bit trace ids should be padded to 128 bit trace ids.""" @@ -224,36 +230,36 @@ def test_64bit_trace_id(self): _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: trace_id_64_bit, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: trace_id_64_bit, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit + new_carrier[format_.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit ) def test_invalid_single_header(self): """If an invalid single header is passed, return an invalid SpanContext. """ - carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() - self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) - self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) + carrier = {format_.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} + ctx = format_.extract(carrier_getter, carrier) + span_context = get_current_span(ctx).get_span_context() + self.assertEqual(span_context.trace_id, INVALID_TRACE_ID) + self.assertEqual(span_context.span_id, INVALID_SPAN_ID) def test_missing_trace_id(self): """If a trace id is missing, populate an invalid trace id.""" carrier = { - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() - self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) + ctx = format_.extract(carrier_getter, carrier) + span_context = get_current_span(ctx).get_span_context() + self.assertEqual(span_context.trace_id, INVALID_TRACE_ID) @patch( "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_trace_id" @@ -270,13 +276,13 @@ def test_invalid_trace_id( mock_generate_span_id.configure_mock(return_value=2) carrier = { - FORMAT.TRACE_ID_KEY: "abc123", - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: "abc123", + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() + ctx = format_.extract(carrier_getter, carrier) + span_context = get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, 1) self.assertEqual(span_context.span_id, 2) @@ -296,13 +302,13 @@ def test_invalid_span_id( mock_generate_span_id.configure_mock(return_value=2) carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: "abc123", - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: "abc123", + format_.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() + ctx = format_.extract(carrier_getter, carrier) + span_context = get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, 1) self.assertEqual(span_context.span_id, 2) @@ -310,19 +316,19 @@ def test_invalid_span_id( def test_missing_span_id(self): """If a trace id is missing, populate an invalid trace id.""" carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() - self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) + ctx = format_.extract(carrier_getter, carrier) + span_context = get_current_span(ctx).get_span_context() + self.assertEqual(span_context.span_id, INVALID_SPAN_ID) @staticmethod def test_inject_empty_context(): """If the current context has no span, don't add headers""" new_carrier = {} - FORMAT.inject(dict.__setitem__, new_carrier, get_current()) + format_.inject(dict.__setitem__, new_carrier, get_current()) assert len(new_carrier) == 0 @staticmethod @@ -336,23 +342,23 @@ def get(self, carrier, key): def setter(carrier, key, value): carrier[key] = value - ctx = FORMAT.extract(CarrierGetter(), {}) - FORMAT.inject(setter, {}, ctx) + ctx = format_.extract(CarrierGetter(), {}) + format_.inject(setter, {}, ctx) def test_fields(self): """Make sure the fields attribute returns the fields used in inject""" - tracer = trace.TracerProvider().get_tracer("sdk_tracer_provider") + tracer = TracerProvider().get_tracer("sdk_tracer_provider") mock_set_in_carrier = Mock() with tracer.start_as_current_span("parent"): with tracer.start_as_current_span("child"): - FORMAT.inject(mock_set_in_carrier, {}) + format_.inject(mock_set_in_carrier, {}) inject_fields = set() for call in mock_set_in_carrier.mock_calls: inject_fields.add(call[1][1]) - self.assertEqual(FORMAT.fields, inject_fields) + self.assertEqual(format_.fields, inject_fields) diff --git a/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py b/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py index 8e7fe5f69ff..183302c9d87 100644 --- a/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py +++ b/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py @@ -12,19 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typing -import urllib.parse +from typing import Optional, Set +from urllib.parse import quote, unquote -import opentelemetry.trace as trace from opentelemetry import baggage from opentelemetry.context import Context, get_current from opentelemetry.propagators.textmap import ( - Getter, - Setter, TextMapPropagator, TextMapPropagatorT, ) -from opentelemetry.trace import format_span_id, format_trace_id +from opentelemetry.trace import ( + INVALID_SPAN, + INVALID_SPAN_CONTEXT, + INVALID_SPAN_ID, + INVALID_TRACE_ID, + NonRecordingSpan, + SpanContext, + TraceFlags, + format_span_id, + format_trace_id, + get_current_span, + set_span_in_context, +) class JaegerPropagator(TextMapPropagator): @@ -33,73 +42,70 @@ class JaegerPropagator(TextMapPropagator): See: https://www.jaegertracing.io/docs/1.19/client-libraries/#propagation-format """ - TRACE_ID_KEY = "uber-trace-id" - BAGGAGE_PREFIX = "uberctx-" - DEBUG_FLAG = 0x02 + _trace_id_key = "uber-trace-id" + _baggage_prefix = "uberctx-" + _debug_flag = 0x02 def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: if context is None: context = get_current() - header = getter.get(carrier, self.TRACE_ID_KEY) - if not header: - return trace.set_span_in_context(trace.INVALID_SPAN, context) - fields = _extract_first_element(header).split(":") + header = carrier.get(self._trace_id_key) + if header is None: + return set_span_in_context(INVALID_SPAN, context) + fields = next(iter(header)).split(":") + + for key in [ + key + for key in carrier.keys() + if key.startswith(self._baggage_prefix) + ]: + context = baggage.set_baggage( + key.replace(self._baggage_prefix, ""), + unquote(next(iter(carrier[key]))).strip(), + context=context, + ) - context = self._extract_baggage(getter, carrier, context) if len(fields) != 4: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return set_span_in_context(INVALID_SPAN, context) trace_id, span_id, _parent_id, flags = fields - if ( - trace_id == trace.INVALID_TRACE_ID - or span_id == trace.INVALID_SPAN_ID - ): - return trace.set_span_in_context(trace.INVALID_SPAN, context) - - span = trace.NonRecordingSpan( - trace.SpanContext( + if trace_id == INVALID_TRACE_ID or span_id == INVALID_SPAN_ID: + return set_span_in_context(INVALID_SPAN, context) + + span = NonRecordingSpan( + SpanContext( trace_id=int(trace_id, 16), span_id=int(span_id, 16), is_remote=True, - trace_flags=trace.TraceFlags( - int(flags, 16) & trace.TraceFlags.SAMPLED - ), + trace_flags=TraceFlags(int(flags, 16) & TraceFlags.SAMPLED), ) ) - return trace.set_span_in_context(span, context) + return set_span_in_context(span, context) def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: - span = trace.get_current_span(context=context) + span = get_current_span(context=context) span_context = span.get_span_context() - if span_context == trace.INVALID_SPAN_CONTEXT: + if span_context == INVALID_SPAN_CONTEXT: return span_parent_id = span.parent.span_id if span.parent else 0 trace_flags = span_context.trace_flags if trace_flags.sampled: - trace_flags |= self.DEBUG_FLAG + trace_flags |= self._debug_flag # set span identity - set_in_carrier( - carrier, - self.TRACE_ID_KEY, - _format_uber_trace_id( - span_context.trace_id, - span_context.span_id, - span_parent_id, - trace_flags, - ), + carrier[ + self._trace_id_key + ] = "{trace_id}:{span_id}:{parent_id}:{:02x}".format( + trace_flags, + trace_id=format_trace_id(span_context.trace_id), + span_id=format_span_id(span_context.span_id), + parent_id=format_span_id(span_parent_id), ) # set span baggage, if any @@ -107,43 +113,9 @@ def inject( if not baggage_entries: return for key, value in baggage_entries.items(): - baggage_key = self.BAGGAGE_PREFIX + key - set_in_carrier( - carrier, baggage_key, urllib.parse.quote(str(value)) - ) + baggage_key = self._baggage_prefix + key + carrier[baggage_key] = quote(str(value)) @property - def fields(self) -> typing.Set[str]: - return {self.TRACE_ID_KEY} - - def _extract_baggage(self, getter, carrier, context): - baggage_keys = [ - key - for key in getter.keys(carrier) - if key.startswith(self.BAGGAGE_PREFIX) - ] - for key in baggage_keys: - value = _extract_first_element(getter.get(carrier, key)) - context = baggage.set_baggage( - key.replace(self.BAGGAGE_PREFIX, ""), - urllib.parse.unquote(value).strip(), - context=context, - ) - return context - - -def _format_uber_trace_id(trace_id, span_id, parent_span_id, flags): - return "{trace_id}:{span_id}:{parent_id}:{:02x}".format( - flags, - trace_id=format_trace_id(trace_id), - span_id=format_span_id(span_id), - parent_id=format_span_id(parent_span_id), - ) - - -def _extract_first_element( - items: typing.Iterable[TextMapPropagatorT], -) -> typing.Optional[TextMapPropagatorT]: - if items is None: - return None - return next(iter(items), None) + def fields(self) -> Set[str]: + return {self._trace_id_key} diff --git a/tests/util/src/opentelemetry/test/mock_textmap.py b/tests/util/src/opentelemetry/test/mock_textmap.py index 1edd079042f..88a48717de1 100644 --- a/tests/util/src/opentelemetry/test/mock_textmap.py +++ b/tests/util/src/opentelemetry/test/mock_textmap.py @@ -12,16 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typing +from typing import Optional -from opentelemetry import trace from opentelemetry.context import Context, get_current from opentelemetry.propagators.textmap import ( - Getter, - Setter, TextMapPropagator, TextMapPropagatorT, ) +from opentelemetry.trace import ( + INVALID_SPAN, + NonRecordingSpan, + SpanContext, + get_current_span, + set_span_in_context, +) class NOOPTextMapPropagator(TextMapPropagator): @@ -32,18 +36,12 @@ class NOOPTextMapPropagator(TextMapPropagator): """ def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: return get_current() def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: return None @@ -55,24 +53,21 @@ def fields(self): class MockTextMapPropagator(TextMapPropagator): """Mock propagator for testing purposes.""" - TRACE_ID_KEY = "mock-traceid" - SPAN_ID_KEY = "mock-spanid" + trace_id_key = "mock-traceid" + span_id_key = "mock-spanid" def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: - trace_id_list = getter.get(carrier, self.TRACE_ID_KEY) - span_id_list = getter.get(carrier, self.SPAN_ID_KEY) + trace_id_list = carrier.get(self.trace_id_key) + span_id_list = carrier.get(self.span_id_key) if not trace_id_list or not span_id_list: - return trace.set_span_in_context(trace.INVALID_SPAN) + return set_span_in_context(INVALID_SPAN) - return trace.set_span_in_context( - trace.NonRecordingSpan( - trace.SpanContext( + return set_span_in_context( + NonRecordingSpan( + SpanContext( trace_id=int(trace_id_list[0]), span_id=int(span_id_list[0]), is_remote=True, @@ -81,19 +76,12 @@ def extract( ) def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: - span = trace.get_current_span(context) - set_in_carrier( - carrier, self.TRACE_ID_KEY, str(span.get_span_context().trace_id) - ) - set_in_carrier( - carrier, self.SPAN_ID_KEY, str(span.get_span_context().span_id) - ) + span = get_current_span(context) + carrier[self.trace_id_key] = str(span.get_span_context().trace_id) + carrier[self.span_id_key] = str(span.get_span_context().span_id) @property def fields(self): - return {self.TRACE_ID_KEY, self.SPAN_ID_KEY} + return {self.trace_id_key, self.span_id_key}