From 3e185602281c8594ff53f9d667e3813b93ba79c4 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 1 Aug 2024 23:15:34 +0100 Subject: [PATCH] feat: add header extraction parameters to `FastAPIInstrumentor().instrument_app` (#2241) * add header parameters to FastAPIInstrumentor().instrument_app * add changelog * move #2072 changelog to unreleased * remove accidental pprint * linting * future annotations for fastapi * same logic for instrument() * Fix lint * Fix test case --------- Co-authored-by: Diego Hurtado --- CHANGELOG.md | 8 +- .../instrumentation/fastapi/__init__.py | 39 +- .../tests/test_fastapi_instrumentation.py | 534 +++++++++++++++++- 3 files changed, 571 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0984efa3de..f3d9b5e40a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -216,6 +216,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2367](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2367)) +### Added +- `opentelemetry-instrumentation-fastapi` Add support for configuring header extraction via runtime constructor parameters + ([#2241](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2241)) + ## Version 1.23.0/0.44b0 (2024-02-23) - Drop support for 3.7 @@ -236,6 +240,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `opentelemetry-instrumentation-psycopg` Initial release for psycopg 3.x +- `opentelemetry-instrumentation-asgi` Add support for configuring ASGI middleware header extraction via runtime constructor parameters + ([#2026](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2026)) ## Version 1.22.0/0.43b0 (2023-12-14) @@ -275,8 +281,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#1948](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1948)) - Added schema_url (`"https://opentelemetry.io/schemas/1.11.0"`) to all metrics and traces ([#1977](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1977)) -- Add support for configuring ASGI middleware header extraction via runtime constructor parameters - ([#2026](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2026)) ### Fixed diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py index bfb4f4a682..fdb035baa8 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py @@ -86,9 +86,10 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A Request headers *************** To capture HTTP request headers as span attributes, set the environment variable -``OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST`` to a comma delimited list of HTTP header names. +``OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST`` to a comma delimited list of HTTP header names, +or pass the ``http_capture_headers_server_request`` keyword argument to the ``instrument_app`` method. -For example, +For example using the environment variable, :: export OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST="content-type,custom_request_header" @@ -120,9 +121,10 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A Response headers **************** To capture HTTP response headers as span attributes, set the environment variable -``OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE`` to a comma delimited list of HTTP header names. +``OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE`` to a comma delimited list of HTTP header names, +or pass the ``http_capture_headers_server_response`` keyword argument to the ``instrument_app`` method. -For example, +For example using the environment variable, :: export OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE="content-type,custom_response_header" @@ -155,10 +157,12 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A ****************** In order to prevent storing sensitive data such as personally identifiable information (PII), session keys, passwords, etc, set the environment variable ``OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS`` -to a comma delimited list of HTTP header names to be sanitized. Regexes may be used, and all header names will be -matched in a case-insensitive manner. +to a comma delimited list of HTTP header names to be sanitized, or pass the ``http_capture_headers_sanitize_fields`` +keyword argument to the ``instrument_app`` method. -For example, +Regexes may be used, and all header names will be matched in a case-insensitive manner. + +For example using the environment variable, :: export OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS=".*session.*,set-cookie" @@ -171,6 +175,9 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A API --- """ + +from __future__ import annotations + import logging from importlib.metadata import PackageNotFoundError, distribution from typing import Collection @@ -227,6 +234,9 @@ def instrument_app( tracer_provider=None, meter_provider=None, excluded_urls=None, + http_capture_headers_server_request: list[str] | None = None, + http_capture_headers_server_response: list[str] | None = None, + http_capture_headers_sanitize_fields: list[str] | None = None, ): """Instrument an uninstrumented FastAPI application.""" if not hasattr(app, "_is_instrumented_by_opentelemetry"): @@ -265,6 +275,9 @@ def instrument_app( # Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation tracer=tracer, meter=meter, + http_capture_headers_server_request=http_capture_headers_server_request, + http_capture_headers_server_response=http_capture_headers_server_response, + http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields, ) app._is_instrumented_by_opentelemetry = True if app not in _InstrumentedFastAPI._instrumented_fastapi_apps: @@ -314,6 +327,15 @@ def _instrument(self, **kwargs): _InstrumentedFastAPI._client_response_hook = kwargs.get( "client_response_hook" ) + _InstrumentedFastAPI._http_capture_headers_server_request = kwargs.get( + "http_capture_headers_server_request" + ) + _InstrumentedFastAPI._http_capture_headers_server_response = ( + kwargs.get("http_capture_headers_server_response") + ) + _InstrumentedFastAPI._http_capture_headers_sanitize_fields = ( + kwargs.get("http_capture_headers_sanitize_fields") + ) _excluded_urls = kwargs.get("excluded_urls") _InstrumentedFastAPI._excluded_urls = ( _excluded_urls_from_env @@ -368,6 +390,9 @@ def __init__(self, *args, **kwargs): # Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation tracer=tracer, meter=meter, + http_capture_headers_server_request=_InstrumentedFastAPI._http_capture_headers_server_request, + http_capture_headers_server_response=_InstrumentedFastAPI._http_capture_headers_server_response, + http_capture_headers_sanitize_fields=_InstrumentedFastAPI._http_capture_headers_sanitize_fields, ) self._is_instrumented_by_opentelemetry = True _InstrumentedFastAPI._instrumented_fastapi_apps.add(self) diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py index d233331283..03fdd6749d 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py @@ -20,9 +20,11 @@ import fastapi from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware +from fastapi.responses import JSONResponse from fastapi.testclient import TestClient import opentelemetry.instrumentation.fastapi as otel_fastapi +from opentelemetry import trace from opentelemetry.instrumentation._semconv import ( OTEL_SEMCONV_STABILITY_OPT_IN, _OpenTelemetrySemanticConventionStability, @@ -47,8 +49,14 @@ ) from opentelemetry.semconv.attributes.url_attributes import URL_SCHEME from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.test.globals_test import reset_trace_globals from opentelemetry.test.test_base import TestBase -from opentelemetry.util.http import get_excluded_urls +from opentelemetry.util.http import ( + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS, + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST, + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE, + get_excluded_urls, +) _expected_metric_names_old = [ "http.server.active_requests", @@ -1227,3 +1235,527 @@ def test_instrumentation(self): should_be_original = fastapi.FastAPI self.assertIs(original, should_be_original) + + +class TestWrappedApplication(TestBase): + def setUp(self): + super().setUp() + + self.app = fastapi.FastAPI() + + @self.app.get("/foobar") + async def _(): + return {"message": "hello world"} + + otel_fastapi.FastAPIInstrumentor().instrument_app(self.app) + self.client = TestClient(self.app) + self.tracer = self.tracer_provider.get_tracer(__name__) + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app) + + def test_mark_span_internal_in_presence_of_span_from_other_framework(self): + with self.tracer.start_as_current_span( + "test", kind=trace.SpanKind.SERVER + ) as parent_span: + resp = self.client.get("/foobar") + self.assertEqual(200, resp.status_code) + + span_list = self.memory_exporter.get_finished_spans() + for span in span_list: + print(str(span.__class__) + ": " + str(span.__dict__)) + + # there should be 4 spans - single SERVER "test" and three INTERNAL "FastAPI" + self.assertEqual(trace.SpanKind.INTERNAL, span_list[0].kind) + self.assertEqual(trace.SpanKind.INTERNAL, span_list[1].kind) + # main INTERNAL span - child of test + self.assertEqual(trace.SpanKind.INTERNAL, span_list[2].kind) + self.assertEqual( + parent_span.context.span_id, span_list[2].parent.span_id + ) + # SERVER "test" + self.assertEqual(trace.SpanKind.SERVER, span_list[3].kind) + self.assertEqual( + parent_span.context.span_id, span_list[3].context.span_id + ) + + +@patch.dict( + "os.environ", + { + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS: ".*my-secret.*", + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,Regex-Test-Header-.*,Regex-Invalid-Test-Header-.*,.*my-secret.*", + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,my-custom-regex-header-.*,invalid-regex-header-.*,.*my-secret.*", + }, +) +class TestHTTPAppWithCustomHeaders(TestBase): + def setUp(self): + super().setUp() + self.app = self._create_app() + otel_fastapi.FastAPIInstrumentor().instrument_app(self.app) + self.client = TestClient(self.app) + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app) + + @staticmethod + def _create_app(): + app = fastapi.FastAPI() + + @app.get("/foobar") + async def _(): + headers = { + "custom-test-header-1": "test-header-value-1", + "custom-test-header-2": "test-header-value-2", + "my-custom-regex-header-1": "my-custom-regex-value-1,my-custom-regex-value-2", + "My-Custom-Regex-Header-2": "my-custom-regex-value-3,my-custom-regex-value-4", + "My-Secret-Header": "My Secret Value", + } + content = {"message": "hello world"} + return JSONResponse(content=content, headers=headers) + + return app + + def test_http_custom_request_headers_in_span_attributes(self): + expected = { + "http.request.header.custom_test_header_1": ( + "test-header-value-1", + ), + "http.request.header.custom_test_header_2": ( + "test-header-value-2", + ), + "http.request.header.regex_test_header_1": ("Regex Test Value 1",), + "http.request.header.regex_test_header_2": ( + "RegexTestValue2,RegexTestValue3", + ), + "http.request.header.my_secret_header": ("[REDACTED]",), + } + resp = self.client.get( + "/foobar", + headers={ + "custom-test-header-1": "test-header-value-1", + "custom-test-header-2": "test-header-value-2", + "Regex-Test-Header-1": "Regex Test Value 1", + "regex-test-header-2": "RegexTestValue2,RegexTestValue3", + "My-Secret-Header": "My Secret Value", + }, + ) + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + self.assertSpanHasAttributes(server_span, expected) + + def test_http_custom_request_headers_not_in_span_attributes(self): + not_expected = { + "http.request.header.custom_test_header_3": ( + "test-header-value-3", + ), + } + resp = self.client.get( + "/foobar", + headers={ + "custom-test-header-1": "test-header-value-1", + "custom-test-header-2": "test-header-value-2", + "Regex-Test-Header-1": "Regex Test Value 1", + "regex-test-header-2": "RegexTestValue2,RegexTestValue3", + "My-Secret-Header": "My Secret Value", + }, + ) + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + for key, _ in not_expected.items(): + self.assertNotIn(key, server_span.attributes) + + def test_http_custom_response_headers_in_span_attributes(self): + expected = { + "http.response.header.custom_test_header_1": ( + "test-header-value-1", + ), + "http.response.header.custom_test_header_2": ( + "test-header-value-2", + ), + "http.response.header.my_custom_regex_header_1": ( + "my-custom-regex-value-1,my-custom-regex-value-2", + ), + "http.response.header.my_custom_regex_header_2": ( + "my-custom-regex-value-3,my-custom-regex-value-4", + ), + "http.response.header.my_secret_header": ("[REDACTED]",), + } + resp = self.client.get("/foobar") + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + self.assertSpanHasAttributes(server_span, expected) + + def test_http_custom_response_headers_not_in_span_attributes(self): + not_expected = { + "http.response.header.custom_test_header_3": ( + "test-header-value-3", + ), + } + resp = self.client.get("/foobar") + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + for key, _ in not_expected.items(): + self.assertNotIn(key, server_span.attributes) + + +class TestHTTPAppWithCustomHeadersParameters(TestBase): + """Minimal tests here since the behavior of this logic is tested above and in the ASGI tests.""" + + def setUp(self): + super().setUp() + self.instrumentor = otel_fastapi.FastAPIInstrumentor() + self.kwargs = { + "http_capture_headers_server_request": ["a.*", "b.*"], + "http_capture_headers_server_response": ["c.*", "d.*"], + "http_capture_headers_sanitize_fields": [".*secret.*"], + } + self.app = None + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + if self.app: + self.instrumentor.uninstrument_app(self.app) + else: + self.instrumentor.uninstrument() + + @staticmethod + def _create_app(): + app = fastapi.FastAPI() + + @app.get("/foobar") + async def _(): + headers = { + "carrot": "bar", + "date-secret": "yellow", + "egg": "ham", + } + content = {"message": "hello world"} + return JSONResponse(content=content, headers=headers) + + return app + + def test_http_custom_request_headers_in_span_attributes_app(self): + self.app = self._create_app() + self.instrumentor.instrument_app(self.app, **self.kwargs) + + resp = TestClient(self.app).get( + "/foobar", + headers={ + "apple": "red", + "banana-secret": "yellow", + "fig": "green", + }, + ) + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + expected = { + # apple should be included because it starts with a + "http.request.header.apple": ("red",), + # same with banana because it starts with b, + # redacted because it contains "secret" + "http.request.header.banana_secret": ("[REDACTED]",), + } + self.assertSpanHasAttributes(server_span, expected) + self.assertNotIn("http.request.header.fig", server_span.attributes) + + def test_http_custom_request_headers_in_span_attributes_instr(self): + """As above, but use instrument(), not instrument_app().""" + self.instrumentor.instrument(**self.kwargs) + + resp = TestClient(self._create_app()).get( + "/foobar", + headers={ + "apple": "red", + "banana-secret": "yellow", + "fig": "green", + }, + ) + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + expected = { + # apple should be included because it starts with a + "http.request.header.apple": ("red",), + # same with banana because it starts with b, + # redacted because it contains "secret" + "http.request.header.banana_secret": ("[REDACTED]",), + } + self.assertSpanHasAttributes(server_span, expected) + self.assertNotIn("http.request.header.fig", server_span.attributes) + + def test_http_custom_response_headers_in_span_attributes_app(self): + self.app = self._create_app() + self.instrumentor.instrument_app(self.app, **self.kwargs) + resp = TestClient(self.app).get("/foobar") + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + expected = { + "http.response.header.carrot": ("bar",), + "http.response.header.date_secret": ("[REDACTED]",), + } + self.assertSpanHasAttributes(server_span, expected) + self.assertNotIn("http.response.header.egg", server_span.attributes) + + def test_http_custom_response_headers_in_span_attributes_inst(self): + """As above, but use instrument(), not instrument_app().""" + self.instrumentor.instrument(**self.kwargs) + + resp = TestClient(self._create_app()).get("/foobar") + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + expected = { + "http.response.header.carrot": ("bar",), + "http.response.header.date_secret": ("[REDACTED]",), + } + self.assertSpanHasAttributes(server_span, expected) + self.assertNotIn("http.response.header.egg", server_span.attributes) + + +@patch.dict( + "os.environ", + { + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS: ".*my-secret.*", + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,Regex-Test-Header-.*,Regex-Invalid-Test-Header-.*,.*my-secret.*", + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,my-custom-regex-header-.*,invalid-regex-header-.*,.*my-secret.*", + }, +) +class TestWebSocketAppWithCustomHeaders(TestBase): + def setUp(self): + super().setUp() + self.app = self._create_app() + otel_fastapi.FastAPIInstrumentor().instrument_app(self.app) + self.client = TestClient(self.app) + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app) + + @staticmethod + def _create_app(): + app = fastapi.FastAPI() + + @app.websocket("/foobar_web") + async def _(websocket: fastapi.WebSocket): + message = await websocket.receive() + if message.get("type") == "websocket.connect": + await websocket.send( + { + "type": "websocket.accept", + "headers": [ + (b"custom-test-header-1", b"test-header-value-1"), + (b"custom-test-header-2", b"test-header-value-2"), + (b"Regex-Test-Header-1", b"Regex Test Value 1"), + ( + b"regex-test-header-2", + b"RegexTestValue2,RegexTestValue3", + ), + (b"My-Secret-Header", b"My Secret Value"), + ], + } + ) + await websocket.send_json({"message": "hello world"}) + await websocket.close() + if message.get("type") == "websocket.disconnect": + pass + + return app + + def test_web_socket_custom_request_headers_in_span_attributes(self): + expected = { + "http.request.header.custom_test_header_1": ( + "test-header-value-1", + ), + "http.request.header.custom_test_header_2": ( + "test-header-value-2", + ), + } + + with self.client.websocket_connect( + "/foobar_web", + headers={ + "custom-test-header-1": "test-header-value-1", + "custom-test-header-2": "test-header-value-2", + }, + ) as websocket: + data = websocket.receive_json() + self.assertEqual(data, {"message": "hello world"}) + + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 5) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + self.assertSpanHasAttributes(server_span, expected) + + @patch.dict( + "os.environ", + { + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS: ".*my-secret.*", + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,Regex-Test-Header-.*,Regex-Invalid-Test-Header-.*,.*my-secret.*", + }, + ) + def test_web_socket_custom_request_headers_not_in_span_attributes(self): + not_expected = { + "http.request.header.custom_test_header_3": ( + "test-header-value-3", + ), + } + + with self.client.websocket_connect( + "/foobar_web", + headers={ + "custom-test-header-1": "test-header-value-1", + "custom-test-header-2": "test-header-value-2", + }, + ) as websocket: + data = websocket.receive_json() + self.assertEqual(data, {"message": "hello world"}) + + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 5) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + for key, _ in not_expected.items(): + self.assertNotIn(key, server_span.attributes) + + def test_web_socket_custom_response_headers_in_span_attributes(self): + expected = { + "http.response.header.custom_test_header_1": ( + "test-header-value-1", + ), + "http.response.header.custom_test_header_2": ( + "test-header-value-2", + ), + } + + with self.client.websocket_connect("/foobar_web") as websocket: + data = websocket.receive_json() + self.assertEqual(data, {"message": "hello world"}) + + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 5) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + self.assertSpanHasAttributes(server_span, expected) + + def test_web_socket_custom_response_headers_not_in_span_attributes(self): + not_expected = { + "http.response.header.custom_test_header_3": ( + "test-header-value-3", + ), + } + + with self.client.websocket_connect("/foobar_web") as websocket: + data = websocket.receive_json() + self.assertEqual(data, {"message": "hello world"}) + + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 5) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + for key, _ in not_expected.items(): + self.assertNotIn(key, server_span.attributes) + + +@patch.dict( + "os.environ", + { + OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3", + }, +) +class TestNonRecordingSpanWithCustomHeaders(TestBase): + def setUp(self): + super().setUp() + self.app = fastapi.FastAPI() + + @self.app.get("/foobar") + async def _(): + return {"message": "hello world"} + + reset_trace_globals() + tracer_provider = trace.NoOpTracerProvider() + trace.set_tracer_provider(tracer_provider=tracer_provider) + + self._instrumentor = otel_fastapi.FastAPIInstrumentor() + self._instrumentor.instrument_app(self.app) + self.client = TestClient(self.app) + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + self._instrumentor.uninstrument_app(self.app) + + def test_custom_header_not_present_in_non_recording_span(self): + resp = self.client.get( + "/foobar", + headers={ + "custom-test-header-1": "test-header-value-1", + }, + ) + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 0)