Skip to content

Commit

Permalink
feat: add type annotations to wrapped grpc calls (#554)
Browse files Browse the repository at this point in the history
* add types to grpc call wrappers

* fixed tests

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* changed type

* changed async types

* added tests

* fixed lint issues

* Update tests/asyncio/test_grpc_helpers_async.py

Co-authored-by: Anthonios Partheniou <partheniou@google.com>

* turned GrpcStream into a type alias

* added test for GrpcStream

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* added comment

* reordered types

* changed type var to P

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
Co-authored-by: Anthonios Partheniou <partheniou@google.com>
  • Loading branch information
3 people committed Nov 17, 2023
1 parent 448923a commit fc12b40
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 13 deletions.
14 changes: 11 additions & 3 deletions google/api_core/grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Helpers for :mod:`grpc`."""
from typing import Generic, TypeVar, Iterator

import collections
import functools
Expand Down Expand Up @@ -54,6 +55,9 @@

_LOGGER = logging.getLogger(__name__)

# denotes the proto response type for grpc calls
P = TypeVar("P")


def _patch_callable_name(callable_):
"""Fix-up gRPC callable attributes.
Expand All @@ -79,7 +83,7 @@ def error_remapped_callable(*args, **kwargs):
return error_remapped_callable


class _StreamingResponseIterator(grpc.Call):
class _StreamingResponseIterator(Generic[P], grpc.Call):
def __init__(self, wrapped, prefetch_first_result=True):
self._wrapped = wrapped

Expand All @@ -97,11 +101,11 @@ def __init__(self, wrapped, prefetch_first_result=True):
# ignore stop iteration at this time. This should be handled outside of retry.
pass

def __iter__(self):
def __iter__(self) -> Iterator[P]:
"""This iterator is also an iterable that returns itself."""
return self

def __next__(self):
def __next__(self) -> P:
"""Get the next response from the stream.
Returns:
Expand Down Expand Up @@ -144,6 +148,10 @@ def trailing_metadata(self):
return self._wrapped.trailing_metadata()


# public type alias denoting the return type of streaming gapic calls
GrpcStream = _StreamingResponseIterator[P]


def _wrap_stream_errors(callable_):
"""Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
Expand Down
30 changes: 20 additions & 10 deletions google/api_core/grpc_helpers_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
import asyncio
import functools

from typing import Generic, Iterator, AsyncGenerator, TypeVar

import grpc
from grpc import aio

from google.api_core import exceptions, grpc_helpers

# denotes the proto response type for grpc calls
P = TypeVar("P")

# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform
# automatic patching for us. But that means the overhead of creating an
Expand Down Expand Up @@ -75,26 +79,26 @@ async def wait_for_connection(self):
raise exceptions.from_grpc_error(rpc_error) from rpc_error


class _WrappedUnaryResponseMixin(_WrappedCall):
def __await__(self):
class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall):
def __await__(self) -> Iterator[P]:
try:
response = yield from self._call.__await__()
return response
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error


class _WrappedStreamResponseMixin(_WrappedCall):
class _WrappedStreamResponseMixin(Generic[P], _WrappedCall):
def __init__(self):
self._wrapped_async_generator = None

async def read(self):
async def read(self) -> P:
try:
return await self._call.read()
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error

async def _wrapped_aiter(self):
async def _wrapped_aiter(self) -> AsyncGenerator[P, None]:
try:
# NOTE(lidiz) coverage doesn't understand the exception raised from
# __anext__ method. It is covered by test case:
Expand All @@ -104,7 +108,7 @@ async def _wrapped_aiter(self):
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error

def __aiter__(self):
def __aiter__(self) -> AsyncGenerator[P, None]:
if not self._wrapped_async_generator:
self._wrapped_async_generator = self._wrapped_aiter()
return self._wrapped_async_generator
Expand All @@ -127,26 +131,32 @@ async def done_writing(self):
# NOTE(lidiz) Implementing each individual class separately, so we don't
# expose any API that should not be seen. E.g., __aiter__ in unary-unary
# RPC, or __await__ in stream-stream RPC.
class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin, aio.UnaryUnaryCall):
class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall):
"""Wrapped UnaryUnaryCall to map exceptions."""


class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin, aio.UnaryStreamCall):
class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall):
"""Wrapped UnaryStreamCall to map exceptions."""


class _WrappedStreamUnaryCall(
_WrappedUnaryResponseMixin, _WrappedStreamRequestMixin, aio.StreamUnaryCall
_WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall
):
"""Wrapped StreamUnaryCall to map exceptions."""


class _WrappedStreamStreamCall(
_WrappedStreamRequestMixin, _WrappedStreamResponseMixin, aio.StreamStreamCall
_WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall
):
"""Wrapped StreamStreamCall to map exceptions."""


# public type alias denoting the return type of async streaming gapic calls
GrpcAsyncStream = _WrappedStreamResponseMixin[P]
# public type alias denoting the return type of unary gapic calls
AwaitableGrpcCall = _WrappedUnaryResponseMixin[P]


def _wrap_unary_errors(callable_):
"""Map errors for Unary-Unary async callables."""
grpc_helpers._patch_callable_name(callable_)
Expand Down
22 changes: 22 additions & 0 deletions tests/asyncio/test_grpc_helpers_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,28 @@ def test_wrap_errors_non_streaming(wrap_unary_errors):
wrap_unary_errors.assert_called_once_with(callable_)


def test_grpc_async_stream():
"""
GrpcAsyncStream type should be both an AsyncIterator and a grpc.aio.Call.
"""
instance = grpc_helpers_async.GrpcAsyncStream[int]()
assert isinstance(instance, grpc.aio.Call)
# should implement __aiter__ and __anext__
assert hasattr(instance, "__aiter__")
it = instance.__aiter__()
assert hasattr(it, "__anext__")


def test_awaitable_grpc_call():
"""
AwaitableGrpcCall type should be an Awaitable and a grpc.aio.Call.
"""
instance = grpc_helpers_async.AwaitableGrpcCall[int]()
assert isinstance(instance, grpc.aio.Call)
# should implement __await__
assert hasattr(instance, "__await__")


@mock.patch("google.api_core.grpc_helpers_async._wrap_stream_errors")
def test_wrap_errors_streaming(wrap_stream_errors):
callable_ = mock.create_autospec(aio.UnaryStreamMultiCallable)
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,23 @@ def test_trailing_metadata(self):
wrapped.trailing_metadata.assert_called_once_with()


class TestGrpcStream(Test_StreamingResponseIterator):
@staticmethod
def _make_one(wrapped, **kw):
return grpc_helpers.GrpcStream(wrapped, **kw)

def test_grpc_stream_attributes(self):
"""
Should be both a grpc.Call and an iterable
"""
call = self._make_one(None)
assert isinstance(call, grpc.Call)
# should implement __iter__
assert hasattr(call, "__iter__")
it = call.__iter__()
assert hasattr(it, "__next__")


def test_wrap_stream_okay():
expected_responses = [1, 2, 3]
callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses))
Expand Down

0 comments on commit fc12b40

Please sign in to comment.