Skip to content

Commit

Permalink
fix aio exception handling (NVIDIA#2084)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv authored Oct 18, 2023
1 parent 58cf5a6 commit 5e28e31
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 64 deletions.
9 changes: 9 additions & 0 deletions nvflare/fuel/f3/drivers/aio_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,19 @@ def get_event_loop(self):

return self.loop

def _handle_exception(self, loop, context):
try:
msg = context.get("exception", context["message"])
self.logger.debug(f"AIO Exception: {msg}")
except Exception as ex:
# ignore exception in the exception handler
self.logger.debug(f"exception in aio exception handler: {ex}")

def run_aio_loop(self):
self.logger.debug(f"{self.name}: started AioContext in thread {threading.current_thread().name}")
# self.loop = asyncio.get_event_loop()
self.loop = asyncio.new_event_loop()
self.loop.set_exception_handler(self._handle_exception)
asyncio.set_event_loop(self.loop)
self.logger.debug(f"{self.name}: got loop: {id(self.loop)}")
self.ready.set()
Expand Down
59 changes: 24 additions & 35 deletions nvflare/fuel/f3/drivers/aio_grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di

conf = CommConfigurator()
if conf.get_bool_var("simulate_unstable_network", default=False):
self.disconn = threading.Thread(target=self._disconnect, daemon=True)
self.disconn.start()
if context:
# only server side
self.disconn = threading.Thread(target=self._disconnect, daemon=True)
self.disconn.start()

def _disconnect(self):
t = random.randint(10, 60)
Expand All @@ -88,9 +90,11 @@ def close(self):
if self.context:
self.aio_ctx.run_coro(self.context.abort(grpc.StatusCode.CANCELLED, "service closed"))
self.context = None
self.logger.info("Closed GRPC context")
if self.channel:
self.aio_ctx.run_coro(self.channel.close())
self.channel = None
self.logger.info("Closed GRPC Channel")

def send_frame(self, frame: BytesAlike):
try:
Expand Down Expand Up @@ -298,57 +302,42 @@ async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, co
address = get_address(params)

self.logger.debug(f"CLIENT: trying to connect {address}")
connection = None
try:
secure = ssl_required(params)
if secure:
grpc_channel = grpc.aio.secure_channel(
channel = grpc.aio.secure_channel(
address, options=self.options, credentials=get_grpc_client_credentials(params)
)
self.logger.info(f"created secure channel at {address}")
else:
grpc_channel = grpc.aio.insecure_channel(address, options=self.options)
channel = grpc.aio.insecure_channel(address, options=self.options)
self.logger.info(f"created insecure channel at {address}")
stub = StreamerStub(channel)

async with grpc_channel as channel:
self.logger.debug(f"CLIENT: connected to {address}")
stub = StreamerStub(channel)
conn_props = {DriverParams.PEER_ADDR.value: address}

if secure:
conn_props[DriverParams.PEER_CN.value] = "N/A"
self.logger.debug(f"CLIENT: connected to {address}")
conn_props = {DriverParams.PEER_ADDR.value: address}

connection = AioStreamSession(
aio_ctx=aio_ctx, connector=connector, conn_props=conn_props, channel=channel
)
if secure:
conn_props[DriverParams.PEER_CN.value] = "N/A"

try:
self.logger.debug(f"CLIENT: start streaming on connection {connection}")
msg_iter = stub.Stream(connection.generate_output())
conn_ctx.conn = connection
await connection.read_loop(msg_iter)
except asyncio.CancelledError as error:
self.logger.debug(f"CLIENT: RPC cancelled: {error}")
except Exception as ex:
if self.closing:
self.logger.debug(
f"Connection {connection} closed by {type(ex)}: {secure_format_exception(ex)}"
)
else:
self.logger.debug(
f"Connection {connection} client read exception {type(ex)}: {secure_format_exception(ex)}"
)
self.logger.debug(secure_format_traceback())
connection = AioStreamSession(aio_ctx=aio_ctx, connector=connector, conn_props=conn_props, channel=channel)

with connection.lock:
connection.channel = None
connection.close()
self.logger.debug(f"CLIENT: start streaming on connection {connection}")
msg_iter = stub.Stream(connection.generate_output())
conn_ctx.conn = connection
await connection.read_loop(msg_iter)
except asyncio.CancelledError:
self.logger.debug("CLIENT: RPC cancelled")
except grpc.FutureCancelledError:
self.logger.info("CLIENT: Future cancelled")
except Exception as ex:
conn_ctx.error = f"connection {connection} error: {type(ex)}: {secure_format_exception(ex)}"
self.logger.debug(conn_ctx.error)
self.logger.debug(secure_format_traceback())

finally:
if connection:
connection.close()
conn_ctx.waiter.set()

def connect(self, connector: ConnectorInfo):
Expand Down
60 changes: 31 additions & 29 deletions nvflare/fuel/f3/drivers/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,18 @@ def close(self):
if self.context:
try:
self.context.abort(grpc.StatusCode.CANCELLED, "service closed")
except:
except Exception as ex:
# ignore any exception when aborting
pass
self.logger.debug(f"exception aborting GRPC context: {secure_format_exception(ex)}")
self.context = None
self.logger.info("Closed GRPC context")
if self.channel:
self.channel.close()
try:
self.channel.close()
except Exception as ex:
self.logger.debug(f"exception closing GRPC channel: {secure_format_exception(ex)}")
self.channel = None
self.logger.info("Closed GRPC Channel")

def send_frame(self, frame: Union[bytes, bytearray, memoryview]):
try:
Expand Down Expand Up @@ -233,39 +238,36 @@ def connect(self, connector: ConnectorInfo):
params = connector.params
address = get_address(params)
conn_props = {DriverParams.PEER_ADDR.value: address}
connection = None
try:
secure = ssl_required(params)
if secure:
self.logger.debug("CLIENT: creating secure channel")
channel = grpc.secure_channel(
address, options=self.options, credentials=get_grpc_client_credentials(params)
)
self.logger.info(f"created secure channel at {address}")
else:
self.logger.info("CLIENT: creating insecure channel")
channel = grpc.insecure_channel(address, options=self.options)
self.logger.info(f"created insecure channel at {address}")

secure = ssl_required(params)
if secure:
self.logger.debug("CLIENT: creating secure channel")
channel = grpc.secure_channel(
address, options=self.options, credentials=get_grpc_client_credentials(params)
)
self.logger.info(f"created secure channel at {address}")
else:
self.logger.info("CLIENT: creating insecure channel")
channel = grpc.insecure_channel(address, options=self.options)
self.logger.info(f"created insecure channel at {address}")

with channel:
stub = StreamerStub(channel)
self.logger.debug("CLIENT: got stub")
oq = QQ()
connection = StreamConnection(oq, connector, conn_props, "CLIENT", channel=channel)
self.add_connection(connection)
self.logger.debug("CLIENT: added connection")
try:
received = stub.Stream(connection.generate_output())
connection.read_loop(received)

except BaseException as ex:
self.logger.info(f"CLIENT: connection done: {type(ex)}")

with connection.lock:
# when we get here the channel is already closed
# set connection.channel to None to prevent closing channel again in connection.close().
connection.channel = None
connection.close()
self.close_connection(connection)
received = stub.Stream(connection.generate_output())
connection.read_loop(received)
except grpc.FutureCancelledError:
self.logger.debug("RPC Cancelled")
except Exception as ex:
self.logger.error(f"connection {connection} error: {type(ex)}: {secure_format_exception(ex)}")
finally:
if connection:
connection.close()
self.close_connection(connection)
self.logger.info(f"CLIENT: finished connection {connection}")

@staticmethod
Expand Down

0 comments on commit 5e28e31

Please sign in to comment.