From 1e904daaaff179d60449a29ff9a7013091c1fcfc Mon Sep 17 00:00:00 2001 From: Erle Carrara Date: Wed, 16 Oct 2024 23:21:19 -0300 Subject: [PATCH] Support functools.partial functions in AsyncioInstrumentor.trace_to_thread Change `trace_to_thread` method to retrieve the function name from the `func` attribute of the `functools.partial` instance. > partial objects are like function objects in that they are callable, weak > referenceable, and can have attributes. There are some important differences. > For instance, the __name__ and function.__doc__ attributes are not created > automatically. Also, partial objects defined in classes behave like static > methods and do not transform into bound methods during instance attribute > look-up. > > Reference: https://docs.python.org/3.12/library/functools.html#partial-objects --- .../instrumentation/asyncio/__init__.py | 10 ++++-- .../tests/test_asyncio_to_thread.py | 35 +++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py b/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py index e83f384a8c..eafc1133c1 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py @@ -77,6 +77,7 @@ def func(): --- """ import asyncio +import functools import sys from asyncio import futures from timeit import default_timer @@ -231,14 +232,17 @@ def wrap_taskgroup_create_task(method, instance, args, kwargs) -> None: def trace_to_thread(self, func: callable): """Trace a function.""" start = default_timer() + func_name = getattr(func, '__name__', None) + if func_name is None and isinstance(func, functools.partial): + func_name = func.func.__name__ span = ( self._tracer.start_span( - f"{ASYNCIO_PREFIX} to_thread-" + func.__name__ + f"{ASYNCIO_PREFIX} to_thread-" + func_name ) - if func.__name__ in self._to_thread_name_to_trace + if func_name in self._to_thread_name_to_trace else None ) - attr = {"type": "to_thread", "name": func.__name__} + attr = {"type": "to_thread", "name": func_name} exception = None try: attr["state"] = "finished" diff --git a/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py b/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py index 3d795d8ae7..35191d3d03 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py +++ b/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import functools import sys from unittest import skipIf from unittest.mock import patch @@ -72,3 +73,37 @@ async def to_thread(): for point in metric.data.data_points: self.assertEqual(point.attributes["type"], "to_thread") self.assertEqual(point.attributes["name"], "multiply") + + @skipIf( + sys.version_info < (3, 9), "to_thread is only available in Python 3.9+" + ) + def test_to_thread_partial_func(self): + def multiply(x, y): + return x * y + + double = functools.partial(multiply, 2) + + async def to_thread(): + result = await asyncio.to_thread(double, 3) + assert result == 6 + + with self._tracer.start_as_current_span("root"): + asyncio.run(to_thread()) + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 2) + assert spans[0].name == "asyncio to_thread-multiply" + for metric in ( + self.memory_metrics_reader.get_metrics_data() + .resource_metrics[0] + .scope_metrics[0] + .metrics + ): + if metric.name == "asyncio.process.duration": + for point in metric.data.data_points: + self.assertEqual(point.attributes["type"], "to_thread") + self.assertEqual(point.attributes["name"], "multiply") + if metric.name == "asyncio.process.created": + for point in metric.data.data_points: + self.assertEqual(point.attributes["type"], "to_thread") + self.assertEqual(point.attributes["name"], "multiply")