Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming support and tests for query-server. #1027

Merged
merged 8 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions apps/query-server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions apps/query-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ python = ">=3.9,<3.13"
sycamore-ai = { path = "../../lib/sycamore", develop = true, extras = ["opensearch", "local-inference"] }
fastapi = { version = "^0.115.0", extras = ["standard"]}
pydantic = "^2.9.2"
sse-starlette = "^2.1.3"
eric-anderson marked this conversation as resolved.
Show resolved Hide resolved

[tool.poetry.group.dev.dependencies]
black = "^24.4"
Expand Down
202 changes: 180 additions & 22 deletions apps/query-server/queryserver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,28 @@
# Run with:
# poetry run fastapi dev queryserver/main.py

import asyncio
import logging
import os
import tempfile
from typing import Annotated, Any, List, Optional
import time
from typing import Annotated, Any, List, Optional, Union

from fastapi import FastAPI, Path
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from sycamore import DocSet
from sycamore.data import Document, MetadataDocument
from sycamore.query.client import SycamoreQueryClient
from sycamore.query.logical_plan import LogicalPlan
from sycamore.query.result import SycamoreQueryResult
from sycamore.query.schema import OpenSearchSchema

import queryserver.util as util
logger = logging.getLogger("uvicorn.error")
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved


app = FastAPI()


# The query and LLM cache paths.
CACHE_PATH = os.getenv("QUERYSERVER_CACHE_PATH", os.path.join(tempfile.gettempdir(), "queryserver_cache"))
LLM_CACHE_PATH = os.getenv("QUERYSERVER_LLM_CACHE_PATH", os.path.join(tempfile.gettempdir(), "queryserver_llm_cache"))

Expand All @@ -30,33 +35,58 @@ class Index(BaseModel):
"""Represents an index that can be queried."""

index: str
"""The index name."""

description: Optional[str] = None
"""Description of the index."""

index_schema: OpenSearchSchema
"""The schema for this index."""


class Query(BaseModel):
"""Query an index with a given natural language query string."""
"""Query an index with a given natural language query string. One of 'query' or 'plan' must be provided."""
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved

query: str
index: str
"""The index to query."""

query: Optional[str] = None
"""The natural language query to run. if specified, `plan` must not be set."""

plan: Optional[LogicalPlan] = None
"""The logical query plan to run. If specified, `query` must not be set."""

stream: bool = False
"""If true, query results will be streamed back to the client as they are generated."""


class QueryResult(BaseModel):
"""Result of a query."""
"""The result of a non-streaming query."""

query_id: str
"""The unique ID of the query operation."""

plan: LogicalPlan
"""The logical query plan that was executed."""

result: Any
retrieved_docs: list[str]
"""The result of the query operation. Depending on the query, this could be a list of documents,
a single document, a string, an integer, etc.
"""

retrieved_docs: List[str]
"""A list of document paths for the documents retrieved by the query."""


@app.get("/v1/indices")
async def list_indices() -> List[Index]:
"""List all available indices."""

retval = []
indices = util.get_opensearch_indices()
# Exclude the 'internal' indices that start with a dot.
indices = {x for x in sqclient.get_opensearch_indices() if not x.startswith(".")}
for index in indices:
index_schema = util.get_schema(sqclient, index)
index_schema = sqclient.get_opensearch_schema(index)
retval.append(Index(index=index, index_schema=index_schema))
return retval

Expand All @@ -67,29 +97,157 @@ async def get_index(
) -> Index:
"""Return details on the given index."""

schema = util.get_schema(sqclient, index)
schema = sqclient.get_opensearch_schema(index)
return Index(index=index, index_schema=schema)


@app.post("/v1/plan")
async def generate_plan(query: Query) -> LogicalPlan:
"""Generate a query plan for the given query, but do not run it."""

plan = sqclient.generate_plan(query.query, query.index, util.get_schema(sqclient, query.index))
if query.query is None:
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("query is required")
if query.plan is not None:
raise ValueError("plan must not be specified")
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved

plan = sqclient.generate_plan(query.query, query.index, sqclient.get_opensearch_schema(query.index))
return plan


@app.post("/v1/plan/run")
async def run_plan(plan: LogicalPlan) -> SycamoreQueryResult:
"""Run the provided query plan."""
def doc_to_json(doc: Document) -> Optional[dict[str, Any]]:
"""Render a Document as a JSON object. Returns None for MetadataDocuments."""
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved
NUM_TEXT_CHARS_GENERATE = 1024

if isinstance(doc, MetadataDocument):
return None

props_dict = {}
props_dict.update(doc.properties)
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved
if "_schema" in props_dict:
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved
del props_dict["_schema"]
if "_schema_class" in props_dict:
del props_dict["_schema_class"]
if "_doc_source" in props_dict:
del props_dict["_doc_source"]
props_dict["text_representation"] = (
doc.text_representation[:NUM_TEXT_CHARS_GENERATE] if doc.text_representation is not None else None
)
return props_dict


async def run_query_stream(query: Query) -> EventSourceResponse:
"""Streaming version of run_query. Returns a stream of results as they are generated."""

async def query_runner():
try:
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"Generating plan for {query.index}: {query.query}")
yield {
"event": "status",
"data": "Generating plan",
}
await asyncio.sleep(0.1)
plan = sqclient.generate_plan(query.query, query.index, sqclient.get_opensearch_schema(query.index))
logger.info(f"Generated plan: {plan}")
# Don't want to return these through the API.
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved
plan.llm_plan = None
plan.llm_prompt = None
yield {
"event": "plan",
"data": plan.model_dump_json(),
}
await asyncio.sleep(0.1)
logger.info("Running plan")
yield {
"event": "status",
"data": "Running plan",
}
await asyncio.sleep(0.1)
sqresult = sqclient.run_plan(plan)
t1 = time.time()
num_results = 0
if isinstance(sqresult.result, DocSet):
logger.info("Got DocSet result")
for doc in sqresult.result.take_all():
eric-anderson marked this conversation as resolved.
Show resolved Hide resolved
rendered = doc_to_json(doc)
logger.debug(f"Doc: {rendered}")
if rendered is not None:
num_results += 1
yield {
"event": "result_doc",
"data": rendered,
}
await asyncio.sleep(0.1)
else:
num_results += 1
yield {
"event": "result",
"data": sqresult.result,
}
await asyncio.sleep(0.1)

for doc in sqresult.retrieved_docs():
eric-anderson marked this conversation as resolved.
Show resolved Hide resolved
yield {
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved
"event": "retrieved_doc",
"data": doc,
}
await asyncio.sleep(0.1)
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved

t2 = time.time()
logger.info(f"Finished query in {t2 - t1:.2f} seconds with {num_results} results")
yield {
"event": "status",
"data": f"Query complete - {num_results} results in {t2 - t1:.2f} seconds",
}
await asyncio.sleep(0.1)
except asyncio.CancelledError:
logger.info("Disconnected from client")

return EventSourceResponse(query_runner())


@app.post("/v1/query", response_model=None)
async def run_query(query: Query) -> Union[EventSourceResponse, QueryResult]:
"""Run the given query.

If the `stream` parameter is set to true, the result will be streamed back to the client as a series of SSE events.
Otherwise, the result will be returned as a QueryResult object.
"""

logger.info(f"Running query: {query}")

if query.query is None and query.plan is None:
raise ValueError("query or plan is required")
if query.query is not None and query.plan is not None:
raise ValueError("query and plan cannot both be specified")

if query.stream:
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved
return await run_query_stream(query)

if query.plan is None:
assert query.query is not None
logger.info(f"Generating plan for {query.index}: {query.query}")
plan = sqclient.generate_plan(query.query, query.index, sqclient.get_opensearch_schema(query.index))
logger.info(f"Generated plan: {plan}")
else:
plan = query.plan

return sqclient.run_plan(plan)
sqresult = sqclient.run_plan(plan)
returned_plan = sqresult.plan

# Don't want to return these through the API.
returned_plan.llm_plan = None
returned_plan.llm_prompt = None

@app.post("/v1/query")
async def run_query(query: Query) -> QueryResult:
"""Generate a plan for the given query, run it, and return the result."""
query_result = QueryResult(query_id=sqresult.query_id, plan=returned_plan, result=[], retrieved_docs=[])

plan = sqclient.generate_plan(query.query, query.index, util.get_schema(sqclient, query.index))
sqresult = sqclient.run_plan(plan)
return QueryResult(plan=sqresult.plan, result=sqresult.result, retrieved_docs=sqresult.retrieved_docs())
if isinstance(sqresult.result, DocSet):
for doc in sqresult.result.take_all():
rendered = doc_to_json(doc)
logger.debug(f"Doc: {rendered}")
if rendered is not None:
query_result.result.append(rendered)
else:
query_result.result = sqresult.result

query_result.retrieved_docs = sqresult.retrieved_docs()
return query_result
Loading
Loading