diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 6676941df8d..bae0527d537 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -20,7 +20,7 @@ from flask import request, Response from flask_login import login_required, current_user -from api.db import FileType, ParserType, FileSource +from api.db import FileType, ParserType, FileSource, LLMType from api.db.db_models import APIToken, API4Conversation, Task, File from api.db.services import duplicate_name from api.db.services.api_service import APITokenService, API4ConversationService @@ -29,6 +29,7 @@ from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import TenantLLMService from api.db.services.task_service import queue_tasks, TaskService from api.db.services.user_service import UserTenantService from api.settings import RetCode, retrievaler @@ -37,6 +38,7 @@ from itsdangerous import URLSafeTimedSerializer from api.utils.file_utils import filename_type, thumbnail +from rag.nlp import keyword_extraction from rag.utils.minio_conn import MINIO @@ -587,3 +589,55 @@ def fillin_conv(ans): except Exception as e: return server_error_response(e) + + +@manager.route('/retrieval', methods=['POST']) +@validate_request("kb_id", "question") +def retrieval(): + token = request.headers.get('Authorization').split()[1] + objs = APIToken.query(token=token) + if not objs: + return get_json_result( + data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) + + req = request.json + kb_id = req.get("kb_id") + doc_ids = req.get("doc_ids", []) + question = req.get("question") + page = int(req.get("page", 1)) + size = int(req.get("size", 30)) + similarity_threshold = float(req.get("similarity_threshold", 0.2)) + vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) + top = int(req.get("top_k", 1024)) + + try: + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + return get_data_error_result(retmsg="Knowledgebase not found!") + + embd_mdl = TenantLLMService.model_instance( + kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) + + rerank_mdl = None + if req.get("rerank_id"): + rerank_mdl = TenantLLMService.model_instance( + kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + + if req.get("keyword", False): + chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) + question += keyword_extraction(chat_mdl, question) + + ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, + similarity_threshold, vector_similarity_weight, top, + doc_ids, rerank_mdl=rerank_mdl) + for c in ranks["chunks"]: + if "vector" in c: + del c["vector"] + + return get_json_result(data=ranks) + except Exception as e: + if str(e).find("not_found") > 0: + return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', + retcode=RetCode.DATA_ERROR) + return server_error_response(e) +