-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b5b1e4f
commit 29a404d
Showing
15 changed files
with
898 additions
and
227 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,21 @@ | ||
"""Backend operations of RAG.""" | ||
|
||
import os | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
DATABASE_URL = os.getenv("DATABASE_URL") | ||
|
||
# Private key used to generate the JWT tokens for secure authentication | ||
SECRET_KEY = os.getenv("SECRET_KEY", "default_unsecure_key") | ||
|
||
# Algorithm used to generate JWT tokens | ||
ALGORITHM = os.getenv("ALGORITHM", "HS256") | ||
|
||
# Activate or deactivate the secure authentication | ||
USE_AUTHENTICATION = bool(int(os.getenv("USE_AUTHENTICATION", True))) | ||
|
||
# If the API runs in admin mode, it will allow the creation of new users | ||
ADMIN_MODE = bool(int(os.getenv("ADMIN_MODE", False))) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from datetime import datetime, timedelta | ||
|
||
import argon2 | ||
from jose import jwt | ||
from pydantic import BaseModel | ||
|
||
from backend import ALGORITHM, SECRET_KEY | ||
from backend.database import Database | ||
|
||
|
||
class UnsecureUser(BaseModel): | ||
email: str = None | ||
password: bytes = None | ||
|
||
|
||
class User(BaseModel): | ||
email: str = None | ||
hashed_password: str = None | ||
|
||
@classmethod | ||
def from_unsecure_user(cls, unsecure_user: UnsecureUser): | ||
hashed_password = argon2.hash_password(unsecure_user.password).decode("utf-8") | ||
return cls(email=unsecure_user.email, hashed_password=hashed_password) | ||
|
||
|
||
def create_user(user: User) -> None: | ||
with Database() as connection: | ||
connection.execute( | ||
"INSERT INTO users (email, password) VALUES (?, ?)", | ||
(user.email, user.hashed_password), | ||
) | ||
|
||
|
||
def user_exists(email: str) -> bool: | ||
with Database() as connection: | ||
result = connection.fetchone("SELECT 1 FROM users WHERE email = ?", (email,)) | ||
return bool(result) | ||
|
||
|
||
def get_user(email: str) -> User | None: | ||
with Database() as connection: | ||
user_row = connection.fetchone("SELECT * FROM users WHERE email = ?", (email,)) | ||
if user_row: | ||
return User(email=user_row[0], hashed_password=user_row[1]) | ||
return None | ||
|
||
|
||
def delete_user(email: str) -> None: | ||
with Database() as connection: | ||
connection.execute("DELETE FROM users WHERE email = ?", (email,)) | ||
|
||
|
||
def authenticate_user(username: str, password: bytes) -> bool | User: | ||
user = get_user(username) | ||
if not user: | ||
return False | ||
|
||
if argon2.verify_password( | ||
user.hashed_password.encode("utf-8"), password.encode("utf-8") | ||
): | ||
return user | ||
|
||
return False | ||
|
||
|
||
def create_access_token(*, data: dict, expires_delta: timedelta | None = None) -> str: | ||
to_encode = data.copy() | ||
if expires_delta: | ||
expire = datetime.utcnow() + expires_delta | ||
else: | ||
expire = datetime.utcnow() + timedelta(minutes=60) | ||
to_encode.update({"exp": expire}) | ||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
106 changes: 106 additions & 0 deletions
106
backend/api_plugins/secure_authentication/secure_authentication.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from pathlib import Path | ||
|
||
from fastapi import Depends, HTTPException, Response, status | ||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | ||
from jose import JWTError, jwt | ||
|
||
from backend import ADMIN_MODE | ||
from backend.api_plugins.lib.user_management import ( | ||
ALGORITHM, | ||
SECRET_KEY, | ||
UnsecureUser, | ||
User, | ||
authenticate_user, | ||
create_access_token, | ||
create_user, | ||
delete_user, | ||
get_user, | ||
user_exists, | ||
) | ||
|
||
|
||
def authentication_routes(app, dependencies=list[Depends]): | ||
from backend.database import Database | ||
|
||
with Database() as connection: | ||
connection.run_script(Path(__file__).parent / "users_tables.sql") | ||
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/user/login") | ||
|
||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: | ||
credentials_exception = HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Could not validate credentials", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
try: | ||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | ||
email: str = payload.get("email") | ||
if email is None: | ||
raise credentials_exception | ||
|
||
user = get_user(email) | ||
if user is None: | ||
raise credentials_exception | ||
return user | ||
except JWTError: | ||
raise credentials_exception | ||
|
||
@app.post("/user/signup", include_in_schema=ADMIN_MODE) | ||
async def signup(user: UnsecureUser) -> dict: | ||
if not ADMIN_MODE: | ||
raise HTTPException( | ||
status_code=status.HTTP_403_FORBIDDEN, detail="Signup is disabled" | ||
) | ||
|
||
user = User.from_unsecure_user(user) | ||
if user_exists(user.email): | ||
raise HTTPException( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
detail=f"User {user.email} already registered", | ||
) | ||
|
||
create_user(user) | ||
return {"email": user.email} | ||
|
||
@app.delete("/user/") | ||
async def del_user(current_user: User = Depends(get_current_user)) -> dict: | ||
email = current_user.email | ||
try: | ||
user = get_user(email) | ||
if user is None: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail=f"User {email} not found", | ||
) | ||
delete_user(email) | ||
return {"detail": f"User {email} deleted"} | ||
except Exception: | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail="Internal Server Error", | ||
) | ||
|
||
@app.post("/user/login") | ||
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict: | ||
user = authenticate_user(form_data.username, form_data.password) | ||
if not user: | ||
raise HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Incorrect username or password", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
user_data = user.dict() | ||
del user_data["hashed_password"] | ||
access_token = create_access_token(data=user_data) | ||
return {"access_token": access_token, "token_type": "bearer"} | ||
|
||
@app.get("/user/me") | ||
async def user_me(current_user: User = Depends(get_current_user)) -> User: | ||
return current_user | ||
|
||
@app.get("/user") | ||
async def user_root() -> dict: | ||
return Response("User management routes are enabled.", status_code=200) | ||
|
||
return Depends(get_current_user) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
-- Dialect MUST be sqlite, even if the database you use is different. | ||
-- It is transpiled to the right dialect when executed. | ||
|
||
CREATE TABLE IF NOT EXISTS "users" ( | ||
"email" VARCHAR(255) PRIMARY KEY, | ||
"password" TEXT | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import json | ||
from collections.abc import Sequence | ||
from datetime import datetime | ||
from pathlib import Path | ||
from uuid import uuid4 | ||
|
||
from fastapi import APIRouter, Depends, FastAPI, Response | ||
|
||
from backend.api_plugins.lib.user_management import User | ||
|
||
|
||
def session_routes( | ||
app: FastAPI | APIRouter, | ||
*, | ||
authentication: Depends = None, | ||
dependencies: Sequence[Depends] | None = None, | ||
): | ||
from backend.database import Database | ||
from backend.model import Message | ||
|
||
with Database() as connection: | ||
connection.run_script(Path(__file__).parent / "sessions_tables.sql") | ||
|
||
@app.post("/session/new") | ||
async def chat_new( | ||
current_user: User = authentication, dependencies=dependencies | ||
) -> dict: | ||
chat_id = str(uuid4()) | ||
timestamp = datetime.utcnow().isoformat() | ||
user_id = current_user.email if current_user else "unauthenticated" | ||
with Database() as connection: | ||
connection.execute( | ||
"INSERT INTO session (id, timestamp, user_id) VALUES (?, ?, ?)", | ||
(chat_id, timestamp, user_id), | ||
) | ||
return {"session_id": chat_id} | ||
|
||
@app.get("/session/list") | ||
async def chat_list( | ||
current_user: User = authentication, dependencies=dependencies | ||
) -> list[dict]: | ||
user_email = current_user.email if current_user else "unauthenticated" | ||
chats = [] | ||
with Database() as connection: | ||
result = connection.execute( | ||
"SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp DESC", | ||
(user_email,), | ||
) | ||
chats = [{"id": row[0], "timestamp": row[1]} for row in result] | ||
return chats | ||
|
||
@app.get("/session/{session_id}") | ||
async def chat( | ||
session_id: str, current_user: User = authentication, dependencies=dependencies | ||
) -> dict: | ||
messages: list[Message] = [] | ||
with Database() as connection: | ||
result = connection.execute( | ||
"SELECT id, timestamp, session_id, message FROM message_history WHERE session_id = ? ORDER BY timestamp ASC", | ||
(session_id,), | ||
) | ||
for row in result: | ||
content = json.loads(row[3])["data"]["content"] | ||
message_type = json.loads(row[3])["type"] | ||
message = Message( | ||
id=row[0], | ||
timestamp=row[1], | ||
session_id=row[2], | ||
sender=message_type if message_type == "human" else "ai", | ||
content=content, | ||
) | ||
messages.append(message) | ||
return { | ||
"chat_id": session_id, | ||
"messages": [message.dict() for message in messages], | ||
} | ||
|
||
@app.get("/session") | ||
async def session_root( | ||
current_user: User = authentication, dependencies=dependencies | ||
) -> dict: | ||
return Response("Sessions management routes are enabled.", status_code=200) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
-- Dialect MUST be sqlite, even if the database you use is different. | ||
-- It is transpiled to the right dialect when executed. | ||
|
||
CREATE TABLE IF NOT EXISTS "session" ( | ||
"id" VARCHAR(255) PRIMARY KEY, | ||
"timestamp" DATETIME, | ||
"user_id" VARCHAR(255), | ||
FOREIGN KEY ("user_id") REFERENCES "users" ("email") | ||
); |
Oops, something went wrong.