Skip to content

Commit

Permalink
same logic for instrument()
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored and ocelotl committed Mar 22, 2024
1 parent 2abcb45 commit 52e0851
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def client_response_hook(span: Span, message: dict):
if span and span.is_recording():
span.set_attribute("custom_user_attribute_from_response_hook", "some-value")
FastAPIInstrumentor().instrument_app(server_request_hook=server_request_hook, client_request_hook=client_request_hook, client_response_hook=client_response_hook)
FastAPIInstrumentor().instrument(server_request_hook=server_request_hook, client_request_hook=client_request_hook, client_response_hook=client_response_hook)
Capture HTTP request and response headers
*****************************************
Expand Down Expand Up @@ -285,6 +285,15 @@ def _instrument(self, **kwargs):
_InstrumentedFastAPI._client_response_hook = kwargs.get(
"client_response_hook"
)
_InstrumentedFastAPI._http_capture_headers_server_request = kwargs.get(
"http_capture_headers_server_request"
)
_InstrumentedFastAPI._http_capture_headers_server_response = (
kwargs.get("http_capture_headers_server_response")
)
_InstrumentedFastAPI._http_capture_headers_sanitize_fields = (
kwargs.get("http_capture_headers_sanitize_fields")
)
_excluded_urls = kwargs.get("excluded_urls")
_InstrumentedFastAPI._excluded_urls = (
_excluded_urls_from_env
Expand Down Expand Up @@ -327,6 +336,9 @@ def __init__(self, *args, **kwargs):
client_response_hook=_InstrumentedFastAPI._client_response_hook,
tracer_provider=_InstrumentedFastAPI._tracer_provider,
meter=meter,
http_capture_headers_server_request=_InstrumentedFastAPI._http_capture_headers_server_request,
http_capture_headers_server_response=_InstrumentedFastAPI._http_capture_headers_server_response,
http_capture_headers_sanitize_fields=_InstrumentedFastAPI._http_capture_headers_sanitize_fields,
)
self._is_instrumented_by_opentelemetry = True
_InstrumentedFastAPI._instrumented_fastapi_apps.add(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,19 +698,21 @@ class TestHTTPAppWithCustomHeadersParameters(TestBase):

def setUp(self):
super().setUp()
self.app = self._create_app()
otel_fastapi.FastAPIInstrumentor().instrument_app(
self.app,
self.instrumentor = otel_fastapi.FastAPIInstrumentor()
self.kwargs = dict(
http_capture_headers_server_request=["a.*", "b.*"],
http_capture_headers_server_response=["c.*", "d.*"],
http_capture_headers_sanitize_fields=[".*secret.*"],
)
self.client = TestClient(self.app)
self.app = None

def tearDown(self) -> None:
super().tearDown()
with self.disable_logging():
otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app)
if self.app:
self.instrumentor.uninstrument_app(self.app)
else:
self.instrumentor.uninstrument()

@staticmethod
def _create_app():
Expand All @@ -728,8 +730,11 @@ async def _():

return app

def test_http_custom_request_headers_in_span_attributes(self):
resp = self.client.get(
def test_http_custom_request_headers_in_span_attributes_app(self):
self.app = self._create_app()
self.instrumentor.instrument_app(self.app, **self.kwargs)

resp = TestClient(self.app).get(
"/foobar",
headers={
"apple": "red",
Expand All @@ -755,8 +760,60 @@ def test_http_custom_request_headers_in_span_attributes(self):
self.assertSpanHasAttributes(server_span, expected)
self.assertNotIn("http.request.header.fig", server_span.attributes)

def test_http_custom_response_headers_in_span_attributes(self):
resp = self.client.get("/foobar")
def test_http_custom_request_headers_in_span_attributes_instr(self):
"""As above, but use instrument(), not instrument_app()."""
self.instrumentor.instrument(**self.kwargs)

resp = TestClient(self._create_app()).get(
"/foobar",
headers={
"apple": "red",
"banana-secret": "yellow",
"fig": "green",
},
)
self.assertEqual(200, resp.status_code)
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 3)

server_span = [
span for span in span_list if span.kind == trace.SpanKind.SERVER
][0]

expected = {
# apple should be included because it starts with a
"http.request.header.apple": ("red",),
# same with banana because it starts with b,
# redacted because it contains "secret"
"http.request.header.banana_secret": ("[REDACTED]",),
}
self.assertSpanHasAttributes(server_span, expected)
self.assertNotIn("http.request.header.fig", server_span.attributes)

def test_http_custom_response_headers_in_span_attributes_app(self):
self.app = self._create_app()
self.instrumentor.instrument_app(self.app, **self.kwargs)
resp = TestClient(self.app).get("/foobar")
self.assertEqual(200, resp.status_code)
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 3)

server_span = [
span for span in span_list if span.kind == trace.SpanKind.SERVER
][0]

expected = {
"http.response.header.carrot": ("bar",),
"http.response.header.date_secret": ("[REDACTED]",),
}
self.assertSpanHasAttributes(server_span, expected)
self.assertNotIn("http.response.header.egg", server_span.attributes)

def test_http_custom_response_headers_in_span_attributes_inst(self):
"""As above, but use instrument(), not instrument_app()."""
self.instrumentor.instrument(**self.kwargs)

resp = TestClient(self._create_app()).get("/foobar")
self.assertEqual(200, resp.status_code)
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 3)
Expand Down

0 comments on commit 52e0851

Please sign in to comment.