Skip to content

Commit

Permalink
feat: replace internal dictionaries with protos in gapic calls (#875)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche authored Nov 22, 2023
1 parent 94bfe66 commit 3ac80a9
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 101 deletions.
22 changes: 16 additions & 6 deletions google/cloud/bigtable/data/_async/_mutate_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

from typing import TYPE_CHECKING
import asyncio
from dataclasses import dataclass
import functools

from google.api_core import exceptions as core_exceptions
from google.api_core import retry_async as retries
import google.cloud.bigtable_v2.types.bigtable as types_pb
import google.cloud.bigtable.data.exceptions as bt_exceptions
from google.cloud.bigtable.data._helpers import _make_metadata
from google.cloud.bigtable.data._helpers import _convert_retry_deadline
Expand All @@ -36,6 +38,16 @@
from google.cloud.bigtable.data._async.client import TableAsync


@dataclass
class _EntryWithProto:
"""
A dataclass to hold a RowMutationEntry and its corresponding proto representation.
"""

entry: RowMutationEntry
proto: types_pb.MutateRowsRequest.Entry


class _MutateRowsOperationAsync:
"""
MutateRowsOperation manages the logic of sending a set of row mutations,
Expand Down Expand Up @@ -105,7 +117,7 @@ def __init__(
self.timeout_generator = _attempt_timeout_generator(
attempt_timeout, operation_timeout
)
self.mutations = mutation_entries
self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries]
self.remaining_indices = list(range(len(self.mutations)))
self.errors: dict[int, list[Exception]] = {}

Expand Down Expand Up @@ -136,7 +148,7 @@ async def start(self):
cause_exc = exc_list[0]
else:
cause_exc = bt_exceptions.RetryExceptionGroup(exc_list)
entry = self.mutations[idx]
entry = self.mutations[idx].entry
all_errors.append(
bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc)
)
Expand All @@ -154,9 +166,7 @@ async def _run_attempt(self):
retry after the attempt is complete
- GoogleAPICallError: if the gapic rpc fails
"""
request_entries = [
self.mutations[idx]._to_dict() for idx in self.remaining_indices
]
request_entries = [self.mutations[idx].proto for idx in self.remaining_indices]
# track mutations in this request that have not been finalized yet
active_request_indices = {
req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices)
Expand Down Expand Up @@ -214,7 +224,7 @@ def _handle_entry_error(self, idx: int, exc: Exception):
- idx: the index of the mutation that failed
- exc: the exception to add to the list
"""
entry = self.mutations[idx]
entry = self.mutations[idx].entry
self.errors.setdefault(idx, []).append(exc)
if (
entry.is_idempotent()
Expand Down
59 changes: 26 additions & 33 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,22 +924,17 @@ async def mutate_row(
GoogleAPIError exceptions from any retries that failed
- GoogleAPIError: raised on non-idempotent operations that cannot be
safely retried.
- ValueError if invalid arguments are provided
"""
operation_timeout, attempt_timeout = _get_timeouts(
operation_timeout, attempt_timeout, self
)

if isinstance(row_key, str):
row_key = row_key.encode("utf-8")
request = {"table_name": self.table_name, "row_key": row_key}
if self.app_profile_id:
request["app_profile_id"] = self.app_profile_id
if not mutations:
raise ValueError("No mutations provided")
mutations_list = mutations if isinstance(mutations, list) else [mutations]

if isinstance(mutations, Mutation):
mutations = [mutations]
request["mutations"] = [mutation._to_dict() for mutation in mutations]

if all(mutation.is_idempotent() for mutation in mutations):
if all(mutation.is_idempotent() for mutation in mutations_list):
# mutations are all idempotent and safe to retry
predicate = retries.if_exception_type(
core_exceptions.DeadlineExceeded,
Expand Down Expand Up @@ -972,7 +967,13 @@ def on_error_fn(exc):
metadata = _make_metadata(self.table_name, self.app_profile_id)
# trigger rpc
await deadline_wrapped(
request, timeout=attempt_timeout, metadata=metadata, retry=None
row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
mutations=[mutation._to_pb() for mutation in mutations_list],
table_name=self.table_name,
app_profile_id=self.app_profile_id,
timeout=attempt_timeout,
metadata=metadata,
retry=None,
)

async def bulk_mutate_rows(
Expand Down Expand Up @@ -1009,6 +1010,7 @@ async def bulk_mutate_rows(
Raises:
- MutationsExceptionGroup if one or more mutations fails
Contains details about any failed entries in .exceptions
- ValueError if invalid arguments are provided
"""
operation_timeout, attempt_timeout = _get_timeouts(
operation_timeout, attempt_timeout, self
Expand Down Expand Up @@ -1065,29 +1067,24 @@ async def check_and_mutate_row(
- GoogleAPIError exceptions from grpc call
"""
operation_timeout, _ = _get_timeouts(operation_timeout, None, self)
row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key
if true_case_mutations is not None and not isinstance(
true_case_mutations, list
):
true_case_mutations = [true_case_mutations]
true_case_dict = [m._to_dict() for m in true_case_mutations or []]
true_case_list = [m._to_pb() for m in true_case_mutations or []]
if false_case_mutations is not None and not isinstance(
false_case_mutations, list
):
false_case_mutations = [false_case_mutations]
false_case_dict = [m._to_dict() for m in false_case_mutations or []]
false_case_list = [m._to_pb() for m in false_case_mutations or []]
metadata = _make_metadata(self.table_name, self.app_profile_id)
result = await self.client._gapic_client.check_and_mutate_row(
request={
"predicate_filter": predicate._to_dict()
if predicate is not None
else None,
"true_mutations": true_case_dict,
"false_mutations": false_case_dict,
"table_name": self.table_name,
"row_key": row_key,
"app_profile_id": self.app_profile_id,
},
true_mutations=true_case_list,
false_mutations=false_case_list,
predicate_filter=predicate._to_pb() if predicate is not None else None,
row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
table_name=self.table_name,
app_profile_id=self.app_profile_id,
metadata=metadata,
timeout=operation_timeout,
retry=None,
Expand Down Expand Up @@ -1123,25 +1120,21 @@ async def read_modify_write_row(
operation
Raises:
- GoogleAPIError exceptions from grpc call
- ValueError if invalid arguments are provided
"""
operation_timeout, _ = _get_timeouts(operation_timeout, None, self)
row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key
if operation_timeout <= 0:
raise ValueError("operation_timeout must be greater than 0")
if rules is not None and not isinstance(rules, list):
rules = [rules]
if not rules:
raise ValueError("rules must contain at least one item")
# concert to dict representation
rules_dict = [rule._to_dict() for rule in rules]
metadata = _make_metadata(self.table_name, self.app_profile_id)
result = await self.client._gapic_client.read_modify_write_row(
request={
"rules": rules_dict,
"table_name": self.table_name,
"row_key": row_key,
"app_profile_id": self.app_profile_id,
},
rules=[rule._to_pb() for rule in rules],
row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
table_name=self.table_name,
app_profile_id=self.app_profile_id,
metadata=metadata,
timeout=operation_timeout,
retry=None,
Expand Down
3 changes: 0 additions & 3 deletions google/cloud/bigtable/data/_async/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,6 @@ async def _execute_mutate_rows(
- list of FailedMutationEntryError objects for mutations that failed.
FailedMutationEntryError objects will not contain index information
"""
request = {"table_name": self._table.table_name}
if self._table.app_profile_id:
request["app_profile_id"] = self._table.app_profile_id
try:
operation = _MutateRowsOperationAsync(
self._table.client._gapic_client,
Expand Down
15 changes: 15 additions & 0 deletions google/cloud/bigtable/data/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from abc import ABC, abstractmethod
from sys import getsizeof

import google.cloud.bigtable_v2.types.bigtable as types_pb
import google.cloud.bigtable_v2.types.data as data_pb

from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE


# special value for SetCell mutation timestamps. If set, server will assign a timestamp
_SERVER_SIDE_TIMESTAMP = -1

Expand All @@ -36,6 +39,12 @@ class Mutation(ABC):
def _to_dict(self) -> dict[str, Any]:
raise NotImplementedError

def _to_pb(self) -> data_pb.Mutation:
"""
Convert the mutation to protobuf
"""
return data_pb.Mutation(**self._to_dict())

def is_idempotent(self) -> bool:
"""
Check if the mutation is idempotent
Expand Down Expand Up @@ -221,6 +230,12 @@ def _to_dict(self) -> dict[str, Any]:
"mutations": [mutation._to_dict() for mutation in self.mutations],
}

def _to_pb(self) -> types_pb.MutateRowsRequest.Entry:
return types_pb.MutateRowsRequest.Entry(
row_key=self.row_key,
mutations=[mutation._to_pb() for mutation in self.mutations],
)

def is_idempotent(self) -> bool:
"""Check if the mutation is idempotent"""
return all(mutation.is_idempotent() for mutation in self.mutations)
Expand Down
11 changes: 8 additions & 3 deletions google/cloud/bigtable/data/read_modify_write_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import abc

import google.cloud.bigtable_v2.types.data as data_pb

# value must fit in 64-bit signed integer
_MAX_INCREMENT_VALUE = (1 << 63) - 1

Expand All @@ -29,9 +31,12 @@ def __init__(self, family: str, qualifier: bytes | str):
self.qualifier = qualifier

@abc.abstractmethod
def _to_dict(self):
def _to_dict(self) -> dict[str, str | bytes | int]:
raise NotImplementedError

def _to_pb(self) -> data_pb.ReadModifyWriteRule:
return data_pb.ReadModifyWriteRule(**self._to_dict())


class IncrementRule(ReadModifyWriteRule):
def __init__(self, family: str, qualifier: bytes | str, increment_amount: int = 1):
Expand All @@ -44,7 +49,7 @@ def __init__(self, family: str, qualifier: bytes | str, increment_amount: int =
super().__init__(family, qualifier)
self.increment_amount = increment_amount

def _to_dict(self):
def _to_dict(self) -> dict[str, str | bytes | int]:
return {
"family_name": self.family,
"column_qualifier": self.qualifier,
Expand All @@ -64,7 +69,7 @@ def __init__(self, family: str, qualifier: bytes | str, append_value: bytes | st
super().__init__(family, qualifier)
self.append_value = append_value

def _to_dict(self):
def _to_dict(self) -> dict[str, str | bytes | int]:
return {
"family_name": self.family,
"column_qualifier": self.qualifier,
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/data/_async/test__mutate_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_ctor(self):
"""
test that constructor sets all the attributes correctly
"""
from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto
from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete
from google.api_core.exceptions import DeadlineExceeded
from google.api_core.exceptions import ServiceUnavailable
Expand Down Expand Up @@ -103,7 +104,8 @@ def test_ctor(self):
assert str(table.table_name) in metadata[0][1]
assert str(table.app_profile_id) in metadata[0][1]
# entries should be passed down
assert instance.mutations == entries
entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries]
assert instance.mutations == entries_w_pb
# timeout_gen should generate per-attempt timeout
assert next(instance.timeout_generator) == attempt_timeout
# ensure predicate is set
Expand Down Expand Up @@ -306,7 +308,7 @@ async def test_run_attempt_single_entry_success(self):
assert mock_gapic_fn.call_count == 1
_, kwargs = mock_gapic_fn.call_args
assert kwargs["timeout"] == expected_timeout
assert kwargs["entries"] == [mutation._to_dict()]
assert kwargs["entries"] == [mutation._to_pb()]

@pytest.mark.asyncio
async def test_run_attempt_empty_request(self):
Expand Down
Loading

0 comments on commit 3ac80a9

Please sign in to comment.