Skip to content

Commit

Permalink
Add streaming support and tests for query-server. (#1027)
Browse files Browse the repository at this point in the history
* Adding streaming support to queryserver.

* Add unit test.

* Linting.

* Fix typing.

* Update dependencies and logging.

* Fix poetry lock.

* PR comments.
  • Loading branch information
mdwelsh authored Nov 19, 2024
1 parent 9cd2a13 commit 70f25a7
Show file tree
Hide file tree
Showing 7 changed files with 873 additions and 71 deletions.
30 changes: 22 additions & 8 deletions apps/query-server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions apps/query-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ python = ">=3.9,<3.13"
sycamore-ai = { path = "../../lib/sycamore", develop = true, extras = ["opensearch", "local-inference"] }
fastapi = { version = "^0.115.0", extras = ["standard"]}
pydantic = "^2.9.2"
sse-starlette = "^2.1.3"

[tool.poetry.group.dev.dependencies]
black = "^24.4"
Expand Down
210 changes: 187 additions & 23 deletions apps/query-server/queryserver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,30 @@
# Run with:
# poetry run fastapi dev queryserver/main.py

import asyncio
import logging
import os
import tempfile
from typing import Annotated, Any, List, Optional
import time
from typing import Annotated, Any, List, Optional, Union

from fastapi import FastAPI, Path
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from sse_starlette.sse import EventSourceResponse
from sycamore import DocSet
from sycamore.data import Document, MetadataDocument
from sycamore.query.client import SycamoreQueryClient
from sycamore.query.logical_plan import LogicalPlan
from sycamore.query.result import SycamoreQueryResult
from sycamore.query.schema import OpenSearchSchema

import queryserver.util as util
# This is the uvicorn general logger, the error name is misleading.
# https://github.com/encode/uvicorn/issues/562
logger = logging.getLogger("uvicorn.error")


app = FastAPI()


# The query and LLM cache paths.
CACHE_PATH = os.getenv("QUERYSERVER_CACHE_PATH", os.path.join(tempfile.gettempdir(), "queryserver_cache"))
LLM_CACHE_PATH = os.getenv("QUERYSERVER_LLM_CACHE_PATH", os.path.join(tempfile.gettempdir(), "queryserver_llm_cache"))

Expand All @@ -30,33 +37,66 @@ class Index(BaseModel):
"""Represents an index that can be queried."""

index: str
"""The index name."""

description: Optional[str] = None
"""Description of the index."""

index_schema: OpenSearchSchema
"""The schema for this index."""


class Query(BaseModel):
"""Query an index with a given natural language query string."""
"""Query an index with a given natural language query string. One of 'query' or 'plan' must be provided."""

query: str
index: str
"""The index to query."""

query: Optional[str] = None
"""The natural language query to run. if specified, `plan` must not be set."""

plan: Optional[LogicalPlan] = None
"""The logical query plan to run. If specified, `query` must not be set."""

stream: bool = False
"""If true, query results will be streamed back to the client as they are generated."""

@model_validator(mode="after")
def check_not_both_query_and_plan(self):
if self.query is not None and self.plan is not None:
raise ValueError("query and plan cannot both be specified")
if self.query is None and self.plan is None:
raise ValueError("one of query or plan is required")
return self


class QueryResult(BaseModel):
"""Result of a query."""
"""The result of a non-streaming query."""

query_id: str
"""The unique ID of the query operation."""

plan: LogicalPlan
"""The logical query plan that was executed."""

result: Any
retrieved_docs: list[str]
"""The result of the query operation. Depending on the query, this could be a list of documents,
a single document, a string, an integer, etc.
"""

retrieved_docs: List[str]
"""A list of document paths for the documents retrieved by the query."""


@app.get("/v1/indices")
async def list_indices() -> List[Index]:
"""List all available indices."""

retval = []
indices = util.get_opensearch_indices()
# Exclude the 'internal' indices that start with a dot.
indices = {x for x in sqclient.get_opensearch_indices() if not x.startswith(".")}
for index in indices:
index_schema = util.get_schema(sqclient, index)
index_schema = sqclient.get_opensearch_schema(index)
retval.append(Index(index=index, index_schema=index_schema))
return retval

Expand All @@ -67,29 +107,153 @@ async def get_index(
) -> Index:
"""Return details on the given index."""

schema = util.get_schema(sqclient, index)
schema = sqclient.get_opensearch_schema(index)
return Index(index=index, index_schema=schema)


@app.post("/v1/plan")
async def generate_plan(query: Query) -> LogicalPlan:
"""Generate a query plan for the given query, but do not run it."""

plan = sqclient.generate_plan(query.query, query.index, util.get_schema(sqclient, query.index))
if query.query is None:
raise ValueError("query is required")
if query.plan is not None:
raise ValueError("plan must not be specified")

plan = sqclient.generate_plan(query.query, query.index, sqclient.get_opensearch_schema(query.index))
return plan


@app.post("/v1/plan/run")
async def run_plan(plan: LogicalPlan) -> SycamoreQueryResult:
"""Run the provided query plan."""
def doc_to_json(doc: Document) -> Optional[dict[str, Any]]:
"""Render a Document as a JSON object. Only external properties and truncated text_representation
are included. Returns None for MetadataDocuments."""

NUM_TEXT_CHARS_GENERATE = 1024

if isinstance(doc, MetadataDocument):
return None

props_dict = {}
props_dict.update(doc.properties)
props_dict.pop("_schema", None)
props_dict.pop("_schema_class", None)
props_dict.pop("_doc_source", None)
props_dict["text_representation"] = (
doc.text_representation[:NUM_TEXT_CHARS_GENERATE] if doc.text_representation is not None else None
)
return props_dict


async def run_query_stream(query: Query) -> EventSourceResponse:
"""Streaming version of run_query. Returns a stream of results as they are generated."""

async def query_runner():
try:
plan = query.plan
if plan is None:
logger.info(f"Generating plan for {query.index}: {query.query}")
yield {
"event": "status",
"data": "Generating plan",
}
await asyncio.sleep(0.1)
plan = sqclient.generate_plan(query.query, query.index, sqclient.get_opensearch_schema(query.index))
logger.info(f"Generated plan: {plan}")
# Don't want to return these through the API.
plan.llm_plan = None
plan.llm_prompt = None
yield {
"event": "plan",
"data": plan.model_dump_json(),
}
await asyncio.sleep(0.1)
logger.info("Running plan")
yield {
"event": "status",
"data": "Running plan",
}
await asyncio.sleep(0.1)
sqresult = sqclient.run_plan(plan)
t1 = time.time()
num_results = 0
if isinstance(sqresult.result, DocSet):
logger.info("Got DocSet result")
for doc in sqresult.result.take_all():
rendered = doc_to_json(doc)
logger.debug(f"Doc: {rendered}")
if rendered is not None:
num_results += 1
yield {
"event": "result_doc",
"data": rendered,
}
await asyncio.sleep(0.1)
else:
num_results += 1
yield {
"event": "result",
"data": sqresult.result,
}
await asyncio.sleep(0.1)

for doc in sqresult.retrieved_docs():
yield {
"event": "retrieved_doc",
"data": doc,
}
await asyncio.sleep(0.1)

t2 = time.time()
logger.info(f"Finished query in {t2 - t1:.2f} seconds with {num_results} results")
yield {
"event": "status",
"data": f"Query complete - {num_results} results in {t2 - t1:.2f} seconds",
}
await asyncio.sleep(0.1)
except asyncio.CancelledError:
logger.info("Disconnected from client")

return EventSourceResponse(query_runner())


@app.post("/v1/query", response_model=None)
async def run_query(query: Query) -> Union[EventSourceResponse, QueryResult]:
"""Run the given query.
If the `stream` parameter is set to true, the result will be streamed back to the client as a series of SSE events.
Otherwise, the result will be returned as a QueryResult object.
"""

logger.info(f"Running query: {query}")

if query.stream:
return await run_query_stream(query)

if query.plan is None:
assert query.query is not None
logger.info(f"Generating plan for {query.index}: {query.query}")
plan = sqclient.generate_plan(query.query, query.index, sqclient.get_opensearch_schema(query.index))
logger.info(f"Generated plan: {plan}")
else:
plan = query.plan

return sqclient.run_plan(plan)
sqresult = sqclient.run_plan(plan)
returned_plan = sqresult.plan

# Don't want to return these through the API.
returned_plan.llm_plan = None
returned_plan.llm_prompt = None

@app.post("/v1/query")
async def run_query(query: Query) -> QueryResult:
"""Generate a plan for the given query, run it, and return the result."""
query_result = QueryResult(query_id=sqresult.query_id, plan=returned_plan, result=[], retrieved_docs=[])

plan = sqclient.generate_plan(query.query, query.index, util.get_schema(sqclient, query.index))
sqresult = sqclient.run_plan(plan)
return QueryResult(plan=sqresult.plan, result=sqresult.result, retrieved_docs=sqresult.retrieved_docs())
if isinstance(sqresult.result, DocSet):
for doc in sqresult.result.take_all():
rendered = doc_to_json(doc)
logger.debug(f"Doc: {rendered}")
if rendered is not None:
query_result.result.append(rendered)
else:
query_result.result = sqresult.result

query_result.retrieved_docs = sqresult.retrieved_docs()
return query_result
Loading

0 comments on commit 70f25a7

Please sign in to comment.