diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index be74f27d37..72ac6e7229 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -98,6 +98,17 @@ def _make_tracer_and_span_attributes( if not enable_extended_tracing: attributes.pop("db.statement", False) + attributes.pop("sql", False) + else: + # Otherwise there are places where the annotated sql was inserted + # directly from the arguments as "sql", and transform those into "db.statement". + db_statement = attributes.get("db.statement", None) + if not db_statement: + sql = attributes.get("sql", None) + if sql: + attributes = attributes.copy() + attributes.pop("sql", False) + attributes["db.statement"] = sql return tracer, attributes @@ -111,7 +122,10 @@ def trace_call_end_lazily(  context manager, please invoke `trace_call` with which you can invoke  `with trace_call(...) as span:`  It is the caller's responsibility to explicitly invoke span.end() - """ + """ + if not name: + return None + tracer, span_attributes = _make_tracer_and_span_attributes( session, extra_attributes, observability_options ) @@ -128,7 +142,11 @@ def trace_call(name, session=None, extra_attributes=None, observability_options=  trace_call is used in situations where you need to end a span with a context manager  or after a scope is exited. If you need to keep a span alive and lazily end it, please  invoke `trace_call_end_lazily`. - """ + """ + if not name: + yield None + return + tracer, span_attributes = _make_tracer_and_span_attributes( session, extra_attributes, observability_options ) @@ -165,9 +183,9 @@ def get_current_span(): return trace.get_current_span() -def add_event_on_current_span(self, event_name, attributes=None, current_span=None): - if not current_span: - current_span = get_current_span() +def add_event_on_current_span(event_name, attributes=None, span=None): + if not span: + span = get_current_span() - if current_span: - current_span.add_event(event_name, attributes) + if span: + span.add_event(event_name, attributes) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 0c134cbc32..3bbe126682 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -51,7 +51,7 @@ def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations = [] self.__span = trace_call_end_lazily( - f"CloudSpanner.{type(self).__name}", + f"CloudSpanner.{type(self).__name__}", self._session, None, getattr(self._session._database, "observability_options", None), diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 6453ffd328..a179731603 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -702,16 +702,14 @@ def execute_partitioned_dml( def execute_pdml(): def do_execute_pdml(session, span): - add_event_on_current_span( - "Starting BeginTransaction", current_span=span - ) + add_event_on_current_span("Starting BeginTransaction", span=span) txn = api.begin_transaction( session=session.name, options=txn_options, metadata=metadata ) add_event_on_current_span( "Completed BeginTransaction", {"transaction.id": txn.id}, - current_span=span, + span=span, ) txn_selector = TransactionSelector(id=txn.id) @@ -731,7 +729,7 @@ def do_execute_pdml(session, span): iterator = _restart_on_unavailable( method=method, request=request, - trace_name="CloudSpannerOperation.ExecuteStreamingSql", + span_name="CloudSpanner.ExecuteStreamingSql", transaction_selector=txn_selector, observability_options=self.observability_options, ) @@ -741,11 +739,9 @@ def do_execute_pdml(session, span): return result_set.stats.row_count_lower_bound - observability_options = getattr(self, "observability_options", {}) with trace_call( - "CloudSpanner.execute_partitioned_pdml", - None, - observability_options=observability_options, + "CloudSpanner.Database.execute_partitioned_pdml", + observability_options=self.observability_options, ) as span: with SessionCheckout(self._pool) as session: return do_execute_pdml(session, span) @@ -1232,8 +1228,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: set_span_status_ok(self.__span) else: - set_span_status_error(self.__span, exc_type) - self.__span.record_exception(exc_type) + set_span_status_error(self.__span, exc_val) + self.__span.record_exception(exc_val) self.__span.end() self.__span = None @@ -1324,8 +1320,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: set_span_status_ok(self.__span) else: - set_span_status_error(self.__span, exc_type) - self.__span.record_exception(exc_type) + set_span_status_error(self.__span, exc_val) + self.__span.record_exception(exc_val) self.__span.end() self.__span = None @@ -1527,7 +1523,7 @@ def generate_read_batches( :meth:`process_read_batch`. """ with trace_call( - f"CloudSpanner.{type(self).__name__}.generate_read_partitions", + f"CloudSpanner.{type(self).__name__}.generate_read_batches", extra_attributes=dict(table=table, columns=columns), observability_options=self.observability_options, ): @@ -1578,10 +1574,8 @@ def process_read_batch( :returns: a result set instance which can be used to consume rows. """ observability_options = self.observability_options or {} - session = self._get_session() - klassname = type(self).__name__ with trace_call( - "CloudSpanner." + klassname + ".process_read_batch", + f"CloudSpanner.{type(self).__name__}.process_read_batch", session, observability_options=observability_options, ): @@ -1665,34 +1659,39 @@ def generate_query_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_query( - sql=sql, - params=params, - param_types=param_types, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_query_batches", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_query( + sql=sql, + params=params, + param_types=param_types, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - query_info = { - "sql": sql, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - if params: - query_info["params"] = params - query_info["param_types"] = param_types - - # Query-level options have higher precedence than client-level and - # environment-level options - default_query_options = self._database._instance._client._query_options - query_info["query_options"] = _merge_query_options( - default_query_options, query_options - ) + query_info = { + "sql": sql, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + if params: + query_info["params"] = params + query_info["param_types"] = param_types + + # Query-level options have higher precedence than client-level and + # environment-level options + default_query_options = self._database._instance._client._query_options + query_info["query_options"] = _merge_query_options( + default_query_options, query_options + ) - for partition in partitions: - yield {"partition": partition, "query": query_info} + for partition in partitions: + yield {"partition": partition, "query": query_info} def process_query_batch( self, @@ -1717,9 +1716,16 @@ def process_query_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self._get_snapshot().execute_sql( - partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.process_query_batch", + observability_options=self.observability_options, + ): + return self._get_snapshot().execute_sql( + partition=batch["partition"], + **batch["query"], + retry=retry, + timeout=timeout, + ) def run_partitioned_query( self, diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 87e62fd6fe..3e950332a3 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -51,7 +51,7 @@ def _restart_on_unavailable( method, request, - trace_name=None, + span_name=None, session=None, attributes=None, transaction=None, @@ -87,7 +87,7 @@ def _restart_on_unavailable( request.transaction = transaction_selector with trace_call( - trace_name, session, attributes, observability_options=observability_options + span_name, session, attributes, observability_options=observability_options ): iterator = method(request=request) while True: @@ -109,7 +109,7 @@ def _restart_on_unavailable( except ServiceUnavailable: del item_buffer[:] with trace_call( - trace_name, + span_name, session, attributes, observability_options=observability_options, @@ -129,7 +129,7 @@ def _restart_on_unavailable( raise del item_buffer[:] with trace_call( - trace_name, + span_name, session, attributes, observability_options=observability_options, diff --git a/tests/_helpers.py b/tests/_helpers.py index b822b950a2..c1b7da10ee 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -77,7 +77,8 @@ def tearDown(self): def assertNoSpans(self): if HAS_OPENTELEMETRY_INSTALLED: - span_list = self.ot_exporter.get_finished_spans() + span_list = self.get_finished_spans() + print("got_span_list", [span.name for span in span_list]) self.assertEqual(len(span_list), 0) def assertSpanAttributes( @@ -85,7 +86,7 @@ def assertSpanAttributes( ): if HAS_OPENTELEMETRY_INSTALLED: if not span: - span_list = self.ot_exporter.get_finished_spans() + span_list = self.get_finished_spans() self.assertEqual(len(span_list), 1) span = span_list[0] @@ -94,3 +95,12 @@ def assertSpanAttributes( print("got_span_attributes ", dict(span.attributes)) print("want_span_attributes", attributes) self.assertEqual(dict(span.attributes), attributes) + + def get_finished_spans(self): + if not HAS_OPENTELEMETRY_INSTALLED: + return [] + + spans = self.ot_exporter.get_finished_spans() + # A span with name=None is the result from invoking trace_call without + # intention to trace, hence these have to be filtered out. + return list(filter(lambda span: span.name, spans)) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 2df73dd0e4..47f3d6d445 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -1214,7 +1214,6 @@ def unit_of_work(transaction): "Test Span", ] got_spans = [span.name for span in span_list] - print("got_spans", got_spans) assert got_spans == expected_span_names # [CreateSession --> Batch] should have their own trace. diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 18fca1f643..f4e2b0ef59 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1374,7 +1374,7 @@ def _partition_query_helper( ) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) @@ -1392,7 +1392,7 @@ def test_partition_query_other_error(self): list(derived.partition_query(SQL_QUERY)) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index a16ceda109..89c28132b6 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -266,7 +266,7 @@ def test_rollback_w_other_error(self): self.assertFalse(transaction.rolled_back) self.assertSpanAttributes( - "CloudSpanner.Rollback", + "CloudSpanner.Transaction.rollback", status=StatusCode.ERROR, attributes=TestTransaction.BASE_ATTRIBUTES, ) @@ -299,7 +299,8 @@ def test_rollback_ok(self): ) self.assertSpanAttributes( - "CloudSpanner.Rollback", attributes=TestTransaction.BASE_ATTRIBUTES + "CloudSpanner.Transaction.rollback", + attributes=TestTransaction.BASE_ATTRIBUTES, ) def test_commit_not_begun(self):