diff --git a/ext/opentelemetry-ext-requests/src/opentelemetry/ext/requests/__init__.py b/ext/opentelemetry-ext-requests/src/opentelemetry/ext/requests/__init__.py index c98a24cc885..2a82e9820cb 100644 --- a/ext/opentelemetry-ext-requests/src/opentelemetry/ext/requests/__init__.py +++ b/ext/opentelemetry-ext-requests/src/opentelemetry/ext/requests/__init__.py @@ -55,7 +55,7 @@ # pylint: disable=unused-argument -def _instrument(tracer_provider=None): +def _instrument(tracer_provider=None, span_callback=None): """Enables tracing of all requests calls that go through :code:`requests.session.Session.request` (this includes :code:`requests.get`, etc.).""" @@ -101,6 +101,8 @@ def instrumented_request(self, method, url, *args, **kwargs): span.set_status( Status(_http_status_to_canonical_code(result.status_code)) ) + if span_callback is not None: + span_callback(span, result) return result @@ -156,8 +158,22 @@ def _http_status_to_canonical_code(code: int, allow_redirect: bool = True): class RequestsInstrumentor(BaseInstrumentor): + """An instrumentor for requests + See `BaseInstrumentor` + """ + def _instrument(self, **kwargs): - _instrument(tracer_provider=kwargs.get("tracer_provider")) + """Instruments requests module + + Args: + **kwargs: Optional arguments + ``tracer_provider``: a TracerProvider, defaults to global + ``span_callback``: An optional callback invoked before returning the http response. Invoked with Span and requests.Response + """ + _instrument( + tracer_provider=kwargs.get("tracer_provider"), + span_callback=kwargs.get("span_callback"), + ) def _uninstrument(self, **kwargs): _uninstrument() diff --git a/ext/opentelemetry-ext-requests/tests/test_requests_integration.py b/ext/opentelemetry-ext-requests/tests/test_requests_integration.py index 7764aad3ec5..28359d8f38a 100644 --- a/ext/opentelemetry-ext-requests/tests/test_requests_integration.py +++ b/ext/opentelemetry-ext-requests/tests/test_requests_integration.py @@ -183,6 +183,37 @@ def test_distributed_context(self): finally: propagators.set_global_httptextformat(previous_propagator) + def test_span_callback(self): + RequestsInstrumentor().uninstrument() + + def span_callback(span, result: requests.Response): + span.set_attribute( + "http.response.body", result.content.decode("utf-8") + ) + + RequestsInstrumentor().instrument( + tracer_provider=self.tracer_provider, span_callback=span_callback, + ) + + result = requests.get(self.URL) + self.assertEqual(result.text, "Hello!") + + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 1) + span = span_list[0] + + self.assertEqual( + span.attributes, + { + "component": "http", + "http.method": "GET", + "http.url": self.URL, + "http.status_code": 200, + "http.status_text": "OK", + "http.response.body": "Hello!", + }, + ) + def test_custom_tracer_provider(self): resource = resources.Resource.create({}) result = self.create_tracer_provider(resource=resource)