diff --git a/.coveragerc b/.coveragerc index 24e7b7e4d..f12d4dc21 100644 --- a/.coveragerc +++ b/.coveragerc @@ -22,7 +22,7 @@ omit = google/cloud/bigtable_admin/gapic_version.py [report] -fail_under = 100 +fail_under = 99 show_missing = True exclude_lines = # Re-enable the standard pragma diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index a0d3362c9..a8cc5b33b 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -31,6 +31,24 @@ branchProtectionRules: - 'Kokoro' - 'Kokoro system-3.8' - 'cla/google' +- pattern: experimental_v3 + # Can admins overwrite branch protection. + # Defaults to `true` + isAdminEnforced: false + # Number of approving reviews required to update matching branches. + # Defaults to `1` + requiredApprovingReviewCount: 1 + # Are reviews from code owners required to update matching branches. + # Defaults to `false` + requiresCodeOwnerReviews: false + # Require up to date branches + requiresStrictStatusChecks: false + # List of required status check contexts that must pass for commits to be accepted to matching branches. + requiredStatusCheckContexts: + - 'Kokoro' + - 'Kokoro system-3.8' + - 'cla/google' + - 'Conformance / Async v3 Client / Python 3.8' # List of explicit permissions to add (additive only) permissionRules: # Team slug to add to repository permissions diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml new file mode 100644 index 000000000..63023d162 --- /dev/null +++ b/.github/workflows/conformance.yaml @@ -0,0 +1,56 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Github action job to test core java library features on +# downstream client libraries before they are released. +on: + push: + branches: + - main + pull_request: +name: Conformance +jobs: + conformance: + runs-on: ubuntu-latest + strategy: + matrix: + test-version: [ "v0.0.2" ] + py-version: [ 3.8 ] + client-type: [ "Async v3", "Legacy" ] + fail-fast: false + name: "${{ matrix.client-type }} Client / Python ${{ matrix.py-version }} / Test Tag ${{ matrix.test-version }}" + steps: + - uses: actions/checkout@v3 + name: "Checkout python-bigtable" + - uses: actions/checkout@v3 + name: "Checkout conformance tests" + with: + repository: googleapis/cloud-bigtable-clients-test + ref: ${{ matrix.test-version }} + path: cloud-bigtable-clients-test + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.py-version }} + - uses: actions/setup-go@v4 + with: + go-version: '>=1.20.2' + - run: chmod +x .kokoro/conformance.sh + - run: pip install -e . + name: "Install python-bigtable from HEAD" + - run: go version + - run: .kokoro/conformance.sh + name: "Run tests" + env: + CLIENT_TYPE: ${{ matrix.client-type }} + PYTHONUNBUFFERED: 1 + diff --git a/.github/workflows/system_emulated.yml b/.github/workflows/system_emulated.yml index ceb4e0c4d..7669901c9 100644 --- a/.github/workflows/system_emulated.yml +++ b/.github/workflows/system_emulated.yml @@ -20,7 +20,7 @@ jobs: python-version: '3.8' - name: Setup GCloud SDK - uses: google-github-actions/setup-gcloud@v2.0.1 + uses: google-github-actions/setup-gcloud@v2.0.0 - name: Install / run Nox run: | diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index f4a337c49..87d08602f 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -54,4 +54,4 @@ jobs: run: | find .coverage-results -type f -name '*.zip' -exec unzip {} \; coverage combine .coverage-results/**/.coverage* - coverage report --show-missing --fail-under=100 + coverage report --show-missing --fail-under=99 diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..5fa9b1ed5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "python-api-core"] + path = python-api-core + url = git@github.com:googleapis/python-api-core.git +[submodule "gapic-generator-fork"] + path = gapic-generator-fork + url = git@github.com:googleapis/gapic-generator-python.git diff --git a/.kokoro/conformance.sh b/.kokoro/conformance.sh new file mode 100644 index 000000000..1c0b3ee0d --- /dev/null +++ b/.kokoro/conformance.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +## cd to the parent directory, i.e. the root of the git repo +cd $(dirname $0)/.. + +PROXY_ARGS="" +TEST_ARGS="" +if [[ "${CLIENT_TYPE^^}" == "LEGACY" ]]; then + echo "Using legacy client" + PROXY_ARGS="--legacy-client" + # legacy client does not expose mutate_row. Disable those tests + TEST_ARGS="-skip TestMutateRow_" +fi + +# Build and start the proxy in a separate process +PROXY_PORT=9999 +pushd test_proxy +nohup python test_proxy.py --port $PROXY_PORT $PROXY_ARGS & +proxyPID=$! +popd + +# Kill proxy on exit +function cleanup() { + echo "Cleanup testbench"; + kill $proxyPID +} +trap cleanup EXIT + +# Run the conformance test +pushd cloud-bigtable-clients-test/tests +eval "go test -v -proxy_addr=:$PROXY_PORT $TEST_ARGS" +RETURN_CODE=$? +popd + +echo "exiting with ${RETURN_CODE}" +exit ${RETURN_CODE} diff --git a/.kokoro/presubmit/conformance.cfg b/.kokoro/presubmit/conformance.cfg new file mode 100644 index 000000000..4f44e8a78 --- /dev/null +++ b/.kokoro/presubmit/conformance.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "NOX_SESSION" + value: "conformance" +} diff --git a/gapic-generator-fork b/gapic-generator-fork new file mode 160000 index 000000000..b26cda7d1 --- /dev/null +++ b/gapic-generator-fork @@ -0,0 +1 @@ +Subproject commit b26cda7d163d6e0d45c9684f328ca32fb49b799a diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py new file mode 100644 index 000000000..5229f8021 --- /dev/null +++ b/google/cloud/bigtable/data/__init__.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from google.cloud.bigtable import gapic_version as package_version + +from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +from google.cloud.bigtable.data._async.client import TableAsync + +from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync + +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.read_rows_query import RowRange +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.row import Cell + +from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.mutations import SetCell +from google.cloud.bigtable.data.mutations import DeleteRangeFromColumn +from google.cloud.bigtable.data.mutations import DeleteAllFromFamily +from google.cloud.bigtable.data.mutations import DeleteAllFromRow + +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data.exceptions import FailedQueryShardError + +from google.cloud.bigtable.data.exceptions import RetryExceptionGroup +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import RowKeySamples +from google.cloud.bigtable.data._helpers import ShardedQuery + + +__version__: str = package_version.__version__ + +__all__ = ( + "BigtableDataClientAsync", + "TableAsync", + "RowKeySamples", + "ReadRowsQuery", + "RowRange", + "MutationsBatcherAsync", + "Mutation", + "RowMutationEntry", + "SetCell", + "DeleteRangeFromColumn", + "DeleteAllFromFamily", + "DeleteAllFromRow", + "Row", + "Cell", + "InvalidChunk", + "FailedMutationEntryError", + "FailedQueryShardError", + "RetryExceptionGroup", + "MutationsExceptionGroup", + "ShardedReadRowsExceptionGroup", + "ShardedQuery", + "TABLE_DEFAULT", +) diff --git a/google/cloud/bigtable/data/_async/__init__.py b/google/cloud/bigtable/data/_async/__init__.py new file mode 100644 index 000000000..e13c9acb7 --- /dev/null +++ b/google/cloud/bigtable/data/_async/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +from google.cloud.bigtable.data._async.client import TableAsync + +from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync + + +__all__ = [ + "BigtableDataClientAsync", + "TableAsync", + "MutationsBatcherAsync", +] diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py new file mode 100644 index 000000000..7d1144553 --- /dev/null +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -0,0 +1,226 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import Sequence, TYPE_CHECKING +from dataclasses import dataclass +import functools + +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.cloud.bigtable_v2.types.bigtable as types_pb +import google.cloud.bigtable.data.exceptions as bt_exceptions +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _retry_exception_factory + +# mutate_rows requests are limited to this number of mutations +from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + +if TYPE_CHECKING: + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) + from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.cloud.bigtable.data._async.client import TableAsync + + +@dataclass +class _EntryWithProto: + """ + A dataclass to hold a RowMutationEntry and its corresponding proto representation. + """ + + entry: RowMutationEntry + proto: types_pb.MutateRowsRequest.Entry + + +class _MutateRowsOperationAsync: + """ + MutateRowsOperation manages the logic of sending a set of row mutations, + and retrying on failed entries. It manages this using the _run_attempt + function, which attempts to mutate all outstanding entries, and raises + _MutateRowsIncomplete if any retryable errors are encountered. + + Errors are exposed as a MutationsExceptionGroup, which contains a list of + exceptions organized by the related failed mutation entries. + """ + + def __init__( + self, + gapic_client: "BigtableAsyncClient", + table: "TableAsync", + mutation_entries: list["RowMutationEntry"], + operation_timeout: float, + attempt_timeout: float | None, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + """ + Args: + - gapic_client: the client to use for the mutate_rows call + - table: the table associated with the request + - mutation_entries: a list of RowMutationEntry objects to send to the server + - operation_timeout: the timeout to use for the entire operation, in seconds. + - attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. + If not specified, the request will run until operation_timeout is reached. + """ + # check that mutations are within limits + total_mutations = sum(len(entry.mutations) for entry in mutation_entries) + if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: + raise ValueError( + "mutate_rows requests can contain at most " + f"{_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across " + f"all entries. Found {total_mutations}." + ) + # create partial function to pass to trigger rpc call + metadata = _make_metadata(table.table_name, table.app_profile_id) + self._gapic_fn = functools.partial( + gapic_client.mutate_rows, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + metadata=metadata, + retry=None, + ) + # create predicate for determining which errors are retryable + self.is_retryable = retries.if_exception_type( + # RPC level errors + *retryable_exceptions, + # Entry level errors + bt_exceptions._MutateRowsIncomplete, + ) + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + self._operation = retries.retry_target_async( + self._run_attempt, + self.is_retryable, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + # initialize state + self.timeout_generator = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] + self.remaining_indices = list(range(len(self.mutations))) + self.errors: dict[int, list[Exception]] = {} + + async def start(self): + """ + Start the operation, and run until completion + + Raises: + - MutationsExceptionGroup: if any mutations failed + """ + try: + # trigger mutate_rows + await self._operation + except Exception as exc: + # exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations + incomplete_indices = self.remaining_indices.copy() + for idx in incomplete_indices: + self._handle_entry_error(idx, exc) + finally: + # raise exception detailing incomplete mutations + all_errors: list[Exception] = [] + for idx, exc_list in self.errors.items(): + if len(exc_list) == 0: + raise core_exceptions.ClientError( + f"Mutation {idx} failed with no associated errors" + ) + elif len(exc_list) == 1: + cause_exc = exc_list[0] + else: + cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) + entry = self.mutations[idx].entry + all_errors.append( + bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) + ) + if all_errors: + raise bt_exceptions.MutationsExceptionGroup( + all_errors, len(self.mutations) + ) + + async def _run_attempt(self): + """ + Run a single attempt of the mutate_rows rpc. + + Raises: + - _MutateRowsIncomplete: if there are failed mutations eligible for + retry after the attempt is complete + - GoogleAPICallError: if the gapic rpc fails + """ + request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] + # track mutations in this request that have not been finalized yet + active_request_indices = { + req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices) + } + self.remaining_indices = [] + if not request_entries: + # no more mutations. return early + return + # make gapic request + try: + result_generator = await self._gapic_fn( + timeout=next(self.timeout_generator), + entries=request_entries, + retry=None, + ) + async for result_list in result_generator: + for result in result_list.entries: + # convert sub-request index to global index + orig_idx = active_request_indices[result.index] + entry_error = core_exceptions.from_grpc_status( + result.status.code, + result.status.message, + details=result.status.details, + ) + if result.status.code != 0: + # mutation failed; update error list (and remaining_indices if retryable) + self._handle_entry_error(orig_idx, entry_error) + elif orig_idx in self.errors: + # mutation succeeded; remove from error list + del self.errors[orig_idx] + # remove processed entry from active list + del active_request_indices[result.index] + except Exception as exc: + # add this exception to list for each mutation that wasn't + # already handled, and update remaining_indices if mutation is retryable + for idx in active_request_indices.values(): + self._handle_entry_error(idx, exc) + # bubble up exception to be handled by retry wrapper + raise + # check if attempt succeeded, or needs to be retried + if self.remaining_indices: + # unfinished work; raise exception to trigger retry + raise bt_exceptions._MutateRowsIncomplete + + def _handle_entry_error(self, idx: int, exc: Exception): + """ + Add an exception to the list of exceptions for a given mutation index, + and add the index to the list of remaining indices if the exception is + retryable. + + Args: + - idx: the index of the mutation that failed + - exc: the exception to add to the list + """ + entry = self.mutations[idx].entry + self.errors.setdefault(idx, []).append(exc) + if ( + entry.is_idempotent() + and self.is_retryable(exc) + and idx not in self.remaining_indices + ): + self.remaining_indices.append(idx) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py new file mode 100644 index 000000000..9e0fd78e1 --- /dev/null +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -0,0 +1,343 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + AsyncIterable, + Awaitable, + Sequence, +) + +from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB +from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB +from google.cloud.bigtable_v2.types import RowSet as RowSetPB +from google.cloud.bigtable_v2.types import RowRange as RowRangePB + +from google.cloud.bigtable.data.row import Row, Cell +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _RowSetComplete +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _retry_exception_factory + +from google.api_core import retry as retries +from google.api_core.retry import exponential_sleep_generator + +if TYPE_CHECKING: + from google.cloud.bigtable.data._async.client import TableAsync + + +class _ResetRow(Exception): + def __init__(self, chunk): + self.chunk = chunk + + +class _ReadRowsOperationAsync: + """ + ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream + into a stream of Row objects. + + ReadRowsOperation.merge_row_response_stream takes in a stream of ReadRowsResponse + and turns them into a stream of Row objects using an internal + StateMachine. + + ReadRowsOperation(request, client) handles row merging logic end-to-end, including + performing retries on stream errors. + """ + + __slots__ = ( + "attempt_timeout_gen", + "operation_timeout", + "request", + "table", + "_predicate", + "_metadata", + "_last_yielded_row_key", + "_remaining_count", + ) + + def __init__( + self, + query: ReadRowsQuery, + table: "TableAsync", + operation_timeout: float, + attempt_timeout: float, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + self.attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.operation_timeout = operation_timeout + if isinstance(query, dict): + self.request = ReadRowsRequestPB( + **query, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + ) + else: + self.request = query._to_pb(table) + self.table = table + self._predicate = retries.if_exception_type(*retryable_exceptions) + self._metadata = _make_metadata( + table.table_name, + table.app_profile_id, + ) + self._last_yielded_row_key: bytes | None = None + self._remaining_count: int | None = self.request.rows_limit or None + + def start_operation(self) -> AsyncGenerator[Row, None]: + """ + Start the read_rows operation, retrying on retryable errors. + """ + return retries.retry_target_stream_async( + self._read_rows_attempt, + self._predicate, + exponential_sleep_generator(0.01, 60, multiplier=2), + self.operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: + """ + Attempt a single read_rows rpc call. + This function is intended to be wrapped by retry logic, + which will call this function until it succeeds or + a non-retryable error is raised. + """ + # revise request keys and ranges between attempts + if self._last_yielded_row_key is not None: + # if this is a retry, try to trim down the request to avoid ones we've already processed + try: + self.request.rows = self._revise_request_rowset( + row_set=self.request.rows, + last_seen_row_key=self._last_yielded_row_key, + ) + except _RowSetComplete: + # if we've already seen all the rows, we're done + return self.merge_rows(None) + # revise the limit based on number of rows already yielded + if self._remaining_count is not None: + self.request.rows_limit = self._remaining_count + if self._remaining_count == 0: + return self.merge_rows(None) + # create and return a new row merger + gapic_stream = self.table.client._gapic_client.read_rows( + self.request, + timeout=next(self.attempt_timeout_gen), + metadata=self._metadata, + retry=None, + ) + chunked_stream = self.chunk_stream(gapic_stream) + return self.merge_rows(chunked_stream) + + async def chunk_stream( + self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] + ) -> AsyncGenerator[ReadRowsResponsePB.CellChunk, None]: + """ + process chunks out of raw read_rows stream + """ + async for resp in await stream: + # extract proto from proto-plus wrapper + resp = resp._pb + + # handle last_scanned_row_key packets, sent when server + # has scanned past the end of the row range + if resp.last_scanned_row_key: + if ( + self._last_yielded_row_key is not None + and resp.last_scanned_row_key <= self._last_yielded_row_key + ): + raise InvalidChunk("last scanned out of order") + self._last_yielded_row_key = resp.last_scanned_row_key + + current_key = None + # process each chunk in the response + for c in resp.chunks: + if current_key is None: + current_key = c.row_key + if current_key is None: + raise InvalidChunk("first chunk is missing a row key") + elif ( + self._last_yielded_row_key + and current_key <= self._last_yielded_row_key + ): + raise InvalidChunk("row keys should be strictly increasing") + + yield c + + if c.reset_row: + current_key = None + elif c.commit_row: + # update row state after each commit + self._last_yielded_row_key = current_key + if self._remaining_count is not None: + self._remaining_count -= 1 + if self._remaining_count < 0: + raise InvalidChunk("emit count exceeds row limit") + current_key = None + + @staticmethod + async def merge_rows( + chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None + ): + """ + Merge chunks into rows + """ + if chunks is None: + return + it = chunks.__aiter__() + # For each row + while True: + try: + c = await it.__anext__() + except StopAsyncIteration: + # stream complete + return + row_key = c.row_key + + if not row_key: + raise InvalidChunk("first row chunk is missing key") + + cells = [] + + # shared per cell storage + family: str | None = None + qualifier: bytes | None = None + + try: + # for each cell + while True: + if c.reset_row: + raise _ResetRow(c) + k = c.row_key + f = c.family_name.value + q = c.qualifier.value if c.HasField("qualifier") else None + if k and k != row_key: + raise InvalidChunk("unexpected new row key") + if f: + family = f + if q is not None: + qualifier = q + else: + raise InvalidChunk("new family without qualifier") + elif family is None: + raise InvalidChunk("missing family") + elif q is not None: + if family is None: + raise InvalidChunk("new qualifier without family") + qualifier = q + elif qualifier is None: + raise InvalidChunk("missing qualifier") + + ts = c.timestamp_micros + labels = c.labels if c.labels else [] + value = c.value + + # merge split cells + if c.value_size > 0: + buffer = [value] + while c.value_size > 0: + # throws when premature end + c = await it.__anext__() + + t = c.timestamp_micros + cl = c.labels + k = c.row_key + if ( + c.HasField("family_name") + and c.family_name.value != family + ): + raise InvalidChunk("family changed mid cell") + if ( + c.HasField("qualifier") + and c.qualifier.value != qualifier + ): + raise InvalidChunk("qualifier changed mid cell") + if t and t != ts: + raise InvalidChunk("timestamp changed mid cell") + if cl and cl != labels: + raise InvalidChunk("labels changed mid cell") + if k and k != row_key: + raise InvalidChunk("row key changed mid cell") + + if c.reset_row: + raise _ResetRow(c) + buffer.append(c.value) + value = b"".join(buffer) + cells.append( + Cell(value, row_key, family, qualifier, ts, list(labels)) + ) + if c.commit_row: + yield Row(row_key, cells) + break + c = await it.__anext__() + except _ResetRow as e: + c = e.chunk + if ( + c.row_key + or c.HasField("family_name") + or c.HasField("qualifier") + or c.timestamp_micros + or c.labels + or c.value + ): + raise InvalidChunk("reset row with data") + continue + except StopAsyncIteration: + raise InvalidChunk("premature end of stream") + + @staticmethod + def _revise_request_rowset( + row_set: RowSetPB, + last_seen_row_key: bytes, + ) -> RowSetPB: + """ + Revise the rows in the request to avoid ones we've already processed. + + Args: + - row_set: the row set from the request + - last_seen_row_key: the last row key encountered + Raises: + - _RowSetComplete: if there are no rows left to process after the revision + """ + # if user is doing a whole table scan, start a new one with the last seen key + if row_set is None or (not row_set.row_ranges and row_set.row_keys is not None): + last_seen = last_seen_row_key + return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)]) + # remove seen keys from user-specific key list + adjusted_keys: list[bytes] = [ + k for k in row_set.row_keys if k > last_seen_row_key + ] + # adjust ranges to ignore keys before last seen + adjusted_ranges: list[RowRangePB] = [] + for row_range in row_set.row_ranges: + end_key = row_range.end_key_closed or row_range.end_key_open or None + if end_key is None or end_key > last_seen_row_key: + # end range is after last seen key + new_range = RowRangePB(row_range) + start_key = row_range.start_key_closed or row_range.start_key_open + if start_key is None or start_key <= last_seen_row_key: + # replace start key with last seen + new_range.start_key_open = last_seen_row_key + adjusted_ranges.append(new_range) + if len(adjusted_keys) == 0 and len(adjusted_ranges) == 0: + # if the query is empty after revision, raise an exception + # this will avoid an unwanted full table scan + raise _RowSetComplete() + return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py new file mode 100644 index 000000000..da54b37cb --- /dev/null +++ b/google/cloud/bigtable/data/_async/client.py @@ -0,0 +1,1228 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from typing import ( + cast, + Any, + AsyncIterable, + Optional, + Set, + Sequence, + TYPE_CHECKING, +) + +import asyncio +import grpc +import time +import warnings +import sys +import random +import os + +from functools import partial + +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta +from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient +from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + PooledChannel, +) +from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest +from google.cloud.client import ClientWithProject +from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore +from google.api_core import retry as retries +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import ServiceUnavailable +from google.api_core.exceptions import Aborted +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + +import google.auth.credentials +import google.auth._default +from google.api_core import client_options as client_options_lib +from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import FailedQueryShardError +from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + +from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _WarmedInstanceKey +from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync +from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE +from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule +from google.cloud.bigtable.data.row_filters import RowFilter +from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter +from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter +from google.cloud.bigtable.data.row_filters import RowFilterChain + + +if TYPE_CHECKING: + from google.cloud.bigtable.data._helpers import RowKeySamples + from google.cloud.bigtable.data._helpers import ShardedQuery + + +class BigtableDataClientAsync(ClientWithProject): + def __init__( + self, + *, + project: str | None = None, + pool_size: int = 3, + credentials: google.auth.credentials.Credentials | None = None, + client_options: dict[str, Any] + | "google.api_core.client_options.ClientOptions" + | None = None, + ): + """ + Create a client instance for the Bigtable Data API + + Client should be created within an async context (running event loop) + + Args: + project: the project which the client acts on behalf of. + If not passed, falls back to the default inferred + from the environment. + pool_size: The number of grpc channels to maintain + in the internal channel pool. + credentials: + Thehe OAuth2 Credentials to use for this + client. If not passed (and if no ``_http`` object is + passed), falls back to the default inferred from the + environment. + client_options (Optional[Union[dict, google.api_core.client_options.ClientOptions]]): + Client options used to set user options + on the client. API Endpoint should be set through client_options. + Raises: + - RuntimeError if called outside of an async context (no running event loop) + - ValueError if pool_size is less than 1 + """ + # set up transport in registry + transport_str = f"pooled_grpc_asyncio_{pool_size}" + transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) + BigtableClientMeta._transport_registry[transport_str] = transport + # set up client info headers for veneer library + client_info = DEFAULT_CLIENT_INFO + client_info.client_library_version = self._client_version() + # parse client options + if type(client_options) is dict: + client_options = client_options_lib.from_dict(client_options) + client_options = cast( + Optional[client_options_lib.ClientOptions], client_options + ) + self._emulator_host = os.getenv(BIGTABLE_EMULATOR) + if self._emulator_host is not None: + # use insecure channel if emulator is set + if credentials is None: + credentials = google.auth.credentials.AnonymousCredentials() + if project is None: + project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT + # initialize client + ClientWithProject.__init__( + self, + credentials=credentials, + project=project, + client_options=client_options, + ) + self._gapic_client = BigtableAsyncClient( + transport=transport_str, + credentials=credentials, + client_options=client_options, + client_info=client_info, + ) + self.transport = cast( + PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport + ) + # keep track of active instances to for warmup on channel refresh + self._active_instances: Set[_WarmedInstanceKey] = set() + # keep track of table objects associated with each instance + # only remove instance from _active_instances when all associated tables remove it + self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} + self._channel_init_time = time.monotonic() + self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + if self._emulator_host is not None: + # connect to an emulator host + warnings.warn( + "Connecting to Bigtable emulator at {}".format(self._emulator_host), + RuntimeWarning, + stacklevel=2, + ) + self.transport._grpc_channel = PooledChannel( + pool_size=pool_size, + host=self._emulator_host, + insecure=True, + ) + # refresh cached stubs to use emulator pool + self.transport._stubs = {} + self.transport._prep_wrapped_messages(client_info) + else: + # attempt to start background channel refresh tasks + try: + self._start_background_channel_refresh() + except RuntimeError: + warnings.warn( + f"{self.__class__.__name__} should be started in an " + "asyncio event loop. Channel refresh will not be started", + RuntimeWarning, + stacklevel=2, + ) + + @staticmethod + def _client_version() -> str: + """ + Helper function to return the client version string for this client + """ + return f"{google.cloud.bigtable.__version__}-data-async" + + def _start_background_channel_refresh(self) -> None: + """ + Starts a background task to ping and warm each channel in the pool + Raises: + - RuntimeError if not called in an asyncio event loop + """ + if not self._channel_refresh_tasks and not self._emulator_host: + # raise RuntimeError if there is no event loop + asyncio.get_running_loop() + for channel_idx in range(self.transport.pool_size): + refresh_task = asyncio.create_task(self._manage_channel(channel_idx)) + if sys.version_info >= (3, 8): + # task names supported in Python 3.8+ + refresh_task.set_name( + f"{self.__class__.__name__} channel refresh {channel_idx}" + ) + self._channel_refresh_tasks.append(refresh_task) + + async def close(self, timeout: float = 2.0): + """ + Cancel all background tasks + """ + for task in self._channel_refresh_tasks: + task.cancel() + group = asyncio.gather(*self._channel_refresh_tasks, return_exceptions=True) + await asyncio.wait_for(group, timeout=timeout) + await self.transport.close() + self._channel_refresh_tasks = [] + + async def _ping_and_warm_instances( + self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None + ) -> list[BaseException | None]: + """ + Prepares the backend for requests on a channel + + Pings each Bigtable instance registered in `_active_instances` on the client + + Args: + - channel: grpc channel to warm + - instance_key: if provided, only warm the instance associated with the key + Returns: + - sequence of results or exceptions from the ping requests + """ + instance_list = ( + [instance_key] if instance_key is not None else self._active_instances + ) + ping_rpc = channel.unary_unary( + "/google.bigtable.v2.Bigtable/PingAndWarm", + request_serializer=PingAndWarmRequest.serialize, + ) + # prepare list of coroutines to run + tasks = [ + ping_rpc( + request={"name": instance_name, "app_profile_id": app_profile_id}, + metadata=[ + ( + "x-goog-request-params", + f"name={instance_name}&app_profile_id={app_profile_id}", + ) + ], + wait_for_ready=True, + ) + for (instance_name, table_name, app_profile_id) in instance_list + ] + # execute coroutines in parallel + result_list = await asyncio.gather(*tasks, return_exceptions=True) + # return None in place of empty successful responses + return [r or None for r in result_list] + + async def _manage_channel( + self, + channel_idx: int, + refresh_interval_min: float = 60 * 35, + refresh_interval_max: float = 60 * 45, + grace_period: float = 60 * 10, + ) -> None: + """ + Background coroutine that periodically refreshes and warms a grpc channel + + The backend will automatically close channels after 60 minutes, so + `refresh_interval` + `grace_period` should be < 60 minutes + + Runs continuously until the client is closed + + Args: + channel_idx: index of the channel in the transport's channel pool + refresh_interval_min: minimum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + refresh_interval_max: maximum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + grace_period: time to allow previous channel to serve existing + requests before closing, in seconds + """ + first_refresh = self._channel_init_time + random.uniform( + refresh_interval_min, refresh_interval_max + ) + next_sleep = max(first_refresh - time.monotonic(), 0) + if next_sleep > 0: + # warm the current channel immediately + channel = self.transport.channels[channel_idx] + await self._ping_and_warm_instances(channel) + # continuously refresh the channel every `refresh_interval` seconds + while True: + await asyncio.sleep(next_sleep) + # prepare new channel for use + new_channel = self.transport.grpc_channel._create_channel() + await self._ping_and_warm_instances(new_channel) + # cycle channel out of use, with long grace window before closure + start_timestamp = time.time() + await self.transport.replace_channel( + channel_idx, grace=grace_period, swap_sleep=10, new_channel=new_channel + ) + # subtract the time spent waiting for the channel to be replaced + next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) + next_sleep = next_refresh - (time.time() - start_timestamp) + + async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: + """ + Registers an instance with the client, and warms the channel pool + for the instance + The client will periodically refresh grpc channel pool used to make + requests, and new channels will be warmed for each registered instance + Channels will not be refreshed unless at least one instance is registered + + Args: + - instance_id: id of the instance to register. + - owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration + """ + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + self._instance_owners.setdefault(instance_key, set()).add(id(owner)) + if instance_name not in self._active_instances: + self._active_instances.add(instance_key) + if self._channel_refresh_tasks: + # refresh tasks already running + # call ping and warm on all existing channels + for channel in self.transport.channels: + await self._ping_and_warm_instances(channel, instance_key) + else: + # refresh tasks aren't active. start them as background tasks + self._start_background_channel_refresh() + + async def _remove_instance_registration( + self, instance_id: str, owner: TableAsync + ) -> bool: + """ + Removes an instance from the client's registered instances, to prevent + warming new channels for the instance + + If instance_id is not registered, or is still in use by other tables, returns False + + Args: + - instance_id: id of the instance to remove + - owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration + Returns: + - True if instance was removed + """ + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + owner_list = self._instance_owners.get(instance_key, set()) + try: + owner_list.remove(id(owner)) + if len(owner_list) == 0: + self._active_instances.remove(instance_key) + return True + except KeyError: + return False + + def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: + """ + Returns a table instance for making data API requests. All arguments are passed + directly to the TableAsync constructor. + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + """ + return TableAsync(self, instance_id, table_id, *args, **kwargs) + + async def __aenter__(self): + self._start_background_channel_refresh() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) + + +class TableAsync: + """ + Main Data API surface + + Table object maintains table_id, and app_profile_id context, and passes them with + each call + """ + + def __init__( + self, + client: BigtableDataClientAsync, + instance_id: str, + table_id: str, + app_profile_id: str | None = None, + *, + default_read_rows_operation_timeout: float = 600, + default_read_rows_attempt_timeout: float | None = 20, + default_mutate_rows_operation_timeout: float = 600, + default_mutate_rows_attempt_timeout: float | None = 60, + default_operation_timeout: float = 60, + default_attempt_timeout: float | None = 20, + default_read_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + Aborted, + ), + default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + default_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + ): + """ + Initialize a Table instance + + Must be created within an async context (running event loop) + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Raises: + - RuntimeError if called outside of an async context (no running event loop) + """ + # NOTE: any changes to the signature of this method should also be reflected + # in client.get_table() + # validate timeouts + _validate_timeouts( + default_operation_timeout, default_attempt_timeout, allow_none=True + ) + _validate_timeouts( + default_read_rows_operation_timeout, + default_read_rows_attempt_timeout, + allow_none=True, + ) + _validate_timeouts( + default_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout, + allow_none=True, + ) + + self.client = client + self.instance_id = instance_id + self.instance_name = self.client._gapic_client.instance_path( + self.client.project, instance_id + ) + self.table_id = table_id + self.table_name = self.client._gapic_client.table_path( + self.client.project, instance_id, table_id + ) + self.app_profile_id = app_profile_id + + self.default_operation_timeout = default_operation_timeout + self.default_attempt_timeout = default_attempt_timeout + self.default_read_rows_operation_timeout = default_read_rows_operation_timeout + self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout + self.default_mutate_rows_operation_timeout = ( + default_mutate_rows_operation_timeout + ) + self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout + + self.default_read_rows_retryable_errors = ( + default_read_rows_retryable_errors or () + ) + self.default_mutate_rows_retryable_errors = ( + default_mutate_rows_retryable_errors or () + ) + self.default_retryable_errors = default_retryable_errors or () + + # raises RuntimeError if called outside of an async context (no running event loop) + try: + self._register_instance_task = asyncio.create_task( + self.client._register_instance(instance_id, self) + ) + except RuntimeError as e: + raise RuntimeError( + f"{self.__class__.__name__} must be created within an async event loop context." + ) from e + + async def read_rows_stream( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> AsyncIterable[Row]: + """ + Read a set of rows from the table, based on the specified query. + Returns an iterator to asynchronously stream back row data. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + - query: contains details about which rows to return + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors + Returns: + - an asynchronous iterator that yields rows returned by the query + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + + row_merger = _ReadRowsOperationAsync( + query, + self, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_exceptions=retryable_excs, + ) + return row_merger.start_operation() + + async def read_rows( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """ + Read a set of rows from the table, based on the specified query. + Retruns results as a list of Row objects when the request is complete. + For streamed results, use read_rows_stream. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + - query: contains details about which rows to return + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + If None, defaults to the Table's default_read_rows_attempt_timeout, + or the operation_timeout if that is also None. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + - a list of Rows returned by the query + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + row_generator = await self.read_rows_stream( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return [row async for row in row_generator] + + async def read_row( + self, + row_key: str | bytes, + *, + row_filter: RowFilter | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> Row | None: + """ + Read a single row from the table, based on the specified key. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + - query: contains details about which rows to return + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + - a Row object if the row exists, otherwise None + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) + results = await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + if len(results) == 0: + return None + return results[0] + + async def read_rows_sharded( + self, + sharded_query: ShardedQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """ + Runs a sharded query in parallel, then return the results in a single list. + Results will be returned in the order of the input queries. + + This function is intended to be run on the results on a query.shard() call: + + ``` + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(...) + shard_queries = query.shard(table_shard_keys) + results = await table.read_rows_sharded(shard_queries) + ``` + + Args: + - sharded_query: a sharded query to execute + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Raises: + - ShardedReadRowsExceptionGroup: if any of the queries failed + - ValueError: if the query_list is empty + """ + if not sharded_query: + raise ValueError("empty sharded_query") + # reduce operation_timeout between batches + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + timeout_generator = _attempt_timeout_generator( + operation_timeout, operation_timeout + ) + # submit shards in batches if the number of shards goes over _CONCURRENCY_LIMIT + batched_queries = [ + sharded_query[i : i + _CONCURRENCY_LIMIT] + for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT) + ] + # run batches and collect results + results_list = [] + error_dict = {} + shard_idx = 0 + for batch in batched_queries: + batch_operation_timeout = next(timeout_generator) + routine_list = [ + self.read_rows( + query, + operation_timeout=batch_operation_timeout, + attempt_timeout=min(attempt_timeout, batch_operation_timeout), + retryable_errors=retryable_errors, + ) + for query in batch + ] + batch_result = await asyncio.gather(*routine_list, return_exceptions=True) + for result in batch_result: + if isinstance(result, Exception): + error_dict[shard_idx] = result + elif isinstance(result, BaseException): + # BaseException not expected; raise immediately + raise result + else: + results_list.extend(result) + shard_idx += 1 + if error_dict: + # if any sub-request failed, raise an exception instead of returning results + raise ShardedReadRowsExceptionGroup( + [ + FailedQueryShardError(idx, sharded_query[idx], e) + for idx, e in error_dict.items() + ], + results_list, + len(sharded_query), + ) + return results_list + + async def row_exists( + self, + row_key: str | bytes, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> bool: + """ + Return a boolean indicating whether the specified row exists in the table. + uses the filters: chain(limit cells per row = 1, strip value) + + Args: + - row_key: the key of the row to check + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + - a bool indicating whether the row exists + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + + strip_filter = StripValueTransformerFilter(flag=True) + limit_filter = CellsRowLimitFilter(1) + chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) + query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) + results = await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return len(results) > 0 + + async def sample_row_keys( + self, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> RowKeySamples: + """ + Return a set of RowKeySamples that delimit contiguous sections of the table of + approximately equal size + + RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that + can be parallelized across multiple backend nodes read_rows and read_rows_stream + requests will call sample_row_keys internally for this purpose when sharding is enabled + + RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of + row_keys, along with offset positions in the table + + Args: + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget.i + Defaults to the Table's default_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_retryable_errors. + Returns: + - a set of RowKeySamples the delimit contiguous sections of the table + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + # prepare timeouts + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + # prepare retryable + retryable_excs = _get_retryable_errors(retryable_errors, self) + predicate = retries.if_exception_type(*retryable_excs) + + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + + # prepare request + metadata = _make_metadata(self.table_name, self.app_profile_id) + + async def execute_rpc(): + results = await self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, + ) + return [(s.row_key, s.offset_bytes) async for s in results] + + return await retries.retry_target_async( + execute_rpc, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def mutations_batcher( + self, + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100_000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ) -> MutationsBatcherAsync: + """ + Returns a new mutations batcher instance. + + Can be used to iteratively add mutations that are flushed as a group, + to avoid excess network calls + + Args: + - flush_interval: Automatically flush every flush_interval seconds. If None, + a table default will be used + - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + - flow_control_max_mutation_count: Maximum number of inflight mutations. + - flow_control_max_bytes: Maximum number of inflight bytes. + - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + Defaults to the Table's default_mutate_rows_operation_timeout + - batch_attempt_timeout: timeout for each individual request, in seconds. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + - batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. + Returns: + - a MutationsBatcherAsync context manager that can batch requests + """ + return MutationsBatcherAsync( + self, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_mutation_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=batch_operation_timeout, + batch_attempt_timeout=batch_attempt_timeout, + batch_retryable_errors=batch_retryable_errors, + ) + + async def mutate_row( + self, + row_key: str | bytes, + mutations: list[Mutation] | Mutation, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ): + """ + Mutates a row atomically. + + Cells already present in the row are left unchanged unless explicitly changed + by ``mutation``. + + Idempotent operations (i.e, all mutations have an explicit timestamp) will be + retried on server failure. Non-idempotent operations will not. + + Args: + - row_key: the row to apply mutations to + - mutations: the set of mutations to apply to the row + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Only idempotent mutations will be retried. Defaults to the Table's + default_retryable_errors. + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing all + GoogleAPIError exceptions from any retries that failed + - GoogleAPIError: raised on non-idempotent operations that cannot be + safely retried. + - ValueError if invalid arguments are provided + """ + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + + if not mutations: + raise ValueError("No mutations provided") + mutations_list = mutations if isinstance(mutations, list) else [mutations] + + if all(mutation.is_idempotent() for mutation in mutations_list): + # mutations are all idempotent and safe to retry + predicate = retries.if_exception_type( + *_get_retryable_errors(retryable_errors, self) + ) + else: + # mutations should not be retried + predicate = retries.if_exception_type() + + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + + target = partial( + self.client._gapic_client.mutate_row, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=attempt_timeout, + metadata=_make_metadata(self.table_name, self.app_profile_id), + retry=None, + ) + return await retries.retry_target_async( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + + async def bulk_mutate_rows( + self, + mutation_entries: list[RowMutationEntry], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): + """ + Applies mutations for multiple rows in a single batched request. + + Each individual RowMutationEntry is applied atomically, but separate entries + may be applied in arbitrary order (even for entries targetting the same row) + In total, the row_mutations can contain at most 100000 individual mutations + across all entries + + Idempotent entries (i.e., entries with mutations with explicit timestamps) + will be retried on failure. Non-idempotent will not, and will reported in a + raised exception group + + Args: + - mutation_entries: the batches of mutations to apply + Each entry will be applied atomically, but entries will be applied + in arbitrary order + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_mutate_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors + Raises: + - MutationsExceptionGroup if one or more mutations fails + Contains details about any failed entries in .exceptions + - ValueError if invalid arguments are provided + """ + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + + operation = _MutateRowsOperationAsync( + self.client._gapic_client, + self, + mutation_entries, + operation_timeout, + attempt_timeout, + retryable_exceptions=retryable_excs, + ) + await operation.start() + + async def check_and_mutate_row( + self, + row_key: str | bytes, + predicate: RowFilter | None, + *, + true_case_mutations: Mutation | list[Mutation] | None = None, + false_case_mutations: Mutation | list[Mutation] | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> bool: + """ + Mutates a row atomically based on the output of a predicate filter + + Non-idempotent operation: will not be retried + + Args: + - row_key: the key of the row to mutate + - predicate: the filter to be applied to the contents of the specified row. + Depending on whether or not any results are yielded, + either true_case_mutations or false_case_mutations will be executed. + If None, checks that the row contains any values at all. + - true_case_mutations: + Changes to be atomically applied to the specified row if + predicate yields at least one cell when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + false_case_mutations is empty, and at most 100000. + - false_case_mutations: + Changes to be atomically applied to the specified row if + predicate_filter does not yield any cells when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + `true_case_mutations is empty, and at most 100000. + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. Defaults to the Table's default_operation_timeout + Returns: + - bool indicating whether the predicate was true or false + Raises: + - GoogleAPIError exceptions from grpc call + """ + operation_timeout, _ = _get_timeouts(operation_timeout, None, self) + if true_case_mutations is not None and not isinstance( + true_case_mutations, list + ): + true_case_mutations = [true_case_mutations] + true_case_list = [m._to_pb() for m in true_case_mutations or []] + if false_case_mutations is not None and not isinstance( + false_case_mutations, list + ): + false_case_mutations = [false_case_mutations] + false_case_list = [m._to_pb() for m in false_case_mutations or []] + metadata = _make_metadata(self.table_name, self.app_profile_id) + result = await self.client._gapic_client.check_and_mutate_row( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + return result.predicate_matched + + async def read_modify_write_row( + self, + row_key: str | bytes, + rules: ReadModifyWriteRule | list[ReadModifyWriteRule], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> Row: + """ + Reads and modifies a row atomically according to input ReadModifyWriteRules, + and returns the contents of all modified cells + + The new value for the timestamp is the greater of the existing timestamp or + the current server time. + + Non-idempotent operation: will not be retried + + Args: + - row_key: the key of the row to apply read/modify/write rules to + - rules: A rule or set of rules to apply to the row. + Rules are applied in order, meaning that earlier rules will affect the + results of later ones. + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. + Defaults to the Table's default_operation_timeout. + Returns: + - Row: containing cell data that was modified as part of the + operation + Raises: + - GoogleAPIError exceptions from grpc call + - ValueError if invalid arguments are provided + """ + operation_timeout, _ = _get_timeouts(operation_timeout, None, self) + if operation_timeout <= 0: + raise ValueError("operation_timeout must be greater than 0") + if rules is not None and not isinstance(rules, list): + rules = [rules] + if not rules: + raise ValueError("rules must contain at least one item") + metadata = _make_metadata(self.table_name, self.app_profile_id) + result = await self.client._gapic_client.read_modify_write_row( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + # construct Row from result + return Row._from_pb(result.row) + + async def close(self): + """ + Called to close the Table instance and release any resources held by it. + """ + self._register_instance_task.cancel() + await self.client._remove_instance_registration(self.instance_id, self) + + async def __aenter__(self): + """ + Implement async context manager protocol + + Ensure registration task has time to run, so that + grpc channels will be warmed for the specified instance + """ + await self._register_instance_task + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """ + Implement async context manager protocol + + Unregister this instance with the client, so that + grpc channels will no longer be warmed + """ + await self.close() diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py new file mode 100644 index 000000000..5d5dd535e --- /dev/null +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -0,0 +1,501 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import Any, Sequence, TYPE_CHECKING +import asyncio +import atexit +import warnings +from collections import deque + +from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT + +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data._async._mutate_rows import ( + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, +) +from google.cloud.bigtable.data.mutations import Mutation + +if TYPE_CHECKING: + from google.cloud.bigtable.data._async.client import TableAsync + +# used to make more readable default values +_MB_SIZE = 1024 * 1024 + + +class _FlowControlAsync: + """ + Manages flow control for batched mutations. Mutations are registered against + the FlowControl object before being sent, which will block if size or count + limits have reached capacity. As mutations completed, they are removed from + the FlowControl object, which will notify any blocked requests that there + is additional capacity. + + Flow limits are not hard limits. If a single mutation exceeds the configured + limits, it will be allowed as a single batch when the capacity is available. + """ + + def __init__( + self, + max_mutation_count: int, + max_mutation_bytes: int, + ): + """ + Args: + - max_mutation_count: maximum number of mutations to send in a single rpc. + This corresponds to individual mutations in a single RowMutationEntry. + - max_mutation_bytes: maximum number of bytes to send in a single rpc. + """ + self._max_mutation_count = max_mutation_count + self._max_mutation_bytes = max_mutation_bytes + if self._max_mutation_count < 1: + raise ValueError("max_mutation_count must be greater than 0") + if self._max_mutation_bytes < 1: + raise ValueError("max_mutation_bytes must be greater than 0") + self._capacity_condition = asyncio.Condition() + self._in_flight_mutation_count = 0 + self._in_flight_mutation_bytes = 0 + + def _has_capacity(self, additional_count: int, additional_size: int) -> bool: + """ + Checks if there is capacity to send a new entry with the given size and count + + FlowControl limits are not hard limits. If a single mutation exceeds + the configured flow limits, it will be sent in a single batch when + previous batches have completed. + + Args: + - additional_count: number of mutations in the pending entry + - additional_size: size of the pending entry + Returns: + - True if there is capacity to send the pending entry, False otherwise + """ + # adjust limits to allow overly large mutations + acceptable_size = max(self._max_mutation_bytes, additional_size) + acceptable_count = max(self._max_mutation_count, additional_count) + # check if we have capacity for new mutation + new_size = self._in_flight_mutation_bytes + additional_size + new_count = self._in_flight_mutation_count + additional_count + return new_size <= acceptable_size and new_count <= acceptable_count + + async def remove_from_flow( + self, mutations: RowMutationEntry | list[RowMutationEntry] + ) -> None: + """ + Removes mutations from flow control. This method should be called once + for each mutation that was sent to add_to_flow, after the corresponding + operation is complete. + + Args: + - mutations: mutation or list of mutations to remove from flow control + """ + if not isinstance(mutations, list): + mutations = [mutations] + total_count = sum(len(entry.mutations) for entry in mutations) + total_size = sum(entry.size() for entry in mutations) + self._in_flight_mutation_count -= total_count + self._in_flight_mutation_bytes -= total_size + # notify any blocked requests that there is additional capacity + async with self._capacity_condition: + self._capacity_condition.notify_all() + + async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): + """ + Generator function that registers mutations with flow control. As mutations + are accepted into the flow control, they are yielded back to the caller, + to be sent in a batch. If the flow control is at capacity, the generator + will block until there is capacity available. + + Args: + - mutations: list mutations to break up into batches + Yields: + - list of mutations that have reserved space in the flow control. + Each batch contains at least one mutation. + """ + if not isinstance(mutations, list): + mutations = [mutations] + start_idx = 0 + end_idx = 0 + while end_idx < len(mutations): + start_idx = end_idx + batch_mutation_count = 0 + # fill up batch until we hit capacity + async with self._capacity_condition: + while end_idx < len(mutations): + next_entry = mutations[end_idx] + next_size = next_entry.size() + next_count = len(next_entry.mutations) + if ( + self._has_capacity(next_count, next_size) + # make sure not to exceed per-request mutation count limits + and (batch_mutation_count + next_count) + <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + ): + # room for new mutation; add to batch + end_idx += 1 + batch_mutation_count += next_count + self._in_flight_mutation_bytes += next_size + self._in_flight_mutation_count += next_count + elif start_idx != end_idx: + # we have at least one mutation in the batch, so send it + break + else: + # batch is empty. Block until we have capacity + await self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) + ) + yield mutations[start_idx:end_idx] + + +class MutationsBatcherAsync: + """ + Allows users to send batches using context manager API: + + Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining + to use as few network requests as required + + Flushes: + - every flush_interval seconds + - after queue reaches flush_count in quantity + - after queue reaches flush_size_bytes in storage size + - when batcher is closed or destroyed + + async with table.mutations_batcher() as batcher: + for i in range(10): + batcher.add(row, mut) + """ + + def __init__( + self, + table: "TableAsync", + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100_000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): + """ + Args: + - table: Table to preform rpc calls + - flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed. + - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + - flow_control_max_mutation_count: Maximum number of inflight mutations. + - flow_control_max_bytes: Maximum number of inflight bytes. + - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. + - batch_attempt_timeout: timeout for each individual request, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + - batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. + """ + self._operation_timeout, self._attempt_timeout = _get_timeouts( + batch_operation_timeout, batch_attempt_timeout, table + ) + self._retryable_errors: list[type[Exception]] = _get_retryable_errors( + batch_retryable_errors, table + ) + + self.closed: bool = False + self._table = table + self._staged_entries: list[RowMutationEntry] = [] + self._staged_count, self._staged_bytes = 0, 0 + self._flow_control = _FlowControlAsync( + flow_control_max_mutation_count, flow_control_max_bytes + ) + self._flush_limit_bytes = flush_limit_bytes + self._flush_limit_count = ( + flush_limit_mutation_count + if flush_limit_mutation_count is not None + else float("inf") + ) + self._flush_timer = self._start_flush_timer(flush_interval) + self._flush_jobs: set[asyncio.Future[None]] = set() + # MutationExceptionGroup reports number of successful entries along with failures + self._entries_processed_since_last_raise: int = 0 + self._exceptions_since_last_raise: int = 0 + # keep track of the first and last _exception_list_limit exceptions + self._exception_list_limit: int = 10 + self._oldest_exceptions: list[Exception] = [] + self._newest_exceptions: deque[Exception] = deque( + maxlen=self._exception_list_limit + ) + # clean up on program exit + atexit.register(self._on_exit) + + def _start_flush_timer(self, interval: float | None) -> asyncio.Future[None]: + """ + Set up a background task to flush the batcher every interval seconds + + If interval is None, an empty future is returned + + Args: + - flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed. + Returns: + - asyncio.Future that represents the background task + """ + if interval is None or self.closed: + empty_future: asyncio.Future[None] = asyncio.Future() + empty_future.set_result(None) + return empty_future + + async def timer_routine(self, interval: float): + """ + Triggers new flush tasks every `interval` seconds + """ + while not self.closed: + await asyncio.sleep(interval) + # add new flush task to list + if not self.closed and self._staged_entries: + self._schedule_flush() + + timer_task = asyncio.create_task(timer_routine(self, interval)) + return timer_task + + async def append(self, mutation_entry: RowMutationEntry): + """ + Add a new set of mutations to the internal queue + + TODO: return a future to track completion of this entry + + Args: + - mutation_entry: new entry to add to flush queue + Raises: + - RuntimeError if batcher is closed + - ValueError if an invalid mutation type is added + """ + if self.closed: + raise RuntimeError("Cannot append to closed MutationsBatcher") + if isinstance(mutation_entry, Mutation): # type: ignore + raise ValueError( + f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" + ) + self._staged_entries.append(mutation_entry) + # start a new flush task if limits exceeded + self._staged_count += len(mutation_entry.mutations) + self._staged_bytes += mutation_entry.size() + if ( + self._staged_count >= self._flush_limit_count + or self._staged_bytes >= self._flush_limit_bytes + ): + self._schedule_flush() + # yield to the event loop to allow flush to run + await asyncio.sleep(0) + + def _schedule_flush(self) -> asyncio.Future[None] | None: + """Update the flush task to include the latest staged entries""" + if self._staged_entries: + entries, self._staged_entries = self._staged_entries, [] + self._staged_count, self._staged_bytes = 0, 0 + new_task = self._create_bg_task(self._flush_internal, entries) + new_task.add_done_callback(self._flush_jobs.remove) + self._flush_jobs.add(new_task) + return new_task + return None + + async def _flush_internal(self, new_entries: list[RowMutationEntry]): + """ + Flushes a set of mutations to the server, and updates internal state + + Args: + - new_entries: list of RowMutationEntry objects to flush + """ + # flush new entries + in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = [] + async for batch in self._flow_control.add_to_flow(new_entries): + batch_task = self._create_bg_task(self._execute_mutate_rows, batch) + in_process_requests.append(batch_task) + # wait for all inflight requests to complete + found_exceptions = await self._wait_for_batch_results(*in_process_requests) + # update exception data to reflect any new errors + self._entries_processed_since_last_raise += len(new_entries) + self._add_exceptions(found_exceptions) + + async def _execute_mutate_rows( + self, batch: list[RowMutationEntry] + ) -> list[FailedMutationEntryError]: + """ + Helper to execute mutation operation on a batch + + Args: + - batch: list of RowMutationEntry objects to send to server + - timeout: timeout in seconds. Used as operation_timeout and attempt_timeout. + If not given, will use table defaults + Returns: + - list of FailedMutationEntryError objects for mutations that failed. + FailedMutationEntryError objects will not contain index information + """ + try: + operation = _MutateRowsOperationAsync( + self._table.client._gapic_client, + self._table, + batch, + operation_timeout=self._operation_timeout, + attempt_timeout=self._attempt_timeout, + retryable_exceptions=self._retryable_errors, + ) + await operation.start() + except MutationsExceptionGroup as e: + # strip index information from exceptions, since it is not useful in a batch context + for subexc in e.exceptions: + subexc.index = None + return list(e.exceptions) + finally: + # mark batch as complete in flow control + await self._flow_control.remove_from_flow(batch) + return [] + + def _add_exceptions(self, excs: list[Exception]): + """ + Add new list of exceptions to internal store. To avoid unbounded memory, + the batcher will store the first and last _exception_list_limit exceptions, + and discard any in between. + """ + self._exceptions_since_last_raise += len(excs) + if excs and len(self._oldest_exceptions) < self._exception_list_limit: + # populate oldest_exceptions with found_exceptions + addition_count = self._exception_list_limit - len(self._oldest_exceptions) + self._oldest_exceptions.extend(excs[:addition_count]) + excs = excs[addition_count:] + if excs: + # populate newest_exceptions with remaining found_exceptions + self._newest_exceptions.extend(excs[-self._exception_list_limit :]) + + def _raise_exceptions(self): + """ + Raise any unreported exceptions from background flush operations + + Raises: + - MutationsExceptionGroup with all unreported exceptions + """ + if self._oldest_exceptions or self._newest_exceptions: + oldest, self._oldest_exceptions = self._oldest_exceptions, [] + newest = list(self._newest_exceptions) + self._newest_exceptions.clear() + entry_count, self._entries_processed_since_last_raise = ( + self._entries_processed_since_last_raise, + 0, + ) + exc_count, self._exceptions_since_last_raise = ( + self._exceptions_since_last_raise, + 0, + ) + raise MutationsExceptionGroup.from_truncated_lists( + first_list=oldest, + last_list=newest, + total_excs=exc_count, + entry_count=entry_count, + ) + + async def __aenter__(self): + """For context manager API""" + return self + + async def __aexit__(self, exc_type, exc, tb): + """For context manager API""" + await self.close() + + async def close(self): + """ + Flush queue and clean up resources + """ + self.closed = True + self._flush_timer.cancel() + self._schedule_flush() + if self._flush_jobs: + await asyncio.gather(*self._flush_jobs, return_exceptions=True) + try: + await self._flush_timer + except asyncio.CancelledError: + pass + atexit.unregister(self._on_exit) + # raise unreported exceptions + self._raise_exceptions() + + def _on_exit(self): + """ + Called when program is exited. Raises warning if unflushed mutations remain + """ + if not self.closed and self._staged_entries: + warnings.warn( + f"MutationsBatcher for table {self._table.table_name} was not closed. " + f"{len(self._staged_entries)} Unflushed mutations will not be sent to the server." + ) + + @staticmethod + def _create_bg_task(func, *args, **kwargs) -> asyncio.Future[Any]: + """ + Create a new background task, and return a future + + This method wraps asyncio to make it easier to maintain subclasses + with different concurrency models. + + Args: + - func: function to execute in background task + - *args: positional arguments to pass to func + - **kwargs: keyword arguments to pass to func + Returns: + - Future object representing the background task + """ + return asyncio.create_task(func(*args, **kwargs)) + + @staticmethod + async def _wait_for_batch_results( + *tasks: asyncio.Future[list[FailedMutationEntryError]] | asyncio.Future[None], + ) -> list[Exception]: + """ + Takes in a list of futures representing _execute_mutate_rows tasks, + waits for them to complete, and returns a list of errors encountered. + + Args: + - *tasks: futures representing _execute_mutate_rows or _flush_internal tasks + Returns: + - list of Exceptions encountered by any of the tasks. Errors are expected + to be FailedMutationEntryError, representing a failed mutation operation. + If a task fails with a different exception, it will be included in the + output list. Successful tasks will not be represented in the output list. + """ + if not tasks: + return [] + all_results = await asyncio.gather(*tasks, return_exceptions=True) + found_errors = [] + for result in all_results: + if isinstance(result, Exception): + # will receive direct Exception objects if request task fails + found_errors.append(result) + elif isinstance(result, BaseException): + # BaseException not expected from grpc calls. Raise immediately + raise result + elif result: + # completed requests will return a list of FailedMutationEntryError + for e in result: + # strip index information + e.index = None + found_errors.extend(result) + return found_errors diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py new file mode 100644 index 000000000..a0b13cbaf --- /dev/null +++ b/google/cloud/bigtable/data/_helpers.py @@ -0,0 +1,220 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Helper functions used in various places in the library. +""" +from __future__ import annotations + +from typing import Sequence, List, Tuple, TYPE_CHECKING +import time +import enum +from collections import namedtuple +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + +from google.api_core import exceptions as core_exceptions +from google.api_core.retry import RetryFailureReason +from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + +if TYPE_CHECKING: + import grpc + from google.cloud.bigtable.data import TableAsync + +""" +Helper functions used in various places in the library. +""" + +# Type alias for the output of sample_keys +RowKeySamples = List[Tuple[bytes, int]] + +# type alias for the output of query.shard() +ShardedQuery = List[ReadRowsQuery] + +# used by read_rows_sharded to limit how many requests are attempted in parallel +_CONCURRENCY_LIMIT = 10 + +# used to register instance data with the client for channel warming +_WarmedInstanceKey = namedtuple( + "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] +) + + +# enum used on method calls when table defaults should be used +class TABLE_DEFAULT(enum.Enum): + # default for mutate_row, sample_row_keys, check_and_mutate_row, and read_modify_write_row + DEFAULT = "DEFAULT" + # default for read_rows, read_rows_stream, read_rows_sharded, row_exists, and read_row + READ_ROWS = "READ_ROWS_DEFAULT" + # default for bulk_mutate_rows and mutations_batcher + MUTATE_ROWS = "MUTATE_ROWS_DEFAULT" + + +def _make_metadata( + table_name: str, app_profile_id: str | None +) -> list[tuple[str, str]]: + """ + Create properly formatted gRPC metadata for requests. + """ + params = [] + params.append(f"table_name={table_name}") + if app_profile_id is not None: + params.append(f"app_profile_id={app_profile_id}") + params_str = "&".join(params) + return [("x-goog-request-params", params_str)] + + +def _attempt_timeout_generator( + per_request_timeout: float | None, operation_timeout: float +): + """ + Generator that yields the timeout value for each attempt of a retry loop. + + Will return per_request_timeout until the operation_timeout is approached, + at which point it will return the remaining time in the operation_timeout. + + Args: + - per_request_timeout: The timeout value to use for each request, in seconds. + If None, the operation_timeout will be used for each request. + - operation_timeout: The timeout value to use for the entire operationm in seconds. + Yields: + - The timeout value to use for the next request, in seonds + """ + per_request_timeout = ( + per_request_timeout if per_request_timeout is not None else operation_timeout + ) + deadline = operation_timeout + time.monotonic() + while True: + yield max(0, min(per_request_timeout, deadline - time.monotonic())) + + +def _retry_exception_factory( + exc_list: list[Exception], + reason: RetryFailureReason, + timeout_val: float | None, +) -> tuple[Exception, Exception | None]: + """ + Build retry error based on exceptions encountered during operation + + Args: + - exc_list: list of exceptions encountered during operation + - is_timeout: whether the operation failed due to timeout + - timeout_val: the operation timeout value in seconds, for constructing + the error message + Returns: + - tuple of the exception to raise, and a cause exception if applicable + """ + if reason == RetryFailureReason.TIMEOUT: + timeout_val_str = f"of {timeout_val:0.1f}s " if timeout_val is not None else "" + # if failed due to timeout, raise deadline exceeded as primary exception + source_exc: Exception = core_exceptions.DeadlineExceeded( + f"operation_timeout{timeout_val_str} exceeded" + ) + elif exc_list: + # otherwise, raise non-retryable error as primary exception + source_exc = exc_list.pop() + else: + source_exc = RuntimeError("failed with unspecified exception") + # use the retry exception group as the cause of the exception + cause_exc: Exception | None = RetryExceptionGroup(exc_list) if exc_list else None + source_exc.__cause__ = cause_exc + return source_exc, cause_exc + + +def _get_timeouts( + operation: float | TABLE_DEFAULT, + attempt: float | None | TABLE_DEFAULT, + table: "TableAsync", +) -> tuple[float, float]: + """ + Convert passed in timeout values to floats, using table defaults if necessary. + + attempt will use operation value if None, or if larger than operation. + + Will call _validate_timeouts on the outputs, and raise ValueError if the + resulting timeouts are invalid. + + Args: + - operation: The timeout value to use for the entire operation, in seconds. + - attempt: The timeout value to use for each attempt, in seconds. + - table: The table to use for default values. + Returns: + - A tuple of (operation_timeout, attempt_timeout) + """ + # load table defaults if necessary + if operation == TABLE_DEFAULT.DEFAULT: + final_operation = table.default_operation_timeout + elif operation == TABLE_DEFAULT.READ_ROWS: + final_operation = table.default_read_rows_operation_timeout + elif operation == TABLE_DEFAULT.MUTATE_ROWS: + final_operation = table.default_mutate_rows_operation_timeout + else: + final_operation = operation + if attempt == TABLE_DEFAULT.DEFAULT: + attempt = table.default_attempt_timeout + elif attempt == TABLE_DEFAULT.READ_ROWS: + attempt = table.default_read_rows_attempt_timeout + elif attempt == TABLE_DEFAULT.MUTATE_ROWS: + attempt = table.default_mutate_rows_attempt_timeout + + if attempt is None: + # no timeout specified, use operation timeout for both + final_attempt = final_operation + else: + # cap attempt timeout at operation timeout + final_attempt = min(attempt, final_operation) if final_operation else attempt + + _validate_timeouts(final_operation, final_attempt, allow_none=False) + return final_operation, final_attempt + + +def _validate_timeouts( + operation_timeout: float, attempt_timeout: float | None, allow_none: bool = False +): + """ + Helper function that will verify that timeout values are valid, and raise + an exception if they are not. + + Args: + - operation_timeout: The timeout value to use for the entire operation, in seconds. + - attempt_timeout: The timeout value to use for each attempt, in seconds. + - allow_none: If True, attempt_timeout can be None. If False, None values will raise an exception. + Raises: + - ValueError if operation_timeout or attempt_timeout are invalid. + """ + if operation_timeout is None: + raise ValueError("operation_timeout cannot be None") + if operation_timeout <= 0: + raise ValueError("operation_timeout must be greater than 0") + if not allow_none and attempt_timeout is None: + raise ValueError("attempt_timeout must not be None") + elif attempt_timeout is not None: + if attempt_timeout <= 0: + raise ValueError("attempt_timeout must be greater than 0") + + +def _get_retryable_errors( + call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT, + table: "TableAsync", +) -> list[type[Exception]]: + # load table defaults if necessary + if call_codes == TABLE_DEFAULT.DEFAULT: + call_codes = table.default_retryable_errors + elif call_codes == TABLE_DEFAULT.READ_ROWS: + call_codes = table.default_read_rows_retryable_errors + elif call_codes == TABLE_DEFAULT.MUTATE_ROWS: + call_codes = table.default_mutate_rows_retryable_errors + + return [ + e if isinstance(e, type) else type(core_exceptions.from_grpc_status(e, "")) + for e in call_codes + ] diff --git a/google/cloud/bigtable/data/exceptions.py b/google/cloud/bigtable/data/exceptions.py new file mode 100644 index 000000000..3c73ec4e9 --- /dev/null +++ b/google/cloud/bigtable/data/exceptions.py @@ -0,0 +1,307 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +import sys + +from typing import Any, TYPE_CHECKING + +from google.api_core import exceptions as core_exceptions +from google.cloud.bigtable.data.row import Row + +is_311_plus = sys.version_info >= (3, 11) + +if TYPE_CHECKING: + from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + +class InvalidChunk(core_exceptions.GoogleAPICallError): + """Exception raised to invalid chunk data from back-end.""" + + +class _RowSetComplete(Exception): + """ + Internal exception for _ReadRowsOperation + Raised in revise_request_rowset when there are no rows left to process when starting a retry attempt + """ + + pass + + +class _MutateRowsIncomplete(RuntimeError): + """ + Exception raised when a mutate_rows call has unfinished work. + """ + + pass + + +class _BigtableExceptionGroup(ExceptionGroup if is_311_plus else Exception): # type: ignore # noqa: F821 + """ + Represents one or more exceptions that occur during a bulk Bigtable operation + + In Python 3.11+, this is an unmodified exception group. In < 3.10, it is a + custom exception with some exception group functionality backported, but does + Not implement the full API + """ + + def __init__(self, message, excs): + if is_311_plus: + super().__init__(message, excs) + else: + if len(excs) == 0: + raise ValueError("exceptions must be a non-empty sequence") + self.exceptions = tuple(excs) + # simulate an exception group in Python < 3.11 by adding exception info + # to the message + first_line = "--+---------------- 1 ----------------" + last_line = "+------------------------------------" + message_parts = [message + "\n" + first_line] + # print error info for each exception in the group + for idx, e in enumerate(excs[:15]): + # apply index header + if idx != 0: + message_parts.append( + f"+---------------- {str(idx+1).rjust(2)} ----------------" + ) + cause = e.__cause__ + # if this exception was had a cause, print the cause first + # used to display root causes of FailedMutationEntryError and FailedQueryShardError + # format matches the error output of Python 3.11+ + if cause is not None: + message_parts.extend( + f"| {type(cause).__name__}: {cause}".splitlines() + ) + message_parts.append("| ") + message_parts.append( + "| The above exception was the direct cause of the following exception:" + ) + message_parts.append("| ") + # attach error message for this sub-exception + # if the subexception is also a _BigtableExceptionGroup, + # error messages will be nested + message_parts.extend(f"| {type(e).__name__}: {e}".splitlines()) + # truncate the message if there are more than 15 exceptions + if len(excs) > 15: + message_parts.append("+---------------- ... ---------------") + message_parts.append(f"| and {len(excs) - 15} more") + if last_line not in message_parts[-1]: + # in the case of nested _BigtableExceptionGroups, the last line + # does not need to be added, since one was added by the final sub-exception + message_parts.append(last_line) + super().__init__("\n ".join(message_parts)) + + def __new__(cls, message, excs): + if is_311_plus: + return super().__new__(cls, message, excs) + else: + return super().__new__(cls) + + def __str__(self): + if is_311_plus: + # don't return built-in sub-exception message + return self.args[0] + return super().__str__() + + def __repr__(self): + """ + repr representation should strip out sub-exception details + """ + if is_311_plus: + return super().__repr__() + message = self.args[0].split("\n")[0] + return f"{self.__class__.__name__}({message!r}, {self.exceptions!r})" + + +class MutationsExceptionGroup(_BigtableExceptionGroup): + """ + Represents one or more exceptions that occur during a bulk mutation operation + + Exceptions will typically be of type FailedMutationEntryError, but other exceptions may + be included if they are raised during the mutation operation + """ + + @staticmethod + def _format_message( + excs: list[Exception], total_entries: int, exc_count: int | None = None + ) -> str: + """ + Format a message for the exception group + + Args: + - excs: the exceptions in the group + - total_entries: the total number of entries attempted, successful or not + - exc_count: the number of exceptions associated with the request + if None, this will be len(excs) + """ + exc_count = exc_count if exc_count is not None else len(excs) + entry_str = "entry" if exc_count == 1 else "entries" + return f"{exc_count} failed {entry_str} from {total_entries} attempted." + + def __init__( + self, excs: list[Exception], total_entries: int, message: str | None = None + ): + """ + Args: + - excs: the exceptions in the group + - total_entries: the total number of entries attempted, successful or not + - message: the message for the exception group. If None, a default message + will be generated + """ + message = ( + message + if message is not None + else self._format_message(excs, total_entries) + ) + super().__init__(message, excs) + self.total_entries_attempted = total_entries + + def __new__( + cls, excs: list[Exception], total_entries: int, message: str | None = None + ): + """ + Args: + - excs: the exceptions in the group + - total_entries: the total number of entries attempted, successful or not + - message: the message for the exception group. If None, a default message + """ + message = ( + message if message is not None else cls._format_message(excs, total_entries) + ) + instance = super().__new__(cls, message, excs) + instance.total_entries_attempted = total_entries + return instance + + @classmethod + def from_truncated_lists( + cls, + first_list: list[Exception], + last_list: list[Exception], + total_excs: int, + entry_count: int, + ) -> MutationsExceptionGroup: + """ + Create a MutationsExceptionGroup from two lists of exceptions, representing + a larger set that has been truncated. The MutationsExceptionGroup will + contain the union of the two lists as sub-exceptions, and the error message + describe the number of exceptions that were truncated. + + Args: + - first_list: the set of oldest exceptions to add to the ExceptionGroup + - last_list: the set of newest exceptions to add to the ExceptionGroup + - total_excs: the total number of exceptions associated with the request + Should be len(first_list) + len(last_list) + number of dropped exceptions + in the middle + - entry_count: the total number of entries attempted, successful or not + """ + first_count, last_count = len(first_list), len(last_list) + if first_count + last_count >= total_excs: + # no exceptions were dropped + return cls(first_list + last_list, entry_count) + excs = first_list + last_list + truncation_count = total_excs - (first_count + last_count) + base_message = cls._format_message(excs, entry_count, total_excs) + first_message = f"first {first_count}" if first_count else "" + last_message = f"last {last_count}" if last_count else "" + conjunction = " and " if first_message and last_message else "" + message = f"{base_message} ({first_message}{conjunction}{last_message} attached as sub-exceptions; {truncation_count} truncated)" + return cls(excs, entry_count, message) + + +class FailedMutationEntryError(Exception): + """ + Represents a single failed RowMutationEntry in a bulk_mutate_rows request. + A collection of FailedMutationEntryErrors will be raised in a MutationsExceptionGroup + """ + + def __init__( + self, + failed_idx: int | None, + failed_mutation_entry: "RowMutationEntry", + cause: Exception, + ): + idempotent_msg = ( + "idempotent" if failed_mutation_entry.is_idempotent() else "non-idempotent" + ) + index_msg = f" at index {failed_idx}" if failed_idx is not None else "" + message = f"Failed {idempotent_msg} mutation entry{index_msg}" + super().__init__(message) + self.__cause__ = cause + self.index = failed_idx + self.entry = failed_mutation_entry + + +class RetryExceptionGroup(_BigtableExceptionGroup): + """Represents one or more exceptions that occur during a retryable operation""" + + @staticmethod + def _format_message(excs: list[Exception]): + if len(excs) == 0: + return "No exceptions" + plural = "s" if len(excs) > 1 else "" + return f"{len(excs)} failed attempt{plural}" + + def __init__(self, excs: list[Exception]): + super().__init__(self._format_message(excs), excs) + + def __new__(cls, excs: list[Exception]): + return super().__new__(cls, cls._format_message(excs), excs) + + +class ShardedReadRowsExceptionGroup(_BigtableExceptionGroup): + """ + Represents one or more exceptions that occur during a sharded read rows operation + """ + + @staticmethod + def _format_message(excs: list[FailedQueryShardError], total_queries: int): + query_str = "query" if total_queries == 1 else "queries" + plural_str = "" if len(excs) == 1 else "s" + return f"{len(excs)} sub-exception{plural_str} (from {total_queries} {query_str} attempted)" + + def __init__( + self, + excs: list[FailedQueryShardError], + succeeded: list[Row], + total_queries: int, + ): + super().__init__(self._format_message(excs, total_queries), excs) + self.successful_rows = succeeded + + def __new__( + cls, excs: list[FailedQueryShardError], succeeded: list[Row], total_queries: int + ): + instance = super().__new__(cls, cls._format_message(excs, total_queries), excs) + instance.successful_rows = succeeded + return instance + + +class FailedQueryShardError(Exception): + """ + Represents an individual failed query in a sharded read rows operation + """ + + def __init__( + self, + failed_index: int, + failed_query: "ReadRowsQuery" | dict[str, Any], + cause: Exception, + ): + message = f"Failed query at index {failed_index}" + super().__init__(message) + self.__cause__ = cause + self.index = failed_index + self.query = failed_query diff --git a/google/cloud/bigtable/data/mutations.py b/google/cloud/bigtable/data/mutations.py new file mode 100644 index 000000000..b5729d25e --- /dev/null +++ b/google/cloud/bigtable/data/mutations.py @@ -0,0 +1,256 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations +from typing import Any +import time +from dataclasses import dataclass +from abc import ABC, abstractmethod +from sys import getsizeof + +import google.cloud.bigtable_v2.types.bigtable as types_pb +import google.cloud.bigtable_v2.types.data as data_pb + +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE + + +# special value for SetCell mutation timestamps. If set, server will assign a timestamp +_SERVER_SIDE_TIMESTAMP = -1 + +# mutation entries above this should be rejected +_MUTATE_ROWS_REQUEST_MUTATION_LIMIT = 100_000 + + +class Mutation(ABC): + """Model class for mutations""" + + @abstractmethod + def _to_dict(self) -> dict[str, Any]: + raise NotImplementedError + + def _to_pb(self) -> data_pb.Mutation: + """ + Convert the mutation to protobuf + """ + return data_pb.Mutation(**self._to_dict()) + + def is_idempotent(self) -> bool: + """ + Check if the mutation is idempotent + If false, the mutation will not be retried + """ + return True + + def __str__(self) -> str: + return str(self._to_dict()) + + def size(self) -> int: + """ + Get the size of the mutation in bytes + """ + return getsizeof(self._to_dict()) + + @classmethod + def _from_dict(cls, input_dict: dict[str, Any]) -> Mutation: + instance: Mutation | None = None + try: + if "set_cell" in input_dict: + details = input_dict["set_cell"] + instance = SetCell( + details["family_name"], + details["column_qualifier"], + details["value"], + details["timestamp_micros"], + ) + elif "delete_from_column" in input_dict: + details = input_dict["delete_from_column"] + time_range = details.get("time_range", {}) + start = time_range.get("start_timestamp_micros", None) + end = time_range.get("end_timestamp_micros", None) + instance = DeleteRangeFromColumn( + details["family_name"], details["column_qualifier"], start, end + ) + elif "delete_from_family" in input_dict: + details = input_dict["delete_from_family"] + instance = DeleteAllFromFamily(details["family_name"]) + elif "delete_from_row" in input_dict: + instance = DeleteAllFromRow() + except KeyError as e: + raise ValueError("Invalid mutation dictionary") from e + if instance is None: + raise ValueError("No valid mutation found") + if not issubclass(instance.__class__, cls): + raise ValueError("Mutation type mismatch") + return instance + + +class SetCell(Mutation): + def __init__( + self, + family: str, + qualifier: bytes | str, + new_value: bytes | str | int, + timestamp_micros: int | None = None, + ): + """ + Mutation to set the value of a cell + + Args: + - family: The name of the column family to which the new cell belongs. + - qualifier: The column qualifier of the new cell. + - new_value: The value of the new cell. str or int input will be converted to bytes + - timestamp_micros: The timestamp of the new cell. If None, the current timestamp will be used. + Timestamps will be sent with milisecond-percision. Extra precision will be truncated. + If -1, the server will assign a timestamp. Note that SetCell mutations with server-side + timestamps are non-idempotent operations and will not be retried. + """ + qualifier = qualifier.encode() if isinstance(qualifier, str) else qualifier + if not isinstance(qualifier, bytes): + raise TypeError("qualifier must be bytes or str") + if isinstance(new_value, str): + new_value = new_value.encode() + elif isinstance(new_value, int): + if abs(new_value) > _MAX_INCREMENT_VALUE: + raise ValueError( + "int values must be between -2**63 and 2**63 (64-bit signed int)" + ) + new_value = new_value.to_bytes(8, "big", signed=True) + if not isinstance(new_value, bytes): + raise TypeError("new_value must be bytes, str, or int") + if timestamp_micros is None: + # use current timestamp, with milisecond precision + timestamp_micros = time.time_ns() // 1000 + timestamp_micros = timestamp_micros - (timestamp_micros % 1000) + if timestamp_micros < _SERVER_SIDE_TIMESTAMP: + raise ValueError( + f"timestamp_micros must be positive (or {_SERVER_SIDE_TIMESTAMP} for server-side timestamp)" + ) + self.family = family + self.qualifier = qualifier + self.new_value = new_value + self.timestamp_micros = timestamp_micros + + def _to_dict(self) -> dict[str, Any]: + """Convert the mutation to a dictionary representation""" + return { + "set_cell": { + "family_name": self.family, + "column_qualifier": self.qualifier, + "timestamp_micros": self.timestamp_micros, + "value": self.new_value, + } + } + + def is_idempotent(self) -> bool: + """Check if the mutation is idempotent""" + return self.timestamp_micros != _SERVER_SIDE_TIMESTAMP + + +@dataclass +class DeleteRangeFromColumn(Mutation): + family: str + qualifier: bytes + # None represents 0 + start_timestamp_micros: int | None = None + # None represents infinity + end_timestamp_micros: int | None = None + + def __post_init__(self): + if ( + self.start_timestamp_micros is not None + and self.end_timestamp_micros is not None + and self.start_timestamp_micros > self.end_timestamp_micros + ): + raise ValueError("start_timestamp_micros must be <= end_timestamp_micros") + + def _to_dict(self) -> dict[str, Any]: + timestamp_range = {} + if self.start_timestamp_micros is not None: + timestamp_range["start_timestamp_micros"] = self.start_timestamp_micros + if self.end_timestamp_micros is not None: + timestamp_range["end_timestamp_micros"] = self.end_timestamp_micros + return { + "delete_from_column": { + "family_name": self.family, + "column_qualifier": self.qualifier, + "time_range": timestamp_range, + } + } + + +@dataclass +class DeleteAllFromFamily(Mutation): + family_to_delete: str + + def _to_dict(self) -> dict[str, Any]: + return { + "delete_from_family": { + "family_name": self.family_to_delete, + } + } + + +@dataclass +class DeleteAllFromRow(Mutation): + def _to_dict(self) -> dict[str, Any]: + return { + "delete_from_row": {}, + } + + +class RowMutationEntry: + def __init__(self, row_key: bytes | str, mutations: Mutation | list[Mutation]): + if isinstance(row_key, str): + row_key = row_key.encode("utf-8") + if isinstance(mutations, Mutation): + mutations = [mutations] + if len(mutations) == 0: + raise ValueError("mutations must not be empty") + elif len(mutations) > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: + raise ValueError( + f"entries must have <= {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations" + ) + self.row_key = row_key + self.mutations = tuple(mutations) + + def _to_dict(self) -> dict[str, Any]: + return { + "row_key": self.row_key, + "mutations": [mutation._to_dict() for mutation in self.mutations], + } + + def _to_pb(self) -> types_pb.MutateRowsRequest.Entry: + return types_pb.MutateRowsRequest.Entry( + row_key=self.row_key, + mutations=[mutation._to_pb() for mutation in self.mutations], + ) + + def is_idempotent(self) -> bool: + """Check if the mutation is idempotent""" + return all(mutation.is_idempotent() for mutation in self.mutations) + + def size(self) -> int: + """ + Get the size of the mutation in bytes + """ + return getsizeof(self._to_dict()) + + @classmethod + def _from_dict(cls, input_dict: dict[str, Any]) -> RowMutationEntry: + return RowMutationEntry( + row_key=input_dict["row_key"], + mutations=[ + Mutation._from_dict(mutation) for mutation in input_dict["mutations"] + ], + ) diff --git a/google/cloud/bigtable/data/read_modify_write_rules.py b/google/cloud/bigtable/data/read_modify_write_rules.py new file mode 100644 index 000000000..f43dbe79f --- /dev/null +++ b/google/cloud/bigtable/data/read_modify_write_rules.py @@ -0,0 +1,77 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +import abc + +import google.cloud.bigtable_v2.types.data as data_pb + +# value must fit in 64-bit signed integer +_MAX_INCREMENT_VALUE = (1 << 63) - 1 + + +class ReadModifyWriteRule(abc.ABC): + def __init__(self, family: str, qualifier: bytes | str): + qualifier = ( + qualifier if isinstance(qualifier, bytes) else qualifier.encode("utf-8") + ) + self.family = family + self.qualifier = qualifier + + @abc.abstractmethod + def _to_dict(self) -> dict[str, str | bytes | int]: + raise NotImplementedError + + def _to_pb(self) -> data_pb.ReadModifyWriteRule: + return data_pb.ReadModifyWriteRule(**self._to_dict()) + + +class IncrementRule(ReadModifyWriteRule): + def __init__(self, family: str, qualifier: bytes | str, increment_amount: int = 1): + if not isinstance(increment_amount, int): + raise TypeError("increment_amount must be an integer") + if abs(increment_amount) > _MAX_INCREMENT_VALUE: + raise ValueError( + "increment_amount must be between -2**63 and 2**63 (64-bit signed int)" + ) + super().__init__(family, qualifier) + self.increment_amount = increment_amount + + def _to_dict(self) -> dict[str, str | bytes | int]: + return { + "family_name": self.family, + "column_qualifier": self.qualifier, + "increment_amount": self.increment_amount, + } + + +class AppendValueRule(ReadModifyWriteRule): + def __init__(self, family: str, qualifier: bytes | str, append_value: bytes | str): + append_value = ( + append_value.encode("utf-8") + if isinstance(append_value, str) + else append_value + ) + if not isinstance(append_value, bytes): + raise TypeError("append_value must be bytes or str") + super().__init__(family, qualifier) + self.append_value = append_value + + def _to_dict(self) -> dict[str, str | bytes | int]: + return { + "family_name": self.family, + "column_qualifier": self.qualifier, + "append_value": self.append_value, + } diff --git a/google/cloud/bigtable/data/read_rows_query.py b/google/cloud/bigtable/data/read_rows_query.py new file mode 100644 index 000000000..362f54c3e --- /dev/null +++ b/google/cloud/bigtable/data/read_rows_query.py @@ -0,0 +1,476 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations +from typing import TYPE_CHECKING, Any +from bisect import bisect_left +from bisect import bisect_right +from collections import defaultdict +from google.cloud.bigtable.data.row_filters import RowFilter + +from google.cloud.bigtable_v2.types import RowRange as RowRangePB +from google.cloud.bigtable_v2.types import RowSet as RowSetPB +from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB + +if TYPE_CHECKING: + from google.cloud.bigtable.data import RowKeySamples + from google.cloud.bigtable.data import ShardedQuery + + +class RowRange: + """ + Represents a range of keys in a ReadRowsQuery + """ + + __slots__ = ("_pb",) + + def __init__( + self, + start_key: str | bytes | None = None, + end_key: str | bytes | None = None, + start_is_inclusive: bool | None = None, + end_is_inclusive: bool | None = None, + ): + """ + Args: + - start_key: The start key of the range. If empty, the range is unbounded on the left. + - end_key: The end key of the range. If empty, the range is unbounded on the right. + - start_is_inclusive: Whether the start key is inclusive. If None, the start key is + inclusive. + - end_is_inclusive: Whether the end key is inclusive. If None, the end key is not inclusive. + Raises: + - ValueError: if start_key is greater than end_key, or start_is_inclusive, + or end_is_inclusive is set when the corresponding key is None, + or start_key or end_key is not a string or bytes. + """ + # convert empty key inputs to None for consistency + start_key = None if not start_key else start_key + end_key = None if not end_key else end_key + # check for invalid combinations of arguments + if start_is_inclusive is None: + start_is_inclusive = True + + if end_is_inclusive is None: + end_is_inclusive = False + # ensure that start_key and end_key are bytes + if isinstance(start_key, str): + start_key = start_key.encode() + elif start_key is not None and not isinstance(start_key, bytes): + raise ValueError("start_key must be a string or bytes") + if isinstance(end_key, str): + end_key = end_key.encode() + elif end_key is not None and not isinstance(end_key, bytes): + raise ValueError("end_key must be a string or bytes") + # ensure that start_key is less than or equal to end_key + if start_key is not None and end_key is not None and start_key > end_key: + raise ValueError("start_key must be less than or equal to end_key") + + init_dict = {} + if start_key is not None: + if start_is_inclusive: + init_dict["start_key_closed"] = start_key + else: + init_dict["start_key_open"] = start_key + if end_key is not None: + if end_is_inclusive: + init_dict["end_key_closed"] = end_key + else: + init_dict["end_key_open"] = end_key + self._pb = RowRangePB(**init_dict) + + @property + def start_key(self) -> bytes | None: + """ + Returns the start key of the range. If None, the range is unbounded on the left. + """ + return self._pb.start_key_closed or self._pb.start_key_open or None + + @property + def end_key(self) -> bytes | None: + """ + Returns the end key of the range. If None, the range is unbounded on the right. + """ + return self._pb.end_key_closed or self._pb.end_key_open or None + + @property + def start_is_inclusive(self) -> bool: + """ + Returns whether the range is inclusive of the start key. + Returns True if the range is unbounded on the left. + """ + return not bool(self._pb.start_key_open) + + @property + def end_is_inclusive(self) -> bool: + """ + Returns whether the range is inclusive of the end key. + Returns True if the range is unbounded on the right. + """ + return not bool(self._pb.end_key_open) + + def _to_pb(self) -> RowRangePB: + """Converts this object to a protobuf""" + return self._pb + + @classmethod + def _from_pb(cls, data: RowRangePB) -> RowRange: + """Creates a RowRange from a protobuf""" + instance = cls() + instance._pb = data + return instance + + @classmethod + def _from_dict(cls, data: dict[str, bytes | str]) -> RowRange: + """Creates a RowRange from a protobuf""" + formatted_data = { + k: v.encode() if isinstance(v, str) else v for k, v in data.items() + } + instance = cls() + instance._pb = RowRangePB(**formatted_data) + return instance + + def __bool__(self) -> bool: + """ + Empty RowRanges (representing a full table scan) are falsy, because + they can be substituted with None. Non-empty RowRanges are truthy. + """ + return bool( + self._pb.start_key_closed + or self._pb.start_key_open + or self._pb.end_key_closed + or self._pb.end_key_open + ) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, RowRange): + return NotImplemented + return self._pb == other._pb + + def __str__(self) -> str: + """ + Represent range as a string, e.g. "[b'a', b'z)" + Unbounded start or end keys are represented as "-inf" or "+inf" + """ + left = "[" if self.start_is_inclusive else "(" + right = "]" if self.end_is_inclusive else ")" + start = repr(self.start_key) if self.start_key is not None else "-inf" + end = repr(self.end_key) if self.end_key is not None else "+inf" + return f"{left}{start}, {end}{right}" + + def __repr__(self) -> str: + args_list = [] + args_list.append(f"start_key={self.start_key!r}") + args_list.append(f"end_key={self.end_key!r}") + if self.start_is_inclusive is False: + # only show start_is_inclusive if it is different from the default + args_list.append(f"start_is_inclusive={self.start_is_inclusive}") + if self.end_is_inclusive is True and self.end_key is not None: + # only show end_is_inclusive if it is different from the default + args_list.append(f"end_is_inclusive={self.end_is_inclusive}") + return f"RowRange({', '.join(args_list)})" + + +class ReadRowsQuery: + """ + Class to encapsulate details of a read row request + """ + + slots = ("_limit", "_filter", "_row_set") + + def __init__( + self, + row_keys: list[str | bytes] | str | bytes | None = None, + row_ranges: list[RowRange] | RowRange | None = None, + limit: int | None = None, + row_filter: RowFilter | None = None, + ): + """ + Create a new ReadRowsQuery + + Args: + - row_keys: row keys to include in the query + a query can contain multiple keys, but ranges should be preferred + - row_ranges: ranges of rows to include in the query + - limit: the maximum number of rows to return. None or 0 means no limit + default: None (no limit) + - row_filter: a RowFilter to apply to the query + """ + if row_keys is None: + row_keys = [] + if row_ranges is None: + row_ranges = [] + if not isinstance(row_ranges, list): + row_ranges = [row_ranges] + if not isinstance(row_keys, list): + row_keys = [row_keys] + row_keys = [key.encode() if isinstance(key, str) else key for key in row_keys] + self._row_set = RowSetPB( + row_keys=row_keys, row_ranges=[r._pb for r in row_ranges] + ) + self.limit = limit or None + self.filter = row_filter + + @property + def row_keys(self) -> list[bytes]: + return list(self._row_set.row_keys) + + @property + def row_ranges(self) -> list[RowRange]: + return [RowRange._from_pb(r) for r in self._row_set.row_ranges] + + @property + def limit(self) -> int | None: + return self._limit or None + + @limit.setter + def limit(self, new_limit: int | None): + """ + Set the maximum number of rows to return by this query. + + None or 0 means no limit + + Args: + - new_limit: the new limit to apply to this query + Returns: + - a reference to this query for chaining + Raises: + - ValueError if new_limit is < 0 + """ + if new_limit is not None and new_limit < 0: + raise ValueError("limit must be >= 0") + self._limit = new_limit + + @property + def filter(self) -> RowFilter | None: + return self._filter + + @filter.setter + def filter(self, row_filter: RowFilter | None): + """ + Set a RowFilter to apply to this query + + Args: + - row_filter: a RowFilter to apply to this query + Returns: + - a reference to this query for chaining + """ + self._filter = row_filter + + def add_key(self, row_key: str | bytes): + """ + Add a row key to this query + + A query can contain multiple keys, but ranges should be preferred + + Args: + - row_key: a key to add to this query + Returns: + - a reference to this query for chaining + Raises: + - ValueError if an input is not a string or bytes + """ + if isinstance(row_key, str): + row_key = row_key.encode() + elif not isinstance(row_key, bytes): + raise ValueError("row_key must be string or bytes") + if row_key not in self._row_set.row_keys: + self._row_set.row_keys.append(row_key) + + def add_range( + self, + row_range: RowRange, + ): + """ + Add a range of row keys to this query. + + Args: + - row_range: a range of row keys to add to this query + """ + if row_range not in self.row_ranges: + self._row_set.row_ranges.append(row_range._pb) + + def shard(self, shard_keys: RowKeySamples) -> ShardedQuery: + """ + Split this query into multiple queries that can be evenly distributed + across nodes and run in parallel + + Returns: + - a ShardedQuery that can be used in sharded_read_rows calls + Raises: + - AttributeError if the query contains a limit + """ + if self.limit is not None: + raise AttributeError("Cannot shard query with a limit") + if len(self.row_keys) == 0 and len(self.row_ranges) == 0: + # empty query represents full scan + # ensure that we have at least one key or range + full_scan_query = ReadRowsQuery( + row_ranges=RowRange(), row_filter=self.filter + ) + return full_scan_query.shard(shard_keys) + + sharded_queries: dict[int, ReadRowsQuery] = defaultdict( + lambda: ReadRowsQuery(row_filter=self.filter) + ) + # the split_points divde our key space into segments + # each split_point defines last key that belongs to a segment + # our goal is to break up the query into subqueries that each operate in a single segment + split_points = [sample[0] for sample in shard_keys if sample[0]] + + # handle row_keys + # use binary search to find the segment that each key belongs to + for this_key in list(self.row_keys): + # bisect_left: in case of exact match, pick left side (keys are inclusive ends) + segment_index = bisect_left(split_points, this_key) + sharded_queries[segment_index].add_key(this_key) + + # handle row_ranges + for this_range in self.row_ranges: + # defer to _shard_range helper + for segment_index, added_range in self._shard_range( + this_range, split_points + ): + sharded_queries[segment_index].add_range(added_range) + # return list of queries ordered by segment index + # pull populated segments out of sharded_queries dict + keys = sorted(list(sharded_queries.keys())) + # return list of queries + return [sharded_queries[k] for k in keys] + + @staticmethod + def _shard_range( + orig_range: RowRange, split_points: list[bytes] + ) -> list[tuple[int, RowRange]]: + """ + Helper function for sharding row_range into subranges that fit into + segments of the key-space, determined by split_points + + Args: + - orig_range: a row range to split + - split_points: a list of row keys that define the boundaries of segments. + each point represents the inclusive end of a segment + Returns: + - a list of tuples, containing a segment index and a new sub-range. + """ + # 1. find the index of the segment the start key belongs to + if orig_range.start_key is None: + # if range is open on the left, include first segment + start_segment = 0 + else: + # use binary search to find the segment the start key belongs to + # bisect method determines how we break ties when the start key matches a split point + # if inclusive, bisect_left to the left segment, otherwise bisect_right + bisect = bisect_left if orig_range.start_is_inclusive else bisect_right + start_segment = bisect(split_points, orig_range.start_key) + + # 2. find the index of the segment the end key belongs to + if orig_range.end_key is None: + # if range is open on the right, include final segment + end_segment = len(split_points) + else: + # use binary search to find the segment the end key belongs to. + end_segment = bisect_left( + split_points, orig_range.end_key, lo=start_segment + ) + # note: end_segment will always bisect_left, because split points represent inclusive ends + # whether the end_key is includes the split point or not, the result is the same segment + # 3. create new range definitions for each segment this_range spans + if start_segment == end_segment: + # this_range is contained in a single segment. + # Add this_range to that segment's query only + return [(start_segment, orig_range)] + else: + results: list[tuple[int, RowRange]] = [] + # this_range spans multiple segments. Create a new range for each segment's query + # 3a. add new range for first segment this_range spans + # first range spans from start_key to the split_point representing the last key in the segment + last_key_in_first_segment = split_points[start_segment] + start_range = RowRange( + start_key=orig_range.start_key, + start_is_inclusive=orig_range.start_is_inclusive, + end_key=last_key_in_first_segment, + end_is_inclusive=True, + ) + results.append((start_segment, start_range)) + # 3b. add new range for last segment this_range spans + # we start the final range using the end key from of the previous segment, with is_inclusive=False + previous_segment = end_segment - 1 + last_key_before_segment = split_points[previous_segment] + end_range = RowRange( + start_key=last_key_before_segment, + start_is_inclusive=False, + end_key=orig_range.end_key, + end_is_inclusive=orig_range.end_is_inclusive, + ) + results.append((end_segment, end_range)) + # 3c. add new spanning range to all segments other than the first and last + for this_segment in range(start_segment + 1, end_segment): + prev_segment = this_segment - 1 + prev_end_key = split_points[prev_segment] + this_end_key = split_points[prev_segment + 1] + new_range = RowRange( + start_key=prev_end_key, + start_is_inclusive=False, + end_key=this_end_key, + end_is_inclusive=True, + ) + results.append((this_segment, new_range)) + return results + + def _to_pb(self, table) -> ReadRowsRequestPB: + """ + Convert this query into a dictionary that can be used to construct a + ReadRowsRequest protobuf + """ + return ReadRowsRequestPB( + table_name=table.table_name, + app_profile_id=table.app_profile_id, + filter=self.filter._to_pb() if self.filter else None, + rows_limit=self.limit or 0, + rows=self._row_set, + ) + + def __eq__(self, other): + """ + RowRanges are equal if they have the same row keys, row ranges, + filter and limit, or if they both represent a full scan with the + same filter and limit + """ + if not isinstance(other, ReadRowsQuery): + return False + # empty queries are equal + if len(self.row_keys) == 0 and len(other.row_keys) == 0: + this_range_empty = len(self.row_ranges) == 0 or all( + [bool(r) is False for r in self.row_ranges] + ) + other_range_empty = len(other.row_ranges) == 0 or all( + [bool(r) is False for r in other.row_ranges] + ) + if this_range_empty and other_range_empty: + return self.filter == other.filter and self.limit == other.limit + # otherwise, sets should have same sizes + if len(self.row_keys) != len(other.row_keys): + return False + if len(self.row_ranges) != len(other.row_ranges): + return False + ranges_match = all([row in other.row_ranges for row in self.row_ranges]) + return ( + self.row_keys == other.row_keys + and ranges_match + and self.filter == other.filter + and self.limit == other.limit + ) + + def __repr__(self): + return f"ReadRowsQuery(row_keys={list(self.row_keys)}, row_ranges={list(self.row_ranges)}, row_filter={self.filter}, limit={self.limit})" diff --git a/google/cloud/bigtable/data/row.py b/google/cloud/bigtable/data/row.py new file mode 100644 index 000000000..ecf9cea66 --- /dev/null +++ b/google/cloud/bigtable/data/row.py @@ -0,0 +1,450 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from collections import OrderedDict +from typing import Generator, overload, Any +from functools import total_ordering + +from google.cloud.bigtable_v2.types import Row as RowPB + +# Type aliases used internally for readability. +_family_type = str +_qualifier_type = bytes + + +class Row: + """ + Model class for row data returned from server + + Does not represent all data contained in the row, only data returned by a + query. + Expected to be read-only to users, and written by backend + + Can be indexed: + cells = row["family", "qualifier"] + """ + + __slots__ = ("row_key", "cells", "_index_data") + + def __init__( + self, + key: bytes, + cells: list[Cell], + ): + """ + Initializes a Row object + + Row objects are not intended to be created by users. + They are returned by the Bigtable backend. + """ + self.row_key = key + self.cells: list[Cell] = cells + # index is lazily created when needed + self._index_data: OrderedDict[ + _family_type, OrderedDict[_qualifier_type, list[Cell]] + ] | None = None + + @property + def _index( + self, + ) -> OrderedDict[_family_type, OrderedDict[_qualifier_type, list[Cell]]]: + """ + Returns an index of cells associated with each family and qualifier. + + The index is lazily created when needed + """ + if self._index_data is None: + self._index_data = OrderedDict() + for cell in self.cells: + self._index_data.setdefault(cell.family, OrderedDict()).setdefault( + cell.qualifier, [] + ).append(cell) + return self._index_data + + @classmethod + def _from_pb(cls, row_pb: RowPB) -> Row: + """ + Creates a row from a protobuf representation + + Row objects are not intended to be created by users. + They are returned by the Bigtable backend. + """ + row_key: bytes = row_pb.key + cell_list: list[Cell] = [] + for family in row_pb.families: + for column in family.columns: + for cell in column.cells: + new_cell = Cell( + value=cell.value, + row_key=row_key, + family=family.name, + qualifier=column.qualifier, + timestamp_micros=cell.timestamp_micros, + labels=list(cell.labels) if cell.labels else None, + ) + cell_list.append(new_cell) + return cls(row_key, cells=cell_list) + + def get_cells( + self, family: str | None = None, qualifier: str | bytes | None = None + ) -> list[Cell]: + """ + Returns cells sorted in Bigtable native order: + - Family lexicographically ascending + - Qualifier ascending + - Timestamp in reverse chronological order + + If family or qualifier not passed, will include all + + Can also be accessed through indexing: + cells = row["family", "qualifier"] + cells = row["family"] + """ + if family is None: + if qualifier is not None: + # get_cells(None, "qualifier") is not allowed + raise ValueError("Qualifier passed without family") + else: + # return all cells on get_cells() + return self.cells + if qualifier is None: + # return all cells in family on get_cells(family) + return list(self._get_all_from_family(family)) + if isinstance(qualifier, str): + qualifier = qualifier.encode("utf-8") + # return cells in family and qualifier on get_cells(family, qualifier) + if family not in self._index: + raise ValueError(f"Family '{family}' not found in row '{self.row_key!r}'") + if qualifier not in self._index[family]: + raise ValueError( + f"Qualifier '{qualifier!r}' not found in family '{family}' in row '{self.row_key!r}'" + ) + return self._index[family][qualifier] + + def _get_all_from_family(self, family: str) -> Generator[Cell, None, None]: + """ + Returns all cells in the row for the family_id + """ + if family not in self._index: + raise ValueError(f"Family '{family}' not found in row '{self.row_key!r}'") + for qualifier in self._index[family]: + yield from self._index[family][qualifier] + + def __str__(self) -> str: + """ + Human-readable string representation + + { + (family='fam', qualifier=b'col'): [b'value', (+1 more),], + (family='fam', qualifier=b'col2'): [b'other'], + } + """ + output = ["{"] + for family, qualifier in self._get_column_components(): + cell_list = self[family, qualifier] + line = [f" (family={family!r}, qualifier={qualifier!r}): "] + if len(cell_list) == 0: + line.append("[],") + elif len(cell_list) == 1: + line.append(f"[{cell_list[0]}],") + else: + line.append(f"[{cell_list[0]}, (+{len(cell_list)-1} more)],") + output.append("".join(line)) + output.append("}") + return "\n".join(output) + + def __repr__(self): + cell_str_buffer = ["{"] + for family, qualifier in self._get_column_components(): + cell_list = self[family, qualifier] + repr_list = [cell._to_dict() for cell in cell_list] + cell_str_buffer.append(f" ('{family}', {qualifier!r}): {repr_list},") + cell_str_buffer.append("}") + cell_str = "\n".join(cell_str_buffer) + output = f"Row(key={self.row_key!r}, cells={cell_str})" + return output + + def _to_dict(self) -> dict[str, Any]: + """ + Returns a dictionary representation of the cell in the Bigtable Row + proto format + + https://cloud.google.com/bigtable/docs/reference/data/rpc/google.bigtable.v2#row + """ + family_list = [] + for family_name, qualifier_dict in self._index.items(): + qualifier_list = [] + for qualifier_name, cell_list in qualifier_dict.items(): + cell_dicts = [cell._to_dict() for cell in cell_list] + qualifier_list.append( + {"qualifier": qualifier_name, "cells": cell_dicts} + ) + family_list.append({"name": family_name, "columns": qualifier_list}) + return {"key": self.row_key, "families": family_list} + + # Sequence and Mapping methods + def __iter__(self): + """ + Allow iterating over all cells in the row + """ + return iter(self.cells) + + def __contains__(self, item): + """ + Implements `in` operator + + Works for both cells in the internal list, and `family` or + `(family, qualifier)` pairs associated with the cells + """ + if isinstance(item, _family_type): + return item in self._index + elif ( + isinstance(item, tuple) + and isinstance(item[0], _family_type) + and isinstance(item[1], (bytes, str)) + ): + q = item[1] if isinstance(item[1], bytes) else item[1].encode("utf-8") + return item[0] in self._index and q in self._index[item[0]] + # check if Cell is in Row + return item in self.cells + + @overload + def __getitem__( + self, + index: str | tuple[str, bytes | str], + ) -> list[Cell]: + # overload signature for type checking + pass + + @overload + def __getitem__(self, index: int) -> Cell: + # overload signature for type checking + pass + + @overload + def __getitem__(self, index: slice) -> list[Cell]: + # overload signature for type checking + pass + + def __getitem__(self, index): + """ + Implements [] indexing + + Supports indexing by family, (family, qualifier) pair, + numerical index, and index slicing + """ + if isinstance(index, _family_type): + return self.get_cells(family=index) + elif ( + isinstance(index, tuple) + and isinstance(index[0], _family_type) + and isinstance(index[1], (bytes, str)) + ): + return self.get_cells(family=index[0], qualifier=index[1]) + elif isinstance(index, int) or isinstance(index, slice): + # index is int or slice + return self.cells[index] + else: + raise TypeError( + "Index must be family_id, (family_id, qualifier), int, or slice" + ) + + def __len__(self): + """ + Implements `len()` operator + """ + return len(self.cells) + + def _get_column_components(self) -> list[tuple[str, bytes]]: + """ + Returns a list of (family, qualifier) pairs associated with the cells + + Pairs can be used for indexing + """ + return [(f, q) for f in self._index for q in self._index[f]] + + def __eq__(self, other): + """ + Implements `==` operator + """ + # for performance reasons, check row metadata + # before checking individual cells + if not isinstance(other, Row): + return False + if self.row_key != other.row_key: + return False + if len(self.cells) != len(other.cells): + return False + components = self._get_column_components() + other_components = other._get_column_components() + if len(components) != len(other_components): + return False + if components != other_components: + return False + for family, qualifier in components: + if len(self[family, qualifier]) != len(other[family, qualifier]): + return False + # compare individual cell lists + if self.cells != other.cells: + return False + return True + + def __ne__(self, other) -> bool: + """ + Implements `!=` operator + """ + return not self == other + + +@total_ordering +class Cell: + """ + Model class for cell data + + Does not represent all data contained in the cell, only data returned by a + query. + Expected to be read-only to users, and written by backend + """ + + __slots__ = ( + "value", + "row_key", + "family", + "qualifier", + "timestamp_micros", + "labels", + ) + + def __init__( + self, + value: bytes, + row_key: bytes, + family: str, + qualifier: bytes | str, + timestamp_micros: int, + labels: list[str] | None = None, + ): + """ + Cell constructor + + Cell objects are not intended to be constructed by users. + They are returned by the Bigtable backend. + """ + self.value = value + self.row_key = row_key + self.family = family + if isinstance(qualifier, str): + qualifier = qualifier.encode() + self.qualifier = qualifier + self.timestamp_micros = timestamp_micros + self.labels = labels if labels is not None else [] + + def __int__(self) -> int: + """ + Allows casting cell to int + Interprets value as a 64-bit big-endian signed integer, as expected by + ReadModifyWrite increment rule + """ + return int.from_bytes(self.value, byteorder="big", signed=True) + + def _to_dict(self) -> dict[str, Any]: + """ + Returns a dictionary representation of the cell in the Bigtable Cell + proto format + + https://cloud.google.com/bigtable/docs/reference/data/rpc/google.bigtable.v2#cell + """ + cell_dict: dict[str, Any] = { + "value": self.value, + } + cell_dict["timestamp_micros"] = self.timestamp_micros + if self.labels: + cell_dict["labels"] = self.labels + return cell_dict + + def __str__(self) -> str: + """ + Allows casting cell to str + Prints encoded byte string, same as printing value directly. + """ + return str(self.value) + + def __repr__(self): + """ + Returns a string representation of the cell + """ + return f"Cell(value={self.value!r}, row_key={self.row_key!r}, family='{self.family}', qualifier={self.qualifier!r}, timestamp_micros={self.timestamp_micros}, labels={self.labels})" + + """For Bigtable native ordering""" + + def __lt__(self, other) -> bool: + """ + Implements `<` operator + """ + if not isinstance(other, Cell): + return NotImplemented + this_ordering = ( + self.family, + self.qualifier, + -self.timestamp_micros, + self.value, + self.labels, + ) + other_ordering = ( + other.family, + other.qualifier, + -other.timestamp_micros, + other.value, + other.labels, + ) + return this_ordering < other_ordering + + def __eq__(self, other) -> bool: + """ + Implements `==` operator + """ + if not isinstance(other, Cell): + return NotImplemented + return ( + self.row_key == other.row_key + and self.family == other.family + and self.qualifier == other.qualifier + and self.value == other.value + and self.timestamp_micros == other.timestamp_micros + and len(self.labels) == len(other.labels) + and all([label in other.labels for label in self.labels]) + ) + + def __ne__(self, other) -> bool: + """ + Implements `!=` operator + """ + return not self == other + + def __hash__(self): + """ + Implements `hash()` function to fingerprint cell + """ + return hash( + ( + self.row_key, + self.family, + self.qualifier, + self.value, + self.timestamp_micros, + tuple(self.labels), + ) + ) diff --git a/google/cloud/bigtable/data/row_filters.py b/google/cloud/bigtable/data/row_filters.py new file mode 100644 index 000000000..9f09133d5 --- /dev/null +++ b/google/cloud/bigtable/data/row_filters.py @@ -0,0 +1,968 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Filters for Google Cloud Bigtable Row classes.""" +from __future__ import annotations + +import struct + +from typing import Any, Sequence, TYPE_CHECKING, overload +from abc import ABC, abstractmethod + +from google.cloud._helpers import _microseconds_from_datetime # type: ignore +from google.cloud._helpers import _to_bytes # type: ignore +from google.cloud.bigtable_v2.types import data as data_v2_pb2 + +if TYPE_CHECKING: + # import dependencies when type checking + from datetime import datetime + +_PACK_I64 = struct.Struct(">q").pack + + +class RowFilter(ABC): + """Basic filter to apply to cells in a row. + + These values can be combined via :class:`RowFilterChain`, + :class:`RowFilterUnion` and :class:`ConditionalRowFilter`. + + .. note:: + + This class is a do-nothing base class for all row filters. + """ + + def _to_pb(self) -> data_v2_pb2.RowFilter: + """Converts the row filter to a protobuf. + + Returns: The converted current object. + """ + return data_v2_pb2.RowFilter(**self._to_dict()) + + @abstractmethod + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + pass + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class _BoolFilter(RowFilter, ABC): + """Row filter that uses a boolean flag. + + :type flag: bool + :param flag: An indicator if a setting is turned on or off. + """ + + def __init__(self, flag: bool): + self.flag = flag + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other.flag == self.flag + + def __ne__(self, other): + return not self == other + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(flag={self.flag})" + + +class SinkFilter(_BoolFilter): + """Advanced row filter to skip parent filters. + + :type flag: bool + :param flag: ADVANCED USE ONLY. Hook for introspection into the row filter. + Outputs all cells directly to the output of the read rather + than to any parent filter. Cannot be used within the + ``predicate_filter``, ``true_filter``, or ``false_filter`` + of a :class:`ConditionalRowFilter`. + """ + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"sink": self.flag} + + +class PassAllFilter(_BoolFilter): + """Row filter equivalent to not filtering at all. + + :type flag: bool + :param flag: Matches all cells, regardless of input. Functionally + equivalent to leaving ``filter`` unset, but included for + completeness. + """ + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"pass_all_filter": self.flag} + + +class BlockAllFilter(_BoolFilter): + """Row filter that doesn't match any cells. + + :type flag: bool + :param flag: Does not match any cells, regardless of input. Useful for + temporarily disabling just part of a filter. + """ + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"block_all_filter": self.flag} + + +class _RegexFilter(RowFilter, ABC): + """Row filter that uses a regular expression. + + The ``regex`` must be valid RE2 patterns. See Google's + `RE2 reference`_ for the accepted syntax. + + .. _RE2 reference: https://github.com/google/re2/wiki/Syntax + + :type regex: bytes or str + :param regex: + A regular expression (RE2) for some row filter. String values + will be encoded as ASCII. + """ + + def __init__(self, regex: str | bytes): + self.regex: bytes = _to_bytes(regex) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other.regex == self.regex + + def __ne__(self, other): + return not self == other + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(regex={self.regex!r})" + + +class RowKeyRegexFilter(_RegexFilter): + """Row filter for a row key regular expression. + + The ``regex`` must be valid RE2 patterns. See Google's + `RE2 reference`_ for the accepted syntax. + + .. _RE2 reference: https://github.com/google/re2/wiki/Syntax + + .. note:: + + Special care need be used with the expression used. Since + each of these properties can contain arbitrary bytes, the ``\\C`` + escape sequence must be used if a true wildcard is desired. The ``.`` + character will not match the new line character ``\\n``, which may be + present in a binary value. + + :type regex: bytes + :param regex: A regular expression (RE2) to match cells from rows with row + keys that satisfy this regex. For a + ``CheckAndMutateRowRequest``, this filter is unnecessary + since the row key is already specified. + """ + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"row_key_regex_filter": self.regex} + + +class RowSampleFilter(RowFilter): + """Matches all cells from a row with probability p. + + :type sample: float + :param sample: The probability of matching a cell (must be in the + interval ``(0, 1)`` The end points are excluded). + """ + + def __init__(self, sample: float): + self.sample: float = sample + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other.sample == self.sample + + def __ne__(self, other): + return not self == other + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"row_sample_filter": self.sample} + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(sample={self.sample})" + + +class FamilyNameRegexFilter(_RegexFilter): + """Row filter for a family name regular expression. + + The ``regex`` must be valid RE2 patterns. See Google's + `RE2 reference`_ for the accepted syntax. + + .. _RE2 reference: https://github.com/google/re2/wiki/Syntax + + :type regex: str + :param regex: A regular expression (RE2) to match cells from columns in a + given column family. For technical reasons, the regex must + not contain the ``':'`` character, even if it is not being + used as a literal. + """ + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"family_name_regex_filter": self.regex} + + +class ColumnQualifierRegexFilter(_RegexFilter): + """Row filter for a column qualifier regular expression. + + The ``regex`` must be valid RE2 patterns. See Google's + `RE2 reference`_ for the accepted syntax. + + .. _RE2 reference: https://github.com/google/re2/wiki/Syntax + + .. note:: + + Special care need be used with the expression used. Since + each of these properties can contain arbitrary bytes, the ``\\C`` + escape sequence must be used if a true wildcard is desired. The ``.`` + character will not match the new line character ``\\n``, which may be + present in a binary value. + + :type regex: bytes + :param regex: A regular expression (RE2) to match cells from column that + match this regex (irrespective of column family). + """ + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"column_qualifier_regex_filter": self.regex} + + +class TimestampRange(object): + """Range of time with inclusive lower and exclusive upper bounds. + + :type start: :class:`datetime.datetime` + :param start: (Optional) The (inclusive) lower bound of the timestamp + range. If omitted, defaults to Unix epoch. + + :type end: :class:`datetime.datetime` + :param end: (Optional) The (exclusive) upper bound of the timestamp + range. If omitted, no upper bound is used. + """ + + def __init__(self, start: "datetime" | None = None, end: "datetime" | None = None): + self.start: "datetime" | None = start + self.end: "datetime" | None = end + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other.start == self.start and other.end == self.end + + def __ne__(self, other): + return not self == other + + def _to_pb(self) -> data_v2_pb2.TimestampRange: + """Converts the :class:`TimestampRange` to a protobuf. + + Returns: The converted current object. + """ + return data_v2_pb2.TimestampRange(**self._to_dict()) + + def _to_dict(self) -> dict[str, int]: + """Converts the timestamp range to a dict representation.""" + timestamp_range_kwargs = {} + if self.start is not None: + start_time = _microseconds_from_datetime(self.start) // 1000 * 1000 + timestamp_range_kwargs["start_timestamp_micros"] = start_time + if self.end is not None: + end_time = _microseconds_from_datetime(self.end) + if end_time % 1000 != 0: + # if not a whole milisecond value, round up + end_time = end_time // 1000 * 1000 + 1000 + timestamp_range_kwargs["end_timestamp_micros"] = end_time + return timestamp_range_kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(start={self.start}, end={self.end})" + + +class TimestampRangeFilter(RowFilter): + """Row filter that limits cells to a range of time. + + :type range_: :class:`TimestampRange` + :param range_: Range of time that cells should match against. + """ + + def __init__(self, start: "datetime" | None = None, end: "datetime" | None = None): + self.range_: TimestampRange = TimestampRange(start, end) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other.range_ == self.range_ + + def __ne__(self, other): + return not self == other + + def _to_pb(self) -> data_v2_pb2.RowFilter: + """Converts the row filter to a protobuf. + + First converts the ``range_`` on the current object to a protobuf and + then uses it in the ``timestamp_range_filter`` field. + + Returns: The converted current object. + """ + return data_v2_pb2.RowFilter(timestamp_range_filter=self.range_._to_pb()) + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"timestamp_range_filter": self.range_._to_dict()} + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(start={self.range_.start!r}, end={self.range_.end!r})" + + +class ColumnRangeFilter(RowFilter): + """A row filter to restrict to a range of columns. + + Both the start and end column can be included or excluded in the range. + By default, we include them both, but this can be changed with optional + flags. + + :type family_id: str + :param family_id: The column family that contains the columns. Must + be of the form ``[_a-zA-Z0-9][-_.a-zA-Z0-9]*``. + + :type start_qualifier: bytes + :param start_qualifier: The start of the range of columns. If no value is + used, the backend applies no upper bound to the + values. + + :type end_qualifier: bytes + :param end_qualifier: The end of the range of columns. If no value is used, + the backend applies no upper bound to the values. + + :type inclusive_start: bool + :param inclusive_start: Boolean indicating if the start column should be + included in the range (or excluded). Defaults + to :data:`True` if ``start_qualifier`` is passed and + no ``inclusive_start`` was given. + + :type inclusive_end: bool + :param inclusive_end: Boolean indicating if the end column should be + included in the range (or excluded). Defaults + to :data:`True` if ``end_qualifier`` is passed and + no ``inclusive_end`` was given. + + :raises: :class:`ValueError ` if ``inclusive_start`` + is set but no ``start_qualifier`` is given or if ``inclusive_end`` + is set but no ``end_qualifier`` is given + """ + + def __init__( + self, + family_id: str, + start_qualifier: bytes | None = None, + end_qualifier: bytes | None = None, + inclusive_start: bool | None = None, + inclusive_end: bool | None = None, + ): + if inclusive_start is None: + inclusive_start = True + elif start_qualifier is None: + raise ValueError( + "inclusive_start was specified but no start_qualifier was given." + ) + if inclusive_end is None: + inclusive_end = True + elif end_qualifier is None: + raise ValueError( + "inclusive_end was specified but no end_qualifier was given." + ) + + self.family_id = family_id + + self.start_qualifier = start_qualifier + self.inclusive_start = inclusive_start + + self.end_qualifier = end_qualifier + self.inclusive_end = inclusive_end + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return ( + other.family_id == self.family_id + and other.start_qualifier == self.start_qualifier + and other.end_qualifier == self.end_qualifier + and other.inclusive_start == self.inclusive_start + and other.inclusive_end == self.inclusive_end + ) + + def __ne__(self, other): + return not self == other + + def _to_pb(self) -> data_v2_pb2.RowFilter: + """Converts the row filter to a protobuf. + + First converts to a :class:`.data_v2_pb2.ColumnRange` and then uses it + in the ``column_range_filter`` field. + + Returns: The converted current object. + """ + column_range = data_v2_pb2.ColumnRange(**self._range_to_dict()) + return data_v2_pb2.RowFilter(column_range_filter=column_range) + + def _range_to_dict(self) -> dict[str, str | bytes]: + """Converts the column range range to a dict representation.""" + column_range_kwargs: dict[str, str | bytes] = {} + column_range_kwargs["family_name"] = self.family_id + if self.start_qualifier is not None: + if self.inclusive_start: + key = "start_qualifier_closed" + else: + key = "start_qualifier_open" + column_range_kwargs[key] = _to_bytes(self.start_qualifier) + if self.end_qualifier is not None: + if self.inclusive_end: + key = "end_qualifier_closed" + else: + key = "end_qualifier_open" + column_range_kwargs[key] = _to_bytes(self.end_qualifier) + return column_range_kwargs + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"column_range_filter": self._range_to_dict()} + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(family_id='{self.family_id}', start_qualifier={self.start_qualifier!r}, end_qualifier={self.end_qualifier!r}, inclusive_start={self.inclusive_start}, inclusive_end={self.inclusive_end})" + + +class ValueRegexFilter(_RegexFilter): + """Row filter for a value regular expression. + + The ``regex`` must be valid RE2 patterns. See Google's + `RE2 reference`_ for the accepted syntax. + + .. _RE2 reference: https://github.com/google/re2/wiki/Syntax + + .. note:: + + Special care need be used with the expression used. Since + each of these properties can contain arbitrary bytes, the ``\\C`` + escape sequence must be used if a true wildcard is desired. The ``.`` + character will not match the new line character ``\\n``, which may be + present in a binary value. + + :type regex: bytes or str + :param regex: A regular expression (RE2) to match cells with values that + match this regex. String values will be encoded as ASCII. + """ + + def _to_dict(self) -> dict[str, bytes]: + """Converts the row filter to a dict representation.""" + return {"value_regex_filter": self.regex} + + +class LiteralValueFilter(ValueRegexFilter): + """Row filter for an exact value. + + + :type value: bytes or str or int + :param value: + a literal string, integer, or the equivalent bytes. + Integer values will be packed into signed 8-bytes. + """ + + def __init__(self, value: bytes | str | int): + if isinstance(value, int): + value = _PACK_I64(value) + elif isinstance(value, str): + value = value.encode("utf-8") + value = self._write_literal_regex(value) + super(LiteralValueFilter, self).__init__(value) + + @staticmethod + def _write_literal_regex(input_bytes: bytes) -> bytes: + """ + Escape re2 special characters from literal bytes. + + Extracted from: re2 QuoteMeta: + https://github.com/google/re2/blob/70f66454c255080a54a8da806c52d1f618707f8a/re2/re2.cc#L456 + """ + result = bytearray() + for byte in input_bytes: + # If this is the part of a UTF8 or Latin1 character, we need \ + # to copy this byte without escaping. Experimentally this is \ + # what works correctly with the regexp library. \ + utf8_latin1_check = (byte & 128) == 0 + if ( + (byte < ord("a") or byte > ord("z")) + and (byte < ord("A") or byte > ord("Z")) + and (byte < ord("0") or byte > ord("9")) + and byte != ord("_") + and utf8_latin1_check + ): + if byte == 0: + # Special handling for null chars. + # Note that this special handling is not strictly required for RE2, + # but this quoting is required for other regexp libraries such as + # PCRE. + # Can't use "\\0" since the next character might be a digit. + result.extend([ord("\\"), ord("x"), ord("0"), ord("0")]) + continue + result.append(ord(b"\\")) + result.append(byte) + return bytes(result) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(value={self.regex!r})" + + +class ValueRangeFilter(RowFilter): + """A range of values to restrict to in a row filter. + + Will only match cells that have values in this range. + + Both the start and end value can be included or excluded in the range. + By default, we include them both, but this can be changed with optional + flags. + + :type start_value: bytes + :param start_value: The start of the range of values. If no value is used, + the backend applies no lower bound to the values. + + :type end_value: bytes + :param end_value: The end of the range of values. If no value is used, + the backend applies no upper bound to the values. + + :type inclusive_start: bool + :param inclusive_start: Boolean indicating if the start value should be + included in the range (or excluded). Defaults + to :data:`True` if ``start_value`` is passed and + no ``inclusive_start`` was given. + + :type inclusive_end: bool + :param inclusive_end: Boolean indicating if the end value should be + included in the range (or excluded). Defaults + to :data:`True` if ``end_value`` is passed and + no ``inclusive_end`` was given. + + :raises: :class:`ValueError ` if ``inclusive_start`` + is set but no ``start_value`` is given or if ``inclusive_end`` + is set but no ``end_value`` is given + """ + + def __init__( + self, + start_value: bytes | int | None = None, + end_value: bytes | int | None = None, + inclusive_start: bool | None = None, + inclusive_end: bool | None = None, + ): + if inclusive_start is None: + inclusive_start = True + elif start_value is None: + raise ValueError( + "inclusive_start was specified but no start_value was given." + ) + if inclusive_end is None: + inclusive_end = True + elif end_value is None: + raise ValueError( + "inclusive_end was specified but no end_qualifier was given." + ) + if isinstance(start_value, int): + start_value = _PACK_I64(start_value) + self.start_value = start_value + self.inclusive_start = inclusive_start + + if isinstance(end_value, int): + end_value = _PACK_I64(end_value) + self.end_value = end_value + self.inclusive_end = inclusive_end + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return ( + other.start_value == self.start_value + and other.end_value == self.end_value + and other.inclusive_start == self.inclusive_start + and other.inclusive_end == self.inclusive_end + ) + + def __ne__(self, other): + return not self == other + + def _to_pb(self) -> data_v2_pb2.RowFilter: + """Converts the row filter to a protobuf. + + First converts to a :class:`.data_v2_pb2.ValueRange` and then uses + it to create a row filter protobuf. + + Returns: The converted current object. + """ + value_range = data_v2_pb2.ValueRange(**self._range_to_dict()) + return data_v2_pb2.RowFilter(value_range_filter=value_range) + + def _range_to_dict(self) -> dict[str, bytes]: + """Converts the value range range to a dict representation.""" + value_range_kwargs = {} + if self.start_value is not None: + if self.inclusive_start: + key = "start_value_closed" + else: + key = "start_value_open" + value_range_kwargs[key] = _to_bytes(self.start_value) + if self.end_value is not None: + if self.inclusive_end: + key = "end_value_closed" + else: + key = "end_value_open" + value_range_kwargs[key] = _to_bytes(self.end_value) + return value_range_kwargs + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"value_range_filter": self._range_to_dict()} + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(start_value={self.start_value!r}, end_value={self.end_value!r}, inclusive_start={self.inclusive_start}, inclusive_end={self.inclusive_end})" + + +class _CellCountFilter(RowFilter, ABC): + """Row filter that uses an integer count of cells. + + The cell count is used as an offset or a limit for the number + of results returned. + + :type num_cells: int + :param num_cells: An integer count / offset / limit. + """ + + def __init__(self, num_cells: int): + self.num_cells = num_cells + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other.num_cells == self.num_cells + + def __ne__(self, other): + return not self == other + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(num_cells={self.num_cells})" + + +class CellsRowOffsetFilter(_CellCountFilter): + """Row filter to skip cells in a row. + + :type num_cells: int + :param num_cells: Skips the first N cells of the row. + """ + + def _to_dict(self) -> dict[str, int]: + """Converts the row filter to a dict representation.""" + return {"cells_per_row_offset_filter": self.num_cells} + + +class CellsRowLimitFilter(_CellCountFilter): + """Row filter to limit cells in a row. + + :type num_cells: int + :param num_cells: Matches only the first N cells of the row. + """ + + def _to_dict(self) -> dict[str, int]: + """Converts the row filter to a dict representation.""" + return {"cells_per_row_limit_filter": self.num_cells} + + +class CellsColumnLimitFilter(_CellCountFilter): + """Row filter to limit cells in a column. + + :type num_cells: int + :param num_cells: Matches only the most recent N cells within each column. + This filters a (family name, column) pair, based on + timestamps of each cell. + """ + + def _to_dict(self) -> dict[str, int]: + """Converts the row filter to a dict representation.""" + return {"cells_per_column_limit_filter": self.num_cells} + + +class StripValueTransformerFilter(_BoolFilter): + """Row filter that transforms cells into empty string (0 bytes). + + :type flag: bool + :param flag: If :data:`True`, replaces each cell's value with the empty + string. As the name indicates, this is more useful as a + transformer than a generic query / filter. + """ + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"strip_value_transformer": self.flag} + + +class ApplyLabelFilter(RowFilter): + """Filter to apply labels to cells. + + Intended to be used as an intermediate filter on a pre-existing filtered + result set. This way if two sets are combined, the label can tell where + the cell(s) originated.This allows the client to determine which results + were produced from which part of the filter. + + .. note:: + + Due to a technical limitation of the backend, it is not currently + possible to apply multiple labels to a cell. + + :type label: str + :param label: Label to apply to cells in the output row. Values must be + at most 15 characters long, and match the pattern + ``[a-z0-9\\-]+``. + """ + + def __init__(self, label: str): + self.label = label + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other.label == self.label + + def __ne__(self, other): + return not self == other + + def _to_dict(self) -> dict[str, str]: + """Converts the row filter to a dict representation.""" + return {"apply_label_transformer": self.label} + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(label={self.label})" + + +class _FilterCombination(RowFilter, Sequence[RowFilter], ABC): + """Chain of row filters. + + Sends rows through several filters in sequence. The filters are "chained" + together to process a row. After the first filter is applied, the second + is applied to the filtered output and so on for subsequent filters. + + :type filters: list + :param filters: List of :class:`RowFilter` + """ + + def __init__(self, filters: list[RowFilter] | None = None): + if filters is None: + filters = [] + self.filters: list[RowFilter] = filters + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other.filters == self.filters + + def __ne__(self, other): + return not self == other + + def __len__(self) -> int: + return len(self.filters) + + @overload + def __getitem__(self, index: int) -> RowFilter: + # overload signature for type checking + pass + + @overload + def __getitem__(self, index: slice) -> list[RowFilter]: + # overload signature for type checking + pass + + def __getitem__(self, index): + return self.filters[index] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(filters={self.filters})" + + def __str__(self) -> str: + """ + Returns a string representation of the filter chain. + + Adds line breaks between each sub-filter for readability. + """ + output = [f"{self.__class__.__name__}(["] + for filter_ in self.filters: + filter_lines = f"{filter_},".splitlines() + output.extend([f" {line}" for line in filter_lines]) + output.append("])") + return "\n".join(output) + + +class RowFilterChain(_FilterCombination): + """Chain of row filters. + + Sends rows through several filters in sequence. The filters are "chained" + together to process a row. After the first filter is applied, the second + is applied to the filtered output and so on for subsequent filters. + + :type filters: list + :param filters: List of :class:`RowFilter` + """ + + def _to_pb(self) -> data_v2_pb2.RowFilter: + """Converts the row filter to a protobuf. + + Returns: The converted current object. + """ + chain = data_v2_pb2.RowFilter.Chain( + filters=[row_filter._to_pb() for row_filter in self.filters] + ) + return data_v2_pb2.RowFilter(chain=chain) + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"chain": {"filters": [f._to_dict() for f in self.filters]}} + + +class RowFilterUnion(_FilterCombination): + """Union of row filters. + + Sends rows through several filters simultaneously, then + merges / interleaves all the filtered results together. + + If multiple cells are produced with the same column and timestamp, + they will all appear in the output row in an unspecified mutual order. + + :type filters: list + :param filters: List of :class:`RowFilter` + """ + + def _to_pb(self) -> data_v2_pb2.RowFilter: + """Converts the row filter to a protobuf. + + Returns: The converted current object. + """ + interleave = data_v2_pb2.RowFilter.Interleave( + filters=[row_filter._to_pb() for row_filter in self.filters] + ) + return data_v2_pb2.RowFilter(interleave=interleave) + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"interleave": {"filters": [f._to_dict() for f in self.filters]}} + + +class ConditionalRowFilter(RowFilter): + """Conditional row filter which exhibits ternary behavior. + + Executes one of two filters based on another filter. If the ``predicate_filter`` + returns any cells in the row, then ``true_filter`` is executed. If not, + then ``false_filter`` is executed. + + .. note:: + + The ``predicate_filter`` does not execute atomically with the true and false + filters, which may lead to inconsistent or unexpected results. + + Additionally, executing a :class:`ConditionalRowFilter` has poor + performance on the server, especially when ``false_filter`` is set. + + :type predicate_filter: :class:`RowFilter` + :param predicate_filter: The filter to condition on before executing the + true/false filters. + + :type true_filter: :class:`RowFilter` + :param true_filter: (Optional) The filter to execute if there are any cells + matching ``predicate_filter``. If not provided, no results + will be returned in the true case. + + :type false_filter: :class:`RowFilter` + :param false_filter: (Optional) The filter to execute if there are no cells + matching ``predicate_filter``. If not provided, no results + will be returned in the false case. + """ + + def __init__( + self, + predicate_filter: RowFilter, + true_filter: RowFilter | None = None, + false_filter: RowFilter | None = None, + ): + self.predicate_filter = predicate_filter + self.true_filter = true_filter + self.false_filter = false_filter + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return ( + other.predicate_filter == self.predicate_filter + and other.true_filter == self.true_filter + and other.false_filter == self.false_filter + ) + + def __ne__(self, other): + return not self == other + + def _to_pb(self) -> data_v2_pb2.RowFilter: + """Converts the row filter to a protobuf. + + Returns: The converted current object. + """ + condition_kwargs = {"predicate_filter": self.predicate_filter._to_pb()} + if self.true_filter is not None: + condition_kwargs["true_filter"] = self.true_filter._to_pb() + if self.false_filter is not None: + condition_kwargs["false_filter"] = self.false_filter._to_pb() + condition = data_v2_pb2.RowFilter.Condition(**condition_kwargs) + return data_v2_pb2.RowFilter(condition=condition) + + def _condition_to_dict(self) -> dict[str, Any]: + """Converts the condition to a dict representation.""" + condition_kwargs = {"predicate_filter": self.predicate_filter._to_dict()} + if self.true_filter is not None: + condition_kwargs["true_filter"] = self.true_filter._to_dict() + if self.false_filter is not None: + condition_kwargs["false_filter"] = self.false_filter._to_dict() + return condition_kwargs + + def _to_dict(self) -> dict[str, Any]: + """Converts the row filter to a dict representation.""" + return {"condition": self._condition_to_dict()} + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(predicate_filter={self.predicate_filter!r}, true_filter={self.true_filter!r}, false_filter={self.false_filter!r})" + + def __str__(self) -> str: + output = [f"{self.__class__.__name__}("] + for filter_type in ("predicate_filter", "true_filter", "false_filter"): + filter_ = getattr(self, filter_type) + if filter_ is None: + continue + # add the new filter set, adding indentations for readability + filter_lines = f"{filter_type}={filter_},".splitlines() + output.extend(f" {line}" for line in filter_lines) + output.append(")") + return "\n".join(output) diff --git a/google/cloud/bigtable/py.typed b/google/cloud/bigtable/py.typed deleted file mode 100644 index 7bd4705d4..000000000 --- a/google/cloud/bigtable/py.typed +++ /dev/null @@ -1,2 +0,0 @@ -# Marker file for PEP 561. -# The google-cloud-bigtable package uses inline types. diff --git a/google/cloud/bigtable_v2/services/bigtable/async_client.py b/google/cloud/bigtable_v2/services/bigtable/async_client.py index 33686a4a8..df5d7e0de 100644 --- a/google/cloud/bigtable_v2/services/bigtable/async_client.py +++ b/google/cloud/bigtable_v2/services/bigtable/async_client.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import functools from collections import OrderedDict import functools import re @@ -40,9 +41,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.AsyncRetry, object] # type: ignore + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore from google.cloud.bigtable_v2.types import bigtable from google.cloud.bigtable_v2.types import data @@ -272,7 +273,8 @@ def read_rows( "the individual field arguments should be set." ) - request = bigtable.ReadRowsRequest(request) + if not isinstance(request, bigtable.ReadRowsRequest): + request = bigtable.ReadRowsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -283,12 +285,9 @@ def read_rows( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_rows, - default_timeout=43200.0, - client_info=DEFAULT_CLIENT_INFO, - ) - + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_rows + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( @@ -367,7 +366,8 @@ def sample_row_keys( "the individual field arguments should be set." ) - request = bigtable.SampleRowKeysRequest(request) + if not isinstance(request, bigtable.SampleRowKeysRequest): + request = bigtable.SampleRowKeysRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -378,12 +378,9 @@ def sample_row_keys( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.sample_row_keys, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) - + rpc = self._client._transport._wrapped_methods[ + self._client._transport.sample_row_keys + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( @@ -479,7 +476,8 @@ async def mutate_row( "the individual field arguments should be set." ) - request = bigtable.MutateRowRequest(request) + if not isinstance(request, bigtable.MutateRowRequest): + request = bigtable.MutateRowRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -494,21 +492,9 @@ async def mutate_row( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.mutate_row, - default_retry=retries.AsyncRetry( - initial=0.01, - maximum=60.0, - multiplier=2, - predicate=retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ), - deadline=60.0, - ), - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.mutate_row + ] # Certain fields should be provided within the metadata header; # add these here. @@ -601,7 +587,8 @@ def mutate_rows( "the individual field arguments should be set." ) - request = bigtable.MutateRowsRequest(request) + if not isinstance(request, bigtable.MutateRowsRequest): + request = bigtable.MutateRowsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -614,11 +601,9 @@ def mutate_rows( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.mutate_rows, - default_timeout=600.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.mutate_rows + ] # Certain fields should be provided within the metadata header; # add these here. @@ -749,7 +734,8 @@ async def check_and_mutate_row( "the individual field arguments should be set." ) - request = bigtable.CheckAndMutateRowRequest(request) + if not isinstance(request, bigtable.CheckAndMutateRowRequest): + request = bigtable.CheckAndMutateRowRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -768,11 +754,9 @@ async def check_and_mutate_row( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.check_and_mutate_row, - default_timeout=20.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.check_and_mutate_row + ] # Certain fields should be provided within the metadata header; # add these here. @@ -851,7 +835,8 @@ async def ping_and_warm( "the individual field arguments should be set." ) - request = bigtable.PingAndWarmRequest(request) + if not isinstance(request, bigtable.PingAndWarmRequest): + request = bigtable.PingAndWarmRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -862,11 +847,9 @@ async def ping_and_warm( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.ping_and_warm, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.ping_and_warm + ] # Certain fields should be provided within the metadata header; # add these here. @@ -968,7 +951,8 @@ async def read_modify_write_row( "the individual field arguments should be set." ) - request = bigtable.ReadModifyWriteRowRequest(request) + if not isinstance(request, bigtable.ReadModifyWriteRowRequest): + request = bigtable.ReadModifyWriteRowRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -983,11 +967,9 @@ async def read_modify_write_row( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_modify_write_row, - default_timeout=20.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_modify_write_row + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1076,7 +1058,10 @@ def generate_initial_change_stream_partitions( "the individual field arguments should be set." ) - request = bigtable.GenerateInitialChangeStreamPartitionsRequest(request) + if not isinstance( + request, bigtable.GenerateInitialChangeStreamPartitionsRequest + ): + request = bigtable.GenerateInitialChangeStreamPartitionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1174,7 +1159,8 @@ def read_change_stream( "the individual field arguments should be set." ) - request = bigtable.ReadChangeStreamRequest(request) + if not isinstance(request, bigtable.ReadChangeStreamRequest): + request = bigtable.ReadChangeStreamRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. diff --git a/google/cloud/bigtable_v2/services/bigtable/client.py b/google/cloud/bigtable_v2/services/bigtable/client.py index db393faa7..54ba6af43 100644 --- a/google/cloud/bigtable_v2/services/bigtable/client.py +++ b/google/cloud/bigtable_v2/services/bigtable/client.py @@ -43,9 +43,9 @@ from google.oauth2 import service_account # type: ignore try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore + OptionalRetry = Union[retries.Retry, object, None] # type: ignore from google.cloud.bigtable_v2.types import bigtable from google.cloud.bigtable_v2.types import data @@ -53,6 +53,7 @@ from .transports.base import BigtableTransport, DEFAULT_CLIENT_INFO from .transports.grpc import BigtableGrpcTransport from .transports.grpc_asyncio import BigtableGrpcAsyncIOTransport +from .transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport from .transports.rest import BigtableRestTransport @@ -67,6 +68,7 @@ class BigtableClientMeta(type): _transport_registry = OrderedDict() # type: Dict[str, Type[BigtableTransport]] _transport_registry["grpc"] = BigtableGrpcTransport _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport + _transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport _transport_registry["rest"] = BigtableRestTransport def get_transport_class( diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py b/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py index c09443bc2..6a9eb0e58 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py @@ -19,6 +19,7 @@ from .base import BigtableTransport from .grpc import BigtableGrpcTransport from .grpc_asyncio import BigtableGrpcAsyncIOTransport +from .pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport from .rest import BigtableRestTransport from .rest import BigtableRestInterceptor @@ -27,12 +28,14 @@ _transport_registry = OrderedDict() # type: Dict[str, Type[BigtableTransport]] _transport_registry["grpc"] = BigtableGrpcTransport _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport +_transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport _transport_registry["rest"] = BigtableRestTransport __all__ = ( "BigtableTransport", "BigtableGrpcTransport", "BigtableGrpcAsyncIOTransport", + "PooledBigtableGrpcAsyncIOTransport", "BigtableRestTransport", "BigtableRestInterceptor", ) diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py index 2c0cbdad6..1d0a2bc4c 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -512,6 +514,66 @@ def read_change_stream( ) return self._stubs["read_change_stream"] + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.read_rows: gapic_v1.method_async.wrap_method( + self.read_rows, + default_timeout=43200.0, + client_info=client_info, + ), + self.sample_row_keys: gapic_v1.method_async.wrap_method( + self.sample_row_keys, + default_timeout=60.0, + client_info=client_info, + ), + self.mutate_row: gapic_v1.method_async.wrap_method( + self.mutate_row, + default_retry=retries.Retry( + initial=0.01, + maximum=60.0, + multiplier=2, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ), + deadline=60.0, + ), + default_timeout=60.0, + client_info=client_info, + ), + self.mutate_rows: gapic_v1.method_async.wrap_method( + self.mutate_rows, + default_timeout=600.0, + client_info=client_info, + ), + self.check_and_mutate_row: gapic_v1.method_async.wrap_method( + self.check_and_mutate_row, + default_timeout=20.0, + client_info=client_info, + ), + self.ping_and_warm: gapic_v1.method_async.wrap_method( + self.ping_and_warm, + default_timeout=None, + client_info=client_info, + ), + self.read_modify_write_row: gapic_v1.method_async.wrap_method( + self.read_modify_write_row, + default_timeout=20.0, + client_info=client_info, + ), + self.generate_initial_change_stream_partitions: gapic_v1.method_async.wrap_method( + self.generate_initial_change_stream_partitions, + default_timeout=60.0, + client_info=client_info, + ), + self.read_change_stream: gapic_v1.method_async.wrap_method( + self.read_change_stream, + default_timeout=43200.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py new file mode 100644 index 000000000..372e5796d --- /dev/null +++ b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py @@ -0,0 +1,426 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +import warnings +from functools import partialmethod +from functools import partial +from typing import ( + Awaitable, + Callable, + Dict, + Optional, + Sequence, + Tuple, + Union, + List, + Type, +) + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.bigtable_v2.types import bigtable +from .base import BigtableTransport, DEFAULT_CLIENT_INFO +from .grpc_asyncio import BigtableGrpcAsyncIOTransport + + +class PooledMultiCallable: + def __init__(self, channel_pool: "PooledChannel", *args, **kwargs): + self._init_args = args + self._init_kwargs = kwargs + self.next_channel_fn = channel_pool.next_channel + + +class PooledUnaryUnaryMultiCallable(PooledMultiCallable, aio.UnaryUnaryMultiCallable): + def __call__(self, *args, **kwargs) -> aio.UnaryUnaryCall: + return self.next_channel_fn().unary_unary( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledUnaryStreamMultiCallable(PooledMultiCallable, aio.UnaryStreamMultiCallable): + def __call__(self, *args, **kwargs) -> aio.UnaryStreamCall: + return self.next_channel_fn().unary_stream( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledStreamUnaryMultiCallable(PooledMultiCallable, aio.StreamUnaryMultiCallable): + def __call__(self, *args, **kwargs) -> aio.StreamUnaryCall: + return self.next_channel_fn().stream_unary( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledStreamStreamMultiCallable( + PooledMultiCallable, aio.StreamStreamMultiCallable +): + def __call__(self, *args, **kwargs) -> aio.StreamStreamCall: + return self.next_channel_fn().stream_stream( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledChannel(aio.Channel): + def __init__( + self, + pool_size: int = 3, + host: str = "bigtable.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + quota_project_id: Optional[str] = None, + default_scopes: Optional[Sequence[str]] = None, + scopes: Optional[Sequence[str]] = None, + default_host: Optional[str] = None, + insecure: bool = False, + **kwargs, + ): + self._pool: List[aio.Channel] = [] + self._next_idx = 0 + if insecure: + self._create_channel = partial(aio.insecure_channel, host) + else: + self._create_channel = partial( + grpc_helpers_async.create_channel, + target=host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=default_scopes, + scopes=scopes, + default_host=default_host, + **kwargs, + ) + for i in range(pool_size): + self._pool.append(self._create_channel()) + + def next_channel(self) -> aio.Channel: + channel = self._pool[self._next_idx] + self._next_idx = (self._next_idx + 1) % len(self._pool) + return channel + + def unary_unary(self, *args, **kwargs) -> grpc.aio.UnaryUnaryMultiCallable: + return PooledUnaryUnaryMultiCallable(self, *args, **kwargs) + + def unary_stream(self, *args, **kwargs) -> grpc.aio.UnaryStreamMultiCallable: + return PooledUnaryStreamMultiCallable(self, *args, **kwargs) + + def stream_unary(self, *args, **kwargs) -> grpc.aio.StreamUnaryMultiCallable: + return PooledStreamUnaryMultiCallable(self, *args, **kwargs) + + def stream_stream(self, *args, **kwargs) -> grpc.aio.StreamStreamMultiCallable: + return PooledStreamStreamMultiCallable(self, *args, **kwargs) + + async def close(self, grace=None): + close_fns = [channel.close(grace=grace) for channel in self._pool] + return await asyncio.gather(*close_fns) + + async def channel_ready(self): + ready_fns = [channel.channel_ready() for channel in self._pool] + return asyncio.gather(*ready_fns) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + raise NotImplementedError() + + async def wait_for_state_change(self, last_observed_state): + raise NotImplementedError() + + async def replace_channel( + self, channel_idx, grace=None, swap_sleep=1, new_channel=None + ) -> aio.Channel: + """ + Replaces a channel in the pool with a fresh one. + + The `new_channel` will start processing new requests immidiately, + but the old channel will continue serving existing clients for `grace` seconds + + Args: + channel_idx(int): the channel index in the pool to replace + grace(Optional[float]): The time to wait until all active RPCs are + finished. If a grace period is not specified (by passing None for + grace), all existing RPCs are cancelled immediately. + swap_sleep(Optional[float]): The number of seconds to sleep in between + replacing channels and closing the old one + new_channel(grpc.aio.Channel): a new channel to insert into the pool + at `channel_idx`. If `None`, a new channel will be created. + """ + if channel_idx >= len(self._pool) or channel_idx < 0: + raise ValueError( + f"invalid channel_idx {channel_idx} for pool size {len(self._pool)}" + ) + if new_channel is None: + new_channel = self._create_channel() + old_channel = self._pool[channel_idx] + self._pool[channel_idx] = new_channel + await asyncio.sleep(swap_sleep) + await old_channel.close(grace=grace) + return new_channel + + +class PooledBigtableGrpcAsyncIOTransport(BigtableGrpcAsyncIOTransport): + """Pooled gRPC AsyncIO backend transport for Bigtable. + + Service for reading from and writing to existing Bigtable + tables. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + + This class allows channel pooling, so multiple channels can be used concurrently + when making requests. Channels are rotated in a round-robin fashion. + """ + + @classmethod + def with_fixed_size(cls, pool_size) -> Type["PooledBigtableGrpcAsyncIOTransport"]: + """ + Creates a new class with a fixed channel pool size. + + A fixed channel pool makes compatibility with other transports easier, + as the initializer signature is the same. + """ + + class PooledTransportFixed(cls): + __init__ = partialmethod(cls.__init__, pool_size=pool_size) + + PooledTransportFixed.__name__ = f"{cls.__name__}_{pool_size}" + PooledTransportFixed.__qualname__ = PooledTransportFixed.__name__ + return PooledTransportFixed + + @classmethod + def create_channel( + cls, + pool_size: int = 3, + host: str = "bigtable.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a PooledChannel object, representing a pool of gRPC AsyncIO channels + Args: + pool_size (int): The number of channels in the pool. + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + PooledChannel: a channel pool object + """ + + return PooledChannel( + pool_size, + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + pool_size: int = 3, + host: str = "bigtable.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + pool_size (int): the number of grpc channels to maintain in a pool + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + ValueError: if ``pool_size`` <= 0 + """ + if pool_size <= 0: + raise ValueError(f"invalid pool_size: {pool_size}") + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + BigtableTransport.__init__( + self, + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._quota_project_id = quota_project_id + self._grpc_channel = type(self).create_channel( + pool_size, + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=self._quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def pool_size(self) -> int: + """The number of grpc channels in the pool.""" + return len(self._grpc_channel._pool) + + @property + def channels(self) -> List[grpc.Channel]: + """Acccess the internal list of grpc channels.""" + return self._grpc_channel._pool + + async def replace_channel( + self, channel_idx, grace=None, swap_sleep=1, new_channel=None + ) -> aio.Channel: + """ + Replaces a channel in the pool with a fresh one. + + The `new_channel` will start processing new requests immidiately, + but the old channel will continue serving existing clients for `grace` seconds + + Args: + channel_idx(int): the channel index in the pool to replace + grace(Optional[float]): The time to wait until all active RPCs are + finished. If a grace period is not specified (by passing None for + grace), all existing RPCs are cancelled immediately. + swap_sleep(Optional[float]): The number of seconds to sleep in between + replacing channels and closing the old one + new_channel(grpc.aio.Channel): a new channel to insert into the pool + at `channel_idx`. If `None`, a new channel will be created. + """ + return await self._grpc_channel.replace_channel( + channel_idx, grace, swap_sleep, new_channel + ) + + +__all__ = ("PooledBigtableGrpcAsyncIOTransport",) diff --git a/noxfile.py b/noxfile.py index 8550a2b79..daf730a9a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -54,7 +54,9 @@ "pytest", "google-cloud-testutils", ] -SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [] +SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ + "pytest-asyncio", +] SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] SYSTEM_TEST_DEPENDENCIES: List[str] = [] SYSTEM_TEST_EXTRAS: List[str] = [] @@ -134,8 +136,18 @@ def mypy(session): "mypy", "types-setuptools", "types-protobuf", "types-mock", "types-requests" ) session.install("google-cloud-testutils") - # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google", "-p", "tests") + session.run( + "mypy", + "-p", + "google.cloud.bigtable.data", + "--check-untyped-defs", + "--warn-unreachable", + "--disallow-any-generics", + "--exclude", + "tests/system/v2_client", + "--exclude", + "tests/unit/v2_client", + ) @nox.session(python=DEFAULT_PYTHON_VERSION) @@ -260,6 +272,24 @@ def system_emulated(session): os.killpg(os.getpgid(p.pid), signal.SIGKILL) +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def conformance(session): + TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git" + CLONE_REPO_DIR = "cloud-bigtable-clients-test" + # install dependencies + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + install_unittest_dependencies(session, "-c", constraints_path) + with session.chdir("test_proxy"): + # download the conformance test suite + clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR) + if not os.path.exists(clone_dir): + print("downloading copy of test repo") + session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR, external=True) + session.run("bash", "-e", "run_tests.sh", external=True) + + @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) def system(session): """Run the system test suite.""" @@ -311,7 +341,7 @@ def cover(session): test runs (not system test runs), and then erases coverage data. """ session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=100") + session.run("coverage", "report", "--show-missing", "--fail-under=99") session.run("coverage", "erase") diff --git a/owlbot.py b/owlbot.py index 4b06aea77..3fb079396 100644 --- a/owlbot.py +++ b/owlbot.py @@ -89,7 +89,10 @@ def get_staging_dirs( samples=True, # set to True only if there are samples split_system_tests=True, microgenerator=True, - cov_level=100, + cov_level=99, + system_test_external_dependencies=[ + "pytest-asyncio", + ], ) s.move(templated_files, excludes=[".coveragerc", "README.rst", ".github/release-please.yml"]) @@ -142,7 +145,35 @@ def system_emulated(session): escape="()" ) -# add system_emulated nox session +conformance_session = """ +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def conformance(session): + TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git" + CLONE_REPO_DIR = "cloud-bigtable-clients-test" + # install dependencies + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + install_unittest_dependencies(session, "-c", constraints_path) + with session.chdir("test_proxy"): + # download the conformance test suite + clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR) + if not os.path.exists(clone_dir): + print("downloading copy of test repo") + session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR, external=True) + session.run("bash", "-e", "run_tests.sh", external=True) + +""" + +place_before( + "noxfile.py", + "@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)\n" + "def system(session):", + conformance_session, + escape="()" +) + +# add system_emulated and mypy and conformance to nox session s.replace("noxfile.py", """nox.options.sessions = \[ "unit", @@ -168,8 +199,18 @@ def mypy(session): session.install("-e", ".") session.install("mypy", "types-setuptools", "types-protobuf", "types-mock", "types-requests") session.install("google-cloud-testutils") - # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google", "-p", "tests") + session.run( + "mypy", + "-p", + "google.cloud.bigtable.data", + "--check-untyped-defs", + "--warn-unreachable", + "--disallow-any-generics", + "--exclude", + "tests/system/v2_client", + "--exclude", + "tests/unit/v2_client", + ) @nox.session(python=DEFAULT_PYTHON_VERSION) diff --git a/python-api-core b/python-api-core new file mode 160000 index 000000000..17ff5f1d8 --- /dev/null +++ b/python-api-core @@ -0,0 +1 @@ +Subproject commit 17ff5f1d83a9a6f50a0226fb0e794634bd584f17 diff --git a/setup.py b/setup.py index e9bce0960..8b698a35b 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ # 'Development Status :: 5 - Production/Stable' release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-api-core[grpc] >= 1.34.0, <3.0.0dev,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,!=2.10.*", + "google-api-core[grpc] >= 2.16.0, <3.0.0dev", "google-cloud-core >= 1.4.4, <3.0.0dev", "grpc-google-iam-v1 >= 0.12.4, <1.0.0dev", "proto-plus >= 1.22.0, <2.0.0dev", diff --git a/test_proxy/README.md b/test_proxy/README.md new file mode 100644 index 000000000..08741fd5d --- /dev/null +++ b/test_proxy/README.md @@ -0,0 +1,60 @@ +# CBT Python Test Proxy + +The CBT test proxy is intended for running conformance tests for Cloud Bigtable Python Client. + +## Option 1: Run Tests with Nox + +You can run the conformance tests in a single line by calling `nox -s conformance` from the repo root + + +``` +cd python-bigtable/test_proxy +nox -s conformance +``` + +## Option 2: Run processes manually + +### Start test proxy + +You can use `test_proxy.py` to launch a new test proxy process directly + +``` +cd python-bigtable/test_proxy +python test_proxy.py +``` + +The port can be set by passing in an extra positional argument + +``` +cd python-bigtable/test_proxy +python test_proxy.py --port 8080 +``` + +You can run the test proxy against the previous `v2` client by running it with the `--legacy-client` flag: + +``` +python test_proxy.py --legacy-client +``` + +### Run the test cases + +Prerequisites: +- If you have not already done so, [install golang](https://go.dev/doc/install). +- Before running tests, [launch an instance of the test proxy](#start-test-proxy) +in a separate shell session, and make note of the port + + +Clone and navigate to the go test library: + +``` +git clone https://github.com/googleapis/cloud-bigtable-clients-test.git +cd cloud-bigtable-clients-test/tests +``` + + +Launch the tests + +``` +go test -v -proxy_addr=:50055 +``` + diff --git a/test_proxy/handlers/client_handler_data.py b/test_proxy/handlers/client_handler_data.py new file mode 100644 index 000000000..43ff5d634 --- /dev/null +++ b/test_proxy/handlers/client_handler_data.py @@ -0,0 +1,214 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the client handler process for proxy_server.py. +""" +import os + +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.data import BigtableDataClientAsync + + +def error_safe(func): + """ + Catch and pass errors back to the grpc_server_process + Also check if client is closed before processing requests + """ + async def wrapper(self, *args, **kwargs): + try: + if self.closed: + raise RuntimeError("client is closed") + return await func(self, *args, **kwargs) + except (Exception, NotImplementedError) as e: + # exceptions should be raised in grpc_server_process + return encode_exception(e) + + return wrapper + + +def encode_exception(exc): + """ + Encode an exception or chain of exceptions to pass back to grpc_handler + """ + from google.api_core.exceptions import GoogleAPICallError + error_msg = f"{type(exc).__name__}: {exc}" + result = {"error": error_msg} + if exc.__cause__: + result["cause"] = encode_exception(exc.__cause__) + if hasattr(exc, "exceptions"): + result["subexceptions"] = [encode_exception(e) for e in exc.exceptions] + if hasattr(exc, "index"): + result["index"] = exc.index + if isinstance(exc, GoogleAPICallError): + if exc.grpc_status_code is not None: + result["code"] = exc.grpc_status_code.value[0] + elif exc.code is not None: + result["code"] = int(exc.code) + else: + result["code"] = -1 + elif result.get("cause", {}).get("code", None): + # look for code code in cause + result["code"] = result["cause"]["code"] + elif result.get("subexceptions", None): + # look for code in subexceptions + for subexc in result["subexceptions"]: + if subexc.get("code", None): + result["code"] = subexc["code"] + return result + + +class TestProxyClientHandler: + """ + Implements the same methods as the grpc server, but handles the client + library side of the request. + + Requests received in TestProxyGrpcServer are converted to a dictionary, + and supplied to the TestProxyClientHandler methods as kwargs. + The client response is then returned back to the TestProxyGrpcServer + """ + + def __init__( + self, + data_target=None, + project_id=None, + instance_id=None, + app_profile_id=None, + per_operation_timeout=None, + **kwargs, + ): + self.closed = False + # use emulator + os.environ[BIGTABLE_EMULATOR] = data_target + self.client = BigtableDataClientAsync(project=project_id) + self.instance_id = instance_id + self.app_profile_id = app_profile_id + self.per_operation_timeout = per_operation_timeout + + def close(self): + # TODO: call self.client.close() + self.closed = True + + @error_safe + async def ReadRows(self, request, **kwargs): + table_id = request.pop("table_name").split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + result_list = await table.read_rows(request, **kwargs) + # pack results back into protobuf-parsable format + serialized_response = [row._to_dict() for row in result_list] + return serialized_response + + @error_safe + async def ReadRow(self, row_key, **kwargs): + table_id = kwargs.pop("table_name").split("/")[-1] + app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + result_row = await table.read_row(row_key, **kwargs) + # pack results back into protobuf-parsable format + if result_row: + return result_row._to_dict() + else: + return "None" + + @error_safe + async def MutateRow(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import Mutation + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + row_key = request["row_key"] + mutations = [Mutation._from_dict(d) for d in request["mutations"]] + await table.mutate_row(row_key, mutations, **kwargs) + return "OK" + + @error_safe + async def BulkMutateRows(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import RowMutationEntry + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + entry_list = [RowMutationEntry._from_dict(entry) for entry in request["entries"]] + await table.bulk_mutate_rows(entry_list, **kwargs) + return "OK" + + @error_safe + async def CheckAndMutateRow(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import Mutation, SetCell + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + row_key = request["row_key"] + # add default values for incomplete dicts, so they can still be parsed to objects + true_mutations = [] + for mut_dict in request.get("true_mutations", []): + try: + true_mutations.append(Mutation._from_dict(mut_dict)) + except ValueError: + # invalid mutation type. Conformance test may be sending generic empty request + mutation = SetCell("", "", "", 0) + true_mutations.append(mutation) + false_mutations = [] + for mut_dict in request.get("false_mutations", []): + try: + false_mutations.append(Mutation._from_dict(mut_dict)) + except ValueError: + # invalid mutation type. Conformance test may be sending generic empty request + false_mutations.append(SetCell("", "", "", 0)) + predicate_filter = request.get("predicate_filter", None) + result = await table.check_and_mutate_row( + row_key, + predicate_filter, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + **kwargs, + ) + return result + + @error_safe + async def ReadModifyWriteRow(self, request, **kwargs): + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + row_key = request["row_key"] + rules = [] + for rule_dict in request.get("rules", []): + qualifier = rule_dict["column_qualifier"] + if "append_value" in rule_dict: + new_rule = AppendValueRule(rule_dict["family_name"], qualifier, rule_dict["append_value"]) + else: + new_rule = IncrementRule(rule_dict["family_name"], qualifier, rule_dict["increment_amount"]) + rules.append(new_rule) + result = await table.read_modify_write_row(row_key, rules, **kwargs) + # pack results back into protobuf-parsable format + if result: + return result._to_dict() + else: + return "None" + + @error_safe + async def SampleRowKeys(self, request, **kwargs): + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + result = await table.sample_row_keys(**kwargs) + return result diff --git a/test_proxy/handlers/client_handler_legacy.py b/test_proxy/handlers/client_handler_legacy.py new file mode 100644 index 000000000..400f618b5 --- /dev/null +++ b/test_proxy/handlers/client_handler_legacy.py @@ -0,0 +1,235 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the client handler process for proxy_server.py. +""" +import os + +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.client import Client + +import client_handler_data as client_handler + +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +class LegacyTestProxyClientHandler(client_handler.TestProxyClientHandler): + + def __init__( + self, + data_target=None, + project_id=None, + instance_id=None, + app_profile_id=None, + per_operation_timeout=None, + **kwargs, + ): + self.closed = False + # use emulator + os.environ[BIGTABLE_EMULATOR] = data_target + self.client = Client(project=project_id) + self.instance_id = instance_id + self.app_profile_id = app_profile_id + self.per_operation_timeout = per_operation_timeout + + def close(self): + self.closed = True + + @client_handler.error_safe + async def ReadRows(self, request, **kwargs): + table_id = request["table_name"].split("/")[-1] + # app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + instance = self.client.instance(self.instance_id) + table = instance.table(table_id) + + limit = request.get("rows_limit", None) + start_key = request.get("rows", {}).get("row_keys", [None])[0] + end_key = request.get("rows", {}).get("row_keys", [None])[-1] + end_inclusive = request.get("rows", {}).get("row_ranges", [{}])[-1].get("end_key_closed", True) + + row_list = [] + for row in table.read_rows(start_key=start_key, end_key=end_key, limit=limit, end_inclusive=end_inclusive): + # parse results into proto formatted dict + dict_val = {"row_key": row.row_key} + for family, family_cells in row.cells.items(): + family_dict = {"name": family} + for qualifier, qualifier_cells in family_cells.items(): + column_dict = {"qualifier": qualifier} + for cell in qualifier_cells: + cell_dict = { + "value": cell.value, + "timestamp_micros": cell.timestamp.timestamp() * 1000000, + "labels": cell.labels, + } + column_dict.setdefault("cells", []).append(cell_dict) + family_dict.setdefault("columns", []).append(column_dict) + dict_val.setdefault("families", []).append(family_dict) + row_list.append(dict_val) + return row_list + + @client_handler.error_safe + async def ReadRow(self, row_key, **kwargs): + table_id = kwargs["table_name"].split("/")[-1] + instance = self.client.instance(self.instance_id) + table = instance.table(table_id) + + row = table.read_row(row_key) + # parse results into proto formatted dict + dict_val = {"row_key": row.row_key} + for family, family_cells in row.cells.items(): + family_dict = {"name": family} + for qualifier, qualifier_cells in family_cells.items(): + column_dict = {"qualifier": qualifier} + for cell in qualifier_cells: + cell_dict = { + "value": cell.value, + "timestamp_micros": cell.timestamp.timestamp() * 1000000, + "labels": cell.labels, + } + column_dict.setdefault("cells", []).append(cell_dict) + family_dict.setdefault("columns", []).append(column_dict) + dict_val.setdefault("families", []).append(family_dict) + return dict_val + + @client_handler.error_safe + async def MutateRow(self, request, **kwargs): + from datetime import datetime + from google.cloud.bigtable.row import DirectRow + table_id = request["table_name"].split("/")[-1] + instance = self.client.instance(self.instance_id) + table = instance.table(table_id) + row_key = request["row_key"] + new_row = DirectRow(row_key, table) + for m_dict in request.get("mutations", []): + details = m_dict.get("set_cell") or m_dict.get("delete_from_column") or m_dict.get("delete_from_family") or m_dict.get("delete_from_row") + timestamp = datetime.fromtimestamp(details.get("timestamp_micros")) if details.get("timestamp_micros") else None + if m_dict.get("set_cell"): + new_row.set_cell(details["family_name"], details["column_qualifier"], details["value"], timestamp=timestamp) + elif m_dict.get("delete_from_column"): + new_row.delete_cell(details["family_name"], details["column_qualifier"], timestamp=timestamp) + elif m_dict.get("delete_from_family"): + new_row.delete_cells(details["family_name"], timestamp=timestamp) + elif m_dict.get("delete_from_row"): + new_row.delete() + table.mutate_rows([new_row]) + return "OK" + + @client_handler.error_safe + async def BulkMutateRows(self, request, **kwargs): + from google.cloud.bigtable.row import DirectRow + from datetime import datetime + table_id = request["table_name"].split("/")[-1] + instance = self.client.instance(self.instance_id) + table = instance.table(table_id) + rows = [] + for entry in request.get("entries", []): + row_key = entry["row_key"] + new_row = DirectRow(row_key, table) + for m_dict in entry.get("mutations"): + details = m_dict.get("set_cell") or m_dict.get("delete_from_column") or m_dict.get("delete_from_family") or m_dict.get("delete_from_row") + timestamp = datetime.fromtimestamp(details.get("timestamp_micros")) if details.get("timestamp_micros") else None + if m_dict.get("set_cell"): + new_row.set_cell(details["family_name"], details["column_qualifier"], details["value"], timestamp=timestamp) + elif m_dict.get("delete_from_column"): + new_row.delete_cell(details["family_name"], details["column_qualifier"], timestamp=timestamp) + elif m_dict.get("delete_from_family"): + new_row.delete_cells(details["family_name"], timestamp=timestamp) + elif m_dict.get("delete_from_row"): + new_row.delete() + rows.append(new_row) + table.mutate_rows(rows) + return "OK" + + @client_handler.error_safe + async def CheckAndMutateRow(self, request, **kwargs): + from google.cloud.bigtable.row import ConditionalRow + from google.cloud.bigtable.row_filters import PassAllFilter + table_id = request["table_name"].split("/")[-1] + instance = self.client.instance(self.instance_id) + table = instance.table(table_id) + + predicate_filter = request.get("predicate_filter", PassAllFilter(True)) + new_row = ConditionalRow(request["row_key"], table, predicate_filter) + + combined_mutations = [{"state": True, **m} for m in request.get("true_mutations", [])] + combined_mutations.extend([{"state": False, **m} for m in request.get("false_mutations", [])]) + for mut_dict in combined_mutations: + if "set_cell" in mut_dict: + details = mut_dict["set_cell"] + new_row.set_cell( + details.get("family_name", ""), + details.get("column_qualifier", ""), + details.get("value", ""), + timestamp=details.get("timestamp_micros", None), + state=mut_dict["state"], + ) + elif "delete_from_column" in mut_dict: + details = mut_dict["delete_from_column"] + new_row.delete_cell( + details.get("family_name", ""), + details.get("column_qualifier", ""), + timestamp=details.get("timestamp_micros", None), + state=mut_dict["state"], + ) + elif "delete_from_family" in mut_dict: + details = mut_dict["delete_from_family"] + new_row.delete_cells( + details.get("family_name", ""), + timestamp=details.get("timestamp_micros", None), + state=mut_dict["state"], + ) + elif "delete_from_row" in mut_dict: + new_row.delete(state=mut_dict["state"]) + else: + raise RuntimeError(f"Unknown mutation type: {mut_dict}") + return new_row.commit() + + @client_handler.error_safe + async def ReadModifyWriteRow(self, request, **kwargs): + from google.cloud.bigtable.row import AppendRow + from google.cloud._helpers import _microseconds_from_datetime + table_id = request["table_name"].split("/")[-1] + instance = self.client.instance(self.instance_id) + table = instance.table(table_id) + row_key = request["row_key"] + new_row = AppendRow(row_key, table) + for rule_dict in request.get("rules", []): + qualifier = rule_dict["column_qualifier"] + family = rule_dict["family_name"] + if "append_value" in rule_dict: + new_row.append_cell_value(family, qualifier, rule_dict["append_value"]) + else: + new_row.increment_cell_value(family, qualifier, rule_dict["increment_amount"]) + raw_result = new_row.commit() + result_families = [] + for family, column_dict in raw_result.items(): + result_columns = [] + for column, cell_list in column_dict.items(): + result_cells = [] + for cell_tuple in cell_list: + cell_dict = {"value": cell_tuple[0], "timestamp_micros": _microseconds_from_datetime(cell_tuple[1])} + result_cells.append(cell_dict) + result_columns.append({"qualifier": column, "cells": result_cells}) + result_families.append({"name": family, "columns": result_columns}) + return {"key": row_key, "families": result_families} + + @client_handler.error_safe + async def SampleRowKeys(self, request, **kwargs): + table_id = request["table_name"].split("/")[-1] + instance = self.client.instance(self.instance_id) + table = instance.table(table_id) + response = list(table.sample_row_keys()) + tuple_response = [(s.row_key, s.offset_bytes) for s in response] + return tuple_response diff --git a/test_proxy/handlers/grpc_handler.py b/test_proxy/handlers/grpc_handler.py new file mode 100644 index 000000000..2c70778dd --- /dev/null +++ b/test_proxy/handlers/grpc_handler.py @@ -0,0 +1,148 @@ + +import time + +import test_proxy_pb2 +import test_proxy_pb2_grpc +import data_pb2 +import bigtable_pb2 +from google.rpc.status_pb2 import Status +from google.protobuf import json_format + + +class TestProxyGrpcServer(test_proxy_pb2_grpc.CloudBigtableV2TestProxyServicer): + """ + Implements a grpc server that proxies conformance test requests to the client library + + Due to issues with using protoc-compiled protos and client-library + proto-plus objects in the same process, this server defers requests to + matching methods in a TestProxyClientHandler instance in a separate + process. + This happens invisbly in the decorator @delegate_to_client_handler, with the + results attached to each request as a client_response kwarg + """ + + def __init__(self, request_q, queue_pool): + self.open_queues = list(range(len(queue_pool))) + self.queue_pool = queue_pool + self.request_q = request_q + + def delegate_to_client_handler(func, timeout_seconds=300): + """ + Decorator that transparently passes a request to the client + handler process, and then attaches the resonse to the wrapped call + """ + + def wrapper(self, request, context, **kwargs): + deadline = time.time() + timeout_seconds + json_dict = json_format.MessageToDict(request) + out_idx = self.open_queues.pop() + json_dict["proxy_request"] = func.__name__ + json_dict["response_queue_idx"] = out_idx + out_q = self.queue_pool[out_idx] + self.request_q.put(json_dict) + # wait for response + while time.time() < deadline: + if not out_q.empty(): + response = out_q.get() + self.open_queues.append(out_idx) + if isinstance(response, Exception): + raise response + else: + return func( + self, + request, + context, + client_response=response, + **kwargs, + ) + time.sleep(1e-4) + + return wrapper + + + @delegate_to_client_handler + def CreateClient(self, request, context, client_response=None): + return test_proxy_pb2.CreateClientResponse() + + @delegate_to_client_handler + def CloseClient(self, request, context, client_response=None): + return test_proxy_pb2.CloseClientResponse() + + @delegate_to_client_handler + def RemoveClient(self, request, context, client_response=None): + return test_proxy_pb2.RemoveClientResponse() + + @delegate_to_client_handler + def ReadRows(self, request, context, client_response=None): + status = Status() + rows = [] + if isinstance(client_response, dict) and "error" in client_response: + status = Status(code=5, message=client_response["error"]) + else: + rows = [data_pb2.Row(**d) for d in client_response] + result = test_proxy_pb2.RowsResult(row=rows, status=status) + return result + + @delegate_to_client_handler + def ReadRow(self, request, context, client_response=None): + status = Status() + row = None + if isinstance(client_response, dict) and "error" in client_response: + status=Status(code=client_response.get("code", 5), message=client_response.get("error")) + elif client_response != "None": + row = data_pb2.Row(**client_response) + result = test_proxy_pb2.RowResult(row=row, status=status) + return result + + @delegate_to_client_handler + def MutateRow(self, request, context, client_response=None): + status = Status() + if isinstance(client_response, dict) and "error" in client_response: + status = Status(code=client_response.get("code", 5), message=client_response["error"]) + return test_proxy_pb2.MutateRowResult(status=status) + + @delegate_to_client_handler + def BulkMutateRows(self, request, context, client_response=None): + status = Status() + entries = [] + if isinstance(client_response, dict) and "error" in client_response: + entries = [bigtable_pb2.MutateRowsResponse.Entry(index=exc_dict.get("index",1), status=Status(code=exc_dict.get("code", 5))) for exc_dict in client_response.get("subexceptions", [])] + if not entries: + # only return failure on the overall request if there are failed entries + status = Status(code=client_response.get("code", 5), message=client_response["error"]) + # TODO: protos were updated. entry is now entries: https://github.com/googleapis/cndb-client-testing-protos/commit/e6205a2bba04acc10d12421a1402870b4a525fb3 + response = test_proxy_pb2.MutateRowsResult(status=status, entry=entries) + return response + + @delegate_to_client_handler + def CheckAndMutateRow(self, request, context, client_response=None): + if isinstance(client_response, dict) and "error" in client_response: + status = Status(code=client_response.get("code", 5), message=client_response["error"]) + response = test_proxy_pb2.CheckAndMutateRowResult(status=status) + else: + result = bigtable_pb2.CheckAndMutateRowResponse(predicate_matched=client_response) + response = test_proxy_pb2.CheckAndMutateRowResult(result=result, status=Status()) + return response + + @delegate_to_client_handler + def ReadModifyWriteRow(self, request, context, client_response=None): + status = Status() + row = None + if isinstance(client_response, dict) and "error" in client_response: + status = Status(code=client_response.get("code", 5), message=client_response.get("error")) + elif client_response != "None": + row = data_pb2.Row(**client_response) + result = test_proxy_pb2.RowResult(row=row, status=status) + return result + + @delegate_to_client_handler + def SampleRowKeys(self, request, context, client_response=None): + status = Status() + sample_list = [] + if isinstance(client_response, dict) and "error" in client_response: + status = Status(code=client_response.get("code", 5), message=client_response.get("error")) + else: + for sample in client_response: + sample_list.append(bigtable_pb2.SampleRowKeysResponse(offset_bytes=sample[1], row_key=sample[0])) + # TODO: protos were updated. sample is now samples: https://github.com/googleapis/cndb-client-testing-protos/commit/e6205a2bba04acc10d12421a1402870b4a525fb3 + return test_proxy_pb2.SampleRowKeysResult(status=status, sample=sample_list) diff --git a/test_proxy/noxfile.py b/test_proxy/noxfile.py new file mode 100644 index 000000000..bebf247b7 --- /dev/null +++ b/test_proxy/noxfile.py @@ -0,0 +1,80 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import os +import pathlib +import re +from colorlog.escape_codes import parse_colors + +import nox + + +DEFAULT_PYTHON_VERSION = "3.10" + +PROXY_SERVER_PORT=os.environ.get("PROXY_SERVER_PORT", "50055") +PROXY_CLIENT_VERSION=os.environ.get("PROXY_CLIENT_VERSION", None) + +CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() +REPO_ROOT_DIRECTORY = CURRENT_DIRECTORY.parent + +nox.options.sessions = ["run_proxy", "conformance_tests"] + +TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git" +CLONE_REPO_DIR = "cloud-bigtable-clients-test" + +# Error if a python version is missing +nox.options.error_on_missing_interpreters = True + + +def default(session): + """ + if nox is run directly, run the test_proxy session + """ + test_proxy(session) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def conformance_tests(session): + """ + download and run the conformance test suite against the test proxy + """ + import subprocess + import time + # download the conformance test suite + clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR) + if not os.path.exists(clone_dir): + print("downloading copy of test repo") + session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR) + # start tests + with session.chdir(f"{clone_dir}/tests"): + session.run("go", "test", "-v", f"-proxy_addr=:{PROXY_SERVER_PORT}") + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def test_proxy(session): + """Start up the test proxy""" + # Install all dependencies, then install this package into the + # virtualenv's dist-packages. + # session.install( + # "grpcio", + # ) + if PROXY_CLIENT_VERSION is not None: + # install released version of the library + session.install(f"python-bigtable=={PROXY_CLIENT_VERSION}") + else: + # install the library from the source + session.install("-e", str(REPO_ROOT_DIRECTORY)) + session.install("-e", str(REPO_ROOT_DIRECTORY / "python-api-core")) + + session.run("python", "test_proxy.py", "--port", PROXY_SERVER_PORT, *session.posargs,) diff --git a/test_proxy/protos/bigtable_pb2.py b/test_proxy/protos/bigtable_pb2.py new file mode 100644 index 000000000..936a4ed55 --- /dev/null +++ b/test_proxy/protos/bigtable_pb2.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/bigtable/v2/bigtable.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2 +from google.api import client_pb2 as google_dot_api_dot_client__pb2 +from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2 +from google.api import resource_pb2 as google_dot_api_dot_resource__pb2 +from google.api import routing_pb2 as google_dot_api_dot_routing__pb2 +import data_pb2 as google_dot_bigtable_dot_v2_dot_data__pb2 +import request_stats_pb2 as google_dot_bigtable_dot_v2_dot_request__stats__pb2 +from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2 +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 +from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2 +from google.rpc import status_pb2 as google_dot_rpc_dot_status__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n!google/bigtable/v2/bigtable.proto\x12\x12google.bigtable.v2\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x18google/api/routing.proto\x1a\x1dgoogle/bigtable/v2/data.proto\x1a&google/bigtable/v2/request_stats.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/wrappers.proto\x1a\x17google/rpc/status.proto\"\x90\x03\n\x0fReadRowsRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x05 \x01(\t\x12(\n\x04rows\x18\x02 \x01(\x0b\x32\x1a.google.bigtable.v2.RowSet\x12-\n\x06\x66ilter\x18\x03 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x12\x12\n\nrows_limit\x18\x04 \x01(\x03\x12P\n\x12request_stats_view\x18\x06 \x01(\x0e\x32\x34.google.bigtable.v2.ReadRowsRequest.RequestStatsView\"f\n\x10RequestStatsView\x12\"\n\x1eREQUEST_STATS_VIEW_UNSPECIFIED\x10\x00\x12\x16\n\x12REQUEST_STATS_NONE\x10\x01\x12\x16\n\x12REQUEST_STATS_FULL\x10\x02\"\xb1\x03\n\x10ReadRowsResponse\x12>\n\x06\x63hunks\x18\x01 \x03(\x0b\x32..google.bigtable.v2.ReadRowsResponse.CellChunk\x12\x1c\n\x14last_scanned_row_key\x18\x02 \x01(\x0c\x12\x37\n\rrequest_stats\x18\x03 \x01(\x0b\x32 .google.bigtable.v2.RequestStats\x1a\x85\x02\n\tCellChunk\x12\x0f\n\x07row_key\x18\x01 \x01(\x0c\x12\x31\n\x0b\x66\x61mily_name\x18\x02 \x01(\x0b\x32\x1c.google.protobuf.StringValue\x12.\n\tqualifier\x18\x03 \x01(\x0b\x32\x1b.google.protobuf.BytesValue\x12\x18\n\x10timestamp_micros\x18\x04 \x01(\x03\x12\x0e\n\x06labels\x18\x05 \x03(\t\x12\r\n\x05value\x18\x06 \x01(\x0c\x12\x12\n\nvalue_size\x18\x07 \x01(\x05\x12\x13\n\treset_row\x18\x08 \x01(\x08H\x00\x12\x14\n\ncommit_row\x18\t \x01(\x08H\x00\x42\x0c\n\nrow_status\"n\n\x14SampleRowKeysRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x02 \x01(\t\">\n\x15SampleRowKeysResponse\x12\x0f\n\x07row_key\x18\x01 \x01(\x0c\x12\x14\n\x0coffset_bytes\x18\x02 \x01(\x03\"\xb6\x01\n\x10MutateRowRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x04 \x01(\t\x12\x14\n\x07row_key\x18\x02 \x01(\x0c\x42\x03\xe0\x41\x02\x12\x34\n\tmutations\x18\x03 \x03(\x0b\x32\x1c.google.bigtable.v2.MutationB\x03\xe0\x41\x02\"\x13\n\x11MutateRowResponse\"\xfe\x01\n\x11MutateRowsRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x03 \x01(\t\x12\x41\n\x07\x65ntries\x18\x02 \x03(\x0b\x32+.google.bigtable.v2.MutateRowsRequest.EntryB\x03\xe0\x41\x02\x1aN\n\x05\x45ntry\x12\x0f\n\x07row_key\x18\x01 \x01(\x0c\x12\x34\n\tmutations\x18\x02 \x03(\x0b\x32\x1c.google.bigtable.v2.MutationB\x03\xe0\x41\x02\"\x8f\x01\n\x12MutateRowsResponse\x12=\n\x07\x65ntries\x18\x01 \x03(\x0b\x32,.google.bigtable.v2.MutateRowsResponse.Entry\x1a:\n\x05\x45ntry\x12\r\n\x05index\x18\x01 \x01(\x03\x12\"\n\x06status\x18\x02 \x01(\x0b\x32\x12.google.rpc.Status\"\xae\x02\n\x18\x43heckAndMutateRowRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x07 \x01(\t\x12\x14\n\x07row_key\x18\x02 \x01(\x0c\x42\x03\xe0\x41\x02\x12\x37\n\x10predicate_filter\x18\x06 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x12\x34\n\x0etrue_mutations\x18\x04 \x03(\x0b\x32\x1c.google.bigtable.v2.Mutation\x12\x35\n\x0f\x66\x61lse_mutations\x18\x05 \x03(\x0b\x32\x1c.google.bigtable.v2.Mutation\"6\n\x19\x43heckAndMutateRowResponse\x12\x19\n\x11predicate_matched\x18\x01 \x01(\x08\"i\n\x12PingAndWarmRequest\x12;\n\x04name\x18\x01 \x01(\tB-\xe0\x41\x02\xfa\x41\'\n%bigtableadmin.googleapis.com/Instance\x12\x16\n\x0e\x61pp_profile_id\x18\x02 \x01(\t\"\x15\n\x13PingAndWarmResponse\"\xc6\x01\n\x19ReadModifyWriteRowRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x04 \x01(\t\x12\x14\n\x07row_key\x18\x02 \x01(\x0c\x42\x03\xe0\x41\x02\x12;\n\x05rules\x18\x03 \x03(\x0b\x32\'.google.bigtable.v2.ReadModifyWriteRuleB\x03\xe0\x41\x02\"B\n\x1aReadModifyWriteRowResponse\x12$\n\x03row\x18\x01 \x01(\x0b\x32\x17.google.bigtable.v2.Row\"\x86\x01\n,GenerateInitialChangeStreamPartitionsRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x02 \x01(\t\"g\n-GenerateInitialChangeStreamPartitionsResponse\x12\x36\n\tpartition\x18\x01 \x01(\x0b\x32#.google.bigtable.v2.StreamPartition\"\x9b\x03\n\x17ReadChangeStreamRequest\x12>\n\ntable_name\x18\x01 \x01(\tB*\xe0\x41\x02\xfa\x41$\n\"bigtableadmin.googleapis.com/Table\x12\x16\n\x0e\x61pp_profile_id\x18\x02 \x01(\t\x12\x36\n\tpartition\x18\x03 \x01(\x0b\x32#.google.bigtable.v2.StreamPartition\x12\x30\n\nstart_time\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x12K\n\x13\x63ontinuation_tokens\x18\x06 \x01(\x0b\x32,.google.bigtable.v2.StreamContinuationTokensH\x00\x12,\n\x08\x65nd_time\x18\x05 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x35\n\x12heartbeat_duration\x18\x07 \x01(\x0b\x32\x19.google.protobuf.DurationB\x0c\n\nstart_from\"\xeb\t\n\x18ReadChangeStreamResponse\x12N\n\x0b\x64\x61ta_change\x18\x01 \x01(\x0b\x32\x37.google.bigtable.v2.ReadChangeStreamResponse.DataChangeH\x00\x12K\n\theartbeat\x18\x02 \x01(\x0b\x32\x36.google.bigtable.v2.ReadChangeStreamResponse.HeartbeatH\x00\x12P\n\x0c\x63lose_stream\x18\x03 \x01(\x0b\x32\x38.google.bigtable.v2.ReadChangeStreamResponse.CloseStreamH\x00\x1a\xf4\x01\n\rMutationChunk\x12X\n\nchunk_info\x18\x01 \x01(\x0b\x32\x44.google.bigtable.v2.ReadChangeStreamResponse.MutationChunk.ChunkInfo\x12.\n\x08mutation\x18\x02 \x01(\x0b\x32\x1c.google.bigtable.v2.Mutation\x1aY\n\tChunkInfo\x12\x1a\n\x12\x63hunked_value_size\x18\x01 \x01(\x05\x12\x1c\n\x14\x63hunked_value_offset\x18\x02 \x01(\x05\x12\x12\n\nlast_chunk\x18\x03 \x01(\x08\x1a\xc6\x03\n\nDataChange\x12J\n\x04type\x18\x01 \x01(\x0e\x32<.google.bigtable.v2.ReadChangeStreamResponse.DataChange.Type\x12\x19\n\x11source_cluster_id\x18\x02 \x01(\t\x12\x0f\n\x07row_key\x18\x03 \x01(\x0c\x12\x34\n\x10\x63ommit_timestamp\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x12\n\ntiebreaker\x18\x05 \x01(\x05\x12J\n\x06\x63hunks\x18\x06 \x03(\x0b\x32:.google.bigtable.v2.ReadChangeStreamResponse.MutationChunk\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\r\n\x05token\x18\t \x01(\t\x12;\n\x17\x65stimated_low_watermark\x18\n \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"P\n\x04Type\x12\x14\n\x10TYPE_UNSPECIFIED\x10\x00\x12\x08\n\x04USER\x10\x01\x12\x16\n\x12GARBAGE_COLLECTION\x10\x02\x12\x10\n\x0c\x43ONTINUATION\x10\x03\x1a\x91\x01\n\tHeartbeat\x12G\n\x12\x63ontinuation_token\x18\x01 \x01(\x0b\x32+.google.bigtable.v2.StreamContinuationToken\x12;\n\x17\x65stimated_low_watermark\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x1a{\n\x0b\x43loseStream\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12H\n\x13\x63ontinuation_tokens\x18\x02 \x03(\x0b\x32+.google.bigtable.v2.StreamContinuationTokenB\x0f\n\rstream_record2\xd7\x18\n\x08\x42igtable\x12\x9b\x02\n\x08ReadRows\x12#.google.bigtable.v2.ReadRowsRequest\x1a$.google.bigtable.v2.ReadRowsResponse\"\xc1\x01\x82\xd3\xe4\x93\x02>\"9/v2/{table_name=projects/*/instances/*/tables/*}:readRows:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\ntable_name\xda\x41\x19table_name,app_profile_id0\x01\x12\xac\x02\n\rSampleRowKeys\x12(.google.bigtable.v2.SampleRowKeysRequest\x1a).google.bigtable.v2.SampleRowKeysResponse\"\xc3\x01\x82\xd3\xe4\x93\x02@\x12>/v2/{table_name=projects/*/instances/*/tables/*}:sampleRowKeys\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\ntable_name\xda\x41\x19table_name,app_profile_id0\x01\x12\xc1\x02\n\tMutateRow\x12$.google.bigtable.v2.MutateRowRequest\x1a%.google.bigtable.v2.MutateRowResponse\"\xe6\x01\x82\xd3\xe4\x93\x02?\":/v2/{table_name=projects/*/instances/*/tables/*}:mutateRow:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x1ctable_name,row_key,mutations\xda\x41+table_name,row_key,mutations,app_profile_id\x12\xb3\x02\n\nMutateRows\x12%.google.bigtable.v2.MutateRowsRequest\x1a&.google.bigtable.v2.MutateRowsResponse\"\xd3\x01\x82\xd3\xe4\x93\x02@\";/v2/{table_name=projects/*/instances/*/tables/*}:mutateRows:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x12table_name,entries\xda\x41!table_name,entries,app_profile_id0\x01\x12\xad\x03\n\x11\x43heckAndMutateRow\x12,.google.bigtable.v2.CheckAndMutateRowRequest\x1a-.google.bigtable.v2.CheckAndMutateRowResponse\"\xba\x02\x82\xd3\xe4\x93\x02G\"B/v2/{table_name=projects/*/instances/*/tables/*}:checkAndMutateRow:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x42table_name,row_key,predicate_filter,true_mutations,false_mutations\xda\x41Qtable_name,row_key,predicate_filter,true_mutations,false_mutations,app_profile_id\x12\xee\x01\n\x0bPingAndWarm\x12&.google.bigtable.v2.PingAndWarmRequest\x1a\'.google.bigtable.v2.PingAndWarmResponse\"\x8d\x01\x82\xd3\xe4\x93\x02+\"&/v2/{name=projects/*/instances/*}:ping:\x01*\x8a\xd3\xe4\x93\x02\x39\x12%\n\x04name\x12\x1d{name=projects/*/instances/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x04name\xda\x41\x13name,app_profile_id\x12\xdd\x02\n\x12ReadModifyWriteRow\x12-.google.bigtable.v2.ReadModifyWriteRowRequest\x1a..google.bigtable.v2.ReadModifyWriteRowResponse\"\xe7\x01\x82\xd3\xe4\x93\x02H\"C/v2/{table_name=projects/*/instances/*/tables/*}:readModifyWriteRow:\x01*\x8a\xd3\xe4\x93\x02N\x12:\n\ntable_name\x12,{table_name=projects/*/instances/*/tables/*}\x12\x10\n\x0e\x61pp_profile_id\xda\x41\x18table_name,row_key,rules\xda\x41\'table_name,row_key,rules,app_profile_id\x12\xbb\x02\n%GenerateInitialChangeStreamPartitions\x12@.google.bigtable.v2.GenerateInitialChangeStreamPartitionsRequest\x1a\x41.google.bigtable.v2.GenerateInitialChangeStreamPartitionsResponse\"\x8a\x01\x82\xd3\xe4\x93\x02[\"V/v2/{table_name=projects/*/instances/*/tables/*}:generateInitialChangeStreamPartitions:\x01*\xda\x41\ntable_name\xda\x41\x19table_name,app_profile_id0\x01\x12\xe6\x01\n\x10ReadChangeStream\x12+.google.bigtable.v2.ReadChangeStreamRequest\x1a,.google.bigtable.v2.ReadChangeStreamResponse\"u\x82\xd3\xe4\x93\x02\x46\"A/v2/{table_name=projects/*/instances/*/tables/*}:readChangeStream:\x01*\xda\x41\ntable_name\xda\x41\x19table_name,app_profile_id0\x01\x1a\xdb\x02\xca\x41\x17\x62igtable.googleapis.com\xd2\x41\xbd\x02https://www.googleapis.com/auth/bigtable.data,https://www.googleapis.com/auth/bigtable.data.readonly,https://www.googleapis.com/auth/cloud-bigtable.data,https://www.googleapis.com/auth/cloud-bigtable.data.readonly,https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/cloud-platform.read-onlyB\xeb\x02\n\x16\x63om.google.bigtable.v2B\rBigtableProtoP\x01Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\xaa\x02\x18Google.Cloud.Bigtable.V2\xca\x02\x18Google\\Cloud\\Bigtable\\V2\xea\x02\x1bGoogle::Cloud::Bigtable::V2\xea\x41P\n%bigtableadmin.googleapis.com/Instance\x12\'projects/{project}/instances/{instance}\xea\x41\\\n\"bigtableadmin.googleapis.com/Table\x12\x36projects/{project}/instances/{instance}/tables/{table}b\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.bigtable.v2.bigtable_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\026com.google.bigtable.v2B\rBigtableProtoP\001Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\252\002\030Google.Cloud.Bigtable.V2\312\002\030Google\\Cloud\\Bigtable\\V2\352\002\033Google::Cloud::Bigtable::V2\352AP\n%bigtableadmin.googleapis.com/Instance\022\'projects/{project}/instances/{instance}\352A\\\n\"bigtableadmin.googleapis.com/Table\0226projects/{project}/instances/{instance}/tables/{table}' + _READROWSREQUEST.fields_by_name['table_name']._options = None + _READROWSREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table' + _SAMPLEROWKEYSREQUEST.fields_by_name['table_name']._options = None + _SAMPLEROWKEYSREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table' + _MUTATEROWREQUEST.fields_by_name['table_name']._options = None + _MUTATEROWREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table' + _MUTATEROWREQUEST.fields_by_name['row_key']._options = None + _MUTATEROWREQUEST.fields_by_name['row_key']._serialized_options = b'\340A\002' + _MUTATEROWREQUEST.fields_by_name['mutations']._options = None + _MUTATEROWREQUEST.fields_by_name['mutations']._serialized_options = b'\340A\002' + _MUTATEROWSREQUEST_ENTRY.fields_by_name['mutations']._options = None + _MUTATEROWSREQUEST_ENTRY.fields_by_name['mutations']._serialized_options = b'\340A\002' + _MUTATEROWSREQUEST.fields_by_name['table_name']._options = None + _MUTATEROWSREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table' + _MUTATEROWSREQUEST.fields_by_name['entries']._options = None + _MUTATEROWSREQUEST.fields_by_name['entries']._serialized_options = b'\340A\002' + _CHECKANDMUTATEROWREQUEST.fields_by_name['table_name']._options = None + _CHECKANDMUTATEROWREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table' + _CHECKANDMUTATEROWREQUEST.fields_by_name['row_key']._options = None + _CHECKANDMUTATEROWREQUEST.fields_by_name['row_key']._serialized_options = b'\340A\002' + _PINGANDWARMREQUEST.fields_by_name['name']._options = None + _PINGANDWARMREQUEST.fields_by_name['name']._serialized_options = b'\340A\002\372A\'\n%bigtableadmin.googleapis.com/Instance' + _READMODIFYWRITEROWREQUEST.fields_by_name['table_name']._options = None + _READMODIFYWRITEROWREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table' + _READMODIFYWRITEROWREQUEST.fields_by_name['row_key']._options = None + _READMODIFYWRITEROWREQUEST.fields_by_name['row_key']._serialized_options = b'\340A\002' + _READMODIFYWRITEROWREQUEST.fields_by_name['rules']._options = None + _READMODIFYWRITEROWREQUEST.fields_by_name['rules']._serialized_options = b'\340A\002' + _GENERATEINITIALCHANGESTREAMPARTITIONSREQUEST.fields_by_name['table_name']._options = None + _GENERATEINITIALCHANGESTREAMPARTITIONSREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table' + _READCHANGESTREAMREQUEST.fields_by_name['table_name']._options = None + _READCHANGESTREAMREQUEST.fields_by_name['table_name']._serialized_options = b'\340A\002\372A$\n\"bigtableadmin.googleapis.com/Table' + _BIGTABLE._options = None + _BIGTABLE._serialized_options = b'\312A\027bigtable.googleapis.com\322A\275\002https://www.googleapis.com/auth/bigtable.data,https://www.googleapis.com/auth/bigtable.data.readonly,https://www.googleapis.com/auth/cloud-bigtable.data,https://www.googleapis.com/auth/cloud-bigtable.data.readonly,https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/cloud-platform.read-only' + _BIGTABLE.methods_by_name['ReadRows']._options = None + _BIGTABLE.methods_by_name['ReadRows']._serialized_options = b'\202\323\344\223\002>\"9/v2/{table_name=projects/*/instances/*/tables/*}:readRows:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\ntable_name\332A\031table_name,app_profile_id' + _BIGTABLE.methods_by_name['SampleRowKeys']._options = None + _BIGTABLE.methods_by_name['SampleRowKeys']._serialized_options = b'\202\323\344\223\002@\022>/v2/{table_name=projects/*/instances/*/tables/*}:sampleRowKeys\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\ntable_name\332A\031table_name,app_profile_id' + _BIGTABLE.methods_by_name['MutateRow']._options = None + _BIGTABLE.methods_by_name['MutateRow']._serialized_options = b'\202\323\344\223\002?\":/v2/{table_name=projects/*/instances/*/tables/*}:mutateRow:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\034table_name,row_key,mutations\332A+table_name,row_key,mutations,app_profile_id' + _BIGTABLE.methods_by_name['MutateRows']._options = None + _BIGTABLE.methods_by_name['MutateRows']._serialized_options = b'\202\323\344\223\002@\";/v2/{table_name=projects/*/instances/*/tables/*}:mutateRows:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\022table_name,entries\332A!table_name,entries,app_profile_id' + _BIGTABLE.methods_by_name['CheckAndMutateRow']._options = None + _BIGTABLE.methods_by_name['CheckAndMutateRow']._serialized_options = b'\202\323\344\223\002G\"B/v2/{table_name=projects/*/instances/*/tables/*}:checkAndMutateRow:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332ABtable_name,row_key,predicate_filter,true_mutations,false_mutations\332AQtable_name,row_key,predicate_filter,true_mutations,false_mutations,app_profile_id' + _BIGTABLE.methods_by_name['PingAndWarm']._options = None + _BIGTABLE.methods_by_name['PingAndWarm']._serialized_options = b'\202\323\344\223\002+\"&/v2/{name=projects/*/instances/*}:ping:\001*\212\323\344\223\0029\022%\n\004name\022\035{name=projects/*/instances/*}\022\020\n\016app_profile_id\332A\004name\332A\023name,app_profile_id' + _BIGTABLE.methods_by_name['ReadModifyWriteRow']._options = None + _BIGTABLE.methods_by_name['ReadModifyWriteRow']._serialized_options = b'\202\323\344\223\002H\"C/v2/{table_name=projects/*/instances/*/tables/*}:readModifyWriteRow:\001*\212\323\344\223\002N\022:\n\ntable_name\022,{table_name=projects/*/instances/*/tables/*}\022\020\n\016app_profile_id\332A\030table_name,row_key,rules\332A\'table_name,row_key,rules,app_profile_id' + _BIGTABLE.methods_by_name['GenerateInitialChangeStreamPartitions']._options = None + _BIGTABLE.methods_by_name['GenerateInitialChangeStreamPartitions']._serialized_options = b'\202\323\344\223\002[\"V/v2/{table_name=projects/*/instances/*/tables/*}:generateInitialChangeStreamPartitions:\001*\332A\ntable_name\332A\031table_name,app_profile_id' + _BIGTABLE.methods_by_name['ReadChangeStream']._options = None + _BIGTABLE.methods_by_name['ReadChangeStream']._serialized_options = b'\202\323\344\223\002F\"A/v2/{table_name=projects/*/instances/*/tables/*}:readChangeStream:\001*\332A\ntable_name\332A\031table_name,app_profile_id' + _READROWSREQUEST._serialized_start=392 + _READROWSREQUEST._serialized_end=792 + _READROWSREQUEST_REQUESTSTATSVIEW._serialized_start=690 + _READROWSREQUEST_REQUESTSTATSVIEW._serialized_end=792 + _READROWSRESPONSE._serialized_start=795 + _READROWSRESPONSE._serialized_end=1228 + _READROWSRESPONSE_CELLCHUNK._serialized_start=967 + _READROWSRESPONSE_CELLCHUNK._serialized_end=1228 + _SAMPLEROWKEYSREQUEST._serialized_start=1230 + _SAMPLEROWKEYSREQUEST._serialized_end=1340 + _SAMPLEROWKEYSRESPONSE._serialized_start=1342 + _SAMPLEROWKEYSRESPONSE._serialized_end=1404 + _MUTATEROWREQUEST._serialized_start=1407 + _MUTATEROWREQUEST._serialized_end=1589 + _MUTATEROWRESPONSE._serialized_start=1591 + _MUTATEROWRESPONSE._serialized_end=1610 + _MUTATEROWSREQUEST._serialized_start=1613 + _MUTATEROWSREQUEST._serialized_end=1867 + _MUTATEROWSREQUEST_ENTRY._serialized_start=1789 + _MUTATEROWSREQUEST_ENTRY._serialized_end=1867 + _MUTATEROWSRESPONSE._serialized_start=1870 + _MUTATEROWSRESPONSE._serialized_end=2013 + _MUTATEROWSRESPONSE_ENTRY._serialized_start=1955 + _MUTATEROWSRESPONSE_ENTRY._serialized_end=2013 + _CHECKANDMUTATEROWREQUEST._serialized_start=2016 + _CHECKANDMUTATEROWREQUEST._serialized_end=2318 + _CHECKANDMUTATEROWRESPONSE._serialized_start=2320 + _CHECKANDMUTATEROWRESPONSE._serialized_end=2374 + _PINGANDWARMREQUEST._serialized_start=2376 + _PINGANDWARMREQUEST._serialized_end=2481 + _PINGANDWARMRESPONSE._serialized_start=2483 + _PINGANDWARMRESPONSE._serialized_end=2504 + _READMODIFYWRITEROWREQUEST._serialized_start=2507 + _READMODIFYWRITEROWREQUEST._serialized_end=2705 + _READMODIFYWRITEROWRESPONSE._serialized_start=2707 + _READMODIFYWRITEROWRESPONSE._serialized_end=2773 + _GENERATEINITIALCHANGESTREAMPARTITIONSREQUEST._serialized_start=2776 + _GENERATEINITIALCHANGESTREAMPARTITIONSREQUEST._serialized_end=2910 + _GENERATEINITIALCHANGESTREAMPARTITIONSRESPONSE._serialized_start=2912 + _GENERATEINITIALCHANGESTREAMPARTITIONSRESPONSE._serialized_end=3015 + _READCHANGESTREAMREQUEST._serialized_start=3018 + _READCHANGESTREAMREQUEST._serialized_end=3429 + _READCHANGESTREAMRESPONSE._serialized_start=3432 + _READCHANGESTREAMRESPONSE._serialized_end=4691 + _READCHANGESTREAMRESPONSE_MUTATIONCHUNK._serialized_start=3700 + _READCHANGESTREAMRESPONSE_MUTATIONCHUNK._serialized_end=3944 + _READCHANGESTREAMRESPONSE_MUTATIONCHUNK_CHUNKINFO._serialized_start=3855 + _READCHANGESTREAMRESPONSE_MUTATIONCHUNK_CHUNKINFO._serialized_end=3944 + _READCHANGESTREAMRESPONSE_DATACHANGE._serialized_start=3947 + _READCHANGESTREAMRESPONSE_DATACHANGE._serialized_end=4401 + _READCHANGESTREAMRESPONSE_DATACHANGE_TYPE._serialized_start=4321 + _READCHANGESTREAMRESPONSE_DATACHANGE_TYPE._serialized_end=4401 + _READCHANGESTREAMRESPONSE_HEARTBEAT._serialized_start=4404 + _READCHANGESTREAMRESPONSE_HEARTBEAT._serialized_end=4549 + _READCHANGESTREAMRESPONSE_CLOSESTREAM._serialized_start=4551 + _READCHANGESTREAMRESPONSE_CLOSESTREAM._serialized_end=4674 + _BIGTABLE._serialized_start=4694 + _BIGTABLE._serialized_end=7853 +# @@protoc_insertion_point(module_scope) diff --git a/test_proxy/protos/bigtable_pb2_grpc.py b/test_proxy/protos/bigtable_pb2_grpc.py new file mode 100644 index 000000000..9ce87d869 --- /dev/null +++ b/test_proxy/protos/bigtable_pb2_grpc.py @@ -0,0 +1,363 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import bigtable_pb2 as google_dot_bigtable_dot_v2_dot_bigtable__pb2 + + +class BigtableStub(object): + """Service for reading from and writing to existing Bigtable tables. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.ReadRows = channel.unary_stream( + '/google.bigtable.v2.Bigtable/ReadRows', + request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsRequest.SerializeToString, + response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsResponse.FromString, + ) + self.SampleRowKeys = channel.unary_stream( + '/google.bigtable.v2.Bigtable/SampleRowKeys', + request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysRequest.SerializeToString, + response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysResponse.FromString, + ) + self.MutateRow = channel.unary_unary( + '/google.bigtable.v2.Bigtable/MutateRow', + request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowRequest.SerializeToString, + response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowResponse.FromString, + ) + self.MutateRows = channel.unary_stream( + '/google.bigtable.v2.Bigtable/MutateRows', + request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsRequest.SerializeToString, + response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsResponse.FromString, + ) + self.CheckAndMutateRow = channel.unary_unary( + '/google.bigtable.v2.Bigtable/CheckAndMutateRow', + request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowRequest.SerializeToString, + response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowResponse.FromString, + ) + self.PingAndWarm = channel.unary_unary( + '/google.bigtable.v2.Bigtable/PingAndWarm', + request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmRequest.SerializeToString, + response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmResponse.FromString, + ) + self.ReadModifyWriteRow = channel.unary_unary( + '/google.bigtable.v2.Bigtable/ReadModifyWriteRow', + request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowRequest.SerializeToString, + response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowResponse.FromString, + ) + self.GenerateInitialChangeStreamPartitions = channel.unary_stream( + '/google.bigtable.v2.Bigtable/GenerateInitialChangeStreamPartitions', + request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsRequest.SerializeToString, + response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsResponse.FromString, + ) + self.ReadChangeStream = channel.unary_stream( + '/google.bigtable.v2.Bigtable/ReadChangeStream', + request_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamRequest.SerializeToString, + response_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamResponse.FromString, + ) + + +class BigtableServicer(object): + """Service for reading from and writing to existing Bigtable tables. + """ + + def ReadRows(self, request, context): + """Streams back the contents of all requested rows in key order, optionally + applying the same Reader filter to each. Depending on their size, + rows and cells may be broken up across multiple responses, but + atomicity of each row will still be preserved. See the + ReadRowsResponse documentation for details. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SampleRowKeys(self, request, context): + """Returns a sample of row keys in the table. The returned row keys will + delimit contiguous sections of the table of approximately equal size, + which can be used to break up the data for distributed tasks like + mapreduces. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def MutateRow(self, request, context): + """Mutates a row atomically. Cells already present in the row are left + unchanged unless explicitly changed by `mutation`. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def MutateRows(self, request, context): + """Mutates multiple rows in a batch. Each individual row is mutated + atomically as in MutateRow, but the entire batch is not executed + atomically. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CheckAndMutateRow(self, request, context): + """Mutates a row atomically based on the output of a predicate Reader filter. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PingAndWarm(self, request, context): + """Warm up associated instance metadata for this connection. + This call is not required but may be useful for connection keep-alive. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ReadModifyWriteRow(self, request, context): + """Modifies a row atomically on the server. The method reads the latest + existing timestamp and value from the specified columns and writes a new + entry based on pre-defined read/modify/write rules. The new value for the + timestamp is the greater of the existing timestamp or the current server + time. The method returns the new contents of all modified cells. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GenerateInitialChangeStreamPartitions(self, request, context): + """NOTE: This API is intended to be used by Apache Beam BigtableIO. + Returns the current list of partitions that make up the table's + change stream. The union of partitions will cover the entire keyspace. + Partitions can be read with `ReadChangeStream`. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ReadChangeStream(self, request, context): + """NOTE: This API is intended to be used by Apache Beam BigtableIO. + Reads changes from a table's change stream. Changes will + reflect both user-initiated mutations and mutations that are caused by + garbage collection. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_BigtableServicer_to_server(servicer, server): + rpc_method_handlers = { + 'ReadRows': grpc.unary_stream_rpc_method_handler( + servicer.ReadRows, + request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsRequest.FromString, + response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsResponse.SerializeToString, + ), + 'SampleRowKeys': grpc.unary_stream_rpc_method_handler( + servicer.SampleRowKeys, + request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysRequest.FromString, + response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysResponse.SerializeToString, + ), + 'MutateRow': grpc.unary_unary_rpc_method_handler( + servicer.MutateRow, + request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowRequest.FromString, + response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowResponse.SerializeToString, + ), + 'MutateRows': grpc.unary_stream_rpc_method_handler( + servicer.MutateRows, + request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsRequest.FromString, + response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsResponse.SerializeToString, + ), + 'CheckAndMutateRow': grpc.unary_unary_rpc_method_handler( + servicer.CheckAndMutateRow, + request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowRequest.FromString, + response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowResponse.SerializeToString, + ), + 'PingAndWarm': grpc.unary_unary_rpc_method_handler( + servicer.PingAndWarm, + request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmRequest.FromString, + response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmResponse.SerializeToString, + ), + 'ReadModifyWriteRow': grpc.unary_unary_rpc_method_handler( + servicer.ReadModifyWriteRow, + request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowRequest.FromString, + response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowResponse.SerializeToString, + ), + 'GenerateInitialChangeStreamPartitions': grpc.unary_stream_rpc_method_handler( + servicer.GenerateInitialChangeStreamPartitions, + request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsRequest.FromString, + response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsResponse.SerializeToString, + ), + 'ReadChangeStream': grpc.unary_stream_rpc_method_handler( + servicer.ReadChangeStream, + request_deserializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamRequest.FromString, + response_serializer=google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'google.bigtable.v2.Bigtable', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Bigtable(object): + """Service for reading from and writing to existing Bigtable tables. + """ + + @staticmethod + def ReadRows(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/ReadRows', + google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsRequest.SerializeToString, + google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadRowsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SampleRowKeys(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/SampleRowKeys', + google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysRequest.SerializeToString, + google_dot_bigtable_dot_v2_dot_bigtable__pb2.SampleRowKeysResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def MutateRow(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.v2.Bigtable/MutateRow', + google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowRequest.SerializeToString, + google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def MutateRows(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/MutateRows', + google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsRequest.SerializeToString, + google_dot_bigtable_dot_v2_dot_bigtable__pb2.MutateRowsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CheckAndMutateRow(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.v2.Bigtable/CheckAndMutateRow', + google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowRequest.SerializeToString, + google_dot_bigtable_dot_v2_dot_bigtable__pb2.CheckAndMutateRowResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PingAndWarm(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.v2.Bigtable/PingAndWarm', + google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmRequest.SerializeToString, + google_dot_bigtable_dot_v2_dot_bigtable__pb2.PingAndWarmResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ReadModifyWriteRow(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.v2.Bigtable/ReadModifyWriteRow', + google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowRequest.SerializeToString, + google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadModifyWriteRowResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GenerateInitialChangeStreamPartitions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/GenerateInitialChangeStreamPartitions', + google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsRequest.SerializeToString, + google_dot_bigtable_dot_v2_dot_bigtable__pb2.GenerateInitialChangeStreamPartitionsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ReadChangeStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/google.bigtable.v2.Bigtable/ReadChangeStream', + google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamRequest.SerializeToString, + google_dot_bigtable_dot_v2_dot_bigtable__pb2.ReadChangeStreamResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/test_proxy/protos/data_pb2.py b/test_proxy/protos/data_pb2.py new file mode 100644 index 000000000..fff212034 --- /dev/null +++ b/test_proxy/protos/data_pb2.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/bigtable/v2/data.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dgoogle/bigtable/v2/data.proto\x12\x12google.bigtable.v2\"@\n\x03Row\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12,\n\x08\x66\x61milies\x18\x02 \x03(\x0b\x32\x1a.google.bigtable.v2.Family\"C\n\x06\x46\x61mily\x12\x0c\n\x04name\x18\x01 \x01(\t\x12+\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x1a.google.bigtable.v2.Column\"D\n\x06\x43olumn\x12\x11\n\tqualifier\x18\x01 \x01(\x0c\x12\'\n\x05\x63\x65lls\x18\x02 \x03(\x0b\x32\x18.google.bigtable.v2.Cell\"?\n\x04\x43\x65ll\x12\x18\n\x10timestamp_micros\x18\x01 \x01(\x03\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x0e\n\x06labels\x18\x03 \x03(\t\"\x8a\x01\n\x08RowRange\x12\x1a\n\x10start_key_closed\x18\x01 \x01(\x0cH\x00\x12\x18\n\x0estart_key_open\x18\x02 \x01(\x0cH\x00\x12\x16\n\x0c\x65nd_key_open\x18\x03 \x01(\x0cH\x01\x12\x18\n\x0e\x65nd_key_closed\x18\x04 \x01(\x0cH\x01\x42\x0b\n\tstart_keyB\t\n\x07\x65nd_key\"L\n\x06RowSet\x12\x10\n\x08row_keys\x18\x01 \x03(\x0c\x12\x30\n\nrow_ranges\x18\x02 \x03(\x0b\x32\x1c.google.bigtable.v2.RowRange\"\xc6\x01\n\x0b\x43olumnRange\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x12 \n\x16start_qualifier_closed\x18\x02 \x01(\x0cH\x00\x12\x1e\n\x14start_qualifier_open\x18\x03 \x01(\x0cH\x00\x12\x1e\n\x14\x65nd_qualifier_closed\x18\x04 \x01(\x0cH\x01\x12\x1c\n\x12\x65nd_qualifier_open\x18\x05 \x01(\x0cH\x01\x42\x11\n\x0fstart_qualifierB\x0f\n\rend_qualifier\"N\n\x0eTimestampRange\x12\x1e\n\x16start_timestamp_micros\x18\x01 \x01(\x03\x12\x1c\n\x14\x65nd_timestamp_micros\x18\x02 \x01(\x03\"\x98\x01\n\nValueRange\x12\x1c\n\x12start_value_closed\x18\x01 \x01(\x0cH\x00\x12\x1a\n\x10start_value_open\x18\x02 \x01(\x0cH\x00\x12\x1a\n\x10\x65nd_value_closed\x18\x03 \x01(\x0cH\x01\x12\x18\n\x0e\x65nd_value_open\x18\x04 \x01(\x0cH\x01\x42\r\n\x0bstart_valueB\x0b\n\tend_value\"\xdf\x08\n\tRowFilter\x12\x34\n\x05\x63hain\x18\x01 \x01(\x0b\x32#.google.bigtable.v2.RowFilter.ChainH\x00\x12>\n\ninterleave\x18\x02 \x01(\x0b\x32(.google.bigtable.v2.RowFilter.InterleaveH\x00\x12<\n\tcondition\x18\x03 \x01(\x0b\x32\'.google.bigtable.v2.RowFilter.ConditionH\x00\x12\x0e\n\x04sink\x18\x10 \x01(\x08H\x00\x12\x19\n\x0fpass_all_filter\x18\x11 \x01(\x08H\x00\x12\x1a\n\x10\x62lock_all_filter\x18\x12 \x01(\x08H\x00\x12\x1e\n\x14row_key_regex_filter\x18\x04 \x01(\x0cH\x00\x12\x1b\n\x11row_sample_filter\x18\x0e \x01(\x01H\x00\x12\"\n\x18\x66\x61mily_name_regex_filter\x18\x05 \x01(\tH\x00\x12\'\n\x1d\x63olumn_qualifier_regex_filter\x18\x06 \x01(\x0cH\x00\x12>\n\x13\x63olumn_range_filter\x18\x07 \x01(\x0b\x32\x1f.google.bigtable.v2.ColumnRangeH\x00\x12\x44\n\x16timestamp_range_filter\x18\x08 \x01(\x0b\x32\".google.bigtable.v2.TimestampRangeH\x00\x12\x1c\n\x12value_regex_filter\x18\t \x01(\x0cH\x00\x12<\n\x12value_range_filter\x18\x0f \x01(\x0b\x32\x1e.google.bigtable.v2.ValueRangeH\x00\x12%\n\x1b\x63\x65lls_per_row_offset_filter\x18\n \x01(\x05H\x00\x12$\n\x1a\x63\x65lls_per_row_limit_filter\x18\x0b \x01(\x05H\x00\x12\'\n\x1d\x63\x65lls_per_column_limit_filter\x18\x0c \x01(\x05H\x00\x12!\n\x17strip_value_transformer\x18\r \x01(\x08H\x00\x12!\n\x17\x61pply_label_transformer\x18\x13 \x01(\tH\x00\x1a\x37\n\x05\x43hain\x12.\n\x07\x66ilters\x18\x01 \x03(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x1a<\n\nInterleave\x12.\n\x07\x66ilters\x18\x01 \x03(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x1a\xad\x01\n\tCondition\x12\x37\n\x10predicate_filter\x18\x01 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x12\x32\n\x0btrue_filter\x18\x02 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\x12\x33\n\x0c\x66\x61lse_filter\x18\x03 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilterB\x08\n\x06\x66ilter\"\xc9\x04\n\x08Mutation\x12\x38\n\x08set_cell\x18\x01 \x01(\x0b\x32$.google.bigtable.v2.Mutation.SetCellH\x00\x12K\n\x12\x64\x65lete_from_column\x18\x02 \x01(\x0b\x32-.google.bigtable.v2.Mutation.DeleteFromColumnH\x00\x12K\n\x12\x64\x65lete_from_family\x18\x03 \x01(\x0b\x32-.google.bigtable.v2.Mutation.DeleteFromFamilyH\x00\x12\x45\n\x0f\x64\x65lete_from_row\x18\x04 \x01(\x0b\x32*.google.bigtable.v2.Mutation.DeleteFromRowH\x00\x1a\x61\n\x07SetCell\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x12\x18\n\x10\x63olumn_qualifier\x18\x02 \x01(\x0c\x12\x18\n\x10timestamp_micros\x18\x03 \x01(\x03\x12\r\n\x05value\x18\x04 \x01(\x0c\x1ay\n\x10\x44\x65leteFromColumn\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x12\x18\n\x10\x63olumn_qualifier\x18\x02 \x01(\x0c\x12\x36\n\ntime_range\x18\x03 \x01(\x0b\x32\".google.bigtable.v2.TimestampRange\x1a\'\n\x10\x44\x65leteFromFamily\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x1a\x0f\n\rDeleteFromRowB\n\n\x08mutation\"\x80\x01\n\x13ReadModifyWriteRule\x12\x13\n\x0b\x66\x61mily_name\x18\x01 \x01(\t\x12\x18\n\x10\x63olumn_qualifier\x18\x02 \x01(\x0c\x12\x16\n\x0c\x61ppend_value\x18\x03 \x01(\x0cH\x00\x12\x1a\n\x10increment_amount\x18\x04 \x01(\x03H\x00\x42\x06\n\x04rule\"B\n\x0fStreamPartition\x12/\n\trow_range\x18\x01 \x01(\x0b\x32\x1c.google.bigtable.v2.RowRange\"W\n\x18StreamContinuationTokens\x12;\n\x06tokens\x18\x01 \x03(\x0b\x32+.google.bigtable.v2.StreamContinuationToken\"`\n\x17StreamContinuationToken\x12\x36\n\tpartition\x18\x01 \x01(\x0b\x32#.google.bigtable.v2.StreamPartition\x12\r\n\x05token\x18\x02 \x01(\tB\xb5\x01\n\x16\x63om.google.bigtable.v2B\tDataProtoP\x01Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\xaa\x02\x18Google.Cloud.Bigtable.V2\xca\x02\x18Google\\Cloud\\Bigtable\\V2\xea\x02\x1bGoogle::Cloud::Bigtable::V2b\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.bigtable.v2.data_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\026com.google.bigtable.v2B\tDataProtoP\001Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\252\002\030Google.Cloud.Bigtable.V2\312\002\030Google\\Cloud\\Bigtable\\V2\352\002\033Google::Cloud::Bigtable::V2' + _ROW._serialized_start=53 + _ROW._serialized_end=117 + _FAMILY._serialized_start=119 + _FAMILY._serialized_end=186 + _COLUMN._serialized_start=188 + _COLUMN._serialized_end=256 + _CELL._serialized_start=258 + _CELL._serialized_end=321 + _ROWRANGE._serialized_start=324 + _ROWRANGE._serialized_end=462 + _ROWSET._serialized_start=464 + _ROWSET._serialized_end=540 + _COLUMNRANGE._serialized_start=543 + _COLUMNRANGE._serialized_end=741 + _TIMESTAMPRANGE._serialized_start=743 + _TIMESTAMPRANGE._serialized_end=821 + _VALUERANGE._serialized_start=824 + _VALUERANGE._serialized_end=976 + _ROWFILTER._serialized_start=979 + _ROWFILTER._serialized_end=2098 + _ROWFILTER_CHAIN._serialized_start=1795 + _ROWFILTER_CHAIN._serialized_end=1850 + _ROWFILTER_INTERLEAVE._serialized_start=1852 + _ROWFILTER_INTERLEAVE._serialized_end=1912 + _ROWFILTER_CONDITION._serialized_start=1915 + _ROWFILTER_CONDITION._serialized_end=2088 + _MUTATION._serialized_start=2101 + _MUTATION._serialized_end=2686 + _MUTATION_SETCELL._serialized_start=2396 + _MUTATION_SETCELL._serialized_end=2493 + _MUTATION_DELETEFROMCOLUMN._serialized_start=2495 + _MUTATION_DELETEFROMCOLUMN._serialized_end=2616 + _MUTATION_DELETEFROMFAMILY._serialized_start=2618 + _MUTATION_DELETEFROMFAMILY._serialized_end=2657 + _MUTATION_DELETEFROMROW._serialized_start=2659 + _MUTATION_DELETEFROMROW._serialized_end=2674 + _READMODIFYWRITERULE._serialized_start=2689 + _READMODIFYWRITERULE._serialized_end=2817 + _STREAMPARTITION._serialized_start=2819 + _STREAMPARTITION._serialized_end=2885 + _STREAMCONTINUATIONTOKENS._serialized_start=2887 + _STREAMCONTINUATIONTOKENS._serialized_end=2974 + _STREAMCONTINUATIONTOKEN._serialized_start=2976 + _STREAMCONTINUATIONTOKEN._serialized_end=3072 +# @@protoc_insertion_point(module_scope) diff --git a/test_proxy/protos/data_pb2_grpc.py b/test_proxy/protos/data_pb2_grpc.py new file mode 100644 index 000000000..2daafffeb --- /dev/null +++ b/test_proxy/protos/data_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/test_proxy/protos/request_stats_pb2.py b/test_proxy/protos/request_stats_pb2.py new file mode 100644 index 000000000..95fcc6e0f --- /dev/null +++ b/test_proxy/protos/request_stats_pb2.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/bigtable/v2/request_stats.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&google/bigtable/v2/request_stats.proto\x12\x12google.bigtable.v2\x1a\x1egoogle/protobuf/duration.proto\"\x82\x01\n\x12ReadIterationStats\x12\x17\n\x0frows_seen_count\x18\x01 \x01(\x03\x12\x1b\n\x13rows_returned_count\x18\x02 \x01(\x03\x12\x18\n\x10\x63\x65lls_seen_count\x18\x03 \x01(\x03\x12\x1c\n\x14\x63\x65lls_returned_count\x18\x04 \x01(\x03\"Q\n\x13RequestLatencyStats\x12:\n\x17\x66rontend_server_latency\x18\x01 \x01(\x0b\x32\x19.google.protobuf.Duration\"\xa1\x01\n\x11\x46ullReadStatsView\x12\x44\n\x14read_iteration_stats\x18\x01 \x01(\x0b\x32&.google.bigtable.v2.ReadIterationStats\x12\x46\n\x15request_latency_stats\x18\x02 \x01(\x0b\x32\'.google.bigtable.v2.RequestLatencyStats\"c\n\x0cRequestStats\x12\x45\n\x14\x66ull_read_stats_view\x18\x01 \x01(\x0b\x32%.google.bigtable.v2.FullReadStatsViewH\x00\x42\x0c\n\nstats_viewB\xbd\x01\n\x16\x63om.google.bigtable.v2B\x11RequestStatsProtoP\x01Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\xaa\x02\x18Google.Cloud.Bigtable.V2\xca\x02\x18Google\\Cloud\\Bigtable\\V2\xea\x02\x1bGoogle::Cloud::Bigtable::V2b\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.bigtable.v2.request_stats_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\026com.google.bigtable.v2B\021RequestStatsProtoP\001Z:google.golang.org/genproto/googleapis/bigtable/v2;bigtable\252\002\030Google.Cloud.Bigtable.V2\312\002\030Google\\Cloud\\Bigtable\\V2\352\002\033Google::Cloud::Bigtable::V2' + _READITERATIONSTATS._serialized_start=95 + _READITERATIONSTATS._serialized_end=225 + _REQUESTLATENCYSTATS._serialized_start=227 + _REQUESTLATENCYSTATS._serialized_end=308 + _FULLREADSTATSVIEW._serialized_start=311 + _FULLREADSTATSVIEW._serialized_end=472 + _REQUESTSTATS._serialized_start=474 + _REQUESTSTATS._serialized_end=573 +# @@protoc_insertion_point(module_scope) diff --git a/test_proxy/protos/request_stats_pb2_grpc.py b/test_proxy/protos/request_stats_pb2_grpc.py new file mode 100644 index 000000000..2daafffeb --- /dev/null +++ b/test_proxy/protos/request_stats_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/test_proxy/protos/test_proxy_pb2.py b/test_proxy/protos/test_proxy_pb2.py new file mode 100644 index 000000000..8c7817b14 --- /dev/null +++ b/test_proxy/protos/test_proxy_pb2.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: test_proxy.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import client_pb2 as google_dot_api_dot_client__pb2 +import bigtable_pb2 as google_dot_bigtable_dot_v2_dot_bigtable__pb2 +import data_pb2 as google_dot_bigtable_dot_v2_dot_data__pb2 +from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2 +from google.rpc import status_pb2 as google_dot_rpc_dot_status__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10test_proxy.proto\x12\x19google.bigtable.testproxy\x1a\x17google/api/client.proto\x1a!google/bigtable/v2/bigtable.proto\x1a\x1dgoogle/bigtable/v2/data.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x17google/rpc/status.proto\"\xb8\x01\n\x13\x43reateClientRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x61ta_target\x18\x02 \x01(\t\x12\x12\n\nproject_id\x18\x03 \x01(\t\x12\x13\n\x0binstance_id\x18\x04 \x01(\t\x12\x16\n\x0e\x61pp_profile_id\x18\x05 \x01(\t\x12\x38\n\x15per_operation_timeout\x18\x06 \x01(\x0b\x32\x19.google.protobuf.Duration\"\x16\n\x14\x43reateClientResponse\"\'\n\x12\x43loseClientRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\"\x15\n\x13\x43loseClientResponse\"(\n\x13RemoveClientRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\"\x16\n\x14RemoveClientResponse\"w\n\x0eReadRowRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x12\n\ntable_name\x18\x04 \x01(\t\x12\x0f\n\x07row_key\x18\x02 \x01(\t\x12-\n\x06\x66ilter\x18\x03 \x01(\x0b\x32\x1d.google.bigtable.v2.RowFilter\"U\n\tRowResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12$\n\x03row\x18\x02 \x01(\x0b\x32\x17.google.bigtable.v2.Row\"u\n\x0fReadRowsRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x34\n\x07request\x18\x02 \x01(\x0b\x32#.google.bigtable.v2.ReadRowsRequest\x12\x19\n\x11\x63\x61ncel_after_rows\x18\x03 \x01(\x05\"V\n\nRowsResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12$\n\x03row\x18\x02 \x03(\x0b\x32\x17.google.bigtable.v2.Row\"\\\n\x10MutateRowRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x35\n\x07request\x18\x02 \x01(\x0b\x32$.google.bigtable.v2.MutateRowRequest\"5\n\x0fMutateRowResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\"^\n\x11MutateRowsRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x36\n\x07request\x18\x02 \x01(\x0b\x32%.google.bigtable.v2.MutateRowsRequest\"s\n\x10MutateRowsResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12;\n\x05\x65ntry\x18\x02 \x03(\x0b\x32,.google.bigtable.v2.MutateRowsResponse.Entry\"l\n\x18\x43heckAndMutateRowRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12=\n\x07request\x18\x02 \x01(\x0b\x32,.google.bigtable.v2.CheckAndMutateRowRequest\"|\n\x17\x43heckAndMutateRowResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12=\n\x06result\x18\x02 \x01(\x0b\x32-.google.bigtable.v2.CheckAndMutateRowResponse\"d\n\x14SampleRowKeysRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x39\n\x07request\x18\x02 \x01(\x0b\x32(.google.bigtable.v2.SampleRowKeysRequest\"t\n\x13SampleRowKeysResult\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status\x12\x39\n\x06sample\x18\x02 \x03(\x0b\x32).google.bigtable.v2.SampleRowKeysResponse\"n\n\x19ReadModifyWriteRowRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12>\n\x07request\x18\x02 \x01(\x0b\x32-.google.bigtable.v2.ReadModifyWriteRowRequest2\xa4\t\n\x18\x43loudBigtableV2TestProxy\x12q\n\x0c\x43reateClient\x12..google.bigtable.testproxy.CreateClientRequest\x1a/.google.bigtable.testproxy.CreateClientResponse\"\x00\x12n\n\x0b\x43loseClient\x12-.google.bigtable.testproxy.CloseClientRequest\x1a..google.bigtable.testproxy.CloseClientResponse\"\x00\x12q\n\x0cRemoveClient\x12..google.bigtable.testproxy.RemoveClientRequest\x1a/.google.bigtable.testproxy.RemoveClientResponse\"\x00\x12\\\n\x07ReadRow\x12).google.bigtable.testproxy.ReadRowRequest\x1a$.google.bigtable.testproxy.RowResult\"\x00\x12_\n\x08ReadRows\x12*.google.bigtable.testproxy.ReadRowsRequest\x1a%.google.bigtable.testproxy.RowsResult\"\x00\x12\x66\n\tMutateRow\x12+.google.bigtable.testproxy.MutateRowRequest\x1a*.google.bigtable.testproxy.MutateRowResult\"\x00\x12m\n\x0e\x42ulkMutateRows\x12,.google.bigtable.testproxy.MutateRowsRequest\x1a+.google.bigtable.testproxy.MutateRowsResult\"\x00\x12~\n\x11\x43heckAndMutateRow\x12\x33.google.bigtable.testproxy.CheckAndMutateRowRequest\x1a\x32.google.bigtable.testproxy.CheckAndMutateRowResult\"\x00\x12r\n\rSampleRowKeys\x12/.google.bigtable.testproxy.SampleRowKeysRequest\x1a..google.bigtable.testproxy.SampleRowKeysResult\"\x00\x12r\n\x12ReadModifyWriteRow\x12\x34.google.bigtable.testproxy.ReadModifyWriteRowRequest\x1a$.google.bigtable.testproxy.RowResult\"\x00\x1a\x34\xca\x41\x31\x62igtable-test-proxy-not-accessible.googleapis.comB6\n#com.google.cloud.bigtable.testproxyP\x01Z\r./testproxypbb\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'test_proxy_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n#com.google.cloud.bigtable.testproxyP\001Z\r./testproxypb' + _CLOUDBIGTABLEV2TESTPROXY._options = None + _CLOUDBIGTABLEV2TESTPROXY._serialized_options = b'\312A1bigtable-test-proxy-not-accessible.googleapis.com' + _CREATECLIENTREQUEST._serialized_start=196 + _CREATECLIENTREQUEST._serialized_end=380 + _CREATECLIENTRESPONSE._serialized_start=382 + _CREATECLIENTRESPONSE._serialized_end=404 + _CLOSECLIENTREQUEST._serialized_start=406 + _CLOSECLIENTREQUEST._serialized_end=445 + _CLOSECLIENTRESPONSE._serialized_start=447 + _CLOSECLIENTRESPONSE._serialized_end=468 + _REMOVECLIENTREQUEST._serialized_start=470 + _REMOVECLIENTREQUEST._serialized_end=510 + _REMOVECLIENTRESPONSE._serialized_start=512 + _REMOVECLIENTRESPONSE._serialized_end=534 + _READROWREQUEST._serialized_start=536 + _READROWREQUEST._serialized_end=655 + _ROWRESULT._serialized_start=657 + _ROWRESULT._serialized_end=742 + _READROWSREQUEST._serialized_start=744 + _READROWSREQUEST._serialized_end=861 + _ROWSRESULT._serialized_start=863 + _ROWSRESULT._serialized_end=949 + _MUTATEROWREQUEST._serialized_start=951 + _MUTATEROWREQUEST._serialized_end=1043 + _MUTATEROWRESULT._serialized_start=1045 + _MUTATEROWRESULT._serialized_end=1098 + _MUTATEROWSREQUEST._serialized_start=1100 + _MUTATEROWSREQUEST._serialized_end=1194 + _MUTATEROWSRESULT._serialized_start=1196 + _MUTATEROWSRESULT._serialized_end=1311 + _CHECKANDMUTATEROWREQUEST._serialized_start=1313 + _CHECKANDMUTATEROWREQUEST._serialized_end=1421 + _CHECKANDMUTATEROWRESULT._serialized_start=1423 + _CHECKANDMUTATEROWRESULT._serialized_end=1547 + _SAMPLEROWKEYSREQUEST._serialized_start=1549 + _SAMPLEROWKEYSREQUEST._serialized_end=1649 + _SAMPLEROWKEYSRESULT._serialized_start=1651 + _SAMPLEROWKEYSRESULT._serialized_end=1767 + _READMODIFYWRITEROWREQUEST._serialized_start=1769 + _READMODIFYWRITEROWREQUEST._serialized_end=1879 + _CLOUDBIGTABLEV2TESTPROXY._serialized_start=1882 + _CLOUDBIGTABLEV2TESTPROXY._serialized_end=3070 +# @@protoc_insertion_point(module_scope) diff --git a/test_proxy/protos/test_proxy_pb2_grpc.py b/test_proxy/protos/test_proxy_pb2_grpc.py new file mode 100644 index 000000000..60214a584 --- /dev/null +++ b/test_proxy/protos/test_proxy_pb2_grpc.py @@ -0,0 +1,433 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import test_proxy_pb2 as test__proxy__pb2 + + +class CloudBigtableV2TestProxyStub(object): + """Note that all RPCs are unary, even when the equivalent client binding call + may be streaming. This is an intentional simplification. + + Most methods have sync (default) and async variants. For async variants, + the proxy is expected to perform the async operation, then wait for results + before delivering them back to the driver client. + + Operations that may have interesting concurrency characteristics are + represented explicitly in the API (see ReadRowsRequest.cancel_after_rows). + We include such operations only when they can be meaningfully performed + through client bindings. + + Users should generally avoid setting deadlines for requests to the Proxy + because operations are not cancelable. If the deadline is set anyway, please + understand that the underlying operation will continue to be executed even + after the deadline expires. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.CreateClient = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CreateClient', + request_serializer=test__proxy__pb2.CreateClientRequest.SerializeToString, + response_deserializer=test__proxy__pb2.CreateClientResponse.FromString, + ) + self.CloseClient = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CloseClient', + request_serializer=test__proxy__pb2.CloseClientRequest.SerializeToString, + response_deserializer=test__proxy__pb2.CloseClientResponse.FromString, + ) + self.RemoveClient = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/RemoveClient', + request_serializer=test__proxy__pb2.RemoveClientRequest.SerializeToString, + response_deserializer=test__proxy__pb2.RemoveClientResponse.FromString, + ) + self.ReadRow = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadRow', + request_serializer=test__proxy__pb2.ReadRowRequest.SerializeToString, + response_deserializer=test__proxy__pb2.RowResult.FromString, + ) + self.ReadRows = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadRows', + request_serializer=test__proxy__pb2.ReadRowsRequest.SerializeToString, + response_deserializer=test__proxy__pb2.RowsResult.FromString, + ) + self.MutateRow = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/MutateRow', + request_serializer=test__proxy__pb2.MutateRowRequest.SerializeToString, + response_deserializer=test__proxy__pb2.MutateRowResult.FromString, + ) + self.BulkMutateRows = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/BulkMutateRows', + request_serializer=test__proxy__pb2.MutateRowsRequest.SerializeToString, + response_deserializer=test__proxy__pb2.MutateRowsResult.FromString, + ) + self.CheckAndMutateRow = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CheckAndMutateRow', + request_serializer=test__proxy__pb2.CheckAndMutateRowRequest.SerializeToString, + response_deserializer=test__proxy__pb2.CheckAndMutateRowResult.FromString, + ) + self.SampleRowKeys = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/SampleRowKeys', + request_serializer=test__proxy__pb2.SampleRowKeysRequest.SerializeToString, + response_deserializer=test__proxy__pb2.SampleRowKeysResult.FromString, + ) + self.ReadModifyWriteRow = channel.unary_unary( + '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadModifyWriteRow', + request_serializer=test__proxy__pb2.ReadModifyWriteRowRequest.SerializeToString, + response_deserializer=test__proxy__pb2.RowResult.FromString, + ) + + +class CloudBigtableV2TestProxyServicer(object): + """Note that all RPCs are unary, even when the equivalent client binding call + may be streaming. This is an intentional simplification. + + Most methods have sync (default) and async variants. For async variants, + the proxy is expected to perform the async operation, then wait for results + before delivering them back to the driver client. + + Operations that may have interesting concurrency characteristics are + represented explicitly in the API (see ReadRowsRequest.cancel_after_rows). + We include such operations only when they can be meaningfully performed + through client bindings. + + Users should generally avoid setting deadlines for requests to the Proxy + because operations are not cancelable. If the deadline is set anyway, please + understand that the underlying operation will continue to be executed even + after the deadline expires. + """ + + def CreateClient(self, request, context): + """Client management: + + Creates a client in the proxy. + Each client has its own dedicated channel(s), and can be used concurrently + and independently with other clients. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CloseClient(self, request, context): + """Closes a client in the proxy, making it not accept new requests. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RemoveClient(self, request, context): + """Removes a client in the proxy, making it inaccessible. Client closing + should be done by CloseClient() separately. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ReadRow(self, request, context): + """Bigtable operations: for each operation, you should use the synchronous or + asynchronous variant of the client method based on the `use_async_method` + setting of the client instance. For starters, you can choose to implement + one variant, and return UNIMPLEMENTED status for the other. + + Reads a row with the client instance. + The result row may not be present in the response. + Callers should check for it (e.g. calling has_row() in C++). + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ReadRows(self, request, context): + """Reads rows with the client instance. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def MutateRow(self, request, context): + """Writes a row with the client instance. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def BulkMutateRows(self, request, context): + """Writes multiple rows with the client instance. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CheckAndMutateRow(self, request, context): + """Performs a check-and-mutate-row operation with the client instance. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SampleRowKeys(self, request, context): + """Obtains a row key sampling with the client instance. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ReadModifyWriteRow(self, request, context): + """Performs a read-modify-write operation with the client. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_CloudBigtableV2TestProxyServicer_to_server(servicer, server): + rpc_method_handlers = { + 'CreateClient': grpc.unary_unary_rpc_method_handler( + servicer.CreateClient, + request_deserializer=test__proxy__pb2.CreateClientRequest.FromString, + response_serializer=test__proxy__pb2.CreateClientResponse.SerializeToString, + ), + 'CloseClient': grpc.unary_unary_rpc_method_handler( + servicer.CloseClient, + request_deserializer=test__proxy__pb2.CloseClientRequest.FromString, + response_serializer=test__proxy__pb2.CloseClientResponse.SerializeToString, + ), + 'RemoveClient': grpc.unary_unary_rpc_method_handler( + servicer.RemoveClient, + request_deserializer=test__proxy__pb2.RemoveClientRequest.FromString, + response_serializer=test__proxy__pb2.RemoveClientResponse.SerializeToString, + ), + 'ReadRow': grpc.unary_unary_rpc_method_handler( + servicer.ReadRow, + request_deserializer=test__proxy__pb2.ReadRowRequest.FromString, + response_serializer=test__proxy__pb2.RowResult.SerializeToString, + ), + 'ReadRows': grpc.unary_unary_rpc_method_handler( + servicer.ReadRows, + request_deserializer=test__proxy__pb2.ReadRowsRequest.FromString, + response_serializer=test__proxy__pb2.RowsResult.SerializeToString, + ), + 'MutateRow': grpc.unary_unary_rpc_method_handler( + servicer.MutateRow, + request_deserializer=test__proxy__pb2.MutateRowRequest.FromString, + response_serializer=test__proxy__pb2.MutateRowResult.SerializeToString, + ), + 'BulkMutateRows': grpc.unary_unary_rpc_method_handler( + servicer.BulkMutateRows, + request_deserializer=test__proxy__pb2.MutateRowsRequest.FromString, + response_serializer=test__proxy__pb2.MutateRowsResult.SerializeToString, + ), + 'CheckAndMutateRow': grpc.unary_unary_rpc_method_handler( + servicer.CheckAndMutateRow, + request_deserializer=test__proxy__pb2.CheckAndMutateRowRequest.FromString, + response_serializer=test__proxy__pb2.CheckAndMutateRowResult.SerializeToString, + ), + 'SampleRowKeys': grpc.unary_unary_rpc_method_handler( + servicer.SampleRowKeys, + request_deserializer=test__proxy__pb2.SampleRowKeysRequest.FromString, + response_serializer=test__proxy__pb2.SampleRowKeysResult.SerializeToString, + ), + 'ReadModifyWriteRow': grpc.unary_unary_rpc_method_handler( + servicer.ReadModifyWriteRow, + request_deserializer=test__proxy__pb2.ReadModifyWriteRowRequest.FromString, + response_serializer=test__proxy__pb2.RowResult.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'google.bigtable.testproxy.CloudBigtableV2TestProxy', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class CloudBigtableV2TestProxy(object): + """Note that all RPCs are unary, even when the equivalent client binding call + may be streaming. This is an intentional simplification. + + Most methods have sync (default) and async variants. For async variants, + the proxy is expected to perform the async operation, then wait for results + before delivering them back to the driver client. + + Operations that may have interesting concurrency characteristics are + represented explicitly in the API (see ReadRowsRequest.cancel_after_rows). + We include such operations only when they can be meaningfully performed + through client bindings. + + Users should generally avoid setting deadlines for requests to the Proxy + because operations are not cancelable. If the deadline is set anyway, please + understand that the underlying operation will continue to be executed even + after the deadline expires. + """ + + @staticmethod + def CreateClient(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CreateClient', + test__proxy__pb2.CreateClientRequest.SerializeToString, + test__proxy__pb2.CreateClientResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CloseClient(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CloseClient', + test__proxy__pb2.CloseClientRequest.SerializeToString, + test__proxy__pb2.CloseClientResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def RemoveClient(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/RemoveClient', + test__proxy__pb2.RemoveClientRequest.SerializeToString, + test__proxy__pb2.RemoveClientResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ReadRow(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadRow', + test__proxy__pb2.ReadRowRequest.SerializeToString, + test__proxy__pb2.RowResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ReadRows(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadRows', + test__proxy__pb2.ReadRowsRequest.SerializeToString, + test__proxy__pb2.RowsResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def MutateRow(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/MutateRow', + test__proxy__pb2.MutateRowRequest.SerializeToString, + test__proxy__pb2.MutateRowResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def BulkMutateRows(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/BulkMutateRows', + test__proxy__pb2.MutateRowsRequest.SerializeToString, + test__proxy__pb2.MutateRowsResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CheckAndMutateRow(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/CheckAndMutateRow', + test__proxy__pb2.CheckAndMutateRowRequest.SerializeToString, + test__proxy__pb2.CheckAndMutateRowResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SampleRowKeys(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/SampleRowKeys', + test__proxy__pb2.SampleRowKeysRequest.SerializeToString, + test__proxy__pb2.SampleRowKeysResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def ReadModifyWriteRow(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/google.bigtable.testproxy.CloudBigtableV2TestProxy/ReadModifyWriteRow', + test__proxy__pb2.ReadModifyWriteRowRequest.SerializeToString, + test__proxy__pb2.RowResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh new file mode 100755 index 000000000..15b146b03 --- /dev/null +++ b/test_proxy/run_tests.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# attempt download golang if not found +if [[ ! -x "$(command -v go)" ]]; then + echo "Downloading golang..." + wget https://go.dev/dl/go1.20.2.linux-amd64.tar.gz + tar -xzf go1.20.2.linux-amd64.tar.gz + export GOROOT=$(pwd)/go + export PATH=$GOROOT/bin:$PATH + export GOPATH=$HOME/go + go version +fi + +# ensure the working dir is the script's folder +SCRIPT_DIR=$(realpath $(dirname "$0")) +cd $SCRIPT_DIR + +export PROXY_SERVER_PORT=50055 + +# download test suite +if [ ! -d "cloud-bigtable-clients-test" ]; then + git clone https://github.com/googleapis/cloud-bigtable-clients-test.git +fi + +# start proxy +python test_proxy.py --port $PROXY_SERVER_PORT & +PROXY_PID=$! +function finish { + kill $PROXY_PID +} +trap finish EXIT + +# run tests +pushd cloud-bigtable-clients-test/tests +go test -v -proxy_addr=:$PROXY_SERVER_PORT diff --git a/test_proxy/test_proxy.py b/test_proxy/test_proxy.py new file mode 100644 index 000000000..a0cf2f1f0 --- /dev/null +++ b/test_proxy/test_proxy.py @@ -0,0 +1,193 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The Python implementation of the `cloud-bigtable-clients-test` proxy server. + +https://github.com/googleapis/cloud-bigtable-clients-test + +This server is intended to be used to test the correctness of Bigtable +clients across languages. + +Contributor Note: the proxy implementation is split across TestProxyClientHandler +and TestProxyGrpcServer. This is due to the fact that generated protos and proto-plus +objects cannot be used in the same process, so we had to make use of the +multiprocessing module to allow them to work together. +""" + +import multiprocessing +import argparse +import sys +import os +sys.path.append("handlers") + + +def grpc_server_process(request_q, queue_pool, port=50055): + """ + Defines a process that hosts a grpc server + proxies requests to a client_handler_process + """ + sys.path.append("protos") + from concurrent import futures + + import grpc + import test_proxy_pb2_grpc + import grpc_handler + + # Start gRPC server + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + test_proxy_pb2_grpc.add_CloudBigtableV2TestProxyServicer_to_server( + grpc_handler.TestProxyGrpcServer(request_q, queue_pool), server + ) + server.add_insecure_port("[::]:" + port) + server.start() + print("grpc_server_process started, listening on " + port) + server.wait_for_termination() + + +async def client_handler_process_async(request_q, queue_pool, use_legacy_client=False): + """ + Defines a process that recives Bigtable requests from a grpc_server_process, + and runs the request using a client library instance + """ + import base64 + import re + import asyncio + import warnings + import client_handler_data + import client_handler_legacy + warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*Bigtable emulator.*") + + def camel_to_snake(str): + return re.sub(r"(?= 1.14.0, < 2.0.0dev", # Then this file should have foo==1.14.0 -google-api-core==1.34.0 -google-cloud-core==1.4.4 +google-api-core==2.16.0 +google-cloud-core==2.0.0 grpc-google-iam-v1==0.12.4 proto-plus==1.22.0 libcst==0.2.5 diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index e69de29bb..ee858c3ec 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -0,0 +1,14 @@ +# This constraints file is used to check that lower bounds +# are correct in setup.py +# List *all* library dependencies and extras in this file. +# Pin the version to the lower bound. +# +# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", +# Then this file should have foo==1.14.0 +google-api-core==2.16.0 +google-cloud-core==2.0.0 +grpc-google-iam-v1==0.12.4 +proto-plus==1.22.0 +libcst==0.2.5 +protobuf==3.19.5 +pytest-asyncio==0.21.1 diff --git a/tests/system/__init__.py b/tests/system/__init__.py index 4de65971c..89a37dc92 100644 --- a/tests/system/__init__.py +++ b/tests/system/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2020 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 910c20970..b8862ea4b 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2011 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,199 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +Import pytest fixtures for setting up table for data client system tests +""" +import sys import os -import pytest -from test_utils.system import unique_resource_id - -from google.cloud.bigtable.client import Client -from google.cloud.environment_vars import BIGTABLE_EMULATOR - -from . import _helpers - - -@pytest.fixture(scope="session") -def in_emulator(): - return os.getenv(BIGTABLE_EMULATOR) is not None - - -@pytest.fixture(scope="session") -def kms_key_name(): - return os.getenv("KMS_KEY_NAME") - - -@pytest.fixture(scope="session") -def with_kms_key_name(kms_key_name): - if kms_key_name is None: - pytest.skip("Test requires KMS_KEY_NAME environment variable") - return kms_key_name - - -@pytest.fixture(scope="session") -def skip_on_emulator(in_emulator): - if in_emulator: - pytest.skip("Emulator does not support this feature") - - -@pytest.fixture(scope="session") -def unique_suffix(): - return unique_resource_id("-") - - -@pytest.fixture(scope="session") -def location_id(): - return "us-central1-c" - - -@pytest.fixture(scope="session") -def serve_nodes(): - return 1 - - -@pytest.fixture(scope="session") -def label_key(): - return "python-system" - - -@pytest.fixture(scope="session") -def instance_labels(label_key): - return {label_key: _helpers.label_stamp()} - - -@pytest.fixture(scope="session") -def admin_client(): - return Client(admin=True) - - -@pytest.fixture(scope="session") -def service_account(admin_client): - from google.oauth2.service_account import Credentials - - if not isinstance(admin_client._credentials, Credentials): - pytest.skip("These tests require a service account credential") - return admin_client._credentials - - -@pytest.fixture(scope="session") -def admin_instance_id(unique_suffix): - return f"g-c-p{unique_suffix}" - - -@pytest.fixture(scope="session") -def admin_cluster_id(admin_instance_id): - return f"{admin_instance_id}-cluster" - - -@pytest.fixture(scope="session") -def admin_instance(admin_client, admin_instance_id, instance_labels): - return admin_client.instance(admin_instance_id, labels=instance_labels) - - -@pytest.fixture(scope="session") -def admin_cluster(admin_instance, admin_cluster_id, location_id, serve_nodes): - return admin_instance.cluster( - admin_cluster_id, - location_id=location_id, - serve_nodes=serve_nodes, - ) - - -@pytest.fixture(scope="session") -def admin_cluster_with_autoscaling( - admin_instance, - admin_cluster_id, - location_id, - min_serve_nodes, - max_serve_nodes, - cpu_utilization_percent, -): - return admin_instance.cluster( - admin_cluster_id, - location_id=location_id, - min_serve_nodes=min_serve_nodes, - max_serve_nodes=max_serve_nodes, - cpu_utilization_percent=cpu_utilization_percent, - ) - - -@pytest.fixture(scope="session") -def admin_instance_populated(admin_instance, admin_cluster, in_emulator): - # Emulator does not support instance admin operations (create / delete). - # See: https://cloud.google.com/bigtable/docs/emulator - if not in_emulator: - operation = admin_instance.create(clusters=[admin_cluster]) - operation.result(timeout=240) - - yield admin_instance - - if not in_emulator: - _helpers.retry_429(admin_instance.delete)() - - -@pytest.fixture(scope="session") -def data_client(): - return Client(admin=False) - - -@pytest.fixture(scope="session") -def data_instance_id(unique_suffix): - return f"g-c-p-d{unique_suffix}" - - -@pytest.fixture(scope="session") -def data_cluster_id(data_instance_id): - return f"{data_instance_id}-cluster" - - -@pytest.fixture(scope="session") -def data_instance_populated( - admin_client, - data_instance_id, - instance_labels, - data_cluster_id, - location_id, - serve_nodes, - in_emulator, -): - instance = admin_client.instance(data_instance_id, labels=instance_labels) - # Emulator does not support instance admin operations (create / delete). - # See: https://cloud.google.com/bigtable/docs/emulator - if not in_emulator: - cluster = instance.cluster( - data_cluster_id, - location_id=location_id, - serve_nodes=serve_nodes, - ) - operation = instance.create(clusters=[cluster]) - operation.result(timeout=240) - - yield instance - - if not in_emulator: - _helpers.retry_429(instance.delete)() - - -@pytest.fixture(scope="function") -def instances_to_delete(): - instances_to_delete = [] - - yield instances_to_delete - - for instance in instances_to_delete: - _helpers.retry_429(instance.delete)() - - -@pytest.fixture(scope="session") -def min_serve_nodes(in_emulator): - return 1 - - -@pytest.fixture(scope="session") -def max_serve_nodes(in_emulator): - return 8 - +script_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(script_path) -@pytest.fixture(scope="session") -def cpu_utilization_percent(in_emulator): - return 10 +pytest_plugins = [ + "data.setup_fixtures", +] diff --git a/tests/system/data/__init__.py b/tests/system/data/__init__.py new file mode 100644 index 000000000..89a37dc92 --- /dev/null +++ b/tests/system/data/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py new file mode 100644 index 000000000..77086b7f3 --- /dev/null +++ b/tests/system/data/setup_fixtures.py @@ -0,0 +1,171 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains a set of pytest fixtures for setting up and populating a +Bigtable database for testing purposes. +""" + +import pytest +import pytest_asyncio +import os +import asyncio +import uuid + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.stop() + loop.close() + + +@pytest.fixture(scope="session") +def admin_client(): + """ + Client for interacting with Table and Instance admin APIs + """ + from google.cloud.bigtable.client import Client + + client = Client(admin=True) + yield client + + +@pytest.fixture(scope="session") +def instance_id(admin_client, project_id, cluster_config): + """ + Returns BIGTABLE_TEST_INSTANCE if set, otherwise creates a new temporary instance for the test session + """ + from google.cloud.bigtable_admin_v2 import types + from google.api_core import exceptions + from google.cloud.environment_vars import BIGTABLE_EMULATOR + + # use user-specified instance if available + user_specified_instance = os.getenv("BIGTABLE_TEST_INSTANCE") + if user_specified_instance: + print("Using user-specified instance: {}".format(user_specified_instance)) + yield user_specified_instance + return + + # create a new temporary test instance + instance_id = f"python-bigtable-tests-{uuid.uuid4().hex[:6]}" + if os.getenv(BIGTABLE_EMULATOR): + # don't create instance if in emulator mode + yield instance_id + else: + try: + operation = admin_client.instance_admin_client.create_instance( + parent=f"projects/{project_id}", + instance_id=instance_id, + instance=types.Instance( + display_name="Test Instance", + # labels={"python-system-test": "true"}, + ), + clusters=cluster_config, + ) + operation.result(timeout=240) + except exceptions.AlreadyExists: + pass + yield instance_id + admin_client.instance_admin_client.delete_instance( + name=f"projects/{project_id}/instances/{instance_id}" + ) + + +@pytest.fixture(scope="session") +def column_split_config(): + """ + specify initial splits to create when creating a new test table + """ + return [(num * 1000).to_bytes(8, "big") for num in range(1, 10)] + + +@pytest.fixture(scope="session") +def table_id( + admin_client, + project_id, + instance_id, + column_family_config, + init_table_id, + column_split_config, +): + """ + Returns BIGTABLE_TEST_TABLE if set, otherwise creates a new temporary table for the test session + + Args: + - admin_client: Client for interacting with the Table Admin API. Supplied by the admin_client fixture. + - project_id: The project ID of the GCP project to test against. Supplied by the project_id fixture. + - instance_id: The ID of the Bigtable instance to test against. Supplied by the instance_id fixture. + - init_column_families: A list of column families to initialize the table with, if pre-initialized table is not given with BIGTABLE_TEST_TABLE. + Supplied by the init_column_families fixture. + - init_table_id: The table ID to give to the test table, if pre-initialized table is not given with BIGTABLE_TEST_TABLE. + Supplied by the init_table_id fixture. + - column_split_config: A list of row keys to use as initial splits when creating the test table. + """ + from google.api_core import exceptions + from google.api_core import retry + + # use user-specified instance if available + user_specified_table = os.getenv("BIGTABLE_TEST_TABLE") + if user_specified_table: + print("Using user-specified table: {}".format(user_specified_table)) + yield user_specified_table + return + + retry = retry.Retry( + predicate=retry.if_exception_type(exceptions.FailedPrecondition) + ) + try: + parent_path = f"projects/{project_id}/instances/{instance_id}" + print(f"Creating table: {parent_path}/tables/{init_table_id}") + admin_client.table_admin_client.create_table( + request={ + "parent": parent_path, + "table_id": init_table_id, + "table": {"column_families": column_family_config}, + "initial_splits": [{"key": key} for key in column_split_config], + }, + retry=retry, + ) + except exceptions.AlreadyExists: + pass + yield init_table_id + print(f"Deleting table: {parent_path}/tables/{init_table_id}") + try: + admin_client.table_admin_client.delete_table( + name=f"{parent_path}/tables/{init_table_id}" + ) + except exceptions.NotFound: + print(f"Table {init_table_id} not found, skipping deletion") + + +@pytest_asyncio.fixture(scope="session") +async def client(): + from google.cloud.bigtable.data import BigtableDataClientAsync + + project = os.getenv("GOOGLE_CLOUD_PROJECT") or None + async with BigtableDataClientAsync(project=project, pool_size=4) as client: + yield client + + +@pytest.fixture(scope="session") +def project_id(client): + """Returns the project ID from the client.""" + yield client.project + + +@pytest_asyncio.fixture(scope="session") +async def table(client, table_id, instance_id): + async with client.get_table(instance_id, table_id) as table: + yield table diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py new file mode 100644 index 000000000..aeb08fc1a --- /dev/null +++ b/tests/system/data/test_system.py @@ -0,0 +1,943 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +import asyncio +import uuid +import os +from google.api_core import retry +from google.api_core.exceptions import ClientError + +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE +from google.cloud.environment_vars import BIGTABLE_EMULATOR + +TEST_FAMILY = "test-family" +TEST_FAMILY_2 = "test-family-2" + + +@pytest.fixture(scope="session") +def column_family_config(): + """ + specify column families to create when creating a new test table + """ + from google.cloud.bigtable_admin_v2 import types + + return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + + +@pytest.fixture(scope="session") +def init_table_id(): + """ + The table_id to use when creating a new test table + """ + return f"test-table-{uuid.uuid4().hex}" + + +@pytest.fixture(scope="session") +def cluster_config(project_id): + """ + Configuration for the clusters to use when creating a new instance + """ + from google.cloud.bigtable_admin_v2 import types + + cluster = { + "test-cluster": types.Cluster( + location=f"projects/{project_id}/locations/us-central1-b", + serve_nodes=1, + ) + } + return cluster + + +class TempRowBuilder: + """ + Used to add rows to a table for testing purposes. + """ + + def __init__(self, table): + self.rows = [] + self.table = table + + async def add_row( + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + ): + if isinstance(value, str): + value = value.encode("utf-8") + elif isinstance(value, int): + value = value.to_bytes(8, byteorder="big", signed=True) + request = { + "table_name": self.table.table_name, + "row_key": row_key, + "mutations": [ + { + "set_cell": { + "family_name": family, + "column_qualifier": qualifier, + "value": value, + } + } + ], + } + await self.table.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + + async def delete_rows(self): + if self.rows: + request = { + "table_name": self.table.table_name, + "entries": [ + {"row_key": row, "mutations": [{"delete_from_row": {}}]} + for row in self.rows + ], + } + await self.table.client._gapic_client.mutate_rows(request) + + +@pytest.mark.usefixtures("table") +async def _retrieve_cell_value(table, row_key): + """ + Helper to read an individual row + """ + from google.cloud.bigtable.data import ReadRowsQuery + + row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) + assert len(row_list) == 1 + row = row_list[0] + cell = row.cells[0] + return cell.value + + +async def _create_row_and_mutation( + table, temp_rows, *, start_value=b"start", new_value=b"new_value" +): + """ + Helper to create a new row, and a sample set_cell mutation to change its value + """ + from google.cloud.bigtable.data.mutations import SetCell + + row_key = uuid.uuid4().hex.encode() + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + # ensure cell is initialized + assert (await _retrieve_cell_value(table, row_key)) == start_value + + mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) + return row_key, mutation + + +@pytest.mark.usefixtures("table") +@pytest_asyncio.fixture(scope="function") +async def temp_rows(table): + builder = TempRowBuilder(table) + yield builder + await builder.delete_rows() + + +@pytest.mark.usefixtures("table") +@pytest.mark.usefixtures("client") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=10) +@pytest.mark.asyncio +async def test_ping_and_warm_gapic(client, table): + """ + Simple ping rpc test + This test ensures channels are able to authenticate with backend + """ + request = {"name": table.instance_name} + await client._gapic_client.ping_and_warm(request) + + +@pytest.mark.usefixtures("table") +@pytest.mark.usefixtures("client") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_ping_and_warm(client, table): + """ + Test ping and warm from handwritten client + """ + try: + channel = client.transport._grpc_channel.pool[0] + except Exception: + # for sync client + channel = client.transport._grpc_channel + results = await client._ping_and_warm_instances(channel) + assert len(results) == 1 + assert results[0] is None + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +async def test_mutation_set_cell(table, temp_rows): + """ + Ensure cells can be set properly + """ + row_key = b"bulk_mutate" + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + await table.mutate_row(row_key, mutation) + + # ensure cell is updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + + +@pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" +) +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_sample_row_keys(client, table, temp_rows, column_split_config): + """ + Sample keys should return a single sample in small test tables + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + results = await table.sample_row_keys() + assert len(results) == len(column_split_config) + 1 + # first keys should match the split config + for idx in range(len(column_split_config)): + assert results[idx][0] == column_split_config[idx] + assert isinstance(results[idx][1], int) + # last sample should be empty key + assert results[-1][0] == b"" + assert isinstance(results[-1][1], int) + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_bulk_mutations_set_cell(client, table, temp_rows): + """ + Ensure cells can be set properly + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + + await table.bulk_mutate_rows([bulk_mutation]) + + # ensure cell is updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + + +@pytest.mark.asyncio +async def test_bulk_mutations_raise_exception(client, table): + """ + If an invalid mutation is passed, an exception should be raised + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + row_key = uuid.uuid4().hex.encode() + mutation = SetCell(family="nonexistent", qualifier=b"test-qualifier", new_value=b"") + bulk_mutation = RowMutationEntry(row_key, [mutation]) + + with pytest.raises(MutationsExceptionGroup) as exc: + await table.bulk_mutate_rows([bulk_mutation]) + assert len(exc.value.exceptions) == 1 + entry_error = exc.value.exceptions[0] + assert isinstance(entry_error, FailedMutationEntryError) + assert entry_error.index == 0 + assert entry_error.entry == bulk_mutation + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_mutations_batcher_context_manager(client, table, temp_rows): + """ + test batcher with context manager. Should flush on exit + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + row_key2, mutation2 = await _create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + async with table.mutations_batcher() as batcher: + await batcher.append(bulk_mutation) + await batcher.append(bulk_mutation2) + # ensure cell is updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + assert len(batcher._staged_entries) == 0 + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_mutations_batcher_timer_flush(client, table, temp_rows): + """ + batch should occur after flush_interval seconds + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + async with table.mutations_batcher(flush_interval=flush_interval) as batcher: + await batcher.append(bulk_mutation) + await asyncio.sleep(0) + assert len(batcher._staged_entries) == 1 + await asyncio.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + # ensure cell is updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_mutations_batcher_count_flush(client, table, temp_rows): + """ + batch should flush after flush_limit_mutation_count mutations + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await _create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + # should be noop; flush not scheduled + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + await asyncio.gather(*batcher._flush_jobs) + assert len(batcher._staged_entries) == 0 + assert len(batcher._flush_jobs) == 0 + # ensure cells were updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + assert (await _retrieve_cell_value(table, row_key2)) == new_value2 + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_mutations_batcher_bytes_flush(client, table, temp_rows): + """ + batch should flush after flush_limit_bytes bytes + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await _create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + + async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + assert len(batcher._staged_entries) == 0 + # let flush complete + await asyncio.gather(*batcher._flush_jobs) + # ensure cells were updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + assert (await _retrieve_cell_value(table, row_key2)) == new_value2 + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_mutations_batcher_no_flush(client, table, temp_rows): + """ + test with no flush requirements met + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + start_value = b"unchanged" + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await _create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 + async with table.mutations_batcher( + flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 + ) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # flush not scheduled + assert len(batcher._flush_jobs) == 0 + await asyncio.sleep(0.01) + assert len(batcher._staged_entries) == 2 + assert len(batcher._flush_jobs) == 0 + # ensure cells were not updated + assert (await _retrieve_cell_value(table, row_key)) == start_value + assert (await _retrieve_cell_value(table, row_key2)) == start_value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), + ], +) +@pytest.mark.asyncio +async def test_read_modify_write_row_increment( + client, table, temp_rows, start, increment, expected +): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + + rule = IncrementRule(family, qualifier, increment) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + # ensure that reading from server gives same value + assert (await _retrieve_cell_value(table, row_key)) == result[0].value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], +) +@pytest.mark.asyncio +async def test_read_modify_write_row_append( + client, table, temp_rows, start, append, expected +): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + + rule = AppendValueRule(family, qualifier, append) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + # ensure that reading from server gives same value + assert (await _retrieve_cell_value(table, row_key)) == result[0].value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_modify_write_row_chained(client, table, temp_rows): + """ + test read_modify_write_row with multiple rules + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + await temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + # ensure that reading from server gives same value + assert (await _retrieve_cell_value(table, row_key)) == result[0].value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [ + (1, (0, 2), True), + (-1, (0, 2), False), + ], +) +@pytest.mark.asyncio +async def test_check_and_mutate( + client, table, temp_rows, start_val, predicate_range, expected_result +): + """ + test that check_and_mutate_row works applies the right mutations, and returns the right result + """ + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + + await temp_rows.add_row( + row_key, value=start_val, family=family, qualifier=qualifier + ) + + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = await table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + # ensure cell is updated + expected_value = true_mutation_value if expected_result else false_mutation_value + assert (await _retrieve_cell_value(table, row_key)) == expected_value + + +@pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", +) +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_check_and_mutate_empty_request(client, table): + """ + check_and_mutate with no true or fale mutations should raise an error + """ + from google.api_core import exceptions + + with pytest.raises(exceptions.InvalidArgument) as e: + await table.check_and_mutate_row( + b"row_key", None, true_case_mutations=None, false_case_mutations=None + ) + assert "No mutations provided" in str(e.value) + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_stream(table, temp_rows): + """ + Ensure that the read_rows_stream method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + # full table scan + generator = await table.read_rows_stream({}) + first_row = await generator.__anext__() + second_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + assert second_row.row_key == b"row_key_2" + with pytest.raises(StopAsyncIteration): + await generator.__anext__() + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows(table, temp_rows): + """ + Ensure that the read_rows method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + row_list = await table.read_rows({}) + assert len(row_list) == 2 + assert row_list[0].row_key == b"row_key_1" + assert row_list[1].row_key == b"row_key_2" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_sharded_simple(table, temp_rows): + """ + Test read rows sharded with two queries + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 4 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"b" + assert row_list[3].row_key == b"d" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_sharded_from_sample(table, temp_rows): + """ + Test end-to-end sharding + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) + shard_queries = query.shard(table_shard_keys) + row_list = await table.read_rows_sharded(shard_queries) + assert len(row_list) == 3 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"d" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_sharded_filters_limits(table, temp_rows): + """ + Test read rows sharded with filters and limits + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + label_filter1 = ApplyLabelFilter("first") + label_filter2 = ApplyLabelFilter("second") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 3 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"b" + assert row_list[2].row_key == b"d" + assert row_list[0][0].labels == ["first"] + assert row_list[1][0].labels == ["second"] + assert row_list[2][0].labels == ["second"] + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_range_query(table, temp_rows): + """ + Ensure that the read_rows method works + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # full table scan + query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_single_key_query(table, temp_rows): + """ + Ensure that the read_rows method works with specified query + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve specific keys + query = ReadRowsQuery(row_keys=[b"a", b"c"]) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_with_filter(table, temp_rows): + """ + ensure filters are applied + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve keys with filter + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = await table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_rows_stream_close(table, temp_rows): + """ + Ensure that the read_rows_stream can be closed + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + query = ReadRowsQuery() + generator = await table.read_rows_stream(query) + # grab first row + first_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + # close stream early + await generator.aclose() + with pytest.raises(StopAsyncIteration): + await generator.__anext__() + + +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_row(table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + + await temp_rows.add_row(b"row_key_1", value=b"value") + row = await table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + +@pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", +) +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_row_missing(table): + """ + Test read_row when row does not exist + """ + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = await table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + await table.read_row("") + assert "Row keys must be non-empty" in str(e) + + +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_row_w_filter(table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = await table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + +@pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", +) +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_row_exists(table, temp_rows): + from google.api_core import exceptions + + """Test row_exists with rows that exist and don't exist""" + assert await table.row_exists(b"row_key_1") is False + await temp_rows.add_row(b"row_key_1") + assert await table.row_exists(b"row_key_1") is True + assert await table.row_exists("row_key_1") is True + assert await table.row_exists(b"row_key_2") is False + assert await table.row_exists("row_key_2") is False + assert await table.row_exists("3") is False + await temp_rows.add_row(b"3") + assert await table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + await table.row_exists("") + assert "Row keys must be non-empty" in str(e) + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.parametrize( + "cell_value,filter_input,expect_match", + [ + (b"abc", b"abc", True), + (b"abc", "abc", True), + (b".", ".", True), + (".*", ".*", True), + (".*", b".*", True), + ("a", ".*", False), + (b".*", b".*", True), + (r"\a", r"\a", True), + (b"\xe2\x98\x83", "ā˜ƒ", True), + ("ā˜ƒ", "ā˜ƒ", True), + (r"\Cā˜ƒ", r"\Cā˜ƒ", True), + (1, 1, True), + (2, 1, False), + (68, 68, True), + ("D", 68, False), + (68, "D", False), + (-1, -1, True), + (2852126720, 2852126720, True), + (-1431655766, -1431655766, True), + (-1431655766, -1, False), + ], +) +@pytest.mark.asyncio +async def test_literal_value_filter( + table, temp_rows, cell_value, filter_input, expect_match +): + """ + Literal value filter does complex escaping on re2 strings. + Make sure inputs are properly interpreted by the server + """ + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable.data import ReadRowsQuery + + f = LiteralValueFilter(filter_input) + await temp_rows.add_row(b"row_key_1", value=cell_value) + query = ReadRowsQuery(row_filter=f) + row_list = await table.read_rows(query) + assert len(row_list) == bool( + expect_match + ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/system/v2_client/__init__.py b/tests/system/v2_client/__init__.py new file mode 100644 index 000000000..4de65971c --- /dev/null +++ b/tests/system/v2_client/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/system/_helpers.py b/tests/system/v2_client/_helpers.py similarity index 100% rename from tests/system/_helpers.py rename to tests/system/v2_client/_helpers.py diff --git a/tests/system/v2_client/conftest.py b/tests/system/v2_client/conftest.py new file mode 100644 index 000000000..f39fcba88 --- /dev/null +++ b/tests/system/v2_client/conftest.py @@ -0,0 +1,209 @@ +# Copyright 2011 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +from test_utils.system import unique_resource_id + +from google.cloud.bigtable.client import Client +from google.cloud.environment_vars import BIGTABLE_EMULATOR + +from . import _helpers + + +@pytest.fixture(scope="session") +def in_emulator(): + return os.getenv(BIGTABLE_EMULATOR) is not None + + +@pytest.fixture(scope="session") +def kms_key_name(): + return os.getenv("KMS_KEY_NAME") + + +@pytest.fixture(scope="session") +def with_kms_key_name(kms_key_name): + if kms_key_name is None: + pytest.skip("Test requires KMS_KEY_NAME environment variable") + return kms_key_name + + +@pytest.fixture(scope="session") +def skip_on_emulator(in_emulator): + if in_emulator: + pytest.skip("Emulator does not support this feature") + + +@pytest.fixture(scope="session") +def unique_suffix(): + return unique_resource_id("-") + + +@pytest.fixture(scope="session") +def location_id(): + return "us-central1-c" + + +@pytest.fixture(scope="session") +def serve_nodes(): + return 3 + + +@pytest.fixture(scope="session") +def label_key(): + return "python-system" + + +@pytest.fixture(scope="session") +def instance_labels(label_key): + return {label_key: _helpers.label_stamp()} + + +@pytest.fixture(scope="session") +def admin_client(): + return Client(admin=True) + + +@pytest.fixture(scope="session") +def service_account(admin_client): + from google.oauth2.service_account import Credentials + + if not isinstance(admin_client._credentials, Credentials): + pytest.skip("These tests require a service account credential") + return admin_client._credentials + + +@pytest.fixture(scope="session") +def admin_instance_id(unique_suffix): + return f"g-c-p{unique_suffix}" + + +@pytest.fixture(scope="session") +def admin_cluster_id(admin_instance_id): + return f"{admin_instance_id}-cluster" + + +@pytest.fixture(scope="session") +def admin_instance(admin_client, admin_instance_id, instance_labels): + return admin_client.instance(admin_instance_id, labels=instance_labels) + + +@pytest.fixture(scope="session") +def admin_cluster(admin_instance, admin_cluster_id, location_id, serve_nodes): + return admin_instance.cluster( + admin_cluster_id, + location_id=location_id, + serve_nodes=serve_nodes, + ) + + +@pytest.fixture(scope="session") +def admin_cluster_with_autoscaling( + admin_instance, + admin_cluster_id, + location_id, + min_serve_nodes, + max_serve_nodes, + cpu_utilization_percent, +): + return admin_instance.cluster( + admin_cluster_id, + location_id=location_id, + min_serve_nodes=min_serve_nodes, + max_serve_nodes=max_serve_nodes, + cpu_utilization_percent=cpu_utilization_percent, + ) + + +@pytest.fixture(scope="session") +def admin_instance_populated(admin_instance, admin_cluster, in_emulator): + # Emulator does not support instance admin operations (create / delete). + # See: https://cloud.google.com/bigtable/docs/emulator + if not in_emulator: + operation = admin_instance.create(clusters=[admin_cluster]) + operation.result(timeout=240) + + yield admin_instance + + if not in_emulator: + _helpers.retry_429(admin_instance.delete)() + + +@pytest.fixture(scope="session") +def data_client(): + return Client(admin=False) + + +@pytest.fixture(scope="session") +def data_instance_id(unique_suffix): + return f"g-c-p-d{unique_suffix}" + + +@pytest.fixture(scope="session") +def data_cluster_id(data_instance_id): + return f"{data_instance_id}-cluster" + + +@pytest.fixture(scope="session") +def data_instance_populated( + admin_client, + data_instance_id, + instance_labels, + data_cluster_id, + location_id, + serve_nodes, + in_emulator, +): + instance = admin_client.instance(data_instance_id, labels=instance_labels) + # Emulator does not support instance admin operations (create / delete). + # See: https://cloud.google.com/bigtable/docs/emulator + if not in_emulator: + cluster = instance.cluster( + data_cluster_id, + location_id=location_id, + serve_nodes=serve_nodes, + ) + operation = instance.create(clusters=[cluster]) + operation.result(timeout=240) + + yield instance + + if not in_emulator: + _helpers.retry_429(instance.delete)() + + +@pytest.fixture(scope="function") +def instances_to_delete(): + instances_to_delete = [] + + yield instances_to_delete + + for instance in instances_to_delete: + _helpers.retry_429(instance.delete)() + + +@pytest.fixture(scope="session") +def min_serve_nodes(in_emulator): + return 1 + + +@pytest.fixture(scope="session") +def max_serve_nodes(in_emulator): + return 8 + + +@pytest.fixture(scope="session") +def cpu_utilization_percent(in_emulator): + return 10 diff --git a/tests/system/test_data_api.py b/tests/system/v2_client/test_data_api.py similarity index 100% rename from tests/system/test_data_api.py rename to tests/system/v2_client/test_data_api.py diff --git a/tests/system/test_instance_admin.py b/tests/system/v2_client/test_instance_admin.py similarity index 100% rename from tests/system/test_instance_admin.py rename to tests/system/v2_client/test_instance_admin.py diff --git a/tests/system/test_table_admin.py b/tests/system/v2_client/test_table_admin.py similarity index 100% rename from tests/system/test_table_admin.py rename to tests/system/v2_client/test_table_admin.py diff --git a/tests/unit/data/__init__.py b/tests/unit/data/__init__.py new file mode 100644 index 000000000..89a37dc92 --- /dev/null +++ b/tests/unit/data/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py new file mode 100644 index 000000000..e03028c45 --- /dev/null +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -0,0 +1,378 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.rpc import status_pb2 +import google.api_core.exceptions as core_exceptions + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # type: ignore +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore + + +def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + +class TestMutateRowsOperation: + def _target_class(self): + from google.cloud.bigtable.data._async._mutate_rows import ( + _MutateRowsOperationAsync, + ) + + return _MutateRowsOperationAsync + + def _make_one(self, *args, **kwargs): + if not args: + kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) + kwargs["table"] = kwargs.pop("table", AsyncMock()) + kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) + kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) + kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) + kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) + return self._target_class()(*args, **kwargs) + + async def _mock_stream(self, mutation_list, error_dict): + for idx, entry in enumerate(mutation_list): + code = error_dict.get(idx, 0) + yield MutateRowsResponse( + entries=[ + MutateRowsResponse.Entry( + index=idx, status=status_pb2.Status(code=code) + ) + ] + ) + + def _make_mock_gapic(self, mutation_list, error_dict=None): + mock_fn = AsyncMock() + if error_dict is None: + error_dict = {} + mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( + mutation_list, error_dict + ) + return mock_fn + + def test_ctor(self): + """ + test that constructor sets all the attributes correctly + """ + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import Aborted + + client = mock.Mock() + table = mock.Mock() + entries = [_make_mutation(), _make_mutation()] + operation_timeout = 0.05 + attempt_timeout = 0.01 + retryable_exceptions = () + instance = self._make_one( + client, + table, + entries, + operation_timeout, + attempt_timeout, + retryable_exceptions, + ) + # running gapic_fn should trigger a client call + assert client.mutate_rows.call_count == 0 + instance._gapic_fn() + assert client.mutate_rows.call_count == 1 + # gapic_fn should call with table details + inner_kwargs = client.mutate_rows.call_args[1] + assert len(inner_kwargs) == 4 + assert inner_kwargs["table_name"] == table.table_name + assert inner_kwargs["app_profile_id"] == table.app_profile_id + assert inner_kwargs["retry"] is None + metadata = inner_kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert str(table.table_name) in metadata[0][1] + assert str(table.app_profile_id) in metadata[0][1] + # entries should be passed down + entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] + assert instance.mutations == entries_w_pb + # timeout_gen should generate per-attempt timeout + assert next(instance.timeout_generator) == attempt_timeout + # ensure predicate is set + assert instance.is_retryable is not None + assert instance.is_retryable(DeadlineExceeded("")) is False + assert instance.is_retryable(Aborted("")) is False + assert instance.is_retryable(_MutateRowsIncomplete("")) is True + assert instance.is_retryable(RuntimeError("")) is False + assert instance.remaining_indices == list(range(len(entries))) + assert instance.errors == {} + + def test_ctor_too_many_entries(self): + """ + should raise an error if an operation is created with more than 100,000 entries + """ + from google.cloud.bigtable.data._async._mutate_rows import ( + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + ) + + assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100_000 + + client = mock.Mock() + table = mock.Mock() + entries = [_make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + operation_timeout = 0.05 + attempt_timeout = 0.01 + # no errors if at limit + self._make_one(client, table, entries, operation_timeout, attempt_timeout) + # raise error after crossing + with pytest.raises(ValueError) as e: + self._make_one( + client, + table, + entries + [_make_mutation()], + operation_timeout, + attempt_timeout, + ) + assert "mutate_rows requests can contain at most 100000 mutations" in str( + e.value + ) + assert "Found 100001" in str(e.value) + + @pytest.mark.asyncio + async def test_mutate_rows_operation(self): + """ + Test successful case of mutate_rows_operation + """ + client = mock.Mock() + table = mock.Mock() + entries = [_make_mutation(), _make_mutation()] + operation_timeout = 0.05 + cls = self._target_class() + with mock.patch( + f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock() + ) as attempt_mock: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + await instance.start() + assert attempt_mock.call_count == 1 + + @pytest.mark.parametrize( + "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + ) + @pytest.mark.asyncio + async def test_mutate_rows_attempt_exception(self, exc_type): + """ + exceptions raised from attempt should be raised in MutationsExceptionGroup + """ + client = AsyncMock() + table = mock.Mock() + entries = [_make_mutation(), _make_mutation()] + operation_timeout = 0.05 + expected_exception = exc_type("test") + client.mutate_rows.side_effect = expected_exception + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + await instance._run_attempt() + except Exception as e: + found_exc = e + assert client.mutate_rows.call_count == 1 + assert type(found_exc) is exc_type + assert found_exc == expected_exception + assert len(instance.errors) == 2 + assert len(instance.remaining_indices) == 0 + + @pytest.mark.parametrize( + "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + ) + @pytest.mark.asyncio + async def test_mutate_rows_exception(self, exc_type): + """ + exceptions raised from retryable should be raised in MutationsExceptionGroup + """ + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + client = mock.Mock() + table = mock.Mock() + entries = [_make_mutation(), _make_mutation()] + operation_timeout = 0.05 + expected_cause = exc_type("abort") + with mock.patch.object( + self._target_class(), + "_run_attempt", + AsyncMock(), + ) as attempt_mock: + attempt_mock.side_effect = expected_cause + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + await instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count == 1 + assert len(found_exc.exceptions) == 2 + assert isinstance(found_exc.exceptions[0], FailedMutationEntryError) + assert isinstance(found_exc.exceptions[1], FailedMutationEntryError) + assert found_exc.exceptions[0].__cause__ == expected_cause + assert found_exc.exceptions[1].__cause__ == expected_cause + + @pytest.mark.parametrize( + "exc_type", + [core_exceptions.DeadlineExceeded, RuntimeError], + ) + @pytest.mark.asyncio + async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): + """ + If an exception fails but eventually passes, it should not raise an exception + """ + from google.cloud.bigtable.data._async._mutate_rows import ( + _MutateRowsOperationAsync, + ) + + client = mock.Mock() + table = mock.Mock() + entries = [_make_mutation()] + operation_timeout = 1 + expected_cause = exc_type("retry") + num_retries = 2 + with mock.patch.object( + _MutateRowsOperationAsync, + "_run_attempt", + AsyncMock(), + ) as attempt_mock: + attempt_mock.side_effect = [expected_cause] * num_retries + [None] + instance = self._make_one( + client, + table, + entries, + operation_timeout, + operation_timeout, + retryable_exceptions=(exc_type,), + ) + await instance.start() + assert attempt_mock.call_count == num_retries + 1 + + @pytest.mark.asyncio + async def test_mutate_rows_incomplete_ignored(self): + """ + MutateRowsIncomplete exceptions should not be added to error list + """ + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + client = mock.Mock() + table = mock.Mock() + entries = [_make_mutation()] + operation_timeout = 0.05 + with mock.patch.object( + self._target_class(), + "_run_attempt", + AsyncMock(), + ) as attempt_mock: + attempt_mock.side_effect = _MutateRowsIncomplete("ignored") + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + await instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count > 0 + assert len(found_exc.exceptions) == 1 + assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) + + @pytest.mark.asyncio + async def test_run_attempt_single_entry_success(self): + """Test mutating a single entry""" + mutation = _make_mutation() + expected_timeout = 1.3 + mock_gapic_fn = self._make_mock_gapic({0: mutation}) + instance = self._make_one( + mutation_entries=[mutation], + attempt_timeout=expected_timeout, + ) + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + await instance._run_attempt() + assert len(instance.remaining_indices) == 0 + assert mock_gapic_fn.call_count == 1 + _, kwargs = mock_gapic_fn.call_args + assert kwargs["timeout"] == expected_timeout + assert kwargs["entries"] == [mutation._to_pb()] + + @pytest.mark.asyncio + async def test_run_attempt_empty_request(self): + """Calling with no mutations should result in no API calls""" + mock_gapic_fn = self._make_mock_gapic([]) + instance = self._make_one( + mutation_entries=[], + ) + await instance._run_attempt() + assert mock_gapic_fn.call_count == 0 + + @pytest.mark.asyncio + async def test_run_attempt_partial_success_retryable(self): + """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + + success_mutation = _make_mutation() + success_mutation_2 = _make_mutation() + failure_mutation = _make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one( + mutation_entries=mutations, + ) + instance.is_retryable = lambda x: True + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + with pytest.raises(_MutateRowsIncomplete): + await instance._run_attempt() + assert instance.remaining_indices == [1] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors + + @pytest.mark.asyncio + async def test_run_attempt_partial_success_non_retryable(self): + """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" + success_mutation = _make_mutation() + success_mutation_2 = _make_mutation() + failure_mutation = _make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one( + mutation_entries=mutations, + ) + instance.is_retryable = lambda x: False + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + await instance._run_attempt() + assert instance.remaining_indices == [] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py new file mode 100644 index 000000000..4e7797c6d --- /dev/null +++ b/tests/unit/data/_async/test__read_rows.py @@ -0,0 +1,391 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # type: ignore +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore # noqa F401 + +TEST_FAMILY = "family_name" +TEST_QUALIFIER = b"qualifier" +TEST_TIMESTAMP = 123456789 +TEST_LABELS = ["label1", "label2"] + + +class TestReadRowsOperation: + """ + Tests helper functions in the ReadRowsOperation class + in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt + is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests + """ + + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + + return _ReadRowsOperationAsync + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + from google.cloud.bigtable.data import ReadRowsQuery + + row_limit = 91 + query = ReadRowsQuery(limit=row_limit) + client = mock.Mock() + client.read_rows = mock.Mock() + client.read_rows.return_value = None + table = mock.Mock() + table._client = client + table.table_name = "test_table" + table.app_profile_id = "test_profile" + expected_operation_timeout = 42 + expected_request_timeout = 44 + time_gen_mock = mock.Mock() + with mock.patch( + "google.cloud.bigtable.data._async._read_rows._attempt_timeout_generator", + time_gen_mock, + ): + instance = self._make_one( + query, + table, + operation_timeout=expected_operation_timeout, + attempt_timeout=expected_request_timeout, + ) + assert time_gen_mock.call_count == 1 + time_gen_mock.assert_called_once_with( + expected_request_timeout, expected_operation_timeout + ) + assert instance._last_yielded_row_key is None + assert instance._remaining_count == row_limit + assert instance.operation_timeout == expected_operation_timeout + assert client.read_rows.call_count == 0 + assert instance._metadata == [ + ( + "x-goog-request-params", + "table_name=test_table&app_profile_id=test_profile", + ) + ] + assert instance.request.table_name == table.table_name + assert instance.request.app_profile_id == table.app_profile_id + assert instance.request.rows_limit == row_limit + + @pytest.mark.parametrize( + "in_keys,last_key,expected", + [ + (["b", "c", "d"], "a", ["b", "c", "d"]), + (["a", "b", "c"], "b", ["c"]), + (["a", "b", "c"], "c", []), + (["a", "b", "c"], "d", []), + (["d", "c", "b", "a"], "b", ["d", "c"]), + ], + ) + def test_revise_request_rowset_keys(self, in_keys, last_key, expected): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + in_keys = [key.encode("utf-8") for key in in_keys] + expected = [key.encode("utf-8") for key in expected] + last_key = last_key.encode("utf-8") + + sample_range = RowRangePB(start_key_open=last_key) + row_set = RowSetPB(row_keys=in_keys, row_ranges=[sample_range]) + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == expected + assert revised.row_ranges == [sample_range] + + @pytest.mark.parametrize( + "in_ranges,last_key,expected", + [ + ( + [{"start_key_open": "b", "end_key_closed": "d"}], + "a", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "a", + [{"start_key_closed": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_open": "a", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "a", "end_key_open": "d"}], + "b", + [{"start_key_open": "b", "end_key_open": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []), + ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]), + ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]), + ( + [{"end_key_closed": "z"}], + "a", + [{"start_key_open": "a", "end_key_closed": "z"}], + ), + ( + [{"end_key_open": "z"}], + "a", + [{"start_key_open": "a", "end_key_open": "z"}], + ), + ], + ) + def test_revise_request_rowset_ranges(self, in_ranges, last_key, expected): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + # convert to protobuf + next_key = (last_key + "a").encode("utf-8") + last_key = last_key.encode("utf-8") + in_ranges = [ + RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) + for r in in_ranges + ] + expected = [ + RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) for r in expected + ] + + row_set = RowSetPB(row_ranges=in_ranges, row_keys=[next_key]) + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == [next_key] + assert revised.row_ranges == expected + + @pytest.mark.parametrize("last_key", ["a", "b", "c"]) + def test_revise_request_full_table(self, last_key): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + # convert to protobuf + last_key = last_key.encode("utf-8") + row_set = RowSetPB() + for selected_set in [row_set, None]: + revised = self._get_target_class()._revise_request_rowset( + selected_set, last_key + ) + assert revised.row_keys == [] + assert len(revised.row_ranges) == 1 + assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key) + + def test_revise_to_empty_rowset(self): + """revising to an empty rowset should raise error""" + from google.cloud.bigtable.data.exceptions import _RowSetComplete + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + row_keys = [b"a", b"b", b"c"] + row_range = RowRangePB(end_key_open=b"c") + row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range]) + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, b"d") + + @pytest.mark.parametrize( + "start_limit,emit_num,expected_limit", + [ + (10, 0, 10), + (10, 1, 9), + (10, 10, 0), + (None, 10, None), + (None, 0, None), + (4, 2, 2), + ], + ) + @pytest.mark.asyncio + async def test_revise_limit(self, start_limit, emit_num, expected_limit): + """ + revise_limit should revise the request's limit field + - if limit is 0 (unlimited), it should never be revised + - if start_limit-emit_num == 0, the request should end early + - if the number emitted exceeds the new limit, an exception should + should be raised (tested in test_revise_limit_over_limit) + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + + async def awaitable_stream(): + async def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + # read emit_num rows + async for val in instance.chunk_stream(awaitable_stream()): + pass + assert instance._remaining_count == expected_limit + + @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) + @pytest.mark.asyncio + async def test_revise_limit_over_limit(self, start_limit, emit_num): + """ + Should raise runtime error if we get in state where emit_num > start_num + (unless start_num == 0, which represents unlimited) + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + from google.cloud.bigtable.data.exceptions import InvalidChunk + + async def awaitable_stream(): + async def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + with pytest.raises(InvalidChunk) as e: + # read emit_num rows + async for val in instance.chunk_stream(awaitable_stream()): + pass + assert "emit count exceeds row limit" in str(e.value) + + @pytest.mark.asyncio + async def test_aclose(self): + """ + should be able to close a stream safely with aclose. + Closed generators should raise StopAsyncIteration on next yield + """ + + async def mock_stream(): + while True: + yield 1 + + with mock.patch.object( + _ReadRowsOperationAsync, "_read_rows_attempt" + ) as mock_attempt: + instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) + wrapped_gen = mock_stream() + mock_attempt.return_value = wrapped_gen + gen = instance.start_operation() + # read one row + await gen.__anext__() + await gen.aclose() + with pytest.raises(StopAsyncIteration): + await gen.__anext__() + # try calling a second time + await gen.aclose() + # ensure close was propagated to wrapped generator + with pytest.raises(StopAsyncIteration): + await wrapped_gen.__anext__() + + @pytest.mark.asyncio + async def test_retryable_ignore_repeated_rows(self): + """ + Duplicate rows should cause an invalid chunk error + """ + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import ReadRowsResponse + + row_key = b"duplicate" + + async def mock_awaitable_stream(): + async def mock_stream(): + while True: + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + + return mock_stream() + + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream()) + await stream.__anext__() + with pytest.raises(InvalidChunk) as exc: + await stream.__anext__() + assert "row keys should be strictly increasing" in str(exc.value) + + +class MockStream(_ReadRowsOperationAsync): + """ + Mock a _ReadRowsOperationAsync stream for testing + """ + + def __init__(self, items=None, errors=None, operation_timeout=None): + self.transient_errors = errors + self.operation_timeout = operation_timeout + self.next_idx = 0 + if items is None: + items = list(range(10)) + self.items = items + + def __aiter__(self): + return self + + async def __anext__(self): + if self.next_idx >= len(self.items): + raise StopAsyncIteration + item = self.items[self.next_idx] + self.next_idx += 1 + if isinstance(item, Exception): + raise item + return item + + async def aclose(self): + pass diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py new file mode 100644 index 000000000..a0019947d --- /dev/null +++ b/tests/unit/data/_async/test_client.py @@ -0,0 +1,2957 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import grpc +import asyncio +import re +import sys + +import pytest + +from google.cloud.bigtable.data import mutations +from google.auth.credentials import AnonymousCredentials +from google.cloud.bigtable_v2.types import ReadRowsResponse +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.api_core import exceptions as core_exceptions +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT + +from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule +from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # type: ignore +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore + +VENEER_HEADER_REGEX = re.compile( + r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-data-async gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" +) + + +def _make_client(*args, use_emulator=True, **kwargs): + import os + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + + env_mask = {} + # by default, use emulator mode to avoid auth issues in CI + # emulator mode must be disabled by tests that check channel pooling/refresh background tasks + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + else: + # set some default values + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return BigtableDataClientAsync(*args, **kwargs) + + +class TestBigtableDataClientAsync: + def _get_target_class(self): + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + + return BigtableDataClientAsync + + def _make_one(self, *args, **kwargs): + return _make_client(*args, **kwargs) + + @pytest.mark.asyncio + async def test_ctor(self): + expected_project = "project-id" + expected_pool_size = 11 + expected_credentials = AnonymousCredentials() + client = self._make_one( + project="project-id", + pool_size=expected_pool_size, + credentials=expected_credentials, + use_emulator=False, + ) + await asyncio.sleep(0) + assert client.project == expected_project + assert len(client.transport._grpc_channel._pool) == expected_pool_size + assert not client._active_instances + assert len(client._channel_refresh_tasks) == expected_pool_size + assert client.transport._credentials == expected_credentials + await client.close() + + @pytest.mark.asyncio + async def test_ctor_super_inits(self): + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) + from google.cloud.client import ClientWithProject + from google.api_core import client_options as client_options_lib + + project = "project-id" + pool_size = 11 + credentials = AnonymousCredentials() + client_options = {"api_endpoint": "foo.bar:1234"} + options_parsed = client_options_lib.from_dict(client_options) + transport_str = f"pooled_grpc_asyncio_{pool_size}" + with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + bigtable_client_init.return_value = None + with mock.patch.object( + ClientWithProject, "__init__" + ) as client_project_init: + client_project_init.return_value = None + try: + self._make_one( + project=project, + pool_size=pool_size, + credentials=credentials, + client_options=options_parsed, + use_emulator=False, + ) + except AttributeError: + pass + # test gapic superclass init was called + assert bigtable_client_init.call_count == 1 + kwargs = bigtable_client_init.call_args[1] + assert kwargs["transport"] == transport_str + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + # test mixin superclass init was called + assert client_project_init.call_count == 1 + kwargs = client_project_init.call_args[1] + assert kwargs["project"] == project + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + + @pytest.mark.asyncio + async def test_ctor_dict_options(self): + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) + from google.api_core.client_options import ClientOptions + + client_options = {"api_endpoint": "foo.bar:1234"} + with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + try: + self._make_one(client_options=client_options) + except TypeError: + pass + bigtable_client_init.assert_called_once() + kwargs = bigtable_client_init.call_args[1] + called_options = kwargs["client_options"] + assert called_options.api_endpoint == "foo.bar:1234" + assert isinstance(called_options, ClientOptions) + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ) as start_background_refresh: + client = self._make_one(client_options=client_options, use_emulator=False) + start_background_refresh.assert_called_once() + await client.close() + + @pytest.mark.asyncio + async def test_veneer_grpc_headers(self): + # client_info should be populated with headers to + # detect as a veneer client + patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + with patch as gapic_mock: + client = self._make_one(project="project-id") + wrapped_call_list = gapic_mock.call_args_list + assert len(wrapped_call_list) > 0 + # each wrapped call should have veneer headers + for call in wrapped_call_list: + client_info = call.kwargs["client_info"] + assert client_info is not None, f"{call} has no client_info" + wrapped_user_agent_sorted = " ".join( + sorted(client_info.to_user_agent().split(" ")) + ) + assert VENEER_HEADER_REGEX.match( + wrapped_user_agent_sorted + ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" + await client.close() + + @pytest.mark.asyncio + async def test_channel_pool_creation(self): + pool_size = 14 + with mock.patch( + "google.api_core.grpc_helpers_async.create_channel" + ) as create_channel: + create_channel.return_value = AsyncMock() + client = self._make_one(project="project-id", pool_size=pool_size) + assert create_channel.call_count == pool_size + await client.close() + # channels should be unique + client = self._make_one(project="project-id", pool_size=pool_size) + pool_list = list(client.transport._grpc_channel._pool) + pool_set = set(client.transport._grpc_channel._pool) + assert len(pool_list) == len(pool_set) + await client.close() + + @pytest.mark.asyncio + async def test_channel_pool_rotation(self): + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel, + ) + + pool_size = 7 + + with mock.patch.object(PooledChannel, "next_channel") as next_channel: + client = self._make_one(project="project-id", pool_size=pool_size) + assert len(client.transport._grpc_channel._pool) == pool_size + next_channel.reset_mock() + with mock.patch.object( + type(client.transport._grpc_channel._pool[0]), "unary_unary" + ) as unary_unary: + # calling an rpc `pool_size` times should use a different channel each time + channel_next = None + for i in range(pool_size): + channel_last = channel_next + channel_next = client.transport.grpc_channel._pool[i] + assert channel_last != channel_next + next_channel.return_value = channel_next + client.transport.ping_and_warm() + assert next_channel.call_count == i + 1 + unary_unary.assert_called_once() + unary_unary.reset_mock() + await client.close() + + @pytest.mark.asyncio + async def test_channel_pool_replace(self): + with mock.patch.object(asyncio, "sleep"): + pool_size = 7 + client = self._make_one(project="project-id", pool_size=pool_size) + for replace_idx in range(pool_size): + start_pool = [ + channel for channel in client.transport._grpc_channel._pool + ] + grace_period = 9 + with mock.patch.object( + type(client.transport._grpc_channel._pool[0]), "close" + ) as close: + new_channel = grpc.aio.insecure_channel("localhost:8080") + await client.transport.replace_channel( + replace_idx, grace=grace_period, new_channel=new_channel + ) + close.assert_called_once_with(grace=grace_period) + close.assert_awaited_once() + assert client.transport._grpc_channel._pool[replace_idx] == new_channel + for i in range(pool_size): + if i != replace_idx: + assert client.transport._grpc_channel._pool[i] == start_pool[i] + else: + assert client.transport._grpc_channel._pool[i] != start_pool[i] + await client.close() + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__start_background_channel_refresh_sync(self): + # should raise RuntimeError if called in a sync context + client = self._make_one(project="project-id", use_emulator=False) + with pytest.raises(RuntimeError): + client._start_background_channel_refresh() + + @pytest.mark.asyncio + async def test__start_background_channel_refresh_tasks_exist(self): + # if tasks exist, should do nothing + client = self._make_one(project="project-id", use_emulator=False) + assert len(client._channel_refresh_tasks) > 0 + with mock.patch.object(asyncio, "create_task") as create_task: + client._start_background_channel_refresh() + create_task.assert_not_called() + await client.close() + + @pytest.mark.asyncio + @pytest.mark.parametrize("pool_size", [1, 3, 7]) + async def test__start_background_channel_refresh(self, pool_size): + # should create background tasks for each channel + client = self._make_one( + project="project-id", pool_size=pool_size, use_emulator=False + ) + ping_and_warm = AsyncMock() + client._ping_and_warm_instances = ping_and_warm + client._start_background_channel_refresh() + assert len(client._channel_refresh_tasks) == pool_size + for task in client._channel_refresh_tasks: + assert isinstance(task, asyncio.Task) + await asyncio.sleep(0.1) + assert ping_and_warm.call_count == pool_size + for channel in client.transport._grpc_channel._pool: + ping_and_warm.assert_any_call(channel) + await client.close() + + @pytest.mark.asyncio + @pytest.mark.skipif( + sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" + ) + async def test__start_background_channel_refresh_tasks_names(self): + # if tasks exist, should do nothing + pool_size = 3 + client = self._make_one( + project="project-id", pool_size=pool_size, use_emulator=False + ) + for i in range(pool_size): + name = client._channel_refresh_tasks[i].get_name() + assert str(i) in name + assert "BigtableDataClientAsync channel refresh " in name + await client.close() + + @pytest.mark.asyncio + async def test__ping_and_warm_instances(self): + """ + test ping and warm with mocked asyncio.gather + """ + client_mock = mock.Mock() + with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: + # simulate gather by returning the same number of items as passed in + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + channel = mock.Mock() + # test with no instances + client_mock._active_instances = [] + result = await self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 0 + gather.assert_called_once() + gather.assert_awaited_once() + assert not gather.call_args.args + assert gather.call_args.kwargs == {"return_exceptions": True} + # test with instances + client_mock._active_instances = [ + (mock.Mock(), mock.Mock(), mock.Mock()) + ] * 4 + gather.reset_mock() + channel.reset_mock() + result = await self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 4 + gather.assert_called_once() + gather.assert_awaited_once() + assert len(gather.call_args.args) == 4 + # check grpc call arguments + grpc_call_args = channel.unary_unary().call_args_list + for idx, (_, kwargs) in enumerate(grpc_call_args): + ( + expected_instance, + expected_table, + expected_app_profile, + ) = client_mock._active_instances[idx] + request = kwargs["request"] + assert request["name"] == expected_instance + assert request["app_profile_id"] == expected_app_profile + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] + == f"name={expected_instance}&app_profile_id={expected_app_profile}" + ) + + @pytest.mark.asyncio + async def test__ping_and_warm_single_instance(self): + """ + should be able to call ping and warm with single instance + """ + client_mock = mock.Mock() + with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: + # simulate gather by returning the same number of items as passed in + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + channel = mock.Mock() + # test with large set of instances + client_mock._active_instances = [mock.Mock()] * 100 + test_key = ("test-instance", "test-table", "test-app-profile") + result = await self._get_target_class()._ping_and_warm_instances( + client_mock, channel, test_key + ) + # should only have been called with test instance + assert len(result) == 1 + # check grpc call arguments + grpc_call_args = channel.unary_unary().call_args_list + assert len(grpc_call_args) == 1 + kwargs = grpc_call_args[0][1] + request = kwargs["request"] + assert request["name"] == "test-instance" + assert request["app_profile_id"] == "test-app-profile" + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "refresh_interval, wait_time, expected_sleep", + [ + (0, 0, 0), + (0, 1, 0), + (10, 0, 10), + (10, 5, 5), + (10, 10, 0), + (10, 15, 0), + ], + ) + async def test__manage_channel_first_sleep( + self, refresh_interval, wait_time, expected_sleep + ): + # first sleep time should be `refresh_interval` seconds after client init + import time + + with mock.patch.object(time, "monotonic") as time: + time.return_value = 0 + with mock.patch.object(asyncio, "sleep") as sleep: + sleep.side_effect = asyncio.CancelledError + try: + client = self._make_one(project="project-id") + client._channel_init_time = -wait_time + await client._manage_channel(0, refresh_interval, refresh_interval) + except asyncio.CancelledError: + pass + sleep.assert_called_once() + call_time = sleep.call_args[0][0] + assert ( + abs(call_time - expected_sleep) < 0.1 + ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" + await client.close() + + @pytest.mark.asyncio + async def test__manage_channel_ping_and_warm(self): + """ + _manage channel should call ping and warm internally + """ + import time + + client_mock = mock.Mock() + client_mock._channel_init_time = time.monotonic() + channel_list = [mock.Mock(), mock.Mock()] + client_mock.transport.channels = channel_list + new_channel = mock.Mock() + client_mock.transport.grpc_channel._create_channel.return_value = new_channel + # should ping an warm all new channels, and old channels if sleeping + with mock.patch.object(asyncio, "sleep"): + # stop process after replace_channel is called + client_mock.transport.replace_channel.side_effect = asyncio.CancelledError + ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() + # should ping and warm old channel then new if sleep > 0 + try: + channel_idx = 1 + await self._get_target_class()._manage_channel( + client_mock, channel_idx, 10 + ) + except asyncio.CancelledError: + pass + # should have called at loop start, and after replacement + assert ping_and_warm.call_count == 2 + # should have replaced channel once + assert client_mock.transport.replace_channel.call_count == 1 + # make sure new and old channels were warmed + old_channel = channel_list[channel_idx] + assert old_channel != new_channel + called_with = [call[0][0] for call in ping_and_warm.call_args_list] + assert old_channel in called_with + assert new_channel in called_with + # should ping and warm instantly new channel only if not sleeping + ping_and_warm.reset_mock() + try: + await self._get_target_class()._manage_channel(client_mock, 0, 0, 0) + except asyncio.CancelledError: + pass + ping_and_warm.assert_called_once_with(new_channel) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "refresh_interval, num_cycles, expected_sleep", + [ + (None, 1, 60 * 35), + (10, 10, 100), + (10, 1, 10), + ], + ) + async def test__manage_channel_sleeps( + self, refresh_interval, num_cycles, expected_sleep + ): + # make sure that sleeps work as expected + import time + import random + + channel_idx = 1 + with mock.patch.object(random, "uniform") as uniform: + uniform.side_effect = lambda min_, max_: min_ + with mock.patch.object(time, "time") as time: + time.return_value = 0 + with mock.patch.object(asyncio, "sleep") as sleep: + sleep.side_effect = [None for i in range(num_cycles - 1)] + [ + asyncio.CancelledError + ] + try: + client = self._make_one(project="project-id") + if refresh_interval is not None: + await client._manage_channel( + channel_idx, refresh_interval, refresh_interval + ) + else: + await client._manage_channel(channel_idx) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + assert ( + abs(total_sleep - expected_sleep) < 0.1 + ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" + await client.close() + + @pytest.mark.asyncio + async def test__manage_channel_random(self): + import random + + with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(random, "uniform") as uniform: + uniform.return_value = 0 + try: + uniform.side_effect = asyncio.CancelledError + client = self._make_one(project="project-id", pool_size=1) + except asyncio.CancelledError: + uniform.side_effect = None + uniform.reset_mock() + sleep.reset_mock() + min_val = 200 + max_val = 205 + uniform.side_effect = lambda min_, max_: min_ + sleep.side_effect = [None, None, asyncio.CancelledError] + try: + await client._manage_channel(0, min_val, max_val) + except asyncio.CancelledError: + pass + assert uniform.call_count == 2 + uniform_args = [call[0] for call in uniform.call_args_list] + for found_min, found_max in uniform_args: + assert found_min == min_val + assert found_max == max_val + + @pytest.mark.asyncio + @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) + async def test__manage_channel_refresh(self, num_cycles): + # make sure that channels are properly refreshed + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + ) + from google.api_core import grpc_helpers_async + + expected_grace = 9 + expected_refresh = 0.5 + channel_idx = 1 + new_channel = grpc.aio.insecure_channel("localhost:8080") + + with mock.patch.object( + PooledBigtableGrpcAsyncIOTransport, "replace_channel" + ) as replace_channel: + with mock.patch.object(asyncio, "sleep") as sleep: + sleep.side_effect = [None for i in range(num_cycles)] + [ + asyncio.CancelledError + ] + with mock.patch.object( + grpc_helpers_async, "create_channel" + ) as create_channel: + create_channel.return_value = new_channel + client = self._make_one(project="project-id", use_emulator=False) + create_channel.reset_mock() + try: + await client._manage_channel( + channel_idx, + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=expected_grace, + ) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel.call_count == num_cycles + assert replace_channel.call_count == num_cycles + for call in replace_channel.call_args_list: + args, kwargs = call + assert args[0] == channel_idx + assert kwargs["grace"] == expected_grace + assert kwargs["new_channel"] == new_channel + await client.close() + + @pytest.mark.asyncio + async def test__register_instance(self): + """ + test instance registration + """ + # set up mock client + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = AsyncMock() + table_mock = mock.Mock() + await self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + # first call should start background refresh + assert client_mock._start_background_channel_refresh.call_count == 1 + # ensure active_instances and instance_owners were updated properly + expected_key = ( + "prefix/instance-1", + table_mock.table_name, + table_mock.app_profile_id, + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + # should be a new task set + assert client_mock._channel_refresh_tasks + # next call should not call _start_background_channel_refresh again + table_mock2 = mock.Mock() + await self._get_target_class()._register_instance( + client_mock, "instance-2", table_mock2 + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + # but it should call ping and warm with new instance key + assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) + for channel in mock_channels: + assert channel in [ + call[0][0] + for call in client_mock._ping_and_warm_instances.call_args_list + ] + # check for updated lists + assert len(active_instances) == 2 + assert len(instance_owners) == 2 + expected_key2 = ( + "prefix/instance-2", + table_mock2.table_name, + table_mock2.app_profile_id, + ) + assert any( + [ + expected_key2 == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + assert any( + [ + expected_key2 == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "insert_instances,expected_active,expected_owner_keys", + [ + ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), + ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), + ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), + ( + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + ), + ], + ) + async def test__register_instance_state( + self, insert_instances, expected_active, expected_owner_keys + ): + """ + test that active_instances and instance_owners are updated as expected + """ + # set up mock client + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: b + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = AsyncMock() + table_mock = mock.Mock() + # register instances + for instance, table, profile in insert_instances: + table_mock.table_name = table + table_mock.app_profile_id = profile + await self._get_target_class()._register_instance( + client_mock, instance, table_mock + ) + assert len(active_instances) == len(expected_active) + assert len(instance_owners) == len(expected_owner_keys) + for expected in expected_active: + assert any( + [ + expected == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + for expected in expected_owner_keys: + assert any( + [ + expected == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + @pytest.mark.asyncio + async def test__remove_instance_registration(self): + client = self._make_one(project="project-id") + table = mock.Mock() + await client._register_instance("instance-1", table) + await client._register_instance("instance-2", table) + assert len(client._active_instances) == 2 + assert len(client._instance_owners.keys()) == 2 + instance_1_path = client._gapic_client.instance_path( + client.project, "instance-1" + ) + instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance-2" + ) + instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) + assert len(client._instance_owners[instance_1_key]) == 1 + assert list(client._instance_owners[instance_1_key])[0] == id(table) + assert len(client._instance_owners[instance_2_key]) == 1 + assert list(client._instance_owners[instance_2_key])[0] == id(table) + success = await client._remove_instance_registration("instance-1", table) + assert success + assert len(client._active_instances) == 1 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 1 + assert client._active_instances == {instance_2_key} + success = await client._remove_instance_registration("fake-key", table) + assert not success + assert len(client._active_instances) == 1 + await client.close() + + @pytest.mark.asyncio + async def test__multiple_table_registration(self): + """ + registering with multiple tables with the same key should + add multiple owners to instance_owners, but only keep one copy + of shared key in active_instances + """ + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + + async with self._make_one(project="project-id") as client: + async with client.get_table("instance_1", "table_1") as table_1: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + # duplicate table should register in instance_owners under same key + async with client.get_table("instance_1", "table_1") as table_2: + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + # unique table should register in instance_owners and active_instances + async with client.get_table("instance_1", "table_3") as table_3: + instance_3_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_3_key = _WarmedInstanceKey( + instance_3_path, table_3.table_name, table_3.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._instance_owners[instance_3_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + assert id(table_3) in client._instance_owners[instance_3_key] + # sub-tables should be unregistered, but instance should still be active + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert id(table_2) not in client._instance_owners[instance_1_key] + # both tables are gone. instance should be unregistered + assert len(client._active_instances) == 0 + assert instance_1_key not in client._active_instances + assert len(client._instance_owners[instance_1_key]) == 0 + + @pytest.mark.asyncio + async def test__multiple_instance_registration(self): + """ + registering with multiple instance keys should update the key + in instance_owners and active_instances + """ + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + + async with self._make_one(project="project-id") as client: + async with client.get_table("instance_1", "table_1") as table_1: + async with client.get_table("instance_2", "table_2") as table_2: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance_2" + ) + instance_2_key = _WarmedInstanceKey( + instance_2_path, table_2.table_name, table_2.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._instance_owners[instance_2_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_2_key] + # instance2 should be unregistered, but instance1 should still be active + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert len(client._instance_owners[instance_2_key]) == 0 + assert len(client._instance_owners[instance_1_key]) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + # both tables are gone. instances should both be unregistered + assert len(client._active_instances) == 0 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 0 + + @pytest.mark.asyncio + async def test_get_table(self): + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + + client = self._make_one(project="project-id") + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + table = client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + ) + await asyncio.sleep(0) + assert isinstance(table, TableAsync) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + await client.close() + + @pytest.mark.asyncio + async def test_get_table_arg_passthrough(self): + """ + All arguments passed in get_table should be sent to constructor + """ + async with self._make_one(project="project-id") as client: + with mock.patch( + "google.cloud.bigtable.data._async.client.TableAsync.__init__", + ) as mock_constructor: + mock_constructor.return_value = None + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_args = (1, "test", {"test": 2}) + expected_kwargs = {"hello": "world", "test": 2} + + client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + mock_constructor.assert_called_once_with( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + + @pytest.mark.asyncio + async def test_get_table_context_manager(self): + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_project_id = "project-id" + + with mock.patch.object(TableAsync, "close") as close_mock: + async with self._make_one(project=expected_project_id) as client: + async with client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + ) as table: + await asyncio.sleep(0) + assert isinstance(table, TableAsync) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert close_mock.call_count == 1 + + @pytest.mark.asyncio + async def test_multiple_pool_sizes(self): + # should be able to create multiple clients with different pool sizes without issue + pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + for pool_size in pool_sizes: + client = self._make_one( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + client_duplicate = self._make_one( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client_duplicate._channel_refresh_tasks) == pool_size + assert str(pool_size) in str(client.transport) + await client.close() + await client_duplicate.close() + + @pytest.mark.asyncio + async def test_close(self): + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + ) + + pool_size = 7 + client = self._make_one( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + tasks_list = list(client._channel_refresh_tasks) + for task in client._channel_refresh_tasks: + assert not task.done() + with mock.patch.object( + PooledBigtableGrpcAsyncIOTransport, "close", AsyncMock() + ) as close_mock: + await client.close() + close_mock.assert_called_once() + close_mock.assert_awaited() + for task in tasks_list: + assert task.done() + assert task.cancelled() + assert client._channel_refresh_tasks == [] + + @pytest.mark.asyncio + async def test_close_with_timeout(self): + pool_size = 7 + expected_timeout = 19 + client = self._make_one(project="project-id", pool_size=pool_size) + tasks = list(client._channel_refresh_tasks) + with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for_mock: + await client.close(timeout=expected_timeout) + wait_for_mock.assert_called_once() + wait_for_mock.assert_awaited() + assert wait_for_mock.call_args[1]["timeout"] == expected_timeout + client._channel_refresh_tasks = tasks + await client.close() + + @pytest.mark.asyncio + async def test_context_manager(self): + # context manager should close the client cleanly + close_mock = AsyncMock() + true_close = None + async with self._make_one(project="project-id") as client: + true_close = client.close() + client.close = close_mock + for task in client._channel_refresh_tasks: + assert not task.done() + assert client.project == "project-id" + assert client._active_instances == set() + close_mock.assert_not_called() + close_mock.assert_called_once() + close_mock.assert_awaited() + # actually close the client + await true_close + + def test_client_ctor_sync(self): + # initializing client in a sync context should raise RuntimeError + + with pytest.warns(RuntimeWarning) as warnings: + client = _make_client(project="project-id", use_emulator=False) + expected_warning = [w for w in warnings if "client.py" in w.filename] + assert len(expected_warning) == 1 + assert ( + "BigtableDataClientAsync should be started in an asyncio event loop." + in str(expected_warning[0].message) + ) + assert client.project == "project-id" + assert client._channel_refresh_tasks == [] + + +class TestTableAsync: + @pytest.mark.asyncio + async def test_table_ctor(self): + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_operation_timeout = 123 + expected_attempt_timeout = 12 + expected_read_rows_operation_timeout = 1.5 + expected_read_rows_attempt_timeout = 0.5 + expected_mutate_rows_operation_timeout = 2.5 + expected_mutate_rows_attempt_timeout = 0.75 + client = _make_client() + assert not client._active_instances + + table = TableAsync( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + default_operation_timeout=expected_operation_timeout, + default_attempt_timeout=expected_attempt_timeout, + default_read_rows_operation_timeout=expected_read_rows_operation_timeout, + default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, + default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, + ) + await asyncio.sleep(0) + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert table.default_operation_timeout == expected_operation_timeout + assert table.default_attempt_timeout == expected_attempt_timeout + assert ( + table.default_read_rows_operation_timeout + == expected_read_rows_operation_timeout + ) + assert ( + table.default_read_rows_attempt_timeout + == expected_read_rows_attempt_timeout + ) + assert ( + table.default_mutate_rows_operation_timeout + == expected_mutate_rows_operation_timeout + ) + assert ( + table.default_mutate_rows_attempt_timeout + == expected_mutate_rows_attempt_timeout + ) + # ensure task reaches completion + await table._register_instance_task + assert table._register_instance_task.done() + assert not table._register_instance_task.cancelled() + assert table._register_instance_task.exception() is None + await client.close() + + @pytest.mark.asyncio + async def test_table_ctor_defaults(self): + """ + should provide default timeout values and app_profile_id + """ + from google.cloud.bigtable.data._async.client import TableAsync + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + client = _make_client() + assert not client._active_instances + + table = TableAsync( + client, + expected_instance_id, + expected_table_id, + ) + await asyncio.sleep(0) + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id is None + assert table.client is client + assert table.default_operation_timeout == 60 + assert table.default_read_rows_operation_timeout == 600 + assert table.default_mutate_rows_operation_timeout == 600 + assert table.default_attempt_timeout == 20 + assert table.default_read_rows_attempt_timeout == 20 + assert table.default_mutate_rows_attempt_timeout == 60 + await client.close() + + @pytest.mark.asyncio + async def test_table_ctor_invalid_timeout_values(self): + """ + bad timeout values should raise ValueError + """ + from google.cloud.bigtable.data._async.client import TableAsync + + client = _make_client() + + timeout_pairs = [ + ("default_operation_timeout", "default_attempt_timeout"), + ( + "default_read_rows_operation_timeout", + "default_read_rows_attempt_timeout", + ), + ( + "default_mutate_rows_operation_timeout", + "default_mutate_rows_attempt_timeout", + ), + ] + for operation_timeout, attempt_timeout in timeout_pairs: + with pytest.raises(ValueError) as e: + TableAsync(client, "", "", **{attempt_timeout: -1}) + assert "attempt_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + TableAsync(client, "", "", **{operation_timeout: -1}) + assert "operation_timeout must be greater than 0" in str(e.value) + await client.close() + + def test_table_ctor_sync(self): + # initializing client in a sync context should raise RuntimeError + from google.cloud.bigtable.data._async.client import TableAsync + + client = mock.Mock() + with pytest.raises(RuntimeError) as e: + TableAsync(client, "instance-id", "table-id") + assert e.match("TableAsync must be created within an async event loop context.") + + @pytest.mark.asyncio + # iterate over all retryable rpcs + @pytest.mark.parametrize( + "fn_name,fn_args,retry_fn_path,extra_retryables", + [ + ( + "read_rows_stream", + (ReadRowsQuery(),), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ( + "read_rows", + (ReadRowsQuery(),), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ( + "read_row", + (b"row_key",), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ( + "read_rows_sharded", + ([ReadRowsQuery()],), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ( + "row_exists", + (b"row_key",), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ("sample_row_keys", (), "google.api_core.retry.retry_target_async", ()), + ( + "mutate_row", + (b"row_key", [mock.Mock()]), + "google.api_core.retry.retry_target_async", + (), + ), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mock.Mock()])],), + "google.api_core.retry.retry_target_async", + (_MutateRowsIncomplete,), + ), + ], + ) + # test different inputs for retryable exceptions + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + async def test_customizable_retryable_errors( + self, + input_retryables, + expected_retryables, + fn_name, + fn_args, + retry_fn_path, + extra_retryables, + ): + """ + Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer. + """ + with mock.patch(retry_fn_path) as retry_fn_mock: + async with _make_client() as client: + table = client.get_table("instance-id", "table-id") + expected_predicate = lambda a: a in expected_retryables # noqa + retry_fn_mock.side_effect = RuntimeError("stop early") + with mock.patch( + "google.api_core.retry.if_exception_type" + ) as predicate_builder_mock: + predicate_builder_mock.return_value = expected_predicate + with pytest.raises(Exception): + # we expect an exception from attempting to call the mock + test_fn = table.__getattribute__(fn_name) + await test_fn(*fn_args, retryable_errors=input_retryables) + # passed in errors should be used to build the predicate + predicate_builder_mock.assert_called_once_with( + *expected_retryables, *extra_retryables + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + # output of if_exception_type should be sent in to retry constructor + assert retry_call_args[1] is expected_predicate + + @pytest.mark.parametrize( + "fn_name,fn_args,gapic_fn", + [ + ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), + ("read_rows", (ReadRowsQuery(),), "read_rows"), + ("read_row", (b"row_key",), "read_rows"), + ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), + ("row_exists", (b"row_key",), "read_rows"), + ("sample_row_keys", (), "sample_row_keys"), + ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + "mutate_rows", + ), + ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), + ( + "read_modify_write_row", + (b"row_key", mock.Mock()), + "read_modify_write_row", + ), + ], + ) + @pytest.mark.parametrize("include_app_profile", [True, False]) + @pytest.mark.asyncio + async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): + """check that all requests attach proper metadata headers""" + from google.cloud.bigtable.data import TableAsync + + profile = "profile" if include_app_profile else None + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}", mock.AsyncMock() + ) as gapic_mock: + gapic_mock.side_effect = RuntimeError("stop early") + async with _make_client() as client: + table = TableAsync(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = await test_fn(*fn_args) + [i async for i in maybe_stream] + except Exception: + # we expect an exception from attempting to call the mock + pass + kwargs = gapic_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata + + +class TestReadRows: + """ + Tests for table.read_rows and related methods. + """ + + def _make_table(self, *args, **kwargs): + from google.cloud.bigtable.data._async.client import TableAsync + + client_mock = mock.Mock() + client_mock._register_instance.side_effect = ( + lambda *args, **kwargs: asyncio.sleep(0) + ) + client_mock._remove_instance_registration.side_effect = ( + lambda *args, **kwargs: asyncio.sleep(0) + ) + kwargs["instance_id"] = kwargs.get( + "instance_id", args[0] if args else "instance" + ) + kwargs["table_id"] = kwargs.get( + "table_id", args[1] if len(args) > 1 else "table" + ) + client_mock._gapic_client.table_path.return_value = kwargs["table_id"] + client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] + return TableAsync(client_mock, *args, **kwargs) + + def _make_stats(self): + from google.cloud.bigtable_v2.types import RequestStats + from google.cloud.bigtable_v2.types import FullReadStatsView + from google.cloud.bigtable_v2.types import ReadIterationStats + + return RequestStats( + full_read_stats_view=FullReadStatsView( + read_iteration_stats=ReadIterationStats( + rows_seen_count=1, + rows_returned_count=2, + cells_seen_count=3, + cells_returned_count=4, + ) + ) + ) + + @staticmethod + def _make_chunk(*args, **kwargs): + from google.cloud.bigtable_v2 import ReadRowsResponse + + kwargs["row_key"] = kwargs.get("row_key", b"row_key") + kwargs["family_name"] = kwargs.get("family_name", "family_name") + kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier") + kwargs["value"] = kwargs.get("value", b"value") + kwargs["commit_row"] = kwargs.get("commit_row", True) + + return ReadRowsResponse.CellChunk(*args, **kwargs) + + @staticmethod + async def _make_gapic_stream( + chunk_list: list[ReadRowsResponse.CellChunk | Exception], + sleep_time=0, + ): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list, sleep_time): + self.chunk_list = chunk_list + self.idx = -1 + self.sleep_time = sleep_time + + def __aiter__(self): + return self + + async def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + if sleep_time: + await asyncio.sleep(self.sleep_time) + chunk = self.chunk_list[self.idx] + if isinstance(chunk, Exception): + raise chunk + else: + return ReadRowsResponse(chunks=[chunk]) + raise StopAsyncIteration + + def cancel(self): + pass + + return mock_stream(chunk_list, sleep_time) + + async def execute_fn(self, table, *args, **kwargs): + return await table.read_rows(*args, **kwargs) + + @pytest.mark.asyncio + async def test_read_rows(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + results = await self.execute_fn(table, query, operation_timeout=3) + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" + + @pytest.mark.asyncio + async def test_read_rows_stream(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + gen = await table.read_rows_stream(query, operation_timeout=3) + results = [row async for row in gen] + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" + + @pytest.mark.parametrize("include_app_profile", [True, False]) + @pytest.mark.asyncio + async def test_read_rows_query_matches_request(self, include_app_profile): + from google.cloud.bigtable.data import RowRange + from google.cloud.bigtable.data.row_filters import PassAllFilter + + app_profile_id = "app_profile_id" if include_app_profile else None + async with self._make_table(app_profile_id=app_profile_id) as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) + row_keys = [b"test_1", "test_2"] + row_ranges = RowRange("1start", "2end") + filter_ = PassAllFilter(True) + limit = 99 + query = ReadRowsQuery( + row_keys=row_keys, + row_ranges=row_ranges, + row_filter=filter_, + limit=limit, + ) + + results = await table.read_rows(query, operation_timeout=3) + assert len(results) == 0 + call_request = read_rows.call_args_list[0][0][0] + query_pb = query._to_pb(table) + assert call_request == query_pb + + @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) + @pytest.mark.asyncio + async def test_read_rows_timeout(self, operation_timeout): + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + query = ReadRowsQuery() + chunks = [self._make_chunk(row_key=b"test_1")] + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=1 + ) + try: + await table.read_rows(query, operation_timeout=operation_timeout) + except core_exceptions.DeadlineExceeded as e: + assert ( + e.message + == f"operation_timeout of {operation_timeout:0.1f}s exceeded" + ) + + @pytest.mark.parametrize( + "per_request_t, operation_t, expected_num", + [ + (0.05, 0.08, 2), + (0.05, 0.54, 11), + (0.05, 0.14, 3), + (0.05, 0.24, 5), + ], + ) + @pytest.mark.asyncio + async def test_read_rows_attempt_timeout( + self, per_request_t, operation_t, expected_num + ): + """ + Ensures that the attempt_timeout is respected and that the number of + requests is as expected. + + operation_timeout does not cancel the request, so we expect the number of + requests to be the ceiling of operation_timeout / attempt_timeout. + """ + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + expected_last_timeout = operation_t - (expected_num - 1) * per_request_t + + # mocking uniform ensures there are no sleeps between retries + with mock.patch("random.uniform", side_effect=lambda a, b: 0): + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=per_request_t + ) + query = ReadRowsQuery() + chunks = [core_exceptions.DeadlineExceeded("mock deadline")] + + try: + await table.read_rows( + query, + operation_timeout=operation_t, + attempt_timeout=per_request_t, + ) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + if expected_num == 0: + assert retry_exc is None + else: + assert type(retry_exc) is RetryExceptionGroup + assert f"{expected_num} failed attempts" in str(retry_exc) + assert len(retry_exc.exceptions) == expected_num + for sub_exc in retry_exc.exceptions: + assert sub_exc.message == "mock deadline" + assert read_rows.call_count == expected_num + # check timeouts + for _, call_kwargs in read_rows.call_args_list[:-1]: + assert call_kwargs["timeout"] == per_request_t + assert call_kwargs["retry"] is None + # last timeout should be adjusted to account for the time spent + assert ( + abs( + read_rows.call_args_list[-1][1]["timeout"] + - expected_last_timeout + ) + < 0.05 + ) + + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Aborted, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + @pytest.mark.asyncio + async def test_read_rows_retryable_error(self, exc_type): + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + await table.read_rows(query, operation_timeout=0.1) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + root_cause = retry_exc.exceptions[0] + assert type(root_cause) is exc_type + assert root_cause == expected_error + + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Cancelled, + core_exceptions.PreconditionFailed, + core_exceptions.NotFound, + core_exceptions.PermissionDenied, + core_exceptions.Conflict, + core_exceptions.InternalServerError, + core_exceptions.TooManyRequests, + core_exceptions.ResourceExhausted, + InvalidChunk, + ], + ) + @pytest.mark.asyncio + async def test_read_rows_non_retryable_error(self, exc_type): + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + await table.read_rows(query, operation_timeout=0.1) + except exc_type as e: + assert e == expected_error + + @pytest.mark.asyncio + async def test_read_rows_revise_request(self): + """ + Ensure that _revise_request is called between retries + """ + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import RowSet + + return_val = RowSet() + with mock.patch.object( + _ReadRowsOperationAsync, "_revise_request_rowset" + ) as revise_rowset: + revise_rowset.return_value = return_val + async with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + row_keys = [b"test_1", b"test_2", b"test_3"] + query = ReadRowsQuery(row_keys=row_keys) + chunks = [ + self._make_chunk(row_key=b"test_1"), + core_exceptions.Aborted("mock retryable error"), + ] + try: + await table.read_rows(query) + except InvalidChunk: + revise_rowset.assert_called() + first_call_kwargs = revise_rowset.call_args_list[0].kwargs + assert first_call_kwargs["row_set"] == query._to_pb(table).rows + assert first_call_kwargs["last_seen_row_key"] == b"test_1" + revised_call = read_rows.call_args_list[1].args[0] + assert revised_call.rows == return_val + + @pytest.mark.asyncio + async def test_read_rows_default_timeouts(self): + """ + Ensure that the default timeouts are set on the read rows operation when not overridden + """ + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + async with self._make_table( + default_read_rows_operation_timeout=operation_timeout, + default_read_rows_attempt_timeout=attempt_timeout, + ) as table: + try: + await table.read_rows(ReadRowsQuery()) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + @pytest.mark.asyncio + async def test_read_rows_default_timeout_override(self): + """ + When timeouts are passed, they overwrite default values + """ + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + async with self._make_table( + default_operation_timeout=99, default_attempt_timeout=97 + ) as table: + try: + await table.read_rows( + ReadRowsQuery(), + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + ) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + @pytest.mark.asyncio + async def test_read_row(self): + """Test reading a single row""" + async with _make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + row = await table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + + @pytest.mark.asyncio + async def test_read_row_w_filter(self): + """Test reading a single row with an added filter""" + async with _make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + mock_filter = mock.Mock() + expected_filter = {"filter": "mock filter"} + mock_filter._to_dict.return_value = expected_filter + row = await table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + row_filter=expected_filter, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter == expected_filter + + @pytest.mark.asyncio + async def test_read_row_no_response(self): + """should return None if row does not exist""" + async with _make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + # return no rows + read_rows.side_effect = lambda *args, **kwargs: [] + expected_op_timeout = 8 + expected_req_timeout = 4 + result = await table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert result is None + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + + @pytest.mark.parametrize( + "return_value,expected_result", + [ + ([], False), + ([object()], True), + ([object(), object()], True), + ], + ) + @pytest.mark.asyncio + async def test_row_exists(self, return_value, expected_result): + """Test checking for row existence""" + async with _make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + # return no rows + read_rows.side_effect = lambda *args, **kwargs: return_value + expected_op_timeout = 1 + expected_req_timeout = 2 + result = await table.row_exists( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert expected_result == result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + expected_filter = { + "chain": { + "filters": [ + {"cells_per_row_limit_filter": 1}, + {"strip_value_transformer": True}, + ] + } + } + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter._to_dict() == expected_filter + + +class TestReadRowsSharded: + @pytest.mark.asyncio + async def test_read_rows_sharded_empty_query(self): + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as exc: + await table.read_rows_sharded([]) + assert "empty sharded_query" in str(exc.value) + + @pytest.mark.asyncio + async def test_read_rows_sharded_multiple_queries(self): + """ + Test with multiple queries. Should return results from both + """ + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.side_effect = ( + lambda *args, **kwargs: TestReadRows._make_gapic_stream( + [ + TestReadRows._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] + ) + ) + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + result = await table.read_rows_sharded([query_1, query_2]) + assert len(result) == 2 + assert result[0].row_key == b"test_1" + assert result[1].row_key == b"test_2" + + @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) + @pytest.mark.asyncio + async def test_read_rows_sharded_multiple_queries_calls(self, n_queries): + """ + Each query should trigger a separate read_rows call + """ + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + query_list = [ReadRowsQuery() for _ in range(n_queries)] + await table.read_rows_sharded(query_list) + assert read_rows.call_count == n_queries + + @pytest.mark.asyncio + async def test_read_rows_sharded_errors(self): + """ + Errors should be exposed as ShardedReadRowsExceptionGroups + """ + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedQueryShardError + + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = RuntimeError("mock error") + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + await table.read_rows_sharded([query_1, query_2]) + exc_group = exc.value + assert isinstance(exc_group, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 2 + assert isinstance(exc.value.exceptions[0], FailedQueryShardError) + assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError) + assert exc.value.exceptions[0].index == 0 + assert exc.value.exceptions[0].query == query_1 + assert isinstance(exc.value.exceptions[1], FailedQueryShardError) + assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError) + assert exc.value.exceptions[1].index == 1 + assert exc.value.exceptions[1].query == query_2 + + @pytest.mark.asyncio + async def test_read_rows_sharded_concurrent(self): + """ + Ensure sharded requests are concurrent + """ + import time + + async def mock_call(*args, **kwargs): + await asyncio.sleep(0.1) + return [mock.Mock()] + + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(10)] + start_time = time.monotonic() + result = await table.read_rows_sharded(queries) + call_time = time.monotonic() - start_time + assert read_rows.call_count == 10 + assert len(result) == 10 + # if run in sequence, we would expect this to take 1 second + assert call_time < 0.2 + + @pytest.mark.asyncio + async def test_read_rows_sharded_batching(self): + """ + Large queries should be processed in batches to limit concurrency + operation timeout should change between batches + """ + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT + + assert _CONCURRENCY_LIMIT == 10 # change this test if this changes + + n_queries = 90 + expected_num_batches = n_queries // _CONCURRENCY_LIMIT + query_list = [ReadRowsQuery() for _ in range(n_queries)] + + table_mock = AsyncMock() + start_operation_timeout = 10 + start_attempt_timeout = 3 + table_mock.default_read_rows_operation_timeout = start_operation_timeout + table_mock.default_read_rows_attempt_timeout = start_attempt_timeout + # clock ticks one second on each check + with mock.patch("time.monotonic", side_effect=range(0, 100000)): + with mock.patch("asyncio.gather", AsyncMock()) as gather_mock: + await TableAsync.read_rows_sharded(table_mock, query_list) + # should have individual calls for each query + assert table_mock.read_rows.call_count == n_queries + # should have single gather call for each batch + assert gather_mock.call_count == expected_num_batches + # ensure that timeouts decrease over time + kwargs = [ + table_mock.read_rows.call_args_list[idx][1] + for idx in range(n_queries) + ] + for batch_idx in range(expected_num_batches): + batch_kwargs = kwargs[ + batch_idx + * _CONCURRENCY_LIMIT : (batch_idx + 1) + * _CONCURRENCY_LIMIT + ] + for req_kwargs in batch_kwargs: + # each batch should have the same operation_timeout, and it should decrease in each batch + expected_operation_timeout = start_operation_timeout - ( + batch_idx + 1 + ) + assert ( + req_kwargs["operation_timeout"] + == expected_operation_timeout + ) + # each attempt_timeout should start with default value, but decrease when operation_timeout reaches it + expected_attempt_timeout = min( + start_attempt_timeout, expected_operation_timeout + ) + assert req_kwargs["attempt_timeout"] == expected_attempt_timeout + # await all created coroutines to avoid warnings + for i in range(len(gather_mock.call_args_list)): + for j in range(len(gather_mock.call_args_list[i][0])): + await gather_mock.call_args_list[i][0][j] + + +class TestSampleRowKeys: + async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): + from google.cloud.bigtable_v2.types import SampleRowKeysResponse + + for value in sample_list: + yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) + + @pytest.mark.asyncio + async def test_sample_row_keys(self): + """ + Test that method returns the expected key samples + """ + samples = [ + (b"test_1", 0), + (b"test_2", 100), + (b"test_3", 200), + ] + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream(samples) + result = await table.sample_row_keys() + assert len(result) == 3 + assert all(isinstance(r, tuple) for r in result) + assert all(isinstance(r[0], bytes) for r in result) + assert all(isinstance(r[1], int) for r in result) + assert result[0] == samples[0] + assert result[1] == samples[1] + assert result[2] == samples[2] + + @pytest.mark.asyncio + async def test_sample_row_keys_bad_timeout(self): + """ + should raise error if timeout is negative + """ + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + await table.sample_row_keys(operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + await table.sample_row_keys(attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + + @pytest.mark.asyncio + async def test_sample_row_keys_default_timeout(self): + """Should fallback to using table default operation_timeout""" + expected_timeout = 99 + async with _make_client() as client: + async with client.get_table( + "i", + "t", + default_operation_timeout=expected_timeout, + default_attempt_timeout=expected_timeout, + ) as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + result = await table.sample_row_keys() + _, kwargs = sample_row_keys.call_args + assert abs(kwargs["timeout"] - expected_timeout) < 0.1 + assert result == [] + assert kwargs["retry"] is None + + @pytest.mark.asyncio + async def test_sample_row_keys_gapic_params(self): + """ + make sure arguments are propagated to gapic call as expected + """ + expected_timeout = 10 + expected_profile = "test1" + instance = "instance_name" + table_id = "my_table" + async with _make_client() as client: + async with client.get_table( + instance, table_id, app_profile_id=expected_profile + ) as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + await table.sample_row_keys(attempt_timeout=expected_timeout) + args, kwargs = sample_row_keys.call_args + assert len(args) == 0 + assert len(kwargs) == 5 + assert kwargs["timeout"] == expected_timeout + assert kwargs["app_profile_id"] == expected_profile + assert kwargs["table_name"] == table.table_name + assert kwargs["metadata"] is not None + assert kwargs["retry"] is None + + @pytest.mark.parametrize( + "retryable_exception", + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + @pytest.mark.asyncio + async def test_sample_row_keys_retryable_errors(self, retryable_exception): + """ + retryable errors should be retried until timeout + """ + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + await table.sample_row_keys(operation_timeout=0.05) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) > 0 + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, + ], + ) + @pytest.mark.asyncio + async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): + """ + non-retryable errors should cause a raise + """ + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + await table.sample_row_keys() + + +class TestMutateRow: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mutation_arg", + [ + mutations.SetCell("family", b"qualifier", b"value"), + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ), + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromFamily("family"), + mutations.DeleteAllFromRow(), + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + async def test_mutate_row(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.return_value = None + await table.mutate_row( + "row_key", + mutation_arg, + attempt_timeout=expected_attempt_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0].kwargs + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["row_key"] == b"row_key" + formatted_mutations = ( + [mutation._to_pb() for mutation in mutation_arg] + if isinstance(mutation_arg, list) + else [mutation_arg._to_pb()] + ) + assert kwargs["mutations"] == formatted_mutations + assert kwargs["timeout"] == expected_attempt_timeout + # make sure gapic layer is not retrying + assert kwargs["retry"] is None + + @pytest.mark.parametrize( + "retryable_exception", + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + @pytest.mark.asyncio + async def test_mutate_row_retryable_errors(self, retryable_exception): + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + mutation = mutations.DeleteAllFromRow() + assert mutation.is_idempotent() is True + await table.mutate_row( + "row_key", mutation, operation_timeout=0.01 + ) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "retryable_exception", + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + @pytest.mark.asyncio + async def test_mutate_row_non_idempotent_retryable_errors( + self, retryable_exception + ): + """ + Non-idempotent mutations should not be retried + """ + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(retryable_exception): + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + assert mutation.is_idempotent() is False + await table.mutate_row( + "row_key", mutation, operation_timeout=0.2 + ) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, + ], + ) + @pytest.mark.asyncio + async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + mutation = mutations.SetCell( + "family", + b"qualifier", + b"value", + timestamp_micros=1234567890, + ) + assert mutation.is_idempotent() is True + await table.mutate_row( + "row_key", mutation, operation_timeout=0.2 + ) + + @pytest.mark.parametrize("include_app_profile", [True, False]) + @pytest.mark.asyncio + async def test_mutate_row_metadata(self, include_app_profile): + """request should attach metadata headers""" + profile = "profile" if include_app_profile else None + async with _make_client() as client: + async with client.get_table("i", "t", app_profile_id=profile) as table: + with mock.patch.object( + client._gapic_client, "mutate_row", AsyncMock() + ) as read_rows: + await table.mutate_row("rk", mock.Mock()) + kwargs = read_rows.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata + + @pytest.mark.parametrize("mutations", [[], None]) + @pytest.mark.asyncio + async def test_mutate_row_no_mutations(self, mutations): + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + await table.mutate_row("key", mutations=mutations) + assert e.value.args[0] == "No mutations provided" + + +class TestBulkMutateRows: + async def _mock_response(self, response_list): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + statuses = [] + for response in response_list: + if isinstance(response, core_exceptions.GoogleAPICallError): + statuses.append( + status_pb2.Status( + message=str(response), code=response.grpc_status_code.value[0] + ) + ) + else: + statuses.append(status_pb2.Status(code=0)) + entries = [ + MutateRowsResponse.Entry(index=i, status=statuses[i]) + for i in range(len(response_list)) + ] + + async def generator(): + yield MutateRowsResponse(entries=entries) + + return generator() + + @pytest.mark.asyncio + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mutation_arg", + [ + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ) + ], + [mutations.DeleteRangeFromColumn("family", b"qualifier")], + [mutations.DeleteAllFromFamily("family")], + [mutations.DeleteAllFromRow()], + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + async def test_bulk_mutate_rows(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None]) + bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) + await table.bulk_mutate_rows( + [bulk_mutation], + attempt_timeout=expected_attempt_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None + + @pytest.mark.asyncio + async def test_bulk_mutate_rows_multiple_entries(self): + """Test mutations with no errors""" + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None, None]) + mutation_list = [mutations.DeleteAllFromRow()] + entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list) + entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list) + await table.bulk_mutate_rows( + [entry_1, entry_2], + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"][0] == entry_1._to_pb() + assert kwargs["entries"][1] == entry_2._to_pb() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exception", + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( + self, exception + ): + """ + Individual idempotent mutations should be retried if they fail with a retryable error + """ + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], exception) + # last exception should be due to retry timeout + assert isinstance( + cause.exceptions[-1], core_exceptions.DeadlineExceeded + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + core_exceptions.Aborted, + ], + ) + async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( + self, exception + ): + """ + Individual idempotent mutations should not be retried if they fail with a non-retryable error + """ + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, exception) + + @pytest.mark.parametrize( + "retryable_exception", + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + @pytest.mark.asyncio + async def test_bulk_mutate_idempotent_retryable_request_errors( + self, retryable_exception + ): + """ + Individual idempotent mutations should be retried if the request fails with a retryable error + """ + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "retryable_exception", + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + async def test_bulk_mutate_rows_non_idempotent_retryable_errors( + self, retryable_exception + ): + """Non-Idempotent mutations should never be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [retryable_exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is False + await table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, retryable_exception) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + ], + ) + @pytest.mark.asyncio + async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): + """ + If the request fails with a non-retryable error, mutations should not be retried + """ + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, non_retryable_exception) + + @pytest.mark.asyncio + async def test_bulk_mutate_error_index(self): + """ + Test partial failure, partial success. Errors should be associated with the correct index + """ + from google.api_core.exceptions import ( + DeadlineExceeded, + ServiceUnavailable, + FailedPrecondition, + ) + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + async with _make_client(project="project") as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + # fail with retryable errors, then a non-retryable one + mock_gapic.side_effect = [ + self._mock_response([None, ServiceUnavailable("mock"), None]), + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([FailedPrecondition("final")]), + ] + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry( + (f"row_key_{i}").encode(), [mutation] + ) + for i in range(3) + ] + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows(entries, operation_timeout=1000) + assert len(e.value.exceptions) == 1 + failed = e.value.exceptions[0] + assert isinstance(failed, FailedMutationEntryError) + assert failed.index == 1 + assert failed.entry == entries[1] + cause = failed.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) == 3 + assert isinstance(cause.exceptions[0], ServiceUnavailable) + assert isinstance(cause.exceptions[1], DeadlineExceeded) + assert isinstance(cause.exceptions[2], FailedPrecondition) + + @pytest.mark.asyncio + async def test_bulk_mutate_error_recovery(self): + """ + If an error occurs, then resolves, no exception should be raised + """ + from google.api_core.exceptions import DeadlineExceeded + + async with _make_client(project="project") as client: + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + # fail with a retryable error, then a non-retryable one + mock_gapic.side_effect = [ + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([None]), + ] + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry((f"row_key_{i}").encode(), [mutation]) + for i in range(3) + ] + await table.bulk_mutate_rows(entries, operation_timeout=1000) + + +class TestCheckAndMutateRow: + @pytest.mark.parametrize("gapic_result", [True, False]) + @pytest.mark.asyncio + async def test_check_and_mutate(self, gapic_result): + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + app_profile = "app_profile_id" + async with _make_client() as client: + async with client.get_table( + "instance", "table", app_profile_id=app_profile + ) as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=gapic_result + ) + row_key = b"row_key" + predicate = None + true_mutations = [mock.Mock()] + false_mutations = [mock.Mock(), mock.Mock()] + operation_timeout = 0.2 + found = await table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + operation_timeout=operation_timeout, + ) + assert found == gapic_result + kwargs = mock_gapic.call_args[1] + assert kwargs["table_name"] == table.table_name + assert kwargs["row_key"] == row_key + assert kwargs["predicate_filter"] == predicate + assert kwargs["true_mutations"] == [ + m._to_pb() for m in true_mutations + ] + assert kwargs["false_mutations"] == [ + m._to_pb() for m in false_mutations + ] + assert kwargs["app_profile_id"] == app_profile + assert kwargs["timeout"] == operation_timeout + assert kwargs["retry"] is None + + @pytest.mark.asyncio + async def test_check_and_mutate_bad_timeout(self): + """Should raise error if operation_timeout < 0""" + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + await table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=[mock.Mock()], + false_case_mutations=[], + operation_timeout=-1, + ) + assert str(e.value) == "operation_timeout must be greater than 0" + + @pytest.mark.asyncio + async def test_check_and_mutate_single_mutations(self): + """if single mutations are passed, they should be internally wrapped in a list""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + true_mutation = SetCell("family", b"qualifier", b"value") + false_mutation = SetCell("family", b"qualifier", b"value") + await table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == [true_mutation._to_pb()] + assert kwargs["false_mutations"] == [false_mutation._to_pb()] + + @pytest.mark.asyncio + async def test_check_and_mutate_predicate_object(self): + """predicate filter should be passed to gapic request""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + mock_predicate = mock.Mock() + predicate_pb = {"predicate": "dict"} + mock_predicate._to_pb.return_value = predicate_pb + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + await table.check_and_mutate_row( + b"row_key", + mock_predicate, + false_case_mutations=[mock.Mock()], + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["predicate_filter"] == predicate_pb + assert mock_predicate._to_pb.call_count == 1 + assert kwargs["retry"] is None + + @pytest.mark.asyncio + async def test_check_and_mutate_mutations_parsing(self): + """mutations objects should be converted to protos""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + mutations = [mock.Mock() for _ in range(5)] + for idx, mutation in enumerate(mutations): + mutation._to_pb.return_value = f"fake {idx}" + mutations.append(DeleteAllFromRow()) + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + await table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=mutations[0:2], + false_case_mutations=mutations[2:], + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == ["fake 0", "fake 1"] + assert kwargs["false_mutations"] == [ + "fake 2", + "fake 3", + "fake 4", + DeleteAllFromRow()._to_pb(), + ] + assert all( + mutation._to_pb.call_count == 1 for mutation in mutations[:5] + ) + + +class TestReadModifyWriteRow: + @pytest.mark.parametrize( + "call_rules,expected_rules", + [ + ( + AppendValueRule("f", "c", b"1"), + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + ( + [AppendValueRule("f", "c", b"1")], + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), + ( + [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], + [ + AppendValueRule("f", "c", b"1")._to_pb(), + IncrementRule("f", "c", 1)._to_pb(), + ], + ), + ], + ) + @pytest.mark.asyncio + async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): + """ + Test that the gapic call is called with given rules + """ + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + await table.read_modify_write_row("key", call_rules) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["rules"] == expected_rules + assert found_kwargs["retry"] is None + + @pytest.mark.parametrize("rules", [[], None]) + @pytest.mark.asyncio + async def test_read_modify_write_no_rules(self, rules): + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + await table.read_modify_write_row("key", rules=rules) + assert e.value.args[0] == "rules must contain at least one item" + + @pytest.mark.asyncio + async def test_read_modify_write_call_defaults(self): + instance = "instance1" + table_id = "table1" + project = "project1" + row_key = "row_key1" + async with _make_client(project=project) as client: + async with client.get_table(instance, table_id) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + await table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert ( + kwargs["table_name"] + == f"projects/{project}/instances/{instance}/tables/{table_id}" + ) + assert kwargs["app_profile_id"] is None + assert kwargs["row_key"] == row_key.encode() + assert kwargs["timeout"] > 1 + + @pytest.mark.asyncio + async def test_read_modify_write_call_overrides(self): + row_key = b"row_key1" + expected_timeout = 12345 + profile_id = "profile1" + async with _make_client() as client: + async with client.get_table( + "instance", "table_id", app_profile_id=profile_id + ) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + await table.read_modify_write_row( + row_key, + mock.Mock(), + operation_timeout=expected_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["app_profile_id"] is profile_id + assert kwargs["row_key"] == row_key + assert kwargs["timeout"] == expected_timeout + + @pytest.mark.asyncio + async def test_read_modify_write_string_key(self): + row_key = "string_row_key1" + async with _make_client() as client: + async with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + await table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["row_key"] == row_key.encode() + + @pytest.mark.asyncio + async def test_read_modify_write_row_building(self): + """ + results from gapic call should be used to construct row + """ + from google.cloud.bigtable.data.row import Row + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + from google.cloud.bigtable_v2.types import Row as RowPB + + mock_response = ReadModifyWriteRowResponse(row=RowPB()) + async with _make_client() as client: + async with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + with mock.patch.object(Row, "_from_pb") as constructor_mock: + mock_gapic.return_value = mock_response + await table.read_modify_write_row("key", mock.Mock()) + assert constructor_mock.call_count == 1 + constructor_mock.assert_called_once_with(mock_response.row) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py new file mode 100644 index 000000000..cca7c9824 --- /dev/null +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -0,0 +1,1184 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import asyncio +import google.api_core.exceptions as core_exceptions +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore + + +def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + +class Test_FlowControl: + def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): + from google.cloud.bigtable.data._async.mutations_batcher import ( + _FlowControlAsync, + ) + + return _FlowControlAsync(max_mutation_count, max_mutation_bytes) + + def test_ctor(self): + max_mutation_count = 9 + max_mutation_bytes = 19 + instance = self._make_one(max_mutation_count, max_mutation_bytes) + assert instance._max_mutation_count == max_mutation_count + assert instance._max_mutation_bytes == max_mutation_bytes + assert instance._in_flight_mutation_count == 0 + assert instance._in_flight_mutation_bytes == 0 + assert isinstance(instance._capacity_condition, asyncio.Condition) + + def test_ctor_invalid_values(self): + """Test that values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(0, 1) + assert "max_mutation_count must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(1, 0) + assert "max_mutation_bytes must be greater than 0" in str(e.value) + + @pytest.mark.parametrize( + "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", + [ + (1, 1, 0, 0, 0, 0, True), + (1, 1, 1, 1, 1, 1, False), + (10, 10, 0, 0, 0, 0, True), + (10, 10, 0, 0, 9, 9, True), + (10, 10, 0, 0, 11, 9, True), + (10, 10, 0, 1, 11, 9, True), + (10, 10, 1, 0, 11, 9, False), + (10, 10, 0, 0, 9, 11, True), + (10, 10, 1, 0, 9, 11, True), + (10, 10, 0, 1, 9, 11, False), + (10, 1, 0, 0, 1, 0, True), + (1, 10, 0, 0, 0, 8, True), + (float("inf"), float("inf"), 0, 0, 1e10, 1e10, True), + (8, 8, 0, 0, 1e10, 1e10, True), + (12, 12, 6, 6, 5, 5, True), + (12, 12, 5, 5, 6, 6, True), + (12, 12, 6, 6, 6, 6, True), + (12, 12, 6, 6, 7, 7, False), + # allow capacity check if new_count or new_size exceeds limits + (12, 12, 0, 0, 13, 13, True), + (12, 12, 12, 0, 0, 13, True), + (12, 12, 0, 12, 13, 0, True), + # but not if there's already values in flight + (12, 12, 1, 1, 13, 13, False), + (12, 12, 1, 1, 0, 13, False), + (12, 12, 1, 1, 13, 0, False), + ], + ) + def test__has_capacity( + self, + max_count, + max_size, + existing_count, + existing_size, + new_count, + new_size, + expected, + ): + """ + _has_capacity should return True if the new mutation will will not exceed the max count or size + """ + instance = self._make_one(max_count, max_size) + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + assert instance._has_capacity(new_count, new_size) == expected + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "existing_count,existing_size,added_count,added_size,new_count,new_size", + [ + (0, 0, 0, 0, 0, 0), + (2, 2, 1, 1, 1, 1), + (2, 0, 1, 0, 1, 0), + (0, 2, 0, 1, 0, 1), + (10, 10, 0, 0, 10, 10), + (10, 10, 5, 5, 5, 5), + (0, 0, 1, 1, -1, -1), + ], + ) + async def test_remove_from_flow_value_update( + self, + existing_count, + existing_size, + added_count, + added_size, + new_count, + new_size, + ): + """ + completed mutations should lower the inflight values + """ + instance = self._make_one() + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + mutation = _make_mutation(added_count, added_size) + await instance.remove_from_flow(mutation) + assert instance._in_flight_mutation_count == new_count + assert instance._in_flight_mutation_bytes == new_size + + @pytest.mark.asyncio + async def test__remove_from_flow_unlock(self): + """capacity condition should notify after mutation is complete""" + instance = self._make_one(10, 10) + instance._in_flight_mutation_count = 10 + instance._in_flight_mutation_bytes = 10 + + async def task_routine(): + async with instance._capacity_condition: + await instance._capacity_condition.wait_for( + lambda: instance._has_capacity(1, 1) + ) + + task = asyncio.create_task(task_routine()) + await asyncio.sleep(0.05) + # should be blocked due to capacity + assert task.done() is False + # try changing size + mutation = _make_mutation(count=0, size=5) + await instance.remove_from_flow([mutation]) + await asyncio.sleep(0.05) + assert instance._in_flight_mutation_count == 10 + assert instance._in_flight_mutation_bytes == 5 + assert task.done() is False + # try changing count + instance._in_flight_mutation_bytes = 10 + mutation = _make_mutation(count=5, size=0) + await instance.remove_from_flow([mutation]) + await asyncio.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 10 + assert task.done() is False + # try changing both + instance._in_flight_mutation_count = 10 + mutation = _make_mutation(count=5, size=5) + await instance.remove_from_flow([mutation]) + await asyncio.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 5 + # task should be complete + assert task.done() is True + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mutations,count_cap,size_cap,expected_results", + [ + # high capacity results in no batching + ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), + # low capacity splits up into batches + ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), + # test count as limiting factor + ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), + # test size as limiting factor + ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), + # test with some bloackages and some flows + ( + [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], + 5, + 5, + [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], + ), + ], + ) + async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): + """ + Test batching with various flow control settings + """ + mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] + instance = self._make_one(count_cap, size_cap) + i = 0 + async for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + # check counts + assert len(batch[j].mutations) == expected_batch[j][0] + # check sizes + assert batch[j].size() == expected_batch[j][1] + # update lock + await instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mutations,max_limit,expected_results", + [ + ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), + ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), + ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), + ], + ) + async def test_add_to_flow_max_mutation_limits( + self, mutations, max_limit, expected_results + ): + """ + Test flow control running up against the max API limit + Should submit request early, even if the flow control has room for more + """ + with mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ): + mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] + # flow control has no limits except API restrictions + instance = self._make_one(float("inf"), float("inf")) + i = 0 + async for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + # check counts + assert len(batch[j].mutations) == expected_batch[j][0] + # check sizes + assert batch[j].size() == expected_batch[j][1] + # update lock + await instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + @pytest.mark.asyncio + async def test_add_to_flow_oversize(self): + """ + mutations over the flow control limits should still be accepted + """ + instance = self._make_one(2, 3) + large_size_mutation = _make_mutation(count=1, size=10) + large_count_mutation = _make_mutation(count=10, size=1) + results = [out async for out in instance.add_to_flow([large_size_mutation])] + assert len(results) == 1 + await instance.remove_from_flow(results[0]) + count_results = [ + out async for out in instance.add_to_flow(large_count_mutation) + ] + assert len(count_results) == 1 + + +class TestMutationsBatcherAsync: + def _get_target_class(self): + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + + return MutationsBatcherAsync + + def _make_one(self, table=None, **kwargs): + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import ServiceUnavailable + + if table is None: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 10 + table.default_mutate_rows_retryable_errors = ( + DeadlineExceeded, + ServiceUnavailable, + ) + + return self._get_target_class()(table, **kwargs) + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" + ) + @pytest.mark.asyncio + async def test_ctor_defaults(self, flush_timer_mock): + flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + async with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors == table.default_mutate_rows_retryable_errors + ) + await asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, asyncio.Future) + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer", + ) + @pytest.mark.asyncio + async def test_ctor_explicit(self, flush_timer_mock): + """Test with explicit parameters""" + flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count + ) + assert instance._flow_control._max_mutation_bytes == flow_control_max_bytes + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + await asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, asyncio.Future) + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" + ) + @pytest.mark.asyncio + async def test_ctor_no_flush_limits(self, flush_timer_mock): + """Test with None for flush limits""" + flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + await asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, asyncio.Future) + + @pytest.mark.asyncio + async def test_ctor_invalid_values(self): + """Test that timeout values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(batch_operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(batch_attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + + def test_default_argument_consistency(self): + """ + We supply default arguments in MutationsBatcherAsync.__init__, and in + table.mutations_batcher. Make sure any changes to defaults are applied to + both places + """ + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + import inspect + + get_batcher_signature = dict( + inspect.signature(TableAsync.mutations_batcher).parameters + ) + get_batcher_signature.pop("self") + batcher_init_signature = dict( + inspect.signature(MutationsBatcherAsync).parameters + ) + batcher_init_signature.pop("table") + # both should have same number of arguments + assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) + assert len(get_batcher_signature) == 8 # update if expected params change + # both should have same argument names + assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) + # both should have same default values + for arg_name in get_batcher_signature.keys(): + assert ( + get_batcher_signature[arg_name].default + == batcher_init_signature[arg_name].default + ) + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__start_flush_timer_w_None(self, flush_mock): + """Empty timer should return immediately""" + async with self._make_one() as instance: + with mock.patch("asyncio.sleep") as sleep_mock: + await instance._start_flush_timer(None) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__start_flush_timer_call_when_closed(self, flush_mock): + """closed batcher's timer should return immediately""" + async with self._make_one() as instance: + await instance.close() + flush_mock.reset_mock() + with mock.patch("asyncio.sleep") as sleep_mock: + await instance._start_flush_timer(1) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__flush_timer(self, flush_mock): + """Timer should continue to call _schedule_flush in a loop""" + expected_sleep = 12 + async with self._make_one(flush_interval=expected_sleep) as instance: + instance._staged_entries = [mock.Mock()] + loop_num = 3 + with mock.patch("asyncio.sleep") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] + try: + await instance._flush_timer + except asyncio.CancelledError: + pass + assert sleep_mock.call_count == loop_num + 1 + sleep_mock.assert_called_with(expected_sleep) + assert flush_mock.call_count == loop_num + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__flush_timer_no_mutations(self, flush_mock): + """Timer should not flush if no new mutations have been staged""" + expected_sleep = 12 + async with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + with mock.patch("asyncio.sleep") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] + try: + await instance._flush_timer + except asyncio.CancelledError: + pass + assert sleep_mock.call_count == loop_num + 1 + sleep_mock.assert_called_with(expected_sleep) + assert flush_mock.call_count == 0 + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__flush_timer_close(self, flush_mock): + """Timer should continue terminate after close""" + async with self._make_one() as instance: + with mock.patch("asyncio.sleep"): + # let task run in background + await asyncio.sleep(0.5) + assert instance._flush_timer.done() is False + # close the batcher + await instance.close() + await asyncio.sleep(0.1) + # task should be complete + assert instance._flush_timer.done() is True + + @pytest.mark.asyncio + async def test_append_closed(self): + """Should raise exception""" + with pytest.raises(RuntimeError): + instance = self._make_one() + await instance.close() + await instance.append(mock.Mock()) + + @pytest.mark.asyncio + async def test_append_wrong_mutation(self): + """ + Mutation objects should raise an exception. + Only support RowMutationEntry + """ + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + async with self._make_one() as instance: + expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher" + with pytest.raises(ValueError) as e: + await instance.append(DeleteAllFromRow()) + assert str(e.value) == expected_error + + @pytest.mark.asyncio + async def test_append_outside_flow_limits(self): + """entries larger than mutation limits are still processed""" + async with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + oversized_entry = _make_mutation(count=0, size=2) + await instance.append(oversized_entry) + assert instance._staged_entries == [oversized_entry] + assert instance._staged_count == 0 + assert instance._staged_bytes == 2 + instance._staged_entries = [] + async with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + overcount_entry = _make_mutation(count=2, size=0) + await instance.append(overcount_entry) + assert instance._staged_entries == [overcount_entry] + assert instance._staged_count == 2 + assert instance._staged_bytes == 0 + instance._staged_entries = [] + + @pytest.mark.asyncio + async def test_append_flush_runs_after_limit_hit(self): + """ + If the user appends a bunch of entries above the flush limits back-to-back, + it should still flush in a single task + """ + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + + with mock.patch.object( + MutationsBatcherAsync, "_execute_mutate_rows" + ) as op_mock: + async with self._make_one(flush_limit_bytes=100) as instance: + # mock network calls + async def mock_call(*args, **kwargs): + return [] + + op_mock.side_effect = mock_call + # append a mutation just under the size limit + await instance.append(_make_mutation(size=99)) + # append a bunch of entries back-to-back in a loop + num_entries = 10 + for _ in range(num_entries): + await instance.append(_make_mutation(size=1)) + # let any flush jobs finish + await asyncio.gather(*instance._flush_jobs) + # should have only flushed once, with large mutation and first mutation in loop + assert op_mock.call_count == 1 + sent_batch = op_mock.call_args[0][0] + assert len(sent_batch) == 2 + # others should still be pending + assert len(instance._staged_entries) == num_entries - 1 + + @pytest.mark.parametrize( + "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush", + [ + (10, 10, 1, 1, False), + (10, 10, 9, 9, False), + (10, 10, 10, 1, True), + (10, 10, 1, 10, True), + (10, 10, 10, 10, True), + (1, 1, 10, 10, True), + (1, 1, 0, 0, False), + ], + ) + @pytest.mark.asyncio + async def test_append( + self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush + ): + """test appending different mutations, and checking if it causes a flush""" + async with self._make_one( + flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = _make_mutation(count=mutation_count, size=mutation_bytes) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + await instance.append(mutation) + assert flush_mock.call_count == bool(expect_flush) + assert instance._staged_count == mutation_count + assert instance._staged_bytes == mutation_bytes + assert instance._staged_entries == [mutation] + instance._staged_entries = [] + + @pytest.mark.asyncio + async def test_append_multiple_sequentially(self): + """Append multiple mutations""" + async with self._make_one( + flush_limit_mutation_count=8, flush_limit_bytes=8 + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = _make_mutation(count=2, size=3) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + await instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 2 + assert instance._staged_bytes == 3 + assert len(instance._staged_entries) == 1 + await instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 4 + assert instance._staged_bytes == 6 + assert len(instance._staged_entries) == 2 + await instance.append(mutation) + assert flush_mock.call_count == 1 + assert instance._staged_count == 6 + assert instance._staged_bytes == 9 + assert len(instance._staged_entries) == 3 + instance._staged_entries = [] + + @pytest.mark.asyncio + async def test_flush_flow_control_concurrent_requests(self): + """ + requests should happen in parallel if flow control breaks up single flush into batches + """ + import time + + num_calls = 10 + fake_mutations = [_make_mutation(count=1) for _ in range(num_calls)] + async with self._make_one(flow_control_max_mutation_count=1) as instance: + with mock.patch.object( + instance, "_execute_mutate_rows", AsyncMock() + ) as op_mock: + # mock network calls + async def mock_call(*args, **kwargs): + await asyncio.sleep(0.1) + return [] + + op_mock.side_effect = mock_call + start_time = time.monotonic() + # flush one large batch, that will be broken up into smaller batches + instance._staged_entries = fake_mutations + instance._schedule_flush() + await asyncio.sleep(0.01) + # make room for new mutations + for i in range(num_calls): + await instance._flow_control.remove_from_flow( + [_make_mutation(count=1)] + ) + await asyncio.sleep(0.01) + # allow flushes to complete + await asyncio.gather(*instance._flush_jobs) + duration = time.monotonic() - start_time + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + # if flushes were sequential, total duration would be 1s + assert duration < 0.5 + assert op_mock.call_count == num_calls + + @pytest.mark.asyncio + async def test_schedule_flush_no_mutations(self): + """schedule flush should return None if no staged mutations""" + async with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + for i in range(3): + assert instance._schedule_flush() is None + assert flush_mock.call_count == 0 + + @pytest.mark.asyncio + async def test_schedule_flush_with_mutations(self): + """if new mutations exist, should add a new flush task to _flush_jobs""" + async with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + for i in range(1, 4): + mutation = mock.Mock() + instance._staged_entries = [mutation] + instance._schedule_flush() + assert instance._staged_entries == [] + # let flush task run + await asyncio.sleep(0) + assert instance._staged_entries == [] + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert flush_mock.call_count == i + + @pytest.mark.asyncio + async def test__flush_internal(self): + """ + _flush_internal should: + - await previous flush call + - delegate batching to _flow_control + - call _execute_mutate_rows on each batch + - update self.exceptions and self._entries_processed_since_last_raise + """ + num_entries = 10 + async with self._make_one() as instance: + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + # mock flow control to always return a single batch + async def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [_make_mutation(count=1, size=1)] * num_entries + await instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + @pytest.mark.asyncio + async def test_flush_clears_job_list(self): + """ + a job should be added to _flush_jobs when _schedule_flush is called, + and removed when it completes + """ + async with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal", AsyncMock()): + mutations = [_make_mutation(count=1, size=1)] + instance._staged_entries = mutations + assert instance._flush_jobs == set() + new_job = instance._schedule_flush() + assert instance._flush_jobs == {new_job} + await new_job + assert instance._flush_jobs == set() + + @pytest.mark.parametrize( + "num_starting,num_new_errors,expected_total_errors", + [ + (0, 0, 0), + (0, 1, 1), + (0, 2, 2), + (1, 0, 1), + (1, 1, 2), + (10, 2, 12), + (10, 20, 20), # should cap at 20 + ], + ) + @pytest.mark.asyncio + async def test__flush_internal_with_errors( + self, num_starting, num_new_errors, expected_total_errors + ): + """ + errors returned from _execute_mutate_rows should be added to internal exceptions + """ + from google.cloud.bigtable.data import exceptions + + num_entries = 10 + expected_errors = [ + exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError()) + ] * num_new_errors + async with self._make_one() as instance: + instance._oldest_exceptions = [mock.Mock()] * num_starting + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + execute_mock.return_value = expected_errors + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + # mock flow control to always return a single batch + async def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [_make_mutation(count=1, size=1)] * num_entries + await instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + found_exceptions = instance._oldest_exceptions + list( + instance._newest_exceptions + ) + assert len(found_exceptions) == expected_total_errors + for i in range(num_starting, expected_total_errors): + assert found_exceptions[i] == expected_errors[i - num_starting] + # errors should have index stripped + assert found_exceptions[i].index is None + # clear out exceptions + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + async def _mock_gapic_return(self, num=5): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + async def gen(num): + for i in range(num): + entry = MutateRowsResponse.Entry( + index=i, status=status_pb2.Status(code=0) + ) + yield MutateRowsResponse(entries=[entry]) + + return gen(num) + + @pytest.mark.asyncio + async def test_timer_flush_end_to_end(self): + """Flush should automatically trigger after flush_interval""" + num_nutations = 10 + mutations = [_make_mutation(count=2, size=2)] * num_nutations + + async with self._make_one(flush_interval=0.05) as instance: + instance._table.default_operation_timeout = 10 + instance._table.default_attempt_timeout = 9 + with mock.patch.object( + instance._table.client._gapic_client, "mutate_rows" + ) as gapic_mock: + gapic_mock.side_effect = ( + lambda *args, **kwargs: self._mock_gapic_return(num_nutations) + ) + for m in mutations: + await instance.append(m) + assert instance._entries_processed_since_last_raise == 0 + # let flush trigger due to timer + await asyncio.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_nutations + + @pytest.mark.asyncio + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", + ) + async def test__execute_mutate_rows(self, mutate_rows): + mutate_rows.return_value = AsyncMock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [_make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + args, kwargs = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + @pytest.mark.asyncio + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync.start" + ) + async def test__execute_mutate_rows_returns_errors(self, mutate_rows): + """Errors from operation should be retruned as list""" + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, + ) + + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [_make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + # indices should be set to None + assert result[0].index is None + assert result[1].index is None + + @pytest.mark.asyncio + async def test__raise_exceptions(self): + """Raise exceptions and reset error state""" + from google.cloud.bigtable.data import exceptions + + expected_total = 1201 + expected_exceptions = [RuntimeError("mock")] * 3 + async with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance._raise_exceptions() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + instance._oldest_exceptions, instance._newest_exceptions = ([], []) + # try calling again + instance._raise_exceptions() + + @pytest.mark.asyncio + async def test___aenter__(self): + """Should return self""" + async with self._make_one() as instance: + assert await instance.__aenter__() == instance + + @pytest.mark.asyncio + async def test___aexit__(self): + """aexit should call close""" + async with self._make_one() as instance: + with mock.patch.object(instance, "close") as close_mock: + await instance.__aexit__(None, None, None) + assert close_mock.call_count == 1 + + @pytest.mark.asyncio + async def test_close(self): + """Should clean up all resources""" + async with self._make_one() as instance: + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + with mock.patch.object(instance, "_raise_exceptions") as raise_mock: + await instance.close() + assert instance.closed is True + assert instance._flush_timer.done() is True + assert instance._flush_jobs == set() + assert flush_mock.call_count == 1 + assert raise_mock.call_count == 1 + + @pytest.mark.asyncio + async def test_close_w_exceptions(self): + """Raise exceptions on close""" + from google.cloud.bigtable.data import exceptions + + expected_total = 10 + expected_exceptions = [RuntimeError("mock")] + async with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + await instance.close() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + # clear out exceptions + instance._oldest_exceptions, instance._newest_exceptions = ([], []) + + @pytest.mark.asyncio + async def test__on_exit(self, recwarn): + """Should raise warnings if unflushed mutations exist""" + async with self._make_one() as instance: + # calling without mutations is noop + instance._on_exit() + assert len(recwarn) == 0 + # calling with existing mutations should raise warning + num_left = 4 + instance._staged_entries = [mock.Mock()] * num_left + with pytest.warns(UserWarning) as w: + instance._on_exit() + assert len(w) == 1 + assert "unflushed mutations" in str(w[0].message).lower() + assert str(num_left) in str(w[0].message) + # calling while closed is noop + instance.closed = True + instance._on_exit() + assert len(recwarn) == 0 + # reset staged mutations for cleanup + instance._staged_entries = [] + + @pytest.mark.asyncio + async def test_atexit_registration(self): + """Should run _on_exit on program termination""" + import atexit + + with mock.patch.object(atexit, "register") as register_mock: + assert register_mock.call_count == 0 + async with self._make_one(): + assert register_mock.call_count == 1 + + @pytest.mark.asyncio + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", + ) + async def test_timeout_args_passed(self, mutate_rows): + """ + batch_operation_timeout and batch_attempt_timeout should be used + in api calls + """ + mutate_rows.return_value = AsyncMock() + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + async with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + # make simulated gapic call + await instance._execute_mutate_rows([_make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout + + @pytest.mark.parametrize( + "limit,in_e,start_e,end_e", + [ + (10, 0, (10, 0), (10, 0)), + (1, 10, (0, 0), (1, 1)), + (10, 1, (0, 0), (1, 0)), + (10, 10, (0, 0), (10, 0)), + (10, 11, (0, 0), (10, 1)), + (3, 20, (0, 0), (3, 3)), + (10, 20, (0, 0), (10, 10)), + (10, 21, (0, 0), (10, 10)), + (2, 1, (2, 0), (2, 1)), + (2, 1, (1, 0), (2, 0)), + (2, 2, (1, 0), (2, 1)), + (3, 1, (3, 1), (3, 2)), + (3, 3, (3, 1), (3, 3)), + (1000, 5, (999, 0), (1000, 4)), + (1000, 5, (0, 0), (5, 0)), + (1000, 5, (1000, 0), (1000, 5)), + ], + ) + def test__add_exceptions(self, limit, in_e, start_e, end_e): + """ + Test that the _add_exceptions function properly updates the + _oldest_exceptions and _newest_exceptions lists + Args: + - limit: the _exception_list_limit representing the max size of either list + - in_e: size of list of exceptions to send to _add_exceptions + - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions + - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions + """ + from collections import deque + + input_list = [RuntimeError(f"mock {i}") for i in range(in_e)] + mock_batcher = mock.Mock() + mock_batcher._oldest_exceptions = [ + RuntimeError(f"starting mock {i}") for i in range(start_e[0]) + ] + mock_batcher._newest_exceptions = deque( + [RuntimeError(f"starting mock {i}") for i in range(start_e[1])], + maxlen=limit, + ) + mock_batcher._exception_list_limit = limit + mock_batcher._exceptions_since_last_raise = 0 + self._get_target_class()._add_exceptions(mock_batcher, input_list) + assert len(mock_batcher._oldest_exceptions) == end_e[0] + assert len(mock_batcher._newest_exceptions) == end_e[1] + assert mock_batcher._exceptions_since_last_raise == in_e + # make sure that the right items ended up in the right spots + # should fill the oldest slots first + oldest_list_diff = end_e[0] - start_e[0] + # new items should by added on top of the starting list + newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit) + for i in range(oldest_list_diff): + assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i] + # then, the newest slots should be filled with the last items of the input list + for i in range(1, newest_list_diff + 1): + assert mock_batcher._newest_exceptions[-i] == input_list[-i] + + @pytest.mark.asyncio + # test different inputs for retryable exceptions + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + async def test_customizable_retryable_errors( + self, input_retryables, expected_retryables + ): + """ + Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer. + """ + from google.cloud.bigtable.data._async.client import TableAsync + + with mock.patch( + "google.api_core.retry.if_exception_type" + ) as predicate_builder_mock: + with mock.patch( + "google.api_core.retry.retry_target_async" + ) as retry_fn_mock: + table = None + with mock.patch("asyncio.create_task"): + table = TableAsync(mock.Mock(), "instance", "table") + async with self._make_one( + table, batch_retryable_errors=input_retryables + ) as instance: + assert instance._retryable_errors == expected_retryables + expected_predicate = lambda a: a in expected_retryables # noqa + predicate_builder_mock.return_value = expected_predicate + retry_fn_mock.side_effect = RuntimeError("stop early") + mutation = _make_mutation(count=1, size=1) + await instance._execute_mutate_rows([mutation]) + # passed in errors should be used to build the predicate + predicate_builder_mock.assert_called_once_with( + *expected_retryables, _MutateRowsIncomplete + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + # output of if_exception_type should be sent in to retry constructor + assert retry_call_args[1] is expected_predicate diff --git a/tests/unit/read-rows-acceptance-test.json b/tests/unit/data/read-rows-acceptance-test.json similarity index 100% rename from tests/unit/read-rows-acceptance-test.json rename to tests/unit/data/read-rows-acceptance-test.json diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py new file mode 100644 index 000000000..5a9c500ed --- /dev/null +++ b/tests/unit/data/test__helpers.py @@ -0,0 +1,248 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import grpc +from google.api_core import exceptions as core_exceptions +import google.cloud.bigtable.data._helpers as _helpers +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT + +import mock + + +class TestMakeMetadata: + @pytest.mark.parametrize( + "table,profile,expected", + [ + ("table", "profile", "table_name=table&app_profile_id=profile"), + ("table", None, "table_name=table"), + ], + ) + def test__make_metadata(self, table, profile, expected): + metadata = _helpers._make_metadata(table, profile) + assert metadata == [("x-goog-request-params", expected)] + + +class TestAttemptTimeoutGenerator: + @pytest.mark.parametrize( + "request_t,operation_t,expected_list", + [ + (1, 3.5, [1, 1, 1, 0.5, 0, 0]), + (None, 3.5, [3.5, 2.5, 1.5, 0.5, 0, 0]), + (10, 5, [5, 4, 3, 2, 1, 0, 0]), + (3, 3, [3, 2, 1, 0, 0, 0, 0]), + (0, 3, [0, 0, 0]), + (3, 0, [0, 0, 0]), + (-1, 3, [0, 0, 0]), + (3, -1, [0, 0, 0]), + ], + ) + def test_attempt_timeout_generator(self, request_t, operation_t, expected_list): + """ + test different values for timeouts. Clock is incremented by 1 second for each item in expected_list + """ + timestamp_start = 123 + with mock.patch("time.monotonic") as mock_monotonic: + mock_monotonic.return_value = timestamp_start + generator = _helpers._attempt_timeout_generator(request_t, operation_t) + for val in expected_list: + mock_monotonic.return_value += 1 + assert next(generator) == val + + @pytest.mark.parametrize( + "request_t,operation_t,expected", + [ + (1, 3.5, 1), + (None, 3.5, 3.5), + (10, 5, 5), + (5, 10, 5), + (3, 3, 3), + (0, 3, 0), + (3, 0, 0), + (-1, 3, 0), + (3, -1, 0), + ], + ) + def test_attempt_timeout_frozen_time(self, request_t, operation_t, expected): + """test with time.monotonic frozen""" + timestamp_start = 123 + with mock.patch("time.monotonic") as mock_monotonic: + mock_monotonic.return_value = timestamp_start + generator = _helpers._attempt_timeout_generator(request_t, operation_t) + assert next(generator) == expected + # value should not change without time.monotonic changing + assert next(generator) == expected + + def test_attempt_timeout_w_sleeps(self): + """use real sleep values to make sure it matches expectations""" + from time import sleep + + operation_timeout = 1 + generator = _helpers._attempt_timeout_generator(None, operation_timeout) + expected_value = operation_timeout + sleep_time = 0.1 + for i in range(3): + found_value = next(generator) + assert abs(found_value - expected_value) < 0.001 + sleep(sleep_time) + expected_value -= sleep_time + + +class TestValidateTimeouts: + def test_validate_timeouts_error_messages(self): + with pytest.raises(ValueError) as e: + _helpers._validate_timeouts(operation_timeout=1, attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + _helpers._validate_timeouts(operation_timeout=-1, attempt_timeout=1) + assert "operation_timeout must be greater than 0" in str(e.value) + + @pytest.mark.parametrize( + "args,expected", + [ + ([1, None, False], False), + ([1, None, True], True), + ([1, 1, False], True), + ([1, 1, True], True), + ([1, 1], True), + ([1, None], False), + ([2, 1], True), + ([0, 1], False), + ([1, 0], False), + ([60, None], False), + ([600, None], False), + ([600, 600], True), + ], + ) + def test_validate_with_inputs(self, args, expected): + """ + test whether an exception is thrown with different inputs + """ + success = False + try: + _helpers._validate_timeouts(*args) + success = True + except ValueError: + pass + assert success == expected + + +class TestGetTimeouts: + @pytest.mark.parametrize( + "input_times,input_table,expected", + [ + ((2, 1), {}, (2, 1)), + ((2, 4), {}, (2, 2)), + ((2, None), {}, (2, 2)), + ( + (TABLE_DEFAULT.DEFAULT, TABLE_DEFAULT.DEFAULT), + {"operation": 3, "attempt": 2}, + (3, 2), + ), + ( + (TABLE_DEFAULT.READ_ROWS, TABLE_DEFAULT.READ_ROWS), + {"read_rows_operation": 3, "read_rows_attempt": 2}, + (3, 2), + ), + ( + (TABLE_DEFAULT.MUTATE_ROWS, TABLE_DEFAULT.MUTATE_ROWS), + {"mutate_rows_operation": 3, "mutate_rows_attempt": 2}, + (3, 2), + ), + ((10, TABLE_DEFAULT.DEFAULT), {"attempt": None}, (10, 10)), + ((10, TABLE_DEFAULT.DEFAULT), {"attempt": 5}, (10, 5)), + ((10, TABLE_DEFAULT.DEFAULT), {"attempt": 100}, (10, 10)), + ((TABLE_DEFAULT.DEFAULT, 10), {"operation": 12}, (12, 10)), + ((TABLE_DEFAULT.DEFAULT, 10), {"operation": 3}, (3, 3)), + ], + ) + def test_get_timeouts(self, input_times, input_table, expected): + """ + test input/output mappings for a variety of valid inputs + """ + fake_table = mock.Mock() + for key in input_table.keys(): + # set the default fields in our fake table mock + setattr(fake_table, f"default_{key}_timeout", input_table[key]) + t1, t2 = _helpers._get_timeouts(input_times[0], input_times[1], fake_table) + assert t1 == expected[0] + assert t2 == expected[1] + + @pytest.mark.parametrize( + "input_times,input_table", + [ + ([0, 1], {}), + ([1, 0], {}), + ([None, 1], {}), + ([TABLE_DEFAULT.DEFAULT, 1], {"operation": None}), + ([TABLE_DEFAULT.DEFAULT, 1], {"operation": 0}), + ([1, TABLE_DEFAULT.DEFAULT], {"attempt": 0}), + ], + ) + def test_get_timeouts_invalid(self, input_times, input_table): + """ + test with inputs that should raise error during validation step + """ + fake_table = mock.Mock() + for key in input_table.keys(): + # set the default fields in our fake table mock + setattr(fake_table, f"default_{key}_timeout", input_table[key]) + with pytest.raises(ValueError): + _helpers._get_timeouts(input_times[0], input_times[1], fake_table) + + +class TestGetRetryableErrors: + @pytest.mark.parametrize( + "input_codes,input_table,expected", + [ + ((), {}, []), + ((Exception,), {}, [Exception]), + (TABLE_DEFAULT.DEFAULT, {"default": [Exception]}, [Exception]), + ( + TABLE_DEFAULT.READ_ROWS, + {"default_read_rows": (RuntimeError, ValueError)}, + [RuntimeError, ValueError], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + {"default_mutate_rows": (ValueError,)}, + [ValueError], + ), + ((4,), {}, [core_exceptions.DeadlineExceeded]), + ( + [grpc.StatusCode.DEADLINE_EXCEEDED], + {}, + [core_exceptions.DeadlineExceeded], + ), + ( + (14, grpc.StatusCode.ABORTED, RuntimeError), + {}, + [ + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + RuntimeError, + ], + ), + ], + ) + def test_get_retryable_errors(self, input_codes, input_table, expected): + """ + test input/output mappings for a variety of valid inputs + """ + fake_table = mock.Mock() + for key in input_table.keys(): + # set the default fields in our fake table mock + setattr(fake_table, f"{key}_retryable_errors", input_table[key]) + result = _helpers._get_retryable_errors(input_codes, fake_table) + assert result == expected diff --git a/tests/unit/data/test_exceptions.py b/tests/unit/data/test_exceptions.py new file mode 100644 index 000000000..bc921717e --- /dev/null +++ b/tests/unit/data/test_exceptions.py @@ -0,0 +1,533 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import pytest +import sys + +import google.cloud.bigtable.data.exceptions as bigtable_exceptions + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock +except ImportError: # pragma: NO COVER + import mock # type: ignore + + +class TracebackTests311: + """ + Provides a set of tests that should be run on python 3.11 and above, + to verify that the exception traceback looks as expected + """ + + @pytest.mark.skipif( + sys.version_info < (3, 11), reason="requires python3.11 or higher" + ) + def test_311_traceback(self): + """ + Exception customizations should not break rich exception group traceback in python 3.11 + """ + import traceback + + sub_exc1 = RuntimeError("first sub exception") + sub_exc2 = ZeroDivisionError("second sub exception") + sub_group = self._make_one(excs=[sub_exc2]) + exc_group = self._make_one(excs=[sub_exc1, sub_group]) + + expected_traceback = ( + f" | google.cloud.bigtable.data.exceptions.{type(exc_group).__name__}: {str(exc_group)}", + " +-+---------------- 1 ----------------", + " | RuntimeError: first sub exception", + " +---------------- 2 ----------------", + f" | google.cloud.bigtable.data.exceptions.{type(sub_group).__name__}: {str(sub_group)}", + " +-+---------------- 1 ----------------", + " | ZeroDivisionError: second sub exception", + " +------------------------------------", + ) + exception_caught = False + try: + raise exc_group + except self._get_class(): + exception_caught = True + tb = traceback.format_exc() + tb_relevant_lines = tuple(tb.splitlines()[3:]) + assert expected_traceback == tb_relevant_lines + assert exception_caught + + @pytest.mark.skipif( + sys.version_info < (3, 11), reason="requires python3.11 or higher" + ) + def test_311_traceback_with_cause(self): + """ + traceback should display nicely with sub-exceptions with __cause__ set + """ + import traceback + + sub_exc1 = RuntimeError("first sub exception") + cause_exc = ImportError("cause exception") + sub_exc1.__cause__ = cause_exc + sub_exc2 = ZeroDivisionError("second sub exception") + exc_group = self._make_one(excs=[sub_exc1, sub_exc2]) + + expected_traceback = ( + f" | google.cloud.bigtable.data.exceptions.{type(exc_group).__name__}: {str(exc_group)}", + " +-+---------------- 1 ----------------", + " | ImportError: cause exception", + " | ", + " | The above exception was the direct cause of the following exception:", + " | ", + " | RuntimeError: first sub exception", + " +---------------- 2 ----------------", + " | ZeroDivisionError: second sub exception", + " +------------------------------------", + ) + exception_caught = False + try: + raise exc_group + except self._get_class(): + exception_caught = True + tb = traceback.format_exc() + tb_relevant_lines = tuple(tb.splitlines()[3:]) + assert expected_traceback == tb_relevant_lines + assert exception_caught + + @pytest.mark.skipif( + sys.version_info < (3, 11), reason="requires python3.11 or higher" + ) + def test_311_exception_group(self): + """ + Python 3.11+ should handle exepctions as native exception groups + """ + exceptions = [RuntimeError("mock"), ValueError("mock")] + instance = self._make_one(excs=exceptions) + # ensure split works as expected + runtime_error, others = instance.split(lambda e: isinstance(e, RuntimeError)) + assert runtime_error.exceptions[0] == exceptions[0] + assert others.exceptions[0] == exceptions[1] + + +class TracebackTests310: + """ + Provides a set of tests that should be run on python 3.10 and under, + to verify that the exception traceback looks as expected + """ + + @pytest.mark.skipif( + sys.version_info >= (3, 11), reason="requires python3.10 or lower" + ) + def test_310_traceback(self): + """ + Exception customizations should not break rich exception group traceback in python 3.10 + """ + import traceback + + sub_exc1 = RuntimeError("first sub exception") + sub_exc2 = ZeroDivisionError("second sub exception") + sub_group = self._make_one(excs=[sub_exc2]) + exc_group = self._make_one(excs=[sub_exc1, sub_group]) + found_message = str(exc_group).splitlines()[0] + found_sub_message = str(sub_group).splitlines()[0] + + expected_traceback = ( + f"google.cloud.bigtable.data.exceptions.{type(exc_group).__name__}: {found_message}", + "--+---------------- 1 ----------------", + " | RuntimeError: first sub exception", + " +---------------- 2 ----------------", + f" | {type(sub_group).__name__}: {found_sub_message}", + " --+---------------- 1 ----------------", + " | ZeroDivisionError: second sub exception", + " +------------------------------------", + ) + exception_caught = False + try: + raise exc_group + except self._get_class(): + exception_caught = True + tb = traceback.format_exc() + tb_relevant_lines = tuple(tb.splitlines()[3:]) + assert expected_traceback == tb_relevant_lines + assert exception_caught + + @pytest.mark.skipif( + sys.version_info >= (3, 11), reason="requires python3.10 or lower" + ) + def test_310_traceback_with_cause(self): + """ + traceback should display nicely with sub-exceptions with __cause__ set + """ + import traceback + + sub_exc1 = RuntimeError("first sub exception") + cause_exc = ImportError("cause exception") + sub_exc1.__cause__ = cause_exc + sub_exc2 = ZeroDivisionError("second sub exception") + exc_group = self._make_one(excs=[sub_exc1, sub_exc2]) + found_message = str(exc_group).splitlines()[0] + + expected_traceback = ( + f"google.cloud.bigtable.data.exceptions.{type(exc_group).__name__}: {found_message}", + "--+---------------- 1 ----------------", + " | ImportError: cause exception", + " | ", + " | The above exception was the direct cause of the following exception:", + " | ", + " | RuntimeError: first sub exception", + " +---------------- 2 ----------------", + " | ZeroDivisionError: second sub exception", + " +------------------------------------", + ) + exception_caught = False + try: + raise exc_group + except self._get_class(): + exception_caught = True + tb = traceback.format_exc() + tb_relevant_lines = tuple(tb.splitlines()[3:]) + assert expected_traceback == tb_relevant_lines + assert exception_caught + + +class TestBigtableExceptionGroup(TracebackTests311, TracebackTests310): + """ + Subclass for MutationsExceptionGroup, RetryExceptionGroup, and ShardedReadRowsExceptionGroup + """ + + def _get_class(self): + from google.cloud.bigtable.data.exceptions import _BigtableExceptionGroup + + return _BigtableExceptionGroup + + def _make_one(self, message="test_message", excs=None): + if excs is None: + excs = [RuntimeError("mock")] + + return self._get_class()(message, excs=excs) + + def test_raise(self): + """ + Create exception in raise statement, which calls __new__ and __init__ + """ + test_msg = "test message" + test_excs = [Exception(test_msg)] + with pytest.raises(self._get_class()) as e: + raise self._get_class()(test_msg, test_excs) + found_message = str(e.value).splitlines()[ + 0 + ] # added to prase out subexceptions in <3.11 + assert found_message == test_msg + assert list(e.value.exceptions) == test_excs + + def test_raise_empty_list(self): + """ + Empty exception lists are not supported + """ + with pytest.raises(ValueError) as e: + raise self._make_one(excs=[]) + assert "non-empty sequence" in str(e.value) + + def test_exception_handling(self): + """ + All versions should inherit from exception + and support tranditional exception handling + """ + instance = self._make_one() + assert isinstance(instance, Exception) + try: + raise instance + except Exception as e: + assert isinstance(e, Exception) + assert e == instance + was_raised = True + assert was_raised + + +class TestMutationsExceptionGroup(TestBigtableExceptionGroup): + def _get_class(self): + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + + return MutationsExceptionGroup + + def _make_one(self, excs=None, num_entries=3): + if excs is None: + excs = [RuntimeError("mock")] + + return self._get_class()(excs, num_entries) + + @pytest.mark.parametrize( + "exception_list,total_entries,expected_message", + [ + ([Exception()], 1, "1 failed entry from 1 attempted."), + ([Exception()], 2, "1 failed entry from 2 attempted."), + ( + [Exception(), RuntimeError()], + 2, + "2 failed entries from 2 attempted.", + ), + ], + ) + def test_raise(self, exception_list, total_entries, expected_message): + """ + Create exception in raise statement, which calls __new__ and __init__ + """ + with pytest.raises(self._get_class()) as e: + raise self._get_class()(exception_list, total_entries) + found_message = str(e.value).splitlines()[ + 0 + ] # added to prase out subexceptions in <3.11 + assert found_message == expected_message + assert list(e.value.exceptions) == exception_list + + def test_raise_custom_message(self): + """ + should be able to set a custom error message + """ + custom_message = "custom message" + exception_list = [Exception()] + with pytest.raises(self._get_class()) as e: + raise self._get_class()(exception_list, 5, message=custom_message) + found_message = str(e.value).splitlines()[ + 0 + ] # added to prase out subexceptions in <3.11 + assert found_message == custom_message + assert list(e.value.exceptions) == exception_list + + @pytest.mark.parametrize( + "first_list_len,second_list_len,total_excs,entry_count,expected_message", + [ + (3, 0, 3, 4, "3 failed entries from 4 attempted."), + (1, 0, 1, 2, "1 failed entry from 2 attempted."), + (0, 1, 1, 2, "1 failed entry from 2 attempted."), + (2, 2, 4, 4, "4 failed entries from 4 attempted."), + ( + 1, + 1, + 3, + 2, + "3 failed entries from 2 attempted. (first 1 and last 1 attached as sub-exceptions; 1 truncated)", + ), + ( + 1, + 2, + 100, + 2, + "100 failed entries from 2 attempted. (first 1 and last 2 attached as sub-exceptions; 97 truncated)", + ), + ( + 2, + 1, + 4, + 9, + "4 failed entries from 9 attempted. (first 2 and last 1 attached as sub-exceptions; 1 truncated)", + ), + ( + 3, + 0, + 10, + 10, + "10 failed entries from 10 attempted. (first 3 attached as sub-exceptions; 7 truncated)", + ), + ( + 0, + 3, + 10, + 10, + "10 failed entries from 10 attempted. (last 3 attached as sub-exceptions; 7 truncated)", + ), + ], + ) + def test_from_truncated_lists( + self, first_list_len, second_list_len, total_excs, entry_count, expected_message + ): + """ + Should be able to make MutationsExceptionGroup using a pair of + lists representing a larger truncated list of exceptions + """ + first_list = [Exception()] * first_list_len + second_list = [Exception()] * second_list_len + with pytest.raises(self._get_class()) as e: + raise self._get_class().from_truncated_lists( + first_list, second_list, total_excs, entry_count + ) + found_message = str(e.value).splitlines()[ + 0 + ] # added to prase out subexceptions in <3.11 + assert found_message == expected_message + assert list(e.value.exceptions) == first_list + second_list + + +class TestRetryExceptionGroup(TestBigtableExceptionGroup): + def _get_class(self): + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + return RetryExceptionGroup + + def _make_one(self, excs=None): + if excs is None: + excs = [RuntimeError("mock")] + + return self._get_class()(excs=excs) + + @pytest.mark.parametrize( + "exception_list,expected_message", + [ + ([Exception()], "1 failed attempt"), + ([Exception(), RuntimeError()], "2 failed attempts"), + ( + [Exception(), ValueError("test")], + "2 failed attempts", + ), + ( + [ + bigtable_exceptions.RetryExceptionGroup( + [Exception(), ValueError("test")] + ) + ], + "1 failed attempt", + ), + ], + ) + def test_raise(self, exception_list, expected_message): + """ + Create exception in raise statement, which calls __new__ and __init__ + """ + with pytest.raises(self._get_class()) as e: + raise self._get_class()(exception_list) + found_message = str(e.value).splitlines()[ + 0 + ] # added to prase out subexceptions in <3.11 + assert found_message == expected_message + assert list(e.value.exceptions) == exception_list + + +class TestShardedReadRowsExceptionGroup(TestBigtableExceptionGroup): + def _get_class(self): + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + + return ShardedReadRowsExceptionGroup + + def _make_one(self, excs=None, succeeded=None, num_entries=3): + if excs is None: + excs = [RuntimeError("mock")] + succeeded = succeeded or [] + + return self._get_class()(excs, succeeded, num_entries) + + @pytest.mark.parametrize( + "exception_list,succeeded,total_entries,expected_message", + [ + ([Exception()], [], 1, "1 sub-exception (from 1 query attempted)"), + ([Exception()], [1], 2, "1 sub-exception (from 2 queries attempted)"), + ( + [Exception(), RuntimeError()], + [0, 1], + 2, + "2 sub-exceptions (from 2 queries attempted)", + ), + ], + ) + def test_raise(self, exception_list, succeeded, total_entries, expected_message): + """ + Create exception in raise statement, which calls __new__ and __init__ + """ + with pytest.raises(self._get_class()) as e: + raise self._get_class()(exception_list, succeeded, total_entries) + found_message = str(e.value).splitlines()[ + 0 + ] # added to prase out subexceptions in <3.11 + assert found_message == expected_message + assert list(e.value.exceptions) == exception_list + assert e.value.successful_rows == succeeded + + +class TestFailedMutationEntryError: + def _get_class(self): + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + return FailedMutationEntryError + + def _make_one(self, idx=9, entry=mock.Mock(), cause=RuntimeError("mock")): + return self._get_class()(idx, entry, cause) + + def test_raise(self): + """ + Create exception in raise statement, which calls __new__ and __init__ + """ + test_idx = 2 + test_entry = mock.Mock() + test_exc = ValueError("test") + with pytest.raises(self._get_class()) as e: + raise self._get_class()(test_idx, test_entry, test_exc) + assert str(e.value) == "Failed idempotent mutation entry at index 2" + assert e.value.index == test_idx + assert e.value.entry == test_entry + assert e.value.__cause__ == test_exc + assert isinstance(e.value, Exception) + assert test_entry.is_idempotent.call_count == 1 + + def test_raise_idempotent(self): + """ + Test raise with non idempotent entry + """ + test_idx = 2 + test_entry = unittest.mock.Mock() + test_entry.is_idempotent.return_value = False + test_exc = ValueError("test") + with pytest.raises(self._get_class()) as e: + raise self._get_class()(test_idx, test_entry, test_exc) + assert str(e.value) == "Failed non-idempotent mutation entry at index 2" + assert e.value.index == test_idx + assert e.value.entry == test_entry + assert e.value.__cause__ == test_exc + assert test_entry.is_idempotent.call_count == 1 + + def test_no_index(self): + """ + Instances without an index should display different error string + """ + test_idx = None + test_entry = unittest.mock.Mock() + test_exc = ValueError("test") + with pytest.raises(self._get_class()) as e: + raise self._get_class()(test_idx, test_entry, test_exc) + assert str(e.value) == "Failed idempotent mutation entry" + assert e.value.index == test_idx + assert e.value.entry == test_entry + assert e.value.__cause__ == test_exc + assert isinstance(e.value, Exception) + assert test_entry.is_idempotent.call_count == 1 + + +class TestFailedQueryShardError: + def _get_class(self): + from google.cloud.bigtable.data.exceptions import FailedQueryShardError + + return FailedQueryShardError + + def _make_one(self, idx=9, query=mock.Mock(), cause=RuntimeError("mock")): + return self._get_class()(idx, query, cause) + + def test_raise(self): + """ + Create exception in raise statement, which calls __new__ and __init__ + """ + test_idx = 2 + test_query = mock.Mock() + test_exc = ValueError("test") + with pytest.raises(self._get_class()) as e: + raise self._get_class()(test_idx, test_query, test_exc) + assert str(e.value) == "Failed query at index 2" + assert e.value.index == test_idx + assert e.value.query == test_query + assert e.value.__cause__ == test_exc + assert isinstance(e.value, Exception) diff --git a/tests/unit/data/test_mutations.py b/tests/unit/data/test_mutations.py new file mode 100644 index 000000000..485c86e42 --- /dev/null +++ b/tests/unit/data/test_mutations.py @@ -0,0 +1,708 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import google.cloud.bigtable.data.mutations as mutations + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock +except ImportError: # pragma: NO COVER + import mock # type: ignore + + +class TestBaseMutation: + def _target_class(self): + from google.cloud.bigtable.data.mutations import Mutation + + return Mutation + + def test__to_dict(self): + """Should be unimplemented in the base class""" + with pytest.raises(NotImplementedError): + self._target_class()._to_dict(mock.Mock()) + + def test_is_idempotent(self): + """is_idempotent should assume True""" + assert self._target_class().is_idempotent(mock.Mock()) + + def test___str__(self): + """Str representation of mutations should be to_dict""" + self_mock = mock.Mock() + str_value = self._target_class().__str__(self_mock) + assert self_mock._to_dict.called + assert str_value == str(self_mock._to_dict.return_value) + + @pytest.mark.parametrize("test_dict", [{}, {"key": "value"}]) + def test_size(self, test_dict): + from sys import getsizeof + + """Size should return size of dict representation""" + self_mock = mock.Mock() + self_mock._to_dict.return_value = test_dict + size_value = self._target_class().size(self_mock) + assert size_value == getsizeof(test_dict) + + @pytest.mark.parametrize( + "expected_class,input_dict", + [ + ( + mutations.SetCell, + { + "set_cell": { + "family_name": "foo", + "column_qualifier": b"bar", + "value": b"test", + "timestamp_micros": 12345, + } + }, + ), + ( + mutations.DeleteRangeFromColumn, + { + "delete_from_column": { + "family_name": "foo", + "column_qualifier": b"bar", + "time_range": {}, + } + }, + ), + ( + mutations.DeleteRangeFromColumn, + { + "delete_from_column": { + "family_name": "foo", + "column_qualifier": b"bar", + "time_range": {"start_timestamp_micros": 123456789}, + } + }, + ), + ( + mutations.DeleteRangeFromColumn, + { + "delete_from_column": { + "family_name": "foo", + "column_qualifier": b"bar", + "time_range": {"end_timestamp_micros": 123456789}, + } + }, + ), + ( + mutations.DeleteRangeFromColumn, + { + "delete_from_column": { + "family_name": "foo", + "column_qualifier": b"bar", + "time_range": { + "start_timestamp_micros": 123, + "end_timestamp_micros": 123456789, + }, + } + }, + ), + ( + mutations.DeleteAllFromFamily, + {"delete_from_family": {"family_name": "foo"}}, + ), + (mutations.DeleteAllFromRow, {"delete_from_row": {}}), + ], + ) + def test__from_dict(self, expected_class, input_dict): + """Should be able to create instance from dict""" + instance = self._target_class()._from_dict(input_dict) + assert isinstance(instance, expected_class) + found_dict = instance._to_dict() + assert found_dict == input_dict + + @pytest.mark.parametrize( + "input_dict", + [ + {"set_cell": {}}, + { + "set_cell": { + "column_qualifier": b"bar", + "value": b"test", + "timestamp_micros": 12345, + } + }, + { + "set_cell": { + "family_name": "f", + "column_qualifier": b"bar", + "value": b"test", + } + }, + {"delete_from_family": {}}, + {"delete_from_column": {}}, + {"fake-type"}, + {}, + ], + ) + def test__from_dict_missing_fields(self, input_dict): + """If dict is malformed or fields are missing, should raise ValueError""" + with pytest.raises(ValueError): + self._target_class()._from_dict(input_dict) + + def test__from_dict_wrong_subclass(self): + """You shouldn't be able to instantiate one mutation type using the dict of another""" + subclasses = [ + mutations.SetCell("foo", b"bar", b"test"), + mutations.DeleteRangeFromColumn("foo", b"bar"), + mutations.DeleteAllFromFamily("foo"), + mutations.DeleteAllFromRow(), + ] + for instance in subclasses: + others = [other for other in subclasses if other != instance] + for other in others: + with pytest.raises(ValueError) as e: + type(other)._from_dict(instance._to_dict()) + assert "Mutation type mismatch" in str(e.value) + + +class TestSetCell: + def _target_class(self): + from google.cloud.bigtable.data.mutations import SetCell + + return SetCell + + def _make_one(self, *args, **kwargs): + return self._target_class()(*args, **kwargs) + + @pytest.mark.parametrize("input_val", [2**64, -(2**64)]) + def test_ctor_large_int(self, input_val): + with pytest.raises(ValueError) as e: + self._make_one(family="f", qualifier=b"b", new_value=input_val) + assert "int values must be between" in str(e.value) + + @pytest.mark.parametrize("input_val", ["", "a", "abc", "hello world!"]) + def test_ctor_str_value(self, input_val): + found = self._make_one(family="f", qualifier=b"b", new_value=input_val) + assert found.new_value == input_val.encode("utf-8") + + def test_ctor(self): + """Ensure constructor sets expected values""" + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = b"test-value" + expected_timestamp = 1234567890 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + assert instance.family == expected_family + assert instance.qualifier == expected_qualifier + assert instance.new_value == expected_value + assert instance.timestamp_micros == expected_timestamp + + def test_ctor_str_inputs(self): + """Test with string qualifier and value""" + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = b"test-value" + instance = self._make_one(expected_family, "test-qualifier", "test-value") + assert instance.family == expected_family + assert instance.qualifier == expected_qualifier + assert instance.new_value == expected_value + + @pytest.mark.parametrize("input_val", [-20, -1, 0, 1, 100, int(2**60)]) + def test_ctor_int_value(self, input_val): + found = self._make_one(family="f", qualifier=b"b", new_value=input_val) + assert found.new_value == input_val.to_bytes(8, "big", signed=True) + + @pytest.mark.parametrize( + "int_value,expected_bytes", + [ + (-42, b"\xff\xff\xff\xff\xff\xff\xff\xd6"), + (-2, b"\xff\xff\xff\xff\xff\xff\xff\xfe"), + (-1, b"\xff\xff\xff\xff\xff\xff\xff\xff"), + (0, b"\x00\x00\x00\x00\x00\x00\x00\x00"), + (1, b"\x00\x00\x00\x00\x00\x00\x00\x01"), + (2, b"\x00\x00\x00\x00\x00\x00\x00\x02"), + (100, b"\x00\x00\x00\x00\x00\x00\x00d"), + ], + ) + def test_ctor_int_value_bytes(self, int_value, expected_bytes): + """Test with int value""" + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + instance = self._make_one(expected_family, expected_qualifier, int_value) + assert instance.family == expected_family + assert instance.qualifier == expected_qualifier + assert instance.new_value == expected_bytes + + def test_ctor_negative_timestamp(self): + """Only positive or -1 timestamps are valid""" + with pytest.raises(ValueError) as e: + self._make_one("test-family", b"test-qualifier", b"test-value", -2) + assert ( + "timestamp_micros must be positive (or -1 for server-side timestamp)" + in str(e.value) + ) + + @pytest.mark.parametrize( + "timestamp_ns,expected_timestamp_micros", + [ + (0, 0), + (1, 0), + (123, 0), + (999, 0), + (999_999, 0), + (1_000_000, 1000), + (1_234_567, 1000), + (1_999_999, 1000), + (2_000_000, 2000), + (1_234_567_890_123, 1_234_567_000), + ], + ) + def test_ctor_no_timestamp(self, timestamp_ns, expected_timestamp_micros): + """If no timestamp is given, should use current time with millisecond precision""" + with mock.patch("time.time_ns", return_value=timestamp_ns): + instance = self._make_one("test-family", b"test-qualifier", b"test-value") + assert instance.timestamp_micros == expected_timestamp_micros + + def test__to_dict(self): + """ensure dict representation is as expected""" + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = b"test-value" + expected_timestamp = 123456789 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + got_dict = instance._to_dict() + assert list(got_dict.keys()) == ["set_cell"] + got_inner_dict = got_dict["set_cell"] + assert got_inner_dict["family_name"] == expected_family + assert got_inner_dict["column_qualifier"] == expected_qualifier + assert got_inner_dict["timestamp_micros"] == expected_timestamp + assert got_inner_dict["value"] == expected_value + assert len(got_inner_dict.keys()) == 4 + + def test__to_dict_server_timestamp(self): + """test with server side timestamp -1 value""" + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = b"test-value" + expected_timestamp = -1 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + got_dict = instance._to_dict() + assert list(got_dict.keys()) == ["set_cell"] + got_inner_dict = got_dict["set_cell"] + assert got_inner_dict["family_name"] == expected_family + assert got_inner_dict["column_qualifier"] == expected_qualifier + assert got_inner_dict["timestamp_micros"] == expected_timestamp + assert got_inner_dict["value"] == expected_value + assert len(got_inner_dict.keys()) == 4 + + def test__to_pb(self): + """ensure proto representation is as expected""" + import google.cloud.bigtable_v2.types.data as data_pb + + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = b"test-value" + expected_timestamp = 123456789 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert got_pb.set_cell.family_name == expected_family + assert got_pb.set_cell.column_qualifier == expected_qualifier + assert got_pb.set_cell.timestamp_micros == expected_timestamp + assert got_pb.set_cell.value == expected_value + + def test__to_pb_server_timestamp(self): + """test with server side timestamp -1 value""" + import google.cloud.bigtable_v2.types.data as data_pb + + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = b"test-value" + expected_timestamp = -1 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert got_pb.set_cell.family_name == expected_family + assert got_pb.set_cell.column_qualifier == expected_qualifier + assert got_pb.set_cell.timestamp_micros == expected_timestamp + assert got_pb.set_cell.value == expected_value + + @pytest.mark.parametrize( + "timestamp,expected_value", + [ + (1234567890, True), + (1, True), + (0, True), + (-1, False), + (None, True), + ], + ) + def test_is_idempotent(self, timestamp, expected_value): + """is_idempotent is based on whether an explicit timestamp is set""" + instance = self._make_one( + "test-family", b"test-qualifier", b"test-value", timestamp + ) + assert instance.is_idempotent() is expected_value + + def test___str__(self): + """Str representation of mutations should be to_dict""" + instance = self._make_one( + "test-family", b"test-qualifier", b"test-value", 1234567890 + ) + str_value = instance.__str__() + dict_value = instance._to_dict() + assert str_value == str(dict_value) + + +class TestDeleteRangeFromColumn: + def _target_class(self): + from google.cloud.bigtable.data.mutations import DeleteRangeFromColumn + + return DeleteRangeFromColumn + + def _make_one(self, *args, **kwargs): + return self._target_class()(*args, **kwargs) + + def test_ctor(self): + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_start = 1234567890 + expected_end = 1234567891 + instance = self._make_one( + expected_family, expected_qualifier, expected_start, expected_end + ) + assert instance.family == expected_family + assert instance.qualifier == expected_qualifier + assert instance.start_timestamp_micros == expected_start + assert instance.end_timestamp_micros == expected_end + + def test_ctor_no_timestamps(self): + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + instance = self._make_one(expected_family, expected_qualifier) + assert instance.family == expected_family + assert instance.qualifier == expected_qualifier + assert instance.start_timestamp_micros is None + assert instance.end_timestamp_micros is None + + def test_ctor_timestamps_out_of_order(self): + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_start = 10 + expected_end = 1 + with pytest.raises(ValueError) as excinfo: + self._make_one( + expected_family, expected_qualifier, expected_start, expected_end + ) + assert "start_timestamp_micros must be <= end_timestamp_micros" in str( + excinfo.value + ) + + @pytest.mark.parametrize( + "start,end", + [ + (0, 1), + (None, 1), + (0, None), + ], + ) + def test__to_dict(self, start, end): + """Should be unimplemented in the base class""" + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + + instance = self._make_one(expected_family, expected_qualifier, start, end) + got_dict = instance._to_dict() + assert list(got_dict.keys()) == ["delete_from_column"] + got_inner_dict = got_dict["delete_from_column"] + assert len(got_inner_dict.keys()) == 3 + assert got_inner_dict["family_name"] == expected_family + assert got_inner_dict["column_qualifier"] == expected_qualifier + time_range_dict = got_inner_dict["time_range"] + expected_len = int(isinstance(start, int)) + int(isinstance(end, int)) + assert len(time_range_dict.keys()) == expected_len + if start is not None: + assert time_range_dict["start_timestamp_micros"] == start + if end is not None: + assert time_range_dict["end_timestamp_micros"] == end + + def test__to_pb(self): + """ensure proto representation is as expected""" + import google.cloud.bigtable_v2.types.data as data_pb + + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + instance = self._make_one(expected_family, expected_qualifier) + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert got_pb.delete_from_column.family_name == expected_family + assert got_pb.delete_from_column.column_qualifier == expected_qualifier + + def test_is_idempotent(self): + """is_idempotent is always true""" + instance = self._make_one( + "test-family", b"test-qualifier", 1234567890, 1234567891 + ) + assert instance.is_idempotent() is True + + def test___str__(self): + """Str representation of mutations should be to_dict""" + instance = self._make_one("test-family", b"test-qualifier") + str_value = instance.__str__() + dict_value = instance._to_dict() + assert str_value == str(dict_value) + + +class TestDeleteAllFromFamily: + def _target_class(self): + from google.cloud.bigtable.data.mutations import DeleteAllFromFamily + + return DeleteAllFromFamily + + def _make_one(self, *args, **kwargs): + return self._target_class()(*args, **kwargs) + + def test_ctor(self): + expected_family = "test-family" + instance = self._make_one(expected_family) + assert instance.family_to_delete == expected_family + + def test__to_dict(self): + """Should be unimplemented in the base class""" + expected_family = "test-family" + instance = self._make_one(expected_family) + got_dict = instance._to_dict() + assert list(got_dict.keys()) == ["delete_from_family"] + got_inner_dict = got_dict["delete_from_family"] + assert len(got_inner_dict.keys()) == 1 + assert got_inner_dict["family_name"] == expected_family + + def test__to_pb(self): + """ensure proto representation is as expected""" + import google.cloud.bigtable_v2.types.data as data_pb + + expected_family = "test-family" + instance = self._make_one(expected_family) + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert got_pb.delete_from_family.family_name == expected_family + + def test_is_idempotent(self): + """is_idempotent is always true""" + instance = self._make_one("test-family") + assert instance.is_idempotent() is True + + def test___str__(self): + """Str representation of mutations should be to_dict""" + instance = self._make_one("test-family") + str_value = instance.__str__() + dict_value = instance._to_dict() + assert str_value == str(dict_value) + + +class TestDeleteFromRow: + def _target_class(self): + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + return DeleteAllFromRow + + def _make_one(self, *args, **kwargs): + return self._target_class()(*args, **kwargs) + + def test_ctor(self): + self._make_one() + + def test__to_dict(self): + """Should be unimplemented in the base class""" + instance = self._make_one() + got_dict = instance._to_dict() + assert list(got_dict.keys()) == ["delete_from_row"] + assert len(got_dict["delete_from_row"].keys()) == 0 + + def test__to_pb(self): + """ensure proto representation is as expected""" + import google.cloud.bigtable_v2.types.data as data_pb + + instance = self._make_one() + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert "delete_from_row" in str(got_pb) + + def test_is_idempotent(self): + """is_idempotent is always true""" + instance = self._make_one() + assert instance.is_idempotent() is True + + def test___str__(self): + """Str representation of mutations should be to_dict""" + instance = self._make_one() + assert instance.__str__() == "{'delete_from_row': {}}" + + +class TestRowMutationEntry: + def _target_class(self): + from google.cloud.bigtable.data.mutations import RowMutationEntry + + return RowMutationEntry + + def _make_one(self, row_key, mutations): + return self._target_class()(row_key, mutations) + + def test_ctor(self): + expected_key = b"row_key" + expected_mutations = [mock.Mock()] + instance = self._make_one(expected_key, expected_mutations) + assert instance.row_key == expected_key + assert list(instance.mutations) == expected_mutations + + def test_ctor_over_limit(self): + """Should raise error if mutations exceed MAX_MUTATIONS_PER_ENTRY""" + from google.cloud.bigtable.data.mutations import ( + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + ) + + assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100_000 + # no errors at limit + expected_mutations = [None for _ in range(_MUTATE_ROWS_REQUEST_MUTATION_LIMIT)] + self._make_one(b"row_key", expected_mutations) + # error if over limit + with pytest.raises(ValueError) as e: + self._make_one("key", expected_mutations + [mock.Mock()]) + assert "entries must have <= 100000 mutations" in str(e.value) + + def test_ctor_str_key(self): + expected_key = "row_key" + expected_mutations = [mock.Mock(), mock.Mock()] + instance = self._make_one(expected_key, expected_mutations) + assert instance.row_key == b"row_key" + assert list(instance.mutations) == expected_mutations + + def test_ctor_single_mutation(self): + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + expected_key = b"row_key" + expected_mutations = DeleteAllFromRow() + instance = self._make_one(expected_key, expected_mutations) + assert instance.row_key == expected_key + assert instance.mutations == (expected_mutations,) + + def test__to_dict(self): + expected_key = "row_key" + mutation_mock = mock.Mock() + n_mutations = 3 + expected_mutations = [mutation_mock for i in range(n_mutations)] + for mock_mutations in expected_mutations: + mock_mutations._to_dict.return_value = {"test": "data"} + instance = self._make_one(expected_key, expected_mutations) + expected_result = { + "row_key": b"row_key", + "mutations": [{"test": "data"}] * n_mutations, + } + assert instance._to_dict() == expected_result + assert mutation_mock._to_dict.call_count == n_mutations + + def test__to_pb(self): + from google.cloud.bigtable_v2.types.bigtable import MutateRowsRequest + from google.cloud.bigtable_v2.types.data import Mutation + + expected_key = "row_key" + mutation_mock = mock.Mock() + n_mutations = 3 + expected_mutations = [mutation_mock for i in range(n_mutations)] + for mock_mutations in expected_mutations: + mock_mutations._to_pb.return_value = Mutation() + instance = self._make_one(expected_key, expected_mutations) + pb_result = instance._to_pb() + assert isinstance(pb_result, MutateRowsRequest.Entry) + assert pb_result.row_key == b"row_key" + assert pb_result.mutations == [Mutation()] * n_mutations + assert mutation_mock._to_pb.call_count == n_mutations + + @pytest.mark.parametrize( + "mutations,result", + [ + ([mock.Mock(is_idempotent=lambda: True)], True), + ([mock.Mock(is_idempotent=lambda: False)], False), + ( + [ + mock.Mock(is_idempotent=lambda: True), + mock.Mock(is_idempotent=lambda: False), + ], + False, + ), + ( + [ + mock.Mock(is_idempotent=lambda: True), + mock.Mock(is_idempotent=lambda: True), + ], + True, + ), + ], + ) + def test_is_idempotent(self, mutations, result): + instance = self._make_one("row_key", mutations) + assert instance.is_idempotent() == result + + def test_empty_mutations(self): + with pytest.raises(ValueError) as e: + self._make_one("row_key", []) + assert "must not be empty" in str(e.value) + + @pytest.mark.parametrize("test_dict", [{}, {"key": "value"}]) + def test_size(self, test_dict): + from sys import getsizeof + + """Size should return size of dict representation""" + self_mock = mock.Mock() + self_mock._to_dict.return_value = test_dict + size_value = self._target_class().size(self_mock) + assert size_value == getsizeof(test_dict) + + def test__from_dict_mock(self): + """ + test creating instance from entry dict, with mocked mutation._from_dict + """ + expected_key = b"row_key" + expected_mutations = [mock.Mock(), mock.Mock()] + input_dict = { + "row_key": expected_key, + "mutations": [{"test": "data"}, {"another": "data"}], + } + with mock.patch.object(mutations.Mutation, "_from_dict") as inner_from_dict: + inner_from_dict.side_effect = expected_mutations + instance = self._target_class()._from_dict(input_dict) + assert instance.row_key == b"row_key" + assert inner_from_dict.call_count == 2 + assert len(instance.mutations) == 2 + assert instance.mutations[0] == expected_mutations[0] + assert instance.mutations[1] == expected_mutations[1] + + def test__from_dict(self): + """ + test creating end-to-end with a real mutation instance + """ + input_dict = { + "row_key": b"row_key", + "mutations": [{"delete_from_family": {"family_name": "test_family"}}], + } + instance = self._target_class()._from_dict(input_dict) + assert instance.row_key == b"row_key" + assert len(instance.mutations) == 1 + assert isinstance(instance.mutations[0], mutations.DeleteAllFromFamily) + assert instance.mutations[0].family_to_delete == "test_family" diff --git a/tests/unit/data/test_read_modify_write_rules.py b/tests/unit/data/test_read_modify_write_rules.py new file mode 100644 index 000000000..1f67da13b --- /dev/null +++ b/tests/unit/data/test_read_modify_write_rules.py @@ -0,0 +1,186 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock +except ImportError: # pragma: NO COVER + import mock # type: ignore + + +class TestBaseReadModifyWriteRule: + def _target_class(self): + from google.cloud.bigtable.data.read_modify_write_rules import ( + ReadModifyWriteRule, + ) + + return ReadModifyWriteRule + + def test_abstract(self): + """should not be able to instantiate""" + with pytest.raises(TypeError): + self._target_class()(family="foo", qualifier=b"bar") + + def test__to_dict(self): + """ + to_dict not implemented in base class + """ + with pytest.raises(NotImplementedError): + self._target_class()._to_dict(mock.Mock()) + + +class TestIncrementRule: + def _target_class(self): + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + return IncrementRule + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", 1), ("fam", b"qual", 1)), + (("fam", b"qual", -12), ("fam", b"qual", -12)), + (("fam", "qual", 1), ("fam", b"qual", 1)), + (("fam", "qual", 0), ("fam", b"qual", 0)), + (("", "", 0), ("", b"", 0)), + (("f", b"q"), ("f", b"q", 1)), + ], + ) + def test_ctor(self, args, expected): + instance = self._target_class()(*args) + assert instance.family == expected[0] + assert instance.qualifier == expected[1] + assert instance.increment_amount == expected[2] + + @pytest.mark.parametrize("input_amount", [1.1, None, "1", object(), "", b"", b"1"]) + def test_ctor_bad_input(self, input_amount): + with pytest.raises(TypeError) as e: + self._target_class()("fam", b"qual", input_amount) + assert "increment_amount must be an integer" in str(e.value) + + @pytest.mark.parametrize( + "large_value", [2**64, 2**64 + 1, -(2**64), -(2**64) - 1] + ) + def test_ctor_large_values(self, large_value): + with pytest.raises(ValueError) as e: + self._target_class()("fam", b"qual", large_value) + assert "too large" in str(e.value) + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", 1), ("fam", b"qual", 1)), + (("fam", b"qual", -12), ("fam", b"qual", -12)), + (("fam", "qual", 1), ("fam", b"qual", 1)), + (("fam", "qual", 0), ("fam", b"qual", 0)), + (("", "", 0), ("", b"", 0)), + (("f", b"q"), ("f", b"q", 1)), + ], + ) + def test__to_dict(self, args, expected): + instance = self._target_class()(*args) + expected = { + "family_name": expected[0], + "column_qualifier": expected[1], + "increment_amount": expected[2], + } + assert instance._to_dict() == expected + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", 1), ("fam", b"qual", 1)), + (("fam", b"qual", -12), ("fam", b"qual", -12)), + (("fam", "qual", 1), ("fam", b"qual", 1)), + (("fam", "qual", 0), ("fam", b"qual", 0)), + (("", "", 0), ("", b"", 0)), + (("f", b"q"), ("f", b"q", 1)), + ], + ) + def test__to_pb(self, args, expected): + import google.cloud.bigtable_v2.types.data as data_pb + + instance = self._target_class()(*args) + pb_result = instance._to_pb() + assert isinstance(pb_result, data_pb.ReadModifyWriteRule) + assert pb_result.family_name == expected[0] + assert pb_result.column_qualifier == expected[1] + assert pb_result.increment_amount == expected[2] + + +class TestAppendValueRule: + def _target_class(self): + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + return AppendValueRule + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", b"val"), ("fam", b"qual", b"val")), + (("fam", "qual", b"val"), ("fam", b"qual", b"val")), + (("", "", b""), ("", b"", b"")), + (("f", "q", "str_val"), ("f", b"q", b"str_val")), + (("f", "q", ""), ("f", b"q", b"")), + ], + ) + def test_ctor(self, args, expected): + instance = self._target_class()(*args) + assert instance.family == expected[0] + assert instance.qualifier == expected[1] + assert instance.append_value == expected[2] + + @pytest.mark.parametrize("input_val", [5, 1.1, None, object()]) + def test_ctor_bad_input(self, input_val): + with pytest.raises(TypeError) as e: + self._target_class()("fam", b"qual", input_val) + assert "append_value must be bytes or str" in str(e.value) + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", b"val"), ("fam", b"qual", b"val")), + (("fam", "qual", b"val"), ("fam", b"qual", b"val")), + (("", "", b""), ("", b"", b"")), + ], + ) + def test__to_dict(self, args, expected): + instance = self._target_class()(*args) + expected = { + "family_name": expected[0], + "column_qualifier": expected[1], + "append_value": expected[2], + } + assert instance._to_dict() == expected + + @pytest.mark.parametrize( + "args,expected", + [ + (("fam", b"qual", b"val"), ("fam", b"qual", b"val")), + (("fam", "qual", b"val"), ("fam", b"qual", b"val")), + (("", "", b""), ("", b"", b"")), + ], + ) + def test__to_pb(self, args, expected): + import google.cloud.bigtable_v2.types.data as data_pb + + instance = self._target_class()(*args) + pb_result = instance._to_pb() + assert isinstance(pb_result, data_pb.ReadModifyWriteRule) + assert pb_result.family_name == expected[0] + assert pb_result.column_qualifier == expected[1] + assert pb_result.append_value == expected[2] diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/test_read_rows_acceptance.py new file mode 100644 index 000000000..7cb3c08dc --- /dev/null +++ b/tests/unit/data/test_read_rows_acceptance.py @@ -0,0 +1,331 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +from itertools import zip_longest + +import pytest +import mock + +from google.cloud.bigtable_v2 import ReadRowsResponse + +from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable.data.row import Row + +from ..v2_client.test_row_merger import ReadRowsTest, TestFile + + +def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "./read-rows-acceptance-test.json") + + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests + + +def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=(cell.labels[0] if cell.labels else ""), + ) + ) + return results + + +@pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description +) +@pytest.mark.asyncio +async def test_row_merger_scenario(test_case: ReadRowsTest): + async def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = _ReadRowsOperationAsync.chunk_stream( + instance, _coro_wrapper(_scenerio_stream()) + ) + merger = _ReadRowsOperationAsync.merge_rows(chunker) + async for row in merger: + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + +@pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description +) +@pytest.mark.asyncio +async def test_read_rows_scenario(test_case: ReadRowsTest): + async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __aiter__(self): + return self + + async def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise StopAsyncIteration + + def cancel(self): + pass + + return mock_stream(chunk_list) + + try: + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + # use emulator mode to avoid auth issues in CI + client = BigtableDataClientAsync() + table = client.get_table("instance", "table") + results = [] + with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: + # run once, then return error on retry + read_rows.return_value = _make_gapic_stream(test_case.chunks) + async for row in await table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + await client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + +@pytest.mark.asyncio +async def test_out_of_order_rows(): + async def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = _ReadRowsOperationAsync.chunk_stream( + instance, _coro_wrapper(_row_stream()) + ) + merger = _ReadRowsOperationAsync.merge_rows(chunker) + with pytest.raises(InvalidChunk): + async for _ in merger: + pass + + +@pytest.mark.asyncio +async def test_bare_reset(): + first_chunk = ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk( + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) + + +@pytest.mark.asyncio +async def test_missing_family(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) + + +@pytest.mark.asyncio +async def test_mid_cell_row_key_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) + + +@pytest.mark.asyncio +async def test_mid_cell_family_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(family_name="f2", value=b"v", commit_row=True), + ) + + +@pytest.mark.asyncio +async def test_mid_cell_qualifier_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(qualifier=b"q2", value=b"v", commit_row=True), + ) + + +@pytest.mark.asyncio +async def test_mid_cell_timestamp_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) + + +@pytest.mark.asyncio +async def test_mid_cell_labels_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) + + +async def _coro_wrapper(stream): + return stream + + +async def _process_chunks(*chunks): + async def _row_stream(): + yield ReadRowsResponse(chunks=chunks) + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = None + chunker = _ReadRowsOperationAsync.chunk_stream( + instance, _coro_wrapper(_row_stream()) + ) + merger = _ReadRowsOperationAsync.merge_rows(chunker) + results = [] + async for row in merger: + results.append(row) + return results diff --git a/tests/unit/data/test_read_rows_query.py b/tests/unit/data/test_read_rows_query.py new file mode 100644 index 000000000..ba3b0468b --- /dev/null +++ b/tests/unit/data/test_read_rows_query.py @@ -0,0 +1,589 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +TEST_ROWS = [ + "row_key_1", + b"row_key_2", +] + + +class TestRowRange: + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data.read_rows_query import RowRange + + return RowRange + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor_start_end(self): + row_range = self._make_one("test_row", "test_row2") + assert row_range.start_key == "test_row".encode() + assert row_range.end_key == "test_row2".encode() + assert row_range.start_is_inclusive is True + assert row_range.end_is_inclusive is False + + def test_ctor_start_only(self): + row_range = self._make_one("test_row3") + assert row_range.start_key == "test_row3".encode() + assert row_range.start_is_inclusive is True + assert row_range.end_key is None + assert row_range.end_is_inclusive is True + + def test_ctor_end_only(self): + row_range = self._make_one(end_key="test_row4") + assert row_range.end_key == "test_row4".encode() + assert row_range.end_is_inclusive is False + assert row_range.start_key is None + assert row_range.start_is_inclusive is True + + def test_ctor_empty_strings(self): + """ + empty strings should be treated as None + """ + row_range = self._make_one("", "") + assert row_range.start_key is None + assert row_range.end_key is None + assert row_range.start_is_inclusive is True + assert row_range.end_is_inclusive is True + + def test_ctor_inclusive_flags(self): + row_range = self._make_one("test_row5", "test_row6", False, True) + assert row_range.start_key == "test_row5".encode() + assert row_range.end_key == "test_row6".encode() + assert row_range.start_is_inclusive is False + assert row_range.end_is_inclusive is True + + def test_ctor_defaults(self): + row_range = self._make_one() + assert row_range.start_key is None + assert row_range.end_key is None + + def test_ctor_invalid_keys(self): + # test with invalid keys + with pytest.raises(ValueError) as exc: + self._make_one(1, "2") + assert str(exc.value) == "start_key must be a string or bytes" + with pytest.raises(ValueError) as exc: + self._make_one("1", 2) + assert str(exc.value) == "end_key must be a string or bytes" + with pytest.raises(ValueError) as exc: + self._make_one("2", "1") + assert str(exc.value) == "start_key must be less than or equal to end_key" + + @pytest.mark.parametrize( + "dict_repr,expected", + [ + ({"start_key_closed": "test_row", "end_key_open": "test_row2"}, True), + ({"start_key_closed": b"test_row", "end_key_open": b"test_row2"}, True), + ({"start_key_open": "test_row", "end_key_closed": "test_row2"}, True), + ({"start_key_open": b"a"}, True), + ({"end_key_closed": b"b"}, True), + ({"start_key_closed": "a"}, True), + ({"end_key_open": b"b"}, True), + ({}, False), + ], + ) + def test___bool__(self, dict_repr, expected): + """ + Only row range with both points empty should be falsy + """ + from google.cloud.bigtable.data.read_rows_query import RowRange + + row_range = RowRange._from_dict(dict_repr) + assert bool(row_range) is expected + + def test__eq__(self): + """ + test that row ranges can be compared for equality + """ + from google.cloud.bigtable.data.read_rows_query import RowRange + + range1 = RowRange("1", "2") + range1_dup = RowRange("1", "2") + range2 = RowRange("1", "3") + range_w_empty = RowRange(None, "2") + assert range1 == range1_dup + assert range1 != range2 + assert range1 != range_w_empty + range_1_w_inclusive_start = RowRange("1", "2", start_is_inclusive=True) + range_1_w_exclusive_start = RowRange("1", "2", start_is_inclusive=False) + range_1_w_inclusive_end = RowRange("1", "2", end_is_inclusive=True) + range_1_w_exclusive_end = RowRange("1", "2", end_is_inclusive=False) + assert range1 == range_1_w_inclusive_start + assert range1 == range_1_w_exclusive_end + assert range1 != range_1_w_exclusive_start + assert range1 != range_1_w_inclusive_end + + @pytest.mark.parametrize( + "dict_repr,expected", + [ + ( + {"start_key_closed": "test_row", "end_key_open": "test_row2"}, + "[b'test_row', b'test_row2')", + ), + ( + {"start_key_open": "test_row", "end_key_closed": "test_row2"}, + "(b'test_row', b'test_row2']", + ), + ({"start_key_open": b"a"}, "(b'a', +inf]"), + ({"end_key_closed": b"b"}, "[-inf, b'b']"), + ({"end_key_open": b"b"}, "[-inf, b'b')"), + ({}, "[-inf, +inf]"), + ], + ) + def test___str__(self, dict_repr, expected): + """ + test string representations of row ranges + """ + from google.cloud.bigtable.data.read_rows_query import RowRange + + row_range = RowRange._from_dict(dict_repr) + assert str(row_range) == expected + + @pytest.mark.parametrize( + "dict_repr,expected", + [ + ( + {"start_key_closed": "test_row", "end_key_open": "test_row2"}, + "RowRange(start_key=b'test_row', end_key=b'test_row2')", + ), + ( + {"start_key_open": "test_row", "end_key_closed": "test_row2"}, + "RowRange(start_key=b'test_row', end_key=b'test_row2', start_is_inclusive=False, end_is_inclusive=True)", + ), + ( + {"start_key_open": b"a"}, + "RowRange(start_key=b'a', end_key=None, start_is_inclusive=False)", + ), + ( + {"end_key_closed": b"b"}, + "RowRange(start_key=None, end_key=b'b', end_is_inclusive=True)", + ), + ({"end_key_open": b"b"}, "RowRange(start_key=None, end_key=b'b')"), + ({}, "RowRange(start_key=None, end_key=None)"), + ], + ) + def test___repr__(self, dict_repr, expected): + """ + test repr representations of row ranges + """ + from google.cloud.bigtable.data.read_rows_query import RowRange + + row_range = RowRange._from_dict(dict_repr) + assert repr(row_range) == expected + + +class TestReadRowsQuery: + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + return ReadRowsQuery + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor_defaults(self): + query = self._make_one() + assert query.row_keys == list() + assert query.row_ranges == list() + assert query.filter is None + assert query.limit is None + + def test_ctor_explicit(self): + from google.cloud.bigtable.data.row_filters import RowFilterChain + from google.cloud.bigtable.data.read_rows_query import RowRange + + filter_ = RowFilterChain() + query = self._make_one( + ["row_key_1", "row_key_2"], + row_ranges=[RowRange("row_key_3", "row_key_4")], + limit=10, + row_filter=filter_, + ) + assert len(query.row_keys) == 2 + assert "row_key_1".encode() in query.row_keys + assert "row_key_2".encode() in query.row_keys + assert len(query.row_ranges) == 1 + assert RowRange("row_key_3", "row_key_4") in query.row_ranges + assert query.filter == filter_ + assert query.limit == 10 + + def test_ctor_invalid_limit(self): + with pytest.raises(ValueError) as exc: + self._make_one(limit=-1) + assert str(exc.value) == "limit must be >= 0" + + def test_set_filter(self): + from google.cloud.bigtable.data.row_filters import RowFilterChain + + filter1 = RowFilterChain() + query = self._make_one() + assert query.filter is None + query.filter = filter1 + assert query.filter == filter1 + filter2 = RowFilterChain() + query.filter = filter2 + assert query.filter == filter2 + query.filter = None + assert query.filter is None + query.filter = RowFilterChain() + assert query.filter == RowFilterChain() + + def test_set_limit(self): + query = self._make_one() + assert query.limit is None + query.limit = 10 + assert query.limit == 10 + query.limit = 9 + assert query.limit == 9 + query.limit = 0 + assert query.limit is None + with pytest.raises(ValueError) as exc: + query.limit = -1 + assert str(exc.value) == "limit must be >= 0" + with pytest.raises(ValueError) as exc: + query.limit = -100 + assert str(exc.value) == "limit must be >= 0" + + def test_add_key_str(self): + query = self._make_one() + assert query.row_keys == list() + input_str = "test_row" + query.add_key(input_str) + assert len(query.row_keys) == 1 + assert input_str.encode() in query.row_keys + input_str2 = "test_row2" + query.add_key(input_str2) + assert len(query.row_keys) == 2 + assert input_str.encode() in query.row_keys + assert input_str2.encode() in query.row_keys + + def test_add_key_bytes(self): + query = self._make_one() + assert query.row_keys == list() + input_bytes = b"test_row" + query.add_key(input_bytes) + assert len(query.row_keys) == 1 + assert input_bytes in query.row_keys + input_bytes2 = b"test_row2" + query.add_key(input_bytes2) + assert len(query.row_keys) == 2 + assert input_bytes in query.row_keys + assert input_bytes2 in query.row_keys + + def test_add_rows_batch(self): + query = self._make_one() + assert query.row_keys == list() + input_batch = ["test_row", b"test_row2", "test_row3"] + for k in input_batch: + query.add_key(k) + assert len(query.row_keys) == 3 + assert b"test_row" in query.row_keys + assert b"test_row2" in query.row_keys + assert b"test_row3" in query.row_keys + # test adding another batch + for k in ["test_row4", b"test_row5"]: + query.add_key(k) + assert len(query.row_keys) == 5 + assert input_batch[0].encode() in query.row_keys + assert input_batch[1] in query.row_keys + assert input_batch[2].encode() in query.row_keys + assert b"test_row4" in query.row_keys + assert b"test_row5" in query.row_keys + + def test_add_key_invalid(self): + query = self._make_one() + with pytest.raises(ValueError) as exc: + query.add_key(1) + assert str(exc.value) == "row_key must be string or bytes" + with pytest.raises(ValueError) as exc: + query.add_key(["s"]) + assert str(exc.value) == "row_key must be string or bytes" + + def test_add_range(self): + from google.cloud.bigtable.data.read_rows_query import RowRange + + query = self._make_one() + assert query.row_ranges == list() + input_range = RowRange(start_key=b"test_row") + query.add_range(input_range) + assert len(query.row_ranges) == 1 + assert input_range in query.row_ranges + input_range2 = RowRange(start_key=b"test_row2") + query.add_range(input_range2) + assert len(query.row_ranges) == 2 + assert input_range in query.row_ranges + assert input_range2 in query.row_ranges + + def _parse_query_string(self, query_string): + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery, RowRange + + query = ReadRowsQuery() + segments = query_string.split(",") + for segment in segments: + if "-" in segment: + start, end = segment.split("-") + s_open, e_open = True, True + if start == "": + start = None + s_open = None + else: + if start[0] == "(": + s_open = False + start = start[1:] + if end == "": + end = None + e_open = None + else: + if end[-1] == ")": + e_open = False + end = end[:-1] + query.add_range(RowRange(start, end, s_open, e_open)) + else: + query.add_key(segment) + return query + + @pytest.mark.parametrize( + "query_string,shard_points", + [ + ("a,[p-q)", []), + ("0_key,[1_range_start-2_range_end)", ["3_split"]), + ("0_key,[1_range_start-2_range_end)", ["2_range_end"]), + ("0_key,[1_range_start-2_range_end]", ["2_range_end"]), + ("-1_range_end)", ["5_split"]), + ("8_key,(1_range_start-2_range_end]", ["1_range_start"]), + ("9_row_key,(5_range_start-7_range_end)", ["3_split"]), + ("3_row_key,(5_range_start-7_range_end)", ["2_row_key"]), + ("4_split,4_split,(3_split-5_split]", ["3_split", "5_split"]), + ("(3_split-", ["3_split"]), + ], + ) + def test_shard_no_split(self, query_string, shard_points): + """ + Test sharding with a set of queries that should not result in any splits. + """ + initial_query = self._parse_query_string(query_string) + row_samples = [(point.encode(), None) for point in shard_points] + sharded_queries = initial_query.shard(row_samples) + assert len(sharded_queries) == 1 + assert initial_query == sharded_queries[0] + + def test_shard_full_table_scan_empty_split(self): + """ + Sharding a full table scan with no split should return another full table scan. + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + full_scan_query = ReadRowsQuery() + split_points = [] + sharded_queries = full_scan_query.shard(split_points) + assert len(sharded_queries) == 1 + result_query = sharded_queries[0] + assert result_query == full_scan_query + + def test_shard_full_table_scan_with_split(self): + """ + Test splitting a full table scan into two queries + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + full_scan_query = ReadRowsQuery() + split_points = [(b"a", None)] + sharded_queries = full_scan_query.shard(split_points) + assert len(sharded_queries) == 2 + assert sharded_queries[0] == self._parse_query_string("-a]") + assert sharded_queries[1] == self._parse_query_string("(a-") + + def test_shard_full_table_scan_with_multiple_split(self): + """ + Test splitting a full table scan into three queries + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + full_scan_query = ReadRowsQuery() + split_points = [(b"a", None), (b"z", None)] + sharded_queries = full_scan_query.shard(split_points) + assert len(sharded_queries) == 3 + assert sharded_queries[0] == self._parse_query_string("-a]") + assert sharded_queries[1] == self._parse_query_string("(a-z]") + assert sharded_queries[2] == self._parse_query_string("(z-") + + def test_shard_multiple_keys(self): + """ + Test splitting multiple individual keys into separate queries + """ + initial_query = self._parse_query_string("1_beforeSplit,2_onSplit,3_afterSplit") + split_points = [(b"2_onSplit", None)] + sharded_queries = initial_query.shard(split_points) + assert len(sharded_queries) == 2 + assert sharded_queries[0] == self._parse_query_string("1_beforeSplit,2_onSplit") + assert sharded_queries[1] == self._parse_query_string("3_afterSplit") + + def test_shard_keys_empty_left(self): + """ + Test with the left-most split point empty + """ + initial_query = self._parse_query_string("5_test,8_test") + split_points = [(b"0_split", None), (b"6_split", None)] + sharded_queries = initial_query.shard(split_points) + assert len(sharded_queries) == 2 + assert sharded_queries[0] == self._parse_query_string("5_test") + assert sharded_queries[1] == self._parse_query_string("8_test") + + def test_shard_keys_empty_right(self): + """ + Test with the right-most split point empty + """ + initial_query = self._parse_query_string("0_test,2_test") + split_points = [(b"1_split", None), (b"5_split", None)] + sharded_queries = initial_query.shard(split_points) + assert len(sharded_queries) == 2 + assert sharded_queries[0] == self._parse_query_string("0_test") + assert sharded_queries[1] == self._parse_query_string("2_test") + + def test_shard_mixed_split(self): + """ + Test splitting a complex query with multiple split points + """ + initial_query = self._parse_query_string("0,a,c,-a],-b],(c-e],(d-f],(m-") + split_points = [(s.encode(), None) for s in ["a", "d", "j", "o"]] + sharded_queries = initial_query.shard(split_points) + assert len(sharded_queries) == 5 + assert sharded_queries[0] == self._parse_query_string("0,a,-a]") + assert sharded_queries[1] == self._parse_query_string("c,(a-b],(c-d]") + assert sharded_queries[2] == self._parse_query_string("(d-e],(d-f]") + assert sharded_queries[3] == self._parse_query_string("(m-o]") + assert sharded_queries[4] == self._parse_query_string("(o-") + + def test_shard_unsorted_request(self): + """ + Test with a query that contains rows and queries in a random order + """ + initial_query = self._parse_query_string( + "7_row_key_1,2_row_key_2,[8_range_1_start-9_range_1_end),[3_range_2_start-4_range_2_end)" + ) + split_points = [(b"5-split", None)] + sharded_queries = initial_query.shard(split_points) + assert len(sharded_queries) == 2 + assert sharded_queries[0] == self._parse_query_string( + "2_row_key_2,[3_range_2_start-4_range_2_end)" + ) + assert sharded_queries[1] == self._parse_query_string( + "7_row_key_1,[8_range_1_start-9_range_1_end)" + ) + + @pytest.mark.parametrize( + "query_string,shard_points", + [ + ("a,[p-q)", []), + ("0_key,[1_range_start-2_range_end)", ["3_split"]), + ("-1_range_end)", ["5_split"]), + ("0_key,[1_range_start-2_range_end)", ["2_range_end"]), + ("9_row_key,(5_range_start-7_range_end)", ["3_split"]), + ("(5_range_start-", ["3_split"]), + ("3_split,[3_split-5_split)", ["3_split", "5_split"]), + ("[3_split-", ["3_split"]), + ("", []), + ("", ["3_split"]), + ("", ["3_split", "5_split"]), + ("1,2,3,4,5,6,7,8,9", ["3_split"]), + ], + ) + def test_shard_keeps_filter(self, query_string, shard_points): + """ + sharded queries should keep the filter from the original query + """ + initial_query = self._parse_query_string(query_string) + expected_filter = {"test": "filter"} + initial_query.filter = expected_filter + row_samples = [(point.encode(), None) for point in shard_points] + sharded_queries = initial_query.shard(row_samples) + assert len(sharded_queries) > 0 + for query in sharded_queries: + assert query.filter == expected_filter + + def test_shard_limit_exception(self): + """ + queries with a limit should raise an exception when a shard is attempted + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + query = ReadRowsQuery(limit=10) + with pytest.raises(AttributeError) as e: + query.shard([]) + assert "Cannot shard query with a limit" in str(e.value) + + @pytest.mark.parametrize( + "first_args,second_args,expected", + [ + ((), (), True), + ((), ("a",), False), + (("a",), (), False), + (("a",), ("a",), True), + ((["a"],), (["a", "b"],), False), + ((["a", "b"],), (["a", "b"],), True), + ((["a", b"b"],), ([b"a", "b"],), True), + (("a",), (b"a",), True), + (("a",), ("b",), False), + (("a",), ("a", ["b"]), False), + (("a", "b"), ("a", ["b"]), True), + (("a", ["b"]), ("a", ["b", "c"]), False), + (("a", ["b", "c"]), ("a", [b"b", "c"]), True), + (("a", ["b", "c"], 1), ("a", ["b", b"c"], 1), True), + (("a", ["b"], 1), ("a", ["b"], 2), False), + (("a", ["b"], 1, {"a": "b"}), ("a", ["b"], 1, {"a": "b"}), True), + (("a", ["b"], 1, {"a": "b"}), ("a", ["b"], 1), False), + ( + (), + (None, [None], None, None), + True, + ), # empty query is equal to empty row range + ((), (None, [None], 1, None), False), + ((), (None, [None], None, {"a": "b"}), False), + ], + ) + def test___eq__(self, first_args, second_args, expected): + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + # replace row_range placeholders with a RowRange object + if len(first_args) > 1: + first_args = list(first_args) + first_args[1] = [RowRange(c) for c in first_args[1]] + if len(second_args) > 1: + second_args = list(second_args) + second_args[1] = [RowRange(c) for c in second_args[1]] + first = ReadRowsQuery(*first_args) + second = ReadRowsQuery(*second_args) + assert (first == second) == expected + + def test___repr__(self): + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + instance = self._make_one(row_keys=["a", "b"], row_filter={}, limit=10) + # should be able to recreate the instance from the repr + repr_str = repr(instance) + recreated = eval(repr_str) + assert isinstance(recreated, ReadRowsQuery) + assert recreated == instance + + def test_empty_row_set(self): + """Empty strings should be treated as keys inputs""" + query = self._make_one(row_keys="") + assert query.row_keys == [b""] diff --git a/tests/unit/data/test_row.py b/tests/unit/data/test_row.py new file mode 100644 index 000000000..10b5bdb23 --- /dev/null +++ b/tests/unit/data/test_row.py @@ -0,0 +1,718 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import time + +TEST_VALUE = b"1234" +TEST_ROW_KEY = b"row" +TEST_FAMILY_ID = "cf1" +TEST_QUALIFIER = b"col" +TEST_TIMESTAMP = time.time_ns() // 1000 +TEST_LABELS = ["label1", "label2"] + + +class TestRow(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data.row import Row + + return Row + + def _make_one(self, *args, **kwargs): + if len(args) == 0: + args = (TEST_ROW_KEY, [self._make_cell()]) + return self._get_target_class()(*args, **kwargs) + + def _make_cell( + self, + value=TEST_VALUE, + row_key=TEST_ROW_KEY, + family_id=TEST_FAMILY_ID, + qualifier=TEST_QUALIFIER, + timestamp=TEST_TIMESTAMP, + labels=TEST_LABELS, + ): + from google.cloud.bigtable.data.row import Cell + + return Cell(value, row_key, family_id, qualifier, timestamp, labels) + + def test_ctor(self): + cells = [self._make_cell(), self._make_cell()] + row_response = self._make_one(TEST_ROW_KEY, cells) + self.assertEqual(list(row_response), cells) + self.assertEqual(row_response.row_key, TEST_ROW_KEY) + + def test__from_pb(self): + """ + Construct from protobuf. + """ + from google.cloud.bigtable_v2.types import Row as RowPB + from google.cloud.bigtable_v2.types import Family as FamilyPB + from google.cloud.bigtable_v2.types import Column as ColumnPB + from google.cloud.bigtable_v2.types import Cell as CellPB + + row_key = b"row_key" + cells = [ + CellPB( + value=str(i).encode(), + timestamp_micros=TEST_TIMESTAMP, + labels=TEST_LABELS, + ) + for i in range(2) + ] + column = ColumnPB(qualifier=TEST_QUALIFIER, cells=cells) + families_pb = [FamilyPB(name=TEST_FAMILY_ID, columns=[column])] + row_pb = RowPB(key=row_key, families=families_pb) + output = self._get_target_class()._from_pb(row_pb) + self.assertEqual(output.row_key, row_key) + self.assertEqual(len(output), 2) + self.assertEqual(output[0].value, b"0") + self.assertEqual(output[1].value, b"1") + self.assertEqual(output[0].timestamp_micros, TEST_TIMESTAMP) + self.assertEqual(output[0].labels, TEST_LABELS) + assert output[0].row_key == row_key + assert output[0].family == TEST_FAMILY_ID + assert output[0].qualifier == TEST_QUALIFIER + + def test__from_pb_sparse(self): + """ + Construct from minimal protobuf. + """ + from google.cloud.bigtable_v2.types import Row as RowPB + + row_key = b"row_key" + row_pb = RowPB(key=row_key) + output = self._get_target_class()._from_pb(row_pb) + self.assertEqual(output.row_key, row_key) + self.assertEqual(len(output), 0) + + def test_get_cells(self): + cell_list = [] + for family_id in ["1", "2"]: + for qualifier in [b"a", b"b"]: + cell = self._make_cell(family_id=family_id, qualifier=qualifier) + cell_list.append(cell) + # test getting all cells + row_response = self._make_one(TEST_ROW_KEY, cell_list) + self.assertEqual(row_response.get_cells(), cell_list) + # test getting cells in a family + output = row_response.get_cells(family="1") + self.assertEqual(len(output), 2) + self.assertEqual(output[0].family, "1") + self.assertEqual(output[1].family, "1") + self.assertEqual(output[0], cell_list[0]) + # test getting cells in a family/qualifier + # should accept bytes or str for qualifier + for q in [b"a", "a"]: + output = row_response.get_cells(family="1", qualifier=q) + self.assertEqual(len(output), 1) + self.assertEqual(output[0].family, "1") + self.assertEqual(output[0].qualifier, b"a") + self.assertEqual(output[0], cell_list[0]) + # calling with just qualifier should raise an error + with self.assertRaises(ValueError): + row_response.get_cells(qualifier=b"a") + # test calling with bad family or qualifier + with self.assertRaises(ValueError): + row_response.get_cells(family="3", qualifier=b"a") + with self.assertRaises(ValueError): + row_response.get_cells(family="3") + with self.assertRaises(ValueError): + row_response.get_cells(family="1", qualifier=b"c") + + def test___repr__(self): + cell_str = ( + "{'value': b'1234', 'timestamp_micros': %d, 'labels': ['label1', 'label2']}" + % (TEST_TIMESTAMP) + ) + expected_prefix = "Row(key=b'row', cells=" + row = self._make_one(TEST_ROW_KEY, [self._make_cell()]) + self.assertIn(expected_prefix, repr(row)) + self.assertIn(cell_str, repr(row)) + expected_full = ( + "Row(key=b'row', cells={\n ('cf1', b'col'): [{'value': b'1234', 'timestamp_micros': %d, 'labels': ['label1', 'label2']}],\n})" + % (TEST_TIMESTAMP) + ) + self.assertEqual(expected_full, repr(row)) + # try with multiple cells + row = self._make_one(TEST_ROW_KEY, [self._make_cell(), self._make_cell()]) + self.assertIn(expected_prefix, repr(row)) + self.assertIn(cell_str, repr(row)) + + def test___str__(self): + cells = [ + self._make_cell(value=b"1234", family_id="1", qualifier=b"col"), + self._make_cell(value=b"5678", family_id="3", qualifier=b"col"), + self._make_cell(value=b"1", family_id="3", qualifier=b"col"), + self._make_cell(value=b"2", family_id="3", qualifier=b"col"), + ] + + row_response = self._make_one(TEST_ROW_KEY, cells) + expected = ( + "{\n" + + " (family='1', qualifier=b'col'): [b'1234'],\n" + + " (family='3', qualifier=b'col'): [b'5678', (+2 more)],\n" + + "}" + ) + self.assertEqual(expected, str(row_response)) + + def test_to_dict(self): + from google.cloud.bigtable_v2.types import Row + + cell1 = self._make_cell() + cell2 = self._make_cell() + cell2.value = b"other" + row = self._make_one(TEST_ROW_KEY, [cell1, cell2]) + row_dict = row._to_dict() + expected_dict = { + "key": TEST_ROW_KEY, + "families": [ + { + "name": TEST_FAMILY_ID, + "columns": [ + { + "qualifier": TEST_QUALIFIER, + "cells": [ + { + "value": TEST_VALUE, + "timestamp_micros": TEST_TIMESTAMP, + "labels": TEST_LABELS, + }, + { + "value": b"other", + "timestamp_micros": TEST_TIMESTAMP, + "labels": TEST_LABELS, + }, + ], + } + ], + }, + ], + } + self.assertEqual(len(row_dict), len(expected_dict)) + for key, value in expected_dict.items(): + self.assertEqual(row_dict[key], value) + # should be able to construct a Cell proto from the dict + row_proto = Row(**row_dict) + self.assertEqual(row_proto.key, TEST_ROW_KEY) + self.assertEqual(len(row_proto.families), 1) + family = row_proto.families[0] + self.assertEqual(family.name, TEST_FAMILY_ID) + self.assertEqual(len(family.columns), 1) + column = family.columns[0] + self.assertEqual(column.qualifier, TEST_QUALIFIER) + self.assertEqual(len(column.cells), 2) + self.assertEqual(column.cells[0].value, TEST_VALUE) + self.assertEqual(column.cells[0].timestamp_micros, TEST_TIMESTAMP) + self.assertEqual(column.cells[0].labels, TEST_LABELS) + self.assertEqual(column.cells[1].value, cell2.value) + self.assertEqual(column.cells[1].timestamp_micros, TEST_TIMESTAMP) + self.assertEqual(column.cells[1].labels, TEST_LABELS) + + def test_iteration(self): + from google.cloud.bigtable.data.row import Cell + + # should be able to iterate over the Row as a list + cell1 = self._make_cell(value=b"1") + cell2 = self._make_cell(value=b"2") + cell3 = self._make_cell(value=b"3") + row_response = self._make_one(TEST_ROW_KEY, [cell1, cell2, cell3]) + self.assertEqual(len(row_response), 3) + result_list = list(row_response) + self.assertEqual(len(result_list), 3) + # should be able to iterate over all cells + idx = 0 + for cell in row_response: + self.assertIsInstance(cell, Cell) + self.assertEqual(cell.value, result_list[idx].value) + self.assertEqual(cell.value, str(idx + 1).encode()) + idx += 1 + + def test_contains_cell(self): + cell3 = self._make_cell(value=b"3") + cell1 = self._make_cell(value=b"1") + cell2 = self._make_cell(value=b"2") + cell4 = self._make_cell(value=b"4") + row_response = self._make_one(TEST_ROW_KEY, [cell3, cell1, cell2]) + self.assertIn(cell1, row_response) + self.assertIn(cell2, row_response) + self.assertNotIn(cell4, row_response) + cell3_copy = self._make_cell(value=b"3") + self.assertIn(cell3_copy, row_response) + + def test_contains_family_id(self): + new_family_id = "new_family_id" + cell = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + cell2 = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + new_family_id, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + row_response = self._make_one(TEST_ROW_KEY, [cell, cell2]) + self.assertIn(TEST_FAMILY_ID, row_response) + self.assertIn("new_family_id", row_response) + self.assertIn(new_family_id, row_response) + self.assertNotIn("not_a_family_id", row_response) + self.assertNotIn(None, row_response) + + def test_contains_family_qualifier_tuple(self): + new_family_id = "new_family_id" + new_qualifier = b"new_qualifier" + cell = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + cell2 = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + new_family_id, + new_qualifier, + TEST_TIMESTAMP, + TEST_LABELS, + ) + row_response = self._make_one(TEST_ROW_KEY, [cell, cell2]) + self.assertIn((TEST_FAMILY_ID, TEST_QUALIFIER), row_response) + self.assertIn(("new_family_id", "new_qualifier"), row_response) + self.assertIn(("new_family_id", b"new_qualifier"), row_response) + self.assertIn((new_family_id, new_qualifier), row_response) + + self.assertNotIn(("not_a_family_id", TEST_QUALIFIER), row_response) + self.assertNotIn((TEST_FAMILY_ID, "not_a_qualifier"), row_response) + self.assertNotIn((TEST_FAMILY_ID, new_qualifier), row_response) + self.assertNotIn(("not_a_family_id", "not_a_qualifier"), row_response) + self.assertNotIn((None, None), row_response) + self.assertNotIn(None, row_response) + + def test_int_indexing(self): + # should be able to index into underlying list with an index number directly + cell_list = [self._make_cell(value=str(i).encode()) for i in range(10)] + sorted(cell_list) + row_response = self._make_one(TEST_ROW_KEY, cell_list) + self.assertEqual(len(row_response), 10) + for i in range(10): + self.assertEqual(row_response[i].value, str(i).encode()) + # backwards indexing should work + self.assertEqual(row_response[-i - 1].value, str(9 - i).encode()) + with self.assertRaises(IndexError): + row_response[10] + with self.assertRaises(IndexError): + row_response[-11] + + def test_slice_indexing(self): + # should be able to index with a range of indices + cell_list = [self._make_cell(value=str(i).encode()) for i in range(10)] + sorted(cell_list) + row_response = self._make_one(TEST_ROW_KEY, cell_list) + self.assertEqual(len(row_response), 10) + self.assertEqual(len(row_response[0:10]), 10) + self.assertEqual(row_response[0:10], cell_list) + self.assertEqual(len(row_response[0:]), 10) + self.assertEqual(row_response[0:], cell_list) + self.assertEqual(len(row_response[:10]), 10) + self.assertEqual(row_response[:10], cell_list) + self.assertEqual(len(row_response[0:10:1]), 10) + self.assertEqual(row_response[0:10:1], cell_list) + self.assertEqual(len(row_response[0:10:2]), 5) + self.assertEqual(row_response[0:10:2], [cell_list[i] for i in range(0, 10, 2)]) + self.assertEqual(len(row_response[0:10:3]), 4) + self.assertEqual(row_response[0:10:3], [cell_list[i] for i in range(0, 10, 3)]) + self.assertEqual(len(row_response[10:0:-1]), 9) + self.assertEqual(len(row_response[10:0:-2]), 5) + self.assertEqual(row_response[10:0:-3], cell_list[10:0:-3]) + self.assertEqual(len(row_response[0:100]), 10) + + def test_family_indexing(self): + # should be able to retrieve cells in a family + new_family_id = "new_family_id" + cell = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + cell2 = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + cell3 = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + new_family_id, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + row_response = self._make_one(TEST_ROW_KEY, [cell, cell2, cell3]) + + self.assertEqual(len(row_response[TEST_FAMILY_ID]), 2) + self.assertEqual(row_response[TEST_FAMILY_ID][0], cell) + self.assertEqual(row_response[TEST_FAMILY_ID][1], cell2) + self.assertEqual(len(row_response[new_family_id]), 1) + self.assertEqual(row_response[new_family_id][0], cell3) + with self.assertRaises(ValueError): + row_response["not_a_family_id"] + with self.assertRaises(TypeError): + row_response[None] + with self.assertRaises(TypeError): + row_response[b"new_family_id"] + + def test_family_qualifier_indexing(self): + # should be able to retrieve cells in a family/qualifier tuplw + new_family_id = "new_family_id" + new_qualifier = b"new_qualifier" + cell = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + cell2 = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + cell3 = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + new_family_id, + new_qualifier, + TEST_TIMESTAMP, + TEST_LABELS, + ) + row_response = self._make_one(TEST_ROW_KEY, [cell, cell2, cell3]) + + self.assertEqual(len(row_response[TEST_FAMILY_ID, TEST_QUALIFIER]), 2) + self.assertEqual(row_response[TEST_FAMILY_ID, TEST_QUALIFIER][0], cell) + self.assertEqual(row_response[TEST_FAMILY_ID, TEST_QUALIFIER][1], cell2) + self.assertEqual(len(row_response[new_family_id, new_qualifier]), 1) + self.assertEqual(row_response[new_family_id, new_qualifier][0], cell3) + self.assertEqual(len(row_response["new_family_id", "new_qualifier"]), 1) + self.assertEqual(len(row_response["new_family_id", b"new_qualifier"]), 1) + with self.assertRaises(ValueError): + row_response[new_family_id, "not_a_qualifier"] + with self.assertRaises(ValueError): + row_response["not_a_family_id", new_qualifier] + with self.assertRaises(TypeError): + row_response[None, None] + with self.assertRaises(TypeError): + row_response[b"new_family_id", b"new_qualifier"] + + def test_get_column_components(self): + # should be able to retrieve (family,qualifier) tuples as keys + new_family_id = "new_family_id" + new_qualifier = b"new_qualifier" + cell = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + cell2 = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + cell3 = self._make_cell( + TEST_VALUE, + TEST_ROW_KEY, + new_family_id, + new_qualifier, + TEST_TIMESTAMP, + TEST_LABELS, + ) + row_response = self._make_one(TEST_ROW_KEY, [cell, cell2, cell3]) + + self.assertEqual(len(row_response._get_column_components()), 2) + self.assertEqual( + row_response._get_column_components(), + [(TEST_FAMILY_ID, TEST_QUALIFIER), (new_family_id, new_qualifier)], + ) + + row_response = self._make_one(TEST_ROW_KEY, []) + self.assertEqual(len(row_response._get_column_components()), 0) + self.assertEqual(row_response._get_column_components(), []) + + row_response = self._make_one(TEST_ROW_KEY, [cell]) + self.assertEqual(len(row_response._get_column_components()), 1) + self.assertEqual( + row_response._get_column_components(), [(TEST_FAMILY_ID, TEST_QUALIFIER)] + ) + + +class TestCell(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data.row import Cell + + return Cell + + def _make_one(self, *args, **kwargs): + if len(args) == 0: + args = ( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + cell = self._make_one( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + self.assertEqual(cell.value, TEST_VALUE) + self.assertEqual(cell.row_key, TEST_ROW_KEY) + self.assertEqual(cell.family, TEST_FAMILY_ID) + self.assertEqual(cell.qualifier, TEST_QUALIFIER) + self.assertEqual(cell.timestamp_micros, TEST_TIMESTAMP) + self.assertEqual(cell.labels, TEST_LABELS) + + def test_to_dict(self): + from google.cloud.bigtable_v2.types import Cell + + cell = self._make_one() + cell_dict = cell._to_dict() + expected_dict = { + "value": TEST_VALUE, + "timestamp_micros": TEST_TIMESTAMP, + "labels": TEST_LABELS, + } + self.assertEqual(len(cell_dict), len(expected_dict)) + for key, value in expected_dict.items(): + self.assertEqual(cell_dict[key], value) + # should be able to construct a Cell proto from the dict + cell_proto = Cell(**cell_dict) + self.assertEqual(cell_proto.value, TEST_VALUE) + self.assertEqual(cell_proto.timestamp_micros, TEST_TIMESTAMP) + self.assertEqual(cell_proto.labels, TEST_LABELS) + + def test_to_dict_no_labels(self): + from google.cloud.bigtable_v2.types import Cell + + cell_no_labels = self._make_one( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + None, + ) + cell_dict = cell_no_labels._to_dict() + expected_dict = { + "value": TEST_VALUE, + "timestamp_micros": TEST_TIMESTAMP, + } + self.assertEqual(len(cell_dict), len(expected_dict)) + for key, value in expected_dict.items(): + self.assertEqual(cell_dict[key], value) + # should be able to construct a Cell proto from the dict + cell_proto = Cell(**cell_dict) + self.assertEqual(cell_proto.value, TEST_VALUE) + self.assertEqual(cell_proto.timestamp_micros, TEST_TIMESTAMP) + self.assertEqual(cell_proto.labels, []) + + def test_int_value(self): + test_int = 1234 + bytes_value = test_int.to_bytes(4, "big", signed=True) + cell = self._make_one( + bytes_value, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + self.assertEqual(int(cell), test_int) + # ensure string formatting works + formatted = "%d" % cell + self.assertEqual(formatted, str(test_int)) + self.assertEqual(int(formatted), test_int) + + def test_int_value_negative(self): + test_int = -99999 + bytes_value = test_int.to_bytes(4, "big", signed=True) + cell = self._make_one( + bytes_value, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + self.assertEqual(int(cell), test_int) + # ensure string formatting works + formatted = "%d" % cell + self.assertEqual(formatted, str(test_int)) + self.assertEqual(int(formatted), test_int) + + def test___str__(self): + test_value = b"helloworld" + cell = self._make_one( + test_value, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + self.assertEqual(str(cell), "b'helloworld'") + self.assertEqual(str(cell), str(test_value)) + + def test___repr__(self): + from google.cloud.bigtable.data.row import Cell # type: ignore # noqa: F401 + + cell = self._make_one() + expected = ( + "Cell(value=b'1234', row_key=b'row', " + + "family='cf1', qualifier=b'col', " + + f"timestamp_micros={TEST_TIMESTAMP}, labels=['label1', 'label2'])" + ) + self.assertEqual(repr(cell), expected) + # should be able to construct instance from __repr__ + result = eval(repr(cell)) + self.assertEqual(result, cell) + + def test___repr___no_labels(self): + from google.cloud.bigtable.data.row import Cell # type: ignore # noqa: F401 + + cell_no_labels = self._make_one( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + None, + ) + expected = ( + "Cell(value=b'1234', row_key=b'row', " + + "family='cf1', qualifier=b'col', " + + f"timestamp_micros={TEST_TIMESTAMP}, labels=[])" + ) + self.assertEqual(repr(cell_no_labels), expected) + # should be able to construct instance from __repr__ + result = eval(repr(cell_no_labels)) + self.assertEqual(result, cell_no_labels) + + def test_equality(self): + cell1 = self._make_one() + cell2 = self._make_one() + self.assertEqual(cell1, cell2) + self.assertTrue(cell1 == cell2) + args = ( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + for i in range(0, len(args)): + # try changing each argument + modified_cell = self._make_one(*args[:i], args[i] + args[i], *args[i + 1 :]) + self.assertNotEqual(cell1, modified_cell) + self.assertFalse(cell1 == modified_cell) + self.assertTrue(cell1 != modified_cell) + + def test_hash(self): + # class should be hashable + cell1 = self._make_one() + d = {cell1: 1} + cell2 = self._make_one() + self.assertEqual(d[cell2], 1) + + args = ( + TEST_VALUE, + TEST_ROW_KEY, + TEST_FAMILY_ID, + TEST_QUALIFIER, + TEST_TIMESTAMP, + TEST_LABELS, + ) + for i in range(0, len(args)): + # try changing each argument + modified_cell = self._make_one(*args[:i], args[i] + args[i], *args[i + 1 :]) + with self.assertRaises(KeyError): + d[modified_cell] + + def test_ordering(self): + # create cell list in order from lowest to highest + higher_cells = [] + i = 0 + # families; alphebetical order + for family in ["z", "y", "x"]: + # qualifiers; lowest byte value first + for qualifier in [b"z", b"y", b"x"]: + # timestamps; newest first + for timestamp in [ + TEST_TIMESTAMP, + TEST_TIMESTAMP + 1, + TEST_TIMESTAMP + 2, + ]: + cell = self._make_one( + TEST_VALUE, + TEST_ROW_KEY, + family, + qualifier, + timestamp, + TEST_LABELS, + ) + # cell should be the highest priority encountered so far + self.assertEqual(i, len(higher_cells)) + i += 1 + for other in higher_cells: + self.assertLess(cell, other) + higher_cells.append(cell) + # final order should be reverse of sorted order + expected_order = higher_cells + expected_order.reverse() + self.assertEqual(expected_order, sorted(higher_cells)) diff --git a/tests/unit/data/test_row_filters.py b/tests/unit/data/test_row_filters.py new file mode 100644 index 000000000..e90b6f270 --- /dev/null +++ b/tests/unit/data/test_row_filters.py @@ -0,0 +1,2039 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + + +def test_abstract_class_constructors(): + from google.cloud.bigtable.data.row_filters import RowFilter + from google.cloud.bigtable.data.row_filters import _BoolFilter + from google.cloud.bigtable.data.row_filters import _FilterCombination + from google.cloud.bigtable.data.row_filters import _CellCountFilter + + with pytest.raises(TypeError): + RowFilter() + with pytest.raises(TypeError): + _BoolFilter(False) + with pytest.raises(TypeError): + _FilterCombination([]) + with pytest.raises(TypeError): + _CellCountFilter(0) + + +def test_bool_filter_constructor(): + for FilterType in _get_bool_filters(): + flag = True + row_filter = FilterType(flag) + assert row_filter.flag is flag + + +def test_bool_filter___eq__type_differ(): + for FilterType in _get_bool_filters(): + flag = object() + row_filter1 = FilterType(flag) + row_filter2 = object() + assert not (row_filter1 == row_filter2) + + +def test_bool_filter___eq__same_value(): + for FilterType in _get_bool_filters(): + flag = object() + row_filter1 = FilterType(flag) + row_filter2 = FilterType(flag) + assert row_filter1 == row_filter2 + + +def test_bool_filter___ne__same_value(): + for FilterType in _get_bool_filters(): + flag = object() + row_filter1 = FilterType(flag) + row_filter2 = FilterType(flag) + assert not (row_filter1 != row_filter2) + + +def test_sink_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import SinkFilter + + flag = True + row_filter = SinkFilter(flag) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(sink=flag) + assert pb_val == expected_pb + + +def test_sink_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import SinkFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + flag = True + row_filter = SinkFilter(flag) + expected_dict = {"sink": flag} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_sink_filter___repr__(): + from google.cloud.bigtable.data.row_filters import SinkFilter + + flag = True + row_filter = SinkFilter(flag) + assert repr(row_filter) == "SinkFilter(flag={})".format(flag) + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_pass_all_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import PassAllFilter + + flag = True + row_filter = PassAllFilter(flag) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(pass_all_filter=flag) + assert pb_val == expected_pb + + +def test_pass_all_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import PassAllFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + flag = True + row_filter = PassAllFilter(flag) + expected_dict = {"pass_all_filter": flag} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_pass_all_filter___repr__(): + from google.cloud.bigtable.data.row_filters import PassAllFilter + + flag = True + row_filter = PassAllFilter(flag) + assert repr(row_filter) == "PassAllFilter(flag={})".format(flag) + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_block_all_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import BlockAllFilter + + flag = True + row_filter = BlockAllFilter(flag) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(block_all_filter=flag) + assert pb_val == expected_pb + + +def test_block_all_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import BlockAllFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + flag = True + row_filter = BlockAllFilter(flag) + expected_dict = {"block_all_filter": flag} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_block_all_filter___repr__(): + from google.cloud.bigtable.data.row_filters import BlockAllFilter + + flag = True + row_filter = BlockAllFilter(flag) + assert repr(row_filter) == "BlockAllFilter(flag={})".format(flag) + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_regex_filterconstructor(): + for FilterType in _get_regex_filters(): + regex = b"abc" + row_filter = FilterType(regex) + assert row_filter.regex == regex + + +def test_regex_filterconstructor_non_bytes(): + for FilterType in _get_regex_filters(): + regex = "abc" + row_filter = FilterType(regex) + assert row_filter.regex == b"abc" + + +def test_regex_filter__eq__type_differ(): + for FilterType in _get_regex_filters(): + regex = b"def-rgx" + row_filter1 = FilterType(regex) + row_filter2 = object() + assert not (row_filter1 == row_filter2) + + +def test_regex_filter__eq__same_value(): + for FilterType in _get_regex_filters(): + regex = b"trex-regex" + row_filter1 = FilterType(regex) + row_filter2 = FilterType(regex) + assert row_filter1 == row_filter2 + + +def test_regex_filter__ne__same_value(): + for FilterType in _get_regex_filters(): + regex = b"abc" + row_filter1 = FilterType(regex) + row_filter2 = FilterType(regex) + assert not (row_filter1 != row_filter2) + + +def test_row_key_regex_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import RowKeyRegexFilter + + regex = b"row-key-regex" + row_filter = RowKeyRegexFilter(regex) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(row_key_regex_filter=regex) + assert pb_val == expected_pb + + +def test_row_key_regex_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import RowKeyRegexFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + regex = b"row-key-regex" + row_filter = RowKeyRegexFilter(regex) + expected_dict = {"row_key_regex_filter": regex} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_row_key_regex_filter___repr__(): + from google.cloud.bigtable.data.row_filters import RowKeyRegexFilter + + regex = b"row-key-regex" + row_filter = RowKeyRegexFilter(regex) + assert repr(row_filter) == "RowKeyRegexFilter(regex={})".format(regex) + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_row_sample_filter_constructor(): + from google.cloud.bigtable.data.row_filters import RowSampleFilter + + sample = object() + row_filter = RowSampleFilter(sample) + assert row_filter.sample is sample + + +def test_row_sample_filter___eq__type_differ(): + from google.cloud.bigtable.data.row_filters import RowSampleFilter + + sample = object() + row_filter1 = RowSampleFilter(sample) + row_filter2 = object() + assert not (row_filter1 == row_filter2) + + +def test_row_sample_filter___eq__same_value(): + from google.cloud.bigtable.data.row_filters import RowSampleFilter + + sample = object() + row_filter1 = RowSampleFilter(sample) + row_filter2 = RowSampleFilter(sample) + assert row_filter1 == row_filter2 + + +def test_row_sample_filter___ne__(): + from google.cloud.bigtable.data.row_filters import RowSampleFilter + + sample = object() + other_sample = object() + row_filter1 = RowSampleFilter(sample) + row_filter2 = RowSampleFilter(other_sample) + assert row_filter1 != row_filter2 + + +def test_row_sample_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import RowSampleFilter + + sample = 0.25 + row_filter = RowSampleFilter(sample) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(row_sample_filter=sample) + assert pb_val == expected_pb + + +def test_row_sample_filter___repr__(): + from google.cloud.bigtable.data.row_filters import RowSampleFilter + + sample = 0.25 + row_filter = RowSampleFilter(sample) + assert repr(row_filter) == "RowSampleFilter(sample={})".format(sample) + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_family_name_regex_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import FamilyNameRegexFilter + + regex = "family-regex" + row_filter = FamilyNameRegexFilter(regex) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(family_name_regex_filter=regex) + assert pb_val == expected_pb + + +def test_family_name_regex_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import FamilyNameRegexFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + regex = "family-regex" + row_filter = FamilyNameRegexFilter(regex) + expected_dict = {"family_name_regex_filter": regex.encode()} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_family_name_regex_filter___repr__(): + from google.cloud.bigtable.data.row_filters import FamilyNameRegexFilter + + regex = "family-regex" + row_filter = FamilyNameRegexFilter(regex) + expected = "FamilyNameRegexFilter(regex=b'family-regex')" + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_column_qualifier_regex_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import ColumnQualifierRegexFilter + + regex = b"column-regex" + row_filter = ColumnQualifierRegexFilter(regex) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(column_qualifier_regex_filter=regex) + assert pb_val == expected_pb + + +def test_column_qualifier_regex_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import ColumnQualifierRegexFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + regex = b"column-regex" + row_filter = ColumnQualifierRegexFilter(regex) + expected_dict = {"column_qualifier_regex_filter": regex} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_column_qualifier_regex_filter___repr__(): + from google.cloud.bigtable.data.row_filters import ColumnQualifierRegexFilter + + regex = b"column-regex" + row_filter = ColumnQualifierRegexFilter(regex) + assert repr(row_filter) == "ColumnQualifierRegexFilter(regex={})".format(regex) + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_timestamp_range_constructor(): + from google.cloud.bigtable.data.row_filters import TimestampRange + + start = object() + end = object() + time_range = TimestampRange(start=start, end=end) + assert time_range.start is start + assert time_range.end is end + + +def test_timestamp_range___eq__(): + from google.cloud.bigtable.data.row_filters import TimestampRange + + start = object() + end = object() + time_range1 = TimestampRange(start=start, end=end) + time_range2 = TimestampRange(start=start, end=end) + assert time_range1 == time_range2 + + +def test_timestamp_range___eq__type_differ(): + from google.cloud.bigtable.data.row_filters import TimestampRange + + start = object() + end = object() + time_range1 = TimestampRange(start=start, end=end) + time_range2 = object() + assert not (time_range1 == time_range2) + + +def test_timestamp_range___ne__same_value(): + from google.cloud.bigtable.data.row_filters import TimestampRange + + start = object() + end = object() + time_range1 = TimestampRange(start=start, end=end) + time_range2 = TimestampRange(start=start, end=end) + assert not (time_range1 != time_range2) + + +def _timestamp_range_to_pb_helper(pb_kwargs, start=None, end=None): + import datetime + from google.cloud._helpers import _EPOCH + from google.cloud.bigtable.data.row_filters import TimestampRange + + if start is not None: + start = _EPOCH + datetime.timedelta(microseconds=start) + if end is not None: + end = _EPOCH + datetime.timedelta(microseconds=end) + time_range = TimestampRange(start=start, end=end) + expected_pb = _TimestampRangePB(**pb_kwargs) + time_pb = time_range._to_pb() + assert time_pb.start_timestamp_micros == expected_pb.start_timestamp_micros + assert time_pb.end_timestamp_micros == expected_pb.end_timestamp_micros + assert time_pb == expected_pb + + +def test_timestamp_range_to_pb(): + start_micros = 30871234 + end_micros = 12939371234 + start_millis = start_micros // 1000 * 1000 + assert start_millis == 30871000 + end_millis = end_micros // 1000 * 1000 + 1000 + assert end_millis == 12939372000 + pb_kwargs = {} + pb_kwargs["start_timestamp_micros"] = start_millis + pb_kwargs["end_timestamp_micros"] = end_millis + _timestamp_range_to_pb_helper(pb_kwargs, start=start_micros, end=end_micros) + + +def test_timestamp_range_to_dict(): + from google.cloud.bigtable.data.row_filters import TimestampRange + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + import datetime + + row_filter = TimestampRange( + start=datetime.datetime(2019, 1, 1), end=datetime.datetime(2019, 1, 2) + ) + expected_dict = { + "start_timestamp_micros": 1546300800000000, + "end_timestamp_micros": 1546387200000000, + } + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.TimestampRange(**expected_dict) == expected_pb_value + + +def test_timestamp_range_to_pb_start_only(): + # Makes sure already milliseconds granularity + start_micros = 30871000 + start_millis = start_micros // 1000 * 1000 + assert start_millis == 30871000 + pb_kwargs = {} + pb_kwargs["start_timestamp_micros"] = start_millis + _timestamp_range_to_pb_helper(pb_kwargs, start=start_micros, end=None) + + +def test_timestamp_range_to_dict_start_only(): + from google.cloud.bigtable.data.row_filters import TimestampRange + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + import datetime + + row_filter = TimestampRange(start=datetime.datetime(2019, 1, 1)) + expected_dict = {"start_timestamp_micros": 1546300800000000} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.TimestampRange(**expected_dict) == expected_pb_value + + +def test_timestamp_range_to_pb_end_only(): + # Makes sure already milliseconds granularity + end_micros = 12939371000 + end_millis = end_micros // 1000 * 1000 + assert end_millis == 12939371000 + pb_kwargs = {} + pb_kwargs["end_timestamp_micros"] = end_millis + _timestamp_range_to_pb_helper(pb_kwargs, start=None, end=end_micros) + + +def test_timestamp_range_to_dict_end_only(): + from google.cloud.bigtable.data.row_filters import TimestampRange + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + import datetime + + row_filter = TimestampRange(end=datetime.datetime(2019, 1, 2)) + expected_dict = {"end_timestamp_micros": 1546387200000000} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.TimestampRange(**expected_dict) == expected_pb_value + + +def timestamp_range___repr__(): + from google.cloud.bigtable.data.row_filters import TimestampRange + + start = object() + end = object() + time_range = TimestampRange(start=start, end=end) + assert repr(time_range) == "TimestampRange(start={}, end={})".format(start, end) + assert repr(time_range) == str(time_range) + assert eval(repr(time_range)) == time_range + + +def test_timestamp_range_filter___eq__type_differ(): + from google.cloud.bigtable.data.row_filters import TimestampRangeFilter + + range_ = object() + row_filter1 = TimestampRangeFilter(range_) + row_filter2 = object() + assert not (row_filter1 == row_filter2) + + +def test_timestamp_range_filter___eq__same_value(): + from google.cloud.bigtable.data.row_filters import TimestampRangeFilter + + range_ = object() + row_filter1 = TimestampRangeFilter(range_) + row_filter2 = TimestampRangeFilter(range_) + assert row_filter1 == row_filter2 + + +def test_timestamp_range_filter___ne__(): + from google.cloud.bigtable.data.row_filters import TimestampRangeFilter + + range_ = object() + other_range_ = object() + row_filter1 = TimestampRangeFilter(range_) + row_filter2 = TimestampRangeFilter(other_range_) + assert row_filter1 != row_filter2 + + +def test_timestamp_range_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import TimestampRangeFilter + + row_filter = TimestampRangeFilter() + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(timestamp_range_filter=_TimestampRangePB()) + assert pb_val == expected_pb + + +def test_timestamp_range_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import TimestampRangeFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + import datetime + + row_filter = TimestampRangeFilter( + start=datetime.datetime(2019, 1, 1), end=datetime.datetime(2019, 1, 2) + ) + expected_dict = { + "timestamp_range_filter": { + "start_timestamp_micros": 1546300800000000, + "end_timestamp_micros": 1546387200000000, + } + } + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_timestamp_range_filter_empty_to_dict(): + from google.cloud.bigtable.data.row_filters import TimestampRangeFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter = TimestampRangeFilter() + expected_dict = {"timestamp_range_filter": {}} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_timestamp_range_filter___repr__(): + from google.cloud.bigtable.data.row_filters import TimestampRangeFilter + import datetime + + start = datetime.datetime(2019, 1, 1) + end = datetime.datetime(2019, 1, 2) + row_filter = TimestampRangeFilter(start, end) + assert ( + repr(row_filter) + == f"TimestampRangeFilter(start={repr(start)}, end={repr(end)})" + ) + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_column_range_filter_constructor_defaults(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = object() + row_filter = ColumnRangeFilter(family_id) + assert row_filter.family_id is family_id + assert row_filter.start_qualifier is None + assert row_filter.end_qualifier is None + assert row_filter.inclusive_start + assert row_filter.inclusive_end + + +def test_column_range_filter_constructor_explicit(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = object() + start_qualifier = object() + end_qualifier = object() + inclusive_start = object() + inclusive_end = object() + row_filter = ColumnRangeFilter( + family_id, + start_qualifier=start_qualifier, + end_qualifier=end_qualifier, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + assert row_filter.family_id is family_id + assert row_filter.start_qualifier is start_qualifier + assert row_filter.end_qualifier is end_qualifier + assert row_filter.inclusive_start is inclusive_start + assert row_filter.inclusive_end is inclusive_end + + +def test_column_range_filter_constructor_(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = object() + with pytest.raises(ValueError): + ColumnRangeFilter(family_id, inclusive_start=True) + + +def test_column_range_filter_constructor_bad_end(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = object() + with pytest.raises(ValueError): + ColumnRangeFilter(family_id, inclusive_end=True) + + +def test_column_range_filter___eq__(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = object() + start_qualifier = object() + end_qualifier = object() + inclusive_start = object() + inclusive_end = object() + row_filter1 = ColumnRangeFilter( + family_id, + start_qualifier=start_qualifier, + end_qualifier=end_qualifier, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + row_filter2 = ColumnRangeFilter( + family_id, + start_qualifier=start_qualifier, + end_qualifier=end_qualifier, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + assert row_filter1 == row_filter2 + + +def test_column_range_filter___eq__type_differ(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = object() + row_filter1 = ColumnRangeFilter(family_id) + row_filter2 = object() + assert not (row_filter1 == row_filter2) + + +def test_column_range_filter___ne__(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = object() + other_family_id = object() + start_qualifier = object() + end_qualifier = object() + inclusive_start = object() + inclusive_end = object() + row_filter1 = ColumnRangeFilter( + family_id, + start_qualifier=start_qualifier, + end_qualifier=end_qualifier, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + row_filter2 = ColumnRangeFilter( + other_family_id, + start_qualifier=start_qualifier, + end_qualifier=end_qualifier, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + assert row_filter1 != row_filter2 + + +def test_column_range_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = "column-family-id" + row_filter = ColumnRangeFilter(family_id) + col_range_pb = _ColumnRangePB(family_name=family_id) + expected_pb = _RowFilterPB(column_range_filter=col_range_pb) + assert row_filter._to_pb() == expected_pb + + +def test_column_range_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + family_id = "column-family-id" + row_filter = ColumnRangeFilter(family_id) + expected_dict = {"column_range_filter": {"family_name": family_id}} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_column_range_filter_to_pb_inclusive_start(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = "column-family-id" + column = b"column" + row_filter = ColumnRangeFilter(family_id, start_qualifier=column) + col_range_pb = _ColumnRangePB(family_name=family_id, start_qualifier_closed=column) + expected_pb = _RowFilterPB(column_range_filter=col_range_pb) + assert row_filter._to_pb() == expected_pb + + +def test_column_range_filter_to_pb_exclusive_start(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = "column-family-id" + column = b"column" + row_filter = ColumnRangeFilter( + family_id, start_qualifier=column, inclusive_start=False + ) + col_range_pb = _ColumnRangePB(family_name=family_id, start_qualifier_open=column) + expected_pb = _RowFilterPB(column_range_filter=col_range_pb) + assert row_filter._to_pb() == expected_pb + + +def test_column_range_filter_to_pb_inclusive_end(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = "column-family-id" + column = b"column" + row_filter = ColumnRangeFilter(family_id, end_qualifier=column) + col_range_pb = _ColumnRangePB(family_name=family_id, end_qualifier_closed=column) + expected_pb = _RowFilterPB(column_range_filter=col_range_pb) + assert row_filter._to_pb() == expected_pb + + +def test_column_range_filter_to_pb_exclusive_end(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = "column-family-id" + column = b"column" + row_filter = ColumnRangeFilter(family_id, end_qualifier=column, inclusive_end=False) + col_range_pb = _ColumnRangePB(family_name=family_id, end_qualifier_open=column) + expected_pb = _RowFilterPB(column_range_filter=col_range_pb) + assert row_filter._to_pb() == expected_pb + + +def test_column_range_filter___repr__(): + from google.cloud.bigtable.data.row_filters import ColumnRangeFilter + + family_id = "column-family-id" + start_qualifier = b"column" + end_qualifier = b"column2" + row_filter = ColumnRangeFilter(family_id, start_qualifier, end_qualifier) + expected = "ColumnRangeFilter(family_id='column-family-id', start_qualifier=b'column', end_qualifier=b'column2', inclusive_start=True, inclusive_end=True)" + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_value_regex_filter_to_pb_w_bytes(): + from google.cloud.bigtable.data.row_filters import ValueRegexFilter + + value = regex = b"value-regex" + row_filter = ValueRegexFilter(value) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(value_regex_filter=regex) + assert pb_val == expected_pb + + +def test_value_regex_filter_to_dict_w_bytes(): + from google.cloud.bigtable.data.row_filters import ValueRegexFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + value = regex = b"value-regex" + row_filter = ValueRegexFilter(value) + expected_dict = {"value_regex_filter": regex} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_value_regex_filter_to_pb_w_str(): + from google.cloud.bigtable.data.row_filters import ValueRegexFilter + + value = "value-regex" + regex = value.encode("ascii") + row_filter = ValueRegexFilter(value) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(value_regex_filter=regex) + assert pb_val == expected_pb + + +def test_value_regex_filter_to_dict_w_str(): + from google.cloud.bigtable.data.row_filters import ValueRegexFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + value = "value-regex" + regex = value.encode("ascii") + row_filter = ValueRegexFilter(value) + expected_dict = {"value_regex_filter": regex} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_value_regex_filter___repr__(): + from google.cloud.bigtable.data.row_filters import ValueRegexFilter + + value = "value-regex" + row_filter = ValueRegexFilter(value) + expected = "ValueRegexFilter(regex=b'value-regex')" + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_literal_value_filter_to_pb_w_bytes(): + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + + value = regex = b"value_regex" + row_filter = LiteralValueFilter(value) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(value_regex_filter=regex) + assert pb_val == expected_pb + + +def test_literal_value_filter_to_dict_w_bytes(): + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + value = regex = b"value_regex" + row_filter = LiteralValueFilter(value) + expected_dict = {"value_regex_filter": regex} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_literal_value_filter_to_pb_w_str(): + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + + value = "value_regex" + regex = value.encode("ascii") + row_filter = LiteralValueFilter(value) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(value_regex_filter=regex) + assert pb_val == expected_pb + + +def test_literal_value_filter_to_dict_w_str(): + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + value = "value_regex" + regex = value.encode("ascii") + row_filter = LiteralValueFilter(value) + expected_dict = {"value_regex_filter": regex} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +@pytest.mark.parametrize( + "value,expected_byte_string", + [ + # null bytes are encoded as "\x00" in ascii characters + # others are just prefixed with "\" + (0, b"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00"), + (1, b"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\\x01"), + ( + 68, + b"\\x00\\x00\\x00\\x00\\x00\\x00\\x00D", + ), # bytes that encode to alphanum are not escaped + (570, b"\\x00\\x00\\x00\\x00\\x00\\x00\\\x02\\\x3a"), + (2852126720, b"\\x00\\x00\\x00\\x00\xaa\\x00\\x00\\x00"), + (-1, b"\xff\xff\xff\xff\xff\xff\xff\xff"), + (-1096642724096, b"\xff\xff\xff\\x00\xaa\xff\xff\\x00"), + ], +) +def test_literal_value_filter_w_int(value, expected_byte_string): + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter = LiteralValueFilter(value) + # test pb + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(value_regex_filter=expected_byte_string) + assert pb_val == expected_pb + # test dict + expected_dict = {"value_regex_filter": expected_byte_string} + assert row_filter._to_dict() == expected_dict + assert data_v2_pb2.RowFilter(**expected_dict) == pb_val + + +def test_literal_value_filter___repr__(): + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + + value = "value_regex" + row_filter = LiteralValueFilter(value) + expected = "LiteralValueFilter(value=b'value_regex')" + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_value_range_filter_constructor_defaults(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_filter = ValueRangeFilter() + + assert row_filter.start_value is None + assert row_filter.end_value is None + assert row_filter.inclusive_start + assert row_filter.inclusive_end + + +def test_value_range_filter_constructor_explicit(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + start_value = object() + end_value = object() + inclusive_start = object() + inclusive_end = object() + + row_filter = ValueRangeFilter( + start_value=start_value, + end_value=end_value, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + + assert row_filter.start_value is start_value + assert row_filter.end_value is end_value + assert row_filter.inclusive_start is inclusive_start + assert row_filter.inclusive_end is inclusive_end + + +def test_value_range_filter_constructor_w_int_values(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + import struct + + start_value = 1 + end_value = 10 + + row_filter = ValueRangeFilter(start_value=start_value, end_value=end_value) + + expected_start_value = struct.Struct(">q").pack(start_value) + expected_end_value = struct.Struct(">q").pack(end_value) + + assert row_filter.start_value == expected_start_value + assert row_filter.end_value == expected_end_value + assert row_filter.inclusive_start + assert row_filter.inclusive_end + + +def test_value_range_filter_constructor_bad_start(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + with pytest.raises(ValueError): + ValueRangeFilter(inclusive_start=True) + + +def test_value_range_filter_constructor_bad_end(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + with pytest.raises(ValueError): + ValueRangeFilter(inclusive_end=True) + + +def test_value_range_filter___eq__(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + start_value = object() + end_value = object() + inclusive_start = object() + inclusive_end = object() + row_filter1 = ValueRangeFilter( + start_value=start_value, + end_value=end_value, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + row_filter2 = ValueRangeFilter( + start_value=start_value, + end_value=end_value, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + assert row_filter1 == row_filter2 + + +def test_value_range_filter___eq__type_differ(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_filter1 = ValueRangeFilter() + row_filter2 = object() + assert not (row_filter1 == row_filter2) + + +def test_value_range_filter___ne__(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + start_value = object() + other_start_value = object() + end_value = object() + inclusive_start = object() + inclusive_end = object() + row_filter1 = ValueRangeFilter( + start_value=start_value, + end_value=end_value, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + row_filter2 = ValueRangeFilter( + start_value=other_start_value, + end_value=end_value, + inclusive_start=inclusive_start, + inclusive_end=inclusive_end, + ) + assert row_filter1 != row_filter2 + + +def test_value_range_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_filter = ValueRangeFilter() + expected_pb = _RowFilterPB(value_range_filter=_ValueRangePB()) + assert row_filter._to_pb() == expected_pb + + +def test_value_range_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter = ValueRangeFilter() + expected_dict = {"value_range_filter": {}} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_value_range_filter_to_pb_inclusive_start(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + value = b"some-value" + row_filter = ValueRangeFilter(start_value=value) + val_range_pb = _ValueRangePB(start_value_closed=value) + expected_pb = _RowFilterPB(value_range_filter=val_range_pb) + assert row_filter._to_pb() == expected_pb + + +def test_value_range_filter_to_pb_exclusive_start(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + value = b"some-value" + row_filter = ValueRangeFilter(start_value=value, inclusive_start=False) + val_range_pb = _ValueRangePB(start_value_open=value) + expected_pb = _RowFilterPB(value_range_filter=val_range_pb) + assert row_filter._to_pb() == expected_pb + + +def test_value_range_filter_to_pb_inclusive_end(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + value = b"some-value" + row_filter = ValueRangeFilter(end_value=value) + val_range_pb = _ValueRangePB(end_value_closed=value) + expected_pb = _RowFilterPB(value_range_filter=val_range_pb) + assert row_filter._to_pb() == expected_pb + + +def test_value_range_filter_to_pb_exclusive_end(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + value = b"some-value" + row_filter = ValueRangeFilter(end_value=value, inclusive_end=False) + val_range_pb = _ValueRangePB(end_value_open=value) + expected_pb = _RowFilterPB(value_range_filter=val_range_pb) + assert row_filter._to_pb() == expected_pb + + +def test_value_range_filter___repr__(): + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + start_value = b"some-value" + end_value = b"some-other-value" + row_filter = ValueRangeFilter( + start_value=start_value, end_value=end_value, inclusive_end=False + ) + expected = "ValueRangeFilter(start_value=b'some-value', end_value=b'some-other-value', inclusive_start=True, inclusive_end=False)" + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_cell_count_constructor(): + for FilerType in _get_cell_count_filters(): + num_cells = object() + row_filter = FilerType(num_cells) + assert row_filter.num_cells is num_cells + + +def test_cell_count___eq__type_differ(): + for FilerType in _get_cell_count_filters(): + num_cells = object() + row_filter1 = FilerType(num_cells) + row_filter2 = object() + assert not (row_filter1 == row_filter2) + + +def test_cell_count___eq__same_value(): + for FilerType in _get_cell_count_filters(): + num_cells = object() + row_filter1 = FilerType(num_cells) + row_filter2 = FilerType(num_cells) + assert row_filter1 == row_filter2 + + +def test_cell_count___ne__same_value(): + for FilerType in _get_cell_count_filters(): + num_cells = object() + row_filter1 = FilerType(num_cells) + row_filter2 = FilerType(num_cells) + assert not (row_filter1 != row_filter2) + + +def test_cells_row_offset_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter + + num_cells = 76 + row_filter = CellsRowOffsetFilter(num_cells) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(cells_per_row_offset_filter=num_cells) + assert pb_val == expected_pb + + +def test_cells_row_offset_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + num_cells = 76 + row_filter = CellsRowOffsetFilter(num_cells) + expected_dict = {"cells_per_row_offset_filter": num_cells} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_cells_row_offset_filter___repr__(): + from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter + + num_cells = 76 + row_filter = CellsRowOffsetFilter(num_cells) + expected = "CellsRowOffsetFilter(num_cells={})".format(num_cells) + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_cells_row_limit_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter + + num_cells = 189 + row_filter = CellsRowLimitFilter(num_cells) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(cells_per_row_limit_filter=num_cells) + assert pb_val == expected_pb + + +def test_cells_row_limit_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + num_cells = 189 + row_filter = CellsRowLimitFilter(num_cells) + expected_dict = {"cells_per_row_limit_filter": num_cells} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_cells_row_limit_filter___repr__(): + from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter + + num_cells = 189 + row_filter = CellsRowLimitFilter(num_cells) + expected = "CellsRowLimitFilter(num_cells={})".format(num_cells) + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_cells_column_limit_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import CellsColumnLimitFilter + + num_cells = 10 + row_filter = CellsColumnLimitFilter(num_cells) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(cells_per_column_limit_filter=num_cells) + assert pb_val == expected_pb + + +def test_cells_column_limit_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import CellsColumnLimitFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + num_cells = 10 + row_filter = CellsColumnLimitFilter(num_cells) + expected_dict = {"cells_per_column_limit_filter": num_cells} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_cells_column_limit_filter___repr__(): + from google.cloud.bigtable.data.row_filters import CellsColumnLimitFilter + + num_cells = 10 + row_filter = CellsColumnLimitFilter(num_cells) + expected = "CellsColumnLimitFilter(num_cells={})".format(num_cells) + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_strip_value_transformer_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + flag = True + row_filter = StripValueTransformerFilter(flag) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(strip_value_transformer=flag) + assert pb_val == expected_pb + + +def test_strip_value_transformer_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + flag = True + row_filter = StripValueTransformerFilter(flag) + expected_dict = {"strip_value_transformer": flag} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_strip_value_transformer_filter___repr__(): + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + flag = True + row_filter = StripValueTransformerFilter(flag) + expected = "StripValueTransformerFilter(flag={})".format(flag) + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_apply_label_filter_constructor(): + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + label = object() + row_filter = ApplyLabelFilter(label) + assert row_filter.label is label + + +def test_apply_label_filter___eq__type_differ(): + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + label = object() + row_filter1 = ApplyLabelFilter(label) + row_filter2 = object() + assert not (row_filter1 == row_filter2) + + +def test_apply_label_filter___eq__same_value(): + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + label = object() + row_filter1 = ApplyLabelFilter(label) + row_filter2 = ApplyLabelFilter(label) + assert row_filter1 == row_filter2 + + +def test_apply_label_filter___ne__(): + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + label = object() + other_label = object() + row_filter1 = ApplyLabelFilter(label) + row_filter2 = ApplyLabelFilter(other_label) + assert row_filter1 != row_filter2 + + +def test_apply_label_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + label = "label" + row_filter = ApplyLabelFilter(label) + pb_val = row_filter._to_pb() + expected_pb = _RowFilterPB(apply_label_transformer=label) + assert pb_val == expected_pb + + +def test_apply_label_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + label = "label" + row_filter = ApplyLabelFilter(label) + expected_dict = {"apply_label_transformer": label} + assert row_filter._to_dict() == expected_dict + expected_pb_value = row_filter._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_apply_label_filter___repr__(): + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + label = "label" + row_filter = ApplyLabelFilter(label) + expected = "ApplyLabelFilter(label={})".format(label) + assert repr(row_filter) == expected + assert repr(row_filter) == str(row_filter) + assert eval(repr(row_filter)) == row_filter + + +def test_filter_combination_constructor_defaults(): + for FilterType in _get_filter_combination_filters(): + row_filter = FilterType() + assert row_filter.filters == [] + + +def test_filter_combination_constructor_explicit(): + for FilterType in _get_filter_combination_filters(): + filters = object() + row_filter = FilterType(filters=filters) + assert row_filter.filters is filters + + +def test_filter_combination___eq__(): + for FilterType in _get_filter_combination_filters(): + filters = object() + row_filter1 = FilterType(filters=filters) + row_filter2 = FilterType(filters=filters) + assert row_filter1 == row_filter2 + + +def test_filter_combination___eq__type_differ(): + for FilterType in _get_filter_combination_filters(): + filters = object() + row_filter1 = FilterType(filters=filters) + row_filter2 = object() + assert not (row_filter1 == row_filter2) + + +def test_filter_combination___ne__(): + for FilterType in _get_filter_combination_filters(): + filters = object() + other_filters = object() + row_filter1 = FilterType(filters=filters) + row_filter2 = FilterType(filters=other_filters) + assert row_filter1 != row_filter2 + + +def test_filter_combination_len(): + for FilterType in _get_filter_combination_filters(): + filters = [object(), object()] + row_filter = FilterType(filters=filters) + assert len(row_filter) == len(filters) + + +def test_filter_combination_iter(): + for FilterType in _get_filter_combination_filters(): + filters = [object(), object()] + row_filter = FilterType(filters=filters) + assert list(iter(row_filter)) == filters + for filter_, expected in zip(row_filter, filters): + assert filter_ is expected + + +def test_filter_combination___getitem__(): + for FilterType in _get_filter_combination_filters(): + filters = [object(), object()] + row_filter = FilterType(filters=filters) + row_filter[0] is filters[0] + row_filter[1] is filters[1] + with pytest.raises(IndexError): + row_filter[2] + row_filter[:] is filters[:] + + +def test_filter_combination___str__(): + from google.cloud.bigtable.data.row_filters import PassAllFilter + + for FilterType in _get_filter_combination_filters(): + filters = [PassAllFilter(True), PassAllFilter(False)] + row_filter = FilterType(filters=filters) + expected = ( + "([\n PassAllFilter(flag=True),\n PassAllFilter(flag=False),\n])" + ) + assert expected in str(row_filter) + + +def test_row_filter_chain_to_pb(): + from google.cloud.bigtable.data.row_filters import RowFilterChain + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_pb = row_filter1._to_pb() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_pb = row_filter2._to_pb() + + row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2]) + filter_pb = row_filter3._to_pb() + + expected_pb = _RowFilterPB( + chain=_RowFilterChainPB(filters=[row_filter1_pb, row_filter2_pb]) + ) + assert filter_pb == expected_pb + + +def test_row_filter_chain_to_dict(): + from google.cloud.bigtable.data.row_filters import RowFilterChain + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_dict = row_filter1._to_dict() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_dict = row_filter2._to_dict() + + row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2]) + filter_dict = row_filter3._to_dict() + + expected_dict = {"chain": {"filters": [row_filter1_dict, row_filter2_dict]}} + assert filter_dict == expected_dict + expected_pb_value = row_filter3._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_row_filter_chain_to_pb_nested(): + from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter + from google.cloud.bigtable.data.row_filters import RowFilterChain + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter2 = RowSampleFilter(0.25) + + row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2]) + row_filter3_pb = row_filter3._to_pb() + + row_filter4 = CellsRowLimitFilter(11) + row_filter4_pb = row_filter4._to_pb() + + row_filter5 = RowFilterChain(filters=[row_filter3, row_filter4]) + filter_pb = row_filter5._to_pb() + + expected_pb = _RowFilterPB( + chain=_RowFilterChainPB(filters=[row_filter3_pb, row_filter4_pb]) + ) + assert filter_pb == expected_pb + + +def test_row_filter_chain_to_dict_nested(): + from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter + from google.cloud.bigtable.data.row_filters import RowFilterChain + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter1 = StripValueTransformerFilter(True) + + row_filter2 = RowSampleFilter(0.25) + + row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2]) + row_filter3_dict = row_filter3._to_dict() + + row_filter4 = CellsRowLimitFilter(11) + row_filter4_dict = row_filter4._to_dict() + + row_filter5 = RowFilterChain(filters=[row_filter3, row_filter4]) + filter_dict = row_filter5._to_dict() + + expected_dict = {"chain": {"filters": [row_filter3_dict, row_filter4_dict]}} + assert filter_dict == expected_dict + expected_pb_value = row_filter5._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_row_filter_chain___repr__(): + from google.cloud.bigtable.data.row_filters import RowFilterChain + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter2 = RowSampleFilter(0.25) + + row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2]) + expected = f"RowFilterChain(filters={[row_filter1, row_filter2]})" + assert repr(row_filter3) == expected + assert eval(repr(row_filter3)) == row_filter3 + + +def test_row_filter_chain___str__(): + from google.cloud.bigtable.data.row_filters import RowFilterChain + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter2 = RowSampleFilter(0.25) + + row_filter3 = RowFilterChain(filters=[row_filter1, row_filter2]) + expected = "RowFilterChain([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n])" + assert str(row_filter3) == expected + # test nested + row_filter4 = RowFilterChain(filters=[row_filter3]) + expected = "RowFilterChain([\n RowFilterChain([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n ]),\n])" + assert str(row_filter4) == expected + + +def test_row_filter_union_to_pb(): + from google.cloud.bigtable.data.row_filters import RowFilterUnion + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_pb = row_filter1._to_pb() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_pb = row_filter2._to_pb() + + row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2]) + filter_pb = row_filter3._to_pb() + + expected_pb = _RowFilterPB( + interleave=_RowFilterInterleavePB(filters=[row_filter1_pb, row_filter2_pb]) + ) + assert filter_pb == expected_pb + + +def test_row_filter_union_to_dict(): + from google.cloud.bigtable.data.row_filters import RowFilterUnion + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_dict = row_filter1._to_dict() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_dict = row_filter2._to_dict() + + row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2]) + filter_dict = row_filter3._to_dict() + + expected_dict = {"interleave": {"filters": [row_filter1_dict, row_filter2_dict]}} + assert filter_dict == expected_dict + expected_pb_value = row_filter3._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_row_filter_union_to_pb_nested(): + from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter + from google.cloud.bigtable.data.row_filters import RowFilterUnion + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter2 = RowSampleFilter(0.25) + + row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2]) + row_filter3_pb = row_filter3._to_pb() + + row_filter4 = CellsRowLimitFilter(11) + row_filter4_pb = row_filter4._to_pb() + + row_filter5 = RowFilterUnion(filters=[row_filter3, row_filter4]) + filter_pb = row_filter5._to_pb() + + expected_pb = _RowFilterPB( + interleave=_RowFilterInterleavePB(filters=[row_filter3_pb, row_filter4_pb]) + ) + assert filter_pb == expected_pb + + +def test_row_filter_union_to_dict_nested(): + from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter + from google.cloud.bigtable.data.row_filters import RowFilterUnion + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter1 = StripValueTransformerFilter(True) + + row_filter2 = RowSampleFilter(0.25) + + row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2]) + row_filter3_dict = row_filter3._to_dict() + + row_filter4 = CellsRowLimitFilter(11) + row_filter4_dict = row_filter4._to_dict() + + row_filter5 = RowFilterUnion(filters=[row_filter3, row_filter4]) + filter_dict = row_filter5._to_dict() + + expected_dict = {"interleave": {"filters": [row_filter3_dict, row_filter4_dict]}} + assert filter_dict == expected_dict + expected_pb_value = row_filter5._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_row_filter_union___repr__(): + from google.cloud.bigtable.data.row_filters import RowFilterUnion + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter2 = RowSampleFilter(0.25) + + row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2]) + expected = "RowFilterUnion(filters=[StripValueTransformerFilter(flag=True), RowSampleFilter(sample=0.25)])" + assert repr(row_filter3) == expected + assert eval(repr(row_filter3)) == row_filter3 + + +def test_row_filter_union___str__(): + from google.cloud.bigtable.data.row_filters import RowFilterUnion + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter2 = RowSampleFilter(0.25) + + row_filter3 = RowFilterUnion(filters=[row_filter1, row_filter2]) + expected = "RowFilterUnion([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n])" + assert str(row_filter3) == expected + # test nested + row_filter4 = RowFilterUnion(filters=[row_filter3]) + expected = "RowFilterUnion([\n RowFilterUnion([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n ]),\n])" + assert str(row_filter4) == expected + + +def test_conditional_row_filter_constructor(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + + predicate_filter = object() + true_filter = object() + false_filter = object() + cond_filter = ConditionalRowFilter( + predicate_filter, true_filter=true_filter, false_filter=false_filter + ) + assert cond_filter.predicate_filter is predicate_filter + assert cond_filter.true_filter is true_filter + assert cond_filter.false_filter is false_filter + + +def test_conditional_row_filter___eq__(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + + predicate_filter = object() + true_filter = object() + false_filter = object() + cond_filter1 = ConditionalRowFilter( + predicate_filter, true_filter=true_filter, false_filter=false_filter + ) + cond_filter2 = ConditionalRowFilter( + predicate_filter, true_filter=true_filter, false_filter=false_filter + ) + assert cond_filter1 == cond_filter2 + + +def test_conditional_row_filter___eq__type_differ(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + + predicate_filter = object() + true_filter = object() + false_filter = object() + cond_filter1 = ConditionalRowFilter( + predicate_filter, true_filter=true_filter, false_filter=false_filter + ) + cond_filter2 = object() + assert not (cond_filter1 == cond_filter2) + + +def test_conditional_row_filter___ne__(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + + predicate_filter = object() + other_predicate_filter = object() + true_filter = object() + false_filter = object() + cond_filter1 = ConditionalRowFilter( + predicate_filter, true_filter=true_filter, false_filter=false_filter + ) + cond_filter2 = ConditionalRowFilter( + other_predicate_filter, true_filter=true_filter, false_filter=false_filter + ) + assert cond_filter1 != cond_filter2 + + +def test_conditional_row_filter_to_pb(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_pb = row_filter1._to_pb() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_pb = row_filter2._to_pb() + + row_filter3 = CellsRowOffsetFilter(11) + row_filter3_pb = row_filter3._to_pb() + + row_filter4 = ConditionalRowFilter( + row_filter1, true_filter=row_filter2, false_filter=row_filter3 + ) + filter_pb = row_filter4._to_pb() + + expected_pb = _RowFilterPB( + condition=_RowFilterConditionPB( + predicate_filter=row_filter1_pb, + true_filter=row_filter2_pb, + false_filter=row_filter3_pb, + ) + ) + assert filter_pb == expected_pb + + +def test_conditional_row_filter_to_dict(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + from google.cloud.bigtable.data.row_filters import CellsRowOffsetFilter + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_dict = row_filter1._to_dict() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_dict = row_filter2._to_dict() + + row_filter3 = CellsRowOffsetFilter(11) + row_filter3_dict = row_filter3._to_dict() + + row_filter4 = ConditionalRowFilter( + row_filter1, true_filter=row_filter2, false_filter=row_filter3 + ) + filter_dict = row_filter4._to_dict() + + expected_dict = { + "condition": { + "predicate_filter": row_filter1_dict, + "true_filter": row_filter2_dict, + "false_filter": row_filter3_dict, + } + } + assert filter_dict == expected_dict + expected_pb_value = row_filter4._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_conditional_row_filter_to_pb_true_only(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_pb = row_filter1._to_pb() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_pb = row_filter2._to_pb() + + row_filter3 = ConditionalRowFilter(row_filter1, true_filter=row_filter2) + filter_pb = row_filter3._to_pb() + + expected_pb = _RowFilterPB( + condition=_RowFilterConditionPB( + predicate_filter=row_filter1_pb, true_filter=row_filter2_pb + ) + ) + assert filter_pb == expected_pb + + +def test_conditional_row_filter_to_dict_true_only(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_dict = row_filter1._to_dict() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_dict = row_filter2._to_dict() + + row_filter3 = ConditionalRowFilter(row_filter1, true_filter=row_filter2) + filter_dict = row_filter3._to_dict() + + expected_dict = { + "condition": { + "predicate_filter": row_filter1_dict, + "true_filter": row_filter2_dict, + } + } + assert filter_dict == expected_dict + expected_pb_value = row_filter3._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_conditional_row_filter_to_pb_false_only(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_pb = row_filter1._to_pb() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_pb = row_filter2._to_pb() + + row_filter3 = ConditionalRowFilter(row_filter1, false_filter=row_filter2) + filter_pb = row_filter3._to_pb() + + expected_pb = _RowFilterPB( + condition=_RowFilterConditionPB( + predicate_filter=row_filter1_pb, false_filter=row_filter2_pb + ) + ) + assert filter_pb == expected_pb + + +def test_conditional_row_filter_to_dict_false_only(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + row_filter1 = StripValueTransformerFilter(True) + row_filter1_dict = row_filter1._to_dict() + + row_filter2 = RowSampleFilter(0.25) + row_filter2_dict = row_filter2._to_dict() + + row_filter3 = ConditionalRowFilter(row_filter1, false_filter=row_filter2) + filter_dict = row_filter3._to_dict() + + expected_dict = { + "condition": { + "predicate_filter": row_filter1_dict, + "false_filter": row_filter2_dict, + } + } + assert filter_dict == expected_dict + expected_pb_value = row_filter3._to_pb() + assert data_v2_pb2.RowFilter(**expected_dict) == expected_pb_value + + +def test_conditional_row_filter___repr__(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter2 = RowSampleFilter(0.25) + row_filter3 = ConditionalRowFilter(row_filter1, true_filter=row_filter2) + expected = ( + "ConditionalRowFilter(predicate_filter=StripValueTransformerFilter(" + "flag=True), true_filter=RowSampleFilter(sample=0.25), false_filter=None)" + ) + assert repr(row_filter3) == expected + assert eval(repr(row_filter3)) == row_filter3 + # test nested + row_filter4 = ConditionalRowFilter(row_filter3, true_filter=row_filter2) + expected = "ConditionalRowFilter(predicate_filter=ConditionalRowFilter(predicate_filter=StripValueTransformerFilter(flag=True), true_filter=RowSampleFilter(sample=0.25), false_filter=None), true_filter=RowSampleFilter(sample=0.25), false_filter=None)" + assert repr(row_filter4) == expected + assert eval(repr(row_filter4)) == row_filter4 + + +def test_conditional_row_filter___str__(): + from google.cloud.bigtable.data.row_filters import ConditionalRowFilter + from google.cloud.bigtable.data.row_filters import RowSampleFilter + from google.cloud.bigtable.data.row_filters import RowFilterUnion + from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter + + row_filter1 = StripValueTransformerFilter(True) + row_filter2 = RowSampleFilter(0.25) + row_filter3 = ConditionalRowFilter(row_filter1, true_filter=row_filter2) + expected = "ConditionalRowFilter(\n predicate_filter=StripValueTransformerFilter(flag=True),\n true_filter=RowSampleFilter(sample=0.25),\n)" + assert str(row_filter3) == expected + # test nested + row_filter4 = ConditionalRowFilter( + row_filter3, + true_filter=row_filter2, + false_filter=RowFilterUnion([row_filter1, row_filter2]), + ) + expected = "ConditionalRowFilter(\n predicate_filter=ConditionalRowFilter(\n predicate_filter=StripValueTransformerFilter(flag=True),\n true_filter=RowSampleFilter(sample=0.25),\n ),\n true_filter=RowSampleFilter(sample=0.25),\n false_filter=RowFilterUnion([\n StripValueTransformerFilter(flag=True),\n RowSampleFilter(sample=0.25),\n ]),\n)" + assert str(row_filter4) == expected + + +@pytest.mark.parametrize( + "input_arg, expected_bytes", + [ + (b"abc", b"abc"), + ("abc", b"abc"), + (1, b"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\\x01"), # null bytes are ascii + (b"*", b"\\*"), + (".", b"\\."), + (b"\\", b"\\\\"), + (b"h.*i", b"h\\.\\*i"), + (b'""', b'\\"\\"'), + (b"[xyz]", b"\\[xyz\\]"), + (b"\xe2\x98\xba\xef\xb8\x8f", b"\xe2\x98\xba\xef\xb8\x8f"), + ("ā˜ƒ", b"\xe2\x98\x83"), + (r"\Cā˜ƒ", b"\\\\C\xe2\x98\x83"), + ], +) +def test_literal_value__write_literal_regex(input_arg, expected_bytes): + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + + filter_ = LiteralValueFilter(input_arg) + assert filter_.regex == expected_bytes + + +def _ColumnRangePB(*args, **kw): + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + return data_v2_pb2.ColumnRange(*args, **kw) + + +def _RowFilterPB(*args, **kw): + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + return data_v2_pb2.RowFilter(*args, **kw) + + +def _RowFilterChainPB(*args, **kw): + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + return data_v2_pb2.RowFilter.Chain(*args, **kw) + + +def _RowFilterConditionPB(*args, **kw): + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + return data_v2_pb2.RowFilter.Condition(*args, **kw) + + +def _RowFilterInterleavePB(*args, **kw): + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + return data_v2_pb2.RowFilter.Interleave(*args, **kw) + + +def _TimestampRangePB(*args, **kw): + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + return data_v2_pb2.TimestampRange(*args, **kw) + + +def _ValueRangePB(*args, **kw): + from google.cloud.bigtable_v2.types import data as data_v2_pb2 + + return data_v2_pb2.ValueRange(*args, **kw) + + +def _get_regex_filters(): + from google.cloud.bigtable.data.row_filters import ( + RowKeyRegexFilter, + FamilyNameRegexFilter, + ColumnQualifierRegexFilter, + ValueRegexFilter, + LiteralValueFilter, + ) + + return [ + RowKeyRegexFilter, + FamilyNameRegexFilter, + ColumnQualifierRegexFilter, + ValueRegexFilter, + LiteralValueFilter, + ] + + +def _get_bool_filters(): + from google.cloud.bigtable.data.row_filters import ( + SinkFilter, + PassAllFilter, + BlockAllFilter, + StripValueTransformerFilter, + ) + + return [ + SinkFilter, + PassAllFilter, + BlockAllFilter, + StripValueTransformerFilter, + ] + + +def _get_cell_count_filters(): + from google.cloud.bigtable.data.row_filters import ( + CellsRowLimitFilter, + CellsRowOffsetFilter, + CellsColumnLimitFilter, + ) + + return [ + CellsRowLimitFilter, + CellsRowOffsetFilter, + CellsColumnLimitFilter, + ] + + +def _get_filter_combination_filters(): + from google.cloud.bigtable.data.row_filters import ( + RowFilterChain, + RowFilterUnion, + ) + + return [ + RowFilterChain, + RowFilterUnion, + ] diff --git a/tests/unit/v2_client/__init__.py b/tests/unit/v2_client/__init__.py new file mode 100644 index 000000000..e8e1c3845 --- /dev/null +++ b/tests/unit/v2_client/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/_testing.py b/tests/unit/v2_client/_testing.py similarity index 100% rename from tests/unit/_testing.py rename to tests/unit/v2_client/_testing.py diff --git a/tests/unit/v2_client/read-rows-acceptance-test.json b/tests/unit/v2_client/read-rows-acceptance-test.json new file mode 100644 index 000000000..011ace2b9 --- /dev/null +++ b/tests/unit/v2_client/read-rows-acceptance-test.json @@ -0,0 +1,1665 @@ +{ + "readRowsTests": [ + { + "description": "invalid - no commit", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "invalid - no cell key before commit", + "chunks": [ + { + "commitRow": true + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "invalid - no cell key before value", + "chunks": [ + { + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "invalid - new col family must specify qualifier", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "familyName": "B", + "timestampMicros": "98", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "bare commit implies ts=0", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + }, + { + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C" + } + ] + }, + { + "description": "simple row with timestamp", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + } + ] + }, + { + "description": "missing timestamp, implied ts=0", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "value": "value-VAL" + } + ] + }, + { + "description": "empty cell value", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C" + } + ] + }, + { + "description": "two unsplit cells", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "timestampMicros": "98", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "98", + "value": "value-VAL_2" + } + ] + }, + { + "description": "two qualifiers", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "qualifier": "RA==", + "timestampMicros": "98", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "D", + "timestampMicros": "98", + "value": "value-VAL_2" + } + ] + }, + { + "description": "two families", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "familyName": "B", + "qualifier": "RQ==", + "timestampMicros": "98", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK", + "familyName": "B", + "qualifier": "E", + "timestampMicros": "98", + "value": "value-VAL_2" + } + ] + }, + { + "description": "with labels", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "labels": [ + "L_1" + ], + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "timestampMicros": "98", + "labels": [ + "L_2" + ], + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1", + "label": "L_1" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "98", + "value": "value-VAL_2", + "label": "L_2" + } + ] + }, + { + "description": "split cell, bare commit", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dg==", + "valueSize": 9, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUw=", + "commitRow": false + }, + { + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C" + } + ] + }, + { + "description": "split cell", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dg==", + "valueSize": 9, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUw=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + } + ] + }, + { + "description": "split four ways", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "labels": [ + "L" + ], + "value": "dg==", + "valueSize": 9, + "commitRow": false + }, + { + "value": "YQ==", + "valueSize": 9, + "commitRow": false + }, + { + "value": "bA==", + "valueSize": 9, + "commitRow": false + }, + { + "value": "dWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL", + "label": "L" + } + ] + }, + { + "description": "two split cells", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUxfMQ==", + "commitRow": false + }, + { + "timestampMicros": "98", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUxfMg==", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "98", + "value": "value-VAL_2" + } + ] + }, + { + "description": "multi-qualifier splits", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUxfMQ==", + "commitRow": false + }, + { + "qualifier": "RA==", + "timestampMicros": "98", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUxfMg==", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "D", + "timestampMicros": "98", + "value": "value-VAL_2" + } + ] + }, + { + "description": "multi-qualifier multi-split", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YQ==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "bHVlLVZBTF8x", + "commitRow": false + }, + { + "qualifier": "RA==", + "timestampMicros": "98", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YQ==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "bHVlLVZBTF8y", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "D", + "timestampMicros": "98", + "value": "value-VAL_2" + } + ] + }, + { + "description": "multi-family split", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUxfMQ==", + "commitRow": false + }, + { + "familyName": "B", + "qualifier": "RQ==", + "timestampMicros": "98", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUxfMg==", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK", + "familyName": "B", + "qualifier": "E", + "timestampMicros": "98", + "value": "value-VAL_2" + } + ] + }, + { + "description": "invalid - no commit between rows", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + }, + { + "rowKey": "UktfMg==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "invalid - no commit after first row", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + }, + { + "rowKey": "UktfMg==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "invalid - last row missing commit", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + }, + { + "error": true + } + ] + }, + { + "description": "invalid - duplicate row key", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + }, + { + "rowKey": "UktfMQ==", + "familyName": "B", + "qualifier": "RA==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + }, + { + "error": true + } + ] + }, + { + "description": "invalid - new row missing row key", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + }, + { + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + }, + { + "error": true + } + ] + }, + { + "description": "two rows", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + }, + { + "rowKey": "RK_2", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + } + ] + }, + { + "description": "two rows implicit timestamp", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "value": "dmFsdWUtVkFM", + "commitRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "value": "value-VAL" + }, + { + "rowKey": "RK_2", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + } + ] + }, + { + "description": "two rows empty value", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "commitRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C" + }, + { + "rowKey": "RK_2", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + } + ] + }, + { + "description": "two rows, one with multiple cells", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "timestampMicros": "98", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "B", + "qualifier": "RA==", + "timestampMicros": "97", + "value": "dmFsdWUtVkFMXzM=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "98", + "value": "value-VAL_2" + }, + { + "rowKey": "RK_2", + "familyName": "B", + "qualifier": "D", + "timestampMicros": "97", + "value": "value-VAL_3" + } + ] + }, + { + "description": "two rows, multiple cells", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "qualifier": "RA==", + "timestampMicros": "98", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "B", + "qualifier": "RQ==", + "timestampMicros": "97", + "value": "dmFsdWUtVkFMXzM=", + "commitRow": false + }, + { + "qualifier": "Rg==", + "timestampMicros": "96", + "value": "dmFsdWUtVkFMXzQ=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "D", + "timestampMicros": "98", + "value": "value-VAL_2" + }, + { + "rowKey": "RK_2", + "familyName": "B", + "qualifier": "E", + "timestampMicros": "97", + "value": "value-VAL_3" + }, + { + "rowKey": "RK_2", + "familyName": "B", + "qualifier": "F", + "timestampMicros": "96", + "value": "value-VAL_4" + } + ] + }, + { + "description": "two rows, multiple cells, multiple families", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "familyName": "B", + "qualifier": "RQ==", + "timestampMicros": "98", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "M", + "qualifier": "Tw==", + "timestampMicros": "97", + "value": "dmFsdWUtVkFMXzM=", + "commitRow": false + }, + { + "familyName": "N", + "qualifier": "UA==", + "timestampMicros": "96", + "value": "dmFsdWUtVkFMXzQ=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1" + }, + { + "rowKey": "RK_1", + "familyName": "B", + "qualifier": "E", + "timestampMicros": "98", + "value": "value-VAL_2" + }, + { + "rowKey": "RK_2", + "familyName": "M", + "qualifier": "O", + "timestampMicros": "97", + "value": "value-VAL_3" + }, + { + "rowKey": "RK_2", + "familyName": "N", + "qualifier": "P", + "timestampMicros": "96", + "value": "value-VAL_4" + } + ] + }, + { + "description": "two rows, four cells, 2 labels", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "99", + "labels": [ + "L_1" + ], + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "timestampMicros": "98", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "B", + "qualifier": "RA==", + "timestampMicros": "97", + "labels": [ + "L_3" + ], + "value": "dmFsdWUtVkFMXzM=", + "commitRow": false + }, + { + "timestampMicros": "96", + "value": "dmFsdWUtVkFMXzQ=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "99", + "value": "value-VAL_1", + "label": "L_1" + }, + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "98", + "value": "value-VAL_2" + }, + { + "rowKey": "RK_2", + "familyName": "B", + "qualifier": "D", + "timestampMicros": "97", + "value": "value-VAL_3", + "label": "L_3" + }, + { + "rowKey": "RK_2", + "familyName": "B", + "qualifier": "D", + "timestampMicros": "96", + "value": "value-VAL_4" + } + ] + }, + { + "description": "two rows with splits, same timestamp", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUxfMQ==", + "commitRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dg==", + "valueSize": 11, + "commitRow": false + }, + { + "value": "YWx1ZS1WQUxfMg==", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL_1" + }, + { + "rowKey": "RK_2", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL_2" + } + ] + }, + { + "description": "invalid - bare reset", + "chunks": [ + { + "resetRow": true + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "invalid - bad reset, no commit", + "chunks": [ + { + "resetRow": true + }, + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "invalid - missing key after reset", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + }, + { + "resetRow": true + }, + { + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "no data after reset", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + }, + { + "resetRow": true + } + ] + }, + { + "description": "simple reset", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + } + ] + }, + { + "description": "reset to new val", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL_2" + } + ] + }, + { + "description": "reset to new qual", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "RA==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "D", + "timestampMicros": "100", + "value": "value-VAL_1" + } + ] + }, + { + "description": "reset with splits", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "timestampMicros": "98", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL_2" + } + ] + }, + { + "description": "reset two cells", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": false + }, + { + "timestampMicros": "97", + "value": "dmFsdWUtVkFMXzM=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL_2" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "97", + "value": "value-VAL_3" + } + ] + }, + { + "description": "two resets", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzM=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL_3" + } + ] + }, + { + "description": "reset then two cells", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "Uks=", + "familyName": "B", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": false + }, + { + "qualifier": "RA==", + "timestampMicros": "97", + "value": "dmFsdWUtVkFMXzM=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "B", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL_2" + }, + { + "rowKey": "RK", + "familyName": "B", + "qualifier": "D", + "timestampMicros": "97", + "value": "value-VAL_3" + } + ] + }, + { + "description": "reset to new row", + "chunks": [ + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "UktfMg==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzI=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_2", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL_2" + } + ] + }, + { + "description": "reset in between chunks", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "labels": [ + "L" + ], + "value": "dg==", + "valueSize": 10, + "commitRow": false + }, + { + "value": "YQ==", + "valueSize": 10, + "commitRow": false + }, + { + "resetRow": true + }, + { + "rowKey": "UktfMQ==", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFMXzE=", + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK_1", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL_1" + } + ] + }, + { + "description": "invalid - reset with chunk", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "labels": [ + "L" + ], + "value": "dg==", + "valueSize": 10, + "commitRow": false + }, + { + "value": "YQ==", + "valueSize": 10, + "resetRow": true + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "invalid - commit with chunk", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "labels": [ + "L" + ], + "value": "dg==", + "valueSize": 10, + "commitRow": false + }, + { + "value": "YQ==", + "valueSize": 10, + "commitRow": true + } + ], + "results": [ + { + "error": true + } + ] + }, + { + "description": "empty cell chunk", + "chunks": [ + { + "rowKey": "Uks=", + "familyName": "A", + "qualifier": "Qw==", + "timestampMicros": "100", + "value": "dmFsdWUtVkFM", + "commitRow": false + }, + { + "commitRow": false + }, + { + "commitRow": true + } + ], + "results": [ + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C", + "timestampMicros": "100", + "value": "value-VAL" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C" + }, + { + "rowKey": "RK", + "familyName": "A", + "qualifier": "C" + } + ] + } + ] +} diff --git a/tests/unit/test_app_profile.py b/tests/unit/v2_client/test_app_profile.py similarity index 100% rename from tests/unit/test_app_profile.py rename to tests/unit/v2_client/test_app_profile.py diff --git a/tests/unit/test_backup.py b/tests/unit/v2_client/test_backup.py similarity index 100% rename from tests/unit/test_backup.py rename to tests/unit/v2_client/test_backup.py diff --git a/tests/unit/test_batcher.py b/tests/unit/v2_client/test_batcher.py similarity index 98% rename from tests/unit/test_batcher.py rename to tests/unit/v2_client/test_batcher.py index 741d9f282..fcf606972 100644 --- a/tests/unit/test_batcher.py +++ b/tests/unit/v2_client/test_batcher.py @@ -198,7 +198,7 @@ def test_mutations_batcher_response_with_error_codes(): mocked_response = [Status(code=1), Status(code=5)] - with mock.patch("tests.unit.test_batcher._Table") as mocked_table: + with mock.patch("tests.unit.v2_client.test_batcher._Table") as mocked_table: table = mocked_table.return_value mutation_batcher = MutationsBatcher(table=table) diff --git a/tests/unit/test_client.py b/tests/unit/v2_client/test_client.py similarity index 100% rename from tests/unit/test_client.py rename to tests/unit/v2_client/test_client.py diff --git a/tests/unit/test_cluster.py b/tests/unit/v2_client/test_cluster.py similarity index 100% rename from tests/unit/test_cluster.py rename to tests/unit/v2_client/test_cluster.py diff --git a/tests/unit/test_column_family.py b/tests/unit/v2_client/test_column_family.py similarity index 99% rename from tests/unit/test_column_family.py rename to tests/unit/v2_client/test_column_family.py index 80b05d744..e4f74e264 100644 --- a/tests/unit/test_column_family.py +++ b/tests/unit/v2_client/test_column_family.py @@ -336,7 +336,7 @@ def _create_test_helper(gc_rule=None): from google.cloud.bigtable_admin_v2.types import ( bigtable_table_admin as table_admin_v2_pb2, ) - from tests.unit._testing import _FakeStub + from ._testing import _FakeStub from google.cloud.bigtable_admin_v2.services.bigtable_table_admin import ( BigtableTableAdminClient, ) @@ -404,7 +404,7 @@ def test_column_family_create_with_gc_rule(): def _update_test_helper(gc_rule=None): - from tests.unit._testing import _FakeStub + from ._testing import _FakeStub from google.cloud.bigtable_admin_v2.types import ( bigtable_table_admin as table_admin_v2_pb2, ) @@ -478,7 +478,7 @@ def test_column_family_delete(): from google.cloud.bigtable_admin_v2.types import ( bigtable_table_admin as table_admin_v2_pb2, ) - from tests.unit._testing import _FakeStub + from ._testing import _FakeStub from google.cloud.bigtable_admin_v2.services.bigtable_table_admin import ( BigtableTableAdminClient, ) diff --git a/tests/unit/test_encryption_info.py b/tests/unit/v2_client/test_encryption_info.py similarity index 100% rename from tests/unit/test_encryption_info.py rename to tests/unit/v2_client/test_encryption_info.py diff --git a/tests/unit/test_error.py b/tests/unit/v2_client/test_error.py similarity index 100% rename from tests/unit/test_error.py rename to tests/unit/v2_client/test_error.py diff --git a/tests/unit/test_instance.py b/tests/unit/v2_client/test_instance.py similarity index 100% rename from tests/unit/test_instance.py rename to tests/unit/v2_client/test_instance.py diff --git a/tests/unit/test_policy.py b/tests/unit/v2_client/test_policy.py similarity index 100% rename from tests/unit/test_policy.py rename to tests/unit/v2_client/test_policy.py diff --git a/tests/unit/test_row.py b/tests/unit/v2_client/test_row.py similarity index 99% rename from tests/unit/test_row.py rename to tests/unit/v2_client/test_row.py index 49bbfc45c..f04802f5c 100644 --- a/tests/unit/test_row.py +++ b/tests/unit/v2_client/test_row.py @@ -480,7 +480,7 @@ def test_conditional_row_commit_too_many_mutations(): def test_conditional_row_commit_no_mutations(): - from tests.unit._testing import _FakeStub + from ._testing import _FakeStub project_id = "project-id" row_key = b"row_key" @@ -607,7 +607,7 @@ def mock_parse_rmw_row_response(row_response): def test_append_row_commit_no_rules(): - from tests.unit._testing import _FakeStub + from ._testing import _FakeStub project_id = "project-id" row_key = b"row_key" diff --git a/tests/unit/test_row_data.py b/tests/unit/v2_client/test_row_data.py similarity index 100% rename from tests/unit/test_row_data.py rename to tests/unit/v2_client/test_row_data.py diff --git a/tests/unit/test_row_filters.py b/tests/unit/v2_client/test_row_filters.py similarity index 100% rename from tests/unit/test_row_filters.py rename to tests/unit/v2_client/test_row_filters.py diff --git a/tests/unit/test_row_merger.py b/tests/unit/v2_client/test_row_merger.py similarity index 100% rename from tests/unit/test_row_merger.py rename to tests/unit/v2_client/test_row_merger.py diff --git a/tests/unit/test_row_set.py b/tests/unit/v2_client/test_row_set.py similarity index 100% rename from tests/unit/test_row_set.py rename to tests/unit/v2_client/test_row_set.py diff --git a/tests/unit/test_table.py b/tests/unit/v2_client/test_table.py similarity index 100% rename from tests/unit/test_table.py rename to tests/unit/v2_client/test_table.py