Skip to content

Commit

Permalink
fix(tracing): ensure nesting of Transaction.begin under commit + fix …
Browse files Browse the repository at this point in the history
…suggestions from feature review

This change ensures that:
* If a transaction was not yet begin, that if .commit() is invoked
the resulting span hierarchy has .begin nested under .commit
* We use "CloudSpanner.Transaction.execute_sql" instead of
  "CloudSpanner.Transaction.execute_streaming_sql"
* If we have a tracer_provider that produces non-recordings spans,
that it won't crash due to lacking `span._status`

Fixes #1286
  • Loading branch information
odeke-em committed Jan 9, 2025
1 parent 04a11a6 commit e83b4af
Show file tree
Hide file tree
Showing 9 changed files with 405 additions and 52 deletions.
5 changes: 4 additions & 1 deletion google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def trace_call(name, session=None, extra_attributes=None, observability_options=
# invoke .record_exception on our own else we shall have 2 exceptions.
raise
else:
if (not span._status) or span._status.status_code == StatusCode.UNSET:
# All spans still have set_status available even if for example
# NonRecordingSpan doesn't have "_status".
absent_span_status = getattr(span, "_status", None) is None
if absent_span_status or span._status.status_code == StatusCode.UNSET:
# OpenTelemetry-Python only allows a status change
# if the current code is UNSET or ERROR. At the end
# of the generator's consumption, only set it to OK
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _get_streamed_result_set(
iterator = _restart_on_unavailable(
restart,
request,
f"CloudSpanner.{type(self).__name__}.execute_streaming_sql",
f"CloudSpanner.{type(self).__name__}.execute_sql",
self._session,
trace_attributes,
transaction=self,
Expand Down
66 changes: 34 additions & 32 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,39 +242,7 @@ def commit(
:returns: timestamp of the committed changes.
:raises ValueError: if there are no mutations to commit.
"""
self._check_state()
if self._transaction_id is None and len(self._mutations) > 0:
self.begin()
elif self._transaction_id is None and len(self._mutations) == 0:
raise ValueError("Transaction is not begun")

database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
if self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
request_options.request_tag = None

request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
max_commit_delay=max_commit_delay,
request_options=request_options,
)

trace_attributes = {"num_mutations": len(self._mutations)}
observability_options = getattr(database, "observability_options", None)
with trace_call(
Expand All @@ -283,6 +251,40 @@ def commit(
trace_attributes,
observability_options,
) as span:
self._check_state()
if self._transaction_id is None and len(self._mutations) > 0:
self.begin()
elif self._transaction_id is None and len(self._mutations) == 0:
raise ValueError("Transaction is not begun")

api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(
database._route_to_leader_enabled
)
)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
if self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
request_options.request_tag = None

request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
max_commit_delay=max_commit_delay,
request_options=request_options,
)

add_span_event(span, "Starting Commit")

method = functools.partial(
Expand Down
17 changes: 17 additions & 0 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,20 @@ def get_finished_spans(self):

def reset(self):
self.tearDown()

def finished_spans_events_statuses(self):
span_list = self.get_finished_spans()
# Some event attributes are noisy/highly ephemeral
# and can't be directly compared against.
got_all_events = []
imprecise_event_attributes = ["exception.stacktrace", "delay_seconds", "cause"]
for span in span_list:
for event in span.events:
evt_attributes = event.attributes.copy()
for attr_name in imprecise_event_attributes:
if attr_name in evt_attributes:
evt_attributes[attr_name] = "EPHEMERAL"

got_all_events.append((event.name, evt_attributes))

return got_all_events
204 changes: 199 additions & 5 deletions tests/system/test_observability_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_propagation(enable_extended_tracing):
gotNames = [span.name for span in from_inject_spans]
wantNames = [
"CloudSpanner.CreateSession",
"CloudSpanner.Snapshot.execute_streaming_sql",
"CloudSpanner.Snapshot.execute_sql",
]
assert gotNames == wantNames

Expand Down Expand Up @@ -216,8 +216,8 @@ def select_in_txn(txn):
"CloudSpanner.Database.run_in_transaction",
"CloudSpanner.CreateSession",
"CloudSpanner.Session.run_in_transaction",
"CloudSpanner.Transaction.execute_streaming_sql",
"CloudSpanner.Transaction.execute_streaming_sql",
"CloudSpanner.Transaction.execute_sql",
"CloudSpanner.Transaction.execute_sql",
"CloudSpanner.Transaction.commit",
]

Expand Down Expand Up @@ -262,13 +262,207 @@ def select_in_txn(txn):
("CloudSpanner.Database.run_in_transaction", codes.OK, None),
("CloudSpanner.CreateSession", codes.OK, None),
("CloudSpanner.Session.run_in_transaction", codes.OK, None),
("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None),
("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None),
("CloudSpanner.Transaction.execute_sql", codes.OK, None),
("CloudSpanner.Transaction.execute_sql", codes.OK, None),
("CloudSpanner.Transaction.commit", codes.OK, None),
]
assert got_statuses == want_statuses


@pytest.mark.skipif(
not _helpers.USE_EMULATOR,
reason="Emulator needed to run this tests",
)
@pytest.mark.skipif(
not HAS_OTEL_INSTALLED,
reason="Tracing requires OpenTelemetry",
)
def test_transaction_update_implicit_begin_nested_inside_commit():
# Tests to ensure that transaction.commit() without a began transaction
# has transaction.begin() inlined and nested under the commit span.
from google.auth.credentials import AnonymousCredentials
from google.api_core.exceptions import Aborted
from google.rpc import code_pb2
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.trace.status import StatusCode
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.sampling import ALWAYS_ON

PROJECT = _helpers.EMULATOR_PROJECT
CONFIGURATION_NAME = "config-name"
INSTANCE_ID = _helpers.INSTANCE_ID
DISPLAY_NAME = "display-name"
DATABASE_ID = _helpers.unique_id("temp_db")
NODE_COUNT = 5
LABELS = {"test": "true"}

counters = dict(aborted=0)

def tx_update(txn):
txn.update(
"Singers",
columns=["SingerId", "FirstName"],
values=[["1", "Bryan"], ["2", "Slash"]],
)

tracer_provider = TracerProvider(sampler=ALWAYS_ON)
trace_exporter = InMemorySpanExporter()
tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter))
observability_options = dict(
tracer_provider=tracer_provider,
enable_extended_tracing=True,
)

client = Client(
project=PROJECT,
observability_options=observability_options,
credentials=AnonymousCredentials(),
)

instance = client.instance(
INSTANCE_ID,
CONFIGURATION_NAME,
display_name=DISPLAY_NAME,
node_count=NODE_COUNT,
labels=LABELS,
)

try:
instance.create()
except Exception:
pass

db = instance.database(DATABASE_ID)
try:
db._ddl_statements = [
"""CREATE TABLE Singers (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX),
FullName STRING(2048) AS (
ARRAY_TO_STRING([FirstName, LastName], " ")
) STORED
) PRIMARY KEY (SingerId)""",
"""CREATE TABLE Albums (
SingerId INT64 NOT NULL,
AlbumId INT64 NOT NULL,
AlbumTitle STRING(MAX),
MarketingBudget INT64,
) PRIMARY KEY (SingerId, AlbumId),
INTERLEAVE IN PARENT Singers ON DELETE CASCADE""",
]
db.create()
except Exception:
pass

try:
db.run_in_transaction(tx_update)
except:
pass

span_list = trace_exporter.get_finished_spans()
# Sort the spans by their start time in the hierarchy.
span_list = sorted(span_list, key=lambda span: span.start_time)
got_span_names = [span.name for span in span_list]
want_span_names = [
"CloudSpanner.Database.run_in_transaction",
"CloudSpanner.CreateSession",
"CloudSpanner.Session.run_in_transaction",
"CloudSpanner.Transaction.commit",
"CloudSpanner.Transaction.begin",
]

assert got_span_names == want_span_names
span_tx_begin = span_list[-1]
span_tx_commit = span_list[-2]

got_events = []
got_statuses = []

# Some event attributes are noisy/highly ephemeral
# and can't be directly compared against.
imprecise_event_attributes = ["exception.stacktrace", "delay_seconds", "cause"]
for span in span_list:
got_statuses.append(
(span.name, span.status.status_code, span.status.description)
)
for event in span.events:
evt_attributes = event.attributes.copy()
for attr_name in imprecise_event_attributes:
if attr_name in evt_attributes:
evt_attributes[attr_name] = "EPHEMERAL"

got_events.append((event.name, evt_attributes))

# Check for the series of events
want_events = [
("Acquiring session", {"kind": "BurstyPool"}),
("Waiting for a session to become available", {"kind": "BurstyPool"}),
("No sessions available in pool. Creating session", {"kind": "BurstyPool"}),
("Creating Session", {}),
(
"exception",
{
"exception.type": "google.api_core.exceptions.NotFound",
"exception.message": "404 Table Singers: Row {Int64(1)} not found.",
"exception.stacktrace": "EPHEMERAL",
"exception.escaped": "False",
},
),
(
"Transaction.commit failed due to GoogleAPICallError, not retrying",
{"attempt": 1},
),
(
"exception",
{
"exception.type": "google.api_core.exceptions.NotFound",
"exception.message": "404 Table Singers: Row {Int64(1)} not found.",
"exception.stacktrace": "EPHEMERAL",
"exception.escaped": "False",
},
),
("Starting Commit", {}),
(
"exception",
{
"exception.type": "google.api_core.exceptions.NotFound",
"exception.message": "404 Table Singers: Row {Int64(1)} not found.",
"exception.stacktrace": "EPHEMERAL",
"exception.escaped": "False",
},
),
]
assert got_events == want_events

# Check for the statues.
codes = StatusCode
want_statuses = [
(
"CloudSpanner.Database.run_in_transaction",
codes.ERROR,
"NotFound: 404 Table Singers: Row {Int64(1)} not found.",
),
("CloudSpanner.CreateSession", codes.OK, None),
(
"CloudSpanner.Session.run_in_transaction",
codes.ERROR,
"NotFound: 404 Table Singers: Row {Int64(1)} not found.",
),
(
"CloudSpanner.Transaction.commit",
codes.ERROR,
"NotFound: 404 Table Singers: Row {Int64(1)} not found.",
),
("CloudSpanner.Transaction.begin", codes.OK, None),
]
assert got_statuses == want_statuses


def _make_credentials():
from google.auth.credentials import AnonymousCredentials

Expand Down
32 changes: 31 additions & 1 deletion tests/unit/test__opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_trace_codeless_error(self):
span = span_list[0]
self.assertEqual(span.status.status_code, StatusCode.ERROR)

def test_trace_call_terminal_span_status(self):
def test_trace_call_terminal_span_status_ALWAYS_ON_sampler(self):
# Verify that we don't unconditionally set the terminal span status to
# SpanStatus.OK per https://github.com/googleapis/python-spanner/issues/1246
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
Expand Down Expand Up @@ -195,3 +195,33 @@ def test_trace_call_terminal_span_status(self):
("VerifyTerminalSpanStatus", StatusCode.ERROR, "Our error exhibit"),
]
assert got_statuses == want_statuses

def test_trace_call_terminal_span_status_ALWAYS_OFF_sampler(self):
# Verify that we get the correct status even when using the ALWAYS_OFF
# sampler which produces the NonRecordingSpan per
# https://github.com/googleapis/python-spanner/issues/1286
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.sampling import ALWAYS_OFF

tracer_provider = TracerProvider(sampler=ALWAYS_OFF)
trace_exporter = InMemorySpanExporter()
tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter))
observability_options = dict(tracer_provider=tracer_provider)

session = _make_session()
used_span = None
with _opentelemetry_tracing.trace_call(
"VerifyWithNonRecordingSpan",
session,
observability_options=observability_options,
) as span:
used_span = span

assert type(used_span).__name__ == "NonRecordingSpan"
span_list = list(trace_exporter.get_finished_spans())
assert span_list == []
Loading

0 comments on commit e83b4af

Please sign in to comment.