Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

service implementation stub generation #170

Merged
merged 13 commits into from
Dec 4, 2020
Merged
40 changes: 40 additions & 0 deletions src/betterproto/grpc/grpclib_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import ABC
from collections import AsyncIterable
from typing import Callable, Any, Dict

import grpclib
import grpclib.server


class ServiceBase(ABC):
"""
Base class for async gRPC servers.
"""

async def _call_rpc_handler_server_unary(
self,
handler: Callable,
stream: grpclib.server.Stream,
request_kwargs: Dict[str, Any],
) -> None:

response = await handler(**request_kwargs)
await stream.send_message(response)

async def _call_rpc_handler_server_stream(
self,
handler: Callable,
stream: grpclib.server.Stream,
request_kwargs: Dict[str, Any],
) -> None:

response_iter = handler(**request_kwargs)
# check if response is actually an AsyncIterator
# this might be false if the method just returns without
# yielding at least once
# in that case, we just interpret it as an empty iterator
if isinstance(response_iter, AsyncIterable):
async for response_message in response_iter:
await stream.send_message(response_message)
else:
response_iter.close()
9 changes: 9 additions & 0 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def proto_name(self) -> str:
def py_name(self) -> str:
return pythonize_class_name(self.proto_name)

@property
def py_name_as_field(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this isn't used either actually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not anymore since 5bbe19a. Nice catch. Good thing you're looking over this. I completely forgot.

return pythonize_field_name(self.proto_name)

@property
def annotation(self) -> str:
if self.repeated:
Expand Down Expand Up @@ -553,12 +557,17 @@ class ServiceCompiler(ProtoContentBase):
def __post_init__(self) -> None:
# Add service to output file
self.output_file.services.append(self)
self.output_file.typing_imports.add("Dict")
super().__post_init__() # check for unset fields

@property
def proto_name(self) -> str:
return self.proto_obj.name

@property
def full_proto_name(self) -> str:
return f"{self.parent.package_proto_obj.package}.{self.proto_obj.name}"

@property
def py_name(self) -> str:
return pythonize_class_name(self.proto_name)
Expand Down
89 changes: 88 additions & 1 deletion src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
{% endif %}

import betterproto
from betterproto.grpc.grpclib_server import ServiceBase
{% if output_file.services %}
import grpclib
{% endif %}
Expand Down Expand Up @@ -82,7 +83,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%} =
{%- endif -%} =
{%- if field.py_name not in method.mutable_default_args -%}
{{ field.default_value_string }}
{%- else -%}
Expand Down Expand Up @@ -154,6 +155,92 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endfor %}
{% endfor %}

{% for service in output_file.services %}
class {{ service.py_name }}Base(ServiceBase):
{% if service.comment %}
{{ service.comment }}

{% endif %}

{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%},
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message.py_name_as_field }}_iterator: AsyncIterable["{{ method.py_input_message_type }}"]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}

{% endif %}
raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)

{% endfor %}

{% for method in service.methods %}
async def __rpc_{{ method.py_name }}(self, stream) -> None:
{% if not method.client_streaming %}
request = await stream.recv_message()

request_kwargs = {
{% for field in method.py_input_message.fields %}
"{{ field.py_name }}": request.{{ field.py_name }},
{% endfor %}
}

{% else %}
request_kwargs = {"{{ method.py_input_message.py_name_as_field }}_iterator": stream.__aiter__()}
{% endif %}

{% if not method.server_streaming %}
await self._call_rpc_handler_server_unary(
self.{{ method.py_name }},
stream,
request_kwargs,
)
{% else %}
await self._call_rpc_handler_server_stream(
self.{{ method.py_name }},
stream,
request_kwargs,
)
{% endif %}

{% endfor %}

def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
return {
{% for method in service.methods %}
"{{ method.route }}": grpclib.const.Handler(
self.__rpc_{{ method.py_name }},
{% if not method.client_streaming and not method.server_streaming %}
grpclib.const.Cardinality.UNARY_UNARY,
{% elif not method.client_streaming and method.server_streaming %}
grpclib.const.Cardinality.UNARY_STREAM,
{% elif method.client_streaming and not method.server_streaming %}
grpclib.const.Cardinality.STREAM_UNARY,
{% else %}
grpclib.const.Cardinality.STREAM_STREAM,
{% endif %}
{{ method.py_input_message_type }},
{{ method.py_output_message_type }},
),
{% endfor %}
}

{% endfor %}

{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
23 changes: 23 additions & 0 deletions tests/inputs/example_service/example_service.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
syntax = "proto3";

package example_service;

service ExampleService {
rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse);
rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse);
rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse);
rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse);
}

message ExampleRequest {
string example_string = 1;
int64 example_integer = 2;
}

message ExampleResponse {
string example_string = 1;
int64 example_integer = 2;
}

// Suppress test framework error when it's looking for a "Test" message or service
message Test {}
95 changes: 95 additions & 0 deletions tests/inputs/example_service/test_example_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import AsyncIterator, AsyncIterable

import pytest
from grpclib.testing import ChannelFor

from tests.output_betterproto.example_service.example_service import (
ExampleServiceBase,
ExampleServiceStub,
ExampleRequest,
ExampleResponse,
)


class ExampleService(ExampleServiceBase):
async def example_unary_unary(
self, example_string: str, example_integer: int
) -> "ExampleResponse":
return ExampleResponse(
example_string=example_string,
example_integer=example_integer,
)

async def example_unary_stream(
self, example_string: str, example_integer: int
) -> AsyncIterator["ExampleResponse"]:
response = ExampleResponse(
example_string=example_string,
example_integer=example_integer,
)
yield response
yield response
yield response

async def example_stream_unary(
self, example_request_iterator: AsyncIterable["ExampleRequest"]
) -> "ExampleResponse":
async for example_request in example_request_iterator:
return ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)

async def example_stream_stream(
self, example_request_iterator: AsyncIterable["ExampleRequest"]
) -> AsyncIterator["ExampleResponse"]:
async for example_request in example_request_iterator:
yield ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)


@pytest.mark.asyncio
async def test_calls_with_different_cardinalities():
test_string = "test string"
test_int = 42

async with ChannelFor([ExampleService()]) as channel:
stub = ExampleServiceStub(channel)

# unary unary
response = await stub.example_unary_unary(
example_string="test string",
example_integer=42,
)
assert response.example_string == test_string
assert response.example_integer == test_int

# unary stream
async for response in stub.example_unary_stream(
example_string="test string",
example_integer=42,
):
assert response.example_string == test_string
assert response.example_integer == test_int

# stream unary
request = ExampleRequest(
example_string=test_string,
example_integer=42,
)

async def request_iterator():
yield request
yield request
yield request

response = await stub.example_stream_unary(request_iterator())
assert response.example_string == test_string
assert response.example_integer == test_int

# stream stream
async for response in stub.example_stream_stream(request_iterator()):
assert response.example_string == test_string
assert response.example_integer == test_int