Skip to content

Commit

Permalink
Ran black formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
duwenxin99 committed Oct 12, 2023
1 parent 6683fa3 commit 2e99e54
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 12 deletions.
9 changes: 7 additions & 2 deletions extension_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,17 @@ async def create(cls, config: C) -> "Client":

@abstractmethod
async def initialize_data(
self, flights: List[models.Flight], toys: List[models.Toy], embeddings: List[models.Embedding]
self,
flights: List[models.Flight],
toys: List[models.Toy],
embeddings: List[models.Embedding],
) -> None:
pass

@abstractmethod
async def export_data(self) -> Tuple[List[models.Flight], List[models.Toy], List[models.Embedding]]:
async def export_data(
self,
) -> Tuple[List[models.Flight], List[models.Toy], List[models.Embedding]]:
pass

@abstractmethod
Expand Down
39 changes: 30 additions & 9 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ async def init(conn):
return cls(pool)

async def initialize_data(
self, flights: List[models.Flight], toys: List[models.Toy], embeddings: List[models.Embedding]
self,
flights: List[models.Flight],
toys: List[models.Toy],
embeddings: List[models.Embedding],
) -> None:
async with self.__pool.acquire() as conn:
# If the table already exists, drop it to avoid conflicts
# If the table already exists, drop it to avoid conflicts
await conn.execute("DROP TABLE IF EXISTS flights CASCADE")
# Create a new table
# Create a new table
await conn.execute(
"""
CREATE TABLE flights(
Expand All @@ -86,12 +89,23 @@ async def initialize_data(
date DATE
)
"""
)
# Insert all the data
)
# Insert all the data
await conn.executemany(
"""INSERT INTO flights VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)""",
[
(f.id, f.airline, f.flight_number, f.origin_airport, f.destination_airport, f.departure_time, f.arrival_time, f.departure_gate, f.arrival_gate, datetime.datetime.strptime(f.date, '%Y-%m-%d').date())
(
f.id,
f.airline,
f.flight_number,
f.origin_airport,
f.destination_airport,
f.departure_time,
f.arrival_time,
f.departure_gate,
f.arrival_gate,
datetime.datetime.strptime(f.date, "%Y-%m-%d").date(),
)
for f in flights
],
)
Expand Down Expand Up @@ -136,13 +150,20 @@ async def initialize_data(
[(e.product_id, e.content, e.embedding) for e in embeddings],
)

async def export_data(self) -> Tuple[List[models.Flight], List[models.Toy], List[models.Embedding]]:
flights_task = asyncio.create_task(self.__pool.fetch("""SELECT * FROM flights"""))
async def export_data(
self,
) -> Tuple[List[models.Flight], List[models.Toy], List[models.Embedding]]:
flights_task = asyncio.create_task(
self.__pool.fetch("""SELECT * FROM flights""")
)
toy_task = asyncio.create_task(self.__pool.fetch("""SELECT * FROM products"""))
emb_task = asyncio.create_task(
self.__pool.fetch("""SELECT * FROM product_embeddings""")
)
flights = [models.Flight.model_validate(dict(f, date=f["date"].isoformat())) for f in await flights_task]
flights = [
models.Flight.model_validate(dict(f, date=f["date"].isoformat()))
for f in await flights_task
]
toys = [models.Toy.model_validate(dict(t)) for t in await toy_task]
embeddings = [models.Embedding.model_validate(dict(v)) for v in await emb_task]

Expand Down
2 changes: 2 additions & 0 deletions extension_service/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from numpy import float32
from pydantic import BaseModel, ConfigDict, FieldValidationInfo, field_validator


class Flight(BaseModel):
id: str
airline: str
Expand All @@ -31,6 +32,7 @@ class Flight(BaseModel):
arrival_gate: str
date: str


class Toy(BaseModel):
product_id: str
product_name: str
Expand Down
13 changes: 12 additions & 1 deletion extension_service/run_database_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,18 @@ async def main():
await ds.close()

with open("../data/flights_dataset.csv.new", "w") as f:
col_names = ['id', 'airline', 'flight_number', 'origin_airport', 'destination_airport', 'departure_time', 'arrival_time', 'departure_gate', 'arrival_gate', 'date']
col_names = [
"id",
"airline",
"flight_number",
"origin_airport",
"destination_airport",
"departure_time",
"arrival_time",
"departure_gate",
"arrival_gate",
"date",
]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
for t in flights:
Expand Down

0 comments on commit 2e99e54

Please sign in to comment.