diff --git a/.gitignore b/.gitignore index d083ea1ddc..4797754726 100644 --- a/.gitignore +++ b/.gitignore @@ -62,3 +62,7 @@ system_tests/local_test_setup # Make sure a generated file isn't accidentally committed. pylintrc pylintrc.test + + +# Ignore coverage files +.coverage* diff --git a/google/cloud/spanner_dbapi/transaction_helper.py b/google/cloud/spanner_dbapi/transaction_helper.py index bc896009c7..f8f5bfa584 100644 --- a/google/cloud/spanner_dbapi/transaction_helper.py +++ b/google/cloud/spanner_dbapi/transaction_helper.py @@ -20,7 +20,7 @@ from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode from google.cloud.spanner_dbapi.exceptions import RetryAborted -from google.cloud.spanner_v1.session import _get_retry_delay +from google.cloud.spanner_v1._helpers import _get_retry_delay if TYPE_CHECKING: from google.cloud.spanner_dbapi import Connection, Cursor diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 1f4bf5b174..27e53200ed 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -27,11 +27,15 @@ from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper from google.api_core import datetime_helpers +from google.api_core.exceptions import Aborted from google.cloud._helpers import _date_from_iso8601_date from google.cloud.spanner_v1 import TypeCode from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import JsonObject from google.cloud.spanner_v1.request_id_header import with_request_id +from google.rpc.error_details_pb2 import RetryInfo + +import random # Validation error messages NUMERIC_MAX_SCALE_ERR_MSG = ( @@ -460,6 +464,23 @@ def _metadata_with_prefix(prefix, **kw): return [("google-cloud-resource-prefix", prefix)] +def _retry_on_aborted_exception( + func, + deadline, +): + """ + Handles retry logic for Aborted exceptions, considering the deadline. + """ + attempts = 0 + while True: + try: + attempts += 1 + return func() + except Aborted as exc: + _delay_until_retry(exc, deadline=deadline, attempts=attempts) + continue + + def _retry( func, retry_count=5, @@ -529,6 +550,60 @@ def _metadata_with_leader_aware_routing(value, **kw): return ("x-goog-spanner-route-to-leader", str(value).lower()) +def _delay_until_retry(exc, deadline, attempts): + """Helper for :meth:`Session.run_in_transaction`. + + Detect retryable abort, and impose server-supplied delay. + + :type exc: :class:`google.api_core.exceptions.Aborted` + :param exc: exception for aborted transaction + + :type deadline: float + :param deadline: maximum timestamp to continue retrying the transaction. + + :type attempts: int + :param attempts: number of call retries + """ + + cause = exc.errors[0] + now = time.time() + if now >= deadline: + raise + + delay = _get_retry_delay(cause, attempts) + if delay is not None: + if now + delay > deadline: + raise + + time.sleep(delay) + + +def _get_retry_delay(cause, attempts): + """Helper for :func:`_delay_until_retry`. + + :type exc: :class:`grpc.Call` + :param exc: exception for aborted transaction + + :rtype: float + :returns: seconds to wait before retrying the transaction. + + :type attempts: int + :param attempts: number of call retries + """ + if hasattr(cause, "trailing_metadata"): + metadata = dict(cause.trailing_metadata()) + else: + metadata = {} + retry_info_pb = metadata.get("google.rpc.retryinfo-bin") + if retry_info_pb is not None: + retry_info = RetryInfo() + retry_info.ParseFromString(retry_info_pb) + nanos = retry_info.retry_delay.nanos + return retry_info.retry_delay.seconds + nanos / 1.0e9 + + return 2**attempts + random.random() + + class AtomicCounter: def __init__(self, start_value=0): self.__lock = threading.Lock() diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index e62f6b690c..6a9f1f48f5 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -29,8 +29,12 @@ from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1._helpers import _retry +from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception from google.cloud.spanner_v1._helpers import _check_rst_stream_error from google.api_core.exceptions import InternalServerError +import time + +DEFAULT_RETRY_TIMEOUT_SECS = 30 class _BatchBase(_SessionWrapper): @@ -162,6 +166,7 @@ def commit( request_options=None, max_commit_delay=None, exclude_txn_from_change_streams=False, + **kwargs, ): """Commit mutations to the database. @@ -227,9 +232,12 @@ def commit( request=request, metadata=metadata, ) - response = _retry( + deadline = time.time() + kwargs.get( + "timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS + ) + response = _retry_on_aborted_exception( method, - allowed_exceptions={InternalServerError: _check_rst_stream_error}, + deadline=deadline, ) self.committed = response.commit_timestamp self.commit_stats = response.commit_stats @@ -348,7 +356,9 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals ) response = _retry( method, - allowed_exceptions={InternalServerError: _check_rst_stream_error}, + allowed_exceptions={ + InternalServerError: _check_rst_stream_error, + }, ) self.committed = True return response diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 6bcd620cba..963debdab8 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -780,6 +780,7 @@ def batch( request_options=None, max_commit_delay=None, exclude_txn_from_change_streams=False, + **kw, ): """Return an object which wraps a batch. @@ -810,7 +811,11 @@ def batch( :returns: new wrapper """ return BatchCheckout( - self, request_options, max_commit_delay, exclude_txn_from_change_streams + self, + request_options, + max_commit_delay, + exclude_txn_from_change_streams, + **kw, ) def mutation_groups(self): @@ -1171,6 +1176,7 @@ def __init__( request_options=None, max_commit_delay=None, exclude_txn_from_change_streams=False, + **kw, ): self._database = database self._session = self._batch = None @@ -1182,6 +1188,7 @@ def __init__( self._request_options = request_options self._max_commit_delay = max_commit_delay self._exclude_txn_from_change_streams = exclude_txn_from_change_streams + self._kw = kw def __enter__(self): """Begin ``with`` block.""" @@ -1202,6 +1209,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): request_options=self._request_options, max_commit_delay=self._max_commit_delay, exclude_txn_from_change_streams=self._exclude_txn_from_change_streams, + **self._kw, ) finally: if self._database.log_commit_stats and self._batch.commit_stats: diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index d73a8cc2b5..ccc0c4ebdc 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -15,7 +15,6 @@ """Wrapper for Cloud Spanner Session objects.""" from functools import total_ordering -import random import time from datetime import datetime @@ -23,7 +22,8 @@ from google.api_core.exceptions import GoogleAPICallError from google.api_core.exceptions import NotFound from google.api_core.gapic_v1 import method -from google.rpc.error_details_pb2 import RetryInfo +from google.cloud.spanner_v1._helpers import _delay_until_retry +from google.cloud.spanner_v1._helpers import _get_retry_delay from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import CreateSessionRequest @@ -554,57 +554,3 @@ def run_in_transaction(self, func, *args, **kw): extra={"commit_stats": txn.commit_stats}, ) return return_value - - -# Rational: this function factors out complex shared deadline / retry -# handling from two `except:` clauses. -def _delay_until_retry(exc, deadline, attempts): - """Helper for :meth:`Session.run_in_transaction`. - - Detect retryable abort, and impose server-supplied delay. - - :type exc: :class:`google.api_core.exceptions.Aborted` - :param exc: exception for aborted transaction - - :type deadline: float - :param deadline: maximum timestamp to continue retrying the transaction. - - :type attempts: int - :param attempts: number of call retries - """ - cause = exc.errors[0] - - now = time.time() - - if now >= deadline: - raise - - delay = _get_retry_delay(cause, attempts) - if delay is not None: - if now + delay > deadline: - raise - - time.sleep(delay) - - -def _get_retry_delay(cause, attempts): - """Helper for :func:`_delay_until_retry`. - - :type exc: :class:`grpc.Call` - :param exc: exception for aborted transaction - - :rtype: float - :returns: seconds to wait before retrying the transaction. - - :type attempts: int - :param attempts: number of call retries - """ - metadata = dict(cause.trailing_metadata()) - retry_info_pb = metadata.get("google.rpc.retryinfo-bin") - if retry_info_pb is not None: - retry_info = RetryInfo() - retry_info.ParseFromString(retry_info_pb) - nanos = retry_info.retry_delay.nanos - return retry_info.retry_delay.seconds + nanos / 1.0e9 - - return 2**attempts + random.random() diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index 6b50d9a6d1..f60dbbe72a 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -213,10 +213,19 @@ def __create_transaction( def Commit(self, request, context): self._requests.append(request) self.mock_spanner.pop_error(context) - tx = self.transactions[request.transaction_id] - if tx is None: - raise ValueError(f"Transaction not found: {request.transaction_id}") - del self.transactions[request.transaction_id] + if not request.transaction_id == b"": + tx = self.transactions[request.transaction_id] + if tx is None: + raise ValueError(f"Transaction not found: {request.transaction_id}") + tx_id = request.transaction_id + elif not request.single_use_transaction == TransactionOptions(): + tx = self.__create_transaction( + request.session, request.single_use_transaction + ) + tx_id = tx.id + else: + raise ValueError("Unsupported transaction type") + del self.transactions[tx_id] return commit.CommitResponse() def Rollback(self, request, context): diff --git a/tests/mockserver_tests/test_aborted_transaction.py b/tests/mockserver_tests/test_aborted_transaction.py index 89b30a0875..93eb42fe39 100644 --- a/tests/mockserver_tests/test_aborted_transaction.py +++ b/tests/mockserver_tests/test_aborted_transaction.py @@ -95,6 +95,30 @@ def test_run_in_transaction_batch_dml_aborted(self): self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest)) self.assertTrue(isinstance(requests[3], CommitRequest)) + def test_batch_commit_aborted(self): + # Add an Aborted error for the Commit method on the mock server. + add_error(SpannerServicer.Commit.__name__, aborted_status()) + with self.database.batch() as batch: + batch.insert( + table="Singers", + columns=("SingerId", "FirstName", "LastName"), + values=[ + (1, "Marc", "Richards"), + (2, "Catalina", "Smith"), + (3, "Alice", "Trentor"), + (4, "Lea", "Martin"), + (5, "David", "Lomond"), + ], + ) + + # Verify that the transaction was retried. + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], CommitRequest)) + # The transaction is aborted and retried. + self.assertTrue(isinstance(requests[2], CommitRequest)) + def _insert_mutations(transaction: Transaction): transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"]) diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index e62bff2a2e..ecc8018648 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -882,6 +882,66 @@ def test_check_rst_stream_error(self): self.assertEqual(test_api.test_fxn.call_count, 3) + def test_retry_on_aborted_exception_with_success_after_first_aborted_retry(self): + from google.api_core.exceptions import Aborted + import time + from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception + import functools + + test_api = mock.create_autospec(self.test_class) + test_api.test_fxn.side_effect = [ + Aborted("aborted exception", errors=("Aborted error")), + "true", + ] + deadline = time.time() + 30 + result_after_retry = _retry_on_aborted_exception( + functools.partial(test_api.test_fxn), deadline + ) + + self.assertEqual(test_api.test_fxn.call_count, 2) + self.assertTrue(result_after_retry) + + def test_retry_on_aborted_exception_with_success_after_three_retries(self): + from google.api_core.exceptions import Aborted + import time + from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception + import functools + + test_api = mock.create_autospec(self.test_class) + # Case where aborted exception is thrown after other generic exceptions + test_api.test_fxn.side_effect = [ + Aborted("aborted exception", errors=("Aborted error")), + Aborted("aborted exception", errors=("Aborted error")), + Aborted("aborted exception", errors=("Aborted error")), + "true", + ] + deadline = time.time() + 30 + _retry_on_aborted_exception( + functools.partial(test_api.test_fxn), + deadline=deadline, + ) + + self.assertEqual(test_api.test_fxn.call_count, 4) + + def test_retry_on_aborted_exception_raises_aborted_if_deadline_expires(self): + from google.api_core.exceptions import Aborted + import time + from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception + import functools + + test_api = mock.create_autospec(self.test_class) + test_api.test_fxn.side_effect = [ + Aborted("aborted exception", errors=("Aborted error")), + "true", + ] + deadline = time.time() + 0.1 + with self.assertRaises(Aborted): + _retry_on_aborted_exception( + functools.partial(test_api.test_fxn), deadline=deadline + ) + + self.assertEqual(test_api.test_fxn.call_count, 1) + class Test_metadata_with_leader_aware_routing(unittest.TestCase): def _call_fut(self, *args, **kw): diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index ea744f8889..eb5069b497 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -14,6 +14,7 @@ import unittest +from unittest.mock import MagicMock from tests._helpers import ( OpenTelemetryBase, StatusCode, @@ -265,6 +266,37 @@ def test_commit_ok(self): attributes=dict(BASE_ATTRIBUTES, num_mutations=1), ) + def test_aborted_exception_on_commit_with_retries(self): + # Test case to verify that an Aborted exception is raised when + # batch.commit() is called and the transaction is aborted internally. + from google.api_core.exceptions import Aborted + + database = _Database() + # Setup the spanner API which throws Aborted exception when calling commit API. + api = database.spanner_api = _FauxSpannerAPI(_aborted_error=True) + api.commit = MagicMock( + side_effect=Aborted("Transaction was aborted", errors=("Aborted error")) + ) + + # Create mock session and batch objects + session = _Session(database) + batch = self._make_one(session) + batch.insert(TABLE_NAME, COLUMNS, VALUES) + + # Assertion: Ensure that calling batch.commit() raises the Aborted exception + with self.assertRaises(Aborted) as context: + batch.commit() + + # Verify additional details about the exception + self.assertEqual(str(context.exception), "409 Transaction was aborted") + self.assertGreater( + api.commit.call_count, 1, "commit should be called more than once" + ) + # Since we are using exponential backoff here and default timeout is set to 30 sec 2^x <= 30. So value for x will be 4 + self.assertEqual( + api.commit.call_count, 4, "commit should be called exactly 4 times" + ) + def _test_commit_with_options( self, request_options=None, @@ -630,6 +662,7 @@ class _FauxSpannerAPI: _committed = None _batch_request = None _rpc_error = False + _aborted_error = False def __init__(self, **kwargs): self.__dict__.update(**kwargs) @@ -640,6 +673,7 @@ def commit( metadata=None, ): from google.api_core.exceptions import Unknown + from google.api_core.exceptions import Aborted max_commit_delay = None if type(request).pb(request).HasField("max_commit_delay"): @@ -656,6 +690,8 @@ def commit( ) if self._rpc_error: raise Unknown("error") + if self._aborted_error: + raise Aborted("Transaction was aborted", errors=("Aborted error")) return self._commit_response def batch_write( diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 6e29255fb7..13a37f66fe 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1899,8 +1899,8 @@ def test_context_mgr_w_commit_stats_success(self): "CommitStats: mutation_count: 4\n", extra={"commit_stats": commit_stats} ) - def test_context_mgr_w_commit_stats_error(self): - from google.api_core.exceptions import Unknown + def test_context_mgr_w_aborted_commit_status(self): + from google.api_core.exceptions import Aborted from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1.batch import Batch @@ -1908,13 +1908,13 @@ def test_context_mgr_w_commit_stats_error(self): database = _Database(self.DATABASE_NAME) database.log_commit_stats = True api = database.spanner_api = self._make_spanner_client() - api.commit.side_effect = Unknown("testing") + api.commit.side_effect = Aborted("aborted exception", errors=("Aborted error")) pool = database._pool = _Pool() session = _Session(database) pool.put(session) checkout = self._make_one(database) - with self.assertRaises(Unknown): + with self.assertRaises(Aborted): with checkout as batch: self.assertIsNone(pool._session) self.assertIsInstance(batch, Batch) @@ -1931,7 +1931,10 @@ def test_context_mgr_w_commit_stats_error(self): return_commit_stats=True, request_options=RequestOptions(), ) - api.commit.assert_called_once_with( + # Asserts that the exponential backoff retry for aborted transactions with a 30-second deadline + # allows for a maximum of 4 retries (2^x <= 30) to stay within the time limit. + self.assertEqual(api.commit.call_count, 4) + api.commit.assert_any_call( request=request, metadata=[ ("google-cloud-resource-prefix", database.name), diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0d60e98cd0..55c91435f8 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1911,7 +1911,7 @@ def unit_of_work(txn, *args, **kw): ) def test_delay_helper_w_no_delay(self): - from google.cloud.spanner_v1.session import _delay_until_retry + from google.cloud.spanner_v1._helpers import _delay_until_retry metadata_mock = mock.Mock() metadata_mock.trailing_metadata.return_value = {} @@ -1928,7 +1928,7 @@ def _time_func(): with mock.patch("time.time", _time_func): with mock.patch( - "google.cloud.spanner_v1.session._get_retry_delay" + "google.cloud.spanner_v1._helpers._get_retry_delay" ) as get_retry_delay_mock: with mock.patch("time.sleep") as sleep_mock: get_retry_delay_mock.return_value = None