From f62eeacfa181c6e7d17ba7f7468bc2b05ace8a40 Mon Sep 17 00:00:00 2001 From: Yusuke Tsutsumi Date: Sat, 3 Aug 2019 15:29:56 -0700 Subject: [PATCH 1/5] Adding propagators API and b3 SDK implementation (#51, #52) Specification: https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/api-propagators.md. --- .../context/propagation/__init__.py | 2 + .../context/propagation/binaryformat.py | 29 +++++++ .../context/propagation/httptextformat.py | 79 +++++++++++++++++++ .../src/opentelemetry/sdk/context/__init__.py | 0 .../sdk/context/propagation/__init__.py | 0 .../sdk/context/propagation/b3_format.py | 74 +++++++++++++++++ opentelemetry-sdk/tests/context/__init__.py | 0 .../tests/context/propagation/__init__.py | 0 .../context/propagation/test_b3_format.py | 57 +++++++++++++ 9 files changed, 241 insertions(+) create mode 100644 opentelemetry-api/src/opentelemetry/context/propagation/__init__.py create mode 100644 opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py create mode 100644 opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py create mode 100644 opentelemetry-sdk/src/opentelemetry/sdk/context/__init__.py create mode 100644 opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/__init__.py create mode 100644 opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py create mode 100644 opentelemetry-sdk/tests/context/__init__.py create mode 100644 opentelemetry-sdk/tests/context/propagation/__init__.py create mode 100644 opentelemetry-sdk/tests/context/propagation/test_b3_format.py diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py new file mode 100644 index 0000000000..c00a800b69 --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py @@ -0,0 +1,2 @@ +from .httptextformat import HTTPTextFormat +from .binaryformat import BinaryFormat diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py new file mode 100644 index 0000000000..66dbd277f8 --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py @@ -0,0 +1,29 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import typing +from opentelemetry.trace import SpanContext + + +class BinaryFormat(abc.ABC): + @staticmethod + @abc.abstractmethod + def to_bytes(context: SpanContext) -> bytes: + pass + + @staticmethod + @abc.abstractmethod + def from_bytes(byte_representation: bytes) -> typing.Optional[SpanContext]: + pass diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py new file mode 100644 index 0000000000..395539bbc7 --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py @@ -0,0 +1,79 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import typing +from opentelemetry.trace import SpanContext + +Setter = typing.Callable[[object, str, str], None] +Getter = typing.Callable[[object, str], str] + + +class HTTPTextFormat(abc.ABC): + """API for propagation of spans via headers. + + This class provides an interface that enables extracting and injecting + trace state into headers of HTTP requests. Http frameworks and client + can integrate with HTTPTextFormat by providing the object containing the + headers, and a getter and setter function for the extraction and + injection of values, respectively. + + Example:: + + import flask + import requests + from opentelemetry.context.propagation import HTTPTextFormat + + PROPAGATOR = HTTPTextFormat() + + + + def get_header_from_flask_request(request, key): + return request.headers[key] + + def set_header_into_requests_request(request: requests.Request, + key: str, value: str): + request.headers[key] = value + + def example_route(): + span_context = PROPAGATOR.extract( + get_header_from_flask_request, + flask.request + ) + request_to_downstream = requests.Request( + "GET", "http://httpbin.org/get" + ) + PROPAGATOR.inject( + span_context, + set_header_into_requests_request, + request_to_downstream + ) + session = requests.Session() + session.send(request_to_downstream.prepare()) + + + .. _Propagation API Specification: + https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/api-propagators.md + + Enabling this flexi + """ + @abc.abstractmethod + def extract(self, get_from_carrier: Getter, + carrier: object) -> SpanContext: + pass + + @abc.abstractmethod + def inject(self, context: SpanContext, set_in_carrier: Setter, + carrier: object): + pass diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py new file mode 100644 index 0000000000..031f87c75f --- /dev/null +++ b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py @@ -0,0 +1,74 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from opentelemetry.context.propagation.httptextformat import HTTPTextFormat +import opentelemetry.trace as trace + + +class B3Format(HTTPTextFormat): + """Propagator for the B3 HTTP header format. + + See: https://github.com/openzipkin/b3-propagation + """ + + SINGLE_HEADER_KEY = "b3" + TRACE_ID_KEY = "x-b3-traceid" + SPAN_ID_KEY = "x-b3-spanid" + SAMPLED_KEY = "x-b3-sampled" + + @classmethod + def extract(cls, get_from_carrier, carrier): + trace_id = trace.INVALID_TRACE_ID + span_id = trace.INVALID_SPAN_ID + sampled = 1 + + single_header = get_from_carrier(carrier, cls.SINGLE_HEADER_KEY) + if single_header: + # b3-propagation spec calls for the sampling state to be + # "deferred", which is unspecified. This concept does not + # translate to SpanContext, so we set it as recorded. + sampled = "1" + fields = single_header.split("-", 4) + + if len(fields) == 1: + sampled = fields[0] + elif len(fields) == 2: + trace_id, span_id = fields + elif len(fields) == 3: + trace_id, span_id, sampled = fields + elif len(fields) == 4: + trace_id, span_id, sampled, _parent_span_id = fields + else: + return trace.INVALID_SPAN_CONTEXT + else: + trace_id = get_from_carrier(carrier, cls.TRACE_ID_KEY) + span_id = get_from_carrier(carrier, cls.SPAN_ID_KEY) + sampled = get_from_carrier(carrier, cls.SAMPLED_KEY) + + options = 0 + if sampled == "1": + options |= trace.TraceOptions.RECORDED + return trace.SpanContext( + trace_id=int(trace_id), + span_id=int(span_id), + trace_options=options, + trace_state={}, + ) + + @classmethod + def inject(cls, context, set_in_carrier, carrier): + sampled = (trace.TraceOptions.RECORDED & context.trace_options) != 0 + set_in_carrier(carrier, cls.TRACE_ID_KEY, str(context.trace_id)) + set_in_carrier(carrier, cls.SPAN_ID_KEY, str(context.span_id)) + set_in_carrier(carrier, cls.SAMPLED_KEY, "1" if sampled else "0") diff --git a/opentelemetry-sdk/tests/context/__init__.py b/opentelemetry-sdk/tests/context/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/opentelemetry-sdk/tests/context/propagation/__init__.py b/opentelemetry-sdk/tests/context/propagation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py new file mode 100644 index 0000000000..79b1f82726 --- /dev/null +++ b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py @@ -0,0 +1,57 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import opentelemetry.sdk.context.propagation.b3_format as b3_format +import opentelemetry.sdk.trace as trace + +FORMAT = b3_format.B3Format() + + +def _get_from_dict(carrier: dict, key: str) -> str: + return carrier.get(key) + + +def _set_into_dict(carrier: dict, key: str, value: str): + carrier[key] = value + + +class TestB3Format(unittest.TestCase): + def test_extract_multi_header(self): + """Test the extraction of B3 headers """ + trace_id = str(trace.generate_trace_id()) + span_id = str(trace.generate_span_id()) + carrier = { + FORMAT.TRACE_ID_KEY: trace_id, + FORMAT.SPAN_ID_KEY: span_id, + FORMAT.SAMPLED_KEY: "1", + } + span_context = FORMAT.extract(_get_from_dict, carrier) + new_carrier = {} + FORMAT.inject(span_context, _set_into_dict, new_carrier) + self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], trace_id) + self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], span_id) + self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + + def test_extract_single_headder(self): + """Test the extraction from a single b3 header""" + trace_id = str(trace.generate_trace_id()) + span_id = str(trace.generate_span_id()) + carrier = {FORMAT.SINGLE_HEADER_KEY: "{}-{}".format(trace_id, span_id)} + span_context = FORMAT.extract(_get_from_dict, carrier) + new_carrier = {} + FORMAT.inject(span_context, _set_into_dict, new_carrier) + self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], trace_id) + self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], span_id) + self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") From 4a3d6179e8bc1c5771972b990f23d71e852edbc2 Mon Sep 17 00:00:00 2001 From: Yusuke Tsutsumi Date: Thu, 8 Aug 2019 08:45:42 -0700 Subject: [PATCH 2/5] Addressing lint errors --- .../src/opentelemetry/context/propagation/__init__.py | 2 ++ .../src/opentelemetry/context/propagation/httptextformat.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py index c00a800b69..5b88cfc908 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py @@ -1,2 +1,4 @@ from .httptextformat import HTTPTextFormat from .binaryformat import BinaryFormat + +__all__ = ["HTTPTextFormat", "BinaryFormat"] diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py index 395539bbc7..c9b283a140 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py +++ b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py @@ -75,5 +75,5 @@ def extract(self, get_from_carrier: Getter, @abc.abstractmethod def inject(self, context: SpanContext, set_in_carrier: Setter, - carrier: object): + carrier: object) -> None: pass From ce65a7382ab32c49fa42ee9ba80d7879881855dc Mon Sep 17 00:00:00 2001 From: Yusuke Tsutsumi Date: Sun, 11 Aug 2019 21:32:37 -0700 Subject: [PATCH 3/5] Adding more flexible b3 format, fiixing linting incorporating unit tests and a more lenient implementation of the b3 propagator. Fixing a bug with b3 propagators consuming and producing integers rather than hex-encoded values. --- .../context/propagation/__init__.py | 4 +- .../context/propagation/binaryformat.py | 33 +++- .../context/propagation/httptextformat.py | 45 ++++- .../sdk/context/propagation/b3_format.py | 39 ++++- .../context/propagation/test_b3_format.py | 160 +++++++++++++++--- 5 files changed, 237 insertions(+), 44 deletions(-) diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py index 5b88cfc908..b964c2a968 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py @@ -1,4 +1,4 @@ -from .httptextformat import HTTPTextFormat from .binaryformat import BinaryFormat +from .httptextformat import HTTPTextFormat -__all__ = ["HTTPTextFormat", "BinaryFormat"] +__all__ = ["BinaryFormat", "HTTPTextFormat"] diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py index 66dbd277f8..f05ef69972 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py +++ b/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py @@ -14,16 +14,45 @@ import abc import typing + from opentelemetry.trace import SpanContext class BinaryFormat(abc.ABC): + """API for serialization of span context into binary formats. + + This class provides an interface that enables converting span contexts + to and from a binary format. + """ @staticmethod @abc.abstractmethod def to_bytes(context: SpanContext) -> bytes: - pass + """Creates a byte representation of a SpanContext. + + to_bytes should read values from a SpanContext and return a data + format to represent it, in bytes. + + Args: + context: the SpanContext to serialize + + Returns: + A bytes representation of the SpanContext. + """ @staticmethod @abc.abstractmethod def from_bytes(byte_representation: bytes) -> typing.Optional[SpanContext]: - pass + """Return a SpanContext that was represented by bytes. + + from_bytes should return back a SpanContext that was constructed from + the data serialized in the byte_representation passed. If it is not + possible to read in a proper SpanContext, return None. + + Args: + byte_representation: the bytes to deserialize + + Returns: + A bytes representation of the SpanContext if it is valid. + Otherwise return None. + + """ diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py index c9b283a140..2e7862c70a 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py +++ b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py @@ -14,17 +14,18 @@ import abc import typing + from opentelemetry.trace import SpanContext Setter = typing.Callable[[object, str, str], None] -Getter = typing.Callable[[object, str], str] +Getter = typing.Callable[[object, str], typing.Optional[str]] class HTTPTextFormat(abc.ABC): - """API for propagation of spans via headers. + """API for propagation of span context via headers. This class provides an interface that enables extracting and injecting - trace state into headers of HTTP requests. Http frameworks and client + span context into headers of HTTP requests. HTTP frameworks and clients can integrate with HTTPTextFormat by providing the object containing the headers, and a getter and setter function for the extraction and injection of values, respectively. @@ -65,15 +66,43 @@ def example_route(): .. _Propagation API Specification: https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/api-propagators.md - - Enabling this flexi """ @abc.abstractmethod def extract(self, get_from_carrier: Getter, carrier: object) -> SpanContext: - pass - + """Create a SpanContext from values in the carrier. + + The extract function should retrieve values from the carrier + object using get_from_carrier, and use values to populate a + SpanContext value and return it. + + Args: + get_from_carrier: a function that can retrieve a value + in the carrier, or return None if not + carrier: and object which contains values that are + used to construct a SpanContext. This object + must be paired with an appropriate get_from_carrier + which understands how to extract a value from it + Returns: + A SpanContext with configuration found in the carrier. + + """ @abc.abstractmethod def inject(self, context: SpanContext, set_in_carrier: Setter, carrier: object) -> None: - pass + """Inject values from a SpanContext into a carrier. + + inject enables the propagation of values into HTTP clients or + other objects which perform an HTTP request. Implementations + should use the set_in_carrier method to set values on the + carrier. + + Args: + context: The SpanContext to read values from + set_in_carrier: A setter function that can set values + on the carrier + carrier: An object that a place to define HTTP headers. + Should be paired with set_in_carrier, which should + know how to set header values on the carrier + + """ diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py index 031f87c75f..e658632148 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py @@ -26,16 +26,19 @@ class B3Format(HTTPTextFormat): TRACE_ID_KEY = "x-b3-traceid" SPAN_ID_KEY = "x-b3-spanid" SAMPLED_KEY = "x-b3-sampled" + FLAGS_KEY = "x-b3-flags" + _SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"]) @classmethod def extract(cls, get_from_carrier, carrier): trace_id = trace.INVALID_TRACE_ID span_id = trace.INVALID_SPAN_ID - sampled = 1 + sampled = 0 + flags = None single_header = get_from_carrier(carrier, cls.SINGLE_HEADER_KEY) if single_header: - # b3-propagation spec calls for the sampling state to be + # The b3 spec calls for the sampling state to be # "deferred", which is unspecified. This concept does not # translate to SpanContext, so we set it as recorded. sampled = "1" @@ -52,13 +55,29 @@ def extract(cls, get_from_carrier, carrier): else: return trace.INVALID_SPAN_CONTEXT else: - trace_id = get_from_carrier(carrier, cls.TRACE_ID_KEY) - span_id = get_from_carrier(carrier, cls.SPAN_ID_KEY) - sampled = get_from_carrier(carrier, cls.SAMPLED_KEY) + trace_id = get_from_carrier(carrier, cls.TRACE_ID_KEY) or trace_id + span_id = get_from_carrier(carrier, cls.SPAN_ID_KEY) or span_id + sampled = get_from_carrier(carrier, cls.SAMPLED_KEY) or sampled + flags = get_from_carrier(carrier, cls.FLAGS_KEY) or flags options = 0 - if sampled == "1": + # The b3 spec provides no defined behavior for both sample and + # flag values set. Since the setting of at least one implies + # the desire for some form of sampling, propagate if either + # header is set to allow. + if sampled in cls._SAMPLE_PROPAGATE_VALUES or flags == "1": options |= trace.TraceOptions.RECORDED + + # trace an span ids are encoded in hex, so must be converted + if trace_id != trace.INVALID_TRACE_ID: + # Convert 64-bit trace ids to 128-bit + if len(trace_id) == 16: + trace_id = "0" * 16 + trace_id + trace_id = int(trace_id, 16) + + if span_id != trace.INVALID_SPAN_ID: + span_id = int(span_id, 16) + return trace.SpanContext( trace_id=int(trace_id), span_id=int(span_id), @@ -69,6 +88,10 @@ def extract(cls, get_from_carrier, carrier): @classmethod def inject(cls, context, set_in_carrier, carrier): sampled = (trace.TraceOptions.RECORDED & context.trace_options) != 0 - set_in_carrier(carrier, cls.TRACE_ID_KEY, str(context.trace_id)) - set_in_carrier(carrier, cls.SPAN_ID_KEY, str(context.span_id)) + set_in_carrier( + carrier, cls.TRACE_ID_KEY, "{:032x}".format(context.trace_id) + ) + set_in_carrier( + carrier, cls.SPAN_ID_KEY, "{:016x}".format(context.span_id) + ) set_in_carrier(carrier, cls.SAMPLED_KEY, "1" if sampled else "0") diff --git a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py index 79b1f82726..d9f12919a7 100644 --- a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py @@ -13,45 +13,157 @@ # limitations under the License. import unittest +import opentelemetry.trace as api_trace import opentelemetry.sdk.context.propagation.b3_format as b3_format import opentelemetry.sdk.trace as trace FORMAT = b3_format.B3Format() -def _get_from_dict(carrier: dict, key: str) -> str: - return carrier.get(key) - - -def _set_into_dict(carrier: dict, key: str, value: str): - carrier[key] = value - - class TestB3Format(unittest.TestCase): + @classmethod + def setUpClass(cls): + trace_id = trace.generate_trace_id() + cls.trace_id = "{:032x}".format(trace_id) + cls.trace_id_internal = str(trace_id) + span_id = trace.generate_span_id() + cls.span_id = "{:016x}".format(span_id) + cls.span_id_internal = str(span_id) + def test_extract_multi_header(self): """Test the extraction of B3 headers """ - trace_id = str(trace.generate_trace_id()) - span_id = str(trace.generate_span_id()) carrier = { - FORMAT.TRACE_ID_KEY: trace_id, - FORMAT.SPAN_ID_KEY: span_id, + FORMAT.TRACE_ID_KEY: self.trace_id, + FORMAT.SPAN_ID_KEY: self.span_id, FORMAT.SAMPLED_KEY: "1", } - span_context = FORMAT.extract(_get_from_dict, carrier) + span_context = FORMAT.extract(dict.get, carrier) new_carrier = {} - FORMAT.inject(span_context, _set_into_dict, new_carrier) - self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], trace_id) - self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], span_id) + FORMAT.inject(span_context, dict.__setitem__, new_carrier) + self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], self.trace_id) + self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], self.span_id) self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") - def test_extract_single_headder(self): + def test_extract_single_header(self): """Test the extraction from a single b3 header""" - trace_id = str(trace.generate_trace_id()) - span_id = str(trace.generate_span_id()) - carrier = {FORMAT.SINGLE_HEADER_KEY: "{}-{}".format(trace_id, span_id)} - span_context = FORMAT.extract(_get_from_dict, carrier) + carrier = { + FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + self.trace_id, self.span_id + ) + } + span_context = FORMAT.extract(dict.get, carrier) new_carrier = {} - FORMAT.inject(span_context, _set_into_dict, new_carrier) - self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], trace_id) - self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], span_id) + FORMAT.inject(span_context, dict.__setitem__, new_carrier) + self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], self.trace_id) + self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], self.span_id) self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + + def test_extract_header_precedence(self): + """A single b3 header should take precedence over multiple + headers. + """ + single_header_trace_id = self.trace_id[:-3] + "123" + carrier = { + FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + single_header_trace_id, self.span_id + ), + FORMAT.TRACE_ID_KEY: self.trace_id, + FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.SAMPLED_KEY: "1", + } + span_context = FORMAT.extract(dict.get, carrier) + new_carrier = {} + FORMAT.inject(span_context, dict.__setitem__, new_carrier) + self.assertEqual( + new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id + ) + + def test_enabled_sampling(self): + """Test b3 sample key variants that turn on sampling. """ + for variant in ["1", "True", "true", "d"]: + carrier = { + FORMAT.TRACE_ID_KEY: self.trace_id, + FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.SAMPLED_KEY: variant, + } + span_context = FORMAT.extract(dict.get, carrier) + new_carrier = {} + FORMAT.inject(span_context, dict.__setitem__, new_carrier) + self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + + def test_disabled_sampling(self): + """Test b3 sample key variants that turn off sampling. """ + for variant in ["0", "False", "false", None]: + carrier = { + FORMAT.TRACE_ID_KEY: self.trace_id, + FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.SAMPLED_KEY: variant, + } + span_context = FORMAT.extract(dict.get, carrier) + new_carrier = {} + FORMAT.inject(span_context, dict.__setitem__, new_carrier) + self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "0") + + def test_flags(self): + """x-b3-flags set to "1" should result in propagation. """ + carrier = { + FORMAT.TRACE_ID_KEY: self.trace_id, + FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.FLAGS_KEY: "1", + } + span_context = FORMAT.extract(dict.get, carrier) + new_carrier = {} + FORMAT.inject(span_context, dict.__setitem__, new_carrier) + self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + + def test_flags_and_sampling(self): + """Propagate if b3 flags and sampling are set. + """ + carrier = { + FORMAT.TRACE_ID_KEY: self.trace_id, + FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.FLAGS_KEY: "1", + } + span_context = FORMAT.extract(dict.get, carrier) + new_carrier = {} + FORMAT.inject(span_context, dict.__setitem__, new_carrier) + self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + + def test_64bit_trace_id(self): + """64 bit trace ids should be padded to 128 bit + trace ids.""" + trace_id_64_bit = self.trace_id[:16] + carrier = { + FORMAT.TRACE_ID_KEY: trace_id_64_bit, + FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.FLAGS_KEY: "1", + } + span_context = FORMAT.extract(dict.get, carrier) + new_carrier = {} + FORMAT.inject(span_context, dict.__setitem__, new_carrier) + self.assertEqual( + new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit + ) + + def test_invalid_single_header(self): + """If an invalid single header is passed, return an + invalid SpanContext. + """ + carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} + span_context = FORMAT.extract(dict.get, carrier) + self.assertEqual(span_context.trace_id, api_trace.INVALID_TRACE_ID) + self.assertEqual(span_context.span_id, api_trace.INVALID_SPAN_ID) + + def test_missing_trace_id(self): + """If a trace id is missing, populate an invalid trace + id.""" + carrier = {FORMAT.SPAN_ID_KEY: self.span_id, FORMAT.FLAGS_KEY: "1"} + span_context = FORMAT.extract(dict.get, carrier) + self.assertEqual(span_context.trace_id, api_trace.INVALID_TRACE_ID) + + def test_missing_span_id(self): + """If a trace id is missing, populate an invalid trace + id.""" + carrier = {FORMAT.TRACE_ID_KEY: self.trace_id, FORMAT.FLAGS_KEY: "1"} + span_context = FORMAT.extract(dict.get, carrier) + self.assertEqual(span_context.span_id, api_trace.INVALID_SPAN_ID) From 90080f2765bcd222aff9511061f389d92c35b39f Mon Sep 17 00:00:00 2001 From: Yusuke Tsutsumi Date: Mon, 12 Aug 2019 21:30:32 -0700 Subject: [PATCH 4/5] Addressing comments refactoring formatting of trace / spans for b3 as it's used frequently. renaming trace_id / span_id in b3 unit tests for clarity. removing unneeded int casts. --- .../sdk/context/propagation/b3_format.py | 38 ++++---- .../context/propagation/test_b3_format.py | 87 ++++++++++--------- 2 files changed, 67 insertions(+), 58 deletions(-) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py index e658632148..69db7e5f5d 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py @@ -31,8 +31,8 @@ class B3Format(HTTPTextFormat): @classmethod def extract(cls, get_from_carrier, carrier): - trace_id = trace.INVALID_TRACE_ID - span_id = trace.INVALID_SPAN_ID + trace_id = format_trace_id(trace.INVALID_TRACE_ID) + span_id = format_span_id(trace.INVALID_SPAN_ID) sampled = 0 flags = None @@ -69,18 +69,12 @@ def extract(cls, get_from_carrier, carrier): options |= trace.TraceOptions.RECORDED # trace an span ids are encoded in hex, so must be converted - if trace_id != trace.INVALID_TRACE_ID: - # Convert 64-bit trace ids to 128-bit - if len(trace_id) == 16: - trace_id = "0" * 16 + trace_id - trace_id = int(trace_id, 16) - - if span_id != trace.INVALID_SPAN_ID: - span_id = int(span_id, 16) + trace_id_as_int = int(trace_id, 16) + span_id_as_int = int(span_id, 16) return trace.SpanContext( - trace_id=int(trace_id), - span_id=int(span_id), + trace_id=trace_id_as_int, + span_id=span_id_as_int, trace_options=options, trace_state={}, ) @@ -88,10 +82,18 @@ def extract(cls, get_from_carrier, carrier): @classmethod def inject(cls, context, set_in_carrier, carrier): sampled = (trace.TraceOptions.RECORDED & context.trace_options) != 0 - set_in_carrier( - carrier, cls.TRACE_ID_KEY, "{:032x}".format(context.trace_id) - ) - set_in_carrier( - carrier, cls.SPAN_ID_KEY, "{:016x}".format(context.span_id) - ) + set_in_carrier(carrier, cls.TRACE_ID_KEY, + format_trace_id(context.trace_id)) + set_in_carrier(carrier, cls.SPAN_ID_KEY, + format_span_id(context.span_id)) set_in_carrier(carrier, cls.SAMPLED_KEY, "1" if sampled else "0") + + +def format_trace_id(trace_id: int): + """Format the trace id according to b3 specification.""" + return format(trace_id, "032x") + + +def format_span_id(span_id: int): + """Format the span id according to b3 specification.""" + return format(span_id, "016x") diff --git a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py index d9f12919a7..8357450f7c 100644 --- a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py @@ -23,67 +23,69 @@ class TestB3Format(unittest.TestCase): @classmethod def setUpClass(cls): - trace_id = trace.generate_trace_id() - cls.trace_id = "{:032x}".format(trace_id) - cls.trace_id_internal = str(trace_id) - span_id = trace.generate_span_id() - cls.span_id = "{:016x}".format(span_id) - cls.span_id_internal = str(span_id) + cls.serialized_trace_id = b3_format.format_trace_id( + trace.generate_trace_id()) + cls.serialized_span_id = b3_format.format_span_id( + trace.generate_span_id()) def test_extract_multi_header(self): """Test the extraction of B3 headers """ carrier = { - FORMAT.TRACE_ID_KEY: self.trace_id, - FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: "1", } span_context = FORMAT.extract(dict.get, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], self.trace_id) - self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], self.span_id) + self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], + self.serialized_trace_id) + self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], + self.serialized_span_id) self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") def test_extract_single_header(self): """Test the extraction from a single b3 header""" carrier = { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( - self.trace_id, self.span_id - ) + FORMAT.SINGLE_HEADER_KEY: + "{}-{}".format(self.serialized_trace_id, self.serialized_span_id) } span_context = FORMAT.extract(dict.get, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], self.trace_id) - self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], self.span_id) + self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], + self.serialized_trace_id) + self.assertEqual(new_carrier[FORMAT.SPAN_ID_KEY], + self.serialized_span_id) self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") def test_extract_header_precedence(self): """A single b3 header should take precedence over multiple headers. """ - single_header_trace_id = self.trace_id[:-3] + "123" + single_header_trace_id = self.serialized_trace_id[:-3] + "123" carrier = { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( - single_header_trace_id, self.span_id - ), - FORMAT.TRACE_ID_KEY: self.trace_id, - FORMAT.SPAN_ID_KEY: self.span_id, - FORMAT.SAMPLED_KEY: "1", + FORMAT.SINGLE_HEADER_KEY: + "{}-{}".format(single_header_trace_id, self.serialized_span_id), + FORMAT.TRACE_ID_KEY: + self.serialized_trace_id, + FORMAT.SPAN_ID_KEY: + self.serialized_span_id, + FORMAT.SAMPLED_KEY: + "1", } span_context = FORMAT.extract(dict.get, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id - ) + self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], + single_header_trace_id) def test_enabled_sampling(self): """Test b3 sample key variants that turn on sampling. """ for variant in ["1", "True", "true", "d"]: carrier = { - FORMAT.TRACE_ID_KEY: self.trace_id, - FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: variant, } span_context = FORMAT.extract(dict.get, carrier) @@ -95,8 +97,8 @@ def test_disabled_sampling(self): """Test b3 sample key variants that turn off sampling. """ for variant in ["0", "False", "false", None]: carrier = { - FORMAT.TRACE_ID_KEY: self.trace_id, - FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: variant, } span_context = FORMAT.extract(dict.get, carrier) @@ -107,8 +109,8 @@ def test_disabled_sampling(self): def test_flags(self): """x-b3-flags set to "1" should result in propagation. """ carrier = { - FORMAT.TRACE_ID_KEY: self.trace_id, - FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } span_context = FORMAT.extract(dict.get, carrier) @@ -120,8 +122,8 @@ def test_flags_and_sampling(self): """Propagate if b3 flags and sampling are set. """ carrier = { - FORMAT.TRACE_ID_KEY: self.trace_id, - FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } span_context = FORMAT.extract(dict.get, carrier) @@ -132,18 +134,17 @@ def test_flags_and_sampling(self): def test_64bit_trace_id(self): """64 bit trace ids should be padded to 128 bit trace ids.""" - trace_id_64_bit = self.trace_id[:16] + trace_id_64_bit = self.serialized_trace_id[:16] carrier = { FORMAT.TRACE_ID_KEY: trace_id_64_bit, - FORMAT.SPAN_ID_KEY: self.span_id, + FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } span_context = FORMAT.extract(dict.get, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit - ) + self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], + "0" * 16 + trace_id_64_bit) def test_invalid_single_header(self): """If an invalid single header is passed, return an @@ -157,13 +158,19 @@ def test_invalid_single_header(self): def test_missing_trace_id(self): """If a trace id is missing, populate an invalid trace id.""" - carrier = {FORMAT.SPAN_ID_KEY: self.span_id, FORMAT.FLAGS_KEY: "1"} + carrier = { + FORMAT.SPAN_ID_KEY: self.serialized_span_id, + FORMAT.FLAGS_KEY: "1" + } span_context = FORMAT.extract(dict.get, carrier) self.assertEqual(span_context.trace_id, api_trace.INVALID_TRACE_ID) def test_missing_span_id(self): """If a trace id is missing, populate an invalid trace id.""" - carrier = {FORMAT.TRACE_ID_KEY: self.trace_id, FORMAT.FLAGS_KEY: "1"} + carrier = { + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.FLAGS_KEY: "1" + } span_context = FORMAT.extract(dict.get, carrier) self.assertEqual(span_context.span_id, api_trace.INVALID_SPAN_ID) From 529e1845c5787b95caad18e2ffc2b75ff0d1bb18 Mon Sep 17 00:00:00 2001 From: Yusuke Tsutsumi Date: Tue, 13 Aug 2019 20:37:35 -0700 Subject: [PATCH 5/5] Modifying HTTPTextFormat getter to retrieve multiple values HTTP Headers can contains multiple values for the same key. This is important to support for formats such as w3c tracestate. --- .../context/propagation/httptextformat.py | 17 ++++--- .../sdk/context/propagation/b3_format.py | 32 +++++++----- .../context/propagation/test_b3_format.py | 49 ++++++++++--------- 3 files changed, 55 insertions(+), 43 deletions(-) diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py index 2e7862c70a..860498fe35 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py +++ b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py @@ -18,7 +18,7 @@ from opentelemetry.trace import SpanContext Setter = typing.Callable[[object, str, str], None] -Getter = typing.Callable[[object, str], typing.Optional[str]] +Getter = typing.Callable[[object, str], typing.List[str]] class HTTPTextFormat(abc.ABC): @@ -41,7 +41,7 @@ class HTTPTextFormat(abc.ABC): def get_header_from_flask_request(request, key): - return request.headers[key] + return request.headers.get_all(key) def set_header_into_requests_request(request: requests.Request, key: str, value: str): @@ -77,12 +77,13 @@ def extract(self, get_from_carrier: Getter, SpanContext value and return it. Args: - get_from_carrier: a function that can retrieve a value - in the carrier, or return None if not + get_from_carrier: a function that can retrieve zero + or more values from the carrier. In the case that + the value does not exist, return an empty list. carrier: and object which contains values that are used to construct a SpanContext. This object must be paired with an appropriate get_from_carrier - which understands how to extract a value from it + which understands how to extract a value from it. Returns: A SpanContext with configuration found in the carrier. @@ -98,11 +99,11 @@ def inject(self, context: SpanContext, set_in_carrier: Setter, carrier. Args: - context: The SpanContext to read values from + context: The SpanContext to read values from. set_in_carrier: A setter function that can set values - on the carrier + on the carrier. carrier: An object that a place to define HTTP headers. Should be paired with set_in_carrier, which should - know how to set header values on the carrier + know how to set header values on the carrier. """ diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py index 69db7e5f5d..eaeeb577d2 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from opentelemetry.context.propagation.httptextformat import HTTPTextFormat import opentelemetry.trace as trace @@ -36,7 +38,8 @@ def extract(cls, get_from_carrier, carrier): sampled = 0 flags = None - single_header = get_from_carrier(carrier, cls.SINGLE_HEADER_KEY) + single_header = _extract_first_element( + get_from_carrier(carrier, cls.SINGLE_HEADER_KEY)) if single_header: # The b3 spec calls for the sampling state to be # "deferred", which is unspecified. This concept does not @@ -55,10 +58,14 @@ def extract(cls, get_from_carrier, carrier): else: return trace.INVALID_SPAN_CONTEXT else: - trace_id = get_from_carrier(carrier, cls.TRACE_ID_KEY) or trace_id - span_id = get_from_carrier(carrier, cls.SPAN_ID_KEY) or span_id - sampled = get_from_carrier(carrier, cls.SAMPLED_KEY) or sampled - flags = get_from_carrier(carrier, cls.FLAGS_KEY) or flags + trace_id = _extract_first_element( + get_from_carrier(carrier, cls.TRACE_ID_KEY)) or trace_id + span_id = _extract_first_element( + get_from_carrier(carrier, cls.SPAN_ID_KEY)) or span_id + sampled = _extract_first_element( + get_from_carrier(carrier, cls.SAMPLED_KEY)) or sampled + flags = _extract_first_element( + get_from_carrier(carrier, cls.FLAGS_KEY)) or flags options = 0 # The b3 spec provides no defined behavior for both sample and @@ -68,13 +75,10 @@ def extract(cls, get_from_carrier, carrier): if sampled in cls._SAMPLE_PROPAGATE_VALUES or flags == "1": options |= trace.TraceOptions.RECORDED - # trace an span ids are encoded in hex, so must be converted - trace_id_as_int = int(trace_id, 16) - span_id_as_int = int(span_id, 16) - return trace.SpanContext( - trace_id=trace_id_as_int, - span_id=span_id_as_int, + # trace an span ids are encoded in hex, so must be converted + trace_id=int(trace_id, 16), + span_id=int(span_id, 16), trace_options=options, trace_state={}, ) @@ -97,3 +101,9 @@ def format_trace_id(trace_id: int): def format_span_id(span_id: int): """Format the span id according to b3 specification.""" return format(span_id, "016x") + + +def _extract_first_element(list_object: list) -> typing.Optional[object]: + if list_object: + return list_object[0] + return None diff --git a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py index 8357450f7c..a24dd01c66 100644 --- a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py @@ -20,6 +20,11 @@ FORMAT = b3_format.B3Format() +def get_as_list(dict_object, key): + value = dict_object.get(key) + return [value] if value is not None else [] + + class TestB3Format(unittest.TestCase): @classmethod def setUpClass(cls): @@ -29,13 +34,13 @@ def setUpClass(cls): trace.generate_span_id()) def test_extract_multi_header(self): - """Test the extraction of B3 headers """ + """Test the extraction of B3 headers.""" carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: "1", } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], @@ -45,12 +50,12 @@ def test_extract_multi_header(self): self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") def test_extract_single_header(self): - """Test the extraction from a single b3 header""" + """Test the extraction from a single b3 header.""" carrier = { FORMAT.SINGLE_HEADER_KEY: "{}-{}".format(self.serialized_trace_id, self.serialized_span_id) } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], @@ -74,73 +79,71 @@ def test_extract_header_precedence(self): FORMAT.SAMPLED_KEY: "1", } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id) def test_enabled_sampling(self): - """Test b3 sample key variants that turn on sampling. """ + """Test b3 sample key variants that turn on sampling.""" for variant in ["1", "True", "true", "d"]: carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: variant, } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") def test_disabled_sampling(self): - """Test b3 sample key variants that turn off sampling. """ + """Test b3 sample key variants that turn off sampling.""" for variant in ["0", "False", "false", None]: carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: variant, } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "0") def test_flags(self): - """x-b3-flags set to "1" should result in propagation. """ + """x-b3-flags set to "1" should result in propagation.""" carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") def test_flags_and_sampling(self): - """Propagate if b3 flags and sampling are set. - """ + """Propagate if b3 flags and sampling are set.""" carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") def test_64bit_trace_id(self): - """64 bit trace ids should be padded to 128 bit - trace ids.""" + """64 bit trace ids should be padded to 128 bit trace ids.""" trace_id_64_bit = self.serialized_trace_id[:16] carrier = { FORMAT.TRACE_ID_KEY: trace_id_64_bit, FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) new_carrier = {} FORMAT.inject(span_context, dict.__setitem__, new_carrier) self.assertEqual(new_carrier[FORMAT.TRACE_ID_KEY], @@ -151,26 +154,24 @@ def test_invalid_single_header(self): invalid SpanContext. """ carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) self.assertEqual(span_context.trace_id, api_trace.INVALID_TRACE_ID) self.assertEqual(span_context.span_id, api_trace.INVALID_SPAN_ID) def test_missing_trace_id(self): - """If a trace id is missing, populate an invalid trace - id.""" + """If a trace id is missing, populate an invalid trace id.""" carrier = { FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1" } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) self.assertEqual(span_context.trace_id, api_trace.INVALID_TRACE_ID) def test_missing_span_id(self): - """If a trace id is missing, populate an invalid trace - id.""" + """If a trace id is missing, populate an invalid trace id.""" carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.FLAGS_KEY: "1" } - span_context = FORMAT.extract(dict.get, carrier) + span_context = FORMAT.extract(get_as_list, carrier) self.assertEqual(span_context.span_id, api_trace.INVALID_SPAN_ID)