Skip to content

Commit

Permalink
Add Authorization checks (#2221)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Add Authorization checks
#2203

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
  • Loading branch information
3 people authored Sep 4, 2024
1 parent 4f05803 commit 0164856
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 25 deletions.
5 changes: 5 additions & 0 deletions api/apps/canvas_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flask import request, Response
from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas
Expand All @@ -43,6 +44,10 @@ def canvas_list():
@login_required
def rm():
for i in request.json["canvas_ids"]:
if not UserCanvasService.query(user_id=current_user.id,id=i):
return get_json_result(
data=False, retmsg=f'Only owner of canvas authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)

Expand Down
71 changes: 50 additions & 21 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from copy import deepcopy

from db.services.user_service import UserTenantService
from flask import request, Response
from flask_login import login_required,current_user
from flask_login import login_required, current_user

from api.db import LLMType
from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.db.services.llm_service import LLMBundle, TenantService
from api.db import LLMType
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
import json
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request


@manager.route('/set', methods=['POST'])
Expand Down Expand Up @@ -72,6 +76,14 @@ def get():
e, conv = ConversationService.get_by_id(conv_id)
if not e:
return get_data_error_result(retmsg="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
conv = conv.to_dict()
return get_json_result(data=conv)
except Exception as e:
Expand All @@ -84,6 +96,17 @@ def rm():
conv_ids = request.json["conversation_ids"]
try:
for cid in conv_ids:
exist, conv = ConversationService.get_by_id(cid)
if not exist:
return get_data_error_result(retmsg="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid)
return get_json_result(data=True)
except Exception as e:
Expand All @@ -95,6 +118,10 @@ def rm():
def list_convsersation():
dialog_id = request.args["dialog_id"]
try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result(
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
convs = ConversationService.query(
dialog_id=dialog_id,
order_by=ConversationService.model.create_time,
Expand All @@ -107,12 +134,12 @@ def list_convsersation():

@manager.route('/completion', methods=['POST'])
@login_required
#@validate_request("conversation_id", "messages")
@validate_request("conversation_id", "messages")
def completion():
req = request.json
#req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# {"role": "user", "content": "上海有吗?"}
#]}
# ]}
msg = []
for m in req["messages"]:
if m["role"] == "system":
Expand Down Expand Up @@ -141,7 +168,8 @@ def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else: conv.reference[-1] = ans["reference"]
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
"id": message_id, "prompt": ans.get("prompt", "")}
ans["id"] = message_id
Expand All @@ -151,13 +179,13 @@ def stream():
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"

if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
Expand All @@ -184,33 +212,34 @@ def stream():
def tts():
req = request.json
text = req["text"]

tenants = TenantService.get_by_user_id(current_user.id)
if not tenants:
return get_data_error_result(retmsg="Tenant not found!")

tts_id = tenants[0]["tts_id"]
if not tts_id:
return get_data_error_result(retmsg="No default TTS model is set")

tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)

def stream_audio():
try:
for chunk in tts_mdl.tts(text):
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e)}},
ensure_ascii=False)).encode('utf-8')
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')

resp = Response(stream_audio(), mimetype="audio/mpeg")
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")

return resp


@manager.route('/delete_msg', methods=['POST'])
@login_required
@validate_request("conversation_id", "message_id")
Expand All @@ -224,10 +253,10 @@ def delete_msg():
for i, msg in enumerate(conv["message"]):
if req["message_id"] != msg.get("id", ""):
continue
assert conv["message"][i+1]["id"] == req["message_id"]
assert conv["message"][i + 1]["id"] == req["message_id"]
conv["message"].pop(i)
conv["message"].pop(i)
conv["reference"].pop(max(0, i//2-1))
conv["reference"].pop(max(0, i // 2 - 1))
break

ConversationService.update_by_id(conv["id"], conv)
Expand Down
17 changes: 14 additions & 3 deletions api/apps/dialog_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from api.db.services.dialog_service import DialogService
from api.db import StatusEnum
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import RetCode
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
Expand Down Expand Up @@ -164,9 +165,19 @@ def list_dialogs():
@validate_request("dialog_ids")
def rm():
req = request.json
dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id)
try:
DialogService.update_many_by_id(
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
for id in req["dialog_ids"]:
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
DialogService.update_many_by_id(dialog_list)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
11 changes: 10 additions & 1 deletion api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import TenantService
from api.db.services.user_service import TenantService, UserTenantService
from graphrag.mind_map_extractor import MindMapExtractor
from rag.app import naive
from rag.nlp import search
Expand Down Expand Up @@ -189,6 +189,15 @@ def list_docs():
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
keywords = request.args.get("keywords", "")

page_number = int(request.args.get("page", 1))
Expand Down

0 comments on commit 0164856

Please sign in to comment.