diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index f7c57135f93d4..e9d6b55c784b9 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -1,7 +1,7 @@ import json from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Awaitable, Callable, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional from starlette.types import Scope @@ -574,7 +574,7 @@ class MultiplexedReplicaInfo: class gRPCRequest: """Sent from the GRPC proxy to replicas on both unary and streaming codepaths.""" - grpc_user_request: bytes + user_request_proto: Any class RequestProtocol(str, Enum): diff --git a/python/ray/serve/_private/proxy.py b/python/ray/serve/_private/proxy.py index 6b018e78d142f..4b4a4e2a6d83c 100644 --- a/python/ray/serve/_private/proxy.py +++ b/python/ray/serve/_private/proxy.py @@ -708,11 +708,8 @@ async def send_request_to_replica( proxy_request: ProxyRequest, app_is_cross_language: bool = False, ) -> ResponseGenerator: - handle_arg = proxy_request.request_object() response_generator = ProxyResponseGenerator( - # NOTE(edoakes): it's important that the request is sent as raw bytes to - # skip the Ray cloudpickle serialization codepath for performance. - handle.remote(pickle.dumps(handle_arg)), + handle.remote(proxy_request.serialized_replica_arg()), timeout_s=self.request_timeout_s, ) @@ -956,12 +953,8 @@ async def send_request_to_replica( # Response is returned as raw bytes, convert it to ASGI messages. result_callback = convert_object_to_asgi_messages else: - # NOTE(edoakes): it's important that the request is sent as raw bytes to - # skip the Ray cloudpickle serialization codepath for performance. - handle_arg_bytes = pickle.dumps( - proxy_request.request_object( - proxy_actor_name=self.self_actor_name, - ) + handle_arg_bytes = proxy_request.serialized_replica_arg( + proxy_actor_name=self.self_actor_name, ) # Messages are returned as pickled dictionaries. result_callback = pickle.loads diff --git a/python/ray/serve/_private/proxy_request_response.py b/python/ray/serve/_private/proxy_request_response.py index a1f6b42ecbce7..922621f624248 100644 --- a/python/ray/serve/_private/proxy_request_response.py +++ b/python/ray/serve/_private/proxy_request_response.py @@ -95,13 +95,14 @@ def set_path(self, path: str): def set_root_path(self, root_path: str): self.scope["root_path"] = root_path - def request_object( - self, - proxy_actor_name: str, - ) -> StreamingHTTPRequest: - return StreamingHTTPRequest( - asgi_scope=self.scope, - proxy_actor_name=proxy_actor_name, + def serialized_replica_arg(self, proxy_actor_name: str) -> bytes: + # NOTE(edoakes): it's important that the request is sent as raw bytes to + # skip the Ray cloudpickle serialization codepath for performance. + return pickle.dumps( + StreamingHTTPRequest( + asgi_scope=self.scope, + proxy_actor_name=proxy_actor_name, + ) ) @@ -115,7 +116,7 @@ def __init__( service_method: str, stream: bool, ): - self.request = request_proto + self._request_proto = request_proto self.context = context self.service_method = service_method self.stream = stream @@ -131,7 +132,6 @@ def __init__( def setup_variables(self): if not self.is_route_request and not self.is_health_request: service_method_split = self.service_method.split("/") - self.request = pickle.dumps(self.request) self.method_name = service_method_split[-1] for key, value in self.context.invocation_metadata(): if key == "application": @@ -161,20 +161,16 @@ def is_route_request(self) -> bool: def is_health_request(self) -> bool: return self.service_method == "/ray.serve.RayServeAPIService/Healthz" - @property - def user_request(self) -> bytes: - return self.request - def send_request_id(self, request_id: str): # Setting the trailing metadata on the ray_serve_grpc_context object, so it's # not overriding the ones set from the user and will be sent back to the # client altogether. self.ray_serve_grpc_context.set_trailing_metadata([("request_id", request_id)]) - def request_object(self) -> gRPCRequest: - return gRPCRequest( - grpc_user_request=self.user_request, - ) + def serialized_replica_arg(self) -> bytes: + # NOTE(edoakes): it's important that the request is sent as raw bytes to + # skip the Ray cloudpickle serialization codepath for performance. + return pickle.dumps(gRPCRequest(user_request_proto=self._request_proto)) @dataclass(frozen=True) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 382b46a27ea10..dcae08bdecd10 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -1363,17 +1363,18 @@ def _prepare_args_for_grpc_request( request_metadata: RequestMetadata, user_method_params: Dict[str, inspect.Parameter], ) -> Tuple[Tuple[Any], Dict[str, Any]]: - """Prepare arguments for a user method handling a gRPC request. + """Prepare args and kwargs for a user method handling a gRPC request. - Returns (request_args, request_kwargs). + The sole argument is always the user request proto. + + If the method has a "context" kwarg, we pass the gRPC context, else no kwargs. """ - request_args = (pickle.loads(request.grpc_user_request),) if GRPC_CONTEXT_ARG_NAME in user_method_params: request_kwargs = {GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context} else: request_kwargs = {} - return request_args, request_kwargs + return (request.user_request_proto,), request_kwargs async def _handle_user_method_result( self, @@ -1492,7 +1493,6 @@ async def call_user_method( generator_result_callback=generator_result_callback, ) elif request_metadata.is_grpc_request: - # Ensure the request args are a single gRPCRequest object. assert len(request_args) == 1 and isinstance( request_args[0], gRPCRequest ) diff --git a/python/ray/serve/tests/unit/test_proxy_request_response.py b/python/ray/serve/tests/unit/test_proxy_request_response.py index 7b2c4388b6572..51768f7a74b0e 100644 --- a/python/ray/serve/tests/unit/test_proxy_request_response.py +++ b/python/ray/serve/tests/unit/test_proxy_request_response.py @@ -219,7 +219,6 @@ def test_calling_user_defined_method(self): ) assert isinstance(proxy_request, ProxyRequest) assert proxy_request.route_path == application - assert pickle.loads(proxy_request.request) == request_proto assert proxy_request.method_name == method_name assert proxy_request.app_name == application assert proxy_request.request_id == request_id @@ -232,9 +231,11 @@ def test_calling_user_defined_method(self): ("request_id", request_id) ] - request_object = proxy_request.request_object() + serialized_arg = proxy_request.serialized_replica_arg() + assert isinstance(serialized_arg, bytes) + request_object = pickle.loads(serialized_arg) assert isinstance(request_object, gRPCRequest) - assert pickle.loads(request_object.grpc_user_request) == request_proto + assert request_object.user_request_proto == request_proto if __name__ == "__main__": diff --git a/python/ray/serve/tests/unit/test_user_callable_wrapper.py b/python/ray/serve/tests/unit/test_user_callable_wrapper.py index 39188f32c421b..b2ad0ed52e675 100644 --- a/python/ray/serve/tests/unit/test_user_callable_wrapper.py +++ b/python/ray/serve/tests/unit/test_user_callable_wrapper.py @@ -556,10 +556,7 @@ def test_grpc_unary_request(run_sync_methods_in_threadpool: bool): ) user_callable_wrapper.initialize_callable().result() - grpc_request = gRPCRequest( - pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world")) - ) - + grpc_request = gRPCRequest(serve_pb2.UserDefinedResponse(greeting="world")) request_metadata = _make_request_metadata(call_method="greet", is_grpc_request=True) _, result_bytes = user_callable_wrapper.call_user_method( request_metadata, (grpc_request,), dict() @@ -579,9 +576,7 @@ def test_grpc_streaming_request(run_sync_methods_in_threadpool: bool): ) user_callable_wrapper.initialize_callable() - grpc_request = gRPCRequest( - pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world")) - ) + grpc_request = gRPCRequest(serve_pb2.UserDefinedResponse(greeting="world")) result_list = []