From 8423fd9d3d96c51cf7d4a34ca5a6dfe495e7ea84 Mon Sep 17 00:00:00 2001 From: Razvan Dinu Date: Fri, 16 Feb 2024 16:53:01 +0200 Subject: [PATCH] Fix #320. Reuse the asyncio loop between sync calls (since other components might cache it). --- nemoguardrails/actions/llm/generation.py | 3 ++- nemoguardrails/colang/v2_x/runtime/statemachine.py | 6 +++--- nemoguardrails/rails/llm/llmrails.py | 12 ++++++++---- tests/test_nest_asyncio.py | 9 ++++++++- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 80d73fe8e..f2f5efa6e 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -100,12 +100,13 @@ def __init__( # There are still some edge cases not covered by nest_asyncio. # Using a separate thread always for now. + loop = asyncio.get_event_loop() if True or check_sync_call_from_async_loop(): t = threading.Thread(target=asyncio.run, args=(self.init(),)) t.start() t.join() else: - asyncio.run(self.init()) + loop.run_until_complete(self.init()) self.llm_task_manager = llm_task_manager diff --git a/nemoguardrails/colang/v2_x/runtime/statemachine.py b/nemoguardrails/colang/v2_x/runtime/statemachine.py index 864da5755..ed492d59a 100644 --- a/nemoguardrails/colang/v2_x/runtime/statemachine.py +++ b/nemoguardrails/colang/v2_x/runtime/statemachine.py @@ -698,9 +698,9 @@ def _resolve_action_conflicts( index = competing_flow_state.action_uids.index( competing_event.action_uid ) - competing_flow_state.action_uids[index] = ( - winning_event.action_uid - ) + competing_flow_state.action_uids[ + index + ] = winning_event.action_uid del state.actions[competing_event.action_uid] advancing_heads.append(head) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index d41af9aa1..f9b0d0149 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -222,12 +222,13 @@ def __init__( # Next, we initialize the Knowledge Base # There are still some edge cases not covered by nest_asyncio. # Using a separate thread always for now. + loop = asyncio.get_event_loop() if True or check_sync_call_from_async_loop(): t = threading.Thread(target=asyncio.run, args=(self._init_kb(),)) t.start() t.join() else: - asyncio.run(self._init_kb()) + loop.run_until_complete(self._init_kb()) # We also register the kb as a parameter that can be passed to actions. self.runtime.register_action_param("kb", self.kb) @@ -724,7 +725,8 @@ def generate( "You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`." ) - return asyncio.run( + loop = asyncio.get_event_loop() + return loop.run_until_complete( self.generate_async( prompt=prompt, messages=messages, @@ -788,7 +790,8 @@ def generate_events( "You should replace with `await generate_events_async(...)` or use `nest_asyncio.apply()`." ) - return asyncio.run(self.generate_events_async(events=events)) + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.generate_events_async(events=events)) async def process_events_async( self, events: List[dict], state: Optional[dict] = None @@ -835,7 +838,8 @@ def process_events( "You should replace with `await generate_events_async(...)." ) - return asyncio.run(self.process_events_async(events, state)) + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.process_events_async(events, state)) def register_action(self, action: callable, name: Optional[str] = None): """Register a custom action for the rails configuration.""" diff --git a/tests/test_nest_asyncio.py b/tests/test_nest_asyncio.py index 66c21ddd6..66e933e62 100644 --- a/tests/test_nest_asyncio.py +++ b/tests/test_nest_asyncio.py @@ -54,11 +54,18 @@ async def test_async_api_error(monkeypatch): # Reload the module to re-run its top-level code with the new env var importlib.reload(nemoguardrails) + importlib.reload(nemoguardrails.patch_asyncio) importlib.reload(asyncio) + # Remove the patching marker + delattr(asyncio, "_nest_patched") + + assert nemoguardrails.patch_asyncio.nest_asyncio_patch_applied is False + assert not hasattr(asyncio, "_nest_patched") + with pytest.raises( RuntimeError, - match=r"asyncio.run\(\) cannot be called from a running event loop", + match=r"await generate_async", ): chat >> "Hi!" chat << "Hello there!"