diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index 9659e61038..df16268f51 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -338,6 +338,8 @@ async def __call__(self, scope, receive, send): if callable(self.server_request_hook): self.server_request_hook(span, scope) + server_span_context = trace.context_api.get_current() + @wraps(receive) async def wrapped_receive(): with self.tracer.start_as_current_span( @@ -371,7 +373,11 @@ async def wrapped_send(message): propagator = get_global_response_propagator() if propagator: - propagator.inject(message, setter=asgi_setter) + propagator.inject( + message, + context=server_span_context, + setter=asgi_setter, + ) await send(message) diff --git a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py index e7eb418632..aa33c34894 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py @@ -303,15 +303,8 @@ def test_traceresponse_header(self): self.seed_app(app) self.send_default_request() - # traceresponse header corresponds to http.response.start span - span = self.memory_exporter.get_finished_spans()[1] - self.assertDictEqual( - dict(span.attributes), - { - SpanAttributes.HTTP_STATUS_CODE: 200, - "type": "http.response.start", - }, - ) + span = self.memory_exporter.get_finished_spans()[-1] + self.assertEqual(trace_api.SpanKind.SERVER, span.kind) response_start, response_body, *_ = self.get_all_output() self.assertEqual(response_body["body"], b"*") @@ -427,12 +420,8 @@ def test_websocket_traceresponse_header(self): self.send_input({"type": "websocket.disconnect"}) _, socket_send, *_ = self.get_all_output() - # traceresponse header corresponds to the 2nd websocket.send span - span = self.memory_exporter.get_finished_spans()[3] - self.assertDictEqual( - dict(span.attributes), - {SpanAttributes.HTTP_STATUS_CODE: 200, "type": "websocket.send"}, - ) + span = self.memory_exporter.get_finished_spans()[-1] + self.assertEqual(trace_api.SpanKind.SERVER, span.kind) traceresponse = "00-{0}-{1}-01".format( format_trace_id(span.get_span_context().trace_id),