Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring SQS Propagation in line with the spec #1673

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ classifiers = [
dependencies = [
"opentelemetry-api ~= 1.12",
"opentelemetry-instrumentation == 0.37b0.dev",
"opentelemetry-propagator-aws-xray == 1.0.1",
"opentelemetry-semantic-conventions == 0.37b0.dev",
"wrapt >= 1.0.0, < 2.0.0",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@
import botocore.client
from wrapt import wrap_function_wrapper

from opentelemetry import context, propagate, trace
from opentelemetry import context, trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import (
_SUPPRESS_INSTRUMENTATION_KEY,
unwrap,
)
from opentelemetry.propagators.aws.aws_xray_propagator import (
TRACE_HEADER_KEY,
AwsXRayPropagator,
)
from opentelemetry.propagators.textmap import CarrierT, Getter, Setter
from opentelemetry.semconv.trace import (
MessagingDestinationKindValues,
Expand All @@ -56,9 +60,14 @@

_IS_SQS_INSTRUMENTED_ATTRIBUTE = "_otel_boto3sqs_instrumented"

_AWS_TRACE_HEADER = "AWSTraceHeader"


class Boto3SQSGetter(Getter[CarrierT]):
def get(self, carrier: CarrierT, key: str) -> Optional[List[str]]:
if key == TRACE_HEADER_KEY:
key = _AWS_TRACE_HEADER

msg_attr = carrier.get(key)
if not isinstance(msg_attr, Mapping):
return None
Expand All @@ -77,6 +86,9 @@ class Boto3SQSSetter(Setter[CarrierT]):
def set(self, carrier: CarrierT, key: str, value: str) -> None:
# This is a limitation defined by AWS for SQS MessageAttributes size
if len(carrier.items()) < 10:
if key == TRACE_HEADER_KEY:
key = _AWS_TRACE_HEADER

carrier[key] = {
"StringValue": value,
"DataType": "String",
Expand Down Expand Up @@ -195,9 +207,12 @@ def _create_processing_span(
receipt_handle: str,
message: Dict[str, Any],
) -> None:
message_attributes = message.get("MessageAttributes", {})
message_system_attributes = message.get("MessageSystemAttributes", {})
links = []
ctx = propagate.extract(message_attributes, getter=boto3sqs_getter)
ctx = AwsXRayPropagator().extract(
message_system_attributes, getter=boto3sqs_getter
)

parent_span_ctx = trace.get_current_span(ctx).get_span_context()
if parent_span_ctx.is_valid:
links.append(Link(context=parent_span_ctx))
Expand Down Expand Up @@ -232,9 +247,11 @@ def send_wrapper(wrapped, instance, args, kwargs):
end_on_exit=True,
) as span:
Boto3SQSInstrumentor._enrich_span(span, queue_name, queue_url)
attributes = kwargs.pop("MessageAttributes", {})
propagate.inject(attributes, setter=boto3sqs_setter)
retval = wrapped(*args, MessageAttributes=attributes, **kwargs)
attributes = kwargs.pop("MessageSystemAttributes", {})
AwsXRayPropagator().inject(attributes, setter=boto3sqs_setter)
retval = wrapped(
*args, MessageSystemAttributes=attributes, **kwargs
)
message_id = retval.get("MessageId")
if message_id:
if span.is_recording():
Expand Down Expand Up @@ -271,10 +288,11 @@ def send_batch_wrapper(wrapped, instance, args, kwargs):
span, queue_name, queue_url, conversation_id=entry_id
)
with trace.use_span(span):
if "MessageAttributes" not in entry:
entry["MessageAttributes"] = {}
propagate.inject(
entry["MessageAttributes"], setter=boto3sqs_setter
if "MessageSystemAttributes" not in entry:
entry["MessageSystemAttributes"] = {}
AwsXRayPropagator().inject(
entry["MessageSystemAttributes"],
setter=boto3sqs_setter,
)
retval = wrapped(*args, **kwargs)
for successful_messages in retval["Successful"]:
Expand All @@ -297,10 +315,9 @@ def send_batch_wrapper(wrapped, instance, args, kwargs):
def _wrap_receive_message(self, sqs_class: type) -> None:
def receive_message_wrapper(wrapped, instance, args, kwargs):
queue_url = kwargs.get("QueueUrl")
message_attribute_names = kwargs.pop("MessageAttributeNames", [])
message_attribute_names.extend(
propagate.get_global_textmap().fields
)
attribute_names = kwargs.pop("AttributeNames", [])
attribute_names.extend(_AWS_TRACE_HEADER)

queue_name = Boto3SQSInstrumentor._extract_queue_name_from_url(
queue_url
)
Expand All @@ -317,7 +334,7 @@ def receive_message_wrapper(wrapped, instance, args, kwargs):
)
retval = wrapped(
*args,
MessageAttributeNames=message_attribute_names,
AttributeNames=attribute_names,
**kwargs,
)
messages = retval.get("Messages", [])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
Boto3SQSInstrumentor,
Boto3SQSSetter,
)
from opentelemetry.propagators.aws.aws_xray_propagator import (
TRACE_HEADER_KEY,
)
from opentelemetry.semconv.trace import (
MessagingDestinationKindValues,
MessagingOperationValues,
Expand All @@ -36,6 +39,8 @@
from opentelemetry.trace import SpanKind
from opentelemetry.trace.span import Span, format_span_id, format_trace_id

_AWS_TRACE_HEADER = "AWSTraceHeader"


def _make_sqs_client():
return boto3.client(
Expand Down Expand Up @@ -166,7 +171,7 @@ def _mocked_endpoint(self, response):
yield

def _assert_injected_span(self, msg_attrs: Dict[str, Any], span: Span):
trace_parent = msg_attrs["traceparent"]["StringValue"]
trace_parent = msg_attrs[_AWS_TRACE_HEADER]["StringValue"]
ctx = span.get_span_context()
self.assertEqual(
self._to_trace_parent(ctx.trace_id, ctx.span_id),
Expand All @@ -183,7 +188,9 @@ def _default_span_attrs(self):

@staticmethod
def _to_trace_parent(trace_id: int, span_id: int) -> str:
return f"00-{format_trace_id(trace_id)}-{format_span_id(span_id)}-01".lower()
formated_trace_id = format_trace_id(trace_id)
formated_trace_id = formated_trace_id[:8] + "-" + formated_trace_id[8:]
return f"root=1-{formated_trace_id};parent={format_span_id(span_id)};sampled=1".lower()

def _get_only_span(self):
spans = self.get_finished_spans()
Expand All @@ -199,13 +206,15 @@ def _make_message(message_id: str, body: str, receipt: str):
"Body": body,
"Attributes": {},
"MD5OfMessageAttributes": "111",
"MD5OfMessageSystemAttributes": "9012",
"MessageAttributes": {},
"MessageSystemAttributes": {},
}

def _add_trace_parent(
def _add_xray_parent(
self, message: Dict[str, Any], trace_id: int, span_id: int
):
message["MessageAttributes"]["traceparent"] = {
message["MessageSystemAttributes"][_AWS_TRACE_HEADER] = {
"StringValue": self._to_trace_parent(trace_id, span_id),
"DataType": "String",
}
Expand All @@ -221,12 +230,14 @@ def test_send_message(self):
}

message_attrs = {}
message_system_attrs = {}

with self._mocked_endpoint(mock_response):
self._client.send_message(
QueueUrl=self._queue_url,
MessageBody="hello msg",
MessageAttributes=message_attrs,
MessageSystemAttributes=message_system_attrs,
)

span = self._get_only_span()
Expand All @@ -239,7 +250,7 @@ def test_send_message(self):
},
span.attributes,
)
self._assert_injected_span(message_attrs, span)
self._assert_injected_span(message_system_attrs, span)

def test_receive_message(self):
msg_def = {
Expand All @@ -252,9 +263,7 @@ def test_receive_message(self):
message = self._make_message(
msg_id, f"hello {msg_id}", attrs["receipt"]
)
self._add_trace_parent(
message, attrs["trace_id"], attrs["span_id"]
)
self._add_xray_parent(message, attrs["trace_id"], attrs["span_id"])
mock_response["Messages"].append(message)

message_attr_names = []
Expand All @@ -265,8 +274,6 @@ def test_receive_message(self):
MessageAttributeNames=message_attr_names,
)

self.assertIn("traceparent", message_attr_names)

# receive span
span = self._get_only_span()
self.assertEqual(f"{self._queue_name} receive", span.name)
Expand Down