Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 update for Qdrant #40

Merged
merged 6 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions dbs/qdrant/api/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from pydantic import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
extra="allow",
)

qdrant_service: str
qdrant_port: str
qdrant_host: str
qdrant_service: str
api_port = str
api_port: str
embedding_model_checkpoint: str
onnx_model_filename: str
tag: str

class Config:
env_file = ".env"
33 changes: 4 additions & 29 deletions dbs/qdrant/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,12 @@

from fastapi import FastAPI
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer

from api.config import Settings
from api.routers import rest

try:
from optimum.onnxruntime import ORTModelForCustomTasks
from optimum.pipelines import pipeline
from transformers import AutoTokenizer

model_type = "onnx"
except ModuleNotFoundError:
from sentence_transformers import SentenceTransformer

model_type = "sbert"
model_type = "sbert"


@lru_cache()
Expand All @@ -26,30 +18,13 @@ def get_settings():
return Settings()


def get_embedding_pipeline(onnx_path, model_filename: str):
"""
Create a sentence embedding pipeline using the optimized ONNX model, if available in the environment
"""
# Reload tokenizer
tokenizer = AutoTokenizer.from_pretrained(onnx_path)
optimized_model = ORTModelForCustomTasks.from_pretrained(onnx_path, file_name=model_filename)
embedding_pipeline = pipeline("feature-extraction", model=optimized_model, tokenizer=tokenizer)
return embedding_pipeline


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Async context manager for Qdrant database connection."""
settings = get_settings()
model_checkpoint = settings.embedding_model_checkpoint
if model_type == "sbert":
app.model = SentenceTransformer(model_checkpoint)
app.model_type = "sbert"
elif model_type == "onnx":
app.model = get_embedding_pipeline(
"onnx_model/onnx", model_filename=settings.onnx_model_filename
)
app.model_type = "onnx"
app.model = SentenceTransformer(model_checkpoint)
app.model_type = "sbert"
# Define Qdrant client
app.client = QdrantClient(host=settings.qdrant_service, port=settings.qdrant_port)
print("Successfully connected to Qdrant")
Expand Down
26 changes: 7 additions & 19 deletions dbs/qdrant/api/routers/rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException, Query, Request
from qdrant_client.http import models

from schemas.retriever import CountByCountry, SimilaritySearch
from api.schemas.rest import CountByCountry, SimilaritySearch

router = APIRouter()

Expand Down Expand Up @@ -124,14 +124,10 @@ def _search_by_similarity(
collection: str,
terms: str,
) -> list[SimilaritySearch] | None:
if request.app.model_type == "sbert":
vector = request.app.model.encode(terms, batch_size=64).tolist()
elif request.app.model_type == "onnx":
vector = request.app.model(terms, truncate=True)[0][0]

vector = request.app.model.encode(terms, batch_size=64).tolist()
# Use `vector` for similarity search on the closest vectors in the collection
search_result = request.app.client.search(
collection_name=collection, query_vector=vector, top=5
collection_name=collection, query_vector=vector, limit=5
)
# `search_result` contains found vector ids with similarity scores along with the stored payload
# For now we are interested in payload only
Expand All @@ -144,11 +140,7 @@ def _search_by_similarity(
def _search_by_similarity_and_country(
request: Request, collection: str, terms: str, country: str
) -> list[SimilaritySearch] | None:
if request.app.model_type == "sbert":
vector = request.app.model.encode(terms, batch_size=64).tolist()
elif request.app.model_type == "onnx":
vector = request.app.model(terms, truncate=True)[0][0]

vector = request.app.model.encode(terms, batch_size=64).tolist()
filter = models.Filter(
**{
"must": [
Expand All @@ -162,7 +154,7 @@ def _search_by_similarity_and_country(
}
)
search_result = request.app.client.search(
collection_name=collection, query_vector=vector, query_filter=filter, top=5
collection_name=collection, query_vector=vector, query_filter=filter, limit=5
)
payloads = [hit.payload for hit in search_result]
if not payloads:
Expand All @@ -178,11 +170,7 @@ def _search_by_similarity_and_filters(
points: int,
price: float,
) -> list[SimilaritySearch] | None:
if request.app.model_type == "sbert":
vector = request.app.model.encode(terms, batch_size=64).tolist()
elif request.app.model_type == "onnx":
vector = request.app.model(terms, truncate=True)[0][0]

vector = request.app.model.encode(terms, batch_size=64).tolist()
filter = models.Filter(
**{
"must": [
Expand All @@ -208,7 +196,7 @@ def _search_by_similarity_and_filters(
}
)
search_result = request.app.client.search(
collection_name=collection, query_vector=vector, query_filter=filter, top=5
collection_name=collection, query_vector=vector, query_filter=filter, limit=5
)
payloads = [hit.payload for hit in search_result]
if not payloads:
Expand Down
Empty file.
33 changes: 17 additions & 16 deletions dbs/qdrant/schemas/retriever.py → dbs/qdrant/api/schemas/rest.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class SimilaritySearch(BaseModel):
id: int
country: str
province: str | None
title: str
description: str | None
points: int
price: float | str | None
variety: str | None
winery: str | None

class Config:
extra = "ignore"
schema_extra = {
model_config = ConfigDict(
extra="ignore",
json_schema_extra={
"example": {
"id": 3845,
"wineID": 3845,
"country": "Italy",
"title": "Castellinuzza e Piuca 2010 Chianti Classico",
"description": "This gorgeous Chianti Classico boasts lively cherry, strawberry and violet aromas. The mouthwatering palate shows concentrated wild-cherry flavor layered with mint, white pepper and clove. It has fresh acidity and firm tannins that will develop complexity with more bottle age. A textbook Chianti Classico.",
Expand All @@ -25,7 +15,18 @@ class Config:
"variety": "Red Blend",
"winery": "Castellinuzza e Piuca",
}
}
},
)

id: int
country: str
province: str | None
title: str
description: str | None
points: int
price: float | str | None
variety: str | None
winery: str | None


class CountByCountry(BaseModel):
Expand Down
15 changes: 9 additions & 6 deletions dbs/qdrant/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
qdrant-client>=1.1.4
transformers==4.28.1
sentence-transformers==2.2.2
pydantic[dotenv]>=1.10.7, <2.0.0
fastapi>=0.95.0, <1.0.0
qdrant-client~=1.3.0
transformers~=4.28.0
sentence-transformers~=2.2.0
pydantic~=2.0.0
pydantic-settings>=2.0.0
python-dotenv>=1.0.0
fastapi~=0.100.0
httpx>=0.24.0
aiohttp>=3.8.4
uvloop>=0.17.0
uvicorn>=0.21.0, <1.0.0
srsly>=2.4.6
srsly>=2.4.6
56 changes: 25 additions & 31 deletions dbs/qdrant/schemas/wine.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,13 @@
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator


class Wine(BaseModel):
id: int
points: int
title: str
description: str | None
price: float | None
variety: str | None
winery: str | None
vineyard: str | None
country: str | None
province: str | None
region_1: str | None
region_2: str | None
taster_name: str | None
taster_twitter_handle: str | None

class Config:
extra = "allow"
allow_population_by_field_name = True
validate_assignment = True
schema_extra = {
model_config = ConfigDict(
populate_by_name=True,
validate_assignment=True,
extra="allow",
str_strip_whitespace=True,
json_schema_extra={
"example": {
"id": 45100,
"points": 85,
Expand All @@ -38,25 +24,33 @@ class Config:
"taster_name": "Michael Schachner",
"taster_twitter_handle": "@wineschach",
}
}
},
)

@root_validator(pre=True)
def _get_vineyard(cls, values):
"Rename designation to vineyard"
vineyard = values.pop("designation", None)
if vineyard:
values["vineyard"] = vineyard.strip()
return values
id: int
points: int
title: str
description: str | None
price: float | None
variety: str | None
winery: str | None
vineyard: str | None = Field(..., alias="designation")
country: str | None
province: str | None
region_1: str | None
region_2: str | None
taster_name: str | None
taster_twitter_handle: str | None

@root_validator
@model_validator(mode="before")
def _fill_country_unknowns(cls, values):
"Fill in missing country values with 'Unknown', as we always want this field to be queryable"
country = values.get("country")
if not country:
values["country"] = "Unknown"
return values

@root_validator
@model_validator(mode="before")
def _add_to_vectorize_fields(cls, values):
"Add a field to_vectorize that will be used to create sentence embeddings"
variety = values.get("variety", "")
Expand Down
9 changes: 4 additions & 5 deletions dbs/qdrant/scripts/bulk_index_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

import srsly
from dotenv import load_dotenv
from pydantic.main import ModelMetaclass
from qdrant_client import QdrantClient
from qdrant_client.http import models

sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1]))
from sentence_transformers import SentenceTransformer

from api.config import Settings
from schemas.wine import Wine
from sentence_transformers import SentenceTransformer

load_dotenv()
# Custom types
Expand Down Expand Up @@ -56,10 +56,9 @@ def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]:

def validate(
data: list[JsonBlob],
model: ModelMetaclass,
exclude_none: bool = False,
) -> list[JsonBlob]:
validated_data = [model(**item).dict(exclude_none=exclude_none) for item in data]
validated_data = [Wine(**item).model_dump(exclude_none=exclude_none) for item in data]
return validated_data


Expand Down Expand Up @@ -92,7 +91,7 @@ def add_vectors_to_index(data_chunk: tuple[JsonBlob, ...]) -> None:
settings = get_settings()
collection = "wines"
client = QdrantClient(host=settings.qdrant_host, port=settings.qdrant_port, timeout=None)
data = validate(data_chunk, Wine, exclude_none=True)
data = validate(data_chunk, exclude_none=True)

# Load a sentence transformer model for semantic similarity from a specified checkpoint
model_id = get_settings().embedding_model_checkpoint
Expand Down
32 changes: 4 additions & 28 deletions dbs/weaviate/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,9 @@
from api.config import Settings
from api.routers import rest

try:
from optimum.onnxruntime import ORTModelForCustomTasks
from optimum.pipelines import pipeline
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer

model_type = "onnx"
except ModuleNotFoundError:
from sentence_transformers import SentenceTransformer

model_type = "sbert"
model_type = "sbert"


@lru_cache()
Expand All @@ -26,30 +19,13 @@ def get_settings():
return Settings()


def get_embedding_pipeline(onnx_path, model_filename: str):
"""
Create a sentence embedding pipeline using the optimized ONNX model, if available in the environment
"""
# Reload tokenizer
tokenizer = AutoTokenizer.from_pretrained(onnx_path)
optimized_model = ORTModelForCustomTasks.from_pretrained(onnx_path, file_name=model_filename)
embedding_pipeline = pipeline("feature-extraction", model=optimized_model, tokenizer=tokenizer)
return embedding_pipeline


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Async context manager for Weaviate database connection."""
settings = get_settings()
model_checkpoint = settings.embedding_model_checkpoint
if model_type == "sbert":
app.model = SentenceTransformer(model_checkpoint)
app.model_type = "sbert"
elif model_type == "onnx":
app.model = get_embedding_pipeline(
"onnx_model/onnx", model_filename=settings.onnx_model_filename
)
app.model_type = "onnx"
app.model = SentenceTransformer(model_checkpoint)
app.model_type = "sbert"
# Create Weaviate client
HOST = settings.weaviate_service
PORT = settings.weaviate_port
Expand Down
Loading