diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 9854cfff7..d785c2445 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -58,6 +58,7 @@ class EmbedDoc1024(BaseDoc): class SearchedDoc(BaseDoc): retrieved_docs: DocList[TextDoc] initial_query: str + top_n: int = 1 class Config: json_encoders = {np.ndarray: lambda x: x.tolist()} diff --git a/comps/reranks/README.md b/comps/reranks/README.md index ac3ab3f78..ecec38272 100644 --- a/comps/reranks/README.md +++ b/comps/reranks/README.md @@ -100,3 +100,12 @@ curl http://localhost:8000/v1/reranking \ -d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}]}' \ -H 'Content-Type: application/json' ``` + +You can add the parameter `top_n` to specify the return number of the reranker model, default value is 1. + +```bash +curl http://localhost:8000/v1/reranking \ + -X POST \ + -d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}], "top_n":2}' \ + -H 'Content-Type: application/json' +``` diff --git a/comps/reranks/langchain/reranking_tei_xeon.py b/comps/reranks/langchain/reranking_tei_xeon.py index 0bfe88fb4..40bd5a6a8 100644 --- a/comps/reranks/langchain/reranking_tei_xeon.py +++ b/comps/reranks/langchain/reranking_tei_xeon.py @@ -1,6 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import heapq import json import os import re @@ -40,9 +41,11 @@ def reranking(input: SearchedDoc) -> LLMParamsDoc: headers = {"Content-Type": "application/json"} response = requests.post(url, data=json.dumps(data), headers=headers) response_data = response.json() - best_response = max(response_data, key=lambda response: response["score"]) - doc = input.retrieved_docs[best_response["index"]] - if doc.text and len(re.findall("[\u4E00-\u9FFF]", doc.text)) / len(doc.text) >= 0.3: + best_response_list = heapq.nlargest(input.top_n, response_data, key=lambda x: x["score"]) + context_str = "" + for best_response in best_response_list: + context_str = context_str + " " + input.retrieved_docs[best_response["index"]].text + if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3: # chinese context template = "仅基于以下背景回答问题:\n{context}\n问题: {question}" else: @@ -51,7 +54,7 @@ def reranking(input: SearchedDoc) -> LLMParamsDoc: Question: {question} """ prompt = ChatPromptTemplate.from_template(template) - final_prompt = prompt.format(context=doc.text, question=input.initial_query) + final_prompt = prompt.format(context=context_str, question=input.initial_query) statistics_dict["opea_service@reranking_tgi_gaudi"].append_latency(time.time() - start, None) return LLMParamsDoc(query=final_prompt.strip())