Skip to content

Commit

Permalink
Refactor (#537)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

### Type of change

- [x] Refactoring
  • Loading branch information
KevinHuSh authored Apr 25, 2024
1 parent cf9b554 commit 66f8d35
Show file tree
Hide file tree
Showing 14 changed files with 124 additions and 34 deletions.
3 changes: 2 additions & 1 deletion api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def upload():
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
return get_data_error_result(
retmsg="Exceed the maximum file number of a free user!")

Expand Down
4 changes: 2 additions & 2 deletions api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]])
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
except Exception as e:
return server_error_response(e)

Expand Down Expand Up @@ -174,7 +174,7 @@ def list():
llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"]
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]

llm_set = set([m["llm_name"] for m in llms])
for o in objs:
Expand Down
2 changes: 1 addition & 1 deletion api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ class Dialog(DataBaseModel):
null=True,
default="Chinese",
help_text="English|Chinese")
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
"presence_penalty": 0.4, "max_tokens": 215})
prompt_type = CharField(
Expand Down
8 changes: 5 additions & 3 deletions api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def init_superuser():
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "QAnything",
"name": "Youdao",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
Expand Down Expand Up @@ -323,7 +323,7 @@ def init_llm_factory():
"max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ QAnything -----------------------
# ------------------------ Youdao -----------------------
{
"fid": factory_infos[7]["name"],
"llm_name": "maidalun1020/bce-embedding-base_v1",
Expand All @@ -347,7 +347,9 @@ def init_llm_factory():
LLMService.filter_delete([LLM.fid == "Local"])
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])

LLMFactoriesService.filter_update([LLMFactoriesService.model.name == "QAnything"], {"name": "Youdao"})
LLMService.filter_update([LLMService.model.fid == "QAnything"], {"fid": "Youdao"})
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
"""
drop table llm;
drop table llm_factories;
Expand Down
2 changes: 1 addition & 1 deletion api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def model_instance(cls, tenant_id, llm_type,
if not model_config:
if llm_type == LLMType.EMBEDDING.value:
llm = LLMService.query(llm_name=llm_name)
if llm and llm[0].fid in ["QAnything", "FastEmbed"]:
if llm and llm[0].fid in ["Youdao", "FastEmbed"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
if not model_config:
if llm_name == "flag-embedding":
Expand Down
20 changes: 20 additions & 0 deletions api/db/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from api.db.db_models import Task, Document, Knowledgebase, Tenant
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.utils import current_timestamp


class TaskService(CommonService):
Expand Down Expand Up @@ -70,6 +71,25 @@ def get_tasks(cls, tm, mod=0, comm=1, items_per_page=1, takeit=True):
cls.model.id == docs[0]["id"]).execute()
return docs

@classmethod
@DB.connection_context()
def get_ongoing_doc_name(cls):
with DB.lock("get_task", -1):
docs = cls.model.select(*[Document.kb_id, Document.location]) \
.join(Document, on=(cls.model.doc_id == Document.id)) \
.where(
Document.status == StatusEnum.VALID.value,
Document.run == TaskStatus.RUNNING.value,
~(Document.type == FileType.VIRTUAL.value),
cls.model.progress >= 0,
cls.model.progress < 1,
cls.model.create_time >= current_timestamp() - 180000
)
docs = list(docs.dicts())
if not docs: return []

return list(set([(d["kb_id"], d["location"]) for d in docs]))

@classmethod
@DB.connection_context()
def do_cancel(cls, id):
Expand Down
34 changes: 17 additions & 17 deletions deepdoc/parser/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(self):
self.updown_cnt_mdl.set_param({"device": "cuda"})
try:
model_dir = os.path.join(
get_project_base_directory(),
"rag/res/deepdoc")
get_project_base_directory(),
"rag/res/deepdoc")
self.updown_cnt_mdl.load_model(os.path.join(
model_dir, "updown_concat_xgb.model"))
except Exception as e:
Expand All @@ -49,7 +49,6 @@ def __init__(self):
self.updown_cnt_mdl.load_model(os.path.join(
model_dir, "updown_concat_xgb.model"))


self.page_from = 0
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
Expand All @@ -76,7 +75,7 @@ def _x_dis(self, a, b):
def _y_dis(
self, a, b):
return (
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2

def _match_proj(self, b):
proj_patt = [
Expand All @@ -99,9 +98,9 @@ def _updown_concat_features(self, up, down):
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
tks_all = up["text"][-LEN:].strip() \
+ (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip()
+ (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip()
tks_all = huqie.qie(tks_all).split(" ")
fea = [
up.get("R", -1) == down.get("R", -1),
Expand All @@ -123,7 +122,7 @@ def _updown_concat_features(self, up, down):
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
True if re.search(r"[\((][^\))]+$", up["text"])
and re.search(r"[\))]", down["text"]) else False,
and re.search(r"[\))]", down["text"]) else False,
self._match_proj(down),
True if re.match(r"[A-Z]", down["text"]) else False,
True if re.match(r"[A-Z]", up["text"][-1]) else False,
Expand Down Expand Up @@ -185,7 +184,7 @@ def _table_transformer_job(self, ZM):
continue
for tb in tbls: # for table
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
left *= ZM
top *= ZM
right *= ZM
Expand Down Expand Up @@ -297,7 +296,7 @@ def __ocr(self, pagenum, img, chars, ZM=3):
for b in bxs:
if not b["text"]:
left, right, top, bott = b["x0"] * ZM, b["x1"] * \
ZM, b["top"] * ZM, b["bottom"] * ZM
ZM, b["top"] * ZM, b["bottom"] * ZM
b["text"] = self.ocr.recognize(np.array(img),
np.array([[left, top], [right, top], [right, bott], [left, bott]],
dtype=np.float32))
Expand Down Expand Up @@ -622,7 +621,7 @@ def _extract_table_figure(self, need_image, ZM,
i += 1
continue
lout_no = str(self.boxes[i]["page_number"]) + \
"-" + str(self.boxes[i]["layoutno"])
"-" + str(self.boxes[i]["layoutno"])
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
"title",
"figure caption",
Expand Down Expand Up @@ -975,6 +974,7 @@ def dfs(arr, depth):
self.outlines.append((a["/Title"], depth))
continue
dfs(a, depth + 1)

dfs(outlines, 0)
except Exception as e:
logging.warning(f"Outlines exception: {e}")
Expand All @@ -984,7 +984,7 @@ def dfs(arr, depth):
logging.info("Images converted.")
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
range(len(self.page_chars))]
range(len(self.page_chars))]
if sum([1 if e else 0 for e in self.is_english]) > len(
self.page_images) / 2:
self.is_english = True
Expand Down Expand Up @@ -1012,9 +1012,9 @@ def dfs(arr, depth):
j += 1

self.__ocr(i + 1, img, chars, zoomin)
#if callback:
# callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
#print("OCR:", timer()-st)
if callback and i % 6 == 5:
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
# print("OCR:", timer()-st)

if not self.is_english and not any(
[c for c in self.page_chars]) and self.boxes:
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def crop(self, text, ZM=3, need_position=False):
left, right, top, bottom = float(left), float(
right), float(top), float(bottom)
poss.append(([int(p) - 1 for p in pn.split("-")],
left, right, top, bottom))
left, right, top, bottom))
if not poss:
if need_position:
return None, None
Expand All @@ -1076,7 +1076,7 @@ def crop(self, text, ZM=3, need_position=False):
self.page_images[pns[0]].crop((left * ZM, top * ZM,
right *
ZM, min(
bottom, self.page_images[pns[0]].size[1])
bottom, self.page_images[pns[0]].size[1])
))
)
if 0 < ii < len(poss) - 1:
Expand Down
2 changes: 1 addition & 1 deletion rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed,
"FastEmbed": FastEmbed,
"QAnything": QAnythingEmbed
"Youdao": YoudaoEmbed
}


Expand Down
12 changes: 6 additions & 6 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,19 @@ def encode_queries(self, text):
return np.array(res.data[0].embedding), res.usage.total_tokens


class QAnythingEmbed(Base):
class YoudaoEmbed(Base):
_client = None

def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
from BCEmbedding import EmbeddingModel as qanthing
if not QAnythingEmbed._client:
if not YoudaoEmbed._client:
try:
print("LOADING BCE...")
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
get_project_base_directory(),
"rag/res/bce-embedding-base_v1"))
except Exception as e:
QAnythingEmbed._client = qanthing(
YoudaoEmbed._client = qanthing(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))

Expand All @@ -251,10 +251,10 @@ def encode(self, texts: list, batch_size=10):
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
res.extend(embds)
return np.array(res), token_count

def encode_queries(self, text):
embds = QAnythingEmbed._client.encode([text])
embds = YoudaoEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)
43 changes: 43 additions & 0 deletions rag/svr/cache_file_svr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import random
import time
import traceback

from api.db.db_models import close_connection
from api.db.services.task_service import TaskService
from rag.utils import MINIO
from rag.utils.redis_conn import REDIS_CONN


def collect():
doc_locations = TaskService.get_ongoing_doc_name()
#print(tasks)
if len(doc_locations) == 0:
time.sleep(1)
return
return doc_locations

def main():
locations = collect()
if not locations:return
print("TASKS:", len(locations))
for kb_id, loc in locations:
try:
if REDIS_CONN.is_alive():
try:
key = "{}/{}".format(kb_id, loc)
if REDIS_CONN.exist(key):continue
file_bin = MINIO.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60)
print("CACHE:", loc)
except Exception as e:
traceback.print_stack(e)
except Exception as e:
traceback.print_stack(e)



if __name__ == "__main__":
while True:
main()
close_connection()
time.sleep(1)
2 changes: 1 addition & 1 deletion rag/svr/task_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def update_progress():
info = {
"process_duation": datetime.timestamp(
datetime.now()) -
d["process_begin_at"].timestamp(),
d["process_begin_at"].timestamp(),
"run": status}
if prg != 0:
info["progress"] = prg
Expand Down
6 changes: 6 additions & 0 deletions rag/svr/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,14 @@ def get_minio_binary(bucket, name):
global MINIO
if REDIS_CONN.is_alive():
try:
for _ in range(30):
if REDIS_CONN.exist("{}/{}".format(bucket, name)):
time.sleep(1)
break
time.sleep(1)
r = REDIS_CONN.get("{}/{}".format(bucket, name))
if r: return r
cron_logger.warning("Cache missing: {}".format(name))
except Exception as e:
cron_logger.warning("Get redis[EXCEPTION]:" + str(e))
return MINIO.get(bucket, name)
Expand Down
1 change: 0 additions & 1 deletion rag/utils/minio_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def rm(self, bucket, fnm):
except Exception as e:
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))


def get(self, bucket, fnm):
for _ in range(1):
try:
Expand Down
19 changes: 19 additions & 0 deletions rag/utils/redis_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def __open__(self):
def is_alive(self):
return self.REDIS is not None

def exist(self, k):
if not self.REDIS: return
try:
return self.REDIS.exists(k)
except Exception as e:
logging.warning("[EXCEPTION]exist" + str(k) + "||" + str(e))
self.__open__()

def get(self, k):
if not self.REDIS: return
try:
Expand All @@ -51,5 +59,16 @@ def set(self, k, v, exp=3600):
self.__open__()
return False

def transaction(self, key, value, exp=3600):
try:
pipeline = self.REDIS.pipeline(transaction=True)
pipeline.set(key, value, exp, nx=True)
pipeline.execute()
return True
except Exception as e:
logging.warning("[EXCEPTION]set" + str(key) + "||" + str(e))
self.__open__()
return False


REDIS_CONN = RedisDB()

0 comments on commit 66f8d35

Please sign in to comment.