From 49df36a46b1ba441760298a130cfe5de56c95acb Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Sat, 18 Jan 2025 08:05:17 -0600 Subject: [PATCH 1/5] fix Signed-off-by: Edward Oakes --- python/ray/serve/_private/proxy.py | 13 +++------- .../serve/_private/proxy_request_response.py | 26 +++++++------------ python/ray/serve/_private/replica.py | 22 +++++++--------- .../tests/unit/test_user_callable_wrapper.py | 14 +++------- 4 files changed, 26 insertions(+), 49 deletions(-) diff --git a/python/ray/serve/_private/proxy.py b/python/ray/serve/_private/proxy.py index 6b018e78d142..4b4a4e2a6d83 100644 --- a/python/ray/serve/_private/proxy.py +++ b/python/ray/serve/_private/proxy.py @@ -708,11 +708,8 @@ async def send_request_to_replica( proxy_request: ProxyRequest, app_is_cross_language: bool = False, ) -> ResponseGenerator: - handle_arg = proxy_request.request_object() response_generator = ProxyResponseGenerator( - # NOTE(edoakes): it's important that the request is sent as raw bytes to - # skip the Ray cloudpickle serialization codepath for performance. - handle.remote(pickle.dumps(handle_arg)), + handle.remote(proxy_request.serialized_replica_arg()), timeout_s=self.request_timeout_s, ) @@ -956,12 +953,8 @@ async def send_request_to_replica( # Response is returned as raw bytes, convert it to ASGI messages. result_callback = convert_object_to_asgi_messages else: - # NOTE(edoakes): it's important that the request is sent as raw bytes to - # skip the Ray cloudpickle serialization codepath for performance. - handle_arg_bytes = pickle.dumps( - proxy_request.request_object( - proxy_actor_name=self.self_actor_name, - ) + handle_arg_bytes = proxy_request.serialized_replica_arg( + proxy_actor_name=self.self_actor_name, ) # Messages are returned as pickled dictionaries. result_callback = pickle.loads diff --git a/python/ray/serve/_private/proxy_request_response.py b/python/ray/serve/_private/proxy_request_response.py index a1f6b42ecbce..4d9af2d4868a 100644 --- a/python/ray/serve/_private/proxy_request_response.py +++ b/python/ray/serve/_private/proxy_request_response.py @@ -95,14 +95,13 @@ def set_path(self, path: str): def set_root_path(self, root_path: str): self.scope["root_path"] = root_path - def request_object( - self, - proxy_actor_name: str, - ) -> StreamingHTTPRequest: - return StreamingHTTPRequest( + def serialized_replica_arg(self, proxy_actor_name: str) -> bytes: + # NOTE(edoakes): it's important that the request is sent as raw bytes to + # skip the Ray cloudpickle serialization codepath for performance. + return pickle.dumps(StreamingHTTPRequest( asgi_scope=self.scope, proxy_actor_name=proxy_actor_name, - ) + )) class gRPCProxyRequest(ProxyRequest): @@ -115,7 +114,7 @@ def __init__( service_method: str, stream: bool, ): - self.request = request_proto + self._request_proto = request_proto self.context = context self.service_method = service_method self.stream = stream @@ -131,7 +130,6 @@ def __init__( def setup_variables(self): if not self.is_route_request and not self.is_health_request: service_method_split = self.service_method.split("/") - self.request = pickle.dumps(self.request) self.method_name = service_method_split[-1] for key, value in self.context.invocation_metadata(): if key == "application": @@ -161,20 +159,16 @@ def is_route_request(self) -> bool: def is_health_request(self) -> bool: return self.service_method == "/ray.serve.RayServeAPIService/Healthz" - @property - def user_request(self) -> bytes: - return self.request - def send_request_id(self, request_id: str): # Setting the trailing metadata on the ray_serve_grpc_context object, so it's # not overriding the ones set from the user and will be sent back to the # client altogether. self.ray_serve_grpc_context.set_trailing_metadata([("request_id", request_id)]) - def request_object(self) -> gRPCRequest: - return gRPCRequest( - grpc_user_request=self.user_request, - ) + def serialized_replica_arg(self) -> bytes: + # NOTE(edoakes): it's important that the request is sent as raw bytes to + # skip the Ray cloudpickle serialization codepath for performance. + return pickle.dumps(self._request_proto) @dataclass(frozen=True) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 382b46a27ea1..1e39347da180 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -1357,23 +1357,21 @@ async def _send(message: Message): return request_args, asgi_args, receive_task - def _prepare_args_for_grpc_request( + def _prepare_kwargs_for_grpc_request( self, - request: gRPCRequest, request_metadata: RequestMetadata, user_method_params: Dict[str, inspect.Parameter], - ) -> Tuple[Tuple[Any], Dict[str, Any]]: - """Prepare arguments for a user method handling a gRPC request. + ) -> Dict[str, Any]: + """Prepare kwargs for a user method handling a gRPC request. - Returns (request_args, request_kwargs). + If the method has a "context" kwarg, we pass it the gRPC context, else nothing. """ - request_args = (pickle.loads(request.grpc_user_request),) if GRPC_CONTEXT_ARG_NAME in user_method_params: request_kwargs = {GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context} else: request_kwargs = {} - return request_args, request_kwargs + return request_kwargs async def _handle_user_method_result( self, @@ -1492,12 +1490,10 @@ async def call_user_method( generator_result_callback=generator_result_callback, ) elif request_metadata.is_grpc_request: - # Ensure the request args are a single gRPCRequest object. - assert len(request_args) == 1 and isinstance( - request_args[0], gRPCRequest - ) - request_args, request_kwargs = self._prepare_args_for_grpc_request( - request_args[0], request_metadata, user_method_params + # The sole request argument is the user proto request object. + assert len(request_args) == 1 + request_kwargs = self._prepare_kwargs_for_grpc_request( + request_metadata, user_method_params ) result, sync_gen_consumed = await self._call_func_or_gen( diff --git a/python/ray/serve/tests/unit/test_user_callable_wrapper.py b/python/ray/serve/tests/unit/test_user_callable_wrapper.py index 39188f32c421..96e873832b3c 100644 --- a/python/ray/serve/tests/unit/test_user_callable_wrapper.py +++ b/python/ray/serve/tests/unit/test_user_callable_wrapper.py @@ -17,7 +17,6 @@ RequestMetadata, RequestProtocol, StreamingHTTPRequest, - gRPCRequest, ) from ray.serve._private.replica import UserCallableWrapper from ray.serve.generated import serve_pb2 @@ -556,13 +555,10 @@ def test_grpc_unary_request(run_sync_methods_in_threadpool: bool): ) user_callable_wrapper.initialize_callable().result() - grpc_request = gRPCRequest( - pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world")) - ) - + request_proto = serve_pb2.UserDefinedResponse(greeting="world") request_metadata = _make_request_metadata(call_method="greet", is_grpc_request=True) _, result_bytes = user_callable_wrapper.call_user_method( - request_metadata, (grpc_request,), dict() + request_metadata, (request_proto,), dict() ).result() assert isinstance(result_bytes, bytes) @@ -579,9 +575,7 @@ def test_grpc_streaming_request(run_sync_methods_in_threadpool: bool): ) user_callable_wrapper.initialize_callable() - grpc_request = gRPCRequest( - pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world")) - ) + request_proto = serve_pb2.UserDefinedResponse(greeting="world") result_list = [] @@ -590,7 +584,7 @@ def test_grpc_streaming_request(run_sync_methods_in_threadpool: bool): ) user_callable_wrapper.call_user_method( request_metadata, - (grpc_request,), + (request_proto,), dict(), generator_result_callback=result_list.append, ).result() From 24311c2862f76dcda19a368be9ab9906feecd86e Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Sat, 18 Jan 2025 08:11:04 -0600 Subject: [PATCH 2/5] fix Signed-off-by: Edward Oakes --- python/ray/serve/_private/common.py | 5 ++-- .../serve/_private/proxy_request_response.py | 16 ++++++++----- python/ray/serve/_private/replica.py | 24 +++++++++++-------- .../tests/unit/test_user_callable_wrapper.py | 9 +++---- 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index f7c57135f93d..ce32a32670f1 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -1,7 +1,7 @@ import json from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Awaitable, Callable, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional from starlette.types import Scope @@ -573,8 +573,7 @@ class MultiplexedReplicaInfo: @dataclass class gRPCRequest: """Sent from the GRPC proxy to replicas on both unary and streaming codepaths.""" - - grpc_user_request: bytes + user_request_proto: Any class RequestProtocol(str, Enum): diff --git a/python/ray/serve/_private/proxy_request_response.py b/python/ray/serve/_private/proxy_request_response.py index 4d9af2d4868a..b93eb04633a4 100644 --- a/python/ray/serve/_private/proxy_request_response.py +++ b/python/ray/serve/_private/proxy_request_response.py @@ -7,7 +7,7 @@ import grpc from starlette.types import Receive, Scope, Send -from ray.serve._private.common import StreamingHTTPRequest, gRPCRequest +from ray.serve._private.common import gRPCRequest, StreamingHTTPRequest from ray.serve._private.constants import SERVE_LOGGER_NAME from ray.serve._private.utils import DEFAULT from ray.serve.grpc_util import RayServegRPCContext @@ -98,10 +98,12 @@ def set_root_path(self, root_path: str): def serialized_replica_arg(self, proxy_actor_name: str) -> bytes: # NOTE(edoakes): it's important that the request is sent as raw bytes to # skip the Ray cloudpickle serialization codepath for performance. - return pickle.dumps(StreamingHTTPRequest( - asgi_scope=self.scope, - proxy_actor_name=proxy_actor_name, - )) + return pickle.dumps( + StreamingHTTPRequest( + asgi_scope=self.scope, + proxy_actor_name=proxy_actor_name, + ) + ) class gRPCProxyRequest(ProxyRequest): @@ -168,7 +170,9 @@ def send_request_id(self, request_id: str): def serialized_replica_arg(self) -> bytes: # NOTE(edoakes): it's important that the request is sent as raw bytes to # skip the Ray cloudpickle serialization codepath for performance. - return pickle.dumps(self._request_proto) + return pickle.dumps(gRPCRequest( + user_request_proto=self._request_proto + )) @dataclass(frozen=True) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 1e39347da180..68647548f239 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -39,8 +39,8 @@ ReplicaQueueLengthInfo, RequestMetadata, ServeComponentType, - StreamingHTTPRequest, gRPCRequest, + StreamingHTTPRequest, ) from ray.serve._private.config import DeploymentConfig from ray.serve._private.constants import ( @@ -1357,21 +1357,24 @@ async def _send(message: Message): return request_args, asgi_args, receive_task - def _prepare_kwargs_for_grpc_request( + def _prepare_args_for_grpc_request( self, + request: gRPCRequest, request_metadata: RequestMetadata, user_method_params: Dict[str, inspect.Parameter], - ) -> Dict[str, Any]: - """Prepare kwargs for a user method handling a gRPC request. + ) -> Tuple[Tuple[Any], Dict[str, Any]]: + """Prepare args and kwargs for a user method handling a gRPC request. + + The sole argument is always the user request proto. - If the method has a "context" kwarg, we pass it the gRPC context, else nothing. + If the method has a "context" kwarg, we pass the gRPC context, else no kwargs. """ if GRPC_CONTEXT_ARG_NAME in user_method_params: request_kwargs = {GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context} else: request_kwargs = {} - return request_kwargs + return (request.user_request_proto,), request_kwargs async def _handle_user_method_result( self, @@ -1490,10 +1493,11 @@ async def call_user_method( generator_result_callback=generator_result_callback, ) elif request_metadata.is_grpc_request: - # The sole request argument is the user proto request object. - assert len(request_args) == 1 - request_kwargs = self._prepare_kwargs_for_grpc_request( - request_metadata, user_method_params + assert len(request_args) == 1 and isinstance( + request_args[0], gRPCRequest + ) + request_args, request_kwargs = self._prepare_args_for_grpc_request( + request_args[0], request_metadata, user_method_params ) result, sync_gen_consumed = await self._call_func_or_gen( diff --git a/python/ray/serve/tests/unit/test_user_callable_wrapper.py b/python/ray/serve/tests/unit/test_user_callable_wrapper.py index 96e873832b3c..99c3de736d95 100644 --- a/python/ray/serve/tests/unit/test_user_callable_wrapper.py +++ b/python/ray/serve/tests/unit/test_user_callable_wrapper.py @@ -16,6 +16,7 @@ DeploymentID, RequestMetadata, RequestProtocol, + gRPCRequest, StreamingHTTPRequest, ) from ray.serve._private.replica import UserCallableWrapper @@ -555,10 +556,10 @@ def test_grpc_unary_request(run_sync_methods_in_threadpool: bool): ) user_callable_wrapper.initialize_callable().result() - request_proto = serve_pb2.UserDefinedResponse(greeting="world") + grpc_request = gRPCRequest(serve_pb2.UserDefinedResponse(greeting="world")) request_metadata = _make_request_metadata(call_method="greet", is_grpc_request=True) _, result_bytes = user_callable_wrapper.call_user_method( - request_metadata, (request_proto,), dict() + request_metadata, (grpc_request,), dict() ).result() assert isinstance(result_bytes, bytes) @@ -575,7 +576,7 @@ def test_grpc_streaming_request(run_sync_methods_in_threadpool: bool): ) user_callable_wrapper.initialize_callable() - request_proto = serve_pb2.UserDefinedResponse(greeting="world") + grpc_request = gRPCRequest(serve_pb2.UserDefinedResponse(greeting="world")) result_list = [] @@ -584,7 +585,7 @@ def test_grpc_streaming_request(run_sync_methods_in_threadpool: bool): ) user_callable_wrapper.call_user_method( request_metadata, - (request_proto,), + (grpc_request,), dict(), generator_result_callback=result_list.append, ).result() From 08cef437d39e8f09bd0d65364c925f4454edd16f Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Sat, 18 Jan 2025 08:11:15 -0600 Subject: [PATCH 3/5] fix Signed-off-by: Edward Oakes --- python/ray/serve/_private/common.py | 1 + python/ray/serve/_private/proxy_request_response.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index ce32a32670f1..e9d6b55c784b 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -573,6 +573,7 @@ class MultiplexedReplicaInfo: @dataclass class gRPCRequest: """Sent from the GRPC proxy to replicas on both unary and streaming codepaths.""" + user_request_proto: Any diff --git a/python/ray/serve/_private/proxy_request_response.py b/python/ray/serve/_private/proxy_request_response.py index b93eb04633a4..fbb023fa2535 100644 --- a/python/ray/serve/_private/proxy_request_response.py +++ b/python/ray/serve/_private/proxy_request_response.py @@ -170,9 +170,7 @@ def send_request_id(self, request_id: str): def serialized_replica_arg(self) -> bytes: # NOTE(edoakes): it's important that the request is sent as raw bytes to # skip the Ray cloudpickle serialization codepath for performance. - return pickle.dumps(gRPCRequest( - user_request_proto=self._request_proto - )) + return pickle.dumps(gRPCRequest(user_request_proto=self._request_proto)) @dataclass(frozen=True) From 5dd45ff28857b1a08dffed3c343ed637e7736b65 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Sat, 18 Jan 2025 08:14:50 -0600 Subject: [PATCH 4/5] fix import orders Signed-off-by: Edward Oakes --- python/ray/serve/_private/proxy_request_response.py | 2 +- python/ray/serve/_private/replica.py | 2 +- python/ray/serve/tests/unit/test_user_callable_wrapper.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/_private/proxy_request_response.py b/python/ray/serve/_private/proxy_request_response.py index fbb023fa2535..922621f62424 100644 --- a/python/ray/serve/_private/proxy_request_response.py +++ b/python/ray/serve/_private/proxy_request_response.py @@ -7,7 +7,7 @@ import grpc from starlette.types import Receive, Scope, Send -from ray.serve._private.common import gRPCRequest, StreamingHTTPRequest +from ray.serve._private.common import StreamingHTTPRequest, gRPCRequest from ray.serve._private.constants import SERVE_LOGGER_NAME from ray.serve._private.utils import DEFAULT from ray.serve.grpc_util import RayServegRPCContext diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 68647548f239..dcae08bdecd1 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -39,8 +39,8 @@ ReplicaQueueLengthInfo, RequestMetadata, ServeComponentType, - gRPCRequest, StreamingHTTPRequest, + gRPCRequest, ) from ray.serve._private.config import DeploymentConfig from ray.serve._private.constants import ( diff --git a/python/ray/serve/tests/unit/test_user_callable_wrapper.py b/python/ray/serve/tests/unit/test_user_callable_wrapper.py index 99c3de736d95..b2ad0ed52e67 100644 --- a/python/ray/serve/tests/unit/test_user_callable_wrapper.py +++ b/python/ray/serve/tests/unit/test_user_callable_wrapper.py @@ -16,8 +16,8 @@ DeploymentID, RequestMetadata, RequestProtocol, - gRPCRequest, StreamingHTTPRequest, + gRPCRequest, ) from ray.serve._private.replica import UserCallableWrapper from ray.serve.generated import serve_pb2 From dcde2b2f1a91943f7d15e3dc9fb91c54de6e51f8 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Sat, 18 Jan 2025 10:38:17 -0600 Subject: [PATCH 5/5] fix Signed-off-by: Edward Oakes --- python/ray/serve/tests/unit/test_proxy_request_response.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/tests/unit/test_proxy_request_response.py b/python/ray/serve/tests/unit/test_proxy_request_response.py index 7b2c4388b657..51768f7a74b0 100644 --- a/python/ray/serve/tests/unit/test_proxy_request_response.py +++ b/python/ray/serve/tests/unit/test_proxy_request_response.py @@ -219,7 +219,6 @@ def test_calling_user_defined_method(self): ) assert isinstance(proxy_request, ProxyRequest) assert proxy_request.route_path == application - assert pickle.loads(proxy_request.request) == request_proto assert proxy_request.method_name == method_name assert proxy_request.app_name == application assert proxy_request.request_id == request_id @@ -232,9 +231,11 @@ def test_calling_user_defined_method(self): ("request_id", request_id) ] - request_object = proxy_request.request_object() + serialized_arg = proxy_request.serialized_replica_arg() + assert isinstance(serialized_arg, bytes) + request_object = pickle.loads(serialized_arg) assert isinstance(request_object, gRPCRequest) - assert pickle.loads(request_object.grpc_user_request) == request_proto + assert request_object.user_request_proto == request_proto if __name__ == "__main__":