Skip to content

Commit

Permalink
handle nits in task_executor (infiniflow#2637)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

- fix typo
- fix string format
- format import

### Type of change

- [x] Refactoring
  • Loading branch information
yqkcn authored Sep 29, 2024
1 parent a0e3ec3 commit bdee5e8
Showing 1 changed file with 36 additions and 38 deletions.
74 changes: 36 additions & 38 deletions rag/svr/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,31 @@
import traceback
from concurrent.futures import ThreadPoolExecutor
from functools import partial

from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.utils.storage_factory import STORAGE_IMPL
from api.db.db_models import close_connection
from rag.settings import database_logger, SVR_QUEUE_NAME
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from multiprocessing import Pool
import numpy as np
from elasticsearch_dsl import Q, Search
from io import BytesIO
from multiprocessing.context import TimeoutError
from api.db.services.task_service import TaskService
from rag.utils.es_conn import ELASTICSEARCH
from timeit import default_timer as timer
from rag.utils import rmSpace, findMaxTm, num_tokens_from_string

from rag.nlp import search, rag_tokenizer
from io import BytesIO
import numpy as np
import pandas as pd

from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from elasticsearch_dsl import Q

from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler
from api.utils.file_utils import get_project_base_directory
from rag.utils.redis_conn import REDIS_CONN
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import database_logger, SVR_QUEUE_NAME
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL

BATCH_SIZE = 64

Expand All @@ -74,11 +71,11 @@
ParserType.KG.value: knowledge_graph
}

CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
PAYLOAD = None
CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
PAYLOAD: Payload | None = None


def set_progress(task_id, from_page=0, to_page=-1,
prog=None, msg="Processing..."):
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
Expand Down Expand Up @@ -107,11 +104,11 @@ def set_progress(task_id, from_page=0, to_page=-1,


def collect():
global CONSUMEER_NAME, PAYLOAD
global CONSUMER_NAME, PAYLOAD
try:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
if not PAYLOAD:
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not PAYLOAD:
time.sleep(1)
return pd.DataFrame()
Expand Down Expand Up @@ -159,17 +156,16 @@ def build(row):
binary = get_storage_binary(bucket, name)
cron_logger.info(
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except TimeoutError as e:
callback(-1, f"Internal server error: Fetch file from minio timeout. Could you try it again.")
except TimeoutError:
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
cron_logger.error(
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
return
except Exception as e:
if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
else:
callback(-1, f"Get file from minio: %s" %
str(e).replace("'", ""))
callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
traceback.print_exc()
return

Expand All @@ -180,7 +176,7 @@ def build(row):
cron_logger.info(
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except Exception as e:
callback(-1, f"Internal server error while chunking: %s" %
callback(-1, "Internal server error while chunking: %s" %
str(e).replace("'", ""))
cron_logger.error(
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
Expand Down Expand Up @@ -236,7 +232,9 @@ def init_kb(row):
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))


def embedding(docs, mdl, parser_config={}, callback=None):
def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None:
parser_config = {}
batch_size = 32
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
Expand Down Expand Up @@ -277,7 +275,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):

def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vts, _ = embd_mdl.encode(["ok"])
vctr_nm = "q_%d_vec"%len(vts[0])
vctr_nm = "q_%d_vec" % len(vts[0])
chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
Expand Down Expand Up @@ -374,7 +372,7 @@ def main():

cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r:
callback(-1, f"Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
cron_logger.error(str(es_r))
Expand All @@ -392,15 +390,15 @@ def main():


def report_status():
global CONSUMEER_NAME
global CONSUMER_NAME
while True:
try:
obj = REDIS_CONN.get("TASKEXE")
if not obj: obj = {}
else: obj = json.loads(obj)
if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
obj[CONSUMEER_NAME].append(timer())
obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = []
obj[CONSUMER_NAME].append(timer())
obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
except Exception as e:
print("[Exception]:", str(e))
Expand Down

0 comments on commit bdee5e8

Please sign in to comment.