Skip to content

Commit

Permalink
sqs: use AWSTraceHeader for xray propagation in MessageSystemAttributes
Browse files Browse the repository at this point in the history
  • Loading branch information
tsloughter committed Feb 14, 2023
1 parent 8d968dd commit 5b86396
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,14 @@

_IS_SQS_INSTRUMENTED_ATTRIBUTE = "_otel_boto3sqs_instrumented"

_AWS_TRACE_HEADER = "AWSTraceHeader"


class Boto3SQSGetter(Getter[CarrierT]):
def get(self, carrier: CarrierT, key: str) -> Optional[List[str]]:
if key == TRACE_HEADER_KEY:
key = _AWS_TRACE_HEADER

msg_attr = carrier.get(key)
if not isinstance(msg_attr, Mapping):
return None
Expand All @@ -81,6 +86,9 @@ class Boto3SQSSetter(Setter[CarrierT]):
def set(self, carrier: CarrierT, key: str, value: str) -> None:
# This is a limitation defined by AWS for SQS MessageAttributes size
if len(carrier.items()) < 10:
if key == TRACE_HEADER_KEY:
key = _AWS_TRACE_HEADER

carrier[key] = {
"StringValue": value,
"DataType": "String",
Expand Down Expand Up @@ -199,13 +207,12 @@ def _create_processing_span(
receipt_handle: str,
message: Dict[str, Any],
) -> None:
message_attributes = message.get("MessageAttributes", {})
message_system_attributes = message.get("MessageSystemAttributes", {})
links = []
print(message_attributes)
ctx = AwsXRayPropagator().extract(
message_attributes, getter=boto3sqs_getter
message_system_attributes, getter=boto3sqs_getter
)
print(ctx)

parent_span_ctx = trace.get_current_span(ctx).get_span_context()
if parent_span_ctx.is_valid:
links.append(Link(context=parent_span_ctx))
Expand Down Expand Up @@ -240,9 +247,11 @@ def send_wrapper(wrapped, instance, args, kwargs):
end_on_exit=True,
) as span:
Boto3SQSInstrumentor._enrich_span(span, queue_name, queue_url)
attributes = kwargs.pop("MessageAttributes", {})
attributes = kwargs.pop("MessageSystemAttributes", {})
AwsXRayPropagator().inject(attributes, setter=boto3sqs_setter)
retval = wrapped(*args, MessageAttributes=attributes, **kwargs)
retval = wrapped(
*args, MessageSystemAttributes=attributes, **kwargs
)
message_id = retval.get("MessageId")
if message_id:
if span.is_recording():
Expand Down Expand Up @@ -279,10 +288,11 @@ def send_batch_wrapper(wrapped, instance, args, kwargs):
span, queue_name, queue_url, conversation_id=entry_id
)
with trace.use_span(span):
if "MessageAttributes" not in entry:
entry["MessageAttributes"] = {}
if "MessageSystemAttributes" not in entry:
entry["MessageSystemAttributes"] = {}
AwsXRayPropagator().inject(
entry["MessageAttributes"], setter=boto3sqs_setter
entry["MessageSystemAttributes"],
setter=boto3sqs_setter,
)
retval = wrapped(*args, **kwargs)
for successful_messages in retval["Successful"]:
Expand All @@ -305,10 +315,9 @@ def send_batch_wrapper(wrapped, instance, args, kwargs):
def _wrap_receive_message(self, sqs_class: type) -> None:
def receive_message_wrapper(wrapped, instance, args, kwargs):
queue_url = kwargs.get("QueueUrl")
message_attribute_names = kwargs.pop("MessageAttributeNames", [])
message_attribute_names.extend(
[TRACE_HEADER_KEY]
)
attribute_names = kwargs.pop("AttributeNames", [])
attribute_names.extend(_AWS_TRACE_HEADER)

queue_name = Boto3SQSInstrumentor._extract_queue_name_from_url(
queue_url
)
Expand All @@ -325,7 +334,7 @@ def receive_message_wrapper(wrapped, instance, args, kwargs):
)
retval = wrapped(
*args,
MessageAttributeNames=message_attribute_names,
AttributeNames=attribute_names,
**kwargs,
)
messages = retval.get("Messages", [])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from opentelemetry.trace import SpanKind
from opentelemetry.trace.span import Span, format_span_id, format_trace_id

_AWS_TRACE_HEADER = "AWSTraceHeader"


def _make_sqs_client():
return boto3.client(
Expand Down Expand Up @@ -169,7 +171,7 @@ def _mocked_endpoint(self, response):
yield

def _assert_injected_span(self, msg_attrs: Dict[str, Any], span: Span):
trace_parent = msg_attrs[TRACE_HEADER_KEY]["StringValue"]
trace_parent = msg_attrs[_AWS_TRACE_HEADER]["StringValue"]
ctx = span.get_span_context()
self.assertEqual(
self._to_trace_parent(ctx.trace_id, ctx.span_id),
Expand All @@ -187,7 +189,7 @@ def _default_span_attrs(self):
@staticmethod
def _to_trace_parent(trace_id: int, span_id: int) -> str:
formated_trace_id = format_trace_id(trace_id)
formated_trace_id = formated_trace_id[:8] + '-' + formated_trace_id[8:]
formated_trace_id = formated_trace_id[:8] + "-" + formated_trace_id[8:]
return f"root=1-{formated_trace_id};parent={format_span_id(span_id)};sampled=1".lower()

def _get_only_span(self):
Expand All @@ -204,13 +206,15 @@ def _make_message(message_id: str, body: str, receipt: str):
"Body": body,
"Attributes": {},
"MD5OfMessageAttributes": "111",
"MD5OfMessageSystemAttributes": "9012",
"MessageAttributes": {},
"MessageSystemAttributes": {},
}

def _add_xray_parent(
self, message: Dict[str, Any], trace_id: int, span_id: int
):
message["MessageAttributes"][TRACE_HEADER_KEY] = {
message["MessageSystemAttributes"][_AWS_TRACE_HEADER] = {
"StringValue": self._to_trace_parent(trace_id, span_id),
"DataType": "String",
}
Expand All @@ -226,12 +230,14 @@ def test_send_message(self):
}

message_attrs = {}
message_system_attrs = {}

with self._mocked_endpoint(mock_response):
self._client.send_message(
QueueUrl=self._queue_url,
MessageBody="hello msg",
MessageAttributes=message_attrs,
MessageSystemAttributes=message_system_attrs,
)

span = self._get_only_span()
Expand All @@ -244,7 +250,7 @@ def test_send_message(self):
},
span.attributes,
)
self._assert_injected_span(message_attrs, span)
self._assert_injected_span(message_system_attrs, span)

def test_receive_message(self):
msg_def = {
Expand All @@ -257,9 +263,7 @@ def test_receive_message(self):
message = self._make_message(
msg_id, f"hello {msg_id}", attrs["receipt"]
)
self._add_xray_parent(
message, attrs["trace_id"], attrs["span_id"]
)
self._add_xray_parent(message, attrs["trace_id"], attrs["span_id"])
mock_response["Messages"].append(message)

message_attr_names = []
Expand All @@ -270,8 +274,6 @@ def test_receive_message(self):
MessageAttributeNames=message_attr_names,
)

self.assertIn(TRACE_HEADER_KEY, message_attr_names)

# receive span
span = self._get_only_span()
self.assertEqual(f"{self._queue_name} receive", span.name)
Expand Down

0 comments on commit 5b86396

Please sign in to comment.