Skip to content

Commit

Permalink
fix rag stream (#895)
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty authored Aug 20, 2024
1 parent 289ccfb commit a1b6497
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 23 deletions.
38 changes: 31 additions & 7 deletions py/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .models import R2RException
from .restructure import RestructureMethods
from .retrieval import RetrievalMethods

from typing import AsyncGenerator, Generator
nest_asyncio.apply()

# The empty args become necessary after a recent modification to `base_endpoint`
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
base_url: str = "http://localhost:8000",
prefix: str = "/v1",
custom_client=None,
timeout: float = 60.0,
timeout: float = 300.0,
):
self.base_url = base_url
self.prefix = prefix
Expand Down Expand Up @@ -136,6 +136,19 @@ async def _make_request(self, method, endpoint, **kwargs):
status_code=500, message=f"Request failed: {str(e)}"
)


async def _make_streaming_request(self, method: str, endpoint: str, **kwargs) -> AsyncGenerator[str, None]:
url = f"{self.base_url}{self.prefix}/{endpoint}"
headers = kwargs.pop("headers", {})
if self.access_token and endpoint not in ["register", "login", "verify_email"]:
headers.update(self._get_auth_header())

async with httpx.AsyncClient() as client:
async with client.stream(method, url, headers=headers, timeout=self.timeout, **kwargs) as response:
handle_request_error(response)
async for chunk in response.aiter_text():
yield chunk

def _get_auth_header(self) -> dict:
if not self.access_token:
return {}
Expand Down Expand Up @@ -175,15 +188,26 @@ class R2RClient:
def __init__(self, *args, **kwargs):
self.async_client = R2RAsyncClient(*args, **kwargs)

def _sync_generator(self, async_gen: AsyncGenerator) -> Generator:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

try:
while True:
yield loop.run_until_complete(async_gen.__anext__())
except StopAsyncIteration:
pass
finally:
loop.close()

def __getattr__(self, name):
async_attr = getattr(self.async_client, name)
if callable(async_attr):

def sync_wrapper(*args, **kwargs):
return asyncio.get_event_loop().run_until_complete(
async_attr(*args, **kwargs)
)

result = asyncio.get_event_loop().run_until_complete(async_attr(*args, **kwargs))
if isinstance(result, AsyncGenerator):
return self._sync_generator(result)
return result
return sync_wrapper
return async_attr

Expand Down
40 changes: 24 additions & 16 deletions py/sdk/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,14 @@ async def rag(
}

if rag_generation_config.stream:

async def stream_response():
async for chunk in await client._make_request(
"POST", "rag", json=data, stream=True
):
yield RAGResponse(**chunk)

return stream_response()
return client._make_streaming_request(
"POST",
"rag",
json=data
)
else:
return await client._make_request("POST", "rag", json=data)


@staticmethod
async def agent(
Expand Down Expand Up @@ -172,13 +170,23 @@ async def agent(
}

if rag_generation_config.stream:
return client._make_streaming_request(
"POST",
"agent",
json=data
)
else:
return await client._make_request("POST", "rag", json=data)

async def stream_response():
async for chunk in await client._make_request(
"POST", "agent", json=data, stream=True
):
yield Message(**chunk)

return stream_response()
else:
return await client._make_request("POST", "agent", json=data)
# if rag_generation_config.stream:

# async def stream_response():
# async for chunk in await client._make_request(
# "POST", "agent", json=data, stream=True
# ):
# yield Message(**chunk)

# return stream_response()
# else:
# return await client._make_request("POST", "agent", json=data)

0 comments on commit a1b6497

Please sign in to comment.