Skip to content

Commit

Permalink
extract _get_distance_func
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Feb 11, 2025
1 parent edb25c5 commit 017f6dc
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _create_collection(self, dimension: int):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session:
session.begin()
create_statement = sql_text(f"""
Expand All @@ -108,7 +109,7 @@ def _create_collection(self, dimension: int):
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
KEY (doc_id),
VECTOR INDEX idx_vector ((VEC_COSINE_DISTANCE(vector))) USING HNSW
VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
);
""")
session.execute(create_statement)
Expand Down Expand Up @@ -194,13 +195,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
)

docs = []
match self._distance_func:
case "l2":
tidb_dist_func = "VEC_L2_DISTANCE"
case "cosine":
tidb_dist_func = "VEC_COSINE_DISTANCE"
case _:
tidb_dist_func = "VEC_COSINE_DISTANCE"
tidb_dist_func = self._get_distance_func()

with Session(self._engine) as session:
select_statement = sql_text(f"""
Expand Down Expand Up @@ -240,6 +235,16 @@ def delete(self) -> None:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()

def _get_distance_func(self) -> str:
match self._distance_func:
case "l2":
tidb_dist_func = "VEC_L2_DISTANCE"
case "cosine":
tidb_dist_func = "VEC_COSINE_DISTANCE"
case _:
tidb_dist_func = "VEC_COSINE_DISTANCE"
return tidb_dist_func


class TiDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
Expand Down

0 comments on commit 017f6dc

Please sign in to comment.