Skip to content

Commit

Permalink
create dataset (infiniflow#2074)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

You can use sdk to create a dataset

### Type of change

- [x] New Feature

---------

Co-authored-by: root <root@xwg>
  • Loading branch information
Feiue and root authored Aug 23, 2024
1 parent 0af1f93 commit b8fc467
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 27 deletions.
Empty file.
30 changes: 30 additions & 0 deletions sdk/python/ragflow/modules/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class Base(object):
def __init__(self, rag, res_dict):
self.rag = rag
for k, v in res_dict.items():
if isinstance(v, dict):
self.__dict__[k] = Base(rag, v)
else:
self.__dict__[k] = v

def to_json(self):
pr = {}
for name in dir(self):
value = getattr(self, name)
if not name.startswith('__') and not callable(value) and name != "rag":
if isinstance(value, Base):
pr[name] = value.to_json()
else:
pr[name] = value
return pr


def post(self, path, param):
res = self.rag.post(path,param)
return res

def get(self, path, params=''):
res = self.rag.get(path,params)
return res


33 changes: 33 additions & 0 deletions sdk/python/ragflow/modules/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .base import Base


class DataSet(Base):
class ParseConfig(Base):
def __init__(self, rag, res_dict):
self.chunk_token_count = 128
self.layout_recognize = True
self.delimiter = '\n!?。;!?'
self.task_page_size = 12
super().__init__(rag, res_dict)

def __init__(self, rag, res_dict):
self.id = ""
self.name = ""
self.avatar = ""
self.tenant_id = None
self.description = ""
self.language = "English"
self.embedding_model = ""
self.permission = "me"
self.document_count = 0
self.chunk_count = 0
self.parse_method = 0
self.parser_config = None
super().__init__(rag, res_dict)

def delete(self):
try:
self.post("/rm", {"kb_id": self.id})
return True
except Exception:
return False
61 changes: 36 additions & 25 deletions sdk/python/ragflow/ragflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os

import requests

from api.db.services.document_service import DocumentService
from api.settings import RetCode
from .modules.dataset import DataSet


class RAGFlow:
def __init__(self, user_key, base_url, version='v1'):
"""
api_url: http://<host_address>/api/v1
dataset_url: http://<host_address>/api/v1/dataset
document_url: http://<host_address>/api/v1/dataset/{dataset_id}/documents
api_url: http://<host_address>/v1
dataset_url: http://<host_address>/v1/kb
document_url: http://<host_address>/v1/dataset/{dataset_id}/documents
"""
self.user_key = user_key
self.api_url = f"{base_url}/api/{version}"
self.dataset_url = f"{self.api_url}/dataset"
self.api_url = f"{base_url}/{version}"
self.dataset_url = f"{self.api_url}/kb"
self.authorization_header = {"Authorization": "{}".format(self.user_key)}
self.base_url = base_url

def post(self, path, param):
res = requests.post(url=self.dataset_url + path, json=param, headers=self.authorization_header)
return res

def get(self, path, params=''):
res = requests.get(self.dataset_url + path, params=params, headers=self.authorization_header)
return res

def create_dataset(self, dataset_name):
"""
name: dataset name
"""
res = requests.post(url=self.dataset_url, json={"name": dataset_name}, headers=self.authorization_header)
result_dict = json.loads(res.text)
return result_dict

res_create = self.post("/create", {"name": dataset_name})
res_create_data = res_create.json()['data']
res_detail = self.get("/detail", {"kb_id": res_create_data["kb_id"]})
res_detail_data = res_detail.json()['data']
result = {}
result['id'] = res_detail_data['id']
result['name'] = res_detail_data['name']
result['avatar'] = res_detail_data['avatar']
result['description'] = res_detail_data['description']
result['language'] = res_detail_data['language']
result['embedding_model'] = res_detail_data['embd_id']
result['permission'] = res_detail_data['permission']
result['document_count'] = res_detail_data['doc_num']
result['chunk_count'] = res_detail_data['chunk_num']
result['parser_config'] = res_detail_data['parser_config']
dataset = DataSet(self, result)
return dataset

"""
def delete_dataset(self, dataset_name):
dataset_id = self.find_dataset_id_by_name(dataset_name)
Expand All @@ -55,16 +76,6 @@ def find_dataset_id_by_name(self, dataset_name):
return dataset["id"]
return None
def list_dataset(self, offset=0, count=-1, orderby="create_time", desc=True):
params = {
"offset": offset,
"count": count,
"orderby": orderby,
"desc": desc
}
response = requests.get(url=self.dataset_url, params=params, headers=self.authorization_header)
return response.json()

def get_dataset(self, dataset_name):
dataset_id = self.find_dataset_id_by_name(dataset_name)
endpoint = f"{self.dataset_url}/{dataset_id}"
Expand All @@ -78,7 +89,7 @@ def update_dataset(self, dataset_name, **params):
response = requests.put(endpoint, json=params, headers=self.authorization_header)
return response.json()
# ------------------------------- CONTENT MANAGEMENT -----------------------------------------------------
# ------------------------------- CONTENT MANAGEMENT -----------------------------------------------------
# ----------------------------upload local files-----------------------------------------------------
def upload_local_file(self, dataset_id, file_paths):
Expand Down Expand Up @@ -186,4 +197,4 @@ def show_parsing_status(self, dataset_id, document_id):
# ----------------------------get a specific chunk-----------------------------------------------------
# ----------------------------retrieval test-----------------------------------------------------

"""
4 changes: 3 additions & 1 deletion sdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@
import setuptools

if __name__ == "__main__":
setuptools.setup(packages=['ragflow'])
setuptools.setup(name='ragflow',
version="0.1",
packages=setuptools.find_packages())

2 changes: 1 addition & 1 deletion sdk/python/test/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@


API_KEY = 'IjJkOGQ4ZDE2MzkyMjExZWZhYTk0MzA0M2Q3ZWU1MzdlIg.ZoUfug.RmqcYyCrlAnLtkzk6bYXiXN3eEY'
API_KEY = 'IjUxNGM0MmM4NWY5MzExZWY5MDhhMDI0MmFjMTIwMDA2Ig.ZsWebA.mV1NKdSPPllgowiH-7vz36tMWyI'
HOST_ADDRESS = 'http://127.0.0.1:9380'
23 changes: 23 additions & 0 deletions sdk/python/test/t_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from ragflow import RAGFlow

from common import API_KEY, HOST_ADDRESS
from test_sdkbase import TestSdk


class TestDataset(TestSdk):
def test_create_dataset_with_success(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("God")
assert ds is not None, "The dataset creation failed, returned None."
assert ds.name == "God", "Dataset name does not match."

def test_delete_one_file(self):
"""
Test deleting one file with success.
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("ABC")
assert ds is not None, "Failed to create dataset"
assert ds.name == "ABC", "Dataset name mismatch"
delete_result = ds.delete()
assert delete_result is True, "Failed to delete dataset"

0 comments on commit b8fc467

Please sign in to comment.