Skip to content

Commit

Permalink
feat: add endpoint for policy (#271)
Browse files Browse the repository at this point in the history
Co-authored-by: Mahyar Ebadi <42976936+mahyareb@users.noreply.github.com>
  • Loading branch information
Yuan325 and mahyareb committed Mar 21, 2024
1 parent 199a67c commit 397ca79
Show file tree
Hide file tree
Showing 13 changed files with 1,863 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ def close_clients(self):
conversations and provide responses that are coherent and relevant to the topic at hand.
Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air
as well as ammenities of San Francisco Airport."""

TOOLS_PREFIX = """
Expand Down
25 changes: 25 additions & 0 deletions llm_demo/orchestrator/langchain_tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ async def search_amenities(query: str):
return search_amenities


def generate_search_policies(client: aiohttp.ClientSession):
async def search_policies(query: str):
response = await client.get(
url=f"{BASE_URL}/policies/search",
params={"top_k": "5", "query": query},
headers=get_headers(client),
)

response = await response.json()
return response

return search_policies


class TicketInput(BaseModel):
airline: str = Field(description="Airline unique 2 letter identifier")
flight_number: str = Field(description="1 to 4 digit number")
Expand Down Expand Up @@ -330,6 +344,17 @@ async def initialize_tools(client: aiohttp.ClientSession):
""",
args_schema=QueryInput,
),
StructuredTool.from_function(
coroutine=generate_search_policies(client),
name="Search Policies",
description="""
Use this tool to search for cymbal air passenger policy.
Policy that are listed is unchangeable.
You will not answer any questions outside of the policy given.
Policy includes information on ticket purchase and changes, baggage, check-in and boarding, special assistance, overbooking, flight delays and cancellations.
""",
args_schema=QueryInput,
),
StructuredTool.from_function(
coroutine=generate_insert_ticket(client),
name="Insert Ticket",
Expand Down
17 changes: 17 additions & 0 deletions llm_demo/orchestrator/vertexai_function_calling/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@
},
)

search_policies_func = generative_models.FunctionDeclaration(
name="policies_search",
description="Use this tool to search for cymbal air passenger policy. Policy that are listed is unchangeable. You will not answer any questions outside of the policy given. Policy includes information on ticket purchase and changes, baggage, check-in and boarding, special assistance, overbooking, flight delays and cancellations. If top_k is not specified, default to 5.",
parameters={
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
"top_k": {
"type": "integer",
"description": "Number of matching policy to return. Default this value to 5.",
},
},
},
)

search_flights_by_number_func = generative_models.FunctionDeclaration(
name="search_flights_by_number",
description="""
Expand Down Expand Up @@ -224,6 +239,7 @@ def function_request(function_call_name: str) -> str:
"search_flights_by_number": "flights/search",
"list_flights": "flights/search",
"amenities_search": "amenities/search",
"policies_search": "policies/search",
"insert_ticket": "tickets/insert",
"list_tickets": "tickets/list",
}
Expand All @@ -235,6 +251,7 @@ def assistant_tool():
function_declarations=[
search_airports_func,
search_amenities_func,
search_policies_func,
search_flights_by_number_func,
list_flights_func,
insert_ticket_func,
Expand Down
44 changes: 44 additions & 0 deletions retrieval_service/app/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,47 @@ def test_search_flights_with_bad_params(m_datastore, app, params):
with TestClient(app) as client:
response = client.get("/flights/search", params=params)
assert response.status_code == 422


policies_search_params = [
pytest.param(
"policies_search",
{
"query": "Additional fee for flight changes.",
"top_k": 1,
},
[
models.Policy(
id=1,
content="foo bar",
),
],
[
{
"id": 1,
"content": "foo bar",
"embedding": None,
},
],
)
]


@pytest.mark.parametrize(
"method_name, params, mock_return, expected", policies_search_params
)
@patch.object(datastore, "create")
def test_policies_search(m_datastore, app, method_name, params, mock_return, expected):
with TestClient(app) as client:
with patch.object(
m_datastore.return_value, method_name, AsyncMock(return_value=mock_return)
) as mock_method:
response = client.get(
"/policies/search",
params=params,
)
assert response.status_code == 200
output = response.json()
assert len(output) == params["top_k"]
assert output == expected
assert models.Policy.model_validate(output[0])
11 changes: 11 additions & 0 deletions retrieval_service/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,14 @@ async def list_tickets(
ds: datastore.Client = request.app.state.datastore
results = await ds.list_tickets(user_info["user_id"])
return results


@routes.get("/policies/search")
async def policies_search(query: str, top_k: int, request: Request):
ds: datastore.Client = request.app.state.datastore

embed_service: Embeddings = request.app.state.embed_service
query_embedding = embed_service.embed_query(query)

results = await ds.policies_search(query_embedding, 0.5, top_k)
return results
6 changes: 6 additions & 0 deletions retrieval_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ async def list_tickets(
) -> list[models.Ticket]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[models.Policy]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def close(self):
pass
Expand Down
50 changes: 50 additions & 0 deletions retrieval_service/datastore/providers/cloudsql_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,27 @@ async def initialize_data(
],
)

# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS tickets CASCADE"))
# Create a new table
await conn.execute(
text(
"""
CREATE TABLE tickets(
user_id TEXT,
user_name TEXT,
user_email TEXT,
airline TEXT,
flight_number TEXT,
departure_airport TEXT,
arrival_airport TEXT,
departure_time TIMESTAMP,
arrival_time TIMESTAMP
)
"""
)
)

# If the table already exists, drop it to avoid conflicts
await conn.execute(text("DROP TABLE IF EXISTS policies CASCADE"))
# Create a new table
Expand Down Expand Up @@ -293,6 +314,9 @@ async def export_data(
policy_task = asyncio.create_task(
conn.execute(text("""SELECT * FROM policies"""))
)
policy_task = asyncio.create_task(
conn.execute(text("""SELECT * FROM policies"""))
)

airport_results = (await airport_task).mappings().fetchall()
amenity_results = (await amenity_task).mappings().fetchall()
Expand Down Expand Up @@ -485,5 +509,31 @@ async def list_tickets(
) -> list[models.Ticket]:
raise NotImplementedError("Not Implemented")

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[models.Policy]:
async with self.__pool.connect() as conn:
s = text(
"""
SELECT id, content
FROM (
SELECT id, content, 1 - (embedding <=> :query_embedding) AS similarity
FROM policies
WHERE 1 - (embedding <=> :query_embedding) > :similarity_threshold
ORDER BY similarity DESC
LIMIT :top_k
) AS sorted_policies
"""
)
params = {
"query_embedding": query_embedding,
"similarity_threshold": similarity_threshold,
"top_k": top_k,
}
results = (await conn.execute(s, params)).mappings().fetchall()

res = [models.Policy.model_validate(r) for r in results]
return res

async def close(self):
await self.__pool.dispose()
73 changes: 69 additions & 4 deletions retrieval_service/datastore/providers/cloudsql_postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@

from .. import datastore
from . import cloudsql_postgres
from .test_data import query_embedding1, query_embedding2, query_embedding3
from .test_data import (
amenities_query_embedding1,
amenities_query_embedding2,
foobar_query_embedding,
policies_query_embedding1,
policies_query_embedding2,
)
from .utils import get_env_var

pytestmark = pytest.mark.asyncio(scope="module")
Expand Down Expand Up @@ -336,7 +342,7 @@ async def test_get_amenity(ds: cloudsql_postgres.Client):
amenities_search_test_data = [
pytest.param(
# "Where can I get coffee near gate A6?"
query_embedding1,
amenities_query_embedding1,
0.65,
1,
[
Expand Down Expand Up @@ -370,7 +376,7 @@ async def test_get_amenity(ds: cloudsql_postgres.Client):
),
pytest.param(
# "Where can I look for luxury goods?"
query_embedding2,
amenities_query_embedding2,
0.65,
2,
[
Expand Down Expand Up @@ -429,7 +435,7 @@ async def test_get_amenity(ds: cloudsql_postgres.Client):
),
pytest.param(
# "FOO BAR"
query_embedding3,
foobar_query_embedding,
0.9,
1,
[],
Expand Down Expand Up @@ -650,3 +656,62 @@ async def test_search_flights_by_airports(
):
res = await ds.search_flights_by_airports(date, departure_airport, arrival_airport)
assert res == expected


policies_search_test_data = [
pytest.param(
# "What is the fee for extra baggage?"
policies_query_embedding1,
0.65,
1,
[
models.Policy(
id=4,
content="## Baggage\nChecked Baggage: Each passenger is allowed 2 checked baggage allowance. Business class and First class passengers are allowed 4 checked baggage. Additional baggage will cost $70 and a $30 fee applies for checked bags over 50 lbs. We don’t accept checked bags over 100 lbs. We only accept checked bags up to 115 inches in total dimensions (length + width + height), and oversized baggage will cost $30. Checked bags above 160 inches in total dimensions will not be accepted.",
embedding=None,
),
],
id="search_extra_baggage_fee",
),
pytest.param(
# "Can I change my flight?"
policies_query_embedding2,
0.65,
2,
[
models.Policy(
id=1,
content="Changes: Changes or reschedules to flights may be permitted depending on the fare type. Changes are permitted right after the ticket is confirmed. The fees for flight changes are $100 for Economy, $50 for Premium Economy, and free for Business Class and First class fares.",
embedding=None,
),
models.Policy(
id=0,
content="# Cymbal Air: Passenger Policy \n## Ticket Purchase and Changes\nTypes of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased.",
embedding=None,
),
],
id="search_flight_delays",
),
pytest.param(
# "FOO BAR"
foobar_query_embedding,
0.65,
1,
[],
id="no_results",
),
]


@pytest.mark.parametrize(
"query_embedding, similarity_threshold, top_k, expected", policies_search_test_data
)
async def test_policies_search(
ds: cloudsql_postgres.Client,
query_embedding: List[float],
similarity_threshold: float,
top_k: int,
expected: List[models.Policy],
):
res = await ds.policies_search(query_embedding, similarity_threshold, top_k)
assert res == expected
6 changes: 5 additions & 1 deletion retrieval_service/datastore/providers/firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ async def export_data(
policies_docs = self.__client.collection("policies").stream()

airports = []

async for doc in airport_docs:
airport_dict = doc.to_dict()
airport_dict["id"] = doc.id
Expand Down Expand Up @@ -310,5 +309,10 @@ async def list_tickets(
) -> list[models.Ticket]:
raise NotImplementedError("Not Implemented")

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[models.Policy]:
raise NotImplementedError("Semantic search not yet supported in Firestore.")

async def close(self):
self.__client.close()
23 changes: 23 additions & 0 deletions retrieval_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,5 +519,28 @@ async def list_tickets(
results = [models.Ticket.model_validate(dict(r)) for r in results]
return results

async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[models.Policy]:
results = await self.__pool.fetch(
"""
SELECT id, content
FROM (
SELECT id, content, 1 - (embedding <=> $1) AS similarity
FROM policies
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
) AS sorted_policies
""",
query_embedding,
similarity_threshold,
top_k,
timeout=10,
)

results = [models.Policy.model_validate(dict(r)) for r in results]
return results

async def close(self):
await self.__pool.close()
Loading

0 comments on commit 397ca79

Please sign in to comment.