Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: renrank_model and pdf_parser bugs | Update: session API #2601

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions api/apps/sdk/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def completion(tenant_id):
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# {"role": "user", "content": "上海有吗?"}
# ]}
if "id" not in req:
return get_data_error_result(retmsg="id is required")
conv = ConversationService.query(id=req["id"])
if "session_id" not in req:
return get_data_error_result(retmsg="session_id is required")
conv = ConversationService.query(id=req["session_id"])
if not conv:
return get_data_error_result(retmsg="Session does not exist")
conv = conv[0]
Expand All @@ -108,7 +108,7 @@ def completion(tenant_id):
msg.append(m)
message_id = msg[-1].get("id")
e, dia = DialogService.get_by_id(conv.dialog_id)
del req["id"]
del req["session_id"]

if not conv.reference:
conv.reference = []
Expand Down Expand Up @@ -168,6 +168,9 @@ def get(tenant_id):
return get_data_error_result(retmsg="Session does not exist")
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(retmsg="You do not own the session")
if "assistant_id" in req:
if req["assistant_id"] != conv[0].dialog_id:
return get_data_error_result(retmsg="The session doesn't belong to the assistant")
conv = conv[0].to_dict()
conv['messages'] = conv.pop("message")
conv["assistant_id"] = conv.pop("dialog_id")
Expand Down Expand Up @@ -207,7 +210,7 @@ def list(tenant_id):
assistant_id = request.args["assistant_id"]
if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
return get_json_result(
data=False, retmsg=f'Only owner of the assistant is authorized for this operation.',
data=False, retmsg=f"You don't own the assistant.",
retcode=RetCode.OPERATING_ERROR)
convs = ConversationService.query(
dialog_id=assistant_id,
Expand Down
2 changes: 1 addition & 1 deletion deepdoc/parser/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def dfs(up, dp):
i += 1
continue

if not down["text"].strip():
if not down["text"].strip() or not up["text"].strip():
i += 1
continue

Expand Down
76 changes: 48 additions & 28 deletions rag/llm/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from rag.utils import num_tokens_from_string, truncate
import json


def sigmoid(x):
return 1 / (1 + np.exp(-x))


class Base(ABC):
def __init__(self, key, model_name):
pass
Expand Down Expand Up @@ -59,16 +61,19 @@ def __init__(self, key, model_name, **kwargs):
with DefaultRerank._model_lock:
if not DefaultRerank._model:
try:
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available())
DefaultRerank._model = FlagReranker(
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
use_fp16=torch.cuda.is_available())
except Exception as e:
model_dir = snapshot_download(repo_id= model_name,
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
model_dir = snapshot_download(repo_id=model_name,
local_dir=os.path.join(get_home_cache_dir(),
re.sub(r"^[a-zA-Z]+/", "", model_name)),
local_dir_use_symlinks=False)
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
self._model = DefaultRerank._model

def similarity(self, query: str, texts: list):
pairs = [(query,truncate(t, 2048)) for t in texts]
pairs = [(query, truncate(t, 2048)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
Expand All @@ -77,8 +82,10 @@ def similarity(self, query: str, texts: list):
for i in range(0, len(pairs), batch_size):
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
scores = sigmoid(np.array(scores)).tolist()
if isinstance(scores, float): res.append(scores)
else: res.extend(scores)
if isinstance(scores, float):
res.append(scores)
else:
res.extend(scores)
return np.array(res), token_count


Expand All @@ -101,7 +108,10 @@ def similarity(self, query: str, texts: list):
"top_n": len(texts)
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
rank = np.zeros(len(texts), dtype=float)
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
return rank, res["usage"]["total_tokens"]


class YoudaoRerank(DefaultRerank):
Expand All @@ -124,7 +134,7 @@ def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **k
"maidalun1020", "InfiniFlow"))

self._model = YoudaoRerank._model

def similarity(self, query: str, texts: list):
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
token_count = 0
Expand All @@ -135,8 +145,10 @@ def similarity(self, query: str, texts: list):
for i in range(0, len(pairs), batch_size):
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
scores = sigmoid(np.array(scores)).tolist()
if isinstance(scores, float): res.append(scores)
else: res.extend(scores)
if isinstance(scores, float):
res.append(scores)
else:
res.extend(scores)
return np.array(res), token_count


Expand All @@ -162,7 +174,10 @@ def similarity(self, query: str, texts: list):
"documents": texts
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
return np.array([d["relevance_score"] for d in res["results"]]), res["meta"]["tokens"]["input_tokens"]+res["meta"]["tokens"]["output_tokens"]
rank = np.zeros(len(texts), dtype=float)
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
return rank, res["meta"]["tokens"]["input_tokens"] + res["meta"]["tokens"]["output_tokens"]


class LocalAIRerank(Base):
Expand All @@ -175,7 +190,7 @@ def similarity(self, query: str, texts: list):

class NvidiaRerank(Base):
def __init__(
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
):
if not base_url:
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
Expand Down Expand Up @@ -208,9 +223,10 @@ def similarity(self, query: str, texts: list):
"top_n": len(texts),
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.array([d["logit"] for d in res["rankings"]])
indexs = [d["index"] for d in res["rankings"]]
return rank[indexs], token_count
rank = np.zeros(len(texts), dtype=float)
for d in res["rankings"]:
rank[d["index"]] = d["logit"]
return rank, token_count


class LmStudioRerank(Base):
Expand Down Expand Up @@ -247,9 +263,10 @@ def similarity(self, query: str, texts: list):
top_n=len(texts),
return_documents=False,
)
rank = np.array([d.relevance_score for d in res.results])
indexs = [d.index for d in res.results]
return rank[indexs], token_count
rank = np.zeros(len(texts), dtype=float)
for d in res.results:
rank[d.index] = d.relevance_score
return rank, token_count


class TogetherAIRerank(Base):
Expand All @@ -262,7 +279,7 @@ def similarity(self, query: str, texts: list):

class SILICONFLOWRerank(Base):
def __init__(
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/rerank"
Expand All @@ -287,10 +304,11 @@ def similarity(self, query: str, texts: list):
response = requests.post(
self.base_url, json=payload, headers=self.headers
).json()
rank = np.array([d["relevance_score"] for d in response["results"]])
indexs = [d["index"] for d in response["results"]]
rank = np.zeros(len(texts), dtype=float)
for d in response["results"]:
rank[d["index"]] = d["relevance_score"]
return (
rank[indexs],
rank,
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
)

Expand All @@ -312,9 +330,10 @@ def similarity(self, query: str, texts: list):
documents=texts,
top_n=len(texts),
).body
rank = np.array([d["relevance_score"] for d in res["results"]])
indexs = [d["index"] for d in res["results"]]
return rank[indexs], res["usage"]["total_tokens"]
rank = np.zeros(len(texts), dtype=float)
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
return rank, res["usage"]["total_tokens"]


class VoyageRerank(Base):
Expand All @@ -328,6 +347,7 @@ def similarity(self, query: str, texts: list):
res = self.client.rerank(
query=query, documents=texts, model=self.model_name, top_k=len(texts)
)
rank = np.array([r.relevance_score for r in res.results])
indexs = [r.index for r in res.results]
return rank[indexs], res.total_tokens
rank = np.zeros(len(texts), dtype=float)
for r in res.results:
rank[r.index] = r.relevance_score
return rank, res.total_tokens
2 changes: 1 addition & 1 deletion sdk/python/ragflow/modules/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def list_session(self) -> List[Session]:
raise Exception(res["retmsg"])

def get_session(self, id) -> Session:
res = self.get("/session/get", {"id": id})
res = self.get("/session/get", {"id": id,"assistant_id":self.id})
res = res.json()
if res.get("retmsg") == "success":
return Session(self.rag, res["data"])
Expand Down
14 changes: 9 additions & 5 deletions sdk/python/ragflow/modules/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ def chat(self, question: str, stream: bool = False):
if "reference" in message:
message.pop("reference")
res = self.post("/session/completion",
{"id": self.id, "question": question, "stream": stream}, stream=True)
{"session_id": self.id, "question": question, "stream": True}, stream=stream)
for line in res.iter_lines():
line = line.decode("utf-8")
if line.startswith("{"):
json_data = json.loads(line)
raise Exception(json_data["retmsg"])
if line.startswith("data:"):
json_data = json.loads(line[5:])
if json_data["data"] != True:
Expand Down Expand Up @@ -69,17 +72,18 @@ def __init__(self, rag, res_dict):
self.reference = None
self.role = "assistant"
self.prompt = None
self.id = None
super().__init__(rag, res_dict)


class Chunk(Base):
def __init__(self, rag, res_dict):
self.id = None
self.content = None
self.document_id = None
self.document_name = None
self.knowledgebase_id = None
self.image_id = None
self.document_id = ""
self.document_name = ""
self.knowledgebase_id = ""
self.image_id = ""
self.similarity = None
self.vector_similarity = None
self.term_similarity = None
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/test/t_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_create_chat_with_success(self):
question = "What is AI"
for ans in session.chat(question, stream=True):
pass
assert ans.content!="\n**ERROR**", "Please check this error."
assert not ans.content.startswith("**ERROR**"), "Please check this error."

def test_delete_session_with_success(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS)
Expand Down