diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index beb3e46edb..d99c4fde2f 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -308,7 +308,7 @@ def _make_params_pb(params, param_types): :raises ValueError: If ``params`` is None but ``param_types`` is not None. """ - if params is not None: + if params: return Struct( fields={key: _make_value_pb(value) for key, value in params.items()} ) diff --git a/tests/mockserver_tests/test_basics.py b/tests/mockserver_tests/test_basics.py index 12a224314f..9d6dad095e 100644 --- a/tests/mockserver_tests/test_basics.py +++ b/tests/mockserver_tests/test_basics.py @@ -15,6 +15,8 @@ import unittest from google.cloud.spanner_admin_database_v1.types import spanner_database_admin +from google.cloud.spanner_dbapi import Connection +from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer from google.cloud.spanner_v1.testing.mock_spanner import ( start_mock_server, @@ -29,6 +31,8 @@ FixedSizePool, BatchCreateSessionsRequest, ExecuteSqlRequest, + BeginTransactionRequest, + TransactionOptions, ) from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance @@ -62,6 +66,10 @@ def tearDownClass(cls): TestBasics.server.stop(grace=None) TestBasics.server = None + def teardown_method(self, *args, **kwargs): + TestBasics.spanner_service.clear_requests() + TestBasics.database_admin_service.clear_requests() + def _add_select1_result(self): result = result_set.ResultSet( dict( @@ -88,6 +96,19 @@ def _add_select1_result(self): result.rows.extend(["1"]) TestBasics.spanner_service.mock_spanner.add_result("select 1", result) + def add_update_count( + self, + sql: str, + count: int, + dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL, + ): + if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC: + stats = dict(row_count_lower_bound=count) + else: + stats = dict(row_count_exact=count) + result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats))) + TestBasics.spanner_service.mock_spanner.add_result(sql, result) + @property def client(self) -> Client: if self._client is None: @@ -145,3 +166,27 @@ def test_create_table(self): ) operation = database_admin_api.update_database_ddl(request) operation.result(1) + + # TODO: Move this to a separate class once the mock server test setup has + # been re-factored to use a base class for the boiler plate code. + def test_dbapi_partitioned_dml(self): + sql = "UPDATE singers SET foo='bar' WHERE active = true" + self.add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC) + connection = Connection(self.instance, self.database) + connection.autocommit = True + connection.set_autocommit_dml_mode(AutocommitDmlMode.PARTITIONED_NON_ATOMIC) + with connection.cursor() as cursor: + # Note: SQLAlchemy uses [] as the list of parameters for statements + # with no parameters. + cursor.execute(sql, []) + self.assertEqual(100, cursor.rowcount) + + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + begin_request: BeginTransactionRequest = requests[1] + self.assertEqual( + TransactionOptions(dict(partitioned_dml={})), begin_request.options + )