Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for API #3134

Merged
merged 7 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
echo "Waiting for service to be available..."
sleep 5
done
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest t_dataset.py t_chat.py t_session.py
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py

- name: Stop ragflow:dev
if: always() # always run this step even if previous steps failed
Expand Down
18 changes: 11 additions & 7 deletions api/apps/sdk/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,11 @@ def list_docs(dataset_id, tenant_id):
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}. ")
id = request.args.get("id")
name = request.args.get("name")
if not DocumentService.query(id=id,kb_id=dataset_id):
return get_error_data_result(retmsg=f"You don't own the document {id}.")
if not DocumentService.query(name=name,kb_id=dataset_id):
return get_error_data_result(retmsg=f"You don't own the document {name}.")
offset = int(request.args.get("offset", 1))
keywords = request.args.get("keywords","")
limit = int(request.args.get("limit", 1024))
Expand All @@ -204,7 +207,7 @@ def list_docs(dataset_id, tenant_id):
desc = False
else:
desc = True
docs, tol = DocumentService.get_list(dataset_id, offset, limit, orderby, desc, keywords, id)
docs, tol = DocumentService.get_list(dataset_id, offset, limit, orderby, desc, keywords, id,name)

# rename key's name
renamed_doc_list = []
Expand Down Expand Up @@ -321,8 +324,8 @@ def stop_parsing(tenant_id,dataset_id):
doc = DocumentService.query(id=id, kb_id=dataset_id)
if not doc:
return get_error_data_result(retmsg=f"You don't own the document {id}.")
if doc[0].progress == 100.0 or doc[0].progress == 0.0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 100")
if int(doc[0].progress) == 1 or int(doc[0].progress) == 0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
info = {"run": "2", "progress": 0,"chunk_num":0}
DocumentService.update_by_id(id, info)
ELASTICSEARCH.deleteByQuery(
Expand Down Expand Up @@ -414,9 +417,9 @@ def list_chunks(tenant_id,dataset_id,document_id):
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
if renamed_chunk["available"] == "0":
if renamed_chunk["available"] == 0:
renamed_chunk["available"] = False
if renamed_chunk["available"] == "1":
if renamed_chunk["available"] == 1:
renamed_chunk["available"] = True
res["chunks"].append(renamed_chunk)
return get_result(data=res)
Expand Down Expand Up @@ -464,6 +467,7 @@ def add_chunk(tenant_id,dataset_id,document_id):
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
d["chunk_id"] = chunk_id
d["kb_id"]=doc.kb_id
# rename keys
key_mapping = {
"chunk_id": "id",
Expand Down Expand Up @@ -581,10 +585,10 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id):
def retrieval_test(tenant_id):
req = request.json
if not req.get("dataset_ids"):
return get_error_data_result("`datasets` is required.")
return get_error_data_result("`dataset_ids` is required.")
kb_ids = req["dataset_ids"]
if not isinstance(kb_ids,list):
return get_error_data_result("`datasets` should be a list")
return get_error_data_result("`dataset_ids` should be a list")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
for id in kb_ids:
if not KnowledgebaseService.query(id=id,tenant_id=tenant_id):
Expand Down
79 changes: 42 additions & 37 deletions api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id):
docs =cls.model.select().where(cls.model.kb_id==kb_id)
orderby, desc, keywords, id, name):
docs = cls.model.select().where(cls.model.kb_id == kb_id)
if id:
docs = docs.where(
cls.model.id== id )
cls.model.id == id)
if name:
docs = docs.where(
cls.model.name == name
)
if keywords:
docs = docs.where(
fn.LOWER(cls.model.name).contains(keywords.lower())
Expand All @@ -70,7 +74,6 @@ def get_list(cls, kb_id, page_number, items_per_page,
count = docs.count()
return list(docs.dicts()), count


@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
Expand Down Expand Up @@ -162,26 +165,27 @@ def get_newly_uploaded(cls):
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value)\
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value) \
.order_by(cls.model.update_time.asc())
return list(docs.dicts())

@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run]
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.progress > 0)
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.progress > 0)
return list(docs.dicts())

@classmethod
Expand All @@ -196,12 +200,12 @@ def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
chunk_num).where(
Knowledgebase.id == kb_id).execute()
return num

@classmethod
@DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
Expand All @@ -214,13 +218,13 @@ def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
return num

@classmethod
@DB.connection_context()
def clear_chunk_num(cls, doc_id):
Expand All @@ -229,10 +233,10 @@ def clear_chunk_num(cls, doc_id):

num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num-1
doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1
).where(
Knowledgebase.id == doc.kb_id).execute()
return num
Expand All @@ -243,8 +247,8 @@ def get_tenant_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
Expand All @@ -270,8 +274,8 @@ def accessible(cls, doc_id, user_id):
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
Expand All @@ -284,7 +288,7 @@ def accessible4deletion(cls, doc_id, user_id):
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
Expand All @@ -296,13 +300,13 @@ def get_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return docs[0]["embd_id"]

@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
Expand Down Expand Up @@ -338,6 +342,7 @@ def dfs_update(old, new):
dfs_update(old[k], v)
else:
old[k] = v

dfs_update(d.parser_config, config)
cls.update_by_id(id, {"parser_config": d.parser_config})

Expand Down Expand Up @@ -372,7 +377,7 @@ def update_progress(cls):
finished = True
bad = 0
e, doc = DocumentService.get_by_id(d["id"])
status = doc.run#TaskStatus.RUNNING.value
status = doc.run # TaskStatus.RUNNING.value
for t in tsks:
if 0 <= t.progress < 1:
finished = False
Expand All @@ -386,9 +391,10 @@ def update_progress(cls):
prg = -1
status = TaskStatus.FAIL.value
elif finished:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
" raptor") < 0:
queue_raptor_tasks(d)
prg = 0.98 * len(tsks)/(len(tsks)+1)
prg = 0.98 * len(tsks) / (len(tsks) + 1)
msg.append("------ RAPTOR -------")
else:
status = TaskStatus.DONE.value
Expand All @@ -414,7 +420,6 @@ def get_kb_doc_count(cls, kb_id):
return len(cls.model.select(cls.model.id).where(
cls.model.kb_id == kb_id).dicts())


@classmethod
@DB.connection_context()
def do_cancel(cls, doc_id):
Expand Down Expand Up @@ -579,4 +584,4 @@ def embedding(doc_id, cnts, batch_size=16):
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

return [d["id"] for d,_ in files]
return [d["id"] for d, _ in files]
2 changes: 1 addition & 1 deletion sdk/python/ragflow_sdk/ragflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def retrieve(self, dataset_ids, document_ids=None, question="", offset=1, limit=
"rerank_id": rerank_id,
"keyword": keyword,
"question": question,
"datasets": dataset_ids,
"dataset_ids": dataset_ids,
"documents": document_ids
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/test/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import os
HOST_ADDRESS=os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380')
10 changes: 4 additions & 6 deletions sdk/python/test/t_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
from ragflow_sdk import RAGFlow

HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380')
from common import HOST_ADDRESS

def test_create_chat_with_name(get_api_key_fixture):
API_KEY = get_api_key_fixture
Expand All @@ -16,7 +14,7 @@ def test_create_chat_with_name(get_api_key_fixture):
docs= kb.upload_documents(documents)
for doc in docs:
doc.add_chunk("This is a test to add chunk")
rag.create_chat("test_create", dataset_ids=[kb.id])
rag.create_chat("test_create_chat", dataset_ids=[kb.id])


def test_update_chat_with_name(get_api_key_fixture):
Expand All @@ -32,7 +30,7 @@ def test_update_chat_with_name(get_api_key_fixture):
docs = kb.upload_documents(documents)
for doc in docs:
doc.add_chunk("This is a test to add chunk")
chat = rag.create_chat("test_update", dataset_ids=[kb.id])
chat = rag.create_chat("test_update_chat", dataset_ids=[kb.id])
chat.update({"name": "new_chat"})


Expand All @@ -49,7 +47,7 @@ def test_delete_chats_with_success(get_api_key_fixture):
docs = kb.upload_documents(documents)
for doc in docs:
doc.add_chunk("This is a test to add chunk")
chat = rag.create_chat("test_delete", dataset_ids=[kb.id])
chat = rag.create_chat("test_delete_chat", dataset_ids=[kb.id])
rag.delete_chats(ids=[chat.id])

def test_list_chats_with_success(get_api_key_fixture):
Expand Down
Loading