Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

observability: PDML + some batch write spans #1274

Merged
merged 8 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
)
observability_options = getattr(database, "observability_options", None)
with trace_call(
"CloudSpanner.BatchWrite",
"CloudSpanner.batch_write",
self._session,
trace_attributes,
observability_options=observability_options,
Expand Down
222 changes: 129 additions & 93 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,38 +699,43 @@ def execute_partitioned_dml(
)

def execute_pdml():
with SessionCheckout(self._pool) as session:
odeke-em marked this conversation as resolved.
Show resolved Hide resolved
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)
with trace_call(
"CloudSpanner.Database.execute_partitioned_pdml",
observability_options=self.observability_options,
) as span:
with SessionCheckout(self._pool) as session:
add_span_event(span, "Starting BeginTransaction")
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)

txn_selector = TransactionSelector(id=txn.id)
txn_selector = TransactionSelector(id=txn.id)

request = ExecuteSqlRequest(
session=session.name,
sql=dml,
params=params_pb,
param_types=param_types,
query_options=query_options,
request_options=request_options,
)
method = functools.partial(
api.execute_streaming_sql,
metadata=metadata,
)
request = ExecuteSqlRequest(
session=session.name,
sql=dml,
params=params_pb,
param_types=param_types,
query_options=query_options,
request_options=request_options,
)
method = functools.partial(
api.execute_streaming_sql,
metadata=metadata,
)

iterator = _restart_on_unavailable(
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
)
iterator = _restart_on_unavailable(
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
)

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials

return result_set.stats.row_count_lower_bound
return result_set.stats.row_count_lower_bound

return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()

Expand Down Expand Up @@ -1357,6 +1362,10 @@ def to_dict(self):
"transaction_id": snapshot._transaction_id,
}

@property
def observability_options(self):
return getattr(self._database, "observability_options", {})

def _get_session(self):
"""Create session as needed.

Expand Down Expand Up @@ -1476,27 +1485,32 @@ def generate_read_batches(
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_read(
table=table,
columns=columns,
keyset=keyset,
index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_read_batches",
extra_attributes=dict(table=table, columns=columns),
observability_options=self.observability_options,
):
partitions = self._get_snapshot().partition_read(
table=table,
columns=columns,
keyset=keyset,
index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

read_info = {
"table": table,
"columns": columns,
"keyset": keyset._to_dict(),
"index": index,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
for partition in partitions:
yield {"partition": partition, "read": read_info.copy()}
read_info = {
"table": table,
"columns": columns,
"keyset": keyset._to_dict(),
"index": index,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
for partition in partitions:
yield {"partition": partition, "read": read_info.copy()}

def process_read_batch(
self,
Expand All @@ -1522,12 +1536,17 @@ def process_read_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)
observability_options = self.observability_options
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_read_batch",
observability_options=observability_options,
):
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)

def generate_query_batches(
self,
Expand Down Expand Up @@ -1602,34 +1621,39 @@ def generate_query_batches(
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_query_batches",
extra_attributes=dict(sql=sql),
observability_options=self.observability_options,
):
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)
query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)

for partition in partitions:
yield {"partition": partition, "query": query_info}
for partition in partitions:
yield {"partition": partition, "query": query_info}

def process_query_batch(
self,
Expand All @@ -1654,9 +1678,16 @@ def process_query_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self._get_snapshot().execute_sql(
partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_query_batch",
observability_options=self.observability_options,
):
return self._get_snapshot().execute_sql(
partition=batch["partition"],
**batch["query"],
retry=retry,
timeout=timeout,
)

def run_partitioned_query(
self,
Expand Down Expand Up @@ -1711,18 +1742,23 @@ def run_partitioned_query(
:rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
with trace_call(
f"CloudSpanner.${type(self).__name__}.run_partitioned_query",
extra_attributes=dict(sql=sql),
observability_options=self.observability_options,
):
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
)
)
)
return MergedResultSet(self, partitions, 0)
return MergedResultSet(self, partitions, 0)

def process(self, batch):
"""Process a single, partitioned query or read.
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/spanner_v1/merged_result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import Any, TYPE_CHECKING
from threading import Lock, Event

from google.cloud.spanner_v1._opentelemetry_tracing import trace_call

if TYPE_CHECKING:
from google.cloud.spanner_v1.database import BatchSnapshot

Expand All @@ -37,6 +39,16 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set):
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue

def run(self):
observability_options = getattr(
self._batch_snapshot, "observability_options", {}
)
with trace_call(
"CloudSpanner.PartitionExecutor.run",
observability_options=observability_options,
):
self.__run()

def __run(self):
results = None
try:
results = self._batch_snapshot.process_query_batch(self._partition_id)
Expand Down
29 changes: 9 additions & 20 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,11 @@ def bind(self, database):
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
created_session_count = 0
self._database_role = self._database_role or self._database.database_role

request = BatchCreateSessionsRequest(
database=database.name,
session_count=self.size - created_session_count,
session_count=self.size,
session_template=Session(creator_role=self.database_role),
)

Expand All @@ -549,38 +548,28 @@ def bind(self, database):
span_event_attributes,
)

if created_session_count >= self.size:
add_span_event(
current_span,
"Created no new sessions as sessionPool is full",
span_event_attributes,
)
return

add_span_event(
current_span,
f"Creating {request.session_count} sessions",
span_event_attributes,
)

observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.PingingPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
returned_session_count = 0
while created_session_count < self.size:
while returned_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)

add_span_event(
span,
f"Created {len(resp.session)} sessions",
)

for session_pb in resp.session:
session = self._new_session()
returned_session_count += 1
session._session_id = session_pb.name.split("/")[-1]
self.put(session)
returned_session_count += 1

created_session_count += len(resp.session)

add_span_event(
span,
Expand Down
Loading
Loading