Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create and update dataset #2110

Merged
merged 17 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions api/apps/sdk/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from flask import request
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request


from api.db import StatusEnum
from api.db.db_models import APIToken
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result
from api.utils.api_utils import get_json_result


@manager.route('/save', methods=['POST'])
def save():
req = request.json
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
e, t = TenantService.get_by_id(tenant_id)
if not e:
return get_data_error_result(retmsg="Tenant not found.")
if "id" not in req:
req['id'] = get_uuid()
req["name"] = req["name"].strip()
if req["name"] == "":
return get_data_error_result(
retmsg="Name is not empty")
if KnowledgebaseService.query(name=req["name"]):
return get_data_error_result(
retmsg="Duplicated knowledgebase name")
req["tenant_id"] = tenant_id
req['created_by'] = tenant_id
req['embd_id'] = t.embd_id
if not KnowledgebaseService.save(**req):
return get_data_error_result(retmsg="Data saving error")
req.pop('created_by')
keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method',
'chunk_num': 'chunk_count', 'doc_num': 'document_count'}
for old_key,new_key in keys_to_rename.items():
if old_key in req:
req[new_key]=req.pop(old_key)
return get_json_result(data=req)
else:
if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change tenant_id or embedding_model")

e, kb = KnowledgebaseService.get_by_id(req["id"])
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")

if not KnowledgebaseService.query(
created_by=tenant_id, id=req["id"]):
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)

if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count or chunk_count ")

if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
return get_data_error_result(
retmsg="if chunk count is not 0, parser method is not changable. ")


if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'],
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name.")

del req["id"]
req['created_by'] = tenant_id
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result(retmsg="Data update error ")
return get_json_result(data=True)
1 change: 1 addition & 0 deletions sdk/python/ragflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
__version__ = importlib.metadata.version("ragflow")

from .ragflow import RAGFlow
from .modules.dataset import DataSet
22 changes: 14 additions & 8 deletions sdk/python/ragflow/modules/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class DataSet(Base):
class ParseConfig(Base):
class ParserConfig(Base):
def __init__(self, rag, res_dict):
self.chunk_token_count = 128
self.layout_recognize = True
Expand All @@ -21,13 +21,19 @@ def __init__(self, rag, res_dict):
self.permission = "me"
self.document_count = 0
self.chunk_count = 0
self.parse_method = 0
self.parser_method = "naive"
self.parser_config = None
super().__init__(rag, res_dict)

def delete(self):
try:
self.post("/rm", {"kb_id": self.id})
return True
except Exception:
return False
def save(self):
res = self.post('/dataset/save',
{"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
"description": self.description, "language": self.language, "embd_id": self.embedding_model,
"permission": self.permission,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method,
"parser_config": self.parser_config.to_json()
})
if "data" in res.json():
KevinHuSh marked this conversation as resolved.
Show resolved Hide resolved
return res.json()['data']
KevinHuSh marked this conversation as resolved.
Show resolved Hide resolved
else:
KevinHuSh marked this conversation as resolved.
Show resolved Hide resolved
return res.json()['retmsg']
KevinHuSh marked this conversation as resolved.
Show resolved Hide resolved
188 changes: 22 additions & 166 deletions sdk/python/ragflow/ragflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,180 +21,36 @@
class RAGFlow:
def __init__(self, user_key, base_url, version='v1'):
"""
api_url: http://<host_address>/v1
dataset_url: http://<host_address>/v1/kb
document_url: http://<host_address>/v1/dataset/{dataset_id}/documents
api_url: http://<host_address>/api/v1
"""
self.user_key = user_key
self.api_url = f"{base_url}/{version}"
self.dataset_url = f"{self.api_url}/kb"
self.authorization_header = {"Authorization": "{}".format(self.user_key)}
self.base_url = base_url
self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer",self.user_key)}

def post(self, path, param):
res = requests.post(url=self.dataset_url + path, json=param, headers=self.authorization_header)
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header)
return res

def get(self, path, params=''):
res = requests.get(self.dataset_url + path, params=params, headers=self.authorization_header)
res = requests.get(self.api_url + path, params=params, headers=self.authorization_header)
return res

def create_dataset(self, dataset_name):
"""
name: dataset name
"""
res_create = self.post("/create", {"name": dataset_name})
res_create_data = res_create.json()['data']
res_detail = self.get("/detail", {"kb_id": res_create_data["kb_id"]})
res_detail_data = res_detail.json()['data']
result = {}
result['id'] = res_detail_data['id']
result['name'] = res_detail_data['name']
result['avatar'] = res_detail_data['avatar']
result['description'] = res_detail_data['description']
result['language'] = res_detail_data['language']
result['embedding_model'] = res_detail_data['embd_id']
result['permission'] = res_detail_data['permission']
result['document_count'] = res_detail_data['doc_num']
result['chunk_count'] = res_detail_data['chunk_num']
result['parser_config'] = res_detail_data['parser_config']
dataset = DataSet(self, result)
return dataset

"""
def delete_dataset(self, dataset_name):
dataset_id = self.find_dataset_id_by_name(dataset_name)

endpoint = f"{self.dataset_url}/{dataset_id}"
res = requests.delete(endpoint, headers=self.authorization_header)
return res.json()

def find_dataset_id_by_name(self, dataset_name):
res = requests.get(self.dataset_url, headers=self.authorization_header)
for dataset in res.json()["data"]:
if dataset["name"] == dataset_name:
return dataset["id"]
return None

def get_dataset(self, dataset_name):
dataset_id = self.find_dataset_id_by_name(dataset_name)
endpoint = f"{self.dataset_url}/{dataset_id}"
response = requests.get(endpoint, headers=self.authorization_header)
return response.json()

def update_dataset(self, dataset_name, **params):
dataset_id = self.find_dataset_id_by_name(dataset_name)

endpoint = f"{self.dataset_url}/{dataset_id}"
response = requests.put(endpoint, json=params, headers=self.authorization_header)
return response.json()

# ------------------------------- CONTENT MANAGEMENT -----------------------------------------------------

# ----------------------------upload local files-----------------------------------------------------
def upload_local_file(self, dataset_id, file_paths):
files = []

for file_path in file_paths:
if not isinstance(file_path, str):
return {"code": RetCode.ARGUMENT_ERROR, "message": f"{file_path} is not string."}
if "http" in file_path:
return {"code": RetCode.ARGUMENT_ERROR, "message": "Remote files have not unsupported."}
if os.path.isfile(file_path):
files.append(("file", open(file_path, "rb")))
else:
return {"code": RetCode.DATA_ERROR, "message": f"The file {file_path} does not exist"}

res = requests.request("POST", url=f"{self.dataset_url}/{dataset_id}/documents", files=files,
headers=self.authorization_header)

result_dict = json.loads(res.text)
return result_dict

# ----------------------------delete a file-----------------------------------------------------
def delete_files(self, document_id, dataset_id):
endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}"
res = requests.delete(endpoint, headers=self.authorization_header)
return res.json()

# ----------------------------list files-----------------------------------------------------
def list_files(self, dataset_id, offset=0, count=-1, order_by="create_time", descend=True, keywords=""):
params = {
"offset": offset,
"count": count,
"order_by": order_by,
"descend": descend,
"keywords": keywords
}
endpoint = f"{self.dataset_url}/{dataset_id}/documents/"
res = requests.get(endpoint, params=params, headers=self.authorization_header)
return res.json()

# ----------------------------update files: enable, rename, template_type-------------------------------------------
def update_file(self, dataset_id, document_id, **params):
endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}"
response = requests.put(endpoint, json=params, headers=self.authorization_header)
return response.json()

# ----------------------------download a file-----------------------------------------------------
def download_file(self, dataset_id, document_id):
endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}"
res = requests.get(endpoint, headers=self.authorization_header)

content = res.content # binary data
# decode the binary data
try:
decoded_content = content.decode("utf-8")
json_data = json.loads(decoded_content)
return json_data # message
except json.JSONDecodeError: # binary data
_, document = DocumentService.get_by_id(document_id)
file_path = os.path.join(os.getcwd(), document.name)
with open(file_path, "wb") as file:
file.write(content)
return {"code": RetCode.SUCCESS, "data": content}

# ----------------------------start parsing-----------------------------------------------------
def start_parsing_document(self, dataset_id, document_id):
endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}/status"
res = requests.post(endpoint, headers=self.authorization_header)

return res.json()

def start_parsing_documents(self, dataset_id, doc_ids=None):
endpoint = f"{self.dataset_url}/{dataset_id}/documents/status"
res = requests.post(endpoint, headers=self.authorization_header, json={"doc_ids": doc_ids})

return res.json()

# ----------------------------stop parsing-----------------------------------------------------
def stop_parsing_document(self, dataset_id, document_id):
endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}/status"
res = requests.delete(endpoint, headers=self.authorization_header)

return res.json()

def stop_parsing_documents(self, dataset_id, doc_ids=None):
endpoint = f"{self.dataset_url}/{dataset_id}/documents/status"
res = requests.delete(endpoint, headers=self.authorization_header, json={"doc_ids": doc_ids})

return res.json()

# ----------------------------show the status of the file-----------------------------------------------------
def show_parsing_status(self, dataset_id, document_id):
endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}/status"
res = requests.get(endpoint, headers=self.authorization_header)

return res.json()
# ----------------------------list the chunks of the file-----------------------------------------------------

# ----------------------------delete the chunk-----------------------------------------------------

# ----------------------------edit the status of the chunk-----------------------------------------------------

# ----------------------------insert a new chunk-----------------------------------------------------
def create_dataset(self, name:str,avatar:str="",description:str="",language:str="English",permission:str="me",
document_count:int=0,chunk_count:int=0,parser_method:str="naive",
parser_config:DataSet.ParserConfig=None):
if parser_config is None:
parser_config = DataSet.ParserConfig(self, {"chunk_token_count":128,"layout_recognize": True, "delimiter":"\n!?。;!?","task_page_size":12})
parser_config=parser_config.to_json()
res=self.post("/dataset/save",{"name":name,"avatar":avatar,"description":description,"language":language,"permission":permission,
"doc_num": document_count,"chunk_num":chunk_count,"parser_id":parser_method,
"parser_config":parser_config
}
)
res = res.json()
if "data" in res:
KevinHuSh marked this conversation as resolved.
Show resolved Hide resolved
dataset = DataSet(self, data)
KevinHuSh marked this conversation as resolved.
Show resolved Hide resolved
return DataSet(self, res["data"])
else :
KevinHuSh marked this conversation as resolved.
Show resolved Hide resolved
return res.json()["retmsg"]
KevinHuSh marked this conversation as resolved.
Show resolved Hide resolved

# ----------------------------get a specific chunk-----------------------------------------------------

# ----------------------------retrieval test-----------------------------------------------------
"""
2 changes: 1 addition & 1 deletion sdk/python/test/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@


API_KEY = 'IjUxNGM0MmM4NWY5MzExZWY5MDhhMDI0MmFjMTIwMDA2Ig.ZsWebA.mV1NKdSPPllgowiH-7vz36tMWyI'
API_KEY = 'ragflow-k0N2I1MzQwNjNhMzExZWY5ODg1MDI0Mm'
HOST_ADDRESS = 'http://127.0.0.1:9380'
27 changes: 18 additions & 9 deletions sdk/python/test/t_dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
from ragflow import RAGFlow
from ragflow import RAGFlow, DataSet

from common import API_KEY, HOST_ADDRESS
from test_sdkbase import TestSdk


class TestDataset(TestSdk):
def test_create_dataset_with_success(self):
"""
Test creating dataset with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("God")
assert ds is not None, "The dataset creation failed, returned None."
assert ds.name == "God", "Dataset name does not match."
if isinstance(ds, DataSet):
assert ds.name == "God", "Name does not match."
else:
assert False, f"Failed to create dataset, error: {ds}"

def test_delete_one_file(self):
def test_update_dataset_with_success(self):
"""
Test deleting one file with success.
Test updating dataset with success.
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("ABC")
assert ds is not None, "Failed to create dataset"
assert ds.name == "ABC", "Dataset name mismatch"
delete_result = ds.delete()
assert delete_result is True, "Failed to delete dataset"
if isinstance(ds, DataSet):
assert ds.name == "ABC", "Name does not match."
ds.name = 'DEF'
res = ds.save()
assert res is True, f"Failed to update dataset, error: {res}"

else:
assert False, f"Failed to create dataset, error: {ds}"