diff --git a/docs/opentelemetry-tracing.rst b/docs/opentelemetry-tracing.rst index cb9a2b1350..c715ad58ad 100644 --- a/docs/opentelemetry-tracing.rst +++ b/docs/opentelemetry-tracing.rst @@ -25,12 +25,21 @@ We also need to tell OpenTelemetry which exporter to use. To export Spanner trac # Create and export one trace every 1000 requests sampler = TraceIdRatioBased(1/1000) - # Use the default tracer provider - trace.set_tracer_provider(TracerProvider(sampler=sampler)) - trace.get_tracer_provider().add_span_processor( + tracer_provider = TracerProvider(sampler=sampler) + tracer_provider.add_span_processor( # Initialize the cloud tracing exporter BatchSpanProcessor(CloudTraceSpanExporter()) ) + observability_options = dict( + tracer_provider=tracer_provider, + + # By default extended_tracing is set to True due + # to legacy reasons to avoid breaking changes, you + # can modify it though using the environment variable + # SPANNER_ENABLE_EXTENDED_TRACING=false. + enable_extended_tracing=False, + ) + spanner = spanner.NewClient(project_id, observability_options=observability_options) To get more fine-grained traces from gRPC, you can enable the gRPC instrumentation by the following @@ -52,3 +61,13 @@ Generated spanner traces should now be available on `Cloud Trace `_ + +Annotating spans with SQL +~~~~~~~~~~~~~~~~~~~~~~~~~ + +By default your spans will be annotated with SQL statements where appropriate, but that can be a PII (Personally Identifiable Information) +leak. Sadly due to legacy behavior, we cannot simply turn off this behavior by default. However you can control this behavior by setting + + SPANNER_ENABLE_EXTENDED_TRACING=false + +to turn it off globally or when creating each SpannerClient, please set `observability_options.enable_extended_tracing=false` diff --git a/examples/trace.py b/examples/trace.py index 791b6cd20b..e7659e13e2 100644 --- a/examples/trace.py +++ b/examples/trace.py @@ -32,15 +32,18 @@ def main(): tracer_provider = TracerProvider(sampler=ALWAYS_ON) trace_exporter = CloudTraceSpanExporter(project_id=project_id) tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) - trace.set_tracer_provider(tracer_provider) - # Retrieve a tracer from the global tracer provider. - tracer = tracer_provider.get_tracer('MyApp') # Setup the Cloud Spanner Client. - spanner_client = spanner.Client(project_id) + spanner_client = spanner.Client( + project_id, + observability_options=dict(tracer_provider=tracer_provider, enable_extended_tracing=True), + ) instance = spanner_client.instance('test-instance') database = instance.database('test-db') + # Retrieve a tracer from our custom tracer provider. + tracer = tracer_provider.get_tracer('MyApp') + # Now run our queries with tracer.start_as_current_span('QueryInformationSchema'): with database.snapshot() as snapshot: diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index 51501a07a3..feb3b92756 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -15,6 +15,7 @@ """Manages OpenTelemetry trace creation and handling""" from contextlib import contextmanager +import os from google.cloud.spanner_v1 import SpannerClient from google.cloud.spanner_v1 import gapic_version @@ -33,6 +34,9 @@ TRACER_NAME = "cloud.google.com/python/spanner" TRACER_VERSION = gapic_version.__version__ +extended_tracing_globally_disabled = ( + os.getenv("SPANNER_ENABLE_EXTENDED_TRACING", "").lower() == "false" +) def get_tracer(tracer_provider=None): @@ -51,13 +55,26 @@ def get_tracer(tracer_provider=None): @contextmanager -def trace_call(name, session, extra_attributes=None): +def trace_call(name, session, extra_attributes=None, observability_options=None): if not HAS_OPENTELEMETRY_INSTALLED or not session: # Empty context manager. Users will have to check if the generated value is None or a span yield None return - tracer = get_tracer() + tracer_provider = None + + # By default enable_extended_tracing=True because in a bid to minimize + # breaking changes and preserve legacy behavior, we are keeping it turned + # on by default. + enable_extended_tracing = True + + if isinstance(observability_options, dict): # Avoid false positives with mock.Mock + tracer_provider = observability_options.get("tracer_provider", None) + enable_extended_tracing = observability_options.get( + "enable_extended_tracing", enable_extended_tracing + ) + + tracer = get_tracer(tracer_provider) # Set base attributes that we know for every trace created attributes = { @@ -72,6 +89,12 @@ def trace_call(name, session, extra_attributes=None): if extra_attributes: attributes.update(extra_attributes) + if extended_tracing_globally_disabled: + enable_extended_tracing = False + + if not enable_extended_tracing: + attributes.pop("db.statement", False) + with tracer.start_as_current_span( name, kind=trace.SpanKind.CLIENT, attributes=attributes ) as span: diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index e3d681189c..948740d7d4 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -205,7 +205,13 @@ def commit( max_commit_delay=max_commit_delay, request_options=request_options, ) - with trace_call("CloudSpanner.Commit", self._session, trace_attributes): + observability_options = getattr(database, "observability_options", None) + with trace_call( + "CloudSpanner.Commit", + self._session, + trace_attributes, + observability_options=observability_options, + ): method = functools.partial( api.commit, request=request, @@ -318,7 +324,13 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals request_options=request_options, exclude_txn_from_change_streams=exclude_txn_from_change_streams, ) - with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes): + observability_options = getattr(database, "observability_options", None) + with trace_call( + "CloudSpanner.BatchWrite", + self._session, + trace_attributes, + observability_options=observability_options, + ): method = functools.partial( api.batch_write, request=request, diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index f8f3fdb72c..afe6264717 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -126,6 +126,16 @@ class Client(ClientWithProject): for all ReadRequests and ExecuteSqlRequests that indicates which replicas or regions should be used for non-transactional reads or queries. + :type observability_options: dict (str -> any) or None + :param observability_options: (Optional) the configuration to control + the tracer's behavior. + tracer_provider is the injected tracer provider + enable_extended_tracing: :type:boolean when set to true will allow for + spans that issue SQL statements to be annotated with SQL. + Default `True`, please set it to `False` to turn it off + or you can use the environment variable `SPANNER_ENABLE_EXTENDED_TRACING=` + to control it. + :raises: :class:`ValueError ` if both ``read_only`` and ``admin`` are :data:`True` """ @@ -146,6 +156,7 @@ def __init__( query_options=None, route_to_leader_enabled=True, directed_read_options=None, + observability_options=None, ): self._emulator_host = _get_spanner_emulator_host() @@ -187,6 +198,7 @@ def __init__( self._route_to_leader_enabled = route_to_leader_enabled self._directed_read_options = directed_read_options + self._observability_options = observability_options @property def credentials(self): @@ -268,6 +280,15 @@ def route_to_leader_enabled(self): """ return self._route_to_leader_enabled + @property + def observability_options(self): + """Getter for observability_options. + + :rtype: dict + :returns: The configured observability_options if set. + """ + return self._observability_options + @property def directed_read_options(self): """Getter for directed_read_options. diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index f6c4ceb667..abddd5d97d 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -718,6 +718,7 @@ def execute_pdml(): method=method, request=request, transaction_selector=txn_selector, + observability_options=self.observability_options, ) result_set = StreamedResultSet(iterator) @@ -1106,6 +1107,17 @@ def set_iam_policy(self, policy): response = api.set_iam_policy(request=request, metadata=metadata) return response + @property + def observability_options(self): + """ + Returns the observability options that you set when creating + the SpannerClient. + """ + if not (self._instance and self._instance._client): + return None + + return getattr(self._instance._client, "observability_options", None) + class BatchCheckout(object): """Context manager for using a batch from a database. diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 28280282f4..6281148590 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -142,7 +142,13 @@ def create(self): if self._labels: request.session.labels = self._labels - with trace_call("CloudSpanner.CreateSession", self, self._labels): + observability_options = getattr(self._database, "observability_options", None) + with trace_call( + "CloudSpanner.CreateSession", + self, + self._labels, + observability_options=observability_options, + ): session_pb = api.create_session( request=request, metadata=metadata, @@ -169,7 +175,10 @@ def exists(self): ) ) - with trace_call("CloudSpanner.GetSession", self) as span: + observability_options = getattr(self._database, "observability_options", None) + with trace_call( + "CloudSpanner.GetSession", self, observability_options=observability_options + ) as span: try: api.get_session(name=self.name, metadata=metadata) if span: @@ -194,7 +203,12 @@ def delete(self): raise ValueError("Session ID not set by back-end") api = self._database.spanner_api metadata = _metadata_with_prefix(self._database.name) - with trace_call("CloudSpanner.DeleteSession", self): + observability_options = getattr(self._database, "observability_options", None) + with trace_call( + "CloudSpanner.DeleteSession", + self, + observability_options=observability_options, + ): api.delete_session(name=self.name, metadata=metadata) def ping(self): diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 3bc1a746bd..a02776b27c 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -56,6 +56,7 @@ def _restart_on_unavailable( attributes=None, transaction=None, transaction_selector=None, + observability_options=None, ): """Restart iteration after :exc:`.ServiceUnavailable`. @@ -84,7 +85,10 @@ def _restart_on_unavailable( ) request.transaction = transaction_selector - with trace_call(trace_name, session, attributes): + + with trace_call( + trace_name, session, attributes, observability_options=observability_options + ): iterator = method(request=request) while True: try: @@ -104,7 +108,12 @@ def _restart_on_unavailable( break except ServiceUnavailable: del item_buffer[:] - with trace_call(trace_name, session, attributes): + with trace_call( + trace_name, + session, + attributes, + observability_options=observability_options, + ): request.resume_token = resume_token if transaction is not None: transaction_selector = transaction._make_txn_selector() @@ -119,7 +128,12 @@ def _restart_on_unavailable( if not resumable_error: raise del item_buffer[:] - with trace_call(trace_name, session, attributes): + with trace_call( + trace_name, + session, + attributes, + observability_options=observability_options, + ): request.resume_token = resume_token if transaction is not None: transaction_selector = transaction._make_txn_selector() @@ -299,6 +313,7 @@ def read( ) trace_attributes = {"table_id": table, "columns": columns} + observability_options = getattr(database, "observability_options", None) if self._transaction_id is None: # lock is added to handle the inline begin for first rpc @@ -310,6 +325,7 @@ def read( self._session, trace_attributes, transaction=self, + observability_options=observability_options, ) self._read_request_count += 1 if self._multi_use: @@ -326,6 +342,7 @@ def read( self._session, trace_attributes, transaction=self, + observability_options=observability_options, ) self._read_request_count += 1 @@ -489,19 +506,35 @@ def execute_sql( ) trace_attributes = {"db.statement": sql} + observability_options = getattr(database, "observability_options", None) if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: return self._get_streamed_result_set( - restart, request, trace_attributes, column_info + restart, + request, + trace_attributes, + column_info, + observability_options, ) else: return self._get_streamed_result_set( - restart, request, trace_attributes, column_info + restart, + request, + trace_attributes, + column_info, + observability_options, ) - def _get_streamed_result_set(self, restart, request, trace_attributes, column_info): + def _get_streamed_result_set( + self, + restart, + request, + trace_attributes, + column_info, + observability_options=None, + ): iterator = _restart_on_unavailable( restart, request, @@ -509,6 +542,7 @@ def _get_streamed_result_set(self, restart, request, trace_attributes, column_in self._session, trace_attributes, transaction=self, + observability_options=observability_options, ) self._read_request_count += 1 self._execute_sql_count += 1 @@ -598,7 +632,10 @@ def partition_read( trace_attributes = {"table_id": table, "columns": columns} with trace_call( - "CloudSpanner.PartitionReadOnlyTransaction", self._session, trace_attributes + "CloudSpanner.PartitionReadOnlyTransaction", + self._session, + trace_attributes, + observability_options=getattr(database, "observability_options", None), ): method = functools.partial( api.partition_read, @@ -701,6 +738,7 @@ def partition_query( "CloudSpanner.PartitionReadWriteTransaction", self._session, trace_attributes, + observability_options=getattr(database, "observability_options", None), ): method = functools.partial( api.partition_query, @@ -843,7 +881,11 @@ def begin(self): (_metadata_with_leader_aware_routing(database._route_to_leader_enabled)) ) txn_selector = self._make_txn_selector() - with trace_call("CloudSpanner.BeginTransaction", self._session): + with trace_call( + "CloudSpanner.BeginTransaction", + self._session, + observability_options=getattr(database, "observability_options", None), + ): method = functools.partial( api.begin_transaction, session=self._session.name, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index c872cc380d..beb3e46edb 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -98,7 +98,13 @@ def _make_txn_selector(self): return TransactionSelector(id=self._transaction_id) def _execute_request( - self, method, request, trace_name=None, session=None, attributes=None + self, + method, + request, + trace_name=None, + session=None, + attributes=None, + observability_options=None, ): """Helper method to execute request after fetching transaction selector. @@ -110,7 +116,9 @@ def _execute_request( """ transaction = self._make_txn_selector() request.transaction = transaction - with trace_call(trace_name, session, attributes): + with trace_call( + trace_name, session, attributes, observability_options=observability_options + ): method = functools.partial(method, request=request) response = _retry( method, @@ -147,7 +155,12 @@ def begin(self): read_write=TransactionOptions.ReadWrite(), exclude_txn_from_change_streams=self.exclude_txn_from_change_streams, ) - with trace_call("CloudSpanner.BeginTransaction", self._session): + observability_options = getattr(database, "observability_options", None) + with trace_call( + "CloudSpanner.BeginTransaction", + self._session, + observability_options=observability_options, + ): method = functools.partial( api.begin_transaction, session=self._session.name, @@ -175,7 +188,12 @@ def rollback(self): database._route_to_leader_enabled ) ) - with trace_call("CloudSpanner.Rollback", self._session): + observability_options = getattr(database, "observability_options", None) + with trace_call( + "CloudSpanner.Rollback", + self._session, + observability_options=observability_options, + ): method = functools.partial( api.rollback, session=self._session.name, @@ -248,7 +266,13 @@ def commit( max_commit_delay=max_commit_delay, request_options=request_options, ) - with trace_call("CloudSpanner.Commit", self._session, trace_attributes): + observability_options = getattr(database, "observability_options", None) + with trace_call( + "CloudSpanner.Commit", + self._session, + trace_attributes, + observability_options, + ): method = functools.partial( api.commit, request=request, @@ -362,6 +386,9 @@ def execute_update( # environment-level options default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) + observability_options = getattr( + database._instance._client, "observability_options", None + ) if request_options is None: request_options = RequestOptions() @@ -399,6 +426,7 @@ def execute_update( "CloudSpanner.ReadWriteTransaction", self._session, trace_attributes, + observability_options=observability_options, ) # Setting the transaction id because the transaction begin was inlined for first rpc. if ( @@ -415,6 +443,7 @@ def execute_update( "CloudSpanner.ReadWriteTransaction", self._session, trace_attributes, + observability_options=observability_options, ) return response.stats.row_count_exact @@ -481,6 +510,7 @@ def batch_update( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) api = database.spanner_api + observability_options = getattr(database, "observability_options", None) seqno, self._execute_sql_count = ( self._execute_sql_count, @@ -521,6 +551,7 @@ def batch_update( "CloudSpanner.DMLTransaction", self._session, trace_attributes, + observability_options=observability_options, ) # Setting the transaction id because the transaction begin was inlined for first rpc. for result_set in response.result_sets: @@ -538,6 +569,7 @@ def batch_update( "CloudSpanner.DMLTransaction", self._session, trace_attributes, + observability_options=observability_options, ) row_counts = [ diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py new file mode 100644 index 0000000000..8382255c15 --- /dev/null +++ b/tests/system/test_observability_options.py @@ -0,0 +1,134 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from . import _helpers +from google.cloud.spanner_v1 import Client + +HAS_OTEL_INSTALLED = False + +try: + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.sampling import ALWAYS_ON + from opentelemetry import trace + + HAS_OTEL_INSTALLED = True +except ImportError: + pass + + +@pytest.mark.skipif( + not HAS_OTEL_INSTALLED, reason="OpenTelemetry is necessary to test traces." +) +@pytest.mark.skipif( + not _helpers.USE_EMULATOR, reason="mulator is necessary to test traces." +) +def test_observability_options_propagation(): + PROJECT = _helpers.EMULATOR_PROJECT + CONFIGURATION_NAME = "config-name" + INSTANCE_ID = _helpers.INSTANCE_ID + DISPLAY_NAME = "display-name" + DATABASE_ID = _helpers.unique_id("temp_db") + NODE_COUNT = 5 + LABELS = {"test": "true"} + + def test_propagation(enable_extended_tracing): + global_tracer_provider = TracerProvider(sampler=ALWAYS_ON) + trace.set_tracer_provider(global_tracer_provider) + global_trace_exporter = InMemorySpanExporter() + global_tracer_provider.add_span_processor( + SimpleSpanProcessor(global_trace_exporter) + ) + + inject_tracer_provider = TracerProvider(sampler=ALWAYS_ON) + inject_trace_exporter = InMemorySpanExporter() + inject_tracer_provider.add_span_processor( + SimpleSpanProcessor(inject_trace_exporter) + ) + observability_options = dict( + tracer_provider=inject_tracer_provider, + enable_extended_tracing=enable_extended_tracing, + ) + client = Client( + project=PROJECT, + observability_options=observability_options, + credentials=_make_credentials(), + ) + + instance = client.instance( + INSTANCE_ID, + CONFIGURATION_NAME, + display_name=DISPLAY_NAME, + node_count=NODE_COUNT, + labels=LABELS, + ) + + try: + instance.create() + except Exception: + pass + + db = instance.database(DATABASE_ID) + try: + db.create() + except Exception: + pass + + assert db.observability_options == observability_options + with db.snapshot() as snapshot: + res = snapshot.execute_sql("SELECT 1") + for val in res: + _ = val + + from_global_spans = global_trace_exporter.get_finished_spans() + from_inject_spans = inject_trace_exporter.get_finished_spans() + assert ( + len(from_global_spans) == 0 + ) # "Expecting no spans from the global trace exporter" + assert ( + len(from_inject_spans) >= 2 + ) # "Expecting at least 2 spans from the injected trace exporter" + gotNames = [span.name for span in from_inject_spans] + wantNames = ["CloudSpanner.CreateSession", "CloudSpanner.ReadWriteTransaction"] + assert gotNames == wantNames + + # Check for conformance of enable_extended_tracing + lastSpan = from_inject_spans[len(from_inject_spans) - 1] + wantAnnotatedSQL = "SELECT 1" + if not enable_extended_tracing: + wantAnnotatedSQL = None + assert ( + lastSpan.attributes.get("db.statement", None) == wantAnnotatedSQL + ) # "Mismatch in annotated sql" + + try: + db.delete() + instance.delete() + except Exception: + pass + + # Test the respective options for enable_extended_tracing + test_propagation(True) + test_propagation(False) + + +def _make_credentials(): + from google.auth.credentials import AnonymousCredentials + + return AnonymousCredentials()