Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor #537

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def upload():
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
return get_data_error_result(
retmsg="Exceed the maximum file number of a free user!")

Expand Down
4 changes: 2 additions & 2 deletions api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]])
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
except Exception as e:
return server_error_response(e)

Expand Down Expand Up @@ -174,7 +174,7 @@ def list():
llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"]
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]

llm_set = set([m["llm_name"] for m in llms])
for o in objs:
Expand Down
2 changes: 1 addition & 1 deletion api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ class Dialog(DataBaseModel):
null=True,
default="Chinese",
help_text="English|Chinese")
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
"presence_penalty": 0.4, "max_tokens": 215})
prompt_type = CharField(
Expand Down
8 changes: 5 additions & 3 deletions api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def init_superuser():
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "QAnything",
"name": "Youdao",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
Expand Down Expand Up @@ -323,7 +323,7 @@ def init_llm_factory():
"max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ QAnything -----------------------
# ------------------------ Youdao -----------------------
{
"fid": factory_infos[7]["name"],
"llm_name": "maidalun1020/bce-embedding-base_v1",
Expand All @@ -347,7 +347,9 @@ def init_llm_factory():
LLMService.filter_delete([LLM.fid == "Local"])
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])

LLMFactoriesService.filter_update([LLMFactoriesService.model.name == "QAnything"], {"name": "Youdao"})
LLMService.filter_update([LLMService.model.fid == "QAnything"], {"fid": "Youdao"})
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
"""
drop table llm;
drop table llm_factories;
Expand Down
2 changes: 1 addition & 1 deletion api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def model_instance(cls, tenant_id, llm_type,
if not model_config:
if llm_type == LLMType.EMBEDDING.value:
llm = LLMService.query(llm_name=llm_name)
if llm and llm[0].fid in ["QAnything", "FastEmbed"]:
if llm and llm[0].fid in ["Youdao", "FastEmbed"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
if not model_config:
if llm_name == "flag-embedding":
Expand Down
20 changes: 20 additions & 0 deletions api/db/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from api.db.db_models import Task, Document, Knowledgebase, Tenant
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.utils import current_timestamp


class TaskService(CommonService):
Expand Down Expand Up @@ -70,6 +71,25 @@ def get_tasks(cls, tm, mod=0, comm=1, items_per_page=1, takeit=True):
cls.model.id == docs[0]["id"]).execute()
return docs

@classmethod
@DB.connection_context()
def get_ongoing_doc_name(cls):
with DB.lock("get_task", -1):
docs = cls.model.select(*[Document.kb_id, Document.location]) \
.join(Document, on=(cls.model.doc_id == Document.id)) \
.where(
Document.status == StatusEnum.VALID.value,
Document.run == TaskStatus.RUNNING.value,
~(Document.type == FileType.VIRTUAL.value),
cls.model.progress >= 0,
cls.model.progress < 1,
cls.model.create_time >= current_timestamp() - 180000
)
docs = list(docs.dicts())
if not docs: return []

return list(set([(d["kb_id"], d["location"]) for d in docs]))

@classmethod
@DB.connection_context()
def do_cancel(cls, id):
Expand Down
34 changes: 17 additions & 17 deletions deepdoc/parser/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(self):
self.updown_cnt_mdl.set_param({"device": "cuda"})
try:
model_dir = os.path.join(
get_project_base_directory(),
"rag/res/deepdoc")
get_project_base_directory(),
"rag/res/deepdoc")
self.updown_cnt_mdl.load_model(os.path.join(
model_dir, "updown_concat_xgb.model"))
except Exception as e:
Expand All @@ -49,7 +49,6 @@ def __init__(self):
self.updown_cnt_mdl.load_model(os.path.join(
model_dir, "updown_concat_xgb.model"))


self.page_from = 0
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
Expand All @@ -76,7 +75,7 @@ def _x_dis(self, a, b):
def _y_dis(
self, a, b):
return (
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2

def _match_proj(self, b):
proj_patt = [
Expand All @@ -99,9 +98,9 @@ def _updown_concat_features(self, up, down):
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
tks_all = up["text"][-LEN:].strip() \
+ (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip()
+ (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip()
tks_all = huqie.qie(tks_all).split(" ")
fea = [
up.get("R", -1) == down.get("R", -1),
Expand All @@ -123,7 +122,7 @@ def _updown_concat_features(self, up, down):
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
True if re.search(r"[\((][^\))]+$", up["text"])
and re.search(r"[\))]", down["text"]) else False,
and re.search(r"[\))]", down["text"]) else False,
self._match_proj(down),
True if re.match(r"[A-Z]", down["text"]) else False,
True if re.match(r"[A-Z]", up["text"][-1]) else False,
Expand Down Expand Up @@ -185,7 +184,7 @@ def _table_transformer_job(self, ZM):
continue
for tb in tbls: # for table
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
left *= ZM
top *= ZM
right *= ZM
Expand Down Expand Up @@ -297,7 +296,7 @@ def __ocr(self, pagenum, img, chars, ZM=3):
for b in bxs:
if not b["text"]:
left, right, top, bott = b["x0"] * ZM, b["x1"] * \
ZM, b["top"] * ZM, b["bottom"] * ZM
ZM, b["top"] * ZM, b["bottom"] * ZM
b["text"] = self.ocr.recognize(np.array(img),
np.array([[left, top], [right, top], [right, bott], [left, bott]],
dtype=np.float32))
Expand Down Expand Up @@ -622,7 +621,7 @@ def _extract_table_figure(self, need_image, ZM,
i += 1
continue
lout_no = str(self.boxes[i]["page_number"]) + \
"-" + str(self.boxes[i]["layoutno"])
"-" + str(self.boxes[i]["layoutno"])
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
"title",
"figure caption",
Expand Down Expand Up @@ -975,6 +974,7 @@ def dfs(arr, depth):
self.outlines.append((a["/Title"], depth))
continue
dfs(a, depth + 1)

dfs(outlines, 0)
except Exception as e:
logging.warning(f"Outlines exception: {e}")
Expand All @@ -984,7 +984,7 @@ def dfs(arr, depth):
logging.info("Images converted.")
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
range(len(self.page_chars))]
range(len(self.page_chars))]
if sum([1 if e else 0 for e in self.is_english]) > len(
self.page_images) / 2:
self.is_english = True
Expand Down Expand Up @@ -1012,9 +1012,9 @@ def dfs(arr, depth):
j += 1

self.__ocr(i + 1, img, chars, zoomin)
#if callback:
# callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
#print("OCR:", timer()-st)
if callback and i % 6 == 5:
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
# print("OCR:", timer()-st)

if not self.is_english and not any(
[c for c in self.page_chars]) and self.boxes:
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def crop(self, text, ZM=3, need_position=False):
left, right, top, bottom = float(left), float(
right), float(top), float(bottom)
poss.append(([int(p) - 1 for p in pn.split("-")],
left, right, top, bottom))
left, right, top, bottom))
if not poss:
if need_position:
return None, None
Expand All @@ -1076,7 +1076,7 @@ def crop(self, text, ZM=3, need_position=False):
self.page_images[pns[0]].crop((left * ZM, top * ZM,
right *
ZM, min(
bottom, self.page_images[pns[0]].size[1])
bottom, self.page_images[pns[0]].size[1])
))
)
if 0 < ii < len(poss) - 1:
Expand Down
2 changes: 1 addition & 1 deletion rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed,
"FastEmbed": FastEmbed,
"QAnything": QAnythingEmbed
"Youdao": YoudaoEmbed
}


Expand Down
12 changes: 6 additions & 6 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,19 @@ def encode_queries(self, text):
return np.array(res.data[0].embedding), res.usage.total_tokens


class QAnythingEmbed(Base):
class YoudaoEmbed(Base):
_client = None

def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
from BCEmbedding import EmbeddingModel as qanthing
if not QAnythingEmbed._client:
if not YoudaoEmbed._client:
try:
print("LOADING BCE...")
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
get_project_base_directory(),
"rag/res/bce-embedding-base_v1"))
except Exception as e:
QAnythingEmbed._client = qanthing(
YoudaoEmbed._client = qanthing(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))

Expand All @@ -251,10 +251,10 @@ def encode(self, texts: list, batch_size=10):
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
res.extend(embds)
return np.array(res), token_count

def encode_queries(self, text):
embds = QAnythingEmbed._client.encode([text])
embds = YoudaoEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)
43 changes: 43 additions & 0 deletions rag/svr/cache_file_svr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import random
import time
import traceback

from api.db.db_models import close_connection
from api.db.services.task_service import TaskService
from rag.utils import MINIO
from rag.utils.redis_conn import REDIS_CONN


def collect():
doc_locations = TaskService.get_ongoing_doc_name()
#print(tasks)
if len(doc_locations) == 0:
time.sleep(1)
return
return doc_locations

def main():
locations = collect()
if not locations:return
print("TASKS:", len(locations))
for kb_id, loc in locations:
try:
if REDIS_CONN.is_alive():
try:
key = "{}/{}".format(kb_id, loc)
if REDIS_CONN.exist(key):continue
file_bin = MINIO.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60)
print("CACHE:", loc)
except Exception as e:
traceback.print_stack(e)
except Exception as e:
traceback.print_stack(e)



if __name__ == "__main__":
while True:
main()
close_connection()
time.sleep(1)
2 changes: 1 addition & 1 deletion rag/svr/task_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def update_progress():
info = {
"process_duation": datetime.timestamp(
datetime.now()) -
d["process_begin_at"].timestamp(),
d["process_begin_at"].timestamp(),
"run": status}
if prg != 0:
info["progress"] = prg
Expand Down
6 changes: 6 additions & 0 deletions rag/svr/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,14 @@ def get_minio_binary(bucket, name):
global MINIO
if REDIS_CONN.is_alive():
try:
for _ in range(30):
if REDIS_CONN.exist("{}/{}".format(bucket, name)):
time.sleep(1)
break
time.sleep(1)
r = REDIS_CONN.get("{}/{}".format(bucket, name))
if r: return r
cron_logger.warning("Cache missing: {}".format(name))
except Exception as e:
cron_logger.warning("Get redis[EXCEPTION]:" + str(e))
return MINIO.get(bucket, name)
Expand Down
1 change: 0 additions & 1 deletion rag/utils/minio_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def rm(self, bucket, fnm):
except Exception as e:
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))


def get(self, bucket, fnm):
for _ in range(1):
try:
Expand Down
19 changes: 19 additions & 0 deletions rag/utils/redis_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def __open__(self):
def is_alive(self):
return self.REDIS is not None

def exist(self, k):
if not self.REDIS: return
try:
return self.REDIS.exists(k)
except Exception as e:
logging.warning("[EXCEPTION]exist" + str(k) + "||" + str(e))
self.__open__()

def get(self, k):
if not self.REDIS: return
try:
Expand All @@ -51,5 +59,16 @@ def set(self, k, v, exp=3600):
self.__open__()
return False

def transaction(self, key, value, exp=3600):
try:
pipeline = self.REDIS.pipeline(transaction=True)
pipeline.set(key, value, exp, nx=True)
pipeline.execute()
return True
except Exception as e:
logging.warning("[EXCEPTION]set" + str(key) + "||" + str(e))
self.__open__()
return False


REDIS_CONN = RedisDB()