Skip to content

Commit

Permalink
fix: AnalyticdbVector retrieval scores (langgenius#8803)
Browse files Browse the repository at this point in the history
  • Loading branch information
lpdink authored and lau-td committed Oct 23, 2024
1 parent ae541b2 commit 5110ce7
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,8 @@ def to_analyticdb_client_params(self):


class AnalyticdbVector(BaseVector):
_instance = None
_init = False

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self, collection_name: str, config: AnalyticdbConfig):
# collection_name must be updated every time
self._collection_name = collection_name.lower()
if AnalyticdbVector._init:
return
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
Expand All @@ -62,7 +51,6 @@ def __init__(self, collection_name: str, config: AnalyticdbConfig):
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
AnalyticdbVector._init = True

def _initialize(self) -> None:
cache_key = f"vector_indexing_{self.config.instance_id}"
Expand Down Expand Up @@ -257,11 +245,14 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
metadata=json.loads(match.metadata.get("metadata_")),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
Expand All @@ -286,12 +277,14 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.metadata.get("vector"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents

def delete(self) -> None:
Expand Down

0 comments on commit 5110ce7

Please sign in to comment.