Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementation for partitioned query in dbapi #1067

Merged
merged 7 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
:param parsed_statement: parsed_statement based on the sql query
"""
connection = cursor.connection
column_values = []
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
statement_type = parsed_statement.client_side_statement_type
Expand All @@ -63,24 +64,26 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
connection.rollback()
return None
if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP:
if connection._transaction is None:
committed_timestamp = None
else:
committed_timestamp = connection._transaction.committed
if (
connection._transaction is not None
and connection._transaction.committed is not None
):
column_values.append(connection._transaction.committed)
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
return _get_streamed_result_set(
ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name,
TypeCode.TIMESTAMP,
committed_timestamp,
column_values,
)
if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP:
if connection._snapshot is None:
read_timestamp = None
else:
read_timestamp = connection._snapshot._transaction_read_timestamp
if (
connection._snapshot is not None
and connection._snapshot._transaction_read_timestamp is not None
):
column_values.append(connection._snapshot._transaction_read_timestamp)
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
return _get_streamed_result_set(
ClientSideStatementType.SHOW_READ_TIMESTAMP.name,
TypeCode.TIMESTAMP,
read_timestamp,
column_values,
)
if statement_type == ClientSideStatementType.START_BATCH_DML:
connection.start_batch_dml(cursor)
Expand All @@ -89,14 +92,28 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
return connection.run_batch()
if statement_type == ClientSideStatementType.ABORT_BATCH:
return connection.abort_batch()
if statement_type == ClientSideStatementType.PARTITION_QUERY:
partition_ids = connection.partition_query(parsed_statement)
return _get_streamed_result_set(
"PARTITION",
TypeCode.STRING,
partition_ids,
)
if statement_type == ClientSideStatementType.RUN_PARTITION:
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
return connection.run_partition(
parsed_statement.client_side_statement_params[0]
)


def _get_streamed_result_set(column_name, type_code, column_value):
def _get_streamed_result_set(column_name, type_code, column_values):
struct_type_pb = StructType(
fields=[StructType.Field(name=column_name, type_=Type(code=type_code))]
)

result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb))
if column_value is not None:
result_set.values.extend([_make_value_pb(column_value)])
if len(column_values) > 0:
column_values_pb = []
for column_value in column_values:
column_values_pb.append(_make_value_pb(column_value))
result_set.values.extend(column_values_pb)
return StreamedResultSet(iter([result_set]))
16 changes: 15 additions & 1 deletion google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)
RE_PARTITION_QUERY = re.compile(r"^\s*(PARTITION)\s+(.+)", re.IGNORECASE)
RE_RUN_PARTITION = re.compile(r"^\s*(RUN)\s+(PARTITION)\s+(.+)", re.IGNORECASE)


def parse_stmt(query):
Expand All @@ -48,6 +50,7 @@ def parse_stmt(query):
:returns: ParsedStatement object.
"""
client_side_statement_type = None
client_side_statement_params = []
if RE_COMMIT.match(query):
client_side_statement_type = ClientSideStatementType.COMMIT
if RE_BEGIN.match(query):
Expand All @@ -64,8 +67,19 @@ def parse_stmt(query):
client_side_statement_type = ClientSideStatementType.RUN_BATCH
if RE_ABORT_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
if RE_PARTITION_QUERY.match(query):
match = re.search(RE_PARTITION_QUERY, query)
client_side_statement_params.append(match.group(2))
client_side_statement_type = ClientSideStatementType.PARTITION_QUERY
if RE_RUN_PARTITION.match(query):
match = re.search(RE_RUN_PARTITION, query)
client_side_statement_params.append(match.group(3))
client_side_statement_type = ClientSideStatementType.RUN_PARTITION
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type
StatementType.CLIENT_SIDE,
Statement(query),
client_side_statement_type,
client_side_statement_params,
)
return None
57 changes: 56 additions & 1 deletion google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,15 @@
from google.api_core.exceptions import Aborted
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi import partition_helper
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
from google.cloud.spanner_dbapi.parse_utils import _get_statement_type
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
Statement,
StatementType,
)
from google.cloud.spanner_dbapi.partition_helper import PartitionId
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
Expand Down Expand Up @@ -585,6 +592,54 @@ def abort_batch(self):
self._batch_dml_executor = None
self._batch_mode = BatchMode.NONE

@check_not_closed
def partition_query(
self,
parsed_statement: ParsedStatement,
query_options=None,
):
statement = parsed_statement.statement
partitioned_query = parsed_statement.client_side_statement_params[0]
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
"Only queries can be partitioned. Invalid statement: " + statement.sql
)
if self.read_only is not True and self._client_transaction_started is True:
raise ProgrammingError(
"Partitioned query not supported as the connection is not in "
"read only mode or ReadWrite transaction started"
)

batch_snapshot = self._database.batch_snapshot()
partition_ids = []
partitions = list(
batch_snapshot.generate_query_batches(
partitioned_query,
statement.params,
statement.param_types,
query_options=query_options,
)
)
for partition in partitions:
batch_transaction_id = batch_snapshot.get_batch_transaction_id()
partition_ids.append(
partition_helper.encode_to_string(batch_transaction_id, partition)
)
return partition_ids

@check_not_closed
def run_partition(self, batch_transaction_id):
partition_id: PartitionId = partition_helper.decode_from_string(
batch_transaction_id
)
batch_transaction_id = partition_id.batch_transaction_id
batch_snapshot = self._database.batch_snapshot(
read_timestamp=batch_transaction_id.read_timestamp,
session_id=batch_transaction_id.session_id,
transaction_id=batch_transaction_id.transaction_id,
)
return batch_snapshot.process(partition_id.partition_result)

def __enter__(self):
return self

Expand Down
16 changes: 10 additions & 6 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,19 +232,23 @@ def classify_statement(query, args=None):
get_param_types(args or None),
ResultsChecksum(),
)
if RE_DDL.match(query):
return ParsedStatement(StatementType.DDL, statement)
statement_type = _get_statement_type(statement)
return ParsedStatement(statement_type, statement)

if RE_IS_INSERT.match(query):
return ParsedStatement(StatementType.INSERT, statement)

def _get_statement_type(statement):
query = statement.sql
if RE_DDL.match(query):
return StatementType.DDL
if RE_IS_INSERT.match(query):
return StatementType.INSERT
if RE_NON_UPDATE.match(query) or RE_WITH.match(query):
# As of 13-March-2020, Cloud Spanner only supports WITH for DQL
# statements and doesn't yet support WITH for DML statements.
return ParsedStatement(StatementType.QUERY, statement)
return StatementType.QUERY

statement.sql = ensure_where_clause(query)
return ParsedStatement(StatementType.UPDATE, statement)
return StatementType.UPDATE


def sql_pyformat_args_to_spanner(sql, params):
Expand Down
7 changes: 5 additions & 2 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 20203 Google LLC All rights reserved.
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import Any, List

from google.cloud.spanner_dbapi.checksum import ResultsChecksum

Expand All @@ -35,6 +35,8 @@ class ClientSideStatementType(Enum):
START_BATCH_DML = 6
RUN_BATCH = 7
ABORT_BATCH = 8
PARTITION_QUERY = 9
RUN_PARTITION = 10


@dataclass
Expand All @@ -53,3 +55,4 @@ class ParsedStatement:
statement_type: StatementType
statement: Statement
client_side_statement_type: ClientSideStatementType = None
client_side_statement_params: List[Any] = None
46 changes: 46 additions & 0 deletions google/cloud/spanner_dbapi/partition_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
olavloite marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any

import gzip
import pickle
import base64


def decode_from_string(encoded_partition_id):
gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8"))
partition_id_bytes = gzip.decompress(gzip_bytes)
return pickle.loads(partition_id_bytes)


def encode_to_string(batch_transaction_id, partition_result):
partition_id = PartitionId(batch_transaction_id, partition_result)
partition_id_bytes = pickle.dumps(partition_id)
gzip_bytes = gzip.compress(partition_id_bytes)
return str(base64.b64encode(gzip_bytes), "utf-8")


@dataclass
class BatchTransactionId:
transaction_id: str
session_id: str
read_timestamp: Any


@dataclass
class PartitionId:
batch_transaction_id: BatchTransactionId
partition_result: Any
Loading