Skip to content

Commit

Permalink
fix: sign out when user token invalid (#329)
Browse files Browse the repository at this point in the history
when the user token is invalid, signout to clear the session instead of
just clearning session['user_info'].
  • Loading branch information
Yuan325 committed Apr 4, 2024
1 parent da8d3f6 commit 2ec915b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 25 deletions.
25 changes: 13 additions & 12 deletions llm_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,20 @@ async def index(request: Request):
orchestrator = request.app.state.orchestrator
session = request.session

# check if token and user info is still valid
if "uuid" in session:
user_id_token = orchestrator.get_user_id_token(session["uuid"])
if user_id_token:
if session.get("user_info") and not get_user_info(
user_id_token, request.app.state.client_id
):
await logout_google(request)
elif not user_id_token and "user_info" in session:
await logout_google(request)

if "uuid" not in session or not orchestrator.user_session_exist(session["uuid"]):
await orchestrator.user_session_create(session)

# recheck if token and user info is still valid
user_id_token = orchestrator.get_user_id_token(session["uuid"])
if user_id_token:
if not get_user_info(user_id_token, request.app.state.client_id):
clear_user_info(session)
elif not user_id_token and "user_info" in session:
clear_user_info(session)

return templates.TemplateResponse(
"index.html",
{
Expand Down Expand Up @@ -129,10 +132,8 @@ async def logout_google(

uuid = request.session["uuid"]
orchestrator = request.app.state.orchestrator
if not orchestrator.user_session_exist(uuid):
raise HTTPException(status_code=500, detail=f"Current user session not found")

orchestrator.user_session_signout(uuid)
if orchestrator.user_session_exist(uuid):
await orchestrator.user_session_signout(uuid)
request.session.clear()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ def get_base_history(self, session: dict[str, Any]):
return base_history
return BASE_HISTORY

async def user_session_signout(self, uuid: str):
user_session = self.get_user_session(uuid)
if user_session:
await user_session.close()
del self._user_sessions[uuid]

def close_clients(self):
close_client_tasks = [
asyncio.create_task(a.close()) for a in self._user_sessions.values()
Expand Down
26 changes: 13 additions & 13 deletions llm_demo/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,26 @@ def get_user_session(self, uuid: str) -> Any:
async def user_session_insert_ticket(self, uuid: str, params: str) -> Any:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def user_session_signout(self, uuid: str):
"""Sign out from user session. Clear and restart session."""
raise NotImplementedError("Subclass should implement this!")

def set_user_session_header(self, uuid: str, user_id_token: str):
user_session = self.get_user_session(uuid)
user_session.client.headers["User-Id-Token"] = f"Bearer {user_id_token}"

def get_user_id_token(self, uuid: str) -> Optional[str]:
user_session = self.get_user_session(uuid)
if user_session.client and "User-Id-Token" in user_session.client.headers:
token = user_session.client.headers["User-Id-Token"]
parts = str(token).split(" ")
if len(parts) != 2 or parts[0] != "Bearer":
raise Exception("Invalid ID token")
return parts[1]
if self.user_session_exist(uuid):
user_session = self.get_user_session(uuid)
if user_session.client and "User-Id-Token" in user_session.client.headers:
token = user_session.client.headers["User-Id-Token"]
parts = str(token).split(" ")
if len(parts) != 2 or parts[0] != "Bearer":
raise Exception("Invalid ID token")
return parts[1]
return None

async def user_session_signout(self, uuid: str):
"""Sign out from user session. Clear and restart session."""
user_session = self.get_user_session(uuid)
await user_session.close()
del user_session


def createOrchestrator(orchestration_type: str) -> "BaseOrchestrator":
for cls in BaseOrchestrator.__subclasses__():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def get_base_history(self, session: dict[str, Any]):
return base_history
return BASE_HISTORY

async def user_session_signout(self, uuid: str):
user_session = self.get_user_session(uuid)
if user_session:
await user_session.close()
del self._user_sessions[uuid]

def close_clients(self):
close_client_tasks = [
asyncio.create_task(a.close()) for a in self._user_sessions.values()
Expand Down

0 comments on commit 2ec915b

Please sign in to comment.