diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index 3921d6640..0544bcb78 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -55,6 +55,10 @@ from google.cloud.bigtable._helpers import _make_metadata from google.cloud.bigtable._helpers import _convert_retry_deadline +from google.cloud.bigtable.row_filters import StripValueTransformerFilter +from google.cloud.bigtable.row_filters import CellsRowLimitFilter +from google.cloud.bigtable.row_filters import RowFilterChain + if TYPE_CHECKING: from google.cloud.bigtable.mutations_batcher import MutationsBatcher from google.cloud.bigtable import RowKeySamples @@ -500,18 +504,31 @@ async def read_row( self, row_key: str | bytes, *, + row_filter: RowFilter | None = None, operation_timeout: int | float | None = 60, per_request_timeout: int | float | None = None, - ) -> Row: + ) -> Row | None: """ Helper function to return a single row See read_rows_stream + Raises: + - google.cloud.bigtable.exceptions.RowNotFound: if the row does not exist Returns: - - the individual row requested + - the individual row requested, or None if it does not exist """ - raise NotImplementedError + if row_key is None: + raise ValueError("row_key must be string or bytes") + query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) + results = await self.read_rows( + query, + operation_timeout=operation_timeout, + per_request_timeout=per_request_timeout, + ) + if len(results) == 0: + return None + return results[0] async def read_rows_sharded( self, @@ -547,7 +564,18 @@ async def row_exists( Returns: - a bool indicating whether the row exists """ - raise NotImplementedError + if row_key is None: + raise ValueError("row_key must be string or bytes") + strip_filter = StripValueTransformerFilter(flag=True) + limit_filter = CellsRowLimitFilter(1) + chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) + query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) + results = await self.read_rows( + query, + operation_timeout=operation_timeout, + per_request_timeout=per_request_timeout, + ) + return len(results) > 0 async def sample_keys( self, diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index e26f99d34..6de84e918 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -106,12 +106,12 @@ def __init__( """ self.row_keys: set[bytes] = set() self.row_ranges: list[RowRange | dict[str, bytes]] = [] - if row_ranges: + if row_ranges is not None: if isinstance(row_ranges, RowRange): row_ranges = [row_ranges] for r in row_ranges: self.add_range(r) - if row_keys: + if row_keys is not None: if not isinstance(row_keys, list): row_keys = [row_keys] for k in row_keys: @@ -221,7 +221,11 @@ def _to_dict(self) -> dict[str, Any]: row_ranges.append(dict_range) row_keys = list(self.row_keys) row_keys.sort() - row_set = {"row_keys": row_keys, "row_ranges": row_ranges} + row_set: dict[str, Any] = {} + if row_keys: + row_set["row_keys"] = row_keys + if row_ranges: + row_set["row_ranges"] = row_ranges final_dict: dict[str, Any] = { "rows": row_set, } diff --git a/tests/system/test_system.py b/tests/system/test_system.py index f0fab7d45..f6730576d 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -158,7 +158,7 @@ def __init__(self, table): self.table = table async def add_row( - self, row_key, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" ): if isinstance(value, str): value = value.encode("utf-8") @@ -339,9 +339,9 @@ async def test_read_rows_range_query(table, temp_rows): @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) @pytest.mark.asyncio -async def test_read_rows_key_query(table, temp_rows): +async def test_read_rows_single_key_query(table, temp_rows): """ - Ensure that the read_rows method works + Ensure that the read_rows method works with specified query """ from google.cloud.bigtable import ReadRowsQuery @@ -349,7 +349,7 @@ async def test_read_rows_key_query(table, temp_rows): await temp_rows.add_row(b"b") await temp_rows.add_row(b"c") await temp_rows.add_row(b"d") - # full table scan + # retrieve specific keys query = ReadRowsQuery(row_keys=[b"a", b"c"]) row_list = await table.read_rows(query) assert len(row_list) == 2 @@ -357,6 +357,29 @@ async def test_read_rows_key_query(table, temp_rows): assert row_list[1].row_key == b"c" +@retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_with_filter(table, temp_rows): + """ + ensure filters are applied + """ + from google.cloud.bigtable import ReadRowsQuery + from google.cloud.bigtable.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve keys with filter + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = await table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + @pytest.mark.asyncio async def test_read_rows_stream_close(table, temp_rows): """ @@ -397,6 +420,72 @@ async def test_read_rows_stream_inactive_timer(table, temp_rows): assert "idle_timeout=0.1" in str(e) +@pytest.mark.asyncio +async def test_read_row(table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable import Row + + await temp_rows.add_row(b"row_key_1", value=b"value") + row = await table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + +@pytest.mark.asyncio +async def test_read_row_missing(table): + """ + Test read_row when row does not exist + """ + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = await table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + await table.read_row("") + assert "Row key must be non-empty" in str(e) + + +@pytest.mark.asyncio +async def test_read_row_w_filter(table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable import Row + from google.cloud.bigtable.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = await table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + +@pytest.mark.asyncio +async def test_row_exists(table, temp_rows): + from google.api_core import exceptions + + """Test row_exists with rows that exist and don't exist""" + assert await table.row_exists(b"row_key_1") is False + await temp_rows.add_row(b"row_key_1") + assert await table.row_exists(b"row_key_1") is True + assert await table.row_exists("row_key_1") is True + assert await table.row_exists(b"row_key_2") is False + assert await table.row_exists("row_key_2") is False + assert await table.row_exists("3") is False + await temp_rows.add_row(b"3") + assert await table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + await table.row_exists("") + assert "Row kest must be non-empty" in str(e) + + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) @pytest.mark.parametrize( "cell_value,filter_input,expect_match", diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index be3703a23..14da80dae 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1296,6 +1296,158 @@ async def test_read_rows_default_timeout_override(self): assert kwargs["operation_timeout"] == operation_timeout assert kwargs["per_request_timeout"] == per_request_timeout + @pytest.mark.asyncio + async def test_read_row(self): + """Test reading a single row""" + async with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + row = await table.read_row( + row_key, + operation_timeout=expected_op_timeout, + per_request_timeout=expected_req_timeout, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["per_request_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + assert args[0]._to_dict() == { + "rows": {"row_keys": [row_key]}, + "rows_limit": 1, + } + + @pytest.mark.asyncio + async def test_read_row_w_filter(self): + """Test reading a single row with an added filter""" + async with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + mock_filter = mock.Mock() + expected_filter = {"filter": "mock filter"} + mock_filter._to_dict.return_value = expected_filter + row = await table.read_row( + row_key, + operation_timeout=expected_op_timeout, + per_request_timeout=expected_req_timeout, + row_filter=expected_filter, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["per_request_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + assert args[0]._to_dict() == { + "rows": {"row_keys": [row_key]}, + "rows_limit": 1, + "filter": expected_filter, + } + + @pytest.mark.asyncio + async def test_read_row_no_response(self): + """should return None if row does not exist""" + async with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + # return no rows + read_rows.side_effect = lambda *args, **kwargs: [] + expected_op_timeout = 8 + expected_req_timeout = 4 + result = await table.read_row( + row_key, + operation_timeout=expected_op_timeout, + per_request_timeout=expected_req_timeout, + ) + assert result is None + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["per_request_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + assert args[0]._to_dict() == { + "rows": {"row_keys": [row_key]}, + "rows_limit": 1, + } + + @pytest.mark.parametrize("input_row", [None, 5, object()]) + @pytest.mark.asyncio + async def test_read_row_w_invalid_input(self, input_row): + """Should raise error when passed None""" + async with self._make_client() as client: + table = client.get_table("instance", "table") + with pytest.raises(ValueError) as e: + await table.read_row(input_row) + assert "must be string or bytes" in e + + @pytest.mark.parametrize( + "return_value,expected_result", + [ + ([], False), + ([object()], True), + ([object(), object()], True), + ], + ) + @pytest.mark.asyncio + async def test_row_exists(self, return_value, expected_result): + """Test checking for row existence""" + async with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + # return no rows + read_rows.side_effect = lambda *args, **kwargs: return_value + expected_op_timeout = 1 + expected_req_timeout = 2 + result = await table.row_exists( + row_key, + operation_timeout=expected_op_timeout, + per_request_timeout=expected_req_timeout, + ) + assert expected_result == result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["per_request_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + expected_filter = { + "chain": { + "filters": [ + {"cells_per_row_limit_filter": 1}, + {"strip_value_transformer": True}, + ] + } + } + assert args[0]._to_dict() == { + "rows": {"row_keys": [row_key]}, + "rows_limit": 1, + "filter": expected_filter, + } + + @pytest.mark.parametrize("input_row", [None, 5, object()]) + @pytest.mark.asyncio + async def test_row_exists_w_invalid_input(self, input_row): + """Should raise error when passed None""" + async with self._make_client() as client: + table = client.get_table("instance", "table") + with pytest.raises(ValueError) as e: + await table.row_exists(input_row) + assert "must be string or bytes" in e + @pytest.mark.parametrize("include_app_profile", [True, False]) @pytest.mark.asyncio async def test_read_rows_metadata(self, include_app_profile): diff --git a/tests/unit/test_read_rows_query.py b/tests/unit/test_read_rows_query.py index aa690bc86..f630f2eab 100644 --- a/tests/unit/test_read_rows_query.py +++ b/tests/unit/test_read_rows_query.py @@ -300,7 +300,7 @@ def test_to_dict_rows_default(self): output = query._to_dict() self.assertTrue(isinstance(output, dict)) self.assertEqual(len(output.keys()), 1) - expected = {"rows": {"row_keys": [], "row_ranges": []}} + expected = {"rows": {}} self.assertEqual(output, expected) request_proto = ReadRowsRequest(**output) @@ -355,5 +355,10 @@ def test_to_dict_rows_populated(self): filter_proto = request_proto.filter self.assertEqual(filter_proto, row_filter._to_pb()) + def test_empty_row_set(self): + """Empty strings should be treated as keys inputs""" + query = self._make_one(row_keys="") + self.assertEqual(query.row_keys, {b""}) + def test_shard(self): pass