Skip to content

Commit

Permalink
Add another form of tracing with explicit .close() invocation
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Nov 22, 2024
1 parent 081ae17 commit 51d3e91
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 24 deletions.
20 changes: 12 additions & 8 deletions google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@ def get_tracer(tracer_provider=None):
return tracer_provider.get_tracer(TRACER_NAME, TRACER_VERSION)


@contextmanager
def trace_call(name, session, extra_attributes=None, observability_options=None):
if not HAS_OPENTELEMETRY_INSTALLED or not session:
# Empty context manager. Users will have to check if the generated value is None or a span
yield None
return
def trace_end_explicitly(
name, session=None, extra_attributes=None, observability_options=None
):
if not HAS_OPENTELEMETRY_INSTALLED:
return None

tracer_provider = None

Expand Down Expand Up @@ -100,8 +99,13 @@ def trace_call(name, session, extra_attributes=None, observability_options=None)
if not enable_extended_tracing:
attributes.pop("db.statement", False)

with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT, attributes=attributes
return tracer.start_span(name, kind=trace.SpanKind.CLIENT, attributes=attributes)


@contextmanager
def trace_call(name, session=None, extra_attributes=None, observability_options=None):
with trace_end_explicitly(
name, session, extra_attributes, observability_options
) as span:
try:
yield span
Expand Down
49 changes: 46 additions & 3 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1._opentelemetry_tracing import (
trace_call,
trace_end_explicitly,
)
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1._helpers import _retry
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
Expand All @@ -46,6 +49,14 @@ class _BatchBase(_SessionWrapper):
def __init__(self, session):
super(_BatchBase, self).__init__(session)
self._mutations = []
observability_options = getattr(
self._session.database, "observability_options", None
)
self.__span = trace_end_explicitly(
"CloudSpannerX." + type(self).__name__,
self._session,
observability_options=observability_options,
)

def _check_state(self):
"""Helper for :meth:`commit` et al.
Expand All @@ -69,6 +80,10 @@ def insert(self, table, columns, values):
:type values: list of lists
:param values: Values to be modified.
"""
if self.__span:
self.__span.add_event(
"insert mutations inserted", dict(table=table, columns=columns)
)
self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values)))

def update(self, table, columns, values):
Expand All @@ -83,6 +98,10 @@ def update(self, table, columns, values):
:type values: list of lists
:param values: Values to be modified.
"""
if self.__span:
self.__span.add_event(
"update mutations inserted", dict(table=table, columns=columns)
)
self._mutations.append(Mutation(update=_make_write_pb(table, columns, values)))

def insert_or_update(self, table, columns, values):
Expand All @@ -97,6 +116,11 @@ def insert_or_update(self, table, columns, values):
:type values: list of lists
:param values: Values to be modified.
"""
if self.__span:
self.__span.add_event(
"insert_or_update mutations inserted",
dict(table=table, columns=columns),
)
self._mutations.append(
Mutation(insert_or_update=_make_write_pb(table, columns, values))
)
Expand All @@ -113,6 +137,10 @@ def replace(self, table, columns, values):
:type values: list of lists
:param values: Values to be modified.
"""
if self.__span:
self.__span.add_event(
"replace mutations inserted", dict(table=table, columns=columns)
)
self._mutations.append(Mutation(replace=_make_write_pb(table, columns, values)))

def delete(self, table, keyset):
Expand All @@ -126,6 +154,10 @@ def delete(self, table, keyset):
"""
delete = Mutation.Delete(table=table, key_set=keyset._to_pb())
self._mutations.append(Mutation(delete=delete))
if self.__span:
self.__span.add_event(
"delete mutations inserted", dict(table=table, columns=columns)
)


class Batch(_BatchBase):
Expand Down Expand Up @@ -207,7 +239,7 @@ def commit(
)
observability_options = getattr(database, "observability_options", None)
with trace_call(
"CloudSpanner.Commit",
"CloudSpanner.Batch.commit",
self._session,
trace_attributes,
observability_options=observability_options,
Expand All @@ -228,13 +260,24 @@ def commit(
def __enter__(self):
"""Begin ``with`` block."""
self._check_state()
observability_options = getattr(
self._session.database, "observability_options", None
)
self.__span = trace_end_explicitly(
"CloudSpanner.Batch",
self._session,
observability_options=observability_options,
)

return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
if exc_type is None:
self.commit()
if self.__span:
self.__span.end()
self.__span = None


class MutationGroup(_BatchBase):
Expand Down Expand Up @@ -326,7 +369,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
)
observability_options = getattr(database, "observability_options", None)
with trace_call(
"CloudSpanner.BatchWrite",
"CloudSpanner.batch_write",
self._session,
trace_attributes,
observability_options=observability_options,
Expand Down
67 changes: 60 additions & 7 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1._opentelemetry_tracing import (
trace_call,
trace_end_explicitly,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
from google.cloud.spanner_v1.keyset import KeySet
Expand Down Expand Up @@ -1188,9 +1191,16 @@ def __init__(
self._request_options = request_options
self._max_commit_delay = max_commit_delay
self._exclude_txn_from_change_streams = exclude_txn_from_change_streams
self.__span = None

def __enter__(self):
"""Begin ``with`` block."""
observability_options = self._database.observability_options
self.__span = trace_end_explicitly(
"CloudSpanner.Database.batch",
None,
observability_options=observability_options,
)
session = self._session = self._database._pool.get()
batch = self._batch = Batch(session)
if self._request_options.transaction_tag:
Expand All @@ -1214,6 +1224,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
extra={"commit_stats": self._batch.commit_stats},
)
self._database._pool.put(self._session)
if self.__span:
self.__span.end()
self.__span = None


class MutationGroupsCheckout(object):
Expand Down Expand Up @@ -1271,9 +1284,20 @@ def __init__(self, database, **kw):
self._database = database
self._session = None
self._kw = kw
self.__span = None

def __enter__(self):
"""Begin ``with`` block."""
observability_options = self._database.observability_options
attributes = dict()
if self._kw:
attributes["multi_use"] = self._kw["multi_use"]
self.__span = trace_end_explicitly(
"CloudSpanner.Database.snapshot",
None,
attributes,
observability_options=observability_options,
)
session = self._session = self._database._pool.get()
return Snapshot(session, **self._kw)

Expand All @@ -1287,6 +1311,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._session.create()
self._database._pool.put(self._session)

if self.__span:
self.__span.end()
self.__span = None


class BatchSnapshot(object):
"""Wrapper for generating and processing read / query batches.
Expand Down Expand Up @@ -1317,6 +1345,15 @@ def __init__(
self._transaction_id = transaction_id
self._read_timestamp = read_timestamp
self._exact_staleness = exact_staleness
observability_options = getattr(self._database, "observability_options", {})
if isinstance(observability_options, dict) and self._database:
observability_options["db_name"] = self._database.name
self.__observability_options = observability_options
self.__span = trace_end_explicitly(
"CloudSpanner.BatchSnapshot",
self._session,
observability_options=observability_options,
)

@classmethod
def from_dict(cls, database, mapping):
Expand Down Expand Up @@ -1352,6 +1389,10 @@ def to_dict(self):
"transaction_id": snapshot._transaction_id,
}

@property
def observability_options(self):
return self.__observability_options

def _get_session(self):
"""Create session as needed.
Expand Down Expand Up @@ -1517,12 +1558,20 @@ def process_read_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)
observability_options = self.observability_options or {}
session = self._get_session()
klassname = type(self).__name__
with trace_call(
"CloudSpanner." + klassname + ".process_read_batch",
session,
observability_options=observability_options,
):
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)

def generate_query_batches(
self,
Expand Down Expand Up @@ -1750,6 +1799,10 @@ def close(self):
if self._session is not None:
self._session.delete()

if self.__span:
self.__span.end()
self.__span = None


def _check_ddl_statements(value):
"""Validate DDL Statements used to define database schema.
Expand Down
11 changes: 11 additions & 0 deletions google/cloud/spanner_v1/merged_result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set):
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue

def run(self):
observability_options = getattr(
self._batch_snapshot, "observability_options", {}
)
with trace_call(
"CloudSpanner.PartitionExecutor.run",
None,
observability_options=observability_options,
):
return self.__run()

def __run(self):
results = None
try:
results = self._batch_snapshot.process_query_batch(self._partition_id)
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ def delete(self):
with trace_call(
"CloudSpanner.DeleteSession",
self,
extra_attributes={
"session.id": self._session_id,
"session.name": self.name,
},
observability_options=observability_options,
):
api.delete_session(name=self.name, metadata=metadata)
Expand Down
12 changes: 6 additions & 6 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def read(
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadOnlyTransaction.read",
"CloudSpanner.read",
self._session,
trace_attributes,
transaction=self,
Expand All @@ -338,7 +338,7 @@ def read(
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadOnlyTransaction.read",
"CloudSpanner.read",
self._session,
trace_attributes,
transaction=self,
Expand Down Expand Up @@ -630,9 +630,9 @@ def partition_read(
partition_options=partition_options,
)

trace_attributes = {"table_id": table, "columns": columns}
trace_attributes = {"table_id": table, "columns": columns, "index": index}
with trace_call(
"CloudSpanner.PartitionReadOnlyTransaction",
"CloudSpanner.partition_read",
self._session,
trace_attributes,
observability_options=getattr(database, "observability_options", None),
Expand Down Expand Up @@ -735,7 +735,7 @@ def partition_query(

trace_attributes = {"db.statement": sql}
with trace_call(
"CloudSpanner.PartitionReadWriteTransaction",
"CloudSpanner.partition_query",
self._session,
trace_attributes,
observability_options=getattr(database, "observability_options", None),
Expand Down Expand Up @@ -882,7 +882,7 @@ def begin(self):
)
txn_selector = self._make_txn_selector()
with trace_call(
"CloudSpanner.BeginTransaction",
"CloudSpanner.begin",
self._session,
observability_options=getattr(database, "observability_options", None),
):
Expand Down

0 comments on commit 51d3e91

Please sign in to comment.