Skip to content

Commit

Permalink
Refactor inject and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ocelotl committed Dec 4, 2019
1 parent 698a68b commit 09f6c35
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class B3Format(HTTPTextFormat):
SINGLE_HEADER_KEY = "b3"
TRACE_ID_KEY = "x-b3-traceid"
SPAN_ID_KEY = "x-b3-spanid"
PARENT_SPAN_ID_KEY = "x-b3-parentspanid"
SAMPLED_KEY = "x-b3-sampled"
FLAGS_KEY = "x-b3-flags"
_SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"])
Expand Down Expand Up @@ -55,7 +56,7 @@ def extract(cls, get_from_carrier, carrier):
elif len(fields) == 3:
trace_id, span_id, sampled = fields
elif len(fields) == 4:
trace_id, span_id, sampled, _parent_span_id = fields
trace_id, span_id, sampled, _ = fields
else:
return trace.INVALID_SPAN_CONTEXT
else:
Expand Down Expand Up @@ -100,13 +101,22 @@ def extract(cls, get_from_carrier, carrier):
)

@classmethod
def inject(cls, context, set_in_carrier, carrier):
sampled = (trace.TraceOptions.SAMPLED & context.trace_options) != 0
def inject(
cls, span, set_in_carrier, carrier
): # pylint: disable=arguments-differ
sampled = (
trace.TraceOptions.SAMPLED & span.context.trace_options
) != 0
set_in_carrier(
carrier, cls.TRACE_ID_KEY, format_trace_id(context.trace_id)
carrier, cls.TRACE_ID_KEY, format_trace_id(span.context.trace_id)
)
set_in_carrier(
carrier, cls.SPAN_ID_KEY, format_span_id(context.span_id)
carrier, cls.SPAN_ID_KEY, format_span_id(span.context.span_id)
)
set_in_carrier(
carrier,
cls.PARENT_SPAN_ID_KEY,
format_span_id(span.parent.context.span_id),
)
set_in_carrier(carrier, cls.SAMPLED_KEY, "1" if sampled else "0")

Expand Down
5 changes: 4 additions & 1 deletion opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def __repr__(self):
)

def __str__(self):
return '{}(name="{}", context={}, kind={}, parent={}, start_time={}, end_time={})'.format(
return (
'{}(name="{}", context={}, kind={}, parent={}, '
"start_time={}, end_time={})"
).format(
type(self).__name__,
self.name,
self.context,
Expand Down
85 changes: 59 additions & 26 deletions opentelemetry-sdk/tests/context/propagation/test_b3_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,23 @@ def test_extract_multi_header(self):
FORMAT.PARENT_SPAN_ID_KEY: self.serialized_parent_span_id,
FORMAT.SAMPLED_KEY: "1",
}
span_context = FORMAT.extract(get_as_list, carrier)
parent_span = trace.Span(
"parent", FORMAT.extract(get_as_list, carrier)
)
child_span = trace.Tracer().start_span("child", parent=parent_span)
new_carrier = {}
FORMAT.inject(span_context, dict.__setitem__, new_carrier)
FORMAT.inject(child_span, dict.__setitem__, new_carrier)
self.assertEqual(
new_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id
new_carrier[FORMAT.TRACE_ID_KEY],
b3_format.format_trace_id(child_span.context.trace_id),
)
self.assertEqual(
new_carrier[FORMAT.SPAN_ID_KEY], self.serialized_span_id
new_carrier[FORMAT.SPAN_ID_KEY],
b3_format.format_span_id(child_span.context.span_id),
)
self.assertEqual(
new_carrier[FORMAT.PARENT_SPAN_ID_KEY],
self.serialized_parent_span_id,
b3_format.format_span_id(parent_span.context.span_id),
)
self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

Expand All @@ -69,14 +74,19 @@ def test_extract_single_header(self):
self.serialized_trace_id, self.serialized_span_id
)
}
span_context = FORMAT.extract(get_as_list, carrier)
parent_span = trace.Span(
"parent", FORMAT.extract(get_as_list, carrier)
)
child_span = trace.Tracer().start_span("child", parent=parent_span)
new_carrier = {}
FORMAT.inject(span_context, dict.__setitem__, new_carrier)
FORMAT.inject(child_span, dict.__setitem__, new_carrier)
self.assertEqual(
new_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id
new_carrier[FORMAT.TRACE_ID_KEY],
b3_format.format_trace_id(child_span.context.trace_id),
)
self.assertEqual(
new_carrier[FORMAT.SPAN_ID_KEY], self.serialized_span_id
new_carrier[FORMAT.SPAN_ID_KEY],
b3_format.format_span_id(child_span.context.span_id),
)
self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

Expand All @@ -87,18 +97,23 @@ def test_extract_single_header(self):
self.serialized_parent_span_id,
)
}
span_context = FORMAT.extract(get_as_list, carrier)
parent_span = trace.Span(
"parent", FORMAT.extract(get_as_list, carrier)
)
child_span = trace.Tracer().start_span("child", parent=parent_span)
new_carrier = {}
FORMAT.inject(span_context, dict.__setitem__, new_carrier)
FORMAT.inject(child_span, dict.__setitem__, new_carrier)
self.assertEqual(
new_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id
new_carrier[FORMAT.TRACE_ID_KEY],
b3_format.format_trace_id(child_span.context.trace_id),
)
self.assertEqual(
new_carrier[FORMAT.SPAN_ID_KEY], self.serialized_span_id
new_carrier[FORMAT.SPAN_ID_KEY],
b3_format.format_span_id(child_span.context.span_id),
)
self.assertEqual(
new_carrier[FORMAT.PARENT_SPAN_ID_KEY],
self.serialized_parent_span_id,
b3_format.format_span_id(parent_span.context.span_id),
)
self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

Expand All @@ -115,9 +130,12 @@ def test_extract_header_precedence(self):
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.SAMPLED_KEY: "1",
}
span_context = FORMAT.extract(get_as_list, carrier)
parent_span = trace.Span(
"parent", FORMAT.extract(get_as_list, carrier)
)
child_span = trace.Tracer().start_span("child", parent=parent_span)
new_carrier = {}
FORMAT.inject(span_context, dict.__setitem__, new_carrier)
FORMAT.inject(child_span, dict.__setitem__, new_carrier)
self.assertEqual(
new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id
)
Expand All @@ -130,9 +148,12 @@ def test_enabled_sampling(self):
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.SAMPLED_KEY: variant,
}
span_context = FORMAT.extract(get_as_list, carrier)
parent_span = trace.Span(
"parent", FORMAT.extract(get_as_list, carrier)
)
child_span = trace.Tracer().start_span("child", parent=parent_span)
new_carrier = {}
FORMAT.inject(span_context, dict.__setitem__, new_carrier)
FORMAT.inject(child_span, dict.__setitem__, new_carrier)
self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

def test_disabled_sampling(self):
Expand All @@ -143,9 +164,12 @@ def test_disabled_sampling(self):
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.SAMPLED_KEY: variant,
}
span_context = FORMAT.extract(get_as_list, carrier)
parent_span = trace.Span(
"parent", FORMAT.extract(get_as_list, carrier)
)
child_span = trace.Tracer().start_span("child", parent=parent_span)
new_carrier = {}
FORMAT.inject(span_context, dict.__setitem__, new_carrier)
FORMAT.inject(child_span, dict.__setitem__, new_carrier)
self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "0")

def test_flags(self):
Expand All @@ -155,9 +179,12 @@ def test_flags(self):
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.FLAGS_KEY: "1",
}
span_context = FORMAT.extract(get_as_list, carrier)
parent_span = trace.Span(
"parent", FORMAT.extract(get_as_list, carrier)
)
child_span = trace.Tracer().start_span("child", parent=parent_span)
new_carrier = {}
FORMAT.inject(span_context, dict.__setitem__, new_carrier)
FORMAT.inject(child_span, dict.__setitem__, new_carrier)
self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

def test_flags_and_sampling(self):
Expand All @@ -167,9 +194,12 @@ def test_flags_and_sampling(self):
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.FLAGS_KEY: "1",
}
span_context = FORMAT.extract(get_as_list, carrier)
parent_span = trace.Span(
"parent", FORMAT.extract(get_as_list, carrier)
)
child_span = trace.Tracer().start_span("child", parent=parent_span)
new_carrier = {}
FORMAT.inject(span_context, dict.__setitem__, new_carrier)
FORMAT.inject(child_span, dict.__setitem__, new_carrier)
self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

def test_64bit_trace_id(self):
Expand All @@ -180,9 +210,12 @@ def test_64bit_trace_id(self):
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.FLAGS_KEY: "1",
}
span_context = FORMAT.extract(get_as_list, carrier)
parent_span = trace.Span(
"parent", FORMAT.extract(get_as_list, carrier)
)
child_span = trace.Tracer().start_span("child", parent=parent_span)
new_carrier = {}
FORMAT.inject(span_context, dict.__setitem__, new_carrier)
FORMAT.inject(child_span, dict.__setitem__, new_carrier)
self.assertEqual(
new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit
)
Expand Down

0 comments on commit 09f6c35

Please sign in to comment.