From c57012b6e8178d1ed24b71e08e1272393fbdb81f Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 12 Jul 2024 12:30:42 +0800 Subject: [PATCH] fix bugs of rerank model with xinference --- api/apps/llm_app.py | 11 +++++++++++ rag/llm/rerank_model.py | 16 +++++++++------- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 678f74ab722..638db1851c6 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -165,6 +165,17 @@ def add_llm(): except Exception as e: msg += f"\nFail to access model({llm['llm_name']})." + str( e) + elif llm["model_type"] == LLMType.RERANK: + mdl = RerankModel[factory]( + key=None, model_name=llm["llm_name"], base_url=llm["api_base"] + ) + try: + arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"]) + if len(arr) == 0 or tc == 0: + raise Exception("Not known.") + except Exception as e: + msg += f"\nFail to access model({llm['llm_name']})." + str( + e) else: # TODO: check other type of models pass diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index fabf11ec52e..53f87a5e548 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -136,10 +136,11 @@ def similarity(self, query: str, texts: list): else: res.extend(scores) return np.array(res), token_count + class XInferenceRerank(Base): - def __init__(self,model_name="",base_url=""): - self.model_name=model_name - self.base_url=base_url + def __init__(self, key="xxxxxxx", model_name="", base_url=""): + self.model_name = model_name + self.base_url = base_url self.headers = { "Content-Type": "application/json", "accept": "application/json" @@ -147,11 +148,12 @@ def __init__(self,model_name="",base_url=""): def similarity(self, query: str, texts: list): data = { - "model":self.model_name, - "query":query, + "model": self.model_name, + "query": query, "return_documents": "true", "return_len": "true", - "documents":texts + "documents": texts } res = requests.post(self.base_url, headers=self.headers, json=data).json() - return np.array([d["relevance_score"] for d in res["results"]]),res["tokens"]["input_tokens"]+res["tokens"]["output_tokens"] + return np.array([d["relevance_score"] for d in res["results"]]), res["tokens"]["input_tokens"] + res["tokens"][ + "output_tokens"]