From 3ac80a958638d89281a54eaeddbef28e9d2aee87 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Nov 2023 14:25:46 -0800 Subject: [PATCH] feat: replace internal dictionaries with protos in gapic calls (#875) --- .../bigtable/data/_async/_mutate_rows.py | 22 +++- google/cloud/bigtable/data/_async/client.py | 59 ++++----- .../bigtable/data/_async/mutations_batcher.py | 3 - google/cloud/bigtable/data/mutations.py | 15 +++ .../bigtable/data/read_modify_write_rules.py | 11 +- tests/unit/data/_async/test__mutate_rows.py | 6 +- tests/unit/data/_async/test_client.py | 112 +++++++++--------- tests/unit/data/test_mutations.py | 84 +++++++++++++ .../unit/data/test_read_modify_write_rules.py | 42 +++++++ 9 files changed, 253 insertions(+), 101 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index be84fac17..5bf759151 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -16,10 +16,12 @@ from typing import TYPE_CHECKING import asyncio +from dataclasses import dataclass import functools from google.api_core import exceptions as core_exceptions from google.api_core import retry_async as retries +import google.cloud.bigtable_v2.types.bigtable as types_pb import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _convert_retry_deadline @@ -36,6 +38,16 @@ from google.cloud.bigtable.data._async.client import TableAsync +@dataclass +class _EntryWithProto: + """ + A dataclass to hold a RowMutationEntry and its corresponding proto representation. + """ + + entry: RowMutationEntry + proto: types_pb.MutateRowsRequest.Entry + + class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -105,7 +117,7 @@ def __init__( self.timeout_generator = _attempt_timeout_generator( attempt_timeout, operation_timeout ) - self.mutations = mutation_entries + self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} @@ -136,7 +148,7 @@ async def start(self): cause_exc = exc_list[0] else: cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) - entry = self.mutations[idx] + entry = self.mutations[idx].entry all_errors.append( bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) ) @@ -154,9 +166,7 @@ async def _run_attempt(self): retry after the attempt is complete - GoogleAPICallError: if the gapic rpc fails """ - request_entries = [ - self.mutations[idx]._to_dict() for idx in self.remaining_indices - ] + request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] # track mutations in this request that have not been finalized yet active_request_indices = { req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices) @@ -214,7 +224,7 @@ def _handle_entry_error(self, idx: int, exc: Exception): - idx: the index of the mutation that failed - exc: the exception to add to the list """ - entry = self.mutations[idx] + entry = self.mutations[idx].entry self.errors.setdefault(idx, []).append(exc) if ( entry.is_idempotent() diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 90939927e..ab8cc48f8 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -924,22 +924,17 @@ async def mutate_row( GoogleAPIError exceptions from any retries that failed - GoogleAPIError: raised on non-idempotent operations that cannot be safely retried. + - ValueError if invalid arguments are provided """ operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self ) - if isinstance(row_key, str): - row_key = row_key.encode("utf-8") - request = {"table_name": self.table_name, "row_key": row_key} - if self.app_profile_id: - request["app_profile_id"] = self.app_profile_id + if not mutations: + raise ValueError("No mutations provided") + mutations_list = mutations if isinstance(mutations, list) else [mutations] - if isinstance(mutations, Mutation): - mutations = [mutations] - request["mutations"] = [mutation._to_dict() for mutation in mutations] - - if all(mutation.is_idempotent() for mutation in mutations): + if all(mutation.is_idempotent() for mutation in mutations_list): # mutations are all idempotent and safe to retry predicate = retries.if_exception_type( core_exceptions.DeadlineExceeded, @@ -972,7 +967,13 @@ def on_error_fn(exc): metadata = _make_metadata(self.table_name, self.app_profile_id) # trigger rpc await deadline_wrapped( - request, timeout=attempt_timeout, metadata=metadata, retry=None + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=attempt_timeout, + metadata=metadata, + retry=None, ) async def bulk_mutate_rows( @@ -1009,6 +1010,7 @@ async def bulk_mutate_rows( Raises: - MutationsExceptionGroup if one or more mutations fails Contains details about any failed entries in .exceptions + - ValueError if invalid arguments are provided """ operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self @@ -1065,29 +1067,24 @@ async def check_and_mutate_row( - GoogleAPIError exceptions from grpc call """ operation_timeout, _ = _get_timeouts(operation_timeout, None, self) - row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key if true_case_mutations is not None and not isinstance( true_case_mutations, list ): true_case_mutations = [true_case_mutations] - true_case_dict = [m._to_dict() for m in true_case_mutations or []] + true_case_list = [m._to_pb() for m in true_case_mutations or []] if false_case_mutations is not None and not isinstance( false_case_mutations, list ): false_case_mutations = [false_case_mutations] - false_case_dict = [m._to_dict() for m in false_case_mutations or []] + false_case_list = [m._to_pb() for m in false_case_mutations or []] metadata = _make_metadata(self.table_name, self.app_profile_id) result = await self.client._gapic_client.check_and_mutate_row( - request={ - "predicate_filter": predicate._to_dict() - if predicate is not None - else None, - "true_mutations": true_case_dict, - "false_mutations": false_case_dict, - "table_name": self.table_name, - "row_key": row_key, - "app_profile_id": self.app_profile_id, - }, + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, metadata=metadata, timeout=operation_timeout, retry=None, @@ -1123,25 +1120,21 @@ async def read_modify_write_row( operation Raises: - GoogleAPIError exceptions from grpc call + - ValueError if invalid arguments are provided """ operation_timeout, _ = _get_timeouts(operation_timeout, None, self) - row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") if rules is not None and not isinstance(rules, list): rules = [rules] if not rules: raise ValueError("rules must contain at least one item") - # concert to dict representation - rules_dict = [rule._to_dict() for rule in rules] metadata = _make_metadata(self.table_name, self.app_profile_id) result = await self.client._gapic_client.read_modify_write_row( - request={ - "rules": rules_dict, - "table_name": self.table_name, - "row_key": row_key, - "app_profile_id": self.app_profile_id, - }, + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, metadata=metadata, timeout=operation_timeout, retry=None, diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 7ff5f9a0b..91d2b11e1 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -342,9 +342,6 @@ async def _execute_mutate_rows( - list of FailedMutationEntryError objects for mutations that failed. FailedMutationEntryError objects will not contain index information """ - request = {"table_name": self._table.table_name} - if self._table.app_profile_id: - request["app_profile_id"] = self._table.app_profile_id try: operation = _MutateRowsOperationAsync( self._table.client._gapic_client, diff --git a/google/cloud/bigtable/data/mutations.py b/google/cloud/bigtable/data/mutations.py index 06db21879..b5729d25e 100644 --- a/google/cloud/bigtable/data/mutations.py +++ b/google/cloud/bigtable/data/mutations.py @@ -19,9 +19,12 @@ from abc import ABC, abstractmethod from sys import getsizeof +import google.cloud.bigtable_v2.types.bigtable as types_pb +import google.cloud.bigtable_v2.types.data as data_pb from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE + # special value for SetCell mutation timestamps. If set, server will assign a timestamp _SERVER_SIDE_TIMESTAMP = -1 @@ -36,6 +39,12 @@ class Mutation(ABC): def _to_dict(self) -> dict[str, Any]: raise NotImplementedError + def _to_pb(self) -> data_pb.Mutation: + """ + Convert the mutation to protobuf + """ + return data_pb.Mutation(**self._to_dict()) + def is_idempotent(self) -> bool: """ Check if the mutation is idempotent @@ -221,6 +230,12 @@ def _to_dict(self) -> dict[str, Any]: "mutations": [mutation._to_dict() for mutation in self.mutations], } + def _to_pb(self) -> types_pb.MutateRowsRequest.Entry: + return types_pb.MutateRowsRequest.Entry( + row_key=self.row_key, + mutations=[mutation._to_pb() for mutation in self.mutations], + ) + def is_idempotent(self) -> bool: """Check if the mutation is idempotent""" return all(mutation.is_idempotent() for mutation in self.mutations) diff --git a/google/cloud/bigtable/data/read_modify_write_rules.py b/google/cloud/bigtable/data/read_modify_write_rules.py index 3a3eb3752..f43dbe79f 100644 --- a/google/cloud/bigtable/data/read_modify_write_rules.py +++ b/google/cloud/bigtable/data/read_modify_write_rules.py @@ -16,6 +16,8 @@ import abc +import google.cloud.bigtable_v2.types.data as data_pb + # value must fit in 64-bit signed integer _MAX_INCREMENT_VALUE = (1 << 63) - 1 @@ -29,9 +31,12 @@ def __init__(self, family: str, qualifier: bytes | str): self.qualifier = qualifier @abc.abstractmethod - def _to_dict(self): + def _to_dict(self) -> dict[str, str | bytes | int]: raise NotImplementedError + def _to_pb(self) -> data_pb.ReadModifyWriteRule: + return data_pb.ReadModifyWriteRule(**self._to_dict()) + class IncrementRule(ReadModifyWriteRule): def __init__(self, family: str, qualifier: bytes | str, increment_amount: int = 1): @@ -44,7 +49,7 @@ def __init__(self, family: str, qualifier: bytes | str, increment_amount: int = super().__init__(family, qualifier) self.increment_amount = increment_amount - def _to_dict(self): + def _to_dict(self) -> dict[str, str | bytes | int]: return { "family_name": self.family, "column_qualifier": self.qualifier, @@ -64,7 +69,7 @@ def __init__(self, family: str, qualifier: bytes | str, append_value: bytes | st super().__init__(family, qualifier) self.append_value = append_value - def _to_dict(self): + def _to_dict(self) -> dict[str, str | bytes | int]: return { "family_name": self.family, "column_qualifier": self.qualifier, diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index eae3483ed..89a153af2 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -75,6 +75,7 @@ def test_ctor(self): """ test that constructor sets all the attributes correctly """ + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable @@ -103,7 +104,8 @@ def test_ctor(self): assert str(table.table_name) in metadata[0][1] assert str(table.app_profile_id) in metadata[0][1] # entries should be passed down - assert instance.mutations == entries + entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] + assert instance.mutations == entries_w_pb # timeout_gen should generate per-attempt timeout assert next(instance.timeout_generator) == attempt_timeout # ensure predicate is set @@ -306,7 +308,7 @@ async def test_run_attempt_single_entry_success(self): assert mock_gapic_fn.call_count == 1 _, kwargs = mock_gapic_fn.call_args assert kwargs["timeout"] == expected_timeout - assert kwargs["entries"] == [mutation._to_dict()] + assert kwargs["entries"] == [mutation._to_pb()] @pytest.mark.asyncio async def test_run_attempt_empty_request(self): diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 4ae46da6e..7afecc5b0 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -2032,18 +2032,17 @@ async def test_mutate_row(self, mutation_arg): ) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0].kwargs - request = mock_gapic.call_args[0][0] assert ( - request["table_name"] + kwargs["table_name"] == "projects/project/instances/instance/tables/table" ) - assert request["row_key"] == b"row_key" + assert kwargs["row_key"] == b"row_key" formatted_mutations = ( - [mutation._to_dict() for mutation in mutation_arg] + [mutation._to_pb() for mutation in mutation_arg] if isinstance(mutation_arg, list) - else [mutation_arg._to_dict()] + else [mutation_arg._to_pb()] ) - assert request["mutations"] == formatted_mutations + assert kwargs["mutations"] == formatted_mutations assert kwargs["timeout"] == expected_attempt_timeout # make sure gapic layer is not retrying assert kwargs["retry"] is None @@ -2146,7 +2145,7 @@ async def test_mutate_row_metadata(self, include_app_profile): with mock.patch.object( client._gapic_client, "mutate_row", AsyncMock() ) as read_rows: - await table.mutate_row("rk", {}) + await table.mutate_row("rk", mock.Mock()) kwargs = read_rows.call_args_list[0].kwargs metadata = kwargs["metadata"] goog_metadata = None @@ -2160,6 +2159,15 @@ async def test_mutate_row_metadata(self, include_app_profile): else: assert "app_profile_id=" not in goog_metadata + @pytest.mark.parametrize("mutations", [[], None]) + @pytest.mark.asyncio + async def test_mutate_row_no_mutations(self, mutations): + async with self._make_client() as client: + async with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + await table.mutate_row("key", mutations=mutations) + assert e.value.args[0] == "No mutations provided" + class TestBulkMutateRows: def _make_client(self, *args, **kwargs): @@ -2232,7 +2240,7 @@ async def test_bulk_mutate_rows(self, mutation_arg): kwargs["table_name"] == "projects/project/instances/instance/tables/table" ) - assert kwargs["entries"] == [bulk_mutation._to_dict()] + assert kwargs["entries"] == [bulk_mutation._to_pb()] assert kwargs["timeout"] == expected_attempt_timeout assert kwargs["retry"] is None @@ -2257,8 +2265,8 @@ async def test_bulk_mutate_rows_multiple_entries(self): kwargs["table_name"] == "projects/project/instances/instance/tables/table" ) - assert kwargs["entries"][0] == entry_1._to_dict() - assert kwargs["entries"][1] == entry_2._to_dict() + assert kwargs["entries"][0] == entry_1._to_pb() + assert kwargs["entries"][1] == entry_2._to_pb() @pytest.mark.asyncio @pytest.mark.parametrize( @@ -2587,17 +2595,16 @@ async def test_check_and_mutate(self, gapic_result): ) assert found == gapic_result kwargs = mock_gapic.call_args[1] - request = kwargs["request"] - assert request["table_name"] == table.table_name - assert request["row_key"] == row_key - assert request["predicate_filter"] == predicate - assert request["true_mutations"] == [ - m._to_dict() for m in true_mutations + assert kwargs["table_name"] == table.table_name + assert kwargs["row_key"] == row_key + assert kwargs["predicate_filter"] == predicate + assert kwargs["true_mutations"] == [ + m._to_pb() for m in true_mutations ] - assert request["false_mutations"] == [ - m._to_dict() for m in false_mutations + assert kwargs["false_mutations"] == [ + m._to_pb() for m in false_mutations ] - assert request["app_profile_id"] == app_profile + assert kwargs["app_profile_id"] == app_profile assert kwargs["timeout"] == operation_timeout assert kwargs["retry"] is None @@ -2655,9 +2662,8 @@ async def test_check_and_mutate_single_mutations(self): false_case_mutations=false_mutation, ) kwargs = mock_gapic.call_args[1] - request = kwargs["request"] - assert request["true_mutations"] == [true_mutation._to_dict()] - assert request["false_mutations"] == [false_mutation._to_dict()] + assert kwargs["true_mutations"] == [true_mutation._to_pb()] + assert kwargs["false_mutations"] == [false_mutation._to_pb()] @pytest.mark.asyncio async def test_check_and_mutate_predicate_object(self): @@ -2665,8 +2671,8 @@ async def test_check_and_mutate_predicate_object(self): from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse mock_predicate = mock.Mock() - predicate_dict = {"predicate": "dict"} - mock_predicate._to_dict.return_value = predicate_dict + predicate_pb = {"predicate": "dict"} + mock_predicate._to_pb.return_value = predicate_pb async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( @@ -2681,19 +2687,19 @@ async def test_check_and_mutate_predicate_object(self): false_case_mutations=[mock.Mock()], ) kwargs = mock_gapic.call_args[1] - assert kwargs["request"]["predicate_filter"] == predicate_dict - assert mock_predicate._to_dict.call_count == 1 + assert kwargs["predicate_filter"] == predicate_pb + assert mock_predicate._to_pb.call_count == 1 assert kwargs["retry"] is None @pytest.mark.asyncio async def test_check_and_mutate_mutations_parsing(self): - """mutations objects should be converted to dicts""" + """mutations objects should be converted to protos""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse from google.cloud.bigtable.data.mutations import DeleteAllFromRow mutations = [mock.Mock() for _ in range(5)] for idx, mutation in enumerate(mutations): - mutation._to_dict.return_value = {"fake": idx} + mutation._to_pb.return_value = f"fake {idx}" mutations.append(DeleteAllFromRow()) async with self._make_client() as client: async with client.get_table("instance", "table") as table: @@ -2709,16 +2715,16 @@ async def test_check_and_mutate_mutations_parsing(self): true_case_mutations=mutations[0:2], false_case_mutations=mutations[2:], ) - kwargs = mock_gapic.call_args[1]["request"] - assert kwargs["true_mutations"] == [{"fake": 0}, {"fake": 1}] + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == ["fake 0", "fake 1"] assert kwargs["false_mutations"] == [ - {"fake": 2}, - {"fake": 3}, - {"fake": 4}, - {"delete_from_row": {}}, + "fake 2", + "fake 3", + "fake 4", + DeleteAllFromRow()._to_pb(), ] assert all( - mutation._to_dict.call_count == 1 for mutation in mutations[:5] + mutation._to_pb.call_count == 1 for mutation in mutations[:5] ) @pytest.mark.parametrize("include_app_profile", [True, False]) @@ -2757,18 +2763,18 @@ def _make_client(self, *args, **kwargs): [ ( AppendValueRule("f", "c", b"1"), - [AppendValueRule("f", "c", b"1")._to_dict()], + [AppendValueRule("f", "c", b"1")._to_pb()], ), ( [AppendValueRule("f", "c", b"1")], - [AppendValueRule("f", "c", b"1")._to_dict()], + [AppendValueRule("f", "c", b"1")._to_pb()], ), - (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_dict()]), + (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), ( [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], [ - AppendValueRule("f", "c", b"1")._to_dict(), - IncrementRule("f", "c", 1)._to_dict(), + AppendValueRule("f", "c", b"1")._to_pb(), + IncrementRule("f", "c", 1)._to_pb(), ], ), ], @@ -2786,7 +2792,7 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules await table.read_modify_write_row("key", call_rules) assert mock_gapic.call_count == 1 found_kwargs = mock_gapic.call_args_list[0][1] - assert found_kwargs["request"]["rules"] == expected_rules + assert found_kwargs["rules"] == expected_rules assert found_kwargs["retry"] is None @pytest.mark.parametrize("rules", [[], None]) @@ -2811,15 +2817,14 @@ async def test_read_modify_write_call_defaults(self): ) as mock_gapic: await table.read_modify_write_row(row_key, mock.Mock()) assert mock_gapic.call_count == 1 - found_kwargs = mock_gapic.call_args_list[0][1] - request = found_kwargs["request"] + kwargs = mock_gapic.call_args_list[0][1] assert ( - request["table_name"] + kwargs["table_name"] == f"projects/{project}/instances/{instance}/tables/{table_id}" ) - assert request["app_profile_id"] is None - assert request["row_key"] == row_key.encode() - assert found_kwargs["timeout"] > 1 + assert kwargs["app_profile_id"] is None + assert kwargs["row_key"] == row_key.encode() + assert kwargs["timeout"] > 1 @pytest.mark.asyncio async def test_read_modify_write_call_overrides(self): @@ -2839,11 +2844,10 @@ async def test_read_modify_write_call_overrides(self): operation_timeout=expected_timeout, ) assert mock_gapic.call_count == 1 - found_kwargs = mock_gapic.call_args_list[0][1] - request = found_kwargs["request"] - assert request["app_profile_id"] is profile_id - assert request["row_key"] == row_key - assert found_kwargs["timeout"] == expected_timeout + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["app_profile_id"] is profile_id + assert kwargs["row_key"] == row_key + assert kwargs["timeout"] == expected_timeout @pytest.mark.asyncio async def test_read_modify_write_string_key(self): @@ -2855,8 +2859,8 @@ async def test_read_modify_write_string_key(self): ) as mock_gapic: await table.read_modify_write_row(row_key, mock.Mock()) assert mock_gapic.call_count == 1 - found_kwargs = mock_gapic.call_args_list[0][1] - assert found_kwargs["request"]["row_key"] == row_key.encode() + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["row_key"] == row_key.encode() @pytest.mark.asyncio async def test_read_modify_write_row_building(self): diff --git a/tests/unit/data/test_mutations.py b/tests/unit/data/test_mutations.py index 8680a8da9..485c86e42 100644 --- a/tests/unit/data/test_mutations.py +++ b/tests/unit/data/test_mutations.py @@ -307,6 +307,42 @@ def test__to_dict_server_timestamp(self): assert got_inner_dict["value"] == expected_value assert len(got_inner_dict.keys()) == 4 + def test__to_pb(self): + """ensure proto representation is as expected""" + import google.cloud.bigtable_v2.types.data as data_pb + + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = b"test-value" + expected_timestamp = 123456789 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert got_pb.set_cell.family_name == expected_family + assert got_pb.set_cell.column_qualifier == expected_qualifier + assert got_pb.set_cell.timestamp_micros == expected_timestamp + assert got_pb.set_cell.value == expected_value + + def test__to_pb_server_timestamp(self): + """test with server side timestamp -1 value""" + import google.cloud.bigtable_v2.types.data as data_pb + + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = b"test-value" + expected_timestamp = -1 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert got_pb.set_cell.family_name == expected_family + assert got_pb.set_cell.column_qualifier == expected_qualifier + assert got_pb.set_cell.timestamp_micros == expected_timestamp + assert got_pb.set_cell.value == expected_value + @pytest.mark.parametrize( "timestamp,expected_value", [ @@ -406,6 +442,18 @@ def test__to_dict(self, start, end): if end is not None: assert time_range_dict["end_timestamp_micros"] == end + def test__to_pb(self): + """ensure proto representation is as expected""" + import google.cloud.bigtable_v2.types.data as data_pb + + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + instance = self._make_one(expected_family, expected_qualifier) + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert got_pb.delete_from_column.family_name == expected_family + assert got_pb.delete_from_column.column_qualifier == expected_qualifier + def test_is_idempotent(self): """is_idempotent is always true""" instance = self._make_one( @@ -445,6 +493,16 @@ def test__to_dict(self): assert len(got_inner_dict.keys()) == 1 assert got_inner_dict["family_name"] == expected_family + def test__to_pb(self): + """ensure proto representation is as expected""" + import google.cloud.bigtable_v2.types.data as data_pb + + expected_family = "test-family" + instance = self._make_one(expected_family) + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert got_pb.delete_from_family.family_name == expected_family + def test_is_idempotent(self): """is_idempotent is always true""" instance = self._make_one("test-family") @@ -477,6 +535,15 @@ def test__to_dict(self): assert list(got_dict.keys()) == ["delete_from_row"] assert len(got_dict["delete_from_row"].keys()) == 0 + def test__to_pb(self): + """ensure proto representation is as expected""" + import google.cloud.bigtable_v2.types.data as data_pb + + instance = self._make_one() + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert "delete_from_row" in str(got_pb) + def test_is_idempotent(self): """is_idempotent is always true""" instance = self._make_one() @@ -550,6 +617,23 @@ def test__to_dict(self): assert instance._to_dict() == expected_result assert mutation_mock._to_dict.call_count == n_mutations + def test__to_pb(self): + from google.cloud.bigtable_v2.types.bigtable import MutateRowsRequest + from google.cloud.bigtable_v2.types.data import Mutation + + expected_key = "row_key" + mutation_mock = mock.Mock() + n_mutations = 3 + expected_mutations = [mutation_mock for i in range(n_mutations)] + for mock_mutations in expected_mutations: + mock_mutations._to_pb.return_value = Mutation() + instance = self._make_one(expected_key, expected_mutations) + pb_result = instance._to_pb() + assert isinstance(pb_result, MutateRowsRequest.Entry) + assert pb_result.row_key == b"row_key" + assert pb_result.mutations == [Mutation()] * n_mutations + assert mutation_mock._to_pb.call_count == n_mutations + @pytest.mark.parametrize( "mutations,result", [ diff --git a/tests/unit/data/test_read_modify_write_rules.py b/tests/unit/data/test_read_modify_write_rules.py index aeb41f19c..1f67da13b 100644 --- a/tests/unit/data/test_read_modify_write_rules.py +++ b/tests/unit/data/test_read_modify_write_rules.py @@ -36,6 +36,9 @@ def test_abstract(self): self._target_class()(family="foo", qualifier=b"bar") def test__to_dict(self): + """ + to_dict not implemented in base class + """ with pytest.raises(NotImplementedError): self._target_class()._to_dict(mock.Mock()) @@ -97,6 +100,27 @@ def test__to_dict(self, args, expected): } assert instance._to_dict() == expected + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", 1), ("fam", b"qual", 1)), + (("fam", b"qual", -12), ("fam", b"qual", -12)), + (("fam", "qual", 1), ("fam", b"qual", 1)), + (("fam", "qual", 0), ("fam", b"qual", 0)), + (("", "", 0), ("", b"", 0)), + (("f", b"q"), ("f", b"q", 1)), + ], + ) + def test__to_pb(self, args, expected): + import google.cloud.bigtable_v2.types.data as data_pb + + instance = self._target_class()(*args) + pb_result = instance._to_pb() + assert isinstance(pb_result, data_pb.ReadModifyWriteRule) + assert pb_result.family_name == expected[0] + assert pb_result.column_qualifier == expected[1] + assert pb_result.increment_amount == expected[2] + class TestAppendValueRule: def _target_class(self): @@ -142,3 +166,21 @@ def test__to_dict(self, args, expected): "append_value": expected[2], } assert instance._to_dict() == expected + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", b"val"), ("fam", b"qual", b"val")), + (("fam", "qual", b"val"), ("fam", b"qual", b"val")), + (("", "", b""), ("", b"", b"")), + ], + ) + def test__to_pb(self, args, expected): + import google.cloud.bigtable_v2.types.data as data_pb + + instance = self._target_class()(*args) + pb_result = instance._to_pb() + assert isinstance(pb_result, data_pb.ReadModifyWriteRule) + assert pb_result.family_name == expected[0] + assert pb_result.column_qualifier == expected[1] + assert pb_result.append_value == expected[2]