From 29ef6a945550c09a83003577f4d8037b22322038 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 10 Jan 2025 19:03:09 +0100 Subject: [PATCH] Add type hints to dbapi (#3068) --- CHANGELOG.md | 2 + docs/nitpick-exceptions.ini | 3 + .../instrumentation/dbapi/__init__.py | 245 ++++++++++-------- .../instrumentation/dbapi/py.typed | 0 4 files changed, 136 insertions(+), 114 deletions(-) create mode 100644 instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/py.typed diff --git a/CHANGELOG.md b/CHANGELOG.md index a6d7340c8c..438f9c787e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `opentelemetry-instrumentation-httpx` Fix `RequestInfo`/`ResponseInfo` type hints ([#3105](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3105)) +- `opentelemetry-instrumentation-dbapi` Move `TracedCursorProxy` and `TracedConnectionProxy` to the module level + ([#3068](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3068)) - `opentelemetry-instrumentation-click` Disable tracing of well-known server click commands ([#3174](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3174)) - `opentelemetry-instrumentation` Fix `get_dist_dependency_conflicts` if no distribution requires diff --git a/docs/nitpick-exceptions.ini b/docs/nitpick-exceptions.ini index e27bee26bb..b1fcdd5342 100644 --- a/docs/nitpick-exceptions.ini +++ b/docs/nitpick-exceptions.ini @@ -41,6 +41,7 @@ py-class= callable Consumer confluent_kafka.Message + ObjectProxy any= ; API @@ -68,6 +69,8 @@ any= py-obj= opentelemetry.propagators.textmap.CarrierT + opentelemetry.instrumentation.dbapi.ConnectionT + opentelemetry.instrumentation.dbapi.CursorT py-func= poll diff --git a/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py b/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py index 9b709c584c..27aafc7308 100644 --- a/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py @@ -37,12 +37,15 @@ --- """ +from __future__ import annotations + import functools import logging import re -import typing +from typing import Any, Callable, Generic, TypeVar import wrapt +from wrapt import wrap_function_wrapper from opentelemetry import trace as trace_api from opentelemetry.instrumentation.dbapi.version import __version__ @@ -61,16 +64,19 @@ _logger = logging.getLogger(__name__) +ConnectionT = TypeVar("ConnectionT") +CursorT = TypeVar("CursorT") + def trace_integration( - connect_module: typing.Callable[..., typing.Any], + connect_module: Callable[..., Any], connect_method_name: str, database_system: str, - connection_attributes: typing.Dict = None, - tracer_provider: typing.Optional[TracerProvider] = None, + connection_attributes: dict[str, Any] | None = None, + tracer_provider: TracerProvider | None = None, capture_parameters: bool = False, enable_commenter: bool = False, - db_api_integration_factory=None, + db_api_integration_factory: type[DatabaseApiIntegration] | None = None, enable_attribute_commenter: bool = False, ): """Integrate with DB API library. @@ -108,16 +114,16 @@ def trace_integration( def wrap_connect( name: str, - connect_module: typing.Callable[..., typing.Any], + connect_module: Callable[..., Any], connect_method_name: str, database_system: str, - connection_attributes: typing.Dict = None, + connection_attributes: dict[str, Any] | None = None, version: str = "", - tracer_provider: typing.Optional[TracerProvider] = None, + tracer_provider: TracerProvider | None = None, capture_parameters: bool = False, enable_commenter: bool = False, - db_api_integration_factory=None, - commenter_options: dict = None, + db_api_integration_factory: type[DatabaseApiIntegration] | None = None, + commenter_options: dict[str, Any] | None = None, enable_attribute_commenter: bool = False, ): """Integrate with DB API library. @@ -146,10 +152,10 @@ def wrap_connect( # pylint: disable=unused-argument def wrap_connect_( - wrapped: typing.Callable[..., typing.Any], - instance: typing.Any, - args: typing.Tuple[typing.Any, typing.Any], - kwargs: typing.Dict[typing.Any, typing.Any], + wrapped: Callable[..., Any], + instance: Any, + args: tuple[Any, Any], + kwargs: dict[Any, Any], ): db_integration = db_api_integration_factory( name, @@ -166,7 +172,7 @@ def wrap_connect_( return db_integration.wrapped_connection(wrapped, args, kwargs) try: - wrapt.wrap_function_wrapper( + wrap_function_wrapper( connect_module, connect_method_name, wrap_connect_ ) except Exception as ex: # pylint: disable=broad-except @@ -174,7 +180,7 @@ def wrap_connect_( def unwrap_connect( - connect_module: typing.Callable[..., typing.Any], connect_method_name: str + connect_module: Callable[..., Any], connect_method_name: str ): """Disable integration with DB API library. https://www.python.org/dev/peps/pep-0249/ @@ -188,17 +194,17 @@ def unwrap_connect( def instrument_connection( name: str, - connection, + connection: ConnectionT | TracedConnectionProxy[ConnectionT], database_system: str, - connection_attributes: typing.Dict = None, + connection_attributes: dict[str, Any] | None = None, version: str = "", - tracer_provider: typing.Optional[TracerProvider] = None, + tracer_provider: TracerProvider | None = None, capture_parameters: bool = False, enable_commenter: bool = False, - commenter_options: dict = None, - connect_module: typing.Callable[..., typing.Any] = None, + commenter_options: dict[str, Any] | None = None, + connect_module: Callable[..., Any] | None = None, enable_attribute_commenter: bool = False, -): +) -> TracedConnectionProxy[ConnectionT]: """Enable instrumentation in a database connection. Args: @@ -238,7 +244,9 @@ def instrument_connection( return get_traced_connection_proxy(connection, db_integration) -def uninstrument_connection(connection): +def uninstrument_connection( + connection: ConnectionT | TracedConnectionProxy[ConnectionT], +) -> ConnectionT: """Disable instrumentation in a database connection. Args: @@ -259,23 +267,24 @@ def __init__( self, name: str, database_system: str, - connection_attributes=None, + connection_attributes: dict[str, Any] | None = None, version: str = "", - tracer_provider: typing.Optional[TracerProvider] = None, + tracer_provider: TracerProvider | None = None, capture_parameters: bool = False, enable_commenter: bool = False, - commenter_options: dict = None, - connect_module: typing.Callable[..., typing.Any] = None, + commenter_options: dict[str, Any] | None = None, + connect_module: Callable[..., Any] | None = None, enable_attribute_commenter: bool = False, ): - self.connection_attributes = connection_attributes - if self.connection_attributes is None: + if connection_attributes is None: self.connection_attributes = { "database": "database", "port": "port", "host": "host", "user": "user", } + else: + self.connection_attributes = connection_attributes self._name = name self._version = version self._tracer = get_tracer( @@ -289,17 +298,14 @@ def __init__( self.commenter_options = commenter_options self.enable_attribute_commenter = enable_attribute_commenter self.database_system = database_system - self.connection_props = {} - self.span_attributes = {} + self.connection_props: dict[str, Any] = {} + self.span_attributes: dict[str, Any] = {} self.name = "" self.database = "" self.connect_module = connect_module self.commenter_data = self.calculate_commenter_data() - def _get_db_version( - self, - db_driver, - ): + def _get_db_version(self, db_driver: str) -> str: if db_driver in _DB_DRIVER_ALIASES: return util_version(_DB_DRIVER_ALIASES[db_driver]) db_version = "" @@ -309,10 +315,8 @@ def _get_db_version( db_version = "unknown" return db_version - def calculate_commenter_data( - self, - ): - commenter_data = {} + def calculate_commenter_data(self) -> dict[str, Any]: + commenter_data: dict[str, Any] = {} if not self.enable_commenter: return commenter_data @@ -339,11 +343,7 @@ def calculate_commenter_data( libpq_version = self.connect_module.__libpq_version__ else: libpq_version = self.connect_module.pq.__build_version__ - commenter_data.update( - { - "libpq_version": libpq_version, - } - ) + commenter_data.update({"libpq_version": libpq_version}) elif self.database_system == "mysql": mysqlc_version = "" if db_driver == "MySQLdb": @@ -351,26 +351,22 @@ def calculate_commenter_data( elif db_driver == "pymysql": mysqlc_version = self.connect_module.get_client_info() - commenter_data.update( - { - "mysql_client_version": mysqlc_version, - } - ) + commenter_data.update({"mysql_client_version": mysqlc_version}) return commenter_data def wrapped_connection( self, - connect_method: typing.Callable[..., typing.Any], - args: typing.Tuple[typing.Any, typing.Any], - kwargs: typing.Dict[typing.Any, typing.Any], - ): + connect_method: Callable[..., ConnectionT], + args: tuple[Any, ...], + kwargs: dict[Any, Any], + ) -> TracedConnectionProxy[ConnectionT]: """Add object proxy to connection object.""" connection = connect_method(*args, **kwargs) self.get_connection_attributes(connection) return get_traced_connection_proxy(connection, self) - def get_connection_attributes(self, connection): + def get_connection_attributes(self, connection: object) -> None: # Populate span fields using connection for key, value in self.connection_attributes.items(): # Allow attributes nested in connection object @@ -404,39 +400,49 @@ def get_connection_attributes(self, connection): self.span_attributes[SpanAttributes.NET_PEER_PORT] = port -def get_traced_connection_proxy( - connection, db_api_integration, *args, **kwargs -): - # pylint: disable=abstract-method - class TracedConnectionProxy(wrapt.ObjectProxy): - # pylint: disable=unused-argument - def __init__(self, connection, *args, **kwargs): - wrapt.ObjectProxy.__init__(self, connection) - - def __getattribute__(self, name): - if object.__getattribute__(self, name): - return object.__getattribute__(self, name) - - return object.__getattribute__( - object.__getattribute__(self, "_connection"), name - ) +# pylint: disable=abstract-method +class TracedConnectionProxy(wrapt.ObjectProxy, Generic[ConnectionT]): + # pylint: disable=unused-argument + def __init__( + self, + connection: ConnectionT, + db_api_integration: DatabaseApiIntegration | None = None, + ): + wrapt.ObjectProxy.__init__(self, connection) + self._self_db_api_integration = db_api_integration - def cursor(self, *args, **kwargs): - return get_traced_cursor_proxy( - self.__wrapped__.cursor(*args, **kwargs), db_api_integration - ) + def __getattribute__(self, name: str): + if object.__getattribute__(self, name): + return object.__getattribute__(self, name) - def __enter__(self): - self.__wrapped__.__enter__() - return self + return object.__getattribute__( + object.__getattribute__(self, "_connection"), name + ) - def __exit__(self, *args, **kwargs): - self.__wrapped__.__exit__(*args, **kwargs) + def cursor(self, *args: Any, **kwargs: Any): + return get_traced_cursor_proxy( + self.__wrapped__.cursor(*args, **kwargs), + self._self_db_api_integration, + ) - return TracedConnectionProxy(connection, *args, **kwargs) + def __enter__(self): + self.__wrapped__.__enter__() + return self + def __exit__(self, *args: Any, **kwargs: Any): + self.__wrapped__.__exit__(*args, **kwargs) -class CursorTracer: + +def get_traced_connection_proxy( + connection: ConnectionT, + db_api_integration: DatabaseApiIntegration | None, + *args: Any, + **kwargs: Any, +) -> TracedConnectionProxy[ConnectionT]: + return TracedConnectionProxy(connection, db_api_integration) + + +class CursorTracer(Generic[CursorT]): def __init__(self, db_api_integration: DatabaseApiIntegration) -> None: self._db_api_integration = db_api_integration self._commenter_enabled = self._db_api_integration.enable_commenter @@ -494,8 +500,8 @@ def _update_args_with_added_sql_comment(self, args, cursor) -> tuple: def _populate_span( self, span: trace_api.Span, - cursor, - *args: typing.Tuple[typing.Any, typing.Any], + cursor: CursorT, + *args: tuple[Any, ...], ): if not span.is_recording(): return @@ -517,13 +523,15 @@ def _populate_span( if self._db_api_integration.capture_parameters and len(args) > 1: span.set_attribute("db.statement.parameters", str(args[1])) - def get_operation_name(self, cursor, args): # pylint: disable=no-self-use + def get_operation_name( + self, cursor: CursorT, args: tuple[Any, ...] + ) -> str: # pylint: disable=no-self-use if args and isinstance(args[0], str): # Strip leading comments so we get the operation name. return self._leading_comment_remover.sub("", args[0]).split()[0] return "" - def get_statement(self, cursor, args): # pylint: disable=no-self-use + def get_statement(self, cursor: CursorT, args: tuple[Any, ...]): # pylint: disable=no-self-use if not args: return "" statement = args[0] @@ -533,10 +541,10 @@ def get_statement(self, cursor, args): # pylint: disable=no-self-use def traced_execution( self, - cursor, - query_method: typing.Callable[..., typing.Any], - *args: typing.Tuple[typing.Any, typing.Any], - **kwargs: typing.Dict[typing.Any, typing.Any], + cursor: CursorT, + query_method: Callable[..., Any], + *args: tuple[Any, ...], + **kwargs: dict[Any, Any], ): name = self.get_operation_name(cursor, args) if not name: @@ -570,35 +578,44 @@ def traced_execution( return query_method(*args, **kwargs) -def get_traced_cursor_proxy(cursor, db_api_integration, *args, **kwargs): - _cursor_tracer = CursorTracer(db_api_integration) +# pylint: disable=abstract-method +class TracedCursorProxy(wrapt.ObjectProxy, Generic[CursorT]): + # pylint: disable=unused-argument + def __init__( + self, + cursor: CursorT, + db_api_integration: DatabaseApiIntegration, + ): + wrapt.ObjectProxy.__init__(self, cursor) + self._self_cursor_tracer = CursorTracer[CursorT](db_api_integration) - # pylint: disable=abstract-method - class TracedCursorProxy(wrapt.ObjectProxy): - # pylint: disable=unused-argument - def __init__(self, cursor, *args, **kwargs): - wrapt.ObjectProxy.__init__(self, cursor) + def execute(self, *args: Any, **kwargs: Any): + return self._self_cursor_tracer.traced_execution( + self.__wrapped__, self.__wrapped__.execute, *args, **kwargs + ) - def execute(self, *args, **kwargs): - return _cursor_tracer.traced_execution( - self.__wrapped__, self.__wrapped__.execute, *args, **kwargs - ) + def executemany(self, *args: Any, **kwargs: Any): + return self._self_cursor_tracer.traced_execution( + self.__wrapped__, self.__wrapped__.executemany, *args, **kwargs + ) - def executemany(self, *args, **kwargs): - return _cursor_tracer.traced_execution( - self.__wrapped__, self.__wrapped__.executemany, *args, **kwargs - ) + def callproc(self, *args: Any, **kwargs: Any): + return self._self_cursor_tracer.traced_execution( + self.__wrapped__, self.__wrapped__.callproc, *args, **kwargs + ) - def callproc(self, *args, **kwargs): - return _cursor_tracer.traced_execution( - self.__wrapped__, self.__wrapped__.callproc, *args, **kwargs - ) + def __enter__(self): + self.__wrapped__.__enter__() + return self - def __enter__(self): - self.__wrapped__.__enter__() - return self + def __exit__(self, *args, **kwargs): + self.__wrapped__.__exit__(*args, **kwargs) - def __exit__(self, *args, **kwargs): - self.__wrapped__.__exit__(*args, **kwargs) - return TracedCursorProxy(cursor, *args, **kwargs) +def get_traced_cursor_proxy( + cursor: CursorT, + db_api_integration: DatabaseApiIntegration, + *args: Any, + **kwargs: Any, +) -> TracedCursorProxy[CursorT]: + return TracedCursorProxy(cursor, db_api_integration) diff --git a/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/py.typed b/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/py.typed new file mode 100644 index 0000000000..e69de29bb2