diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py new file mode 100644 index 00000000000..7dd129d1093 --- /dev/null +++ b/api/apps/sdk/dify_retrieval.py @@ -0,0 +1,62 @@ +from flask import request, jsonify + +from db import LLMType, ParserType +from db.services.knowledgebase_service import KnowledgebaseService +from db.services.llm_service import LLMBundle +from settings import retrievaler, kg_retrievaler, RetCode +from utils.api_utils import validate_request, build_error_result, apikey_required + + +@manager.route('/dify/retrieval', methods=['POST']) +@apikey_required +@validate_request("knowledge_id", "query") +def retrieval(tenant_id): + req = request.json + question = req["query"] + kb_id = req["knowledge_id"] + retrieval_setting = req.get("retrieval_setting", {}) + similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0)) + top = int(retrieval_setting.get("top_k", 1024)) + + try: + + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + return build_error_result(error_msg="Knowledgebase not found!", retcode=RetCode.NOT_FOUND) + + if kb.tenant_id != tenant_id: + return build_error_result(error_msg="Knowledgebase not found!", retcode=RetCode.NOT_FOUND) + + embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) + + retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler + ranks = retr.retrieval( + question, + embd_mdl, + kb.tenant_id, + [kb_id], + page=1, + page_size=top, + similarity_threshold=similarity_threshold, + vector_similarity_weight=0.3, + top=top + ) + records = [] + for c in ranks["chunks"]: + if "vector" in c: + del c["vector"] + records.append({ + "content": c["content_ltks"], + "score": c["similarity"], + "title": c["docnm_kwd"], + "metadata": "" + }) + + return jsonify({"records": records}) + except Exception as e: + if str(e).find("not_found") > 0: + return build_error_result( + error_msg=f'No chunk found! Check the chunk status please!', + retcode=RetCode.NOT_FOUND + ) + return build_error_result(error_msg=str(e), retcode=RetCode.SERVER_ERROR) diff --git a/api/settings.py b/api/settings.py index 5078903a3d9..9faf7c1698d 100644 --- a/api/settings.py +++ b/api/settings.py @@ -250,3 +250,5 @@ class RetCode(IntEnum, CustomEnum): AUTHENTICATION_ERROR = 109 UNAUTHORIZED = 401 SERVER_ERROR = 500 + FORBIDDEN = 403 + NOT_FOUND = 404 diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 6bfeb8011df..8ca64c6663b 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -200,6 +200,27 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None): response = {"retcode": retcode, "retmsg": retmsg, "data": data} return jsonify(response) +def apikey_required(func): + @wraps(func) + def decorated_function(*args, **kwargs): + token = flask_request.headers.get('Authorization').split()[1] + objs = APIToken.query(token=token) + if not objs: + return build_error_result( + error_msg='API-KEY is invalid!', retcode=RetCode.FORBIDDEN + ) + kwargs['tenant_id'] = objs[0].tenant_id + return func(*args, **kwargs) + + return decorated_function + + +def build_error_result(retcode=RetCode.FORBIDDEN, error_msg='success'): + response = {"error_code": retcode, "error_msg": error_msg} + response = jsonify(response) + response.status_code = retcode + return response + def construct_response(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None):