Skip to content

Commit

Permalink
Add 'Snapshot.begin' API method.
Browse files Browse the repository at this point in the history
- Valid only for multi-use snapshots.
- Raises if the snapshot already has a transaction ID.
  • Loading branch information
tseaver committed Jul 24, 2017
1 parent a5219a5 commit 230d715
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 4 deletions.
23 changes: 23 additions & 0 deletions spanner/google/cloud/spanner/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,26 @@ def _make_txn_selector(self):
return TransactionSelector(begin=options)
else:
return TransactionSelector(single_use=options)

def begin(self):
"""Begin a transaction on the database.
:rtype: bytes
:returns: the ID for the newly-begun transaction.
:raises: ValueError if the transaction is already begun, committed,
or rolled back.
"""
if not self._multi_use:
raise ValueError("Cannot call 'begin' single-use snapshots")

if self._transaction_id is not None:
raise ValueError("Transaction already begun")

database = self._session._database
api = database.spanner_api
options = _options_with_prefix(database.name)
txn_selector = self._make_txn_selector()
response = api.begin_transaction(
self._session.name, txn_selector.begin, options=options)
self._transaction_id = response.id
return self._transaction_id
100 changes: 96 additions & 4 deletions spanner/tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ class TestSnapshot(unittest.TestCase):
DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID
SESSION_ID = 'session-id'
SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID
TRANSACTION_ID = b'DEADBEEF'

def _getTargetClass(self):
from google.cloud.spanner.snapshot import Snapshot
Expand Down Expand Up @@ -493,12 +494,11 @@ def test_ctor_w_multi_use_and_exact_staleness(self):
self.assertTrue(snapshot._multi_use)

def test__make_txn_selector_w_transaction_id(self):
TXN_ID = b'DEADBEEF'
session = _Session()
snapshot = self._make_one(session)
snapshot._transaction_id = TXN_ID
snapshot._transaction_id = self.TRANSACTION_ID
selector = snapshot._make_txn_selector()
self.assertEqual(selector.id, TXN_ID)
self.assertEqual(selector.id, self.TRANSACTION_ID)

def test__make_txn_selector_strong(self):
session = _Session()
Expand Down Expand Up @@ -579,6 +579,90 @@ def test__make_txn_selector_w_exact_staleness_w_multi_use(self):
self.assertEqual(options.read_only.exact_staleness.seconds, 3)
self.assertEqual(options.read_only.exact_staleness.nanos, 123456000)

def test_begin_wo_multi_use(self):
session = _Session()
snapshot = self._make_one(session)
with self.assertRaises(ValueError):
snapshot.begin()

def test_begin_w_existing_txn_id(self):
session = _Session()
snapshot = self._make_one(session, multi_use=True)
snapshot._transaction_id = self.TRANSACTION_ID
with self.assertRaises(ValueError):
snapshot.begin()

def test_begin_w_gax_error(self):
from google.gax.errors import GaxError
from google.cloud._helpers import _pb_timestamp_to_datetime

database = _Database()
api = database.spanner_api = _FauxSpannerAPI(
_random_gax_error=True)
timestamp = self._makeTimestamp()
session = _Session(database)
snapshot = self._make_one(
session, read_timestamp=timestamp, multi_use=True)

with self.assertRaises(GaxError):
snapshot.begin()

session_id, txn_options, options = api._begun
self.assertEqual(session_id, session.name)
self.assertEqual(
_pb_timestamp_to_datetime(txn_options.read_only.read_timestamp),
timestamp)
self.assertEqual(options.kwargs['metadata'],
[('google-cloud-resource-prefix', database.name)])

def test_begin_ok_exact_staleness(self):
from google.cloud.proto.spanner.v1.transaction_pb2 import (
Transaction as TransactionPB)

transaction_pb = TransactionPB(id=self.TRANSACTION_ID)
database = _Database()
api = database.spanner_api = _FauxSpannerAPI(
_begin_transaction_response=transaction_pb)
duration = self._makeDuration(seconds=3, microseconds=123456)
session = _Session(database)
snapshot = self._make_one(
session, exact_staleness=duration, multi_use=True)

txn_id = snapshot.begin()

self.assertEqual(txn_id, self.TRANSACTION_ID)
self.assertEqual(snapshot._transaction_id, self.TRANSACTION_ID)

session_id, txn_options, options = api._begun
self.assertEqual(session_id, session.name)
read_only = txn_options.read_only
self.assertEqual(read_only.exact_staleness.seconds, 3)
self.assertEqual(read_only.exact_staleness.nanos, 123456000)
self.assertEqual(options.kwargs['metadata'],
[('google-cloud-resource-prefix', database.name)])

def test_begin_ok_exact_strong(self):
from google.cloud.proto.spanner.v1.transaction_pb2 import (
Transaction as TransactionPB)

transaction_pb = TransactionPB(id=self.TRANSACTION_ID)
database = _Database()
api = database.spanner_api = _FauxSpannerAPI(
_begin_transaction_response=transaction_pb)
session = _Session(database)
snapshot = self._make_one(session, multi_use=True)

txn_id = snapshot.begin()

self.assertEqual(txn_id, self.TRANSACTION_ID)
self.assertEqual(snapshot._transaction_id, self.TRANSACTION_ID)

session_id, txn_options, options = api._begun
self.assertEqual(session_id, session.name)
self.assertTrue(txn_options.read_only.strong)
self.assertEqual(options.kwargs['metadata'],
[('google-cloud-resource-prefix', database.name)])


class _Session(object):

Expand All @@ -593,7 +677,15 @@ class _Database(object):

class _FauxSpannerAPI(_GAXBaseAPI):

_read_with = None
_read_with = _begin = None

def begin_transaction(self, session, options_, options=None):
from google.gax.errors import GaxError

self._begun = (session, options_, options)
if self._random_gax_error:
raise GaxError('error')
return self._begin_transaction_response

# pylint: disable=too-many-arguments
def streaming_read(self, session, table, columns, key_set,
Expand Down

0 comments on commit 230d715

Please sign in to comment.