diff --git a/extension_service/datastore/datastore.py b/extension_service/datastore/datastore.py index 47fb1a1f3..f6129e746 100644 --- a/extension_service/datastore/datastore.py +++ b/extension_service/datastore/datastore.py @@ -46,12 +46,17 @@ async def create(cls, config: C) -> "Client": @abstractmethod async def initialize_data( - self, flights: List[models.Flight], toys: List[models.Toy], embeddings: List[models.Embedding] + self, + flights: List[models.Flight], + toys: List[models.Toy], + embeddings: List[models.Embedding], ) -> None: pass @abstractmethod - async def export_data(self) -> Tuple[List[models.Flight], List[models.Toy], List[models.Embedding]]: + async def export_data( + self, + ) -> Tuple[List[models.Flight], List[models.Toy], List[models.Embedding]]: pass @abstractmethod diff --git a/extension_service/datastore/providers/postgres.py b/extension_service/datastore/providers/postgres.py index 71832d5cd..4f4af74fc 100644 --- a/extension_service/datastore/providers/postgres.py +++ b/extension_service/datastore/providers/postgres.py @@ -65,12 +65,15 @@ async def init(conn): return cls(pool) async def initialize_data( - self, flights: List[models.Flight], toys: List[models.Toy], embeddings: List[models.Embedding] + self, + flights: List[models.Flight], + toys: List[models.Toy], + embeddings: List[models.Embedding], ) -> None: async with self.__pool.acquire() as conn: - # If the table already exists, drop it to avoid conflicts + # If the table already exists, drop it to avoid conflicts await conn.execute("DROP TABLE IF EXISTS flights CASCADE") - # Create a new table + # Create a new table await conn.execute( """ CREATE TABLE flights( @@ -86,12 +89,23 @@ async def initialize_data( date DATE ) """ - ) - # Insert all the data + ) + # Insert all the data await conn.executemany( """INSERT INTO flights VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)""", [ - (f.id, f.airline, f.flight_number, f.origin_airport, f.destination_airport, f.departure_time, f.arrival_time, f.departure_gate, f.arrival_gate, datetime.datetime.strptime(f.date, '%Y-%m-%d').date()) + ( + f.id, + f.airline, + f.flight_number, + f.origin_airport, + f.destination_airport, + f.departure_time, + f.arrival_time, + f.departure_gate, + f.arrival_gate, + datetime.datetime.strptime(f.date, "%Y-%m-%d").date(), + ) for f in flights ], ) @@ -136,13 +150,20 @@ async def initialize_data( [(e.product_id, e.content, e.embedding) for e in embeddings], ) - async def export_data(self) -> Tuple[List[models.Flight], List[models.Toy], List[models.Embedding]]: - flights_task = asyncio.create_task(self.__pool.fetch("""SELECT * FROM flights""")) + async def export_data( + self, + ) -> Tuple[List[models.Flight], List[models.Toy], List[models.Embedding]]: + flights_task = asyncio.create_task( + self.__pool.fetch("""SELECT * FROM flights""") + ) toy_task = asyncio.create_task(self.__pool.fetch("""SELECT * FROM products""")) emb_task = asyncio.create_task( self.__pool.fetch("""SELECT * FROM product_embeddings""") ) - flights = [models.Flight.model_validate(dict(f, date=f["date"].isoformat())) for f in await flights_task] + flights = [ + models.Flight.model_validate(dict(f, date=f["date"].isoformat())) + for f in await flights_task + ] toys = [models.Toy.model_validate(dict(t)) for t in await toy_task] embeddings = [models.Embedding.model_validate(dict(v)) for v in await emb_task] diff --git a/extension_service/models/models.py b/extension_service/models/models.py index 45754de66..eee252424 100644 --- a/extension_service/models/models.py +++ b/extension_service/models/models.py @@ -19,6 +19,7 @@ from numpy import float32 from pydantic import BaseModel, ConfigDict, FieldValidationInfo, field_validator + class Flight(BaseModel): id: str airline: str @@ -31,6 +32,7 @@ class Flight(BaseModel): arrival_gate: str date: str + class Toy(BaseModel): product_id: str product_name: str diff --git a/extension_service/run_database_export.py b/extension_service/run_database_export.py index d55956d88..431d1f553 100644 --- a/extension_service/run_database_export.py +++ b/extension_service/run_database_export.py @@ -32,7 +32,18 @@ async def main(): await ds.close() with open("../data/flights_dataset.csv.new", "w") as f: - col_names = ['id', 'airline', 'flight_number', 'origin_airport', 'destination_airport', 'departure_time', 'arrival_time', 'departure_gate', 'arrival_gate', 'date'] + col_names = [ + "id", + "airline", + "flight_number", + "origin_airport", + "destination_airport", + "departure_time", + "arrival_time", + "departure_gate", + "arrival_gate", + "date", + ] writer = csv.DictWriter(f, col_names, delimiter=",") writer.writeheader() for t in flights: