Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: run warmup on Runtimes and Executor #5579

Merged
merged 40 commits into from
Jan 14, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
5817aac
feat: add warmup coroutine to Gateway and Worker runtimes
Jan 6, 2023
330dd7a
style: fix overload and cli autocomplete
jina-bot Jan 6, 2023
10947b6
test: reduce traced operations to account for warmup
Jan 6, 2023
967111a
refactor: user threding.Event to signal graceful warmup task cancella…
Jan 9, 2023
02faa77
style: fix overload and cli autocomplete
jina-bot Jan 9, 2023
23e1777
chore: reduce debug logging
Jan 9, 2023
aa9189d
feat: implement warmup using discovery requests
Jan 9, 2023
6824fbf
refactor: use asyncio.sleep
Jan 9, 2023
b70008c
feat: create warmup task per deployment
Jan 10, 2023
c5155e8
feat: implement warmup for HeadRuntime
Jan 10, 2023
280d481
feat: remove executor warmup task
Jan 10, 2023
2885b4e
feat: don't warmup deprecated head uses before and after
Jan 11, 2023
0f2bf67
Merge remote-tracking branch 'origin/master' into feat-serve-5467-run…
Jan 11, 2023
2be40b1
style: fix overload and cli autocomplete
jina-bot Jan 11, 2023
4651f3b
fix: revert changes to unrelated tests
Jan 11, 2023
d1c7e21
Merge remote-tracking branch 'origin/feat-serve-5467-runtime-warmup' …
Jan 11, 2023
111db68
test: start worker deployment before gateway
Jan 11, 2023
8090ea8
fix: don't use asyncio.gather or single task
Jan 11, 2023
fcb122b
Merge branch 'master' into feat-serve-5467-runtime-warmup
girishc13 Jan 11, 2023
367a115
fix: await gather_endpoints coroutine
Jan 12, 2023
34d7d4b
feat: wait until grpc channel is ready before warmup
Jan 12, 2023
4c86668
Revert "feat: wait until grpc channel is ready before warmup"
Jan 12, 2023
a82c128
feat: create JinaInfoRPC stub in the ConnectionStubs for reuse
Jan 12, 2023
87ee55f
feat: enable grpc SO_REUSEPORT for multi process/threading
Jan 12, 2023
017dfea
Revert "feat: enable grpc SO_REUSEPORT for multi process/threading"
Jan 12, 2023
99e0823
Revert "fix: await gather_endpoints coroutine"
Jan 12, 2023
e11a41f
ci: debug warmup task
Jan 12, 2023
ab6119d
fix: pop removed connection channel from dict
Jan 12, 2023
e9b6561
Merge remote-tracking branch 'origin/master' into feat-serve-5467-run…
Jan 12, 2023
479c810
fix: pop removed connection channel from dict
Jan 12, 2023
95a5767
Revert "fix: pop removed connection channel from dict"
Jan 13, 2023
9deda11
Revert "fix: pop removed connection channel from dict"
Jan 13, 2023
dd3516b
Revert "feat: create JinaInfoRPC stub in the ConnectionStubs for reuse"
Jan 13, 2023
ada4e5e
feat: create duplicate stubs for warmup requests
Jan 13, 2023
a23caf1
Merge remote-tracking branch 'origin/master' into feat-serve-5467-run…
Jan 13, 2023
51c7e60
style: fix overload and cli autocomplete
jina-bot Jan 13, 2023
855bd7c
Merge remote-tracking branch 'origin/master' into feat-serve-5467-run…
Jan 13, 2023
7211337
chore: remove debug logging
Jan 13, 2023
463a998
chore: clean up imports
Jan 13, 2023
6232b42
feat: close channels created for warmup stubs
Jan 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions jina/serve/networking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import ipaddress
import os
import threading
import time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove unneeded imports

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
Expand Down Expand Up @@ -613,9 +615,6 @@ def _get_connection_list(
return self._get_connection_list(
deployment, type_, 0, increase_access_count
)
self._logger.debug(
f'did not find a connection for deployment {deployment}, type {type_} and entity_id {entity_id}. There are {len(self._deployments[deployment][type_]) if deployment in self._deployments else 0} available connections for this deployment and type. '
)
return None

def _add_deployment(self, deployment: str):
Expand Down Expand Up @@ -1114,6 +1113,71 @@ async def task_wrapper():

return asyncio.create_task(task_wrapper())

async def warmup(
self,
deployment: str,
stop_event: threading.Event,
):
'''Executes JinaInfoRPC against the provided deployment. A single task is created for each replica connection.
:param deployment: deployment name and the replicas that needs to be warmed up.
:param stop_event: signal to indicate if an early termination of the task is required for graceful teardown.
'''

async def task_wrapper(target_warmup_responses, target, channel):
try:
stub = jina_pb2_grpc.JinaInfoRPCStub(channel=channel)
call_result = stub._status(
request=jina_pb2.google_dot_protobuf_dot_empty__pb2.Empty(),
)
await call_result
target_warmup_responses[target] = True
except Exception:
target_warmup_responses[target] = False

try:
timeout = time.time() + 60 * 5 # 5 minutes from now
warmed_up_targets = set()

while not stop_event.is_set():
# refresh channels in case connection has been reset due to InternalNetworkError
target_to_channel = self.__extract_target_to_channel(deployment)
for warmed_target in warmed_up_targets:
target_to_channel.pop(warmed_target)

replica_warmup_responses = {}
tasks = []
for target, channel in target_to_channel.items():
tasks.append(
asyncio.create_task(
task_wrapper(replica_warmup_responses, target, channel)
)
)
await asyncio.gather(*tasks, return_exceptions=True)

for target, response in replica_warmup_responses.items():
if response:
warmed_up_targets.add(target)

if time.time() > timeout or len(target_to_channel) == 0:
return

await asyncio.sleep(0.2)
except Exception as ex:
self._logger.error(f'error with warmup up task: {ex}')
return

def __extract_target_to_channel(self, deployment):
replica_set = set()
replica_set.update(self._connections.get_replicas_all_shards(deployment))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can extract out the warmup logic into a Warmer class but the self._connections is the inner class _ConnectionPoolMap of the GrpcConnectionPool. Does the inner class still make sense?

replica_set.add(
self._connections.get_replicas(deployment=deployment, head=True)
)

target_to_channel = {}
for replica_list in filter(None, replica_set):
target_to_channel.update(replica_list._address_to_channel)
return target_to_channel

@staticmethod
def __aio_channel_with_tracing_interceptor(
address,
Expand Down
17 changes: 17 additions & 0 deletions jina/serve/runtimes/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import asyncio
import signal
import threading
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unneeded import

import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Union
Expand Down Expand Up @@ -76,6 +77,8 @@ def _cancel(signum, frame):
self._start_time = time.time()
self._loop.run_until_complete(self.async_setup())
self._send_telemetry_event()
self.warmup_task = None
self.warmup_stop_event = threading.Event()

def _send_telemetry_event(self):
send_telemetry_event(event='start', obj=self, entity_id=self._entity_id)
Expand Down Expand Up @@ -161,6 +164,20 @@ async def async_run_forever(self):
"""The async method to run until it is stopped."""
...

async def cancel_warmup_task(self):
'''Cancel warmup task if exists and is not completed. Cancellation is required if the Flow is being terminated before the
task is successful or hasn't reached the max timeout.
'''
if self.warmup_task:
try:
if not self.warmup_task.done():
self.logger.debug(f'Cancelling warmup task.')
self.warmup_stop_event.set()
await self.warmup_task
self.warmup_task.exception()
except:
pass

# Static methods used by the Pod to communicate with the `Runtime` in the separate process

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions jina/serve/runtimes/gateway/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,22 @@ async def _wait_for_cancel(self):

async def async_teardown(self):
"""Shutdown the server."""
await self.cancel_warmup_task()
await self.gateway.streamer.close()
await self.gateway.shutdown()
await self.async_cancel()

async def async_cancel(self):
"""Stop the server."""
await self.cancel_warmup_task()
await self.gateway.streamer.close()
await self.gateway.shutdown()

async def async_run_forever(self):
"""Running method of the server."""
self.warmup_task = asyncio.create_task(
self.gateway.streamer.warmup(self.warmup_stop_event)
)
await self.gateway.run_server()
self.is_cancel.set()

Expand Down
16 changes: 13 additions & 3 deletions jina/serve/runtimes/gateway/composite/gateway.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import copy
from typing import Any, List, Optional

Expand Down Expand Up @@ -36,18 +37,27 @@ async def setup_server(self):
"""
setup GRPC server
"""
tasks = []
for gateway in self.gateways:
await gateway.setup_server()
tasks.append(asyncio.create_task(gateway.setup_server()))

await asyncio.gather(*tasks)

async def shutdown(self):
"""Free other resources allocated with the server, e.g, gateway object, ..."""
shutdown_tasks = []
for gateway in self.gateways:
await gateway.shutdown()
shutdown_tasks.append(asyncio.create_task(gateway.shutdown()))

await asyncio.gather(*shutdown_tasks)

async def run_server(self):
"""Run GRPC server forever"""
run_server_tasks = []
for gateway in self.gateways:
await gateway.run_server()
run_server_tasks.append(asyncio.create_task(gateway.run_server()))

await asyncio.gather(*run_server_tasks)

@staticmethod
def _deepcopy_with_ignore_attrs(obj: Any, ignore_attrs: List[str]) -> Any:
Expand Down
17 changes: 15 additions & 2 deletions jina/serve/runtimes/head/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import asyncio
import json
import os
from abc import ABC
Expand Down Expand Up @@ -158,24 +159,35 @@ async def async_setup(self):
service, health_pb2.HealthCheckResponse.SERVING
)
reflection.enable_server_reflection(service_names, self._grpc_server)

bind_addr = f'{self.args.host}:{self.args.port}'
self._grpc_server.add_insecure_port(bind_addr)
self.logger.debug(f'start listening on {bind_addr}')
await self._grpc_server.start()

def _warmup(self):
self.warmup_task = asyncio.create_task(
self.request_handler.warmup(
connection_pool=self.connection_pool,
stop_event=self.warmup_stop_event,
deployment=self._deployment_name,
)
)

async def async_run_forever(self):
"""Block until the GRPC server is terminated"""
self._warmup()
await self._grpc_server.wait_for_termination()

async def async_cancel(self):
"""Stop the GRPC server"""
self.logger.debug('cancel HeadRuntime')

await self.cancel_warmup_task()
await self._grpc_server.stop(0)

async def async_teardown(self):
"""Close the connection pool"""
await self.cancel_warmup_task()
await self._health_servicer.enter_graceful_shutdown()
await self.async_cancel()
await self.connection_pool.close()
Expand Down Expand Up @@ -294,6 +306,7 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
:param context: grpc context
:returns: the response request
"""
self.logger.debug('recv _status request')
infoProto = jina_pb2.JinaInfoProto()
version, env_info = get_full_version()
for k, v in version.items():
Expand Down
30 changes: 28 additions & 2 deletions jina/serve/runtimes/head/request_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import threading
girishc13 marked this conversation as resolved.
Show resolved Hide resolved
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.monitoring import MonitoringRequestMixin
from jina.serve.runtimes.worker.request_handling import WorkerRequestHandler

Expand Down Expand Up @@ -164,7 +167,9 @@ async def _handle_data_request(
elif len(worker_results) > 1 and not reduce:
# worker returned multiple responses, but the head is configured to skip reduction
# just concatenate the docs in this case
response_request.data.docs = WorkerRequestHandler.get_docs_from_request(requests)
response_request.data.docs = WorkerRequestHandler.get_docs_from_request(
requests
)

merged_metadata = self._merge_metadata(
metadata,
Expand All @@ -177,3 +182,24 @@ async def _handle_data_request(
self._update_end_request_metrics(response_request)

return response_request, merged_metadata

async def warmup(
self,
connection_pool: GrpcConnectionPool,
stop_event: threading.Event,
deployment: str,
):
'''Executes warmup task against the deployments from the connection pool.
:param connection_pool: GrpcConnectionPool that implements the warmup to the connected deployments.
:param stop_event: signal to indicate if an early termination of the task is required for graceful teardown.
:param deployment: deployment name that need to be warmed up.
'''
self.logger.debug(f'Running HeadRuntime warmup')

try:
await asyncio.create_task(
connection_pool.warmup(deployment=deployment, stop_event=stop_event)
)
except Exception as ex:
self.logger.error(f'error with HeadRuntime warmup up task: {ex}')
return
3 changes: 2 additions & 1 deletion jina/serve/runtimes/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def _async_setup_grpc_server(self):
self._health_servicer, self._grpc_server
)

reflection.enable_server_reflection(service_names, self._grpc_server)
reflection.enable_server_reflection(service_names, self._grpc_server)
bind_addr = f'{self.args.host}:{self.args.port}'
self.logger.debug(f'start listening on {bind_addr}')
self._grpc_server.add_insecure_port(bind_addr)
Expand Down Expand Up @@ -306,6 +306,7 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
:param context: grpc context
:returns: the response request
"""
self.logger.debug('recv _status request')
info_proto = jina_pb2.JinaInfoProto()
version, env_info = get_full_version()
for k, v in version.items():
Expand Down
4 changes: 1 addition & 3 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,7 @@ def get_docs_from_request(
"""
if len(requests) > 1:
result = DocumentArray(
d
for r in reversed(requests)
for d in getattr(r, 'docs')
d for r in reversed(requests) for d in getattr(r, 'docs')
)
else:
result = getattr(requests[0], 'docs')
Expand Down
32 changes: 31 additions & 1 deletion jina/serve/streamer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
import json
import os
import threading
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union

from docarray import DocumentArray
from docarray import Document, DocumentArray

from jina.logging.logger import JinaLogger
from jina.serve.networking import GrpcConnectionPool
Expand Down Expand Up @@ -65,6 +68,7 @@ def __init__(
:param aio_tracing_client_interceptors: Optional list of aio grpc tracing server interceptors.
:param tracing_client_interceptor: Optional gprc tracing server interceptor.
"""
self.logger = logger
topology_graph = TopologyGraph(
graph_representation=graph_representation,
graph_conditions=graph_conditions,
Expand All @@ -78,6 +82,7 @@ def __init__(
self.runtime_name = runtime_name
self.aio_tracing_client_interceptors = aio_tracing_client_interceptors
self.tracing_client_interceptor = tracing_client_interceptor
self._executor_addresses = executor_addresses

self._connection_pool = self._create_connection_pool(
executor_addresses,
Expand Down Expand Up @@ -221,3 +226,28 @@ def get_streamer():
@staticmethod
def _set_env_streamer_args(**kwargs):
os.environ['JINA_STREAMER_ARGS'] = json.dumps(kwargs)

async def warmup(self, stop_event: threading.Event):
'''Executes warmup task on each deployment. This forces the gateway to establish connection and open a
gRPC channel to each executor so that the first request doesn't need to experience the penalty of
eastablishing a brand new gRPC channel.
:param stop_event: signal to indicate if an early termination of the task is required for graceful teardown.
'''
self.logger.debug(f'Running GatewayRuntime warmup')
deployments = {key for key in self._executor_addresses.keys()}

try:
deployment_warmup_tasks = []
for deployment in deployments:
deployment_warmup_tasks.append(
asyncio.create_task(
self._connection_pool.warmup(
deployment=deployment, stop_event=stop_event
)
)
)

await asyncio.gather(*deployment_warmup_tasks, return_exceptions=True)
except Exception as ex:
self.logger.error(f'error with GatewayRuntime warmup up task: {ex}')
return
Loading