Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replace internal dictionaries with protos in gapic calls #875

Merged
merged 12 commits into from
Nov 22, 2023
22 changes: 16 additions & 6 deletions google/cloud/bigtable/data/_async/_mutate_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

from typing import TYPE_CHECKING
import asyncio
from dataclasses import dataclass
import functools

from google.api_core import exceptions as core_exceptions
from google.api_core import retry_async as retries
import google.cloud.bigtable_v2.types.bigtable as types_pb
import google.cloud.bigtable.data.exceptions as bt_exceptions
from google.cloud.bigtable.data._helpers import _make_metadata
from google.cloud.bigtable.data._helpers import _convert_retry_deadline
Expand All @@ -36,6 +38,16 @@
from google.cloud.bigtable.data._async.client import TableAsync


@dataclass
class _EntryWithProto:
"""
A dataclass to hold a RowMutationEntry and its corresponding proto representation.
"""

entry: RowMutationEntry
proto: types_pb.MutateRowsRequest.Entry


class _MutateRowsOperationAsync:
"""
MutateRowsOperation manages the logic of sending a set of row mutations,
Expand Down Expand Up @@ -104,7 +116,7 @@ def __init__(
self.timeout_generator = _attempt_timeout_generator(
attempt_timeout, operation_timeout
)
self.mutations = mutation_entries
self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make the mutation classes immutable, we can simplify this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think making each individual mutation (ie SetCell, DeleteCell, etc) is a great idea
Howevr I dont think you can make the collection of mutations immutable (ie RowMutationEntry)

self.remaining_indices = list(range(len(self.mutations)))
self.errors: dict[int, list[Exception]] = {}

Expand Down Expand Up @@ -135,7 +147,7 @@ async def start(self):
cause_exc = exc_list[0]
else:
cause_exc = bt_exceptions.RetryExceptionGroup(exc_list)
entry = self.mutations[idx]
entry = self.mutations[idx].entry
all_errors.append(
bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc)
)
Expand All @@ -153,9 +165,7 @@ async def _run_attempt(self):
retry after the attempt is complete
- GoogleAPICallError: if the gapic rpc fails
"""
request_entries = [
self.mutations[idx]._to_dict() for idx in self.remaining_indices
]
request_entries = [self.mutations[idx].proto for idx in self.remaining_indices]
# track mutations in this request that have not been finalized yet
active_request_indices = {
req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices)
Expand Down Expand Up @@ -213,7 +223,7 @@ def _handle_entry_error(self, idx: int, exc: Exception):
- idx: the index of the mutation that failed
- exc: the exception to add to the list
"""
entry = self.mutations[idx]
entry = self.mutations[idx].entry
self.errors.setdefault(idx, []).append(exc)
if (
entry.is_idempotent()
Expand Down
59 changes: 26 additions & 33 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,24 +943,19 @@ async def mutate_row(
GoogleAPIError exceptions from any retries that failed
- GoogleAPIError: raised on non-idempotent operations that cannot be
safely retried.
- ValueError if invalid arguments are provided
"""
operation_timeout = operation_timeout or self.default_operation_timeout
attempt_timeout = (
attempt_timeout or self.default_attempt_timeout or operation_timeout
)
_validate_timeouts(operation_timeout, attempt_timeout)

if isinstance(row_key, str):
row_key = row_key.encode("utf-8")
request = {"table_name": self.table_name, "row_key": row_key}
if self.app_profile_id:
request["app_profile_id"] = self.app_profile_id
if not mutations:
raise ValueError("No mutations provided")
mutations_list = mutations if isinstance(mutations, list) else [mutations]

if isinstance(mutations, Mutation):
mutations = [mutations]
request["mutations"] = [mutation._to_dict() for mutation in mutations]

if all(mutation.is_idempotent() for mutation in mutations):
if all(mutation.is_idempotent() for mutation in mutations_list):
igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved
# mutations are all idempotent and safe to retry
predicate = retries.if_exception_type(
core_exceptions.DeadlineExceeded,
Expand Down Expand Up @@ -993,7 +988,13 @@ def on_error_fn(exc):
metadata = _make_metadata(self.table_name, self.app_profile_id)
# trigger rpc
await deadline_wrapped(
request, timeout=attempt_timeout, metadata=metadata, retry=None
row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
mutations=[mutation._to_pb() for mutation in mutations_list],
table_name=self.table_name,
app_profile_id=self.app_profile_id,
timeout=attempt_timeout,
metadata=metadata,
retry=None,
)

async def bulk_mutate_rows(
Expand Down Expand Up @@ -1030,6 +1031,7 @@ async def bulk_mutate_rows(
Raises:
- MutationsExceptionGroup if one or more mutations fails
Contains details about any failed entries in .exceptions
- ValueError if invalid arguments are provided
"""
operation_timeout = (
operation_timeout or self.default_mutate_rows_operation_timeout
Expand Down Expand Up @@ -1095,29 +1097,24 @@ async def check_and_mutate_row(
operation_timeout = operation_timeout or self.default_operation_timeout
if operation_timeout <= 0:
raise ValueError("operation_timeout must be greater than 0")
row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key
if true_case_mutations is not None and not isinstance(
true_case_mutations, list
):
true_case_mutations = [true_case_mutations]
true_case_dict = [m._to_dict() for m in true_case_mutations or []]
true_case_list = [m._to_pb() for m in true_case_mutations or []]
if false_case_mutations is not None and not isinstance(
false_case_mutations, list
):
false_case_mutations = [false_case_mutations]
false_case_dict = [m._to_dict() for m in false_case_mutations or []]
false_case_list = [m._to_pb() for m in false_case_mutations or []]
metadata = _make_metadata(self.table_name, self.app_profile_id)
result = await self.client._gapic_client.check_and_mutate_row(
request={
"predicate_filter": predicate._to_dict()
if predicate is not None
else None,
"true_mutations": true_case_dict,
"false_mutations": false_case_dict,
"table_name": self.table_name,
"row_key": row_key,
"app_profile_id": self.app_profile_id,
},
true_mutations=true_case_list,
false_mutations=false_case_list,
predicate_filter=predicate._to_pb() if predicate is not None else None,
row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
table_name=self.table_name,
app_profile_id=self.app_profile_id,
metadata=metadata,
timeout=operation_timeout,
)
Expand Down Expand Up @@ -1152,25 +1149,21 @@ async def read_modify_write_row(
operation
Raises:
- GoogleAPIError exceptions from grpc call
- ValueError if invalid arguments are provided
"""
operation_timeout = operation_timeout or self.default_operation_timeout
row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key
if operation_timeout <= 0:
raise ValueError("operation_timeout must be greater than 0")
if rules is not None and not isinstance(rules, list):
rules = [rules]
if not rules:
raise ValueError("rules must contain at least one item")
# concert to dict representation
rules_dict = [rule._to_dict() for rule in rules]
metadata = _make_metadata(self.table_name, self.app_profile_id)
result = await self.client._gapic_client.read_modify_write_row(
request={
"rules": rules_dict,
"table_name": self.table_name,
"row_key": row_key,
"app_profile_id": self.app_profile_id,
},
rules=[rule._to_pb() for rule in rules],
row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
table_name=self.table_name,
app_profile_id=self.app_profile_id,
metadata=metadata,
timeout=operation_timeout,
)
Expand Down
3 changes: 0 additions & 3 deletions google/cloud/bigtable/data/_async/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,6 @@ async def _execute_mutate_rows(
- list of FailedMutationEntryError objects for mutations that failed.
FailedMutationEntryError objects will not contain index information
"""
request = {"table_name": self._table.table_name}
if self._table.app_profile_id:
request["app_profile_id"] = self._table.app_profile_id
try:
operation = _MutateRowsOperationAsync(
self._table.client._gapic_client,
Expand Down
15 changes: 15 additions & 0 deletions google/cloud/bigtable/data/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from abc import ABC, abstractmethod
from sys import getsizeof

import google.cloud.bigtable_v2.types.bigtable as types_pb
import google.cloud.bigtable_v2.types.data as data_pb

from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE


# special value for SetCell mutation timestamps. If set, server will assign a timestamp
_SERVER_SIDE_TIMESTAMP = -1

Expand All @@ -36,6 +39,12 @@ class Mutation(ABC):
def _to_dict(self) -> dict[str, Any]:
raise NotImplementedError

def _to_pb(self) -> data_pb.Mutation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to double store all of the attributes? Can't we store all of the attributes in the proto directly?
Also can we drop _to_dict altogether?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main downside is we'd need a lot more boilerplate setters/getters (we lose the simplicity of the dataclasses). And marshaling to/from the protos is more expensive, but I guess that should'nt be too much of an issue here. I see this is how we handled it in the ReadRowsQuery class, which is more complicated than these would be.

What do you think of making these immutable? That would simplify the setters/getters, and then we wouldn't have to worry about making a static copy before starting a mutate_rows operation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also can we drop _to_dict altogether?

I think so. We do use row._to_dict in the test proxy, but I don't think we'd need it for mutations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally the logic for converting from models to protos for the test proxy is self contained in the proxy. A side goal of test proxy was to be a rosetta stone for expressing the same concepts across every cleint impl

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I dont think you need to worry about someone mutating the model after passing it to the proxy...I dont think you need to be that defensive in your code

"""
Convert the mutation to protobuf
"""
return data_pb.Mutation(**self._to_dict())

def is_idempotent(self) -> bool:
"""
Check if the mutation is idempotent
Expand Down Expand Up @@ -221,6 +230,12 @@ def _to_dict(self) -> dict[str, Any]:
"mutations": [mutation._to_dict() for mutation in self.mutations],
}

def _to_pb(self) -> types_pb.MutateRowsRequest.Entry:
return types_pb.MutateRowsRequest.Entry(
row_key=self.row_key,
mutations=[mutation._to_pb() for mutation in self.mutations],
)

def is_idempotent(self) -> bool:
"""Check if the mutation is idempotent"""
return all(mutation.is_idempotent() for mutation in self.mutations)
Expand Down
11 changes: 8 additions & 3 deletions google/cloud/bigtable/data/read_modify_write_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import abc

import google.cloud.bigtable_v2.types.data as data_pb

# value must fit in 64-bit signed integer
_MAX_INCREMENT_VALUE = (1 << 63) - 1

Expand All @@ -29,9 +31,12 @@ def __init__(self, family: str, qualifier: bytes | str):
self.qualifier = qualifier

@abc.abstractmethod
def _to_dict(self):
def _to_dict(self) -> dict[str, str | bytes | int]:
raise NotImplementedError

def _to_pb(self) -> data_pb.ReadModifyWriteRule:
return data_pb.ReadModifyWriteRule(**self._to_dict())


class IncrementRule(ReadModifyWriteRule):
def __init__(self, family: str, qualifier: bytes | str, increment_amount: int = 1):
Expand All @@ -44,7 +49,7 @@ def __init__(self, family: str, qualifier: bytes | str, increment_amount: int =
super().__init__(family, qualifier)
self.increment_amount = increment_amount

def _to_dict(self):
def _to_dict(self) -> dict[str, str | bytes | int]:
return {
"family_name": self.family,
"column_qualifier": self.qualifier,
Expand All @@ -64,7 +69,7 @@ def __init__(self, family: str, qualifier: bytes | str, append_value: bytes | st
super().__init__(family, qualifier)
self.append_value = append_value

def _to_dict(self):
def _to_dict(self) -> dict[str, str | bytes | int]:
return {
"family_name": self.family,
"column_qualifier": self.qualifier,
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/data/_async/test__mutate_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_ctor(self):
"""
test that constructor sets all the attributes correctly
"""
from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto
from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete
from google.api_core.exceptions import DeadlineExceeded
from google.api_core.exceptions import ServiceUnavailable
Expand Down Expand Up @@ -102,7 +103,8 @@ def test_ctor(self):
assert str(table.table_name) in metadata[0][1]
assert str(table.app_profile_id) in metadata[0][1]
# entries should be passed down
assert instance.mutations == entries
entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries]
assert instance.mutations == entries_w_pb
# timeout_gen should generate per-attempt timeout
assert next(instance.timeout_generator) == attempt_timeout
# ensure predicate is set
Expand Down Expand Up @@ -305,7 +307,7 @@ async def test_run_attempt_single_entry_success(self):
assert mock_gapic_fn.call_count == 1
_, kwargs = mock_gapic_fn.call_args
assert kwargs["timeout"] == expected_timeout
assert kwargs["entries"] == [mutation._to_dict()]
assert kwargs["entries"] == [mutation._to_pb()]

@pytest.mark.asyncio
async def test_run_attempt_empty_request(self):
Expand Down
Loading
Loading