From 1cd0f2e7dcef3d0027a7a583b59b45e4b6914657 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 28 Mar 2023 22:54:09 +0200 Subject: [PATCH 1/3] API refactoring --- Dockerfile | 17 +- README.md | 1 - api/.gitignore | 160 ++++++++++++ api/poetry.toml | 2 + api/pyproject.toml | 66 +++++ api/requirements.txt | 44 ---- api/src/serge/dependencies.py | 31 +++ api/src/serge/main.py | 97 ++++++++ api/src/serge/models/__init__.py | 1 + api/{models.py => src/serge/models/chat.py} | 1 - .../__init__.py => src/serge/models/model.py} | 0 api/src/serge/routers/__init__.py | 2 + api/{main.py => src/serge/routers/chat.py} | 233 +++++------------- api/src/serge/routers/model.py | 111 +++++++++ api/src/serge/utils/__init__.py | 0 api/{ => src/serge}/utils/convert.py | 2 +- api/{ => src/serge}/utils/generate.py | 4 +- .../serge}/utils/initiate_database.py | 2 +- api/utils/download.py | 56 ----- compile.sh => scripts/compile.sh | 0 deploy.sh => scripts/deploy.sh | 2 +- dev.sh => scripts/dev.sh | 4 +- web/src/routes/+layout.svelte | 4 +- web/src/routes/+layout.ts | 2 +- web/src/routes/+page.svelte | 26 +- web/src/routes/+page.ts | 11 +- web/src/routes/models/+page.svelte | 72 ++++++ web/src/routes/models/+page.ts | 16 ++ 28 files changed, 657 insertions(+), 310 deletions(-) create mode 100644 api/.gitignore create mode 100644 api/poetry.toml create mode 100644 api/pyproject.toml delete mode 100644 api/requirements.txt create mode 100644 api/src/serge/dependencies.py create mode 100644 api/src/serge/main.py create mode 100644 api/src/serge/models/__init__.py rename api/{models.py => src/serge/models/chat.py} (97%) rename api/{utils/__init__.py => src/serge/models/model.py} (100%) create mode 100644 api/src/serge/routers/__init__.py rename api/{main.py => src/serge/routers/chat.py} (53%) create mode 100644 api/src/serge/routers/model.py create mode 100644 api/src/serge/utils/__init__.py rename api/{ => src/serge}/utils/convert.py (98%) rename api/{ => src/serge}/utils/generate.py (95%) rename api/{ => src/serge}/utils/initiate_database.py (89%) delete mode 100644 api/utils/download.py rename compile.sh => scripts/compile.sh (100%) rename deploy.sh => scripts/deploy.sh (69%) rename dev.sh => scripts/dev.sh (65%) create mode 100644 web/src/routes/models/+page.svelte create mode 100644 web/src/routes/models/+page.ts diff --git a/Dockerfile b/Dockerfile index 467ad54818b..34e61065c59 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ ENV TZ=Europe/Amsterdam WORKDIR /usr/src/app -COPY --chmod=0755 compile.sh . +COPY --chmod=0755 scripts/compile.sh . # Install MongoDB and necessary tools RUN apt update && \ @@ -23,11 +23,7 @@ RUN apt update && \ apt-get install -y mongodb-org && \ git clone https://github.com/ggerganov/llama.cpp.git --branch master-d5850c5 - -# copy & install python reqs -COPY ./api/requirements.txt api/requirements.txt -RUN pip install --upgrade pip && \ - pip install --no-cache-dir -r ./api/requirements.txt +RUN pip install --upgrade pip # Dev environment FROM base as dev @@ -38,11 +34,7 @@ COPY --from=node_base /usr/local /usr/local COPY ./web/package*.json ./ RUN npm ci -# Copy the rest of the project files -COPY web /usr/src/app/web -COPY ./api /usr/src/app/api - -COPY --chmod=0755 dev.sh /usr/src/app/dev.sh +COPY --chmod=0755 scripts/dev.sh /usr/src/app/dev.sh CMD ./dev.sh # Build frontend @@ -64,6 +56,7 @@ WORKDIR /usr/src/app COPY --from=frontend_builder /usr/src/app/web/build /usr/src/app/api/static/ COPY ./api /usr/src/app/api -COPY --chmod=0755 deploy.sh /usr/src/app/deploy.sh +RUN pip install ./api +COPY --chmod=0755 scripts/deploy.sh /usr/src/app/deploy.sh CMD ./deploy.sh diff --git a/README.md b/README.md index c459508f59c..79fd91965ae 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,6 @@ cd serge docker compose up --build -d docker compose exec serge python3 /usr/src/app/api/utils/download.py tokenizer 7B ``` -Please note that the models occupy the following storage space: 7B requires 4.21G, 13B requires 8.14G, and 30B requires 20.3G #### Windows diff --git a/api/.gitignore b/api/.gitignore new file mode 100644 index 00000000000..6769e21d99a --- /dev/null +++ b/api/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/api/poetry.toml b/api/poetry.toml new file mode 100644 index 00000000000..084377a033b --- /dev/null +++ b/api/poetry.toml @@ -0,0 +1,2 @@ +[virtualenvs] +create = false diff --git a/api/pyproject.toml b/api/pyproject.toml new file mode 100644 index 00000000000..47135cff06f --- /dev/null +++ b/api/pyproject.toml @@ -0,0 +1,66 @@ +[tool.poetry] +name = "serge" +description = "Serge API package" +version = "0.1.0" +license = "MIT" +authors = [ + "Nathan Sarrazin " +] + +packages = [ + { include = "serge", from = "src" } +] + +homepage = "https://serge.chat/" +repository = "https://github.com/nsarrazin/serge" + +include = [{path="src"}] + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry.dependencies] +python=">=3.10,<4.0" +asyncio = "^3.4.3" +packaging = "^23.0" +pydantic = "^1.10.7" +pymongo = "^4.3.3" +python-dotenv = "^1.0.0" +python-multipart = "^0.0.6" +pyyaml = "^6.0" +rfc3986 = "^2.0.0" +sentencepiece = "^0.1.97" +sniffio = "^1.3.0" +sse-starlette = "^1.3.3" +starlette = "^0.26.1" +toml = "^0.10.2" +tqdm = "^4.65.0" +typing-extensions = "^4.5.0" +ujson = "^5.7.0" +urllib3 = "^1.26.15" +uvicorn = "^0.21.1" +uvloop = "^0.17.0" +watchfiles = "^0.19.0" +websockets = "^10.4" +anyio = "^3.6.2" +certifi = "^2022.12.7" +charset-normalizer = "^3.1.0" +click = "^8.1.3" +email-validator = "^1.3.1" +fastapi = "^0.95.0" +filelock = "^3.10.7" +h11 = "^0.14.0" +httpcore = "^0.17.0" +httptools = "^0.5.0" +huggingface-hub = "^0.13.3" +idna = "^3.4" +itsdangerous = "^2.1.2" +jinja2 = "^3.1.2" +markupsafe = "^2.1.2" +motor = "^3.1.1" +orjson = "^3.8.8" +beanie = "^1.17.0" +dnspython = "^2.3.0" +lazy-model = "^0.0.5" +requests = "^2.28.2" diff --git a/api/requirements.txt b/api/requirements.txt deleted file mode 100644 index 2574559d2ba..00000000000 --- a/api/requirements.txt +++ /dev/null @@ -1,44 +0,0 @@ -anyio==3.6.2 -asyncio==3.4.3 -beanie==1.17.0 -certifi==2022.12.7 -charset-normalizer==3.1.0 -click==8.1.3 -dnspython==2.3.0 -email-validator==1.3.1 -fastapi==0.95.0 -filelock==3.10.2 -h11==0.14.0 -httpcore==0.16.3 -httptools==0.5.0 -httpx==0.23.3 -huggingface-hub==0.13.3 -idna==3.4 -itsdangerous==2.1.2 -Jinja2==3.1.2 -lazy-model==0.0.5 -MarkupSafe==2.1.2 -motor==3.1.1 -orjson==3.8.8 -packaging==23.0 -psutil==5.9.4 -pydantic==1.10.7 -pymongo==4.3.3 -python-dotenv==1.0.0 -python-multipart==0.0.6 -PyYAML==6.0 -requests==2.28.2 -rfc3986==1.5.0 -sentencepiece==0.1.97 -sniffio==1.3.0 -sse-starlette==1.3.3 -starlette==0.26.1 -toml==0.10.2 -tqdm==4.65.0 -typing_extensions==4.5.0 -ujson==5.7.0 -urllib3==1.26.15 -uvicorn==0.21.1 -uvloop==0.17.0 -watchfiles==0.18.1 -websockets==10.4 \ No newline at end of file diff --git a/api/src/serge/dependencies.py b/api/src/serge/dependencies.py new file mode 100644 index 00000000000..e33a498a823 --- /dev/null +++ b/api/src/serge/dependencies.py @@ -0,0 +1,31 @@ +from fastapi import HTTPException, status +from .utils.convert import convert_all +import anyio +import os + +MODEL_IS_READY: bool = False + + +def dep_models_ready() -> list[str]: + """ + FastAPI dependency that checks if models are ready. + + Returns a list of available models + """ + if MODEL_IS_READY is False: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail={ + "message": "models are not ready" + } + ) + + files = os.listdir("/usr/src/app/weights") + files = list(filter(lambda x: x.endswith(".bin"), files)) + return files + + +async def convert_model_files(): + global MODEL_IS_READY + await anyio.to_thread.run_sync(convert_all, "/usr/src/app/weights/", "/usr/src/app/weights/tokenizer.model") + MODEL_IS_READY = True diff --git a/api/src/serge/main.py b/api/src/serge/main.py new file mode 100644 index 00000000000..282578e0659 --- /dev/null +++ b/api/src/serge/main.py @@ -0,0 +1,97 @@ +import asyncio +import logging + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from starlette.responses import FileResponse + +from serge.routers.chat import chat_router +from serge.routers.model import model_router +from serge.utils.initiate_database import initiate_database, Settings +from serge.dependencies import convert_model_files + +# Configure logging settings +logging.basicConfig( + level=logging.INFO, + format="%(levelname)s:\t%(name)s\t%(message)s", + handlers=[ + logging.StreamHandler() + ] +) + +# Define a logger for the current module +logger = logging.getLogger(__name__) + +settings = Settings() + +tags_metadata = [ + { + "name": "misc.", + "description": "Miscellaneous endpoints that don't fit anywhere else", + }, + { + "name": "chats", + "description": "Used to manage chats", + }, +] + +description = """ +Serge answers your questions poorly using LLaMa/alpaca. 🚀 +""" + +origins = [ + "http://localhost", + "http://api:9124", + "http://localhost:9123", + "http://localhost:9124", +] + +app = FastAPI( + title="Serge", version="0.0.1", description=description, tags_metadata=tags_metadata +) + +api_app = FastAPI(title="Serge API") +api_app.include_router(chat_router) +api_app.include_router(model_router) +app.mount('/api', api_app) + +# handle serving the frontend as static files in production +if settings.NODE_ENV == "production": + @app.middleware("http") + async def add_custom_header(request, call_next): + response = await call_next(request) + if response.status_code == 404: + return FileResponse('static/200.html') + return response + + @app.exception_handler(404) + def not_found(request, exc): + return FileResponse('static/200.html') + + async def homepage(request): + return FileResponse('static/200.html') + + app.route('/', homepage) + app.mount('/', StaticFiles(directory='static')) + + start_app = app +else: + start_app = api_app + + +@start_app.on_event("startup") +async def start_database(): + logger.info("initializing database connection") + await initiate_database() + + logger.info("initializing models") + asyncio.create_task(convert_model_files()) + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) diff --git a/api/src/serge/models/__init__.py b/api/src/serge/models/__init__.py new file mode 100644 index 00000000000..1260fb039a7 --- /dev/null +++ b/api/src/serge/models/__init__.py @@ -0,0 +1 @@ +from .chat import Chat, Question, ChatParameters \ No newline at end of file diff --git a/api/models.py b/api/src/serge/models/chat.py similarity index 97% rename from api/models.py rename to api/src/serge/models/chat.py index d0e7bdffc95..1bca150f9a7 100644 --- a/api/models.py +++ b/api/src/serge/models/chat.py @@ -4,7 +4,6 @@ from pydantic import Field from datetime import datetime -from enum import Enum class ChatParameters(Document): model: str = Field(default="ggml-alpaca-7B-q4_0.bin") diff --git a/api/utils/__init__.py b/api/src/serge/models/model.py similarity index 100% rename from api/utils/__init__.py rename to api/src/serge/models/model.py diff --git a/api/src/serge/routers/__init__.py b/api/src/serge/routers/__init__.py new file mode 100644 index 00000000000..19e751991e0 --- /dev/null +++ b/api/src/serge/routers/__init__.py @@ -0,0 +1,2 @@ +from .chat import chat_router +from .model import model_router \ No newline at end of file diff --git a/api/main.py b/api/src/serge/routers/chat.py similarity index 53% rename from api/main.py rename to api/src/serge/routers/chat.py index 5f60815fbc1..a731804828c 100644 --- a/api/main.py +++ b/api/src/serge/routers/chat.py @@ -1,146 +1,43 @@ import asyncio -import logging -import os -import psutil -from typing import Annotated - -import anyio -from fastapi import FastAPI, HTTPException, status, Depends -from fastapi.middleware.cors import CORSMiddleware -from beanie.odm.enums import SortDirection +from fastapi import APIRouter, HTTPException, Depends from sse_starlette.sse import EventSourceResponse -from utils.initiate_database import initiate_database, Settings -from utils.generate import generate, get_full_prompt_from_chat -from utils.convert import convert_all -from models import Question, Chat, ChatParameters -from fastapi.staticfiles import StaticFiles -from starlette.responses import FileResponse - - -# Configure logging settings -logging.basicConfig( - level=logging.INFO, - format="%(levelname)s:\t%(name)s\t%(message)s", - handlers=[ - logging.StreamHandler() - ] -) - -# Define a logger for the current module -logger = logging.getLogger(__name__) - -settings = Settings() - -tags_metadata = [ - { - "name": "misc.", - "description": "Miscellaneous endpoints that don't fit anywhere else", - }, - { - "name": "chats", - "description": "Used to manage chats", - }, -] - -description = """ -Serge answers your questions poorly using LLaMa/alpaca. 🚀 -""" - -app = FastAPI( - title="Serge", version="0.0.1", description=description, tags_metadata=tags_metadata -) - -api_app = FastAPI(title="Serge API") -app.mount('/api', api_app) - -if settings.NODE_ENV == "production": - @app.middleware("http") - async def add_custom_header(request, call_next): - response = await call_next(request) - if response.status_code == 404: - return FileResponse('static/200.html') - return response - - @app.exception_handler(404) - def not_found(request, exc): - return FileResponse('static/200.html') - - async def homepage(request): - return FileResponse('static/200.html') - - app.route('/', homepage) - app.mount('/', StaticFiles(directory='static')) - -if settings.NODE_ENV == "development": - start_app = api_app -else: - start_app = app - -origins = [ - "http://localhost", - "http://api:9124", - "http://localhost:9123", - "http://localhost:9124", -] - -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -MODEL_IS_READY: bool = False - - -def dep_models_ready() -> list[str]: - """ - FastAPI dependency that checks if models are ready. - - Returns a list of available models - """ - if MODEL_IS_READY is False: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail={ - "message": "models are not ready" - } - ) +from beanie.odm.enums import SortDirection - files = os.listdir("/usr/src/app/weights") - files = list(filter(lambda x: x.endswith(".bin"), files)) - return files +from serge.models.chat import Question, Chat,ChatParameters +from serge.utils.generate import generate, get_full_prompt_from_chat +from serge.dependencies import dep_models_ready +async def on_close(chat, prompt, answer=None, error=None): + question = await Question(question=prompt.rstrip(), + answer=answer.rstrip() if answer != None else None, + error=error).create() -async def convert_model_files(): - global MODEL_IS_READY - await anyio.to_thread.run_sync(convert_all, "/usr/src/app/weights/", "/usr/src/app/weights/tokenizer.model") - MODEL_IS_READY = True - logger.info("models are ready") + if chat.questions is None: + chat.questions = [question] + else: + chat.questions.append(question) + await chat.save() -@start_app.on_event("startup") -async def start_database(): - logger.info("initializing database connection") - await initiate_database() - logger.info("initializing models") - asyncio.create_task(convert_model_files()) +def remove_matching_end(a, b): + min_length = min(len(a), len(b)) + for i in range(min_length, 0, -1): + if a[-i:] == b[:i]: + return b[i:] -@api_app.get("/models", tags=["misc."]) -def list_of_installed_models( - models: Annotated[list[str], Depends(dep_models_ready)] -): - return models + return b -THREADS = len(psutil.Process().cpu_affinity()) +chat_router = APIRouter( + prefix="/chat", + tags=["chat"], +) -@api_app.post("/chat", tags=["chats"]) +@chat_router.post("/") async def create_new_chat( - model: str = "ggml-alpaca-7B-q4_0.bin", + model: str = "7B", temperature: float = 0.1, top_k: int = 50, top_p: float = 0.95, @@ -149,7 +46,7 @@ async def create_new_chat( repeat_last_n: int = 64, repeat_penalty: float = 1.3, init_prompt: str = "Below is an instruction that describes a task. Write a response that appropriately completes the request. The response must be accurate, concise and evidence-based whenever possible. A complete answer is always ended by [end of text].", - n_threads: int = THREADS / 2, + n_threads: int = 4, ): parameters = await ChatParameters( model=model, @@ -168,14 +65,37 @@ async def create_new_chat( return chat.id -@api_app.get("/chat/{chat_id}", tags=["chats"]) +@chat_router.get("/") +async def get_all_chats(): + res = [] + + for i in ( + await Chat.find_all().sort((Chat.created, SortDirection.DESCENDING)).to_list() + ): + await i.fetch_link(Chat.parameters) + await i.fetch_link(Chat.questions) + + first_q = i.questions[0].question if i.questions else "" + res.append( + { + "id": i.id, + "created": i.created, + "model": i.parameters.model, + "subtitle": first_q, + } + ) + + return res + + +@chat_router.get("/{chat_id}") async def get_specific_chat(chat_id: str): chat = await Chat.get(chat_id) await chat.fetch_all_links() return chat -@api_app.delete("/chat/{chat_id}", tags=["chats"]) +@chat_router.delete("/{chat_id}" ) async def delete_chat(chat_id: str): chat = await Chat.get(chat_id) deleted_chat = await chat.delete() @@ -185,29 +105,9 @@ async def delete_chat(chat_id: str): else: raise HTTPException(status_code=404, detail="No chat found with the given id.") -async def on_close(chat, prompt, answer=None, error=None): - question = await Question(question=prompt.rstrip(), - answer=answer.rstrip() if answer != None else None, - error=error).create() - - if chat.questions is None: - chat.questions = [question] - else: - chat.questions.append(question) - - await chat.save() - - -def remove_matching_end(a, b): - min_length = min(len(a), len(b)) - - for i in range(min_length, 0, -1): - if a[-i:] == b[:i]: - return b[i:] - return b -@api_app.get("/chat/{chat_id}/question", dependencies=[Depends(dep_models_ready)]) +@chat_router.get("/{chat_id}/question", dependencies=[Depends(dep_models_ready)]) async def stream_ask_a_question(chat_id: str, prompt: str): chat = await Chat.get(chat_id) @@ -238,7 +138,6 @@ async def event_generator(): except Exception as e: error = e.__str__() - logger.error(error) yield({"event" : "error"}) finally: answer = "".join(chunks)[len(full_prompt)+1:] @@ -248,7 +147,7 @@ async def event_generator(): return EventSourceResponse(event_generator()) -@api_app.post("/chat/{chat_id}/question", dependencies=[Depends(dep_models_ready)]) +@chat_router.post("/{chat_id}/question", dependencies=[Depends(dep_models_ready)]) async def ask_a_question(chat_id: str, prompt: str): chat = await Chat.get(chat_id) await chat.fetch_link(Chat.parameters) @@ -270,26 +169,4 @@ async def ask_a_question(chat_id: str, prompt: str): finally: await on_close(chat, prompt, answer=answer[len(full_prompt)+1:], error=error) - return {"question" : prompt, "answer" : answer[len(full_prompt)+1:]} - -@api_app.get("/chats", tags=["chats"]) -async def get_all_chats(): - res = [] - - for i in ( - await Chat.find_all().sort((Chat.created, SortDirection.DESCENDING)).to_list() - ): - await i.fetch_link(Chat.parameters) - await i.fetch_link(Chat.questions) - - first_q = i.questions[0].question if i.questions else "" - res.append( - { - "id": i.id, - "created": i.created, - "model": i.parameters.model, - "subtitle": first_q, - } - ) - - return res + return {"question" : prompt, "answer" : answer[len(full_prompt)+1:]} \ No newline at end of file diff --git a/api/src/serge/routers/model.py b/api/src/serge/routers/model.py new file mode 100644 index 00000000000..890eb608422 --- /dev/null +++ b/api/src/serge/routers/model.py @@ -0,0 +1,111 @@ +from fastapi import APIRouter, HTTPException, Depends +from typing import Annotated + +from serge.dependencies import dep_models_ready +from serge.utils.convert import convert_one_file +import huggingface_hub +import os +import urllib.request + +model_router = APIRouter( + prefix="/model", + tags=["model"], +) + +models_info = { + "7B": [ + "nsarrazin/alpaca", + "alpaca-7B-ggml/ggml-model-q4_0.bin", + 4.20E9, + ], + "7B-native": [ + "nsarrazin/alpaca", + "alpaca-native-7B-ggml/ggml-model-q4_0.bin", + 4.20E9, + ], + "13B": [ + "nsarrazin/alpaca", + "alpaca-13B-ggml/ggml-model-q4_0.bin", + 8.13E9, + ], + "30B": [ + "nsarrazin/alpaca", + "alpaca-30B-ggml/ggml-model-q4_0.bin", + 20.2E9, + ], + } + +WEIGHTS = "/usr/src/app/weights/" + +@model_router.get("/all") +async def list_of_all_models(): + res = [] + for model in models_info.keys(): + + progress = await download_status(model) + + res.append({ + "name": model, + "size": models_info[model][2], + "available": model+".bin" in await list_of_installed_models(), + "progress" : progress, + }) + + return res + +@model_router.get("/downloadable") +async def list_of_downloadable_models(): + files = os.listdir(WEIGHTS) + files = list(filter(lambda x: x.endswith(".bin"), files)) + + installed_models = [i.rstrip(".bin") for i in files] + + return list(filter(lambda x: x not in installed_models, models_info.keys())) + +@model_router.get("/installed") +async def list_of_installed_models(): + files = os.listdir(WEIGHTS) + files = list(filter(lambda x: x.endswith(".bin"), files)) + + return files + + +@model_router.post("/{model_name}/download") +def download_model(model_name: str): + models = list(models_info.keys()) + if model_name not in models: + raise HTTPException(status_code=404, detail="Model not found") + + if not os.path.exists(WEIGHTS+ "tokenizer.model"): + print("Downloading tokenizer...") + url = huggingface_hub.hf_hub_url("nsarrazin/alpaca", "alpaca-7B-ggml/tokenizer.model", repo_type="model", revision="main") + urllib.request.urlretrieve(url, WEIGHTS+"tokenizer.model") + + + repo_id, filename,_ = models_info[model_name] + + print(f"Downloading {model_name} model from {repo_id}...") + url = huggingface_hub.hf_hub_url(repo_id, filename, repo_type="model", revision="main") + urllib.request.urlretrieve(url, WEIGHTS+f"{model_name}.bin.tmp") + + os.rename(WEIGHTS+f"{model_name}.bin.tmp", WEIGHTS+f"{model_name}.bin") + convert_one_file(WEIGHTS+ "f{model_name}.bin", WEIGHTS + f"tokenizer.model") + + return {"message": f"Model {model_name} downloaded"} + + +@model_router.get("/{model_name}/download/status") +async def download_status(model_name: str): + models = list(models_info.keys()) + + if model_name not in models: + raise HTTPException(status_code=404, detail="Model not found") + + filesize = models_info[model_name][2] + + bin_path = WEIGHTS+f"{model_name}.bin.tmp" + + if os.path.exists(bin_path): + currentsize = os.path.getsize(bin_path) + return min(round(currentsize / filesize*100, 1), 100) + return None \ No newline at end of file diff --git a/api/src/serge/utils/__init__.py b/api/src/serge/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/api/utils/convert.py b/api/src/serge/utils/convert.py similarity index 98% rename from api/utils/convert.py rename to api/src/serge/utils/convert.py index 2ede507bd97..1622c54650a 100644 --- a/api/utils/convert.py +++ b/api/src/serge/utils/convert.py @@ -122,4 +122,4 @@ def convert_all(dir_model: str, tokenizer_model: str): if __name__ == "__main__": args = parse_args() - convert_all(args.dir_model, args.tokenizer_model) + convert_all(args.dir_model, args.tokenizer_model) \ No newline at end of file diff --git a/api/utils/generate.py b/api/src/serge/utils/generate.py similarity index 95% rename from api/utils/generate.py rename to api/src/serge/utils/generate.py index 20be7b54d8c..97bf0b7ce49 100644 --- a/api/utils/generate.py +++ b/api/src/serge/utils/generate.py @@ -1,5 +1,5 @@ import subprocess, os -from models import Chat, ChatParameters +from serge.models.chat import Chat, ChatParameters import asyncio import logging @@ -16,7 +16,7 @@ async def generate( args = ( "llama", "--model", - "/usr/src/app/weights/" + params.model, + "/usr/src/app/weights/" + params.model + ".bin", "--prompt", prompt, "--n_predict", diff --git a/api/utils/initiate_database.py b/api/src/serge/utils/initiate_database.py similarity index 89% rename from api/utils/initiate_database.py rename to api/src/serge/utils/initiate_database.py index b58ad8b4fb8..31b7de362c6 100644 --- a/api/utils/initiate_database.py +++ b/api/src/serge/utils/initiate_database.py @@ -4,7 +4,7 @@ from motor.motor_asyncio import AsyncIOMotorClient from pydantic import BaseSettings -from models import Question, Chat, ChatParameters +from serge.models.chat import Question, Chat, ChatParameters class Settings(BaseSettings): diff --git a/api/utils/download.py b/api/utils/download.py deleted file mode 100644 index 4ab575d4afb..00000000000 --- a/api/utils/download.py +++ /dev/null @@ -1,56 +0,0 @@ -import argparse -import huggingface_hub -import os - -from typing import List -from convert import convert_all - -models_info = { - "7B": ["Pi3141/alpaca-7B-ggml", "ggml-model-q4_0.bin"], - "13B": ["Pi3141/alpaca-13B-ggml", "ggml-model-q4_0.bin"], - "30B": ["Pi3141/alpaca-30B-ggml", "ggml-model-q4_0.bin"], - "tokenizer": ["decapoda-research/llama-7b-hf", "tokenizer.model"], -} - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Download and convert LLaMA models to the current format" - ) - parser.add_argument( - "model", - help="Model name", - nargs="+", - choices=["7B", "13B", "30B", "tokenizer"], - ) - - return parser.parse_args() - - -def download_models(models: List[str]): - for model in models: - repo_id, filename = models_info[model] - print(f"Downloading {model} model from {repo_id}...") - - huggingface_hub.hf_hub_download( - repo_id=repo_id, - filename=filename, - local_dir="/usr/src/app/weights", - local_dir_use_symlinks=False, - cache_dir="/usr/src/app/weights/.cache", - ) - - if filename == "ggml-model-q4_0.bin": - os.rename( - "/usr/src/app/weights/ggml-model-q4_0.bin", f"/usr/src/app/weights/ggml-alpaca-{model}-q4_0.bin" - ) - - -if __name__ == "__main__": - args = parse_args() - - print("Downloading models from HuggingFace") - download_models(args.model) - - print("Converting models to the current format") - convert_all("/usr/src/app/weights", "/usr/src/app/weights/tokenizer.model") diff --git a/compile.sh b/scripts/compile.sh similarity index 100% rename from compile.sh rename to scripts/compile.sh diff --git a/deploy.sh b/scripts/deploy.sh similarity index 69% rename from deploy.sh rename to scripts/deploy.sh index 51968b8b103..3582b368b27 100644 --- a/deploy.sh +++ b/scripts/deploy.sh @@ -4,7 +4,7 @@ mongod & # Start the API -cd api && uvicorn main:app --host 0.0.0.0 --port 8008 & +cd api && uvicorn src.serge.main:app --host 0.0.0.0 --port 8008 & # Wait for any process to exit wait -n diff --git a/dev.sh b/scripts/dev.sh similarity index 65% rename from dev.sh rename to scripts/dev.sh index 2b461a9fdaf..55e12e72f57 100644 --- a/dev.sh +++ b/scripts/dev.sh @@ -1,13 +1,15 @@ #!/bin/bash ./compile.sh +pip install -e ./api + mongod & # Start the web server cd web && npm run dev -- --host 0.0.0.0 --port 8008 & # Start the API -cd api && uvicorn main:api_app --reload --host 0.0.0.0 --port 9124 --root-path /api/ & +cd api && uvicorn src.serge.main:api_app --reload --host 0.0.0.0 --port 9124 --root-path /api/ & # Wait for any process to exit wait -n diff --git a/web/src/routes/+layout.svelte b/web/src/routes/+layout.svelte index 3f57de09f84..24e477eb84a 100644 --- a/web/src/routes/+layout.svelte +++ b/web/src/routes/+layout.svelte @@ -12,7 +12,7 @@ if (response.status == 200) { toggleDeleteConfirm(); await goto("/"); - await invalidate("/api/chats"); + await invalidate("/api/chat/"); } else { console.error("Error " + response.status + ": " + response.statusText); } @@ -69,7 +69,7 @@
  • diff --git a/web/src/routes/+layout.ts b/web/src/routes/+layout.ts index 70f8d0dc468..82d3a433ecc 100644 --- a/web/src/routes/+layout.ts +++ b/web/src/routes/+layout.ts @@ -8,7 +8,7 @@ type t = { }; export const load: LayoutLoad = async ({ fetch }) => { - const r = await fetch("/api/chats"); + const r = await fetch("/api/chat/"); const chats = (await r.json()) as t[]; return { chats, diff --git a/web/src/routes/+page.svelte b/web/src/routes/+page.svelte index 1af71d99741..6362ae893ed 100644 --- a/web/src/routes/+page.svelte +++ b/web/src/routes/+page.svelte @@ -3,7 +3,12 @@ import { goto, invalidate } from "$app/navigation"; export let data: PageData; - const modelAvailable = data.models.length > 0; + const models = data.models.filter((el) => el.available); + + console.log(models); + + const modelAvailable = models.length > 0; + const modelsLabels = models.map((el) => el.name); let temp = 0.1; let top_k = 50; @@ -30,7 +35,7 @@ ]); const searchParams = new URLSearchParams(convertedFormEntries); - const r = await fetch("/api/chat?" + searchParams.toString(), { + const r = await fetch("/api/chat/?" + searchParams.toString(), { method: "POST", }); @@ -38,14 +43,14 @@ if (r.ok) { const data = await r.json(); await goto("/chat/" + data); - await invalidate("/api/chats"); + await invalidate("/api/chat/"); } else { console.log(r.statusText); } } -

    Say Hi to Serge!

    +

    Say Hi to Serge 🦙

    An easy way to chat with Alpaca & other LLaMa based models.

    @@ -53,8 +58,15 @@
    - Start a new chat +
    @@ -162,7 +174,7 @@
    diff --git a/web/src/routes/+page.ts b/web/src/routes/+page.ts index 4e5fcd5ec6b..7126d216a8a 100644 --- a/web/src/routes/+page.ts +++ b/web/src/routes/+page.ts @@ -1,8 +1,15 @@ import type { PageLoad } from "./$types"; +interface ModelStatus { + name: string; + size: number; + available: boolean; + progress?: number; +} + export const load: PageLoad = async ({ fetch }) => { - const r = await fetch("api/models"); - const models = (await r.json()) as string[]; + const r = await fetch("/api/model/all"); + const models = (await r.json()) as Array; return { models, }; diff --git a/web/src/routes/models/+page.svelte b/web/src/routes/models/+page.svelte new file mode 100644 index 00000000000..dac95018a03 --- /dev/null +++ b/web/src/routes/models/+page.svelte @@ -0,0 +1,72 @@ + + +

    ⚡ Download a model ⚡

    +

    + Make sure you have enough disk space and available RAM to run them. +

    + +
    +
    +
    + {#each data.models as model} +
    +

    + {model.name + " " + (model.available ? "☑️" : "")} +

    +

    + ({model.size / 1e9}GB) +

    + {#if model.progress} +
    +

    {model.progress}%

    + +
    + {/if} + +
    +
    + {/each} +
    +
    diff --git a/web/src/routes/models/+page.ts b/web/src/routes/models/+page.ts new file mode 100644 index 00000000000..7126d216a8a --- /dev/null +++ b/web/src/routes/models/+page.ts @@ -0,0 +1,16 @@ +import type { PageLoad } from "./$types"; + +interface ModelStatus { + name: string; + size: number; + available: boolean; + progress?: number; +} + +export const load: PageLoad = async ({ fetch }) => { + const r = await fetch("/api/model/all"); + const models = (await r.json()) as Array; + return { + models, + }; +}; From 6949201e4e390ba140bfc5cfadb007190940c195 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 28 Mar 2023 23:19:56 +0200 Subject: [PATCH 2/3] delete partially downloaded files on startup --- api/src/serge/dependencies.py | 29 +++-------------------------- api/src/serge/main.py | 26 +++++++++++++++++--------- api/src/serge/utils/convert.py | 7 +++---- 3 files changed, 23 insertions(+), 39 deletions(-) diff --git a/api/src/serge/dependencies.py b/api/src/serge/dependencies.py index e33a498a823..7c10589b3bd 100644 --- a/api/src/serge/dependencies.py +++ b/api/src/serge/dependencies.py @@ -1,31 +1,8 @@ -from fastapi import HTTPException, status from .utils.convert import convert_all import anyio -import os - -MODEL_IS_READY: bool = False - - -def dep_models_ready() -> list[str]: - """ - FastAPI dependency that checks if models are ready. - - Returns a list of available models - """ - if MODEL_IS_READY is False: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail={ - "message": "models are not ready" - } - ) - - files = os.listdir("/usr/src/app/weights") - files = list(filter(lambda x: x.endswith(".bin"), files)) - return files async def convert_model_files(): - global MODEL_IS_READY - await anyio.to_thread.run_sync(convert_all, "/usr/src/app/weights/", "/usr/src/app/weights/tokenizer.model") - MODEL_IS_READY = True + await anyio.to_thread.run_sync( + convert_all, "/usr/src/app/weights/", "/usr/src/app/weights/tokenizer.model" + ) diff --git a/api/src/serge/main.py b/api/src/serge/main.py index 282578e0659..faec8e2519c 100644 --- a/api/src/serge/main.py +++ b/api/src/serge/main.py @@ -1,5 +1,6 @@ import asyncio import logging +import os from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -15,9 +16,7 @@ logging.basicConfig( level=logging.INFO, format="%(levelname)s:\t%(name)s\t%(message)s", - handlers=[ - logging.StreamHandler() - ] + handlers=[logging.StreamHandler()], ) # Define a logger for the current module @@ -54,26 +53,27 @@ api_app = FastAPI(title="Serge API") api_app.include_router(chat_router) api_app.include_router(model_router) -app.mount('/api', api_app) +app.mount("/api", api_app) # handle serving the frontend as static files in production if settings.NODE_ENV == "production": + @app.middleware("http") async def add_custom_header(request, call_next): response = await call_next(request) if response.status_code == 404: - return FileResponse('static/200.html') + return FileResponse("static/200.html") return response @app.exception_handler(404) def not_found(request, exc): - return FileResponse('static/200.html') + return FileResponse("static/200.html") async def homepage(request): - return FileResponse('static/200.html') + return FileResponse("static/200.html") - app.route('/', homepage) - app.mount('/', StaticFiles(directory='static')) + app.route("/", homepage) + app.mount("/", StaticFiles(directory="static")) start_app = app else: @@ -82,12 +82,20 @@ async def homepage(request): @start_app.on_event("startup") async def start_database(): + WEIGHTS = "/usr/src/app/weights/" + files = os.listdir(WEIGHTS) + files = list(filter(lambda x: x.endswith(".tmp"), files)) + + for file in files: + os.remove(WEIGHTS + file) + logger.info("initializing database connection") await initiate_database() logger.info("initializing models") asyncio.create_task(convert_model_files()) + app.add_middleware( CORSMiddleware, allow_origins=origins, diff --git a/api/src/serge/utils/convert.py b/api/src/serge/utils/convert.py index 1622c54650a..80c1282adc1 100644 --- a/api/src/serge/utils/convert.py +++ b/api/src/serge/utils/convert.py @@ -113,13 +113,12 @@ def convert_all(dir_model: str, tokenizer_model: str): try: tokenizer = SentencePieceProcessor(tokenizer_model) + for file in files: + convert_one_file(file, tokenizer) except OSError: print("Missing tokenizer, don't forget to download it!") - for file in files: - convert_one_file(file, tokenizer) - if __name__ == "__main__": args = parse_args() - convert_all(args.dir_model, args.tokenizer_model) \ No newline at end of file + convert_all(args.dir_model, args.tokenizer_model) From 8c89a9b594b6cc392efccbae429574486d2aa1fc Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 28 Mar 2023 23:25:08 +0200 Subject: [PATCH 3/3] remove unused deps --- api/src/serge/routers/chat.py | 5 ++--- api/src/serge/routers/model.py | 4 +--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/api/src/serge/routers/chat.py b/api/src/serge/routers/chat.py index a731804828c..fbb6e2e807e 100644 --- a/api/src/serge/routers/chat.py +++ b/api/src/serge/routers/chat.py @@ -6,7 +6,6 @@ from serge.models.chat import Question, Chat,ChatParameters from serge.utils.generate import generate, get_full_prompt_from_chat -from serge.dependencies import dep_models_ready async def on_close(chat, prompt, answer=None, error=None): question = await Question(question=prompt.rstrip(), @@ -107,7 +106,7 @@ async def delete_chat(chat_id: str): -@chat_router.get("/{chat_id}/question", dependencies=[Depends(dep_models_ready)]) +@chat_router.get("/{chat_id}/question") async def stream_ask_a_question(chat_id: str, prompt: str): chat = await Chat.get(chat_id) @@ -147,7 +146,7 @@ async def event_generator(): return EventSourceResponse(event_generator()) -@chat_router.post("/{chat_id}/question", dependencies=[Depends(dep_models_ready)]) +@chat_router.post("/{chat_id}/question") async def ask_a_question(chat_id: str, prompt: str): chat = await Chat.get(chat_id) await chat.fetch_link(Chat.parameters) diff --git a/api/src/serge/routers/model.py b/api/src/serge/routers/model.py index 890eb608422..61d6d57b464 100644 --- a/api/src/serge/routers/model.py +++ b/api/src/serge/routers/model.py @@ -1,7 +1,5 @@ -from fastapi import APIRouter, HTTPException, Depends -from typing import Annotated +from fastapi import APIRouter, HTTPException -from serge.dependencies import dep_models_ready from serge.utils.convert import convert_one_file import huggingface_hub import os