-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
868d99d
commit 462e498
Showing
10 changed files
with
422 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,18 @@ | ||
"""Backend operations of RAG.""" | ||
|
||
import os | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
DATABASE_URL = os.getenv("DATABASE_URL") | ||
|
||
# Private key used to generate the JWT tokens for secure authentication | ||
SECRET_KEY = os.getenv("SECRET_KEY", "default_unsecure_key") | ||
|
||
# Algorithm used to generate JWT tokens | ||
ALGORITHM = os.getenv("ALGORITHM", "HS256") | ||
|
||
# Activate or deactivate the secure authentication | ||
USE_AUTHENTICATION = bool(int(os.getenv("USE_AUTHENTICATION", True))) |
Empty file.
106 changes: 106 additions & 0 deletions
106
backend/api_plugins/secure_authentication/secure_authentication.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from pathlib import Path | ||
|
||
from fastapi import Depends, HTTPException, Response, status | ||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | ||
from jose import JWTError, jwt | ||
|
||
from backend import ADMIN_MODE | ||
from backend.api_plugins.lib.user_management import ( | ||
ALGORITHM, | ||
SECRET_KEY, | ||
UnsecureUser, | ||
User, | ||
authenticate_user, | ||
create_access_token, | ||
create_user, | ||
delete_user, | ||
get_user, | ||
user_exists, | ||
) | ||
|
||
|
||
def authentication_routes(app, dependencies=list[Depends]): | ||
from backend.database import Database | ||
|
||
with Database() as connection: | ||
connection.run_script(Path(__file__).parent / "users_tables.sql") | ||
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/user/login") | ||
|
||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: | ||
credentials_exception = HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Could not validate credentials", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
try: | ||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | ||
email: str = payload.get("email") | ||
if email is None: | ||
raise credentials_exception | ||
|
||
user = get_user(email) | ||
if user is None: | ||
raise credentials_exception | ||
return user | ||
except JWTError: | ||
raise credentials_exception | ||
|
||
@app.post("/user/signup", include_in_schema=ADMIN_MODE) | ||
async def signup(user: UnsecureUser) -> dict: | ||
if not ADMIN_MODE: | ||
raise HTTPException( | ||
status_code=status.HTTP_403_FORBIDDEN, detail="Signup is disabled" | ||
) | ||
|
||
user = User.from_unsecure_user(user) | ||
if user_exists(user.email): | ||
raise HTTPException( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
detail=f"User {user.email} already registered", | ||
) | ||
|
||
create_user(user) | ||
return {"email": user.email} | ||
|
||
@app.delete("/user/") | ||
async def del_user(current_user: User = Depends(get_current_user)) -> dict: | ||
email = current_user.email | ||
try: | ||
user = get_user(email) | ||
if user is None: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail=f"User {email} not found", | ||
) | ||
delete_user(email) | ||
return {"detail": f"User {email} deleted"} | ||
except Exception: | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail="Internal Server Error", | ||
) | ||
|
||
@app.post("/user/login") | ||
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict: | ||
user = authenticate_user(form_data.username, form_data.password) | ||
if not user: | ||
raise HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Incorrect username or password", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
user_data = user.dict() | ||
del user_data["hashed_password"] | ||
access_token = create_access_token(data=user_data) | ||
return {"access_token": access_token, "token_type": "bearer"} | ||
|
||
@app.get("/user/me") | ||
async def user_me(current_user: User = Depends(get_current_user)) -> User: | ||
return current_user | ||
|
||
@app.get("/user") | ||
async def user_root() -> dict: | ||
return Response("User management routes are enabled.", status_code=200) | ||
|
||
return Depends(get_current_user) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
-- Dialect MUST be sqlite, even if the database you use is different. | ||
-- It is transpiled to the right dialect when executed. | ||
|
||
CREATE TABLE IF NOT EXISTS "users" ( | ||
"email" VARCHAR(255) PRIMARY KEY, | ||
"password" TEXT | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import json | ||
from collections.abc import Sequence | ||
from datetime import datetime | ||
from pathlib import Path | ||
from uuid import uuid4 | ||
|
||
from fastapi import APIRouter, Depends, FastAPI, Response | ||
|
||
from backend.api_plugins.lib.user_management import User | ||
|
||
|
||
def session_routes( | ||
app: FastAPI | APIRouter, | ||
*, | ||
authentication: Depends = None, | ||
dependencies: Sequence[Depends] | None = None, | ||
): | ||
from backend.database import Database | ||
from backend.model import Message | ||
|
||
with Database() as connection: | ||
connection.run_script(Path(__file__).parent / "sessions_tables.sql") | ||
|
||
@app.post("/session/new") | ||
async def chat_new( | ||
current_user: User = authentication, dependencies=dependencies | ||
) -> dict: | ||
chat_id = str(uuid4()) | ||
timestamp = datetime.utcnow().isoformat() | ||
user_id = current_user.email if current_user else "unauthenticated" | ||
with Database() as connection: | ||
connection.execute( | ||
"INSERT INTO session (id, timestamp, user_id) VALUES (?, ?, ?)", | ||
(chat_id, timestamp, user_id), | ||
) | ||
return {"session_id": chat_id} | ||
|
||
@app.get("/session/list") | ||
async def chat_list( | ||
current_user: User = authentication, dependencies=dependencies | ||
) -> list[dict]: | ||
user_email = current_user.email if current_user else "unauthenticated" | ||
chats = [] | ||
with Database() as connection: | ||
result = connection.execute( | ||
"SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp DESC", | ||
(user_email,), | ||
) | ||
chats = [{"id": row[0], "timestamp": row[1]} for row in result] | ||
return chats | ||
|
||
@app.get("/session/{session_id}") | ||
async def chat( | ||
session_id: str, current_user: User = authentication, dependencies=dependencies | ||
) -> dict: | ||
messages: list[Message] = [] | ||
with Database() as connection: | ||
result = connection.execute( | ||
"SELECT id, timestamp, session_id, message FROM message_history WHERE session_id = ? ORDER BY timestamp ASC", | ||
(session_id,), | ||
) | ||
for row in result: | ||
content = json.loads(row[3])["data"]["content"] | ||
message_type = json.loads(row[3])["type"] | ||
message = Message( | ||
id=row[0], | ||
timestamp=row[1], | ||
session_id=row[2], | ||
sender=message_type if message_type == "human" else "ai", | ||
content=content, | ||
) | ||
messages.append(message) | ||
return { | ||
"chat_id": session_id, | ||
"messages": [message.dict() for message in messages], | ||
} | ||
|
||
@app.get("/session") | ||
async def session_root( | ||
current_user: User = authentication, dependencies=dependencies | ||
) -> dict: | ||
return Response("Sessions management routes are enabled.", status_code=200) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
-- Dialect MUST be sqlite, even if the database you use is different. | ||
-- It is transpiled to the right dialect when executed. | ||
|
||
CREATE TABLE IF NOT EXISTS "session" ( | ||
"id" VARCHAR(255) PRIMARY KEY, | ||
"timestamp" DATETIME, | ||
"user_id" VARCHAR(255), | ||
FOREIGN KEY ("user_id") REFERENCES "users" ("email") | ||
); |
Oops, something went wrong.