Skip to content

Commit

Permalink
green tests \o/
Browse files Browse the repository at this point in the history
  • Loading branch information
xrmx committed Oct 22, 2024
1 parent dd3d236 commit e7c8efe
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -733,8 +733,18 @@ def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
self._request_hook = kwargs.get("request_hook")
self._response_hook = kwargs.get("response_hook")
self._async_request_hook = kwargs.get("async_request_hook")
self._async_response_hook = kwargs.get("async_response_hook")
_async_request_hook = kwargs.get("async_request_hook")
self._async_request_hook = (
_async_request_hook
if iscoroutinefunction(_async_request_hook)
else None
)
_async_response_hook = kwargs.get("async_response_hook")
self._async_response_hook = (
_async_response_hook
if iscoroutinefunction(_async_response_hook)
else None
)

_OpenTelemetrySemanticConventionStability._initialize()
self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
Expand Down Expand Up @@ -826,7 +836,7 @@ def _handle_request_wrapper(self, wrapped, instance, args, kwargs):
span.set_attribute(
ERROR_TYPE, type(exception).__qualname__
)
raise exception.with_traceback(exception.__traceback__)
raise exception

return response

Expand Down Expand Up @@ -895,7 +905,7 @@ async def _handle_async_request_wrapper(
span.set_attribute(
ERROR_TYPE, type(exception).__qualname__
)
raise exception.with_traceback(exception.__traceback__)
raise exception

return response

Expand Down Expand Up @@ -927,6 +937,19 @@ def instrument_client(
)
return

# FIXME: sharing state in the instrumentor instance maybe it's not that great, need to pass tracer and semconv to each
# instance separately
_OpenTelemetrySemanticConventionStability._initialize()
self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
_OpenTelemetryStabilitySignalType.HTTP,
)
self._tracer = get_tracer(
__name__,
instrumenting_library_version=__version__,
tracer_provider=tracer_provider,
schema_url=_get_schema_url(self._sem_conv_opt_in_mode),
)

if iscoroutinefunction(request_hook):
self._async_request_hook = request_hook
self._request_hook = None
Expand All @@ -947,13 +970,27 @@ def instrument_client(
"handle_request",
self._handle_request_wrapper,
)
for transport in client._mounts.values():
# FIXME: check it's not wrapped already?
wrap_function_wrapper(
transport,
"handle_request",
self._handle_request_wrapper,
)
client._is_instrumented_by_opentelemetry = True
if hasattr(client._transport, "handle_async_request"):
wrap_function_wrapper(
client._transport,
"handle_async_request",
self._handle_async_request_wrapper,
)
for transport in client._mounts.values():
# FIXME: check it's not wrapped already?
wrap_function_wrapper(
transport,
"handle_async_request",
self._handle_async_request_wrapper,
)
client._is_instrumented_by_opentelemetry = True

@staticmethod
Expand All @@ -967,7 +1004,11 @@ def uninstrument_client(
"""
if hasattr(client._transport, "handle_request"):
unwrap(client._transport, "handle_request")
for transport in client._mounts.values():
unwrap(transport, "handle_request")
client._is_instrumented_by_opentelemetry = False
elif hasattr(client._transport, "handle_async_request"):
unwrap(client._transport, "handle_async_request")
for transport in client._mounts.values():
unwrap(transport, "handle_async_request")
client._is_instrumented_by_opentelemetry = False
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ def setUp(self):
)
)

HTTPXClientInstrumentor().instrument()

def print_spans(self, spans):
for span in spans:
print(span.name, span.attributes)
Expand Down Expand Up @@ -751,8 +749,9 @@ def create_proxy_transport(self, url: str):

def setUp(self):
super().setUp()
HTTPXClientInstrumentor().instrument()
self.client = self.create_client()
# FIXME: calling instrument() instead fixes 13*2 tests :(
HTTPXClientInstrumentor().instrument_client(self.client)

def tearDown(self):
HTTPXClientInstrumentor().uninstrument()
Expand Down Expand Up @@ -792,7 +791,6 @@ def test_custom_tracer_provider(self):
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result

HTTPXClientInstrumentor().uninstrument()
HTTPXClientInstrumentor().instrument(
tracer_provider=tracer_provider
)
Expand All @@ -802,7 +800,6 @@ def test_custom_tracer_provider(self):
self.assertEqual(result.text, "Hello!")
span = self.assert_span(exporter=exporter)
self.assertIs(span.resource, resource)
HTTPXClientInstrumentor().uninstrument()

def test_response_hook(self):
response_hook_key = (
Expand All @@ -811,7 +808,6 @@ def test_response_hook(self):
else "response_hook"
)
response_hook_kwargs = {response_hook_key: self.response_hook}
HTTPXClientInstrumentor().uninstrument()
HTTPXClientInstrumentor().instrument(
tracer_provider=self.tracer_provider,
**response_hook_kwargs,
Expand All @@ -830,10 +826,8 @@ def test_response_hook(self):
HTTP_RESPONSE_BODY: "Hello!",
},
)
HTTPXClientInstrumentor().uninstrument()

def test_response_hook_sync_async_kwargs(self):
HTTPXClientInstrumentor().uninstrument()
HTTPXClientInstrumentor().instrument(
tracer_provider=self.tracer_provider,
response_hook=_response_hook,
Expand All @@ -845,15 +839,14 @@ def test_response_hook_sync_async_kwargs(self):
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(
dict(span.attributes),
span.attributes,
{
SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL,
SpanAttributes.HTTP_STATUS_CODE: 200,
HTTP_RESPONSE_BODY: "Hello!",
},
)
HTTPXClientInstrumentor().uninstrument()

def test_request_hook(self):
request_hook_key = (
Expand All @@ -862,7 +855,6 @@ def test_request_hook(self):
else "request_hook"
)
request_hook_kwargs = {request_hook_key: self.request_hook}
HTTPXClientInstrumentor().uninstrument()
HTTPXClientInstrumentor().instrument(
tracer_provider=self.tracer_provider,
**request_hook_kwargs,
Expand All @@ -873,10 +865,8 @@ def test_request_hook(self):
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(span.name, "GET" + self.URL)
HTTPXClientInstrumentor().uninstrument()

def test_request_hook_sync_async_kwargs(self):
HTTPXClientInstrumentor().uninstrument()
HTTPXClientInstrumentor().instrument(
tracer_provider=self.tracer_provider,
request_hook=_request_hook,
Expand All @@ -888,10 +878,8 @@ def test_request_hook_sync_async_kwargs(self):
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(span.name, "GET" + self.URL)
HTTPXClientInstrumentor().uninstrument()

def test_request_hook_no_span_update(self):
HTTPXClientInstrumentor().uninstrument()
HTTPXClientInstrumentor().instrument(
tracer_provider=self.tracer_provider,
request_hook=self.no_update_request_hook,
Expand All @@ -902,10 +890,8 @@ def test_request_hook_no_span_update(self):
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(span.name, "GET")
HTTPXClientInstrumentor().uninstrument()

def test_not_recording(self):
HTTPXClientInstrumentor().uninstrument()
with mock.patch("opentelemetry.trace.INVALID_SPAN") as mock_span:
HTTPXClientInstrumentor().instrument(
tracer_provider=trace.NoOpTracerProvider()
Expand All @@ -921,28 +907,26 @@ def test_not_recording(self):
self.assertTrue(mock_span.is_recording.called)
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)
HTTPXClientInstrumentor().uninstrument()

def test_suppress_instrumentation_new_client(self):
HTTPXClientInstrumentor().uninstrument()
HTTPXClientInstrumentor().instrument()
with suppress_http_instrumentation():
client = self.create_client()
result = self.perform_request(self.URL, client=client)
self.assertEqual(result.text, "Hello!")

self.assert_span(num_spans=0)
HTTPXClientInstrumentor().uninstrument()

def test_instrument_client(self):
HTTPXClientInstrumentor().uninstrument()
client = self.create_client()
HTTPXClientInstrumentor().instrument_client(client)
result = self.perform_request(self.URL, client=client)
self.assertEqual(result.text, "Hello!")
self.assert_span(num_spans=1)

def test_instrumentation_without_client(self):

HTTPXClientInstrumentor().instrument()
results = [
httpx.get(self.URL),
httpx.request("GET", self.URL),
Expand All @@ -961,6 +945,7 @@ def test_instrumentation_without_client(self):
)

def test_uninstrument(self):
HTTPXClientInstrumentor().instrument()
HTTPXClientInstrumentor().uninstrument()
client = self.create_client()
result = self.perform_request(self.URL, client=client)
Expand All @@ -970,7 +955,6 @@ def test_uninstrument(self):
self.assert_span(num_spans=0)

def test_uninstrument_client(self):
HTTPXClientInstrumentor().uninstrument()
HTTPXClientInstrumentor().uninstrument_client(self.client)

result = self.perform_request(self.URL)
Expand All @@ -979,6 +963,7 @@ def test_uninstrument_client(self):
self.assert_span(num_spans=0)

def test_uninstrument_new_client(self):
HTTPXClientInstrumentor().instrument()
client1 = self.create_client()
HTTPXClientInstrumentor().uninstrument_client(client1)

Expand All @@ -1001,6 +986,7 @@ def test_uninstrument_new_client(self):

def test_instrument_proxy(self):
proxy_mounts = self.create_proxy_mounts()
HTTPXClientInstrumentor().instrument()
client = self.create_client(mounts=proxy_mounts)
self.perform_request(self.URL, client=client)
self.assert_span(num_spans=1)
Expand All @@ -1027,7 +1013,6 @@ def print_handler(self, client):
return handler

def test_instrument_client_with_proxy(self):
HTTPXClientInstrumentor().uninstrument()
proxy_mounts = self.create_proxy_mounts()
client = self.create_client(mounts=proxy_mounts)
self.assert_proxy_mounts(
Expand All @@ -1047,6 +1032,7 @@ def test_instrument_client_with_proxy(self):

def test_uninstrument_client_with_proxy(self):
proxy_mounts = self.create_proxy_mounts()
HTTPXClientInstrumentor().instrument()
client = self.create_client(mounts=proxy_mounts)
self.assert_proxy_mounts(
client._mounts.values(),
Expand Down Expand Up @@ -1109,7 +1095,7 @@ def create_client(
transport: typing.Optional[SyncOpenTelemetryTransport] = None,
**kwargs,
):
return httpx.Client(**kwargs)
return httpx.Client(transport=transport, **kwargs)

def perform_request(
self,
Expand Down Expand Up @@ -1230,6 +1216,7 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
def setUp(self):
super().setUp()
self.client2 = self.create_client()
HTTPXClientInstrumentor().instrument_client(self.client2)

def create_client(
self,
Expand Down Expand Up @@ -1283,7 +1270,6 @@ def test_async_response_hook_does_nothing_if_not_coroutine(self):
SpanAttributes.HTTP_STATUS_CODE: 200,
},
)
HTTPXClientInstrumentor().uninstrument()

def test_async_request_hook_does_nothing_if_not_coroutine(self):
HTTPXClientInstrumentor().instrument(
Expand All @@ -1296,4 +1282,3 @@ def test_async_request_hook_does_nothing_if_not_coroutine(self):
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(span.name, "GET")
HTTPXClientInstrumentor().uninstrument()

0 comments on commit e7c8efe

Please sign in to comment.