Skip to content

Commit

Permalink
Merge pull request #978 from tseaver/944-batch_transaction_query_hold…
Browse files Browse the repository at this point in the history
…_client

Update batch/transaction/query to hold client.
  • Loading branch information
tseaver committed Jul 11, 2015
2 parents 681105d + d556b63 commit 3c76dbf
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 419 deletions.
55 changes: 22 additions & 33 deletions gcloud/datastore/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,11 @@
https://cloud.google.com/datastore/docs/concepts/entities#Datastore_Batch_operations
"""

from gcloud._helpers import _LocalStack
from gcloud.datastore import _implicit_environ
from gcloud.datastore import helpers
from gcloud.datastore.key import _dataset_ids_equal
from gcloud.datastore import _datastore_v1_pb2 as datastore_pb


_BATCHES = _LocalStack()


class Batch(object):
"""An abstraction representing a collected group of updates / deletes.
Expand Down Expand Up @@ -62,34 +57,19 @@ class Batch(object):
... do_some_work(batch)
... raise Exception() # rolls back
:type dataset_id: :class:`str`.
:param dataset_id: The ID of the dataset.
:type connection: :class:`gcloud.datastore.connection.Connection`
:param connection: The connection used to connect to datastore.
:raises: :class:`ValueError` if either a connection or dataset ID
are not set.
:type client: :class:`gcloud.datastore.client.Client`
:param client: The client used to connect to datastore.
"""
_id = None # "protected" attribute, always None for non-transactions

def __init__(self, dataset_id=None, connection=None):
self._connection = (connection or
_implicit_environ.get_default_connection())
self._dataset_id = (dataset_id or
_implicit_environ.get_default_dataset_id())

if self._connection is None or self._dataset_id is None:
raise ValueError('A batch must have a connection and '
'a dataset ID set.')

def __init__(self, client):
self._client = client
self._mutation = datastore_pb.Mutation()
self._auto_id_entities = []

@staticmethod
def current():
def current(self):
"""Return the topmost batch / transaction, or None."""
return _BATCHES.top
return self._client.current_batch

@property
def dataset_id(self):
Expand All @@ -98,7 +78,16 @@ def dataset_id(self):
:rtype: :class:`str`
:returns: The dataset ID in which the batch will run.
"""
return self._dataset_id
return self._client.dataset_id

@property
def namespace(self):
"""Getter for namespace in which the batch will run.
:rtype: :class:`str`
:returns: The namespace in which the batch will run.
"""
return self._client.namespace

@property
def connection(self):
Expand All @@ -107,7 +96,7 @@ def connection(self):
:rtype: :class:`gcloud.datastore.connection.Connection`
:returns: The connection over which the batch will run.
"""
return self._connection
return self._client.connection

@property
def mutation(self):
Expand Down Expand Up @@ -172,7 +161,7 @@ def put(self, entity):
if entity.key is None:
raise ValueError("Entity must have a key")

if not _dataset_ids_equal(self._dataset_id, entity.key.dataset_id):
if not _dataset_ids_equal(self.dataset_id, entity.key.dataset_id):
raise ValueError("Key must be from same dataset as batch")

_assign_entity_to_mutation(
Expand All @@ -190,7 +179,7 @@ def delete(self, key):
if key.is_partial:
raise ValueError("Key must be complete")

if not _dataset_ids_equal(self._dataset_id, key.dataset_id):
if not _dataset_ids_equal(self.dataset_id, key.dataset_id):
raise ValueError("Key must be from same dataset as batch")

key_pb = helpers._prepare_key_for_request(key.to_protobuf())
Expand All @@ -211,7 +200,7 @@ def commit(self):
context manager.
"""
response = self.connection.commit(
self._dataset_id, self.mutation, self._id)
self.dataset_id, self.mutation, self._id)
# If the back-end returns without error, we are guaranteed that
# the response's 'insert_auto_id_key' will match (length and order)
# the request's 'insert_auto_id` entities, which are derived from
Expand All @@ -229,7 +218,7 @@ def rollback(self):
pass

def __enter__(self):
_BATCHES.push(self)
self._client._push_batch(self)
self.begin()
return self

Expand All @@ -240,7 +229,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
else:
self.rollback()
finally:
_BATCHES.pop()
self._client._pop_batch()


def _assign_entity_to_mutation(mutation_pb, entity, auto_id_entities):
Expand Down
32 changes: 22 additions & 10 deletions gcloud/datastore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ def current_batch(self):
"""
return self._batch_stack.top

@property
def current_transaction(self):
"""Currently-active transaction.
:rtype: :class:`gcloud.datastore.transaction.Transaction`, or an object
implementing its API, or ``NoneType`` (if no transaction is
active).
:returns: The transaction at the toop of the batch stack.
"""
transaction = self.current_batch
if isinstance(transaction, Transaction):
return transaction

def get(self, key, missing=None, deferred=None):
"""Retrieve an entity from a single key (if it exists).
Expand Down Expand Up @@ -222,7 +235,7 @@ def get_multi(self, keys, missing=None, deferred=None):
if ids != [self.dataset_id]:
raise ValueError('Keys do not match dataset ID')

transaction = Transaction.current()
transaction = self.current_transaction

entity_pbs = _extended_lookup(
connection=self.connection,
Expand Down Expand Up @@ -274,12 +287,12 @@ def put_multi(self, entities):
if not entities:
return

current = Batch.current()
current = self.current_batch
in_batch = current is not None

if not in_batch:
current = Batch(dataset_id=self.dataset_id,
connection=self.connection)
current = self.batch()

for entity in entities:
current.put(entity)

Expand Down Expand Up @@ -310,12 +323,12 @@ def delete_multi(self, keys):
return

# We allow partial keys to attempt a delete, the backend will fail.
current = Batch.current()
current = self.current_batch
in_batch = current is not None

if not in_batch:
current = Batch(dataset_id=self.dataset_id,
connection=self.connection)
current = self.batch()

for key in keys:
current.delete(key)

Expand Down Expand Up @@ -368,15 +381,14 @@ def batch(self):
Passes our ``dataset_id``.
"""
return Batch(dataset_id=self.dataset_id, connection=self.connection)
return Batch(self)

def transaction(self):
"""Proxy to :class:`gcloud.datastore.transaction.Transaction`.
Passes our ``dataset_id``.
"""
return Transaction(dataset_id=self.dataset_id,
connection=self.connection)
return Transaction(self)

def query(self, **kwargs):
"""Proxy to :class:`gcloud.datastore.query.Query`.
Expand Down
47 changes: 21 additions & 26 deletions gcloud/datastore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
import base64

from gcloud._helpers import _ensure_tuple_or_list
from gcloud.datastore import _implicit_environ
from gcloud.datastore import _datastore_v1_pb2 as datastore_pb
from gcloud.datastore import helpers
from gcloud.datastore.key import Key
from gcloud.datastore.transaction import Transaction


class Query(object):
Expand All @@ -30,15 +28,19 @@ class Query(object):
This class serves as an abstraction for creating a query over data
stored in the Cloud Datastore.
:type client: :class:`gcloud.datastore.client.Client`
:param client: The client used to connect to datastore.
:type kind: string
:param kind: The kind to query.
:type dataset_id: string
:param dataset_id: The ID of the dataset to query. If not passed,
uses the implicit default.
uses the client's value.
:type namespace: string or None
:param namespace: The namespace to which to restrict results.
:param namespace: The namespace to which to restrict results. If not
passed, uses the client's value.
:type ancestor: :class:`gcloud.datastore.key.Key` or None
:param ancestor: key of the ancestor to which this query's results are
Expand Down Expand Up @@ -71,6 +73,7 @@ class Query(object):
"""Mapping of operator strings and their protobuf equivalents."""

def __init__(self,
client,
kind=None,
dataset_id=None,
namespace=None,
Expand All @@ -80,15 +83,10 @@ def __init__(self,
order=(),
group_by=()):

if dataset_id is None:
dataset_id = _implicit_environ.get_default_dataset_id()

if dataset_id is None:
raise ValueError("No dataset ID supplied, and no default set.")

self._dataset_id = dataset_id
self._client = client
self._kind = kind
self._namespace = namespace
self._dataset_id = dataset_id or client.dataset_id
self._namespace = namespace or client.namespace
self._ancestor = ancestor
self._filters = []
# Verify filters passed in.
Expand Down Expand Up @@ -294,7 +292,7 @@ def group_by(self, value):
self._group_by[:] = value

def fetch(self, limit=None, offset=0, start_cursor=None, end_cursor=None,
connection=None):
client=None):
"""Execute the Query; return an iterator for the matching entities.
For example::
Expand All @@ -319,22 +317,19 @@ def fetch(self, limit=None, offset=0, start_cursor=None, end_cursor=None,
:type end_cursor: bytes
:param end_cursor: An optional cursor passed through to the iterator.
:type connection: :class:`gcloud.datastore.connection.Connection`
:param connection: An optional cursor passed through to the iterator.
If not supplied, uses the implicit default.
:type client: :class:`gcloud.datastore.client.Client`
:param client: client used to connect to datastore.
If not supplied, uses the query's value.
:rtype: :class:`Iterator`
:raises: ValueError if ``connection`` is not passed and no implicit
default has been set.
"""
if connection is None:
connection = _implicit_environ.get_default_connection()

if connection is None:
raise ValueError("No connection passed, and no default set")
if client is None:
client = self._client

return Iterator(
self, connection, limit, offset, start_cursor, end_cursor)
self, client, limit, offset, start_cursor, end_cursor)


class Iterator(object):
Expand All @@ -347,10 +342,10 @@ class Iterator(object):
datastore_pb.QueryResultBatch.MORE_RESULTS_AFTER_LIMIT,
)

def __init__(self, query, connection, limit=None, offset=0,
def __init__(self, query, client, limit=None, offset=0,
start_cursor=None, end_cursor=None):
self._query = query
self._connection = connection
self._client = client
self._limit = limit
self._offset = offset
self._start_cursor = start_cursor
Expand Down Expand Up @@ -380,9 +375,9 @@ def next_page(self):

pb.offset = self._offset

transaction = Transaction.current()
transaction = self._client.current_transaction

query_results = self._connection.run_query(
query_results = self._client.connection.run_query(
query_pb=pb,
dataset_id=self._query.dataset_id,
namespace=self._query.namespace,
Expand Down
Loading

0 comments on commit 3c76dbf

Please sign in to comment.