diff --git a/docs/opentelemetry-tracing.rst b/docs/opentelemetry-tracing.rst
index cb9a2b1350..c715ad58ad 100644
--- a/docs/opentelemetry-tracing.rst
+++ b/docs/opentelemetry-tracing.rst
@@ -25,12 +25,21 @@ We also need to tell OpenTelemetry which exporter to use. To export Spanner trac
# Create and export one trace every 1000 requests
sampler = TraceIdRatioBased(1/1000)
- # Use the default tracer provider
- trace.set_tracer_provider(TracerProvider(sampler=sampler))
- trace.get_tracer_provider().add_span_processor(
+ tracer_provider = TracerProvider(sampler=sampler)
+ tracer_provider.add_span_processor(
# Initialize the cloud tracing exporter
BatchSpanProcessor(CloudTraceSpanExporter())
)
+ observability_options = dict(
+ tracer_provider=tracer_provider,
+
+ # By default extended_tracing is set to True due
+ # to legacy reasons to avoid breaking changes, you
+ # can modify it though using the environment variable
+ # SPANNER_ENABLE_EXTENDED_TRACING=false.
+ enable_extended_tracing=False,
+ )
+ spanner = spanner.NewClient(project_id, observability_options=observability_options)
To get more fine-grained traces from gRPC, you can enable the gRPC instrumentation by the following
@@ -52,3 +61,13 @@ Generated spanner traces should now be available on `Cloud Trace `_
+
+Annotating spans with SQL
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+By default your spans will be annotated with SQL statements where appropriate, but that can be a PII (Personally Identifiable Information)
+leak. Sadly due to legacy behavior, we cannot simply turn off this behavior by default. However you can control this behavior by setting
+
+ SPANNER_ENABLE_EXTENDED_TRACING=false
+
+to turn it off globally or when creating each SpannerClient, please set `observability_options.enable_extended_tracing=false`
diff --git a/examples/trace.py b/examples/trace.py
index 791b6cd20b..e7659e13e2 100644
--- a/examples/trace.py
+++ b/examples/trace.py
@@ -32,15 +32,18 @@ def main():
tracer_provider = TracerProvider(sampler=ALWAYS_ON)
trace_exporter = CloudTraceSpanExporter(project_id=project_id)
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
- trace.set_tracer_provider(tracer_provider)
- # Retrieve a tracer from the global tracer provider.
- tracer = tracer_provider.get_tracer('MyApp')
# Setup the Cloud Spanner Client.
- spanner_client = spanner.Client(project_id)
+ spanner_client = spanner.Client(
+ project_id,
+ observability_options=dict(tracer_provider=tracer_provider, enable_extended_tracing=True),
+ )
instance = spanner_client.instance('test-instance')
database = instance.database('test-db')
+ # Retrieve a tracer from our custom tracer provider.
+ tracer = tracer_provider.get_tracer('MyApp')
+
# Now run our queries
with tracer.start_as_current_span('QueryInformationSchema'):
with database.snapshot() as snapshot:
diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py
index 51501a07a3..feb3b92756 100644
--- a/google/cloud/spanner_v1/_opentelemetry_tracing.py
+++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py
@@ -15,6 +15,7 @@
"""Manages OpenTelemetry trace creation and handling"""
from contextlib import contextmanager
+import os
from google.cloud.spanner_v1 import SpannerClient
from google.cloud.spanner_v1 import gapic_version
@@ -33,6 +34,9 @@
TRACER_NAME = "cloud.google.com/python/spanner"
TRACER_VERSION = gapic_version.__version__
+extended_tracing_globally_disabled = (
+ os.getenv("SPANNER_ENABLE_EXTENDED_TRACING", "").lower() == "false"
+)
def get_tracer(tracer_provider=None):
@@ -51,13 +55,26 @@ def get_tracer(tracer_provider=None):
@contextmanager
-def trace_call(name, session, extra_attributes=None):
+def trace_call(name, session, extra_attributes=None, observability_options=None):
if not HAS_OPENTELEMETRY_INSTALLED or not session:
# Empty context manager. Users will have to check if the generated value is None or a span
yield None
return
- tracer = get_tracer()
+ tracer_provider = None
+
+ # By default enable_extended_tracing=True because in a bid to minimize
+ # breaking changes and preserve legacy behavior, we are keeping it turned
+ # on by default.
+ enable_extended_tracing = True
+
+ if isinstance(observability_options, dict): # Avoid false positives with mock.Mock
+ tracer_provider = observability_options.get("tracer_provider", None)
+ enable_extended_tracing = observability_options.get(
+ "enable_extended_tracing", enable_extended_tracing
+ )
+
+ tracer = get_tracer(tracer_provider)
# Set base attributes that we know for every trace created
attributes = {
@@ -72,6 +89,12 @@ def trace_call(name, session, extra_attributes=None):
if extra_attributes:
attributes.update(extra_attributes)
+ if extended_tracing_globally_disabled:
+ enable_extended_tracing = False
+
+ if not enable_extended_tracing:
+ attributes.pop("db.statement", False)
+
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT, attributes=attributes
) as span:
diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py
index e3d681189c..948740d7d4 100644
--- a/google/cloud/spanner_v1/batch.py
+++ b/google/cloud/spanner_v1/batch.py
@@ -205,7 +205,13 @@ def commit(
max_commit_delay=max_commit_delay,
request_options=request_options,
)
- with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
+ observability_options = getattr(database, "observability_options", None)
+ with trace_call(
+ "CloudSpanner.Commit",
+ self._session,
+ trace_attributes,
+ observability_options=observability_options,
+ ):
method = functools.partial(
api.commit,
request=request,
@@ -318,7 +324,13 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
request_options=request_options,
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)
- with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes):
+ observability_options = getattr(database, "observability_options", None)
+ with trace_call(
+ "CloudSpanner.BatchWrite",
+ self._session,
+ trace_attributes,
+ observability_options=observability_options,
+ ):
method = functools.partial(
api.batch_write,
request=request,
diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py
index f8f3fdb72c..afe6264717 100644
--- a/google/cloud/spanner_v1/client.py
+++ b/google/cloud/spanner_v1/client.py
@@ -126,6 +126,16 @@ class Client(ClientWithProject):
for all ReadRequests and ExecuteSqlRequests that indicates which replicas
or regions should be used for non-transactional reads or queries.
+ :type observability_options: dict (str -> any) or None
+ :param observability_options: (Optional) the configuration to control
+ the tracer's behavior.
+ tracer_provider is the injected tracer provider
+ enable_extended_tracing: :type:boolean when set to true will allow for
+ spans that issue SQL statements to be annotated with SQL.
+ Default `True`, please set it to `False` to turn it off
+ or you can use the environment variable `SPANNER_ENABLE_EXTENDED_TRACING=`
+ to control it.
+
:raises: :class:`ValueError ` if both ``read_only``
and ``admin`` are :data:`True`
"""
@@ -146,6 +156,7 @@ def __init__(
query_options=None,
route_to_leader_enabled=True,
directed_read_options=None,
+ observability_options=None,
):
self._emulator_host = _get_spanner_emulator_host()
@@ -187,6 +198,7 @@ def __init__(
self._route_to_leader_enabled = route_to_leader_enabled
self._directed_read_options = directed_read_options
+ self._observability_options = observability_options
@property
def credentials(self):
@@ -268,6 +280,15 @@ def route_to_leader_enabled(self):
"""
return self._route_to_leader_enabled
+ @property
+ def observability_options(self):
+ """Getter for observability_options.
+
+ :rtype: dict
+ :returns: The configured observability_options if set.
+ """
+ return self._observability_options
+
@property
def directed_read_options(self):
"""Getter for directed_read_options.
diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py
index f6c4ceb667..abddd5d97d 100644
--- a/google/cloud/spanner_v1/database.py
+++ b/google/cloud/spanner_v1/database.py
@@ -718,6 +718,7 @@ def execute_pdml():
method=method,
request=request,
transaction_selector=txn_selector,
+ observability_options=self.observability_options,
)
result_set = StreamedResultSet(iterator)
@@ -1106,6 +1107,17 @@ def set_iam_policy(self, policy):
response = api.set_iam_policy(request=request, metadata=metadata)
return response
+ @property
+ def observability_options(self):
+ """
+ Returns the observability options that you set when creating
+ the SpannerClient.
+ """
+ if not (self._instance and self._instance._client):
+ return None
+
+ return getattr(self._instance._client, "observability_options", None)
+
class BatchCheckout(object):
"""Context manager for using a batch from a database.
diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py
index 28280282f4..6281148590 100644
--- a/google/cloud/spanner_v1/session.py
+++ b/google/cloud/spanner_v1/session.py
@@ -142,7 +142,13 @@ def create(self):
if self._labels:
request.session.labels = self._labels
- with trace_call("CloudSpanner.CreateSession", self, self._labels):
+ observability_options = getattr(self._database, "observability_options", None)
+ with trace_call(
+ "CloudSpanner.CreateSession",
+ self,
+ self._labels,
+ observability_options=observability_options,
+ ):
session_pb = api.create_session(
request=request,
metadata=metadata,
@@ -169,7 +175,10 @@ def exists(self):
)
)
- with trace_call("CloudSpanner.GetSession", self) as span:
+ observability_options = getattr(self._database, "observability_options", None)
+ with trace_call(
+ "CloudSpanner.GetSession", self, observability_options=observability_options
+ ) as span:
try:
api.get_session(name=self.name, metadata=metadata)
if span:
@@ -194,7 +203,12 @@ def delete(self):
raise ValueError("Session ID not set by back-end")
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
- with trace_call("CloudSpanner.DeleteSession", self):
+ observability_options = getattr(self._database, "observability_options", None)
+ with trace_call(
+ "CloudSpanner.DeleteSession",
+ self,
+ observability_options=observability_options,
+ ):
api.delete_session(name=self.name, metadata=metadata)
def ping(self):
diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py
index 3bc1a746bd..a02776b27c 100644
--- a/google/cloud/spanner_v1/snapshot.py
+++ b/google/cloud/spanner_v1/snapshot.py
@@ -56,6 +56,7 @@ def _restart_on_unavailable(
attributes=None,
transaction=None,
transaction_selector=None,
+ observability_options=None,
):
"""Restart iteration after :exc:`.ServiceUnavailable`.
@@ -84,7 +85,10 @@ def _restart_on_unavailable(
)
request.transaction = transaction_selector
- with trace_call(trace_name, session, attributes):
+
+ with trace_call(
+ trace_name, session, attributes, observability_options=observability_options
+ ):
iterator = method(request=request)
while True:
try:
@@ -104,7 +108,12 @@ def _restart_on_unavailable(
break
except ServiceUnavailable:
del item_buffer[:]
- with trace_call(trace_name, session, attributes):
+ with trace_call(
+ trace_name,
+ session,
+ attributes,
+ observability_options=observability_options,
+ ):
request.resume_token = resume_token
if transaction is not None:
transaction_selector = transaction._make_txn_selector()
@@ -119,7 +128,12 @@ def _restart_on_unavailable(
if not resumable_error:
raise
del item_buffer[:]
- with trace_call(trace_name, session, attributes):
+ with trace_call(
+ trace_name,
+ session,
+ attributes,
+ observability_options=observability_options,
+ ):
request.resume_token = resume_token
if transaction is not None:
transaction_selector = transaction._make_txn_selector()
@@ -299,6 +313,7 @@ def read(
)
trace_attributes = {"table_id": table, "columns": columns}
+ observability_options = getattr(database, "observability_options", None)
if self._transaction_id is None:
# lock is added to handle the inline begin for first rpc
@@ -310,6 +325,7 @@ def read(
self._session,
trace_attributes,
transaction=self,
+ observability_options=observability_options,
)
self._read_request_count += 1
if self._multi_use:
@@ -326,6 +342,7 @@ def read(
self._session,
trace_attributes,
transaction=self,
+ observability_options=observability_options,
)
self._read_request_count += 1
@@ -489,19 +506,35 @@ def execute_sql(
)
trace_attributes = {"db.statement": sql}
+ observability_options = getattr(database, "observability_options", None)
if self._transaction_id is None:
# lock is added to handle the inline begin for first rpc
with self._lock:
return self._get_streamed_result_set(
- restart, request, trace_attributes, column_info
+ restart,
+ request,
+ trace_attributes,
+ column_info,
+ observability_options,
)
else:
return self._get_streamed_result_set(
- restart, request, trace_attributes, column_info
+ restart,
+ request,
+ trace_attributes,
+ column_info,
+ observability_options,
)
- def _get_streamed_result_set(self, restart, request, trace_attributes, column_info):
+ def _get_streamed_result_set(
+ self,
+ restart,
+ request,
+ trace_attributes,
+ column_info,
+ observability_options=None,
+ ):
iterator = _restart_on_unavailable(
restart,
request,
@@ -509,6 +542,7 @@ def _get_streamed_result_set(self, restart, request, trace_attributes, column_in
self._session,
trace_attributes,
transaction=self,
+ observability_options=observability_options,
)
self._read_request_count += 1
self._execute_sql_count += 1
@@ -598,7 +632,10 @@ def partition_read(
trace_attributes = {"table_id": table, "columns": columns}
with trace_call(
- "CloudSpanner.PartitionReadOnlyTransaction", self._session, trace_attributes
+ "CloudSpanner.PartitionReadOnlyTransaction",
+ self._session,
+ trace_attributes,
+ observability_options=getattr(database, "observability_options", None),
):
method = functools.partial(
api.partition_read,
@@ -701,6 +738,7 @@ def partition_query(
"CloudSpanner.PartitionReadWriteTransaction",
self._session,
trace_attributes,
+ observability_options=getattr(database, "observability_options", None),
):
method = functools.partial(
api.partition_query,
@@ -843,7 +881,11 @@ def begin(self):
(_metadata_with_leader_aware_routing(database._route_to_leader_enabled))
)
txn_selector = self._make_txn_selector()
- with trace_call("CloudSpanner.BeginTransaction", self._session):
+ with trace_call(
+ "CloudSpanner.BeginTransaction",
+ self._session,
+ observability_options=getattr(database, "observability_options", None),
+ ):
method = functools.partial(
api.begin_transaction,
session=self._session.name,
diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py
index c872cc380d..beb3e46edb 100644
--- a/google/cloud/spanner_v1/transaction.py
+++ b/google/cloud/spanner_v1/transaction.py
@@ -98,7 +98,13 @@ def _make_txn_selector(self):
return TransactionSelector(id=self._transaction_id)
def _execute_request(
- self, method, request, trace_name=None, session=None, attributes=None
+ self,
+ method,
+ request,
+ trace_name=None,
+ session=None,
+ attributes=None,
+ observability_options=None,
):
"""Helper method to execute request after fetching transaction selector.
@@ -110,7 +116,9 @@ def _execute_request(
"""
transaction = self._make_txn_selector()
request.transaction = transaction
- with trace_call(trace_name, session, attributes):
+ with trace_call(
+ trace_name, session, attributes, observability_options=observability_options
+ ):
method = functools.partial(method, request=request)
response = _retry(
method,
@@ -147,7 +155,12 @@ def begin(self):
read_write=TransactionOptions.ReadWrite(),
exclude_txn_from_change_streams=self.exclude_txn_from_change_streams,
)
- with trace_call("CloudSpanner.BeginTransaction", self._session):
+ observability_options = getattr(database, "observability_options", None)
+ with trace_call(
+ "CloudSpanner.BeginTransaction",
+ self._session,
+ observability_options=observability_options,
+ ):
method = functools.partial(
api.begin_transaction,
session=self._session.name,
@@ -175,7 +188,12 @@ def rollback(self):
database._route_to_leader_enabled
)
)
- with trace_call("CloudSpanner.Rollback", self._session):
+ observability_options = getattr(database, "observability_options", None)
+ with trace_call(
+ "CloudSpanner.Rollback",
+ self._session,
+ observability_options=observability_options,
+ ):
method = functools.partial(
api.rollback,
session=self._session.name,
@@ -248,7 +266,13 @@ def commit(
max_commit_delay=max_commit_delay,
request_options=request_options,
)
- with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
+ observability_options = getattr(database, "observability_options", None)
+ with trace_call(
+ "CloudSpanner.Commit",
+ self._session,
+ trace_attributes,
+ observability_options,
+ ):
method = functools.partial(
api.commit,
request=request,
@@ -362,6 +386,9 @@ def execute_update(
# environment-level options
default_query_options = database._instance._client._query_options
query_options = _merge_query_options(default_query_options, query_options)
+ observability_options = getattr(
+ database._instance._client, "observability_options", None
+ )
if request_options is None:
request_options = RequestOptions()
@@ -399,6 +426,7 @@ def execute_update(
"CloudSpanner.ReadWriteTransaction",
self._session,
trace_attributes,
+ observability_options=observability_options,
)
# Setting the transaction id because the transaction begin was inlined for first rpc.
if (
@@ -415,6 +443,7 @@ def execute_update(
"CloudSpanner.ReadWriteTransaction",
self._session,
trace_attributes,
+ observability_options=observability_options,
)
return response.stats.row_count_exact
@@ -481,6 +510,7 @@ def batch_update(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
api = database.spanner_api
+ observability_options = getattr(database, "observability_options", None)
seqno, self._execute_sql_count = (
self._execute_sql_count,
@@ -521,6 +551,7 @@ def batch_update(
"CloudSpanner.DMLTransaction",
self._session,
trace_attributes,
+ observability_options=observability_options,
)
# Setting the transaction id because the transaction begin was inlined for first rpc.
for result_set in response.result_sets:
@@ -538,6 +569,7 @@ def batch_update(
"CloudSpanner.DMLTransaction",
self._session,
trace_attributes,
+ observability_options=observability_options,
)
row_counts = [
diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py
new file mode 100644
index 0000000000..8382255c15
--- /dev/null
+++ b/tests/system/test_observability_options.py
@@ -0,0 +1,134 @@
+# Copyright 2024 Google LLC All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from . import _helpers
+from google.cloud.spanner_v1 import Client
+
+HAS_OTEL_INSTALLED = False
+
+try:
+ from opentelemetry.sdk.trace.export import SimpleSpanProcessor
+ from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
+ InMemorySpanExporter,
+ )
+ from opentelemetry.sdk.trace import TracerProvider
+ from opentelemetry.sdk.trace.sampling import ALWAYS_ON
+ from opentelemetry import trace
+
+ HAS_OTEL_INSTALLED = True
+except ImportError:
+ pass
+
+
+@pytest.mark.skipif(
+ not HAS_OTEL_INSTALLED, reason="OpenTelemetry is necessary to test traces."
+)
+@pytest.mark.skipif(
+ not _helpers.USE_EMULATOR, reason="mulator is necessary to test traces."
+)
+def test_observability_options_propagation():
+ PROJECT = _helpers.EMULATOR_PROJECT
+ CONFIGURATION_NAME = "config-name"
+ INSTANCE_ID = _helpers.INSTANCE_ID
+ DISPLAY_NAME = "display-name"
+ DATABASE_ID = _helpers.unique_id("temp_db")
+ NODE_COUNT = 5
+ LABELS = {"test": "true"}
+
+ def test_propagation(enable_extended_tracing):
+ global_tracer_provider = TracerProvider(sampler=ALWAYS_ON)
+ trace.set_tracer_provider(global_tracer_provider)
+ global_trace_exporter = InMemorySpanExporter()
+ global_tracer_provider.add_span_processor(
+ SimpleSpanProcessor(global_trace_exporter)
+ )
+
+ inject_tracer_provider = TracerProvider(sampler=ALWAYS_ON)
+ inject_trace_exporter = InMemorySpanExporter()
+ inject_tracer_provider.add_span_processor(
+ SimpleSpanProcessor(inject_trace_exporter)
+ )
+ observability_options = dict(
+ tracer_provider=inject_tracer_provider,
+ enable_extended_tracing=enable_extended_tracing,
+ )
+ client = Client(
+ project=PROJECT,
+ observability_options=observability_options,
+ credentials=_make_credentials(),
+ )
+
+ instance = client.instance(
+ INSTANCE_ID,
+ CONFIGURATION_NAME,
+ display_name=DISPLAY_NAME,
+ node_count=NODE_COUNT,
+ labels=LABELS,
+ )
+
+ try:
+ instance.create()
+ except Exception:
+ pass
+
+ db = instance.database(DATABASE_ID)
+ try:
+ db.create()
+ except Exception:
+ pass
+
+ assert db.observability_options == observability_options
+ with db.snapshot() as snapshot:
+ res = snapshot.execute_sql("SELECT 1")
+ for val in res:
+ _ = val
+
+ from_global_spans = global_trace_exporter.get_finished_spans()
+ from_inject_spans = inject_trace_exporter.get_finished_spans()
+ assert (
+ len(from_global_spans) == 0
+ ) # "Expecting no spans from the global trace exporter"
+ assert (
+ len(from_inject_spans) >= 2
+ ) # "Expecting at least 2 spans from the injected trace exporter"
+ gotNames = [span.name for span in from_inject_spans]
+ wantNames = ["CloudSpanner.CreateSession", "CloudSpanner.ReadWriteTransaction"]
+ assert gotNames == wantNames
+
+ # Check for conformance of enable_extended_tracing
+ lastSpan = from_inject_spans[len(from_inject_spans) - 1]
+ wantAnnotatedSQL = "SELECT 1"
+ if not enable_extended_tracing:
+ wantAnnotatedSQL = None
+ assert (
+ lastSpan.attributes.get("db.statement", None) == wantAnnotatedSQL
+ ) # "Mismatch in annotated sql"
+
+ try:
+ db.delete()
+ instance.delete()
+ except Exception:
+ pass
+
+ # Test the respective options for enable_extended_tracing
+ test_propagation(True)
+ test_propagation(False)
+
+
+def _make_credentials():
+ from google.auth.credentials import AnonymousCredentials
+
+ return AnonymousCredentials()