diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index 3b6ddebd3c..08a5aa7016 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -54,12 +54,11 @@ def get_tracer(tracer_provider=None): return tracer_provider.get_tracer(TRACER_NAME, TRACER_VERSION) -@contextmanager -def trace_call(name, session, extra_attributes=None, observability_options=None): - if not HAS_OPENTELEMETRY_INSTALLED or not session: - # Empty context manager. Users will have to check if the generated value is None or a span - yield None - return +def trace_end_explicitly( + name, session=None, extra_attributes=None, observability_options=None +): + if not HAS_OPENTELEMETRY_INSTALLED: + return None tracer_provider = None @@ -100,8 +99,13 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) if not enable_extended_tracing: attributes.pop("db.statement", False) - with tracer.start_as_current_span( - name, kind=trace.SpanKind.CLIENT, attributes=attributes + return tracer.start_span(name, kind=trace.SpanKind.CLIENT, attributes=attributes) + + +@contextmanager +def trace_call(name, session=None, extra_attributes=None, observability_options=None): + with trace_end_explicitly( + name, session, extra_attributes, observability_options ) as span: try: yield span diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 948740d7d4..39e10f3d3f 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -26,7 +26,10 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import ( + trace_call, + trace_end_explicitly, +) from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1._helpers import _retry from google.cloud.spanner_v1._helpers import _check_rst_stream_error @@ -46,6 +49,14 @@ class _BatchBase(_SessionWrapper): def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations = [] + observability_options = getattr( + self._session.database, "observability_options", None + ) + self.__span = trace_end_explicitly( + "CloudSpannerX." + type(self).__name__, + self._session, + observability_options=observability_options, + ) def _check_state(self): """Helper for :meth:`commit` et al. @@ -69,6 +80,10 @@ def insert(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ + if self.__span: + self.__span.add_event( + "insert mutations inserted", dict(table=table, columns=columns) + ) self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values))) def update(self, table, columns, values): @@ -83,6 +98,10 @@ def update(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ + if self.__span: + self.__span.add_event( + "update mutations inserted", dict(table=table, columns=columns) + ) self._mutations.append(Mutation(update=_make_write_pb(table, columns, values))) def insert_or_update(self, table, columns, values): @@ -97,6 +116,11 @@ def insert_or_update(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ + if self.__span: + self.__span.add_event( + "insert_or_update mutations inserted", + dict(table=table, columns=columns), + ) self._mutations.append( Mutation(insert_or_update=_make_write_pb(table, columns, values)) ) @@ -113,6 +137,10 @@ def replace(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ + if self.__span: + self.__span.add_event( + "replace mutations inserted", dict(table=table, columns=columns) + ) self._mutations.append(Mutation(replace=_make_write_pb(table, columns, values))) def delete(self, table, keyset): @@ -126,6 +154,10 @@ def delete(self, table, keyset): """ delete = Mutation.Delete(table=table, key_set=keyset._to_pb()) self._mutations.append(Mutation(delete=delete)) + if self.__span: + self.__span.add_event( + "delete mutations inserted", dict(table=table, columns=columns) + ) class Batch(_BatchBase): @@ -207,7 +239,7 @@ def commit( ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.Commit", + "CloudSpanner.Batch.commit", self._session, trace_attributes, observability_options=observability_options, @@ -228,6 +260,14 @@ def commit( def __enter__(self): """Begin ``with`` block.""" self._check_state() + observability_options = getattr( + self._session.database, "observability_options", None + ) + self.__span = trace_end_explicitly( + "CloudSpanner.Batch", + self._session, + observability_options=observability_options, + ) return self @@ -235,6 +275,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" if exc_type is None: self.commit() + if self.__span: + self.__span.end() + self.__span = None class MutationGroup(_BatchBase): @@ -326,7 +369,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", self._session, trace_attributes, observability_options=observability_options, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 38ac469fe3..23ed1b0ed0 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -53,7 +53,10 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import ( + trace_call, + trace_end_explicitly, +) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups from google.cloud.spanner_v1.keyset import KeySet @@ -1188,9 +1191,16 @@ def __init__( self._request_options = request_options self._max_commit_delay = max_commit_delay self._exclude_txn_from_change_streams = exclude_txn_from_change_streams + self.__span = None def __enter__(self): """Begin ``with`` block.""" + observability_options = self._database.observability_options + self.__span = trace_end_explicitly( + "CloudSpanner.Database.batch", + None, + observability_options=observability_options, + ) session = self._session = self._database._pool.get() batch = self._batch = Batch(session) if self._request_options.transaction_tag: @@ -1214,6 +1224,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): extra={"commit_stats": self._batch.commit_stats}, ) self._database._pool.put(self._session) + if self.__span: + self.__span.end() + self.__span = None class MutationGroupsCheckout(object): @@ -1271,9 +1284,20 @@ def __init__(self, database, **kw): self._database = database self._session = None self._kw = kw + self.__span = None def __enter__(self): """Begin ``with`` block.""" + observability_options = self._database.observability_options + attributes = dict() + if self._kw: + attributes["multi_use"] = self._kw["multi_use"] + self.__span = trace_end_explicitly( + "CloudSpanner.Database.snapshot", + None, + attributes, + observability_options=observability_options, + ) session = self._session = self._database._pool.get() return Snapshot(session, **self._kw) @@ -1287,6 +1311,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._session.create() self._database._pool.put(self._session) + if self.__span: + self.__span.end() + self.__span = None + class BatchSnapshot(object): """Wrapper for generating and processing read / query batches. @@ -1317,6 +1345,15 @@ def __init__( self._transaction_id = transaction_id self._read_timestamp = read_timestamp self._exact_staleness = exact_staleness + observability_options = getattr(self._database, "observability_options", {}) + if isinstance(observability_options, dict) and self._database: + observability_options["db_name"] = self._database.name + self.__observability_options = observability_options + self.__span = trace_end_explicitly( + "CloudSpanner.BatchSnapshot", + self._session, + observability_options=observability_options, + ) @classmethod def from_dict(cls, database, mapping): @@ -1352,6 +1389,10 @@ def to_dict(self): "transaction_id": snapshot._transaction_id, } + @property + def observability_options(self): + return self.__observability_options + def _get_session(self): """Create session as needed. @@ -1517,12 +1558,20 @@ def process_read_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - kwargs = copy.deepcopy(batch["read"]) - keyset_dict = kwargs.pop("keyset") - kwargs["keyset"] = KeySet._from_dict(keyset_dict) - return self._get_snapshot().read( - partition=batch["partition"], **kwargs, retry=retry, timeout=timeout - ) + observability_options = self.observability_options or {} + session = self._get_session() + klassname = type(self).__name__ + with trace_call( + "CloudSpanner." + klassname + ".process_read_batch", + session, + observability_options=observability_options, + ): + kwargs = copy.deepcopy(batch["read"]) + keyset_dict = kwargs.pop("keyset") + kwargs["keyset"] = KeySet._from_dict(keyset_dict) + return self._get_snapshot().read( + partition=batch["partition"], **kwargs, retry=retry, timeout=timeout + ) def generate_query_batches( self, @@ -1750,6 +1799,10 @@ def close(self): if self._session is not None: self._session.delete() + if self.__span: + self.__span.end() + self.__span = None + def _check_ddl_statements(value): """Validate DDL Statements used to define database schema. diff --git a/google/cloud/spanner_v1/merged_result_set.py b/google/cloud/spanner_v1/merged_result_set.py index 9165af9ee3..19cd8950f1 100644 --- a/google/cloud/spanner_v1/merged_result_set.py +++ b/google/cloud/spanner_v1/merged_result_set.py @@ -37,6 +37,17 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set): self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue def run(self): + observability_options = getattr( + self._batch_snapshot, "observability_options", {} + ) + with trace_call( + "CloudSpanner.PartitionExecutor.run", + None, + observability_options=observability_options, + ): + return self.__run() + + def __run(self): results = None try: results = self._batch_snapshot.process_query_batch(self._partition_id) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 71b36619a7..56173438c5 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -207,6 +207,10 @@ def delete(self): with trace_call( "CloudSpanner.DeleteSession", self, + extra_attributes={ + "session.id": self._session_id, + "session.name": self.name, + }, observability_options=observability_options, ): api.delete_session(name=self.name, metadata=metadata) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index c30fb5a225..ec739865e5 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -321,7 +321,7 @@ def read( iterator = _restart_on_unavailable( restart, request, - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner.read", self._session, trace_attributes, transaction=self, @@ -338,7 +338,7 @@ def read( iterator = _restart_on_unavailable( restart, request, - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner.read", self._session, trace_attributes, transaction=self, @@ -630,9 +630,9 @@ def partition_read( partition_options=partition_options, ) - trace_attributes = {"table_id": table, "columns": columns} + trace_attributes = {"table_id": table, "columns": columns, "index": index} with trace_call( - "CloudSpanner.PartitionReadOnlyTransaction", + "CloudSpanner.partition_read", self._session, trace_attributes, observability_options=getattr(database, "observability_options", None), @@ -735,7 +735,7 @@ def partition_query( trace_attributes = {"db.statement": sql} with trace_call( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner.partition_query", self._session, trace_attributes, observability_options=getattr(database, "observability_options", None), @@ -882,7 +882,7 @@ def begin(self): ) txn_selector = self._make_txn_selector() with trace_call( - "CloudSpanner.BeginTransaction", + "CloudSpanner.begin", self._session, observability_options=getattr(database, "observability_options", None), ):