diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index 1caac59ecd..3bfd3a0b05 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -55,15 +55,11 @@ def get_tracer(tracer_provider=None): return tracer_provider.get_tracer(TRACER_NAME, TRACER_VERSION) -@contextmanager -def trace_call(name, session, extra_attributes=None, observability_options=None): - if session: - session._last_use_time = datetime.now() - - if not HAS_OPENTELEMETRY_INSTALLED or not session: - # Empty context manager. Users will have to check if the generated value is None or a span - yield None - return +def _make_tracer_and_span_attributes( + session=None, extra_attributes=None, observability_options=None +): + if not HAS_OPENTELEMETRY_INSTALLED: + return None, None tracer_provider = None @@ -72,20 +68,24 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) # on by default. enable_extended_tracing = True + db_name = "" + if session and getattr(session, "_database", None): + db_name = session._database.name + if isinstance(observability_options, dict): # Avoid false positives with mock.Mock tracer_provider = observability_options.get("tracer_provider", None) enable_extended_tracing = observability_options.get( "enable_extended_tracing", enable_extended_tracing ) + db_name = observability_options.get("db_name", db_name) tracer = get_tracer(tracer_provider) # Set base attributes that we know for every trace created - db = session._database attributes = { "db.type": "spanner", "db.url": SpannerClient.DEFAULT_ENDPOINT, - "db.instance": "" if not db else db.name, + "db.instance": db_name, "net.host.name": SpannerClient.DEFAULT_ENDPOINT, OTEL_SCOPE_NAME: TRACER_NAME, OTEL_SCOPE_VERSION: TRACER_VERSION, @@ -99,9 +99,77 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) if not enable_extended_tracing: attributes.pop("db.statement", False) + attributes.pop("sql", False) + else: + # Otherwise there are places where the annotated sql was inserted + # directly from the arguments as "sql", and transform those into "db.statement". + db_statement = attributes.get("db.statement", None) + if not db_statement: + sql = attributes.get("sql", None) + if sql: + attributes = attributes.copy() + attributes.pop("sql", False) + attributes["db.statement"] = sql + + return tracer, attributes + + +def trace_call_end_lazily( + name, session=None, extra_attributes=None, observability_options=None +): + """ + trace_call_end_lazily is used in situations where you don't want a context managed + span in a with statement to end as soon as a block exits. This is useful for example + after a Database.batch or Database.snapshot but without a context manager. + If you need to directly invoke tracing with a context manager, please invoke + `trace_call` with which you can invoke +  `with trace_call(...) as span:` + It is the caller's responsibility to explicitly invoke the returned ending function. + """ + if not name: + return None + + tracer, span_attributes = _make_tracer_and_span_attributes( + session, extra_attributes, observability_options + ) + if not tracer: + return None + + span = tracer.start_span( + name, kind=trace.SpanKind.CLIENT, attributes=span_attributes + ) + ctx_manager = trace.use_span(span, end_on_exit=True, record_exception=True) + ctx_manager.__enter__() + + def discard(exc_type=None, exc_value=None, exc_traceback=None): + if not exc_type: + span.set_status(Status(StatusCode.OK)) + + ctx_manager.__exit__(exc_type, exc_value, exc_traceback) + + return discard + + +@contextmanager +def trace_call(name, session=None, extra_attributes=None, observability_options=None): + """ +  trace_call is used in situations where you need to end a span with a context manager +  or after a scope is exited. If you need to keep a span alive and lazily end it, please +  invoke `trace_call_end_lazily`. + """ + if not name: + yield None + return + + tracer, span_attributes = _make_tracer_and_span_attributes( + session, extra_attributes, observability_options + ) + if not tracer: + yield None + return with tracer.start_as_current_span( - name, kind=trace.SpanKind.CLIENT, attributes=attributes + name, kind=trace.SpanKind.CLIENT, attributes=span_attributes ) as span: try: yield span @@ -128,6 +196,15 @@ def get_current_span(): return trace.get_current_span() -def add_span_event(span, event_name, event_attributes=None): +def add_event_on_current_span(event_name, attributes=None, span=None): + if not span: + span = get_current_span() + + if span: + span.add_event(event_name, attributes) + + +def record_span_exception_and_status(span, exc): if span: - span.add_event(event_name, event_attributes) + span.set_status(Status(StatusCode.ERROR, str(exc))) + span.record_exception(exc) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 948740d7d4..b4a250c6ec 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -26,7 +26,11 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_event_on_current_span, + trace_call, + trace_call_end_lazily, +) from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1._helpers import _retry from google.cloud.spanner_v1._helpers import _check_rst_stream_error @@ -46,6 +50,12 @@ class _BatchBase(_SessionWrapper): def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations = [] + self.__base_discard_span = trace_call_end_lazily( + f"CloudSpanner.{type(self).__name__}", + self._session, + None, + getattr(self._session._database, "observability_options", None), + ) def _check_state(self): """Helper for :meth:`commit` et al. @@ -69,6 +79,10 @@ def insert(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ + add_event_on_current_span( + "insert mutations added", + dict(table=table, columns=columns), + ) self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values))) def update(self, table, columns, values): @@ -84,6 +98,10 @@ def update(self, table, columns, values): :param values: Values to be modified. """ self._mutations.append(Mutation(update=_make_write_pb(table, columns, values))) + add_event_on_current_span( + "update mutations added", + dict(table=table, columns=columns), + ) def insert_or_update(self, table, columns, values): """Insert/update one or more table rows. @@ -100,6 +118,10 @@ def insert_or_update(self, table, columns, values): self._mutations.append( Mutation(insert_or_update=_make_write_pb(table, columns, values)) ) + add_event_on_current_span( + "insert_or_update mutations added", + dict(table=table, columns=columns), + ) def replace(self, table, columns, values): """Replace one or more table rows. @@ -114,6 +136,10 @@ def replace(self, table, columns, values): :param values: Values to be modified. """ self._mutations.append(Mutation(replace=_make_write_pb(table, columns, values))) + add_event_on_current_span( + "replace mutations added", + dict(table=table, columns=columns), + ) def delete(self, table, keyset): """Delete one or more table rows. @@ -126,6 +152,21 @@ def delete(self, table, keyset): """ delete = Mutation.Delete(table=table, key_set=keyset._to_pb()) self._mutations.append(Mutation(delete=delete)) + add_event_on_current_span( + "delete mutations added", + dict(table=table), + ) + + def _discard_on_end(self, exc_type=None, exc_val=None, exc_traceback=None): + if self.__base_discard_span: + self.__base_discard_span(exc_type, exc_val, exc_traceback) + self.__base_discard_span = None + + def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): + self._discard_on_end(exc_type, exc_val, exc_traceback) + + def __enter__(self): + return self class Batch(_BatchBase): @@ -207,7 +248,7 @@ def commit( ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.Commit", + "CloudSpanner.Batch.commit", self._session, trace_attributes, observability_options=observability_options, @@ -223,11 +264,20 @@ def commit( ) self.committed = response.commit_timestamp self.commit_stats = response.commit_stats + self._discard_on_end() return self.committed def __enter__(self): """Begin ``with`` block.""" self._check_state() + observability_options = getattr( + self._session._database, "observability_options", None + ) + self.__discard_span = trace_call_end_lazily( + "CloudSpanner.Batch", + self._session, + observability_options=observability_options, + ) return self @@ -235,6 +285,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" if exc_type is None: self.commit() + if self.__discard_span: + self.__discard_span(exc_type, exc_val, exc_tb) + self.__discard_span = None + self._discard_on_end() class MutationGroup(_BatchBase): @@ -326,7 +380,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", self._session, trace_attributes, observability_options=observability_options, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index c8230ab503..7d8d75f285 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -53,6 +53,12 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_event_on_current_span, + get_current_span, + trace_call, + trace_call_end_lazily, +) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups from google.cloud.spanner_v1.keyset import KeySet @@ -67,10 +73,6 @@ SpannerGrpcTransport, ) from google.cloud.spanner_v1.table import Table -from google.cloud.spanner_v1._opentelemetry_tracing import ( - add_span_event, - get_current_span, -) SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data" @@ -698,11 +700,16 @@ def execute_partitioned_dml( ) def execute_pdml(): - with SessionCheckout(self._pool) as session: + def do_execute_pdml(session, span): + add_event_on_current_span("Starting BeginTransaction", span=span) txn = api.begin_transaction( session=session.name, options=txn_options, metadata=metadata ) - + add_event_on_current_span( + "Completed BeginTransaction", + {"transaction.id": txn.id}, + span=span, + ) txn_selector = TransactionSelector(id=txn.id) request = ExecuteSqlRequest( @@ -721,6 +728,7 @@ def execute_pdml(): iterator = _restart_on_unavailable( method=method, request=request, + span_name="CloudSpanner.ExecuteStreamingSql", transaction_selector=txn_selector, observability_options=self.observability_options, ) @@ -730,6 +738,13 @@ def execute_pdml(): return result_set.stats.row_count_lower_bound + with trace_call( + "CloudSpanner.Database.execute_partitioned_pdml", + observability_options=self.observability_options, + ) as span: + with SessionCheckout(self._pool) as session: + return do_execute_pdml(session, span) + return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() def session(self, labels=None, database_role=None): @@ -891,8 +906,13 @@ def run_in_transaction(self, func, *args, **kw): # Check out a session and run the function in a transaction; once # done, flip the sanity check bit back. try: - with SessionCheckout(self._pool) as session: - return session.run_in_transaction(func, *args, **kw) + observability_options = getattr(self, "observability_options", None) + with trace_call( + "CloudSpanner.Database.run_in_transaction", + observability_options=observability_options, + ): + with SessionCheckout(self._pool) as session: + return session.run_in_transaction(func, *args, **kw) finally: self._local.transaction_running = False @@ -1120,7 +1140,12 @@ def observability_options(self): if not (self._instance and self._instance._client): return None - return getattr(self._instance._client, "observability_options", None) + opts = getattr(self._instance._client, "observability_options", None) + if not opts: + opts = dict() + + opts["db_name"] = self.name + return opts class BatchCheckout(object): @@ -1165,12 +1190,20 @@ def __init__( self._request_options = request_options self._max_commit_delay = max_commit_delay self._exclude_txn_from_change_streams = exclude_txn_from_change_streams + self.__span_ctx_manager = None def __enter__(self): """Begin ``with`` block.""" + observability_options = getattr(self._database, "observability_options", None) + self.__span_ctx_manager = trace_call_end_lazily( + "CloudSpanner.Database.batch", + observability_options=observability_options, + ) current_span = get_current_span() session = self._session = self._database._pool.get() - add_span_event(current_span, "Using session", {"id": session.session_id}) + add_event_on_current_span( + "Using session", {"id": session.session_id}, current_span + ) batch = self._batch = Batch(session) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag @@ -1192,12 +1225,17 @@ def __exit__(self, exc_type, exc_val, exc_tb): "CommitStats: {}".format(self._batch.commit_stats), extra={"commit_stats": self._batch.commit_stats}, ) + + if self.__span_ctx_manager: + self.__span_ctx_manager(exc_type, exc_val, exc_tb) + self.__span_ctx_manager = None + self._database._pool.put(self._session) current_span = get_current_span() - add_span_event( - current_span, + add_event_on_current_span( "Returned session to pool", {"id": self._session.session_id}, + current_span, ) @@ -1256,9 +1294,19 @@ def __init__(self, database, **kw): self._database = database self._session = None self._kw = kw + self.__span_ctx_manager = None def __enter__(self): """Begin ``with`` block.""" + observability_options = getattr(self._database, "observability_options", {}) + attributes = None + if self._kw: + attributes = dict(multi_use=self._kw.get("multi_use", False)) + self.__span_ctx_manager = trace_call_end_lazily( + "CloudSpanner.Database.snapshot", + extra_attributes=attributes, + observability_options=observability_options, + ) session = self._session = self._database._pool.get() return Snapshot(session, **self._kw) @@ -1270,6 +1318,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not self._session.exists(): self._session = self._database._pool._new_session() self._session.create() + + if self.__span_ctx_manager: + self.__span_ctx_manager(exc_type, exc_val, exc_tb) + self.__span_ctx_manager = None + self._database._pool.put(self._session) @@ -1302,6 +1355,13 @@ def __init__( self._transaction_id = transaction_id self._read_timestamp = read_timestamp self._exact_staleness = exact_staleness + observability_options = getattr(self._database, "observability_options", {}) + self.__observability_options = observability_options + self.__span_ctx_manager = trace_call_end_lazily( + "CloudSpanner.BatchSnapshot", + self._session, + observability_options=observability_options, + ) @classmethod def from_dict(cls, database, mapping): @@ -1337,6 +1397,10 @@ def to_dict(self): "transaction_id": snapshot._transaction_id, } + @property + def observability_options(self): + return self.__observability_options + def _get_session(self): """Create session as needed. @@ -1456,27 +1520,32 @@ def generate_read_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_read( - table=table, - columns=columns, - keyset=keyset, - index=index, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_read_batches", + extra_attributes=dict(table=table, columns=columns), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_read( + table=table, + columns=columns, + keyset=keyset, + index=index, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - read_info = { - "table": table, - "columns": columns, - "keyset": keyset._to_dict(), - "index": index, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - for partition in partitions: - yield {"partition": partition, "read": read_info.copy()} + read_info = { + "table": table, + "columns": columns, + "keyset": keyset._to_dict(), + "index": index, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + for partition in partitions: + yield {"partition": partition, "read": read_info.copy()} def process_read_batch( self, @@ -1502,12 +1571,17 @@ def process_read_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - kwargs = copy.deepcopy(batch["read"]) - keyset_dict = kwargs.pop("keyset") - kwargs["keyset"] = KeySet._from_dict(keyset_dict) - return self._get_snapshot().read( - partition=batch["partition"], **kwargs, retry=retry, timeout=timeout - ) + observability_options = self.observability_options or {} + with trace_call( + f"CloudSpanner.{type(self).__name__}.process_read_batch", + observability_options=observability_options, + ): + kwargs = copy.deepcopy(batch["read"]) + keyset_dict = kwargs.pop("keyset") + kwargs["keyset"] = KeySet._from_dict(keyset_dict) + return self._get_snapshot().read( + partition=batch["partition"], **kwargs, retry=retry, timeout=timeout + ) def generate_query_batches( self, @@ -1582,34 +1656,39 @@ def generate_query_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_query( - sql=sql, - params=params, - param_types=param_types, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_query_batches", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_query( + sql=sql, + params=params, + param_types=param_types, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - query_info = { - "sql": sql, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - if params: - query_info["params"] = params - query_info["param_types"] = param_types - - # Query-level options have higher precedence than client-level and - # environment-level options - default_query_options = self._database._instance._client._query_options - query_info["query_options"] = _merge_query_options( - default_query_options, query_options - ) + query_info = { + "sql": sql, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + if params: + query_info["params"] = params + query_info["param_types"] = param_types + + # Query-level options have higher precedence than client-level and + # environment-level options + default_query_options = self._database._instance._client._query_options + query_info["query_options"] = _merge_query_options( + default_query_options, query_options + ) - for partition in partitions: - yield {"partition": partition, "query": query_info} + for partition in partitions: + yield {"partition": partition, "query": query_info} def process_query_batch( self, @@ -1634,9 +1713,16 @@ def process_query_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self._get_snapshot().execute_sql( - partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.process_query_batch", + observability_options=self.observability_options, + ): + return self._get_snapshot().execute_sql( + partition=batch["partition"], + **batch["query"], + retry=retry, + timeout=timeout, + ) def run_partitioned_query( self, @@ -1691,18 +1777,23 @@ def run_partitioned_query( :rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet` :returns: a result set instance which can be used to consume rows. """ - partitions = list( - self.generate_query_batches( - sql, - params, - param_types, - partition_size_bytes, - max_partitions, - query_options, - data_boost_enabled, + with trace_call( + f"CloudSpanner.${type(self).__name__}.run_partitioned_query", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = list( + self.generate_query_batches( + sql, + params, + param_types, + partition_size_bytes, + max_partitions, + query_options, + data_boost_enabled, + ) ) - ) - return MergedResultSet(self, partitions, 0) + return MergedResultSet(self, partitions, 0) def process(self, batch): """Process a single, partitioned query or read. @@ -1735,6 +1826,10 @@ def close(self): if self._session is not None: self._session.delete() + if self.__span_ctx_manager: + self.__span_ctx_manager() + self.__span_ctx_manager = None + def _check_ddl_statements(value): """Validate DDL Statements used to define database schema. diff --git a/google/cloud/spanner_v1/merged_result_set.py b/google/cloud/spanner_v1/merged_result_set.py index 9165af9ee3..9eb05cca0f 100644 --- a/google/cloud/spanner_v1/merged_result_set.py +++ b/google/cloud/spanner_v1/merged_result_set.py @@ -19,6 +19,9 @@ if TYPE_CHECKING: from google.cloud.spanner_v1.database import BatchSnapshot +from google.cloud.spanner_v1._opentelemetry_tracing import ( + trace_call, +) QUEUE_SIZE_PER_WORKER = 32 MAX_PARALLELISM = 16 @@ -37,6 +40,17 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set): self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue def run(self): + observability_options = getattr( + self._batch_snapshot, "observability_options", {} + ) + with trace_call( + "CloudSpanner.PartitionExecutor.run", + None, + observability_options=observability_options, + ): + return self.__run() + + def __run(self): results = None try: results = self._batch_snapshot.process_query_batch(self._partition_id) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 4f90196b4a..45b99582d8 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -26,8 +26,9 @@ _metadata_with_leader_aware_routing, ) from google.cloud.spanner_v1._opentelemetry_tracing import ( - add_span_event, get_current_span, + add_event_on_current_span, + trace_call, ) from warnings import warn @@ -206,10 +207,10 @@ def bind(self, database): span_event_attributes = {"kind": type(self).__name__} if requested_session_count <= 0: - add_span_event( - span, + add_event_on_current_span( f"Invalid session pool size({requested_session_count}) <= 0", span_event_attributes, + span, ) return @@ -221,14 +222,16 @@ def bind(self, database): ) self._database_role = self._database_role or self._database.database_role if requested_session_count > 0: - add_span_event( - span, + add_event_on_current_span( f"Requesting {requested_session_count} sessions", span_event_attributes, + span, ) if self._sessions.full(): - add_span_event(span, "Session pool is already full", span_event_attributes) + add_event_on_current_span( + "Session pool is already full", span_event_attributes, span + ) return request = BatchCreateSessionsRequest( @@ -237,29 +240,33 @@ def bind(self, database): session_template=Session(creator_role=self.database_role), ) - returned_session_count = 0 - while not self._sessions.full(): - request.session_count = requested_session_count - self._sessions.qsize() - add_span_event( + observability_options = getattr(self._database, "observability_options", None) + with trace_call( + "Cloudspanner.FixedPool.BatchCreateSessions", + observability_options=observability_options, + ) as span: + n_created = 0 + while not self._sessions.full(): + resp = api.batch_create_sessions( + request=request, + metadata=metadata, + ) + + add_event_on_current_span( + "Created sessions", dict(count=len(resp.session)), span + ) + + for session_pb in resp.session: + session = self._new_session() + session._session_id = session_pb.name.split("/")[-1] + self._sessions.put(session) + n_created += 1 + + add_event_on_current_span( + "Finished creating sessions", + dict(requested_count=request.session_count, created_count=n_created), span, - f"Creating {request.session_count} sessions", - span_event_attributes, - ) - resp = api.batch_create_sessions( - request=request, - metadata=metadata, ) - for session_pb in resp.session: - session = self._new_session() - session._session_id = session_pb.name.split("/")[-1] - self._sessions.put(session) - returned_session_count += 1 - - add_span_event( - span, - f"Requested for {requested_session_count} sessions, returned {returned_session_count}", - span_event_attributes, - ) def get(self, timeout=None): """Check a session out from the pool. @@ -278,14 +285,16 @@ def get(self, timeout=None): start_time = time.time() current_span = get_current_span() span_event_attributes = {"kind": type(self).__name__} - add_span_event(current_span, "Acquiring session", span_event_attributes) + add_event_on_current_span( + "Acquiring session", span_event_attributes, current_span + ) session = None try: - add_span_event( - current_span, + add_event_on_current_span( "Waiting for a session to become available", span_event_attributes, + current_span, ) session = self._sessions.get(block=True, timeout=timeout) @@ -293,10 +302,10 @@ def get(self, timeout=None): if age >= self._max_age and not session.exists(): if not session.exists(): - add_span_event( - current_span, + add_event_on_current_span( "Session is not valid, recreating it", span_event_attributes, + current_span, ) session = self._database.session() session.create() @@ -305,11 +314,13 @@ def get(self, timeout=None): span_event_attributes["session.id"] = session._session_id span_event_attributes["time.elapsed"] = time.time() - start_time - add_span_event(current_span, "Acquired session", span_event_attributes) + add_event_on_current_span( + "Acquired session", span_event_attributes, current_span + ) except queue.Empty as e: - add_span_event( - current_span, "No sessions available in the pool", span_event_attributes + add_event_on_current_span( + "No sessions available in the pool", span_event_attributes, current_span ) raise e @@ -387,29 +398,31 @@ def get(self): """ current_span = get_current_span() span_event_attributes = {"kind": type(self).__name__} - add_span_event(current_span, "Acquiring session", span_event_attributes) + add_event_on_current_span( + "Acquiring session", span_event_attributes, current_span + ) try: - add_span_event( - current_span, + add_event_on_current_span( "Waiting for a session to become available", span_event_attributes, + current_span, ) session = self._sessions.get_nowait() except queue.Empty: - add_span_event( - current_span, + add_event_on_current_span( "No sessions available in pool. Creating session", span_event_attributes, + current_span, ) session = self._new_session() session.create() else: if not session.exists(): - add_span_event( - current_span, + add_event_on_current_span( "Session is not valid, recreating it", span_event_attributes, + current_span, ) session = self._new_session() session.create() @@ -523,52 +536,64 @@ def bind(self, database): current_span = get_current_span() requested_session_count = request.session_count if requested_session_count <= 0: - add_span_event( - current_span, + add_event_on_current_span( f"Invalid session pool size({requested_session_count}) <= 0", span_event_attributes, + current_span, ) return - add_span_event( - current_span, - f"Requesting {requested_session_count} sessions", + add_event_on_current_span( + f"Requesting for {requested_session_count} sessions", span_event_attributes, + current_span, ) - if created_session_count >= self.size: - add_span_event( - current_span, - "Created no new sessions as sessionPool is full", - span_event_attributes, - ) - return + observability_options = getattr(self._database, "observability_options", None) + with trace_call( + "Cloudspanner.PingingPool.BatchCreateSessions", + observability_options=observability_options, + ) as span: + while created_session_count < self.size: + resp = api.batch_create_sessions( + request=request, + metadata=metadata, + ) - add_span_event( - current_span, - f"Creating {request.session_count} sessions", - span_event_attributes, - ) + add_event_on_current_span( + f"Created {len(resp.session)} sessions", + span=span, + ) - returned_session_count = 0 - while created_session_count < self.size: - resp = api.batch_create_sessions( - request=request, - metadata=metadata, - ) - for session_pb in resp.session: - session = self._new_session() - session._session_id = session_pb.name.split("/")[-1] - self.put(session) - returned_session_count += 1 + for session_pb in resp.session: + session = self._new_session() + session._session_id = session_pb.name.split("/")[-1] + self.put(session) - created_session_count += len(resp.session) + created_session_count += len(resp.session) - add_span_event( - current_span, - f"Requested for {requested_session_count} sessions, return {returned_session_count}", - span_event_attributes, - ) + if created_session_count >= self.size: + add_event_on_current_span( + "Created no new sessions as sessionPool is full", + span_event_attributes, + current_span, + ) + return + + add_event_on_current_span( + f"Requested for {requested_session_count} sessions, return {returned_session_count}", + span_event_attributes, + span, + ) + + add_event_on_current_span( + f"Finished creating sessions", + dict( + requested_count=request.session_count, + created_count=created_session_count, + ), + span, + ) def get(self, timeout=None): """Check a session out from the pool. @@ -587,10 +612,10 @@ def get(self, timeout=None): start_time = time.time() span_event_attributes = {"kind": type(self).__name__} current_span = get_current_span() - add_span_event( - current_span, + add_event_on_current_span( "Waiting for a session to become available", span_event_attributes, + current_span, ) ping_after = None @@ -598,10 +623,10 @@ def get(self, timeout=None): try: ping_after, session = self._sessions.get(block=True, timeout=timeout) except queue.Empty as e: - add_span_event( - current_span, + add_event_on_current_span( "No sessions available in the pool within the specified timeout", span_event_attributes, + current_span, ) raise e @@ -620,7 +645,9 @@ def get(self, timeout=None): "kind": "pinging_pool", } ) - add_span_event(current_span, "Acquired session", span_event_attributes) + add_event_on_current_span( + "Acquired session", span_event_attributes, current_span + ) return session def put(self, session): diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 166d5488c6..cb859c9d9d 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -32,8 +32,9 @@ _metadata_with_leader_aware_routing, ) from google.cloud.spanner_v1._opentelemetry_tracing import ( - add_span_event, get_current_span, + add_event_on_current_span, + record_span_exception_and_status, trace_call, ) from google.cloud.spanner_v1.batch import Batch @@ -139,7 +140,7 @@ def create(self): :raises ValueError: if :attr:`session_id` is already set. """ current_span = get_current_span() - add_span_event(current_span, "Creating Session") + add_event_on_current_span("Creating Session", span=current_span) if self._session_id is not None: raise ValueError("Session ID already set by back-end") @@ -183,14 +184,16 @@ def exists(self): """ current_span = get_current_span() if self._session_id is None: - add_span_event( - current_span, + add_event_on_current_span( "Checking session existence: Session does not exist as it has not been created yet", + span=current_span, ) return False - add_span_event( - current_span, "Checking if Session exists", {"session.id": self._session_id} + add_event_on_current_span( + "Checking if Session exists", + {"session.id": self._session_id}, + current_span, ) api = self._database.spanner_api @@ -228,13 +231,16 @@ def delete(self): """ current_span = get_current_span() if self._session_id is None: - add_span_event( - current_span, "Deleting Session failed due to unset session_id" + add_event_on_current_span( + "Deleting Session failed due to unset session_id", + current_span, ) raise ValueError("Session ID not set by back-end") - add_span_event( - current_span, "Deleting Session", {"session.id": self._session_id} + add_event_on_current_span( + "Deleting Session", + {"session.id": self._session_id}, + current_span, ) api = self._database.spanner_api @@ -243,6 +249,10 @@ def delete(self): with trace_call( "CloudSpanner.DeleteSession", self, + extra_attributes={ + "session.id": self._session_id, + "session.name": self.name, + }, observability_options=observability_options, ): api.delete_session(name=self.name, metadata=metadata) @@ -458,47 +468,99 @@ def run_in_transaction(self, func, *args, **kw): ) attempts = 0 - while True: - if self._transaction is None: - txn = self.transaction() - txn.transaction_tag = transaction_tag - txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams - else: - txn = self._transaction + observability_options = getattr(self._database, "observability_options", None) + with trace_call( + "CloudSpanner.session.run_in_transaction", + self, + observability_options=observability_options, + ) as span: + while True: + if self._transaction is None: + add_event_on_current_span("Creating Transaction", span=span) + txn = self.transaction() + txn.transaction_tag = transaction_tag + txn.exclude_txn_from_change_streams = ( + exclude_txn_from_change_streams + ) + else: + txn = self._transaction - try: attempts += 1 - return_value = func(txn, *args, **kw) - except Aborted as exc: - del self._transaction - _delay_until_retry(exc, deadline, attempts) - continue - except GoogleAPICallError: - del self._transaction - raise - except Exception: - txn.rollback() - raise - try: - txn.commit( - return_commit_stats=self._database.log_commit_stats, - request_options=commit_request_options, - max_commit_delay=max_commit_delay, - ) - except Aborted as exc: - del self._transaction - _delay_until_retry(exc, deadline, attempts) - except GoogleAPICallError: - del self._transaction - raise - else: - if self._database.log_commit_stats and txn.commit_stats: - self._database.logger.info( - "CommitStats: {}".format(txn.commit_stats), - extra={"commit_stats": txn.commit_stats}, + txn_id = getattr(txn, "_transaction_id", "") or "" + span_attributes = {"attempt": attempts} + if txn_id: + span_attributes["transaction.id"] = txn_id + + add_event_on_current_span("Using Transaction", span_attributes, span) + + try: + return_value = func(txn, *args, **kw) + except Aborted as exc: + del self._transaction + if span: + delay_seconds = _get_retry_delay(exc.errors[0], attempts) + attributes = dict(delay_seconds=delay_seconds) + attributes.update(span_attributes) + record_span_exception_and_status(span, exc) + add_event_on_current_span( + "Transaction was aborted, retrying", attributes, span + ) + + _delay_until_retry(exc, deadline, attempts) + continue + except GoogleAPICallError: + del self._transaction + add_event_on_current_span( + "Transaction.commit failed due to GoogleAPICallError, not retrying", + span_attributes, + span, + ) + raise + except Exception: + add_event_on_current_span( + "Invoking Transaction.rollback(), not retrying", + span_attributes, + span, + ) + txn.rollback() + raise + + try: + txn.commit( + return_commit_stats=self._database.log_commit_stats, + request_options=commit_request_options, + max_commit_delay=max_commit_delay, ) - return return_value + except Aborted as exc: + del self._transaction + if span: + delay_seconds = _get_retry_delay(exc.errors[0], attempts) + attributes = dict(delay_seconds=delay_seconds) + attributes.update(span_attributes) + add_event_on_current_span( + "Transaction.commit was aborted, retrying afresh", + attributes, + span, + ) + + _delay_until_retry(exc, deadline, attempts) + except GoogleAPICallError: + del self._transaction + if span: + add_event_on_current_span( + "Transaction.commit failed due to GoogleAPICallError, not retrying", + span_attributes, + span, + ) + raise + else: + if self._database.log_commit_stats and txn.commit_stats: + self._database.logger.info( + "CommitStats: {}".format(txn.commit_stats), + extra={"commit_stats": txn.commit_stats}, + ) + return return_value # Rational: this function factors out complex shared deadline / retry diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 89b5094706..239db8211e 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -52,7 +52,7 @@ def _restart_on_unavailable( method, request, - trace_name=None, + span_name=None, session=None, attributes=None, transaction=None, @@ -88,9 +88,10 @@ def _restart_on_unavailable( request.transaction = transaction_selector with trace_call( - trace_name, session, attributes, observability_options=observability_options + span_name, session, attributes, observability_options=observability_options ): iterator = method(request=request) + while True: try: for item in iterator: @@ -110,7 +111,7 @@ def _restart_on_unavailable( except ServiceUnavailable: del item_buffer[:] with trace_call( - trace_name, + span_name, session, attributes, observability_options=observability_options, @@ -130,7 +131,7 @@ def _restart_on_unavailable( raise del item_buffer[:] with trace_call( - trace_name, + span_name, session, attributes, observability_options=observability_options, @@ -329,13 +330,14 @@ def read( trace_attributes = {"table_id": table, "columns": columns} observability_options = getattr(database, "observability_options", None) + span_name = f"CloudSpanner.{type(self).__name__}.read" if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: iterator = _restart_on_unavailable( restart, request, - "CloudSpanner.ReadOnlyTransaction", + span_name, self._session, trace_attributes, transaction=self, @@ -357,7 +359,7 @@ def read( iterator = _restart_on_unavailable( restart, request, - "CloudSpanner.ReadOnlyTransaction", + span_name, self._session, trace_attributes, transaction=self, @@ -578,7 +580,7 @@ def _get_streamed_result_set( iterator = _restart_on_unavailable( restart, request, - "CloudSpanner.ReadWriteTransaction", + f"CloudSpanner.{type(self).__name__}.execute_streaming_sql", self._session, trace_attributes, transaction=self, @@ -675,8 +677,12 @@ def partition_read( ) trace_attributes = {"table_id": table, "columns": columns} + can_include_index = (index != "") and (index is not None) + if can_include_index: + trace_attributes["index"] = index + with trace_call( - "CloudSpanner.PartitionReadOnlyTransaction", + f"CloudSpanner.{type(self).__name__}.partition_read", self._session, trace_attributes, observability_options=getattr(database, "observability_options", None), @@ -779,7 +785,7 @@ def partition_query( trace_attributes = {"db.statement": sql} with trace_call( - "CloudSpanner.PartitionReadWriteTransaction", + f"CloudSpanner.{type(self).__name__}.partition_query", self._session, trace_attributes, observability_options=getattr(database, "observability_options", None), @@ -926,7 +932,7 @@ def begin(self): ) txn_selector = self._make_txn_selector() with trace_call( - "CloudSpanner.BeginTransaction", + f"CloudSpanner.{type(self).__name__}.begin", self._session, observability_options=getattr(database, "observability_options", None), ): diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index fa8e5121ff..da35666c7a 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -32,7 +32,10 @@ from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase -from google.cloud.spanner_v1._opentelemetry_tracing import add_span_event, trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_event_on_current_span, + trace_call, +) from google.cloud.spanner_v1 import RequestOptions from google.api_core import gapic_v1 from google.api_core.exceptions import InternalServerError @@ -157,7 +160,7 @@ def begin(self): ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.BeginTransaction", + f"CloudSpanner.{type(self).__name__}.begin", self._session, observability_options=observability_options, ) as span: @@ -169,10 +172,10 @@ def begin(self): ) def beforeNextRetry(nthRetry, delayInSeconds): - add_span_event( - span, + add_event_on_current_span( "Transaction Begin Attempt Failed. Retrying", {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + span, ) response = _retry( @@ -199,7 +202,7 @@ def rollback(self): ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.Rollback", + f"CloudSpanner.{type(self).__name__}.rollback", self._session, observability_options=observability_options, ): @@ -215,6 +218,7 @@ def rollback(self): ) self.rolled_back = True del self._session._transaction + self._discard_on_end() def commit( self, return_commit_stats=False, request_options=None, max_commit_delay=None @@ -278,12 +282,12 @@ def commit( trace_attributes = {"num_mutations": len(self._mutations)} observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.Commit", + f"CloudSpanner.{type(self).__name__}.commit", self._session, trace_attributes, observability_options, ) as span: - add_span_event(span, "Starting Commit") + add_event_on_current_span("Starting Commit", span=span) method = functools.partial( api.commit, @@ -292,10 +296,10 @@ def commit( ) def beforeNextRetry(nthRetry, delayInSeconds): - add_span_event( - span, + add_event_on_current_span( "Transaction Commit Attempt Failed. Retrying", {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + span, ) response = _retry( @@ -304,13 +308,13 @@ def beforeNextRetry(nthRetry, delayInSeconds): beforeNextRetry=beforeNextRetry, ) - add_span_event(span, "Commit Done") - - self.committed = response.commit_timestamp - if return_commit_stats: - self.commit_stats = response.commit_stats - del self._session._transaction - return self.committed + add_event_on_current_span("Commit Done", span=span) + self.committed = response.commit_timestamp + if return_commit_stats: + self.commit_stats = response.commit_stats + del self._session._transaction + self._discard_on_end() + return self.committed @staticmethod def _make_params_pb(params, param_types): @@ -447,7 +451,7 @@ def execute_update( response = self._execute_request( method, request, - "CloudSpanner.ReadWriteTransaction", + f"CloudSpanner.{type(self).__name__}.execute_update", self._session, trace_attributes, observability_options=observability_options, @@ -464,7 +468,7 @@ def execute_update( response = self._execute_request( method, request, - "CloudSpanner.ReadWriteTransaction", + f"CloudSpanner.{type(self).__name__}.execute_update", self._session, trace_attributes, observability_options=observability_options, diff --git a/tests/_helpers.py b/tests/_helpers.py index 81787c5a86..0c35c1c30c 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -78,7 +78,7 @@ def tearDown(self): def assertNoSpans(self): if HAS_OPENTELEMETRY_INSTALLED: - span_list = self.ot_exporter.get_finished_spans() + span_list = self.get_finished_spans() self.assertEqual(len(span_list), 0) def assertSpanAttributes( @@ -86,7 +86,7 @@ def assertSpanAttributes( ): if HAS_OPENTELEMETRY_INSTALLED: if not span: - span_list = self.ot_exporter.get_finished_spans() + span_list = self.get_finished_spans() self.assertEqual(len(span_list) > 0, True) span = span_list[0] @@ -118,12 +118,13 @@ def assertSpanNames(self, want_span_names): self.assertEqual(got_span_names, want_span_names) def get_finished_spans(self): - if HAS_OPENTELEMETRY_INSTALLED: - return list( - filter( - lambda span: span and span.name, - self.ot_exporter.get_finished_spans(), - ) - ) - else: + if not HAS_OPENTELEMETRY_INSTALLED: return [] + + spans = self.ot_exporter.get_finished_spans() + # A span with name=None is the result from invoking trace_call without + # intention to trace, hence these have to be filtered out. + return list(filter(lambda span: span.name, spans)) + + def reset(self): + self.tearDown() diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index 8382255c15..009f4eb01e 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -37,7 +37,7 @@ not HAS_OTEL_INSTALLED, reason="OpenTelemetry is necessary to test traces." ) @pytest.mark.skipif( - not _helpers.USE_EMULATOR, reason="mulator is necessary to test traces." + not _helpers.USE_EMULATOR, reason="emulator is necessary to test traces." ) def test_observability_options_propagation(): PROJECT = _helpers.EMULATOR_PROJECT @@ -105,16 +105,21 @@ def test_propagation(enable_extended_tracing): len(from_inject_spans) >= 2 ) # "Expecting at least 2 spans from the injected trace exporter" gotNames = [span.name for span in from_inject_spans] - wantNames = ["CloudSpanner.CreateSession", "CloudSpanner.ReadWriteTransaction"] + wantNames = [ + "CloudSpanner.CreateSession", + "CloudSpanner.Snapshot.execute_streaming_sql", + "CloudSpanner.Database.snapshot", + ] assert gotNames == wantNames # Check for conformance of enable_extended_tracing - lastSpan = from_inject_spans[len(from_inject_spans) - 1] + snapshot_execute_span = from_inject_spans[len(from_inject_spans) - 2] wantAnnotatedSQL = "SELECT 1" if not enable_extended_tracing: wantAnnotatedSQL = None assert ( - lastSpan.attributes.get("db.statement", None) == wantAnnotatedSQL + snapshot_execute_span.attributes.get("db.statement", None) + == wantAnnotatedSQL ) # "Mismatch in annotated sql" try: @@ -132,3 +137,48 @@ def _make_credentials(): from google.auth.credentials import AnonymousCredentials return AnonymousCredentials() + + +from tests import _helpers as ot_helpers + + +@pytest.mark.skipif( + not ot_helpers.HAS_OPENTELEMETRY_INSTALLED, + reason="Tracing requires OpenTelemetry", +) +def test_trace_call_keeps_span_error_status(): + # Verifies that after our span's status was set to ERROR + # that it doesn't unconditionally get changed to OK + # per https://github.com/googleapis/python-spanner/issues/1246 + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + from google.cloud.spanner_v1._opentelemetry_tracing import trace_call + from opentelemetry.trace.status import Status, StatusCode + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.sampling import ALWAYS_ON + from opentelemetry import trace + + tracer_provider = TracerProvider(sampler=ALWAYS_ON) + trace_exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter)) + observability_options = dict(tracer_provider=tracer_provider) + + with trace_call( + "VerifyBehavior", observability_options=observability_options + ) as span: + span.set_status(Status(StatusCode.ERROR, "Our error exhibit")) + + span_list = trace_exporter.get_finished_spans() + got_statuses = [] + + for span in span_list: + got_statuses.append( + (span.name, span.status.status_code, span.status.description) + ) + + want_statuses = [ + ("VerifyBehavior", StatusCode.ERROR, "Our error exhibit"), + ] + assert got_statuses == want_statuses diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index b7337cb258..bf9dc67d6f 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -437,8 +437,6 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): if ot_exporter is not None: span_list = ot_exporter.get_finished_spans() - assert len(span_list) == 4 - assert_span_attributes( ot_exporter, "CloudSpanner.GetSession", @@ -447,21 +445,40 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): ) assert_span_attributes( ot_exporter, - "CloudSpanner.Commit", + "CloudSpanner.Batch.commit", attributes=_make_attributes(db_name, num_mutations=2), span=span_list[1], ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Batch", + attributes=_make_attributes(db_name), + span=span_list[2], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Database.batch", + attributes=_make_attributes(db_name), + span=span_list[3], + ) + assert_span_attributes( ot_exporter, "CloudSpanner.GetSession", attributes=_make_attributes(db_name, session_found=True), - span=span_list[2], + span=span_list[4], ) assert_span_attributes( ot_exporter, - "CloudSpanner.ReadOnlyTransaction", + "CloudSpanner.Snapshot.read", attributes=_make_attributes(db_name, columns=sd.COLUMNS, table_id=sd.TABLE), - span=span_list[3], + span=span_list[5], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Database.snapshot", + attributes=_make_attributes(db_name, multi_use=False), + span=span_list[6], ) @@ -608,7 +625,6 @@ def test_transaction_read_and_insert_then_rollback( if ot_exporter is not None: span_list = ot_exporter.get_finished_spans() - assert len(span_list) == 8 assert_span_attributes( ot_exporter, @@ -624,51 +640,70 @@ def test_transaction_read_and_insert_then_rollback( ) assert_span_attributes( ot_exporter, - "CloudSpanner.Commit", + "CloudSpanner.Batch.commit", attributes=_make_attributes(db_name, num_mutations=1), span=span_list[2], ) assert_span_attributes( ot_exporter, - "CloudSpanner.BeginTransaction", + "CloudSpanner.Batch", attributes=_make_attributes(db_name), span=span_list[3], ) assert_span_attributes( ot_exporter, - "CloudSpanner.ReadOnlyTransaction", + "CloudSpanner.Database.batch", + attributes=_make_attributes(db_name), + span=span_list[4], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Transaction.begin", + attributes=_make_attributes(db_name), + span=span_list[5], + ) + + assert_span_attributes( + ot_exporter, + "CloudSpanner.Transaction.read", attributes=_make_attributes( db_name, table_id=sd.TABLE, columns=sd.COLUMNS, ), - span=span_list[4], + span=span_list[6], ) assert_span_attributes( ot_exporter, - "CloudSpanner.ReadOnlyTransaction", + "CloudSpanner.Transaction.read", attributes=_make_attributes( db_name, table_id=sd.TABLE, columns=sd.COLUMNS, ), - span=span_list[5], + span=span_list[7], ) assert_span_attributes( ot_exporter, - "CloudSpanner.Rollback", + "CloudSpanner.Transaction.rollback", attributes=_make_attributes(db_name), - span=span_list[6], + span=span_list[8], ) assert_span_attributes( ot_exporter, - "CloudSpanner.ReadOnlyTransaction", + "CloudSpanner.Transaction", + attributes=_make_attributes(db_name), + span=span_list[9], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Snapshot.read", attributes=_make_attributes( db_name, table_id=sd.TABLE, columns=sd.COLUMNS, ), - span=span_list[7], + span=span_list[10], ) @@ -699,6 +734,159 @@ def _transaction_read_then_raise(transaction): assert rows == [] +@pytest.mark.skipif( + not _helpers.USE_EMULATOR, + reason="Emulator needed to run this tests", +) +@pytest.mark.skipif( + not ot_helpers.HAS_OPENTELEMETRY_INSTALLED, + reason="Tracing requires OpenTelemetry", +) +def test_transaction_abort_then_retry_spans(sessions_database, ot_exporter): + from google.auth.credentials import AnonymousCredentials + from google.api_core.exceptions import Aborted + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + from opentelemetry.trace.status import StatusCode + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.sampling import ALWAYS_ON + from opentelemetry import trace + + PROJECT = _helpers.EMULATOR_PROJECT + CONFIGURATION_NAME = "config-name" + INSTANCE_ID = _helpers.INSTANCE_ID + DISPLAY_NAME = "display-name" + DATABASE_ID = _helpers.unique_id("temp_db") + NODE_COUNT = 5 + LABELS = {"test": "true"} + + counters = dict(aborted=0) + already_aborted = False + + def select_in_txn(txn): + from google.rpc import error_details_pb2 + + results = txn.execute_sql("SELECT 1") + for row in results: + _ = row + + if counters["aborted"] == 0: + counters["aborted"] = 1 + raise Aborted( + "Thrown from ClientInterceptor for testing", + errors=[FauxCall(code_pb2.ABORTED)], + ) + + tracer_provider = TracerProvider(sampler=ALWAYS_ON) + trace_exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter)) + observability_options = dict( + tracer_provider=tracer_provider, + enable_extended_tracing=True, + ) + + client = spanner_v1.Client( + project=PROJECT, + observability_options=observability_options, + credentials=AnonymousCredentials(), + ) + + instance = client.instance( + INSTANCE_ID, + CONFIGURATION_NAME, + display_name=DISPLAY_NAME, + node_count=NODE_COUNT, + labels=LABELS, + ) + + try: + instance.create() + except Exception: + pass + + db = instance.database(DATABASE_ID) + try: + db.create() + except Exception: + pass + + db.run_in_transaction(select_in_txn) + + span_list = trace_exporter.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = [ + "CloudSpanner.CreateSession", + "CloudSpanner.Transaction.execute_streaming_sql", + "CloudSpanner.Transaction", + "CloudSpanner.Transaction.execute_streaming_sql", + "CloudSpanner.Transaction.commit", + "CloudSpanner.Transaction", + "CloudSpanner.ReadWriteTransaction", + "CloudSpanner.Database.run_in_transaction", + ] + + assert got_span_names == want_span_names + + # Let's check for the series of events + want_events = [ + ("Creating Transaction", {}), + ("Using Transaction", {"attempt": 1}), + ( + "exception", + { + "exception.type": "google.api_core.exceptions.Aborted", + "exception.message": "409 Thrown from ClientInterceptor for testing", + "exception.stacktrace": "EPHEMERAL", + "exception.escaped": "False", + }, + ), + ( + "Transaction was aborted, retrying", + {"delay_seconds": "EPHEMERAL", "attempt": 1}, + ), + ("Creating Transaction", {}), + ("Using Transaction", {"attempt": 2}), + ] + got_events = [] + got_statuses = [] + + # Some event attributes are noisy/highly ephemeral + # and can't be directly compared against. + imprecise_event_attributes = ["exception.stacktrace", "delay_seconds"] + for span in span_list: + got_statuses.append( + (span.name, span.status.status_code, span.status.description) + ) + for event in span.events: + evt_attributes = event.attributes.copy() + for attr_name in imprecise_event_attributes: + if attr_name in evt_attributes: + evt_attributes[attr_name] = "EPHEMERAL" + + got_events.append((event.name, evt_attributes)) + + assert got_events == want_events + + codes = StatusCode + want_statuses = [ + ("CloudSpanner.CreateSession", codes.OK, None), + ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), + ("CloudSpanner.Transaction", codes.UNSET, None), + ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), + ("CloudSpanner.Transaction.commit", codes.OK, None), + ("CloudSpanner.Transaction", codes.OK, None), + ( + "CloudSpanner.ReadWriteTransaction", + codes.ERROR, + "409 Thrown from ClientInterceptor for testing", + ), + ("CloudSpanner.Database.run_in_transaction", codes.OK, None), + ] + assert got_statuses == want_statuses + + @_helpers.retry_mabye_conflict def test_transaction_read_and_insert_or_update_then_commit( sessions_database, @@ -1182,19 +1370,67 @@ def unit_of_work(transaction): with tracer.start_as_current_span("Test Span"): session.run_in_transaction(unit_of_work) - span_list = ot_exporter.get_finished_spans() - assert len(span_list) == 5 + span_list = [] + for span in ot_exporter.get_finished_spans(): + if span and span.name: + span_list.append(span) + + span_list = sorted(span_list, key=lambda v1: v1.start_time) + expected_span_names = [ "CloudSpanner.CreateSession", - "CloudSpanner.Commit", - "CloudSpanner.DMLTransaction", - "CloudSpanner.Commit", + "CloudSpanner.Batch", + "CloudSpanner.Batch", + "CloudSpanner.Batch.commit", "Test Span", + "CloudSpanner.ReadWriteTransaction", + "CloudSpanner.Transaction", + "CloudSpanner.DMLTransaction", + "CloudSpanner.Transaction.commit", ] - assert [span.name for span in span_list] == expected_span_names - for span in span_list[2:-1]: - assert span.context.trace_id == span_list[-1].context.trace_id - assert span.parent.span_id == span_list[-1].context.span_id + + got_span_names = [span.name for span in span_list] + assert got_span_names == expected_span_names + + # We expect: + # |------CloudSpanner.CreateSession-------- + # + # |---Test Span----------------------------| + # |>--ReadWriteTransaction----------------- + # |>-Transaction------------------------- + # |--------------DMLTransaction-------- + # + # |>---Batch------------------------------- + # + # |>----------Batch------------------------- + # |>------------Batch.commit--------------- + + # CreateSession should have a trace of its own, with no children + # nor being a child of any other span. + session_span = span_list[0] + test_span = span_list[4] + # assert session_span.context.trace_id != test_span.context.trace_id + for span in span_list[1:]: + if span.parent: + assert span.parent.span_id != session_span.context.span_id + + def assert_parent_and_children(parent_span, children): + for span in children: + assert span.context.trace_id == parent_span.context.trace_id + assert span.parent.span_id == parent_span.context.span_id + + # [CreateSession --> Batch] should have their own trace. + rw_txn_span = span_list[5] + children_of_test_span = [rw_txn_span] + assert_parent_and_children(test_span, children_of_test_span) + + children_of_rw_txn_span = [span_list[6]] + assert_parent_and_children(rw_txn_span, children_of_rw_txn_span) + + # Batch_first should have no parent, should be in its own trace. + batch_0_span = span_list[2] + children_of_batch_0 = [span_list[1]] + assert_parent_and_children(rw_txn_span, children_of_rw_txn_span) def test_execute_partitioned_dml( diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index a7f7a6f970..ea744f8889 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -212,7 +212,7 @@ def test_commit_grpc_error(self): batch.commit() self.assertSpanAttributes( - "CloudSpanner.Commit", + "CloudSpanner.Batch.commit", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, num_mutations=1), ) @@ -261,7 +261,8 @@ def test_commit_ok(self): self.assertEqual(max_commit_delay, None) self.assertSpanAttributes( - "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) + "CloudSpanner.Batch.commit", + attributes=dict(BASE_ATTRIBUTES, num_mutations=1), ) def _test_commit_with_options( @@ -327,7 +328,8 @@ def _test_commit_with_options( self.assertEqual(actual_request_options, expected_request_options) self.assertSpanAttributes( - "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) + "CloudSpanner.Batch.commit", + attributes=dict(BASE_ATTRIBUTES, num_mutations=1), ) self.assertEqual(max_commit_delay_in, max_commit_delay) @@ -438,7 +440,8 @@ def test_context_mgr_success(self): self.assertEqual(request_options, RequestOptions()) self.assertSpanAttributes( - "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) + "CloudSpanner.Batch.commit", + attributes=dict(BASE_ATTRIBUTES, num_mutations=1), ) def test_context_mgr_failure(self): @@ -492,7 +495,7 @@ def test_batch_write_already_committed(self): group.delete(TABLE_NAME, keyset=keyset) groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -518,7 +521,7 @@ def test_batch_write_grpc_error(self): groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -580,7 +583,7 @@ def _test_batch_write_with_request_options( ) self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 479a0d62e9..4b62c475bf 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -616,7 +616,7 @@ def test_read_other_error(self): list(derived.read(TABLE_NAME, COLUMNS, keyset)) self.assertSpanAttributes( - "CloudSpanner.ReadOnlyTransaction", + "CloudSpanner._Derived.read", status=StatusCode.ERROR, attributes=dict( BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) @@ -773,7 +773,7 @@ def _read_helper( ) self.assertSpanAttributes( - "CloudSpanner.ReadOnlyTransaction", + "CloudSpanner._Derived.read", attributes=dict( BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) ), @@ -868,7 +868,7 @@ def test_execute_sql_other_error(self): self.assertEqual(derived._execute_sql_count, 1) self.assertSpanAttributes( - "CloudSpanner.ReadWriteTransaction", + "CloudSpanner._Derived.execute_streaming_sql", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) @@ -1024,7 +1024,7 @@ def _execute_sql_helper( self.assertEqual(derived._execute_sql_count, sql_count + 1) self.assertSpanAttributes( - "CloudSpanner.ReadWriteTransaction", + "CloudSpanner._Derived.execute_streaming_sql", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) @@ -1194,12 +1194,17 @@ def _partition_read_helper( timeout=timeout, ) + want_span_attributes = dict( + BASE_ATTRIBUTES, + table_id=TABLE_NAME, + columns=tuple(COLUMNS), + ) + if index: + want_span_attributes["index"] = index self.assertSpanAttributes( - "CloudSpanner.PartitionReadOnlyTransaction", + "CloudSpanner._Derived.partition_read", status=StatusCode.OK, - attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) - ), + attributes=want_span_attributes, ) def test_partition_read_single_use_raises(self): @@ -1226,7 +1231,7 @@ def test_partition_read_other_error(self): list(derived.partition_read(TABLE_NAME, COLUMNS, keyset)) self.assertSpanAttributes( - "CloudSpanner.PartitionReadOnlyTransaction", + "CloudSpanner._Derived.partition_read", status=StatusCode.ERROR, attributes=dict( BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) @@ -1369,7 +1374,7 @@ def _partition_query_helper( ) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) @@ -1387,7 +1392,7 @@ def test_partition_query_other_error(self): list(derived.partition_query(SQL_QUERY)) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) @@ -1696,8 +1701,13 @@ def test_begin_w_other_error(self): with self.assertRaises(RuntimeError): snapshot.begin() + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Snapshot.begin"] + assert got_span_names == want_span_names + self.assertSpanAttributes( - "CloudSpanner.BeginTransaction", + "CloudSpanner.Snapshot.begin", status=StatusCode.ERROR, attributes=BASE_ATTRIBUTES, ) @@ -1755,7 +1765,7 @@ def test_begin_ok_exact_staleness(self): ) self.assertSpanAttributes( - "CloudSpanner.BeginTransaction", + "CloudSpanner.Snapshot.begin", status=StatusCode.OK, attributes=BASE_ATTRIBUTES, ) @@ -1791,7 +1801,7 @@ def test_begin_ok_exact_strong(self): ) self.assertSpanAttributes( - "CloudSpanner.BeginTransaction", + "CloudSpanner.Snapshot.begin", status=StatusCode.OK, attributes=BASE_ATTRIBUTES, ) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index e426f912b2..7a1c512ec5 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -125,6 +125,7 @@ def test__make_txn_selector(self): self.assertEqual(selector.id, self.TRANSACTION_ID) def test_begin_already_begun(self): + self.reset() session = _Session() transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID @@ -134,6 +135,7 @@ def test_begin_already_begun(self): self.assertNoSpans() def test_begin_already_rolled_back(self): + self.reset() session = _Session() transaction = self._make_one(session) transaction.rolled_back = True @@ -143,6 +145,7 @@ def test_begin_already_rolled_back(self): self.assertNoSpans() def test_begin_already_committed(self): + self.reset() session = _Session() transaction = self._make_one(session) transaction.committed = object() @@ -152,6 +155,7 @@ def test_begin_already_committed(self): self.assertNoSpans() def test_begin_w_other_error(self): + self.reset() database = _Database() database.spanner_api = self._make_spanner_api() database.spanner_api.begin_transaction.side_effect = RuntimeError() @@ -161,13 +165,19 @@ def test_begin_w_other_error(self): with self.assertRaises(RuntimeError): transaction.begin() + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Transaction.begin"] + assert got_span_names == want_span_names + self.assertSpanAttributes( - "CloudSpanner.BeginTransaction", + "CloudSpanner.Transaction.begin", status=StatusCode.ERROR, attributes=TestTransaction.BASE_ATTRIBUTES, ) def test_begin_ok(self): + self.reset() from google.cloud.spanner_v1 import Transaction as TransactionPB transaction_pb = TransactionPB(id=self.TRANSACTION_ID) @@ -195,10 +205,11 @@ def test_begin_ok(self): ) self.assertSpanAttributes( - "CloudSpanner.BeginTransaction", attributes=TestTransaction.BASE_ATTRIBUTES + "CloudSpanner.Transaction.begin", attributes=TestTransaction.BASE_ATTRIBUTES ) def test_begin_w_retry(self): + self.reset() from google.cloud.spanner_v1 import ( Transaction as TransactionPB, ) @@ -266,7 +277,7 @@ def test_rollback_w_other_error(self): self.assertFalse(transaction.rolled_back) self.assertSpanAttributes( - "CloudSpanner.Rollback", + "CloudSpanner.Transaction.rollback", status=StatusCode.ERROR, attributes=TestTransaction.BASE_ATTRIBUTES, ) @@ -299,7 +310,8 @@ def test_rollback_ok(self): ) self.assertSpanAttributes( - "CloudSpanner.Rollback", attributes=TestTransaction.BASE_ATTRIBUTES + "CloudSpanner.Transaction.rollback", + attributes=TestTransaction.BASE_ATTRIBUTES, ) def test_commit_not_begun(self): @@ -344,10 +356,25 @@ def test_commit_w_other_error(self): self.assertIsNone(transaction.committed) + span_list = sorted(self.get_finished_spans(), key=lambda v: v.start_time) + + got_span_names = [span.name for span in span_list] + want_span_names = [ + "CloudSpanner.Transaction", + "CloudSpanner.Transaction", + "CloudSpanner.Transaction", + "CloudSpanner.Transaction", + "CloudSpanner.Transaction.commit", + ] + print("got_names", got_span_names) + assert got_span_names == want_span_names + + txn_commit_span = span_list[-1] self.assertSpanAttributes( - "CloudSpanner.Commit", + "CloudSpanner.Transaction.commit", status=StatusCode.ERROR, attributes=dict(TestTransaction.BASE_ATTRIBUTES, num_mutations=1), + span=txn_commit_span, ) def _commit_helper( @@ -426,12 +453,15 @@ def _commit_helper( if return_commit_stats: self.assertEqual(transaction.commit_stats.mutation_count, 4) + span_list = sorted(self.get_finished_spans(), key=lambda v: v.start_time) + txn_commit_span = span_list[-1] self.assertSpanAttributes( - "CloudSpanner.Commit", + "CloudSpanner.Transaction.commit", attributes=dict( TestTransaction.BASE_ATTRIBUTES, num_mutations=len(transaction._mutations), ), + span=txn_commit_span, ) def test_commit_no_mutations(self):