From 60fc566e2cdff74992e48c630f037daa5aff0fac Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Fri, 14 Jun 2024 22:35:38 +0000 Subject: [PATCH 01/15] create mysql provider & test --- .../datastore/providers/cloudsql_mysql.py | 541 +++++++++++++++ .../providers/cloudsql_mysql_test.py | 642 ++++++++++++++++++ 2 files changed, 1183 insertions(+) create mode 100644 retrieval_service/datastore/providers/cloudsql_mysql.py create mode 100644 retrieval_service/datastore/providers/cloudsql_mysql_test.py diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py new file mode 100644 index 00000000..675f68ce --- /dev/null +++ b/retrieval_service/datastore/providers/cloudsql_mysql.py @@ -0,0 +1,541 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from datetime import datetime +from typing import Any, Literal, Optional + +from google.cloud.sql.connector import Connector +from pydantic import BaseModel +from sqlalchemy import text, create_engine, Engine +from sqlalchemy.engine.base import Engine + +import pymysql +import models + +from .. import datastore + +MYSQL_IDENTIFIER = "cloudsql-mysql" + +class Config(BaseModel, datastore.AbstractConfig): + kind: Literal["cloudsql-mysql"] + project: str + region: str + instance: str + user: str + password: str + database: str + + +class Client(datastore.Client[Config]): + __pool: Engine + + @datastore.classproperty + def kind(cls): + return "cloudsql-mysql" + + def __init__(self, pool: Engine): + self.__pool = pool + + @classmethod + async def create(cls, config: Config) -> "Client": + loop = asyncio.get_running_loop() + + def getconn() -> pymysql.Connection: + with Connector() as connector: + conn: pymysql.Connection = connector.connect( + # Cloud SQL instance connection name + "juliaofferman-playground:us-central1:my-vector-app", + "pymysql", + user="mysql", + password="my-cloudsql-pass", + db="assistantdemo", + autocommit=True, + ) + return conn + + pool = create_engine( + "mysql+pymysql://", + creator=getconn, + ) + if pool is None: + raise TypeError("pool not instantiated") + return cls(pool) + + async def initialize_data( + self, + airports: list[models.Airport], + amenities: list[models.Amenity], + flights: list[models.Flight], + policies: list[models.Policy], + ) -> None: + with self.__pool.connect() as conn: + # If the table already exists, drop it to avoid conflicts + conn.execute(text("DROP TABLE IF EXISTS airports")) + # Create a new table + conn.execute( + text( + """ + CREATE TABLE airports( + id INT PRIMARY KEY, + iata TEXT, + name TEXT, + city TEXT, + country TEXT + ) + """ + ) + ) + # Insert all the data + conn.execute( + text( + """INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)""" + ),parameters=[{ + "id": a.id, + "iata": a.iata, + "name": a.name, + "city": a.city, + "country": a.country, + } for a in airports] + ) + + # If the table already exists, drop it to avoid conflicts + conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE")) + + # Create a new table + conn.execute( + text( + """ + CREATE TABLE amenities( + id INT PRIMARY KEY, + name TEXT, + description TEXT, + location TEXT, + terminal TEXT, + category TEXT, + hour TEXT, + sunday_start_hour TIME, + sunday_end_hour TIME, + monday_start_hour TIME, + monday_end_hour TIME, + tuesday_start_hour TIME, + tuesday_end_hour TIME, + wednesday_start_hour TIME, + wednesday_end_hour TIME, + thursday_start_hour TIME, + thursday_end_hour TIME, + friday_start_hour TIME, + friday_end_hour TIME, + saturday_start_hour TIME, + saturday_end_hour TIME, + content TEXT NOT NULL, + embedding vector(768) USING VARBINARY NOT NULL + ) + """ + ) + ) + + # Insert all the data + conn.execute( + text( + """ + INSERT INTO amenities VALUES (:id, :name, :description, :location, + :terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour, + :monday_start_hour, :monday_end_hour, :tuesday_start_hour, + :tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour, + :thursday_start_hour, :thursday_end_hour, :friday_start_hour, + :friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, string_to_vector(:embedding)) + """ + ),parameters=[{ + "id": a.id, + "name": a.name, + "description": a.description, + "location": a.location, + "terminal": a.terminal, + "category": a.category, + "hour": a.hour, + "sunday_start_hour": a.sunday_start_hour, + "sunday_end_hour": a.sunday_end_hour, + "monday_start_hour": a.monday_start_hour, + "monday_end_hour": a.monday_end_hour, + "tuesday_start_hour": a.tuesday_start_hour, + "tuesday_end_hour": a.tuesday_end_hour, + "wednesday_start_hour": a.wednesday_start_hour, + "wednesday_end_hour": a.wednesday_end_hour, + "thursday_start_hour": a.thursday_start_hour, + "thursday_end_hour": a.thursday_end_hour, + "friday_start_hour": a.friday_start_hour, + "friday_end_hour": a.friday_end_hour, + "saturday_start_hour": a.saturday_start_hour, + "saturday_end_hour": a.saturday_end_hour, + "content": a.content, + "embedding": f"{a.embedding}", + } for a in amenities] + ) + + # Create a vector index for the embeddings column + conn.execute(text("CALL mysql.create_vector_index('amenities_index', 'assistantdemo.amenities', 'embedding', '')")) + + # If the table already exists, drop it to avoid conflicts + conn.execute(text("DROP TABLE IF EXISTS flights")) + # Create a new table + conn.execute( + text( + """ + CREATE TABLE flights( + id INTEGER PRIMARY KEY, + airline TEXT, + flight_number TEXT, + departure_airport TEXT, + arrival_airport TEXT, + departure_time TIMESTAMP, + arrival_time TIMESTAMP, + departure_gate TEXT, + arrival_gate TEXT + ) + """ + ) + ) + # Insert all the data + conn.execute( + text( + """ + INSERT INTO flights VALUES (:id, :airline, :flight_number, + :departure_airport, :arrival_airport, :departure_time, + :arrival_time, :departure_gate, :arrival_gate) + """ + ),parameters=[{ + "id": f.id, + "airline": f.airline, + "flight_number": f.flight_number, + "departure_airport": f.departure_airport, + "arrival_airport": f.arrival_airport, + "departure_time": f.departure_time, + "arrival_time": f.arrival_time, + "departure_gate": f.departure_gate, + "arrival_gate": f.arrival_gate, + } for f in flights] + ) + + # If the table already exists, drop it to avoid conflicts + conn.execute(text("DROP TABLE IF EXISTS tickets")) + # Create a new table + conn.execute( + text( + """ + CREATE TABLE tickets( + user_id TEXT, + user_name TEXT, + user_email TEXT, + airline TEXT, + flight_number TEXT, + departure_airport TEXT, + arrival_airport TEXT, + departure_time TIMESTAMP, + arrival_time TIMESTAMP + ) + """ + ) + ) + + # If the table already exists, drop it to avoid conflicts + conn.execute(text("DROP TABLE IF EXISTS policies")) + # Create a new table + conn.execute( + text( + """ + CREATE TABLE policies( + id INT PRIMARY KEY, + content TEXT NOT NULL, + embedding vector(768) USING VARBINARY NOT NULL + ) + """ + ) + ) + # Insert all the data + conn.execute( + text( + """ + INSERT INTO policies VALUES (:id, :content, string_to_vector(:embedding)) + """ + ),parameters=[{ + "id": p.id, + "content": p.content, + "embedding": f"{p.embedding}", + } for p in policies]) + + # Create a vector index on the embedding column + conn.execute(text("CALL mysql.create_vector_index('policies_index', 'assistantdemo.policies', 'embedding', '')")) + + async def export_data( + self, + ) -> tuple[ + list[models.Airport], + list[models.Amenity], + list[models.Flight], + list[models.Policy], + ]: + with self.__pool.connect() as conn: + airport_task = conn.execute(text("""SELECT * FROM airports ORDER BY id ASC""")) + amenity_task = conn.execute(text(""" + SELECT id, + name, + description, + location, + terminal, + category, + hour, + DATE_FORMAT(sunday_start_hour, '%H:%i') AS sunday_start_hour, + DATE_FORMAT(sunday_end_hour, '%H:%i') AS sunday_end_hour, + DATE_FORMAT(monday_start_hour, '%H:%i') AS monday_start_hour, + DATE_FORMAT(monday_end_hour, '%H:%i') AS monday_end_hour, + DATE_FORMAT(tuesday_start_hour, '%H:%i') AS tuesday_start_hour, + DATE_FORMAT(tuesday_end_hour, '%H:%i') AS tuesday_end_hour, + DATE_FORMAT(wednesday_start_hour, '%H:%i') AS wednesday_start_hour, + DATE_FORMAT(wednesday_end_hour, '%H:%i') AS wednesday_end_hour, + DATE_FORMAT(thursday_start_hour, '%H:%i') AS thursday_start_hour, + DATE_FORMAT(thursday_end_hour, '%H:%i') AS thursday_end_hour, + DATE_FORMAT(friday_start_hour, '%H:%i') AS friday_start_hour, + DATE_FORMAT(friday_end_hour, '%H:%i') AS friday_end_hour, + DATE_FORMAT(saturday_start_hour, '%H:%i') AS saturday_start_hour, + DATE_FORMAT(saturday_end_hour, '%H:%i') AS saturday_end_hour, + content, + vector_to_string(embedding) as embedding + FROM amenities ORDER BY id ASC + """)) + flights_task = conn.execute(text("""SELECT * FROM flights ORDER BY id ASC""")) + policy_task = conn.execute(text("""SELECT id, content, vector_to_string(embedding) as embedding FROM policies ORDER BY id ASC""")) + + airport_results = (airport_task).mappings().fetchall() + amenity_results = (amenity_task).mappings().fetchall() + flights_results = (flights_task).mappings().fetchall() + policy_results = (policy_task).mappings().fetchall() + + airports = [models.Airport.model_validate(a) for a in airport_results] + amenities = [models.Amenity.model_validate(a) for a in amenity_results] + flights = [models.Flight.model_validate(f) for f in flights_results] + policies = [models.Policy.model_validate(p) for p in policy_results] + + return airports, amenities, flights, policies + + async def get_airport_by_id(self, id: int) -> Optional[models.Airport]: + with self.__pool.connect() as conn: + s = text("""SELECT * FROM airports WHERE id=:id""") + params = {"id" : id} + result = (conn.execute(s, params)).mappings().fetchone() + + if result is None: + return None + + res = models.Airport.model_validate(result) + return res + + async def get_airport_by_iata(self, iata: str) -> Optional[models.Airport]: + with self.__pool.connect() as conn: + s = text("""SELECT * FROM airports WHERE LOWER(iata) LIKE LOWER(:iata)""") + params = {"iata": iata} + result = (conn.execute(s, params)).mappings().fetchone() + + if result is None: + return None + + res = models.Airport.model_validate(result) + return res + + async def search_airports( + self, + country: Optional[str] = None, + city: Optional[str] = None, + name: Optional[str] = None, + ) -> list[models.Airport]: + with self.__pool.connect() as conn: + s = text( + """ + SELECT * FROM airports + WHERE (:country IS NULL OR LOWER(country) LIKE CONCAT('%', LOWER(:country), '%')) + AND (:city IS NULL OR LOWER(city) LIKE CONCAT('%', LOWER(:city), '%')) + AND (:name IS NULL OR LOWER(name) LIKE CONCAT('%', LOWER(:name), '%')) + LIMIT 10; + """ + ) + params = { + "country": country, + "city": city, + "name": name, + } + results = (conn.execute(s, parameters=params)).mappings().fetchall() + + res = [models.Airport.model_validate(r) for r in results] + return res + + async def get_amenity(self, id: int) -> Optional[models.Amenity]: + with self.__pool.connect() as conn: + s = text( + """ + SELECT id, name, description, location, terminal, category, hour + FROM amenities WHERE id=:id + """ + ) + params = {"id" : id} + result = (conn.execute(s, parameters=params)).mappings().fetchone() + + if result is None: + return None + + res = models.Amenity.model_validate(result) + return res + + async def amenities_search( + self, query_embedding: list[float], similarity_threshold: float, top_k: int + ) -> list[Any]: + with self.__pool.connect() as conn: + s = text( + """ + SELECT name, description, location, terminal, category, hour + FROM amenities + WHERE NEAREST(embedding) TO (string_to_vector(:query), :search_options) + """ + ) + params = { + "query": f"{query_embedding}", + "search_options": f"num_neighbors={top_k}" + } + results = (conn.execute(s, parameters=params)).mappings().fetchall() + + res = [r for r in results] + return res + + async def get_flight(self, flight_id: int) -> Optional[models.Flight]: + with self.__pool.connect() as conn: + s = text( + """ + SELECT * FROM flights + WHERE id = :flight_id + """ + ) + params = {"flight_id": flight_id} + result = (conn.execute(s, parameters=params)).mappings().fetchone() + + if result is None: + return None + + res = models.Flight.model_validate(result) + return res + + async def search_flights_by_number( + self, + airline: str, + number: str, + ) -> list[models.Flight]: + with self.__pool.connect() as conn: + s = text( + """ + SELECT * FROM flights + WHERE airline = :airline + AND flight_number = :number + LIMIT 10 + """ + ) + params = { + "airline": airline, + "number": number, + } + results = (conn.execute(s, parameters=params)).mappings().fetchall() + + res = [models.Flight.model_validate(r) for r in results] + return res + + async def search_flights_by_airports( + self, + date: str, + departure_airport: Optional[str] = None, + arrival_airport: Optional[str] = None, + ) -> list[models.Flight]: + with self.__pool.connect() as conn: + s = text( + """ + SELECT * FROM flights + WHERE (CAST(:departure_airport AS CHAR(255)) IS NULL OR LOWER(departure_airport) LIKE LOWER(:departure_airport)) + AND (CAST(:arrival_airport AS CHAR(255)) IS NULL OR LOWER(arrival_airport) LIKE LOWER(:arrival_airport)) + AND departure_time >= CAST(:datetime AS DATETIME) + AND (departure_time < DATE_ADD(CAST(:datetime AS DATETIME), interval 1 day)) + LIMIT 10 + """ + ) + params = { + "departure_airport": departure_airport, + "arrival_airport": arrival_airport, + "datetime": datetime.strptime(date, "%Y-%m-%d"), + } + + results = (conn.execute(s, parameters=params)).mappings().fetchall() + + res = [models.Flight.model_validate(r) for r in results] + return res + + async def insert_ticket( + self, + user_id: str, + user_name: str, + user_email: str, + airline: str, + flight_number: str, + departure_airport: str, + arrival_airport: str, + departure_time: str, + arrival_time: str, + ): + raise NotImplementedError("Not Implemented") + + async def list_tickets( + self, + user_id: str, + ) -> list[models.Ticket]: + raise NotImplementedError("Not Implemented") + + async def policies_search( + self, query_embedding: list[float], similarity_threshold: float, top_k: int + ) -> list[str]: + with self.__pool.connect() as conn: + s = text( + """ + SELECT content + FROM policies + WHERE NEAREST(embedding) TO (string_to_vector(:query), :search_options) + """ + ) + params = { + "query": f"{query_embedding}", + "search_options": f"num_neighbors={top_k}" + } + + results = (conn.execute(s, parameters=params)).mappings().fetchall() + + res = [r["content"] for r in results] + return res + + async def close(self): + with self.__pool.connect() as conn: + s = text( + """ + CALL mysql.drop_vector_index(:index_name) + """ + ) + params = [ + {"index_name": "assistantdemo.amenities_index"}, + {"index_name": "assistantdemo.policies_index"}, + ] + + conn.execute(s, parameters=params) + self.__pool.dispose() \ No newline at end of file diff --git a/retrieval_service/datastore/providers/cloudsql_mysql_test.py b/retrieval_service/datastore/providers/cloudsql_mysql_test.py new file mode 100644 index 00000000..c6114937 --- /dev/null +++ b/retrieval_service/datastore/providers/cloudsql_mysql_test.py @@ -0,0 +1,642 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from datetime import datetime +from typing import Any, AsyncGenerator, List + +import pymysql +import pytest +import pytest_asyncio +from csv_diff import compare, load_csv # type: ignore +from google.cloud.sql.connector import Connector + +import models + +from .. import datastore +from . import cloudsql_mysql +from .test_data import ( + amenities_query_embedding1, + amenities_query_embedding2, + foobar_query_embedding, + policies_query_embedding1, + policies_query_embedding2, +) +from .utils import get_env_var + +pytestmark = pytest.mark.asyncio(scope="module") + + +@pytest.fixture(scope="module") +def db_user() -> str: + return "mysql" + + +@pytest.fixture(scope="module") +def db_pass() -> str: + return "my-cloudsql-pass" + + +@pytest.fixture(scope="module") +def db_project() -> str: + return "juliaofferman-playground" + + +@pytest.fixture(scope="module") +def db_region() -> str: + return "us-central1" + + +@pytest.fixture(scope="module") +def db_instance() -> str: + return "my-vector-app" + + +@pytest.fixture(scope="module") +async def create_db( + db_user: str, db_pass: str, db_project: str, db_region: str, db_instance: str +) -> AsyncGenerator[str, None]: + db_name = "assistantdemo" + loop = asyncio.get_running_loop() + connector = Connector(loop=loop) + # Database does not exist, create it. + sys_conn: pymysql.Connection = await connector.connect_async( + # Cloud SQL instance connection name + f"{db_project}:{db_region}:{db_instance}", + "pymysql", + user=f"{db_user}", + password=f"{db_pass}", + db=f"{db_name}", + ) + cursor = sys_conn.cursor() + + cursor.execute(f'DROP DATABASE IF EXISTS assistantdemo;') + cursor.execute(f'CREATE DATABASE assistantdemo;') + cursor.close() + conn: pymysql.Connection = await connector.connect_async( + # Cloud SQL instance connection name + f"{db_project}:{db_region}:{db_instance}", + "pymysql", + user=f"{db_user}", + password=f"{db_pass}", + db=f"{db_name}", + ) + yield db_name + await conn.close() + + + +@pytest_asyncio.fixture(scope="module") +async def ds( + create_db: AsyncGenerator[str, None], + db_user: str, + db_pass: str, + db_project: str, + db_region: str, + db_instance: str, +) -> AsyncGenerator[datastore.Client, None]: + db_name = await create_db.__anext__() + cfg = cloudsql_mysql.Config( + kind="cloudsql-mysql", + user=db_user, + password=db_pass, + database=db_name, + project=db_project, + region=db_region, + instance=db_instance, + ) + t = create_db + ds = await datastore.create(cfg) + + airports_ds_path = "../data/airport_dataset.csv" + amenities_ds_path = "../data/amenity_dataset.csv" + flights_ds_path = "../data/flights_dataset.csv" + policies_ds_path = "../data/cymbalair_policy.csv" + airports, amenities, flights, policies = await ds.load_dataset( + airports_ds_path, + amenities_ds_path, + flights_ds_path, + policies_ds_path, + ) + await ds.initialize_data(airports, amenities, flights, policies) + + if ds is None: + raise TypeError("datastore creation failure") + yield ds + await ds.close() + + +def check_file_diff(file_diff, has_embedding_column=False): + assert file_diff["added"] == [] + assert file_diff["removed"] == [] + assert file_diff["columns_added"] == [] + assert file_diff["columns_removed"] == [] + # MySQL rounds embedding values, so column will appear 'changed' + if not has_embedding_column: + assert file_diff["changed"] == [] + + +async def test_export_dataset(ds: cloudsql_mysql.Client): + airports, amenities, flights, policies = await ds.export_data() + + airports_ds_path = "../data/airport_dataset.csv" + amenities_ds_path = "../data/amenity_dataset.csv" + flights_ds_path = "../data/flights_dataset.csv" + policies_ds_path = "../data/cymbalair_policy.csv" + + airports_new_path = "../data/airport_dataset.csv.new" + amenities_new_path = "../data/amenity_dataset.csv.new" + flights_new_path = "../data/flights_dataset.csv.new" + policies_new_path = "../data/cymbalair_policy.csv.new" + + await ds.export_dataset( + airports, + amenities, + flights, + policies, + airports_new_path, + amenities_new_path, + flights_new_path, + policies_new_path, + ) + + diff_airports = compare( + load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id") + ) + check_file_diff(diff_airports) + + diff_amenities = compare( + load_csv(open(amenities_ds_path), "id"), + load_csv(open(amenities_new_path), "id"), + ) + check_file_diff(diff_amenities, True) + + diff_flights = compare( + load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id") + ) + check_file_diff(diff_flights) + + diff_policies = compare( + load_csv(open(policies_ds_path), "id"), + load_csv(open(policies_new_path), "id"), + ) + + check_file_diff(diff_policies, True) + + +async def test_get_airport_by_id(ds: cloudsql_mysql.Client): + res = await ds.get_airport_by_id(1) + expected = models.Airport( + id=1, + iata="MAG", + name="Madang Airport", + city="Madang", + country="Papua New Guinea", + ) + assert res == expected + + +@pytest.mark.parametrize( + "iata", + [ + pytest.param("SFO", id="upper_case"), + pytest.param("sfo", id="lower_case"), + ], +) +async def test_get_airport_by_iata(ds: cloudsql_mysql.Client, iata: str): + res = await ds.get_airport_by_iata(iata) + expected = models.Airport( + id=3270, + iata="SFO", + name="San Francisco International Airport", + city="San Francisco", + country="United States", + ) + assert res == expected + + +search_airports_test_data = [ + pytest.param( + "Philippines", + "San jose", + None, + [ + models.Airport( + id=2299, + iata="SJI", + name="San Jose Airport", + city="San Jose", + country="Philippines", + ), + models.Airport( + id=2313, + iata="EUQ", + name="Evelio Javier Airport", + city="San Jose", + country="Philippines", + ), + ], + id="country_and_city_only", + ), + pytest.param( + "united states", + "san francisco", + None, + [ + models.Airport( + id=3270, + iata="SFO", + name="San Francisco International Airport", + city="San Francisco", + country="United States", + ) + ], + id="country_and_name_only", + ), + pytest.param( + None, + "San Jose", + "San Jose", + [ + models.Airport( + id=1714, + iata="GSJ", + name="San José Airport", + city="San Jose", + country="Guatemala", + ), + models.Airport( + id=2299, + iata="SJI", + name="San Jose Airport", + city="San Jose", + country="Philippines", + ), + models.Airport( + id=3548, + iata="SJC", + name="Norman Y. Mineta San Jose International Airport", + city="San Jose", + country="United States", + ), + ], + id="city_and_name_only", + ), + pytest.param( + "Foo", + "FOO BAR", + "Foo bar", + [], + id="no_results", + ), +] + + +@pytest.mark.parametrize("country, city, name, expected", search_airports_test_data) +async def test_search_airports( + ds: cloudsql_mysql.Client, + country: str, + city: str, + name: str, + expected: List[models.Airport], +): + res = await ds.search_airports(country, city, name) + assert res == expected + + +async def test_get_amenity(ds: cloudsql_mysql.Client): + res = await ds.get_amenity(0) + expected = models.Amenity( + id=0, + name="Coffee Shop 732", + description="Serving American cuisine.", + location="Near Gate B12", + terminal="Terminal 3", + category="restaurant", + hour="Daily 7:00 am - 10:00 pm", + sunday_start_hour=None, + sunday_end_hour=None, + monday_start_hour=None, + monday_end_hour=None, + tuesday_start_hour=None, + tuesday_end_hour=None, + wednesday_start_hour=None, + wednesday_end_hour=None, + thursday_start_hour=None, + thursday_end_hour=None, + friday_start_hour=None, + friday_end_hour=None, + saturday_start_hour=None, + saturday_end_hour=None, + ) + assert res == expected + + +amenities_search_test_data = [ + pytest.param( + # "Where can I get coffee near gate A6?" + amenities_query_embedding1, + 0.35, + 1, + [ + { + "name": "Coffee Shop 732", + "description": "Serving American cuisine.", + "location": "Near Gate B12", + "terminal": "Terminal 3", + "category": "restaurant", + "hour": "Daily 7:00 am - 10:00 pm", + }, + ], + id="search_coffee_shop", + ), + pytest.param( + # "Where can I look for luxury goods?" + amenities_query_embedding2, + 0.35, + 2, + [ + { + "name": "Dufry Duty Free", + "description": "Duty-free shop offering a large selection of luxury goods, including perfumes, cosmetics, and liquor.", + "location": "Gate E2", + "terminal": "International Terminal A", + "category": "shop", + "hour": "Daily 7:00 am-10:00 pm", + }, + { + "name": "Gucci Duty Free", + "description": "Luxury brand duty-free shop offering designer clothing, accessories, and fragrances.", + "location": "Gate E9", + "terminal": "International Terminal A", + "category": "shop", + "hour": "Daily 7:00 am-10:00 pm", + }, + ], + id="search_luxury_goods", + ), +] + + +@pytest.mark.parametrize( + "query_embedding, similarity_threshold, top_k, expected", amenities_search_test_data +) +async def test_amenities_search( + ds: cloudsql_mysql.Client, + query_embedding: List[float], + similarity_threshold: float, + top_k: int, + expected: List[Any], +): + res = await ds.amenities_search(query_embedding, similarity_threshold, top_k) + assert res == expected + + +async def test_get_flight(ds: cloudsql_mysql.Client): + res = await ds.get_flight(1) + expected = models.Flight( + id=1, + airline="UA", + flight_number="1158", + departure_airport="SFO", + arrival_airport="ORD", + departure_time=datetime.strptime("2024-01-01 05:57:00", "%Y-%m-%d %H:%M:%S"), + arrival_time=datetime.strptime("2024-01-01 12:13:00", "%Y-%m-%d %H:%M:%S"), + departure_gate="C38", + arrival_gate="D30", + ) + assert res == expected + + +search_flights_by_number_test_data = [ + pytest.param( + "UA", + "1158", + [ + models.Flight( + id=1, + airline="UA", + flight_number="1158", + departure_airport="SFO", + arrival_airport="ORD", + departure_time=datetime.strptime( + "2024-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" + ), + arrival_time=datetime.strptime( + "2024-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" + ), + departure_gate="C38", + arrival_gate="D30", + ), + models.Flight( + id=55455, + airline="UA", + flight_number="1158", + departure_airport="SFO", + arrival_airport="JFK", + departure_time=datetime.strptime( + "2024-10-15 05:18:00", "%Y-%m-%d %H:%M:%S" + ), + arrival_time=datetime.strptime( + "2024-10-15 08:40:00", "%Y-%m-%d %H:%M:%S" + ), + departure_gate="B50", + arrival_gate="E4", + ), + ], + id="successful_airport_search", + ), + pytest.param( + "UU", + "0000", + [], + id="no_results", + ), +] + + +@pytest.mark.parametrize( + "airline, number, expected", search_flights_by_number_test_data +) +async def test_search_flights_by_number( + ds: cloudsql_mysql.Client, + airline: str, + number: str, + expected: List[models.Flight], +): + res = await ds.search_flights_by_number(airline, number) + assert res == expected + + +search_flights_by_airports_test_data = [ + pytest.param( + "2024-01-01", + "SFO", + "ORD", + [ + models.Flight( + id=1, + airline="UA", + flight_number="1158", + departure_airport="SFO", + arrival_airport="ORD", + departure_time=datetime.strptime( + "2024-01-01 05:57:00", "%Y-%m-%d %H:%M:%S" + ), + arrival_time=datetime.strptime( + "2024-01-01 12:13:00", "%Y-%m-%d %H:%M:%S" + ), + departure_gate="C38", + arrival_gate="D30", + ), + models.Flight( + id=13, + airline="UA", + flight_number="616", + departure_airport="SFO", + arrival_airport="ORD", + departure_time=datetime.strptime( + "2024-01-01 07:14:00", "%Y-%m-%d %H:%M:%S" + ), + arrival_time=datetime.strptime( + "2024-01-01 13:24:00", "%Y-%m-%d %H:%M:%S" + ), + departure_gate="A11", + arrival_gate="D8", + ), + models.Flight( + id=25, + airline="AA", + flight_number="242", + departure_airport="SFO", + arrival_airport="ORD", + departure_time=datetime.strptime( + "2024-01-01 08:18:00", "%Y-%m-%d %H:%M:%S" + ), + arrival_time=datetime.strptime( + "2024-01-01 14:26:00", "%Y-%m-%d %H:%M:%S" + ), + departure_gate="E30", + arrival_gate="C1", + ), + models.Flight( + id=109, + airline="UA", + flight_number="1640", + departure_airport="SFO", + arrival_airport="ORD", + departure_time=datetime.strptime( + "2024-01-01 17:01:00", "%Y-%m-%d %H:%M:%S" + ), + arrival_time=datetime.strptime( + "2024-01-01 23:02:00", "%Y-%m-%d %H:%M:%S" + ), + departure_gate="E27", + arrival_gate="C24", + ), + models.Flight( + id=119, + airline="AA", + flight_number="197", + departure_airport="SFO", + arrival_airport="ORD", + departure_time=datetime.strptime( + "2024-01-01 17:21:00", "%Y-%m-%d %H:%M:%S" + ), + arrival_time=datetime.strptime( + "2024-01-01 23:33:00", "%Y-%m-%d %H:%M:%S" + ), + departure_gate="D25", + arrival_gate="E49", + ), + models.Flight( + id=136, + airline="UA", + flight_number="1564", + departure_airport="SFO", + arrival_airport="ORD", + departure_time=datetime.strptime( + "2024-01-01 19:14:00", "%Y-%m-%d %H:%M:%S" + ), + arrival_time=datetime.strptime( + "2024-01-02 01:14:00", "%Y-%m-%d %H:%M:%S" + ), + departure_gate="E3", + arrival_gate="C48", + ), + ], + id="successful_airport_search", + ), + pytest.param( + "2024-01-01", + "FOO", + "BAR", + [], + id="no_results", + ), +] + + +@pytest.mark.parametrize( + "date, departure_airport, arrival_airport, expected", + search_flights_by_airports_test_data, +) +async def test_search_flights_by_airports( + ds: cloudsql_mysql.Client, + date: str, + departure_airport: str, + arrival_airport: str, + expected: List[models.Flight], +): + res = await ds.search_flights_by_airports(date, departure_airport, arrival_airport) + assert res == expected + + +policies_search_test_data = [ + pytest.param( + # "What is the fee for extra baggage?" + policies_query_embedding1, + 0.35, + 1, + [ + "## Baggage\nChecked Baggage: Economy passengers are allowed 2 checked bags. Business class and First class passengers are allowed 4 checked bags. Additional baggage will cost $70 and a $30 fee applies for all checked bags over 50 lbs. Cymbal Air cannot accept checked bags over 100 lbs. We only accept checked bags up to 115 inches in total dimensions (length + width + height), and oversized baggage will cost $30. Checked bags above 160 inches in total dimensions will not be accepted.", + ], + id="search_extra_baggage_fee", + ), + pytest.param( + # "Can I change my flight?" + policies_query_embedding2, + 0.35, + 2, + [ + "# Cymbal Air: Passenger Policy \n## Ticket Purchase and Changes\nTypes of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased.", + "Changes: Changes to tickets are permitted at any time until 60 minutes prior to scheduled departure. There are no fees for changes as long as the new ticket is on Cymbal Air and is at an equal or lower price. If the new ticket has a higher price, the customer must pay the difference between the new and old fares. Changes to a non-Cymbal-Air flight include a $100 change fee.", + ], + id="search_flight_delays", + ), +] + + +@pytest.mark.parametrize( + "query_embedding, similarity_threshold, top_k, expected", policies_search_test_data +) +async def test_policies_search( + ds: cloudsql_mysql.Client, + query_embedding: List[float], + similarity_threshold: float, + top_k: int, + expected: List[str], +): + res = await ds.policies_search(query_embedding, similarity_threshold, top_k) + assert res == expected From 63855be6e3e4158d341cfc43ddc645e153512b59 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Fri, 14 Jun 2024 23:21:15 +0000 Subject: [PATCH 02/15] Fix export data test --- .../datastore/providers/cloudsql_mysql.py | 9 +++++++++ .../datastore/providers/cloudsql_mysql_test.py | 18 ++++++++++++------ .../example-config-cloudsql-mysql.yml | 10 ++++++++++ retrieval_service/requirements.txt | 1 + 4 files changed, 32 insertions(+), 6 deletions(-) create mode 100644 retrieval_service/example-config-cloudsql-mysql.yml diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py index 675f68ce..c24fcb09 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql.py @@ -483,6 +483,15 @@ async def search_flights_by_airports( res = [models.Flight.model_validate(r) for r in results] return res + + async def validate_ticket( + self, + airline: str, + flight_number: str, + departure_airport: str, + departure_time: str, + ) -> Optional[models.Flight]: + raise NotImplementedError("Not Implemented") async def insert_ticket( self, diff --git a/retrieval_service/datastore/providers/cloudsql_mysql_test.py b/retrieval_service/datastore/providers/cloudsql_mysql_test.py index c6114937..329d454c 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql_test.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql_test.py @@ -136,15 +136,21 @@ async def ds( yield ds await ds.close() +def only_embedding_changed(file_diff): + return all( + key == "embedding" + for change in file_diff["changed"] + for key in change["changes"] + ) + + -def check_file_diff(file_diff, has_embedding_column=False): +def check_file_diff(file_diff): assert file_diff["added"] == [] assert file_diff["removed"] == [] assert file_diff["columns_added"] == [] assert file_diff["columns_removed"] == [] - # MySQL rounds embedding values, so column will appear 'changed' - if not has_embedding_column: - assert file_diff["changed"] == [] + assert file_diff["changed"] == [] or only_embedding_changed(file_diff) async def test_export_dataset(ds: cloudsql_mysql.Client): @@ -180,7 +186,7 @@ async def test_export_dataset(ds: cloudsql_mysql.Client): load_csv(open(amenities_ds_path), "id"), load_csv(open(amenities_new_path), "id"), ) - check_file_diff(diff_amenities, True) + check_file_diff(diff_amenities) diff_flights = compare( load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id") @@ -192,7 +198,7 @@ async def test_export_dataset(ds: cloudsql_mysql.Client): load_csv(open(policies_new_path), "id"), ) - check_file_diff(diff_policies, True) + check_file_diff(diff_policies) async def test_get_airport_by_id(ds: cloudsql_mysql.Client): diff --git a/retrieval_service/example-config-cloudsql-mysql.yml b/retrieval_service/example-config-cloudsql-mysql.yml new file mode 100644 index 00000000..ccc539cd --- /dev/null +++ b/retrieval_service/example-config-cloudsql-mysql.yml @@ -0,0 +1,10 @@ +host: 0.0.0.0 +datastore: + # Example for Cloud SQL + kind: "cloudsql-mysql" + project: "my-project" + region: "my-region" + instance: "my-instance" + database: "my_database" + user: "my-user" + password: "my-password" \ No newline at end of file diff --git a/retrieval_service/requirements.txt b/retrieval_service/requirements.txt index e4c7c6b9..22c6c069 100644 --- a/retrieval_service/requirements.txt +++ b/retrieval_service/requirements.txt @@ -17,3 +17,4 @@ langchain-text-splitters==0.2.0 langchain-google-vertexai==1.0.4 asyncio==3.4.3 datetime==5.5 +pymysql==1.1.1 From 4f147b3fd527a5077bd2d3b5a9103d3e1ef65393 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Fri, 14 Jun 2024 23:24:11 +0000 Subject: [PATCH 03/15] Read parameters from config instead of hard coding. --- retrieval_service/datastore/providers/cloudsql_mysql.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py index c24fcb09..715a6fa0 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql.py @@ -56,11 +56,11 @@ def getconn() -> pymysql.Connection: with Connector() as connector: conn: pymysql.Connection = connector.connect( # Cloud SQL instance connection name - "juliaofferman-playground:us-central1:my-vector-app", + f"{config.project}:{config.region}:{config.instance}", "pymysql", - user="mysql", - password="my-cloudsql-pass", - db="assistantdemo", + user=f"{config.user}", + password=f"{config.password}", + db=f"{config.database}", autocommit=True, ) return conn From 330080f64f4b22842f2812e99e6e815b0f594ede Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Fri, 14 Jun 2024 23:40:01 +0000 Subject: [PATCH 04/15] Add Cloud SQL MYSQL to the datastore & providers init files --- retrieval_service/datastore/__init__.py | 1 + retrieval_service/datastore/providers/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/retrieval_service/datastore/__init__.py b/retrieval_service/datastore/__init__.py index abd8d54f..7f376e71 100644 --- a/retrieval_service/datastore/__init__.py +++ b/retrieval_service/datastore/__init__.py @@ -23,6 +23,7 @@ providers.cloudsql_postgres.Config, providers.spanner_gsql.Config, providers.alloydb.Config, + providers.cloudsql_mysql.Config, ] __ALL__ = [Client, Config, create, providers] diff --git a/retrieval_service/datastore/providers/__init__.py b/retrieval_service/datastore/providers/__init__.py index ad63773c..cd4987da 100644 --- a/retrieval_service/datastore/providers/__init__.py +++ b/retrieval_service/datastore/providers/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import alloydb, cloudsql_postgres, firestore, postgres, spanner_gsql +from . import alloydb, cloudsql_postgres, firestore, postgres, spanner_gsql, cloudsql_mysql -__ALL__ = [alloydb, postgres, cloudsql_postgres, firestore, spanner_gsql] +__ALL__ = [alloydb, postgres, cloudsql_postgres, firestore, spanner_gsql, cloudsql_mysql] From 9aafc96355e9bd4dd2fa883d1fb04c07b6ef4aea Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Mon, 17 Jun 2024 18:54:16 +0000 Subject: [PATCH 05/15] Wrap synchronous functions in async executor --- .../datastore/providers/cloudsql_mysql.py | 139 +++++++++++++++--- 1 file changed, 121 insertions(+), 18 deletions(-) diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py index 715a6fa0..51f492af 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql.py @@ -49,9 +49,7 @@ def __init__(self, pool: Engine): self.__pool = pool @classmethod - async def create(cls, config: Config) -> "Client": - loop = asyncio.get_running_loop() - + def create_sync(cls, config: Config) -> "Client": def getconn() -> pymysql.Connection: with Connector() as connector: conn: pymysql.Connection = connector.connect( @@ -73,7 +71,15 @@ def getconn() -> pymysql.Connection: raise TypeError("pool not instantiated") return cls(pool) - async def initialize_data( + + @classmethod + async def create(cls, config: Config) -> "Client": + loop = asyncio.get_running_loop() + + pool = await loop.run_in_executor(None, cls.create_sync, config) + return pool + + def initialize_data_sync( self, airports: list[models.Airport], amenities: list[models.Amenity], @@ -278,7 +284,18 @@ async def initialize_data( # Create a vector index on the embedding column conn.execute(text("CALL mysql.create_vector_index('policies_index', 'assistantdemo.policies', 'embedding', '')")) - async def export_data( + + async def initialize_data( + self, + airports: list[models.Airport], + amenities: list[models.Amenity], + flights: list[models.Flight], + policies: list[models.Policy], + ) -> None: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self.initialize_data_sync, airports, amenities, flights, policies) + + def export_data_sync( self, ) -> tuple[ list[models.Airport], @@ -328,8 +345,20 @@ async def export_data( policies = [models.Policy.model_validate(p) for p in policy_results] return airports, amenities, flights, policies - - async def get_airport_by_id(self, id: int) -> Optional[models.Airport]: + + async def export_data( + self, + ) -> tuple[ + list[models.Airport], + list[models.Amenity], + list[models.Flight], + list[models.Policy], + ]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.export_data_sync) + return res + + def get_airport_by_id_sync(self, id: int) -> Optional[models.Airport]: with self.__pool.connect() as conn: s = text("""SELECT * FROM airports WHERE id=:id""") params = {"id" : id} @@ -340,8 +369,13 @@ async def get_airport_by_id(self, id: int) -> Optional[models.Airport]: res = models.Airport.model_validate(result) return res + + async def get_airport_by_id(self, id: int) -> Optional[models.Airport]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.get_airport_by_id_sync, id) + return res - async def get_airport_by_iata(self, iata: str) -> Optional[models.Airport]: + def get_airport_by_iata_sync(self, iata: str) -> Optional[models.Airport]: with self.__pool.connect() as conn: s = text("""SELECT * FROM airports WHERE LOWER(iata) LIKE LOWER(:iata)""") params = {"iata": iata} @@ -353,7 +387,12 @@ async def get_airport_by_iata(self, iata: str) -> Optional[models.Airport]: res = models.Airport.model_validate(result) return res - async def search_airports( + async def get_airport_by_iata(self, iata: str) -> Optional[models.Airport]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.get_airport_by_iata_sync, iata) + return res + + def search_airports_sync( self, country: Optional[str] = None, city: Optional[str] = None, @@ -379,7 +418,17 @@ async def search_airports( res = [models.Airport.model_validate(r) for r in results] return res - async def get_amenity(self, id: int) -> Optional[models.Amenity]: + async def search_airports( + self, + country: Optional[str] = None, + city: Optional[str] = None, + name: Optional[str] = None, + ) -> list[models.Airport]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.search_airports_sync, country, city, name) + return res + + def get_amenity_sync(self, id: int) -> Optional[models.Amenity]: with self.__pool.connect() as conn: s = text( """ @@ -396,7 +445,12 @@ async def get_amenity(self, id: int) -> Optional[models.Amenity]: res = models.Amenity.model_validate(result) return res - async def amenities_search( + async def get_amenity(self, id: int) -> Optional[models.Amenity]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.get_amenity_sync, id) + return res + + def amenities_search_sync( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> list[Any]: with self.__pool.connect() as conn: @@ -416,7 +470,14 @@ async def amenities_search( res = [r for r in results] return res - async def get_flight(self, flight_id: int) -> Optional[models.Flight]: + async def amenities_search( + self, query_embedding: list[float], similarity_threshold: float, top_k: int + ) -> list[Any]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.amenities_search_sync, query_embedding, similarity_threshold, top_k) + return res + + def get_flight_sync(self, flight_id: int) -> Optional[models.Flight]: with self.__pool.connect() as conn: s = text( """ @@ -432,8 +493,13 @@ async def get_flight(self, flight_id: int) -> Optional[models.Flight]: res = models.Flight.model_validate(result) return res - - async def search_flights_by_number( + + async def get_flight(self, flight_id: int) -> Optional[models.Flight]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.get_flight_sync, flight_id) + return res + + def search_flights_by_number_sync( self, airline: str, number: str, @@ -455,8 +521,17 @@ async def search_flights_by_number( res = [models.Flight.model_validate(r) for r in results] return res - - async def search_flights_by_airports( + + async def search_flights_by_number( + self, + airline: str, + number: str, + ) -> list[models.Flight]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.search_flights_by_number_sync, airline, number) + return res + + def search_flights_by_airports_sync( self, date: str, departure_airport: Optional[str] = None, @@ -484,7 +559,17 @@ async def search_flights_by_airports( res = [models.Flight.model_validate(r) for r in results] return res - async def validate_ticket( + async def search_flights_by_airports( + self, + date: str, + departure_airport: Optional[str] = None, + arrival_airport: Optional[str] = None, + ) -> list[models.Flight]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.search_flights_by_airports_sync, date, departure_airport, arrival_airport) + return res + + def validate_ticket_sync( self, airline: str, flight_number: str, @@ -493,6 +578,17 @@ async def validate_ticket( ) -> Optional[models.Flight]: raise NotImplementedError("Not Implemented") + async def validate_ticket( + self, + airline: str, + flight_number: str, + departure_airport: str, + departure_time: str, + ) -> Optional[models.Flight]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.validate_ticket_sync, airline, flight_number, departure_airport, departure_time) + return res + async def insert_ticket( self, user_id: str, @@ -513,7 +609,7 @@ async def list_tickets( ) -> list[models.Ticket]: raise NotImplementedError("Not Implemented") - async def policies_search( + def policies_search_sync( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> list[str]: with self.__pool.connect() as conn: @@ -534,6 +630,13 @@ async def policies_search( res = [r["content"] for r in results] return res + async def policies_search( + self, query_embedding: list[float], similarity_threshold: float, top_k: int + ) -> list[str]: + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.policies_search_sync, query_embedding, similarity_threshold, top_k) + return res + async def close(self): with self.__pool.connect() as conn: s = text( From bfdad9873513fb7254c7bea50efb5b84c4904655 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Mon, 17 Jun 2024 19:30:06 +0000 Subject: [PATCH 06/15] Write MySQL setup doc --- docs/datastore/cloudsql_mysql.md | 182 +++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 docs/datastore/cloudsql_mysql.md diff --git a/docs/datastore/cloudsql_mysql.md b/docs/datastore/cloudsql_mysql.md new file mode 100644 index 00000000..24fe04c2 --- /dev/null +++ b/docs/datastore/cloudsql_mysql.md @@ -0,0 +1,182 @@ +# Setup and configure Cloud SQL for MySQL + +## Before you begin + +1. Make sure you have a Google Cloud project and billing is enabled. + +1. Set your `PROJECT_ID` environment variable: + + ```bash + export PROJECT_ID= + ``` + +1. [Install](https://cloud.google.com/sdk/docs/install) the gcloud CLI. + +1. Set gcloud project: + + ```bash + gcloud config set project $PROJECT_ID + ``` + +1. Enable APIs: + + ```bash + gcloud services enable sqladmin.googleapis.com \ + aiplatform.googleapis.com + ``` + +1. [Install python][install-python] and set up a python [virtual environment][venv]. + +1. Make sure you have python version 3.11+ installed. + + ```bash + python -V + ``` + +1. Download and install [mysql-client cli (`mysql`)][install-mysql]. + +1. Install the [Cloud SQL Auth Proxy client][install-cloudsql-proxy]. + +[install-python]: https://cloud.google.com/python/docs/setup#installing_python +[venv]: https://cloud.google.com/python/docs/setup#installing_and_using_virtualenv +[install-mysql]: https://dev.mysql.com/doc/mysql-installation-excerpt/8.0/en/ +[install-cloudsql-proxy]: https://cloud.google.com/sql/docs/mysql/connect-auth-proxy + + +## Create a Cloud SQL for MySQL instance + +1. Set environment variables. For security reasons, use a different password for + `$DB_PASS` and note it for future use: + + ```bash + export DB_PASS=my-cloudsql-pass + export DB_USER=mysql + export INSTANCE=my-cloudsql-instance + export REGION=us-central1 + ``` + +1. Create a MySQL instance with vector enabled: + + ```bash + gcloud sql instances create $INSTANCE \ + --database-version=MYSQL_8_0_36 \ + --cpu=4 \ + --memory=16GB \ + --region=$REGION \ + --database-flags=cloudsql_vector=ON + ``` + +1. Set password for mysql user: + + ```bash + gcloud sql users set-password $DB_USER \ + --instance=$INSTANCE \ + --password=$DB_PASS + ``` + + +## Connect to the Cloud SQL instance + +1. Connect to instance using cloud sql proxy: + + ```bash + ./cloud-sql-proxy $PROJECT_ID:$REGION:$INSTANCE + ``` + +1. Verify you can connect to your instance with the `mysql` tool. Enter + password for Cloud SQL (`$DB_PASS` environment variable set above) when prompted: + + ```bash + mysql "host=127.0.0.1 port=5432 sslmode=disable user=$DB_USER" + ``` + +## Update config + +Update `config.yml` with your database information. + +```bash +host: 0.0.0.0 +datastore: + # Example for cloudsql_mysql.py provider + kind: "cloudsql-mysql" + # Update this with your project ID + project: + region: us-central1 + instance: my-cloudsql-instance + # Update this with the database name + database: "assistantdemo" + # Update with database user, the default is `mysql` + user: "mysql" + # Update with database user password + password: "my-cloudsql-pass" +``` + +## Initialize data + +1. While connected using `mysql`, create a database and switch to it: + + ```bash + CREATE DATABASE assistantdemo; + \c assistantdemo + ``` + +1. Change into the retrieval service directory: + + ```bash + cd genai-databases-retrieval-app/retrieval_service + ``` + +1. Install requirements: + + ```bash + pip install -r requirements.txt + ``` + +1. Make a copy of `example-config.yml` and name it `config.yml`. + + ```bash + cp example-config.yml config.yml + ``` + +1. Populate data into database: + + ```bash + python run_database_init.py + ``` + +## Clean up resources + +Clean up after completing the demo. + +1. Delete the Cloud SQL instance: + + ```bash + gcloud sql instances delete my-cloudsql-instance + ``` + +## Developer information + +This section is for developers that want to develop and run the app locally. + +### Test Environment Variables + +Set environment variables: + +```bash +export DB_USER="" +export DB_PASS="" +export DB_PROJECT="" +export DB_REGION="" +export DB_INSTANCE="" +``` + +### Run tests + +Run retrieval service unit tests: + +```bash +gcloud builds submit --config retrieval_service/cloudsql.tests.cloudbuild.yaml \ + --substitutions _DATABASE_NAME=$DB_NAME,_DATABASE_USER=$DB_USER,_CLOUDSQL_REGION=$DB_REGION,_CLOUDSQL_INSTANCE=$DB_INSTANCE +``` + +Where `$DB_NAME`,`$DB_USER`,`$DB_REGION`,`$DB_CLUSTER`,`$DB_INSTANCE` are environment variables with your database values. \ No newline at end of file From 7f142eabc9f1d598924f939ac6726be3bb08c817 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Mon, 17 Jun 2024 19:36:32 +0000 Subject: [PATCH 07/15] Fixing formatting --- retrieval_service/datastore/providers/cloudsql_mysql.py | 3 +-- retrieval_service/datastore/providers/cloudsql_mysql_test.py | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py index 51f492af..8d837790 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql.py @@ -71,7 +71,6 @@ def getconn() -> pymysql.Connection: raise TypeError("pool not instantiated") return cls(pool) - @classmethod async def create(cls, config: Config) -> "Client": loop = asyncio.get_running_loop() @@ -284,7 +283,6 @@ def initialize_data_sync( # Create a vector index on the embedding column conn.execute(text("CALL mysql.create_vector_index('policies_index', 'assistantdemo.policies', 'embedding', '')")) - async def initialize_data( self, airports: list[models.Airport], @@ -638,6 +636,7 @@ async def policies_search( return res async def close(self): + # Vector indexes must be dropped before any DDLs on the base table are permitted with self.__pool.connect() as conn: s = text( """ diff --git a/retrieval_service/datastore/providers/cloudsql_mysql_test.py b/retrieval_service/datastore/providers/cloudsql_mysql_test.py index 329d454c..2f996fd8 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql_test.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql_test.py @@ -96,7 +96,6 @@ async def create_db( await conn.close() - @pytest_asyncio.fixture(scope="module") async def ds( create_db: AsyncGenerator[str, None], @@ -144,7 +143,6 @@ def only_embedding_changed(file_diff): ) - def check_file_diff(file_diff): assert file_diff["added"] == [] assert file_diff["removed"] == [] From bad1486770d1c8c7790ddb19bb93965da76d9d1b Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Tue, 18 Jun 2024 21:51:01 +0000 Subject: [PATCH 08/15] Remove hardcoding the database name and formatting changes --- .../datastore/providers/__init__.py | 4 ++-- .../datastore/providers/cloudsql_mysql.py | 18 ++++++++++-------- .../datastore/providers/cloudsql_mysql_test.py | 18 +++++++++--------- .../example-config-cloudsql-mysql.yml | 3 ++- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/retrieval_service/datastore/providers/__init__.py b/retrieval_service/datastore/providers/__init__.py index cd4987da..17b06c3d 100644 --- a/retrieval_service/datastore/providers/__init__.py +++ b/retrieval_service/datastore/providers/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import alloydb, cloudsql_postgres, firestore, postgres, spanner_gsql, cloudsql_mysql +from . import alloydb, cloudsql_mysql, cloudsql_postgres, firestore, postgres, spanner_gsql -__ALL__ = [alloydb, postgres, cloudsql_postgres, firestore, spanner_gsql, cloudsql_mysql] +__ALL__ = [alloydb, postgres, cloudsql_mysql, cloudsql_postgres, firestore, spanner_gsql] \ No newline at end of file diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py index 8d837790..8b3bca9d 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql.py @@ -40,13 +40,15 @@ class Config(BaseModel, datastore.AbstractConfig): class Client(datastore.Client[Config]): __pool: Engine + __db_name: str @datastore.classproperty def kind(cls): return "cloudsql-mysql" - def __init__(self, pool: Engine): + def __init__(self, pool: Engine, db_name: str): self.__pool = pool + self.__db_name = db_name @classmethod def create_sync(cls, config: Config) -> "Client": @@ -69,12 +71,12 @@ def getconn() -> pymysql.Connection: ) if pool is None: raise TypeError("pool not instantiated") - return cls(pool) + return cls(pool, config.database) @classmethod async def create(cls, config: Config) -> "Client": loop = asyncio.get_running_loop() - + pool = await loop.run_in_executor(None, cls.create_sync, config) return pool @@ -190,7 +192,7 @@ def initialize_data_sync( ) # Create a vector index for the embeddings column - conn.execute(text("CALL mysql.create_vector_index('amenities_index', 'assistantdemo.amenities', 'embedding', '')")) + conn.execute(text(f"CALL mysql.create_vector_index('amenities_index', '{self.__db_name}.amenities', 'embedding', '')")) # If the table already exists, drop it to avoid conflicts conn.execute(text("DROP TABLE IF EXISTS flights")) @@ -281,7 +283,7 @@ def initialize_data_sync( } for p in policies]) # Create a vector index on the embedding column - conn.execute(text("CALL mysql.create_vector_index('policies_index', 'assistantdemo.policies', 'embedding', '')")) + conn.execute(text(f"CALL mysql.create_vector_index('policies_index', '{self.__db_name}.policies', 'embedding', '')")) async def initialize_data( self, @@ -644,9 +646,9 @@ async def close(self): """ ) params = [ - {"index_name": "assistantdemo.amenities_index"}, - {"index_name": "assistantdemo.policies_index"}, + {"index_name": f"{self.__db_name}.amenities_index"}, + {"index_name": f"{self.__db_name}.policies_index"}, ] conn.execute(s, parameters=params) - self.__pool.dispose() \ No newline at end of file + self.__pool.dispose() diff --git a/retrieval_service/datastore/providers/cloudsql_mysql_test.py b/retrieval_service/datastore/providers/cloudsql_mysql_test.py index 2f996fd8..9fa08e1b 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql_test.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql_test.py @@ -40,34 +40,34 @@ @pytest.fixture(scope="module") def db_user() -> str: - return "mysql" + return get_env_var("DB_USER", "name of a mysql user") @pytest.fixture(scope="module") def db_pass() -> str: - return "my-cloudsql-pass" + return get_env_var("DB_PASS", "password for the mysql user") @pytest.fixture(scope="module") def db_project() -> str: - return "juliaofferman-playground" + return get_env_var("DB_PROJECT", "project id for google cloud") @pytest.fixture(scope="module") def db_region() -> str: - return "us-central1" + return get_env_var("DB_REGION", "region for cloud sql instance") @pytest.fixture(scope="module") def db_instance() -> str: - return "my-vector-app" + return get_env_var("DB_INSTANCE", "instance for cloud sql") @pytest.fixture(scope="module") async def create_db( db_user: str, db_pass: str, db_project: str, db_region: str, db_instance: str ) -> AsyncGenerator[str, None]: - db_name = "assistantdemo" + db_name = get_env_var("DB_NAME", "name of a postgres database") loop = asyncio.get_running_loop() connector = Connector(loop=loop) # Database does not exist, create it. @@ -77,12 +77,12 @@ async def create_db( "pymysql", user=f"{db_user}", password=f"{db_pass}", - db=f"{db_name}", + db="mysql", ) cursor = sys_conn.cursor() - cursor.execute(f'DROP DATABASE IF EXISTS assistantdemo;') - cursor.execute(f'CREATE DATABASE assistantdemo;') + cursor.execute(f'DROP DATABASE IF EXISTS {db_name};') + cursor.execute(f'CREATE DATABASE {db_name};') cursor.close() conn: pymysql.Connection = await connector.connect_async( # Cloud SQL instance connection name diff --git a/retrieval_service/example-config-cloudsql-mysql.yml b/retrieval_service/example-config-cloudsql-mysql.yml index ccc539cd..c3649e70 100644 --- a/retrieval_service/example-config-cloudsql-mysql.yml +++ b/retrieval_service/example-config-cloudsql-mysql.yml @@ -7,4 +7,5 @@ datastore: instance: "my-instance" database: "my_database" user: "my-user" - password: "my-password" \ No newline at end of file + password: "my-password" + \ No newline at end of file From 744f3c220cc3adadffe2046d0e9993e349b30419 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Tue, 18 Jun 2024 21:56:00 +0000 Subject: [PATCH 09/15] Use MySQL identifier in provider class --- retrieval_service/datastore/providers/cloudsql_mysql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py index 8b3bca9d..8579d8cf 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql.py @@ -44,7 +44,7 @@ class Client(datastore.Client[Config]): @datastore.classproperty def kind(cls): - return "cloudsql-mysql" + return MYSQL_IDENTIFIER def __init__(self, pool: Engine, db_name: str): self.__pool = pool @@ -76,7 +76,7 @@ def getconn() -> pymysql.Connection: @classmethod async def create(cls, config: Config) -> "Client": loop = asyncio.get_running_loop() - + pool = await loop.run_in_executor(None, cls.create_sync, config) return pool From d8c75c8ce626982bc2f90b8a590266c84d3d4f98 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Thu, 20 Jun 2024 20:43:56 +0000 Subject: [PATCH 10/15] Update readme and implement ticket functions --- docs/datastore/cloudsql_mysql.md | 8 +- .../datastore/providers/cloudsql_mysql.py | 103 +++++++++++++++++- .../providers/cloudsql_mysql_test.py | 37 ++++++- .../example-config-cloudsql-mysql.yml | 11 -- 4 files changed, 140 insertions(+), 19 deletions(-) delete mode 100644 retrieval_service/example-config-cloudsql-mysql.yml diff --git a/docs/datastore/cloudsql_mysql.md b/docs/datastore/cloudsql_mysql.md index 24fe04c2..0f805a1e 100644 --- a/docs/datastore/cloudsql_mysql.md +++ b/docs/datastore/cloudsql_mysql.md @@ -50,7 +50,7 @@ ```bash export DB_PASS=my-cloudsql-pass - export DB_USER=mysql + export DB_USER=root export INSTANCE=my-cloudsql-instance export REGION=us-central1 ``` @@ -87,7 +87,7 @@ password for Cloud SQL (`$DB_PASS` environment variable set above) when prompted: ```bash - mysql "host=127.0.0.1 port=5432 sslmode=disable user=$DB_USER" + mysql "host=127.0.0.1 port=3306 sslmode=disable user=$DB_USER" ``` ## Update config @@ -106,7 +106,7 @@ datastore: # Update this with the database name database: "assistantdemo" # Update with database user, the default is `mysql` - user: "mysql" + user: "root" # Update with database user password password: "my-cloudsql-pass" ``` @@ -175,7 +175,7 @@ export DB_INSTANCE="" Run retrieval service unit tests: ```bash -gcloud builds submit --config retrieval_service/cloudsql.tests.cloudbuild.yaml \ +gcloud builds submit --config retrieval_service/cloudsql-mysql.tests.cloudbuild.yaml \ --substitutions _DATABASE_NAME=$DB_NAME,_DATABASE_USER=$DB_USER,_CLOUDSQL_REGION=$DB_REGION,_CLOUDSQL_INSTANCE=$DB_INSTANCE ``` diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py index 8579d8cf..24e80f93 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql.py @@ -576,7 +576,29 @@ def validate_ticket_sync( departure_airport: str, departure_time: str, ) -> Optional[models.Flight]: - raise NotImplementedError("Not Implemented") + with self.__pool.connect() as conn: + s = text( + """ + SELECT * FROM flights + WHERE LOWER(airline) LIKE LOWER(:airline) + AND LOWER(flight_number) LIKE LOWER(:flight_number) + AND LOWER(departure_airport) LIKE LOWER(:departure_airport) + AND departure_time = CAST(:departure_time AS DATETIME) + LIMIT 10 + """ + ) + params = { + "airline": airline, + "flight_number": flight_number, + "departure_airport": departure_airport, + "departure_time": departure_time, + } + + result = (conn.execute(s, parameters=params)).mappings().fetchone() + if result is None: + return None + res = models.Flight.model_validate(result) + return res async def validate_ticket( self, @@ -589,6 +611,58 @@ async def validate_ticket( res = await loop.run_in_executor(None, self.validate_ticket_sync, airline, flight_number, departure_airport, departure_time) return res + def insert_ticket_sync( + self, + user_id: str, + user_name: str, + user_email: str, + airline: str, + flight_number: str, + departure_airport: str, + arrival_airport: str, + departure_time: str, + arrival_time: str, + ): + with self.__pool.connect() as conn: + s = text( + """ + INSERT INTO tickets ( + user_id, + user_name, + user_email, + airline, + flight_number, + departure_airport, + arrival_airport, + departure_time, + arrival_time + ) VALUES ( + :user_id, + :user_name, + :user_email, + :airline, + :flight_number, + :departure_airport, + :arrival_airport, + :departure_time, + :arrival_time + ); + """ + ) + params = { + "user_id": user_id, + "user_name": user_name, + "user_email": user_email, + "airline": airline, + "flight_number": flight_number, + "departure_airport": departure_airport, + "arrival_airport": arrival_airport, + "departure_time": departure_time, + "arrival_time": arrival_time, + } + conn.execute(s, params).mappings() + + async def insert_ticket( self, user_id: str, @@ -601,13 +675,36 @@ async def insert_ticket( departure_time: str, arrival_time: str, ): - raise NotImplementedError("Not Implemented") + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self.insert_ticket_sync, user_id, user_name, user_email, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time) + + def list_tickets_sync( + self, + user_id: str, + ) -> list[models.Ticket]: + with self.__pool.connect() as conn: + s = text( + """ + SELECT * FROM tickets + WHERE user_id = :user_id + """ + ) + params = { + "user_id": user_id, + } + + results = (conn.execute(s, parameters=params)).mappings().fetchall() + res = [models.Ticket.model_validate(r) for r in results] + return res + async def list_tickets( self, user_id: str, ) -> list[models.Ticket]: - raise NotImplementedError("Not Implemented") + loop = asyncio.get_running_loop() + res = await loop.run_in_executor(None, self.list_tickets_sync, user_id) + return res def policies_search_sync( self, query_embedding: list[float], similarity_threshold: float, top_k: int diff --git a/retrieval_service/datastore/providers/cloudsql_mysql_test.py b/retrieval_service/datastore/providers/cloudsql_mysql_test.py index 9fa08e1b..20523f84 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql_test.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql_test.py @@ -67,7 +67,7 @@ def db_instance() -> str: async def create_db( db_user: str, db_pass: str, db_project: str, db_region: str, db_instance: str ) -> AsyncGenerator[str, None]: - db_name = get_env_var("DB_NAME", "name of a postgres database") + db_name = get_env_var("DB_NAME", "name of a cloud sql mysql database") loop = asyncio.get_running_loop() connector = Connector(loop=loop) # Database does not exist, create it. @@ -607,6 +607,41 @@ async def test_search_flights_by_airports( assert res == expected +async def test_insert_ticket(ds: cloudsql_mysql.Client): + await ds.insert_ticket("1", "test", "test", "UA", "1532", "SFO", "DEN", "2024-01-01 05:50:00", "2024-01-01 09:23:00") + +async def test_list_tickets(ds: cloudsql_mysql.Client): + res = await ds.list_tickets("1") + expected = models.Ticket( + user_id=1, + user_name="test", + user_email="test", + airline="UA", + flight_number="1532", + departure_airport="SFO", + arrival_airport="DEN", + departure_time=datetime.strptime("2024-01-01 05:50:00", "%Y-%m-%d %H:%M:%S"), + arrival_time=datetime.strptime("2024-01-01 09:23:00", "%Y-%m-%d %H:%M:%S"), + ) + assert res == [expected] + +async def test_validate_ticket(ds: cloudsql_mysql.Client): + res = await ds.validate_ticket("UA", "1532", "SFO", "2024-01-01 05:50:00") + expected = models.Flight( + id=0, + airline="UA", + flight_number="1532", + departure_airport="SFO", + arrival_airport="DEN", + departure_time=datetime.strptime("2024-01-01 05:50:00", "%Y-%m-%d %H:%M:%S"), + arrival_time=datetime.strptime("2024-01-01 09:23:00", "%Y-%m-%d %H:%M:%S"), + departure_gate="E49", + arrival_gate="D6", + ) + assert res == expected + + + policies_search_test_data = [ pytest.param( # "What is the fee for extra baggage?" diff --git a/retrieval_service/example-config-cloudsql-mysql.yml b/retrieval_service/example-config-cloudsql-mysql.yml deleted file mode 100644 index c3649e70..00000000 --- a/retrieval_service/example-config-cloudsql-mysql.yml +++ /dev/null @@ -1,11 +0,0 @@ -host: 0.0.0.0 -datastore: - # Example for Cloud SQL - kind: "cloudsql-mysql" - project: "my-project" - region: "my-region" - instance: "my-instance" - database: "my_database" - user: "my-user" - password: "my-password" - \ No newline at end of file From dd73968f537298ad337d089fd9d70f01f2f023c8 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Thu, 20 Jun 2024 20:47:28 +0000 Subject: [PATCH 11/15] Fix MySQL default user in Read me --- docs/datastore/cloudsql_mysql.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/datastore/cloudsql_mysql.md b/docs/datastore/cloudsql_mysql.md index 0f805a1e..2a54da54 100644 --- a/docs/datastore/cloudsql_mysql.md +++ b/docs/datastore/cloudsql_mysql.md @@ -105,7 +105,7 @@ datastore: instance: my-cloudsql-instance # Update this with the database name database: "assistantdemo" - # Update with database user, the default is `mysql` + # Update with database user, the default is `root` user: "root" # Update with database user password password: "my-cloudsql-pass" From 073affe45040877fc1e1c25ea960f36f6d9e3808 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Mon, 24 Jun 2024 15:07:23 -0400 Subject: [PATCH 12/15] chore: newline at EOF --- retrieval_service/datastore/providers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/retrieval_service/datastore/providers/__init__.py b/retrieval_service/datastore/providers/__init__.py index 17b06c3d..dfed0d76 100644 --- a/retrieval_service/datastore/providers/__init__.py +++ b/retrieval_service/datastore/providers/__init__.py @@ -14,4 +14,4 @@ from . import alloydb, cloudsql_mysql, cloudsql_postgres, firestore, postgres, spanner_gsql -__ALL__ = [alloydb, postgres, cloudsql_mysql, cloudsql_postgres, firestore, spanner_gsql] \ No newline at end of file +__ALL__ = [alloydb, postgres, cloudsql_mysql, cloudsql_postgres, firestore, spanner_gsql] From 1594020eec637905c417675cfae8f1047ffc3332 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Mon, 24 Jun 2024 21:23:01 +0000 Subject: [PATCH 13/15] Formatting with linter --- .../datastore/providers/__init__.py | 18 +- .../datastore/providers/cloudsql_mysql.py | 188 +++++++++++++----- .../providers/cloudsql_mysql_test.py | 58 +++--- 3 files changed, 185 insertions(+), 79 deletions(-) diff --git a/retrieval_service/datastore/providers/__init__.py b/retrieval_service/datastore/providers/__init__.py index dfed0d76..fbbcbb46 100644 --- a/retrieval_service/datastore/providers/__init__.py +++ b/retrieval_service/datastore/providers/__init__.py @@ -12,6 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import alloydb, cloudsql_mysql, cloudsql_postgres, firestore, postgres, spanner_gsql +from . import ( + alloydb, + cloudsql_mysql, + cloudsql_postgres, + firestore, + postgres, + spanner_gsql, +) -__ALL__ = [alloydb, postgres, cloudsql_mysql, cloudsql_postgres, firestore, spanner_gsql] +__ALL__ = [ + alloydb, + postgres, + cloudsql_mysql, + cloudsql_postgres, + firestore, + spanner_gsql, +] diff --git a/retrieval_service/datastore/providers/cloudsql_mysql.py b/retrieval_service/datastore/providers/cloudsql_mysql.py index 24e80f93..d5be4d6e 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql.py @@ -16,18 +16,19 @@ from datetime import datetime from typing import Any, Literal, Optional +import pymysql from google.cloud.sql.connector import Connector from pydantic import BaseModel -from sqlalchemy import text, create_engine, Engine +from sqlalchemy import Engine, create_engine, text from sqlalchemy.engine.base import Engine -import pymysql import models from .. import datastore MYSQL_IDENTIFIER = "cloudsql-mysql" + class Config(BaseModel, datastore.AbstractConfig): kind: Literal["cloudsql-mysql"] project: str @@ -108,18 +109,22 @@ def initialize_data_sync( conn.execute( text( """INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)""" - ),parameters=[{ + ), + parameters=[ + { "id": a.id, "iata": a.iata, "name": a.name, "city": a.city, "country": a.country, - } for a in airports] + } + for a in airports + ], ) # If the table already exists, drop it to avoid conflicts conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE")) - + # Create a new table conn.execute( text( @@ -152,7 +157,7 @@ def initialize_data_sync( """ ) ) - + # Insert all the data conn.execute( text( @@ -164,7 +169,9 @@ def initialize_data_sync( :thursday_start_hour, :thursday_end_hour, :friday_start_hour, :friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, string_to_vector(:embedding)) """ - ),parameters=[{ + ), + parameters=[ + { "id": a.id, "name": a.name, "description": a.description, @@ -188,11 +195,17 @@ def initialize_data_sync( "saturday_end_hour": a.saturday_end_hour, "content": a.content, "embedding": f"{a.embedding}", - } for a in amenities] + } + for a in amenities + ], ) # Create a vector index for the embeddings column - conn.execute(text(f"CALL mysql.create_vector_index('amenities_index', '{self.__db_name}.amenities', 'embedding', '')")) + conn.execute( + text( + f"CALL mysql.create_vector_index('amenities_index', '{self.__db_name}.amenities', 'embedding', '')" + ) + ) # If the table already exists, drop it to avoid conflicts conn.execute(text("DROP TABLE IF EXISTS flights")) @@ -222,7 +235,9 @@ def initialize_data_sync( :departure_airport, :arrival_airport, :departure_time, :arrival_time, :departure_gate, :arrival_gate) """ - ),parameters=[{ + ), + parameters=[ + { "id": f.id, "airline": f.airline, "flight_number": f.flight_number, @@ -232,7 +247,9 @@ def initialize_data_sync( "arrival_time": f.arrival_time, "departure_gate": f.departure_gate, "arrival_gate": f.arrival_gate, - } for f in flights] + } + for f in flights + ], ) # If the table already exists, drop it to avoid conflicts @@ -276,14 +293,23 @@ def initialize_data_sync( """ INSERT INTO policies VALUES (:id, :content, string_to_vector(:embedding)) """ - ),parameters=[{ + ), + parameters=[ + { "id": p.id, "content": p.content, "embedding": f"{p.embedding}", - } for p in policies]) - + } + for p in policies + ], + ) + # Create a vector index on the embedding column - conn.execute(text(f"CALL mysql.create_vector_index('policies_index', '{self.__db_name}.policies', 'embedding', '')")) + conn.execute( + text( + f"CALL mysql.create_vector_index('policies_index', '{self.__db_name}.policies', 'embedding', '')" + ) + ) async def initialize_data( self, @@ -293,7 +319,9 @@ async def initialize_data( policies: list[models.Policy], ) -> None: loop = asyncio.get_running_loop() - await loop.run_in_executor(None, self.initialize_data_sync, airports, amenities, flights, policies) + await loop.run_in_executor( + None, self.initialize_data_sync, airports, amenities, flights, policies + ) def export_data_sync( self, @@ -304,8 +332,12 @@ def export_data_sync( list[models.Policy], ]: with self.__pool.connect() as conn: - airport_task = conn.execute(text("""SELECT * FROM airports ORDER BY id ASC""")) - amenity_task = conn.execute(text(""" + airport_task = conn.execute( + text("""SELECT * FROM airports ORDER BY id ASC""") + ) + amenity_task = conn.execute( + text( + """ SELECT id, name, description, @@ -330,9 +362,17 @@ def export_data_sync( content, vector_to_string(embedding) as embedding FROM amenities ORDER BY id ASC - """)) - flights_task = conn.execute(text("""SELECT * FROM flights ORDER BY id ASC""")) - policy_task = conn.execute(text("""SELECT id, content, vector_to_string(embedding) as embedding FROM policies ORDER BY id ASC""")) + """ + ) + ) + flights_task = conn.execute( + text("""SELECT * FROM flights ORDER BY id ASC""") + ) + policy_task = conn.execute( + text( + """SELECT id, content, vector_to_string(embedding) as embedding FROM policies ORDER BY id ASC""" + ) + ) airport_results = (airport_task).mappings().fetchall() amenity_results = (amenity_task).mappings().fetchall() @@ -345,7 +385,7 @@ def export_data_sync( policies = [models.Policy.model_validate(p) for p in policy_results] return airports, amenities, flights, policies - + async def export_data( self, ) -> tuple[ @@ -357,11 +397,11 @@ async def export_data( loop = asyncio.get_running_loop() res = await loop.run_in_executor(None, self.export_data_sync) return res - + def get_airport_by_id_sync(self, id: int) -> Optional[models.Airport]: with self.__pool.connect() as conn: s = text("""SELECT * FROM airports WHERE id=:id""") - params = {"id" : id} + params = {"id": id} result = (conn.execute(s, params)).mappings().fetchone() if result is None: @@ -369,7 +409,7 @@ def get_airport_by_id_sync(self, id: int) -> Optional[models.Airport]: res = models.Airport.model_validate(result) return res - + async def get_airport_by_id(self, id: int) -> Optional[models.Airport]: loop = asyncio.get_running_loop() res = await loop.run_in_executor(None, self.get_airport_by_id_sync, id) @@ -391,7 +431,7 @@ async def get_airport_by_iata(self, iata: str) -> Optional[models.Airport]: loop = asyncio.get_running_loop() res = await loop.run_in_executor(None, self.get_airport_by_iata_sync, iata) return res - + def search_airports_sync( self, country: Optional[str] = None, @@ -425,7 +465,9 @@ async def search_airports( name: Optional[str] = None, ) -> list[models.Airport]: loop = asyncio.get_running_loop() - res = await loop.run_in_executor(None, self.search_airports_sync, country, city, name) + res = await loop.run_in_executor( + None, self.search_airports_sync, country, city, name + ) return res def get_amenity_sync(self, id: int) -> Optional[models.Amenity]: @@ -436,7 +478,7 @@ def get_amenity_sync(self, id: int) -> Optional[models.Amenity]: FROM amenities WHERE id=:id """ ) - params = {"id" : id} + params = {"id": id} result = (conn.execute(s, parameters=params)).mappings().fetchone() if result is None: @@ -462,9 +504,9 @@ def amenities_search_sync( """ ) params = { - "query": f"{query_embedding}", - "search_options": f"num_neighbors={top_k}" - } + "query": f"{query_embedding}", + "search_options": f"num_neighbors={top_k}", + } results = (conn.execute(s, parameters=params)).mappings().fetchall() res = [r for r in results] @@ -474,9 +516,15 @@ async def amenities_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> list[Any]: loop = asyncio.get_running_loop() - res = await loop.run_in_executor(None, self.amenities_search_sync, query_embedding, similarity_threshold, top_k) + res = await loop.run_in_executor( + None, + self.amenities_search_sync, + query_embedding, + similarity_threshold, + top_k, + ) return res - + def get_flight_sync(self, flight_id: int) -> Optional[models.Flight]: with self.__pool.connect() as conn: s = text( @@ -493,12 +541,12 @@ def get_flight_sync(self, flight_id: int) -> Optional[models.Flight]: res = models.Flight.model_validate(result) return res - + async def get_flight(self, flight_id: int) -> Optional[models.Flight]: loop = asyncio.get_running_loop() res = await loop.run_in_executor(None, self.get_flight_sync, flight_id) return res - + def search_flights_by_number_sync( self, airline: str, @@ -521,16 +569,18 @@ def search_flights_by_number_sync( res = [models.Flight.model_validate(r) for r in results] return res - + async def search_flights_by_number( self, airline: str, number: str, ) -> list[models.Flight]: loop = asyncio.get_running_loop() - res = await loop.run_in_executor(None, self.search_flights_by_number_sync, airline, number) + res = await loop.run_in_executor( + None, self.search_flights_by_number_sync, airline, number + ) return res - + def search_flights_by_airports_sync( self, date: str, @@ -558,7 +608,7 @@ def search_flights_by_airports_sync( res = [models.Flight.model_validate(r) for r in results] return res - + async def search_flights_by_airports( self, date: str, @@ -566,9 +616,15 @@ async def search_flights_by_airports( arrival_airport: Optional[str] = None, ) -> list[models.Flight]: loop = asyncio.get_running_loop() - res = await loop.run_in_executor(None, self.search_flights_by_airports_sync, date, departure_airport, arrival_airport) - return res - + res = await loop.run_in_executor( + None, + self.search_flights_by_airports_sync, + date, + departure_airport, + arrival_airport, + ) + return res + def validate_ticket_sync( self, airline: str, @@ -608,8 +664,15 @@ async def validate_ticket( departure_time: str, ) -> Optional[models.Flight]: loop = asyncio.get_running_loop() - res = await loop.run_in_executor(None, self.validate_ticket_sync, airline, flight_number, departure_airport, departure_time) - return res + res = await loop.run_in_executor( + None, + self.validate_ticket_sync, + airline, + flight_number, + departure_airport, + departure_time, + ) + return res def insert_ticket_sync( self, @@ -662,7 +725,6 @@ def insert_ticket_sync( } conn.execute(s, params).mappings() - async def insert_ticket( self, user_id: str, @@ -676,7 +738,19 @@ async def insert_ticket( arrival_time: str, ): loop = asyncio.get_running_loop() - await loop.run_in_executor(None, self.insert_ticket_sync, user_id, user_name, user_email, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time) + await loop.run_in_executor( + None, + self.insert_ticket_sync, + user_id, + user_name, + user_email, + airline, + flight_number, + departure_airport, + arrival_airport, + departure_time, + arrival_time, + ) def list_tickets_sync( self, @@ -697,14 +771,14 @@ def list_tickets_sync( res = [models.Ticket.model_validate(r) for r in results] return res - + async def list_tickets( self, user_id: str, ) -> list[models.Ticket]: loop = asyncio.get_running_loop() res = await loop.run_in_executor(None, self.list_tickets_sync, user_id) - return res + return res def policies_search_sync( self, query_embedding: list[float], similarity_threshold: float, top_k: int @@ -718,9 +792,9 @@ def policies_search_sync( """ ) params = { - "query": f"{query_embedding}", - "search_options": f"num_neighbors={top_k}" - } + "query": f"{query_embedding}", + "search_options": f"num_neighbors={top_k}", + } results = (conn.execute(s, parameters=params)).mappings().fetchall() @@ -731,9 +805,15 @@ async def policies_search( self, query_embedding: list[float], similarity_threshold: float, top_k: int ) -> list[str]: loop = asyncio.get_running_loop() - res = await loop.run_in_executor(None, self.policies_search_sync, query_embedding, similarity_threshold, top_k) - return res - + res = await loop.run_in_executor( + None, + self.policies_search_sync, + query_embedding, + similarity_threshold, + top_k, + ) + return res + async def close(self): # Vector indexes must be dropped before any DDLs on the base table are permitted with self.__pool.connect() as conn: diff --git a/retrieval_service/datastore/providers/cloudsql_mysql_test.py b/retrieval_service/datastore/providers/cloudsql_mysql_test.py index 20523f84..56a283d5 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql_test.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql_test.py @@ -71,27 +71,27 @@ async def create_db( loop = asyncio.get_running_loop() connector = Connector(loop=loop) # Database does not exist, create it. - sys_conn: pymysql.Connection = await connector.connect_async( - # Cloud SQL instance connection name - f"{db_project}:{db_region}:{db_instance}", - "pymysql", - user=f"{db_user}", - password=f"{db_pass}", - db="mysql", - ) + sys_conn: pymysql.Connection = await connector.connect_async( + # Cloud SQL instance connection name + f"{db_project}:{db_region}:{db_instance}", + "pymysql", + user=f"{db_user}", + password=f"{db_pass}", + db="mysql", + ) cursor = sys_conn.cursor() - - cursor.execute(f'DROP DATABASE IF EXISTS {db_name};') - cursor.execute(f'CREATE DATABASE {db_name};') + + cursor.execute(f"DROP DATABASE IF EXISTS {db_name};") + cursor.execute(f"CREATE DATABASE {db_name};") cursor.close() conn: pymysql.Connection = await connector.connect_async( - # Cloud SQL instance connection name - f"{db_project}:{db_region}:{db_instance}", - "pymysql", - user=f"{db_user}", - password=f"{db_pass}", - db=f"{db_name}", - ) + # Cloud SQL instance connection name + f"{db_project}:{db_region}:{db_instance}", + "pymysql", + user=f"{db_user}", + password=f"{db_pass}", + db=f"{db_name}", + ) yield db_name await conn.close() @@ -135,13 +135,14 @@ async def ds( yield ds await ds.close() + def only_embedding_changed(file_diff): return all( - key == "embedding" - for change in file_diff["changed"] + key == "embedding" + for change in file_diff["changed"] for key in change["changes"] ) - + def check_file_diff(file_diff): assert file_diff["added"] == [] @@ -608,7 +609,18 @@ async def test_search_flights_by_airports( async def test_insert_ticket(ds: cloudsql_mysql.Client): - await ds.insert_ticket("1", "test", "test", "UA", "1532", "SFO", "DEN", "2024-01-01 05:50:00", "2024-01-01 09:23:00") + await ds.insert_ticket( + "1", + "test", + "test", + "UA", + "1532", + "SFO", + "DEN", + "2024-01-01 05:50:00", + "2024-01-01 09:23:00", + ) + async def test_list_tickets(ds: cloudsql_mysql.Client): res = await ds.list_tickets("1") @@ -625,6 +637,7 @@ async def test_list_tickets(ds: cloudsql_mysql.Client): ) assert res == [expected] + async def test_validate_ticket(ds: cloudsql_mysql.Client): res = await ds.validate_ticket("UA", "1532", "SFO", "2024-01-01 05:50:00") expected = models.Flight( @@ -641,7 +654,6 @@ async def test_validate_ticket(ds: cloudsql_mysql.Client): assert res == expected - policies_search_test_data = [ pytest.param( # "What is the fee for extra baggage?" From 2c36cec5151f8a4a79067806f624b01592885365 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Mon, 24 Jun 2024 22:14:02 +0000 Subject: [PATCH 14/15] Install pymysql types --- retrieval_service/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/retrieval_service/requirements.txt b/retrieval_service/requirements.txt index 4e8316c7..d39c6c50 100644 --- a/retrieval_service/requirements.txt +++ b/retrieval_service/requirements.txt @@ -18,3 +18,4 @@ langchain-google-vertexai==1.0.5 asyncio==3.4.3 datetime==5.5 pymysql==1.1.1 +types-PyMySQL==1.1.0 \ No newline at end of file From b9162ece5db89190bfbfb6e16349ceb7e39fa340 Mon Sep 17 00:00:00 2001 From: JULIA OFFERMAN Date: Mon, 24 Jun 2024 22:34:10 +0000 Subject: [PATCH 15/15] Remove handling return type in void connection.close() function --- retrieval_service/datastore/providers/cloudsql_mysql_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/retrieval_service/datastore/providers/cloudsql_mysql_test.py b/retrieval_service/datastore/providers/cloudsql_mysql_test.py index 56a283d5..4969790d 100644 --- a/retrieval_service/datastore/providers/cloudsql_mysql_test.py +++ b/retrieval_service/datastore/providers/cloudsql_mysql_test.py @@ -93,7 +93,7 @@ async def create_db( db=f"{db_name}", ) yield db_name - await conn.close() + conn.close() @pytest_asyncio.fixture(scope="module")