Skip to content

Commit

Permalink
add authentification
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Mar 21, 2024
1 parent b5b1e4f commit 29a404d
Show file tree
Hide file tree
Showing 15 changed files with 898 additions and 227 deletions.
14 changes: 11 additions & 3 deletions app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from hydra import compose, initialize
from langserve import add_routes

from backend import USE_AUTHENTICATION
from backend.api_plugins import authentication_routes, session_routes
from backend.rag_1.chain import get_chain as get_chain_rag_1
from backend.rag_1.config import validate_config as validate_config_1
from backend.rag_2.chain import get_chain as get_chain_rag_2
Expand Down Expand Up @@ -42,15 +44,21 @@ async def redirect_root_to_docs() -> RedirectResponse:
# validate config
_ = validate_config_3(config_3)

if USE_AUTHENTICATION:
auth = authentication_routes(app)
session_routes(app, authentication=auth)
dependencies = [auth]
else:
dependencies = None

chain_rag_1 = get_chain_rag_1(config_1)
add_routes(app, chain_rag_1, path="/rag-1")
add_routes(app, chain_rag_1, path="/rag-1", dependencies=dependencies)

chain_rag_2 = get_chain_rag_2(config_2)
add_routes(app, chain_rag_2, path="/rag-2")
add_routes(app, chain_rag_2, path="/rag-2", dependencies=dependencies)

chain_rag_3 = get_chain_rag_3(config_3)
add_routes(app, chain_rag_3, path="/rag-3")
add_routes(app, chain_rag_3, path="/rag-3", dependencies=dependencies)

if __name__ == "__main__":
import uvicorn
Expand Down
20 changes: 20 additions & 0 deletions backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,21 @@
"""Backend operations of RAG."""

import os

from dotenv import load_dotenv

load_dotenv()

DATABASE_URL = os.getenv("DATABASE_URL")

# Private key used to generate the JWT tokens for secure authentication
SECRET_KEY = os.getenv("SECRET_KEY", "default_unsecure_key")

# Algorithm used to generate JWT tokens
ALGORITHM = os.getenv("ALGORITHM", "HS256")

# Activate or deactivate the secure authentication
USE_AUTHENTICATION = bool(int(os.getenv("USE_AUTHENTICATION", True)))

# If the API runs in admin mode, it will allow the creation of new users
ADMIN_MODE = bool(int(os.getenv("ADMIN_MODE", False)))
Empty file added backend/api_plugins/__init__.py
Empty file.
73 changes: 73 additions & 0 deletions backend/api_plugins/lib/user_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from datetime import datetime, timedelta

import argon2
from jose import jwt
from pydantic import BaseModel

from backend import ALGORITHM, SECRET_KEY
from backend.database import Database


class UnsecureUser(BaseModel):
email: str = None
password: bytes = None


class User(BaseModel):
email: str = None
hashed_password: str = None

@classmethod
def from_unsecure_user(cls, unsecure_user: UnsecureUser):
hashed_password = argon2.hash_password(unsecure_user.password).decode("utf-8")
return cls(email=unsecure_user.email, hashed_password=hashed_password)


def create_user(user: User) -> None:
with Database() as connection:
connection.execute(
"INSERT INTO users (email, password) VALUES (?, ?)",
(user.email, user.hashed_password),
)


def user_exists(email: str) -> bool:
with Database() as connection:
result = connection.fetchone("SELECT 1 FROM users WHERE email = ?", (email,))
return bool(result)


def get_user(email: str) -> User | None:
with Database() as connection:
user_row = connection.fetchone("SELECT * FROM users WHERE email = ?", (email,))
if user_row:
return User(email=user_row[0], hashed_password=user_row[1])
return None


def delete_user(email: str) -> None:
with Database() as connection:
connection.execute("DELETE FROM users WHERE email = ?", (email,))


def authenticate_user(username: str, password: bytes) -> bool | User:
user = get_user(username)
if not user:
return False

if argon2.verify_password(
user.hashed_password.encode("utf-8"), password.encode("utf-8")
):
return user

return False


def create_access_token(*, data: dict, expires_delta: timedelta | None = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=60)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
106 changes: 106 additions & 0 deletions backend/api_plugins/secure_authentication/secure_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from pathlib import Path

from fastapi import Depends, HTTPException, Response, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt

from backend import ADMIN_MODE
from backend.api_plugins.lib.user_management import (
ALGORITHM,
SECRET_KEY,
UnsecureUser,
User,
authenticate_user,
create_access_token,
create_user,
delete_user,
get_user,
user_exists,
)


def authentication_routes(app, dependencies=list[Depends]):
from backend.database import Database

with Database() as connection:
connection.run_script(Path(__file__).parent / "users_tables.sql")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/user/login")

async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get("email")
if email is None:
raise credentials_exception

user = get_user(email)
if user is None:
raise credentials_exception
return user
except JWTError:
raise credentials_exception

@app.post("/user/signup", include_in_schema=ADMIN_MODE)
async def signup(user: UnsecureUser) -> dict:
if not ADMIN_MODE:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Signup is disabled"
)

user = User.from_unsecure_user(user)
if user_exists(user.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"User {user.email} already registered",
)

create_user(user)
return {"email": user.email}

@app.delete("/user/")
async def del_user(current_user: User = Depends(get_current_user)) -> dict:
email = current_user.email
try:
user = get_user(email)
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User {email} not found",
)
delete_user(email)
return {"detail": f"User {email} deleted"}
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal Server Error",
)

@app.post("/user/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
user = authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
user_data = user.dict()
del user_data["hashed_password"]
access_token = create_access_token(data=user_data)
return {"access_token": access_token, "token_type": "bearer"}

@app.get("/user/me")
async def user_me(current_user: User = Depends(get_current_user)) -> User:
return current_user

@app.get("/user")
async def user_root() -> dict:
return Response("User management routes are enabled.", status_code=200)

return Depends(get_current_user)
7 changes: 7 additions & 0 deletions backend/api_plugins/secure_authentication/users_tables.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Dialect MUST be sqlite, even if the database you use is different.
-- It is transpiled to the right dialect when executed.

CREATE TABLE IF NOT EXISTS "users" (
"email" VARCHAR(255) PRIMARY KEY,
"password" TEXT
);
82 changes: 82 additions & 0 deletions backend/api_plugins/sessions/sessions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import json
from collections.abc import Sequence
from datetime import datetime
from pathlib import Path
from uuid import uuid4

from fastapi import APIRouter, Depends, FastAPI, Response

from backend.api_plugins.lib.user_management import User


def session_routes(
app: FastAPI | APIRouter,
*,
authentication: Depends = None,
dependencies: Sequence[Depends] | None = None,
):
from backend.database import Database
from backend.model import Message

with Database() as connection:
connection.run_script(Path(__file__).parent / "sessions_tables.sql")

@app.post("/session/new")
async def chat_new(
current_user: User = authentication, dependencies=dependencies
) -> dict:
chat_id = str(uuid4())
timestamp = datetime.utcnow().isoformat()
user_id = current_user.email if current_user else "unauthenticated"
with Database() as connection:
connection.execute(
"INSERT INTO session (id, timestamp, user_id) VALUES (?, ?, ?)",
(chat_id, timestamp, user_id),
)
return {"session_id": chat_id}

@app.get("/session/list")
async def chat_list(
current_user: User = authentication, dependencies=dependencies
) -> list[dict]:
user_email = current_user.email if current_user else "unauthenticated"
chats = []
with Database() as connection:
result = connection.execute(
"SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp DESC",
(user_email,),
)
chats = [{"id": row[0], "timestamp": row[1]} for row in result]
return chats

@app.get("/session/{session_id}")
async def chat(
session_id: str, current_user: User = authentication, dependencies=dependencies
) -> dict:
messages: list[Message] = []
with Database() as connection:
result = connection.execute(
"SELECT id, timestamp, session_id, message FROM message_history WHERE session_id = ? ORDER BY timestamp ASC",
(session_id,),
)
for row in result:
content = json.loads(row[3])["data"]["content"]
message_type = json.loads(row[3])["type"]
message = Message(
id=row[0],
timestamp=row[1],
session_id=row[2],
sender=message_type if message_type == "human" else "ai",
content=content,
)
messages.append(message)
return {
"chat_id": session_id,
"messages": [message.dict() for message in messages],
}

@app.get("/session")
async def session_root(
current_user: User = authentication, dependencies=dependencies
) -> dict:
return Response("Sessions management routes are enabled.", status_code=200)
9 changes: 9 additions & 0 deletions backend/api_plugins/sessions/sessions_tables.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- Dialect MUST be sqlite, even if the database you use is different.
-- It is transpiled to the right dialect when executed.

CREATE TABLE IF NOT EXISTS "session" (
"id" VARCHAR(255) PRIMARY KEY,
"timestamp" DATETIME,
"user_id" VARCHAR(255),
FOREIGN KEY ("user_id") REFERENCES "users" ("email")
);
Loading

0 comments on commit 29a404d

Please sign in to comment.