Skip to content

Commit

Permalink
feat(server): replace Flask with FastAPI and Uvicorn
Browse files Browse the repository at this point in the history
- replace Flask with FastAPI and Uvicorn
- fix web page not found error
- port is now defaulted to 7001
- bind to localhost (127.0.0.1) instead of 0.0.0.0
- improve performance by using Uvicorn
- add OpenAPI docs for endpoints
  • Loading branch information
leafspark committed Aug 31, 2024
1 parent db1733b commit 22bd74b
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 59 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ sentencepiece~=0.2.0
PyYAML~=6.0.2
pynvml~=11.5.3
PySide6~=6.7.2
flask~=3.0.3
python-dotenv~=1.0.1
safetensors~=0.4.4
setuptools~=68.2.0
huggingface-hub~=0.24.6
transformers~=4.44.2
fastapi~=0.112.2
uvicorn~=0.30.6
214 changes: 156 additions & 58 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,178 @@
import os
import sys
import threading
from enum import Enum
from typing import List, Optional

from PySide6.QtCore import QTimer
from PySide6.QtWidgets import QApplication
from fastapi import FastAPI, Query
from pydantic import BaseModel, Field
from uvicorn import Config, Server

from AutoGGUF import AutoGGUF
from flask import Flask, Response, jsonify
from Localizations import AUTOGGUF_VERSION

server = Flask(__name__)
app = FastAPI(
title="AutoGGUF",
description="API for AutoGGUF - automatically quant GGUF models",
version=AUTOGGUF_VERSION,
license_info={
"name": "Apache 2.0",
"url": "https://raw.githubusercontent.com/leafspark/AutoGGUF/main/LICENSE",
},
)

# Global variable to hold the window reference
window = None

def main() -> None:
@server.route("/v1/models", methods=["GET"])
def models() -> Response:
if window:
return jsonify({"models": window.get_models_data()})
return jsonify({"models": []})

@server.route("/v1/tasks", methods=["GET"])
def tasks() -> Response:
if window:
return jsonify({"tasks": window.get_tasks_data()})
return jsonify({"tasks": []})

@server.route("/v1/health", methods=["GET"])
def ping() -> Response:
return jsonify({"status": "alive"})

@server.route("/v1/backends", methods=["GET"])
def get_backends() -> Response:
backends = []

class ModelType(str, Enum):
single = "single"
sharded = "sharded"


class Model(BaseModel):
name: str = Field(..., description="Name of the model")
type: str = Field(..., description="Type of the model")
path: str = Field(..., description="Path to the model file")
size: Optional[int] = Field(None, description="Size of the model in bytes")

class Config:
json_schema_extra = {
"example": {
"name": "Llama-3.1-8B-Instruct.fp16.gguf",
"type": "single",
"path": "Llama-3.1-8B-Instruct.fp16.gguf",
"size": 13000000000,
}
}


class Task(BaseModel):
id: str = Field(..., description="Unique identifier for the task")
status: str = Field(..., description="Current status of the task")
progress: float = Field(..., description="Progress of the task as a percentage")

class Config:
json_json_schema_extra = {
"example": {"id": "task_123", "status": "running", "progress": 75.5}
}


class Backend(BaseModel):
name: str = Field(..., description="Name of the backend")
path: str = Field(..., description="Path to the backend executable")


class Plugin(BaseModel):
name: str = Field(..., description="Name of the plugin")
version: str = Field(..., description="Version of the plugin")
description: str = Field(..., description="Description of the plugin")
author: str = Field(..., description="Author of the plugin")


@app.get("/v1/models", response_model=List[Model], tags=["Models"])
async def get_models(
type: Optional[ModelType] = Query(None, description="Filter models by type")
) -> List[Model]:
"""
Get a list of all available models.
- **type**: Optional filter for model type
Returns a list of Model objects containing name, type, path, and optional size.
"""
if window:
models = window.get_models_data()
if type:
models = [m for m in models if m["type"] == type]

# Convert to Pydantic models, handling missing 'size' field
return [Model(**m) for m in models]
return []


@app.get("/v1/tasks", response_model=List[Task], tags=["Tasks"])
async def get_tasks() -> List[Task]:
"""
Get a list of all current tasks.
Returns a list of Task objects containing id, status, and progress.
"""
if window:
return window.get_tasks_data()
return []


@app.get("/v1/health", tags=["System"])
async def health_check() -> dict:
"""
Check the health status of the API.
Returns a simple status message indicating the API is alive.
"""
return {"status": "alive"}


@app.get("/v1/backends", response_model=List[Backend], tags=["System"])
async def get_backends() -> List[Backend]:
"""
Get a list of all available llama.cpp backends.
Returns a list of Backend objects containing name and path.
"""
backends = []
if window:
for i in range(window.backend_combo.count()):
backends.append(
{
"name": window.backend_combo.itemText(i),
"path": window.backend_combo.itemData(i),
}
)
return jsonify({"backends": backends})

@server.route("/v1/plugins", methods=["GET"])
def get_plugins() -> Response:
if window:
return jsonify(
{
"plugins": [
{
"name": plugin_data["data"]["name"],
"version": plugin_data["data"]["version"],
"description": plugin_data["data"]["description"],
"author": plugin_data["data"]["author"],
}
for plugin_data in window.plugins.values()
]
}
)
return jsonify({"plugins": []})

def run_flask() -> None:
if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled":
server.run(
host="0.0.0.0",
port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 5000)),
debug=False,
use_reloader=False,
Backend(
name=window.backend_combo.itemText(i),
path=window.backend_combo.itemData(i),
)
)
return backends

app = QApplication(sys.argv)

@app.get("/v1/plugins", response_model=List[Plugin], tags=["System"])
async def get_plugins() -> List[Plugin]:
"""
Get a list of all installed plugins.
Returns a list of Plugin objects containing name, version, description, and author.
"""
if window:
return [
Plugin(**plugin_data["data"]) for plugin_data in window.plugins.values()
]
return []


def run_uvicorn() -> None:
if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled":
config = Config(
app=app,
host="127.0.0.1",
port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 7001)),
log_level="info",
)
server = Server(config)
server.run()


def main() -> None:
global window
qt_app = QApplication(sys.argv)
window = AutoGGUF(sys.argv)
window.show()
# Start Flask in a separate thread after a short delay

# Start Uvicorn in a separate thread after a short delay
timer = QTimer()
timer.singleShot(
100, lambda: threading.Thread(target=run_flask, daemon=True).start()
100, lambda: threading.Thread(target=run_uvicorn, daemon=True).start()
)
sys.exit(app.exec())

sys.exit(qt_app.exec())


if __name__ == "__main__":
Expand Down

0 comments on commit 22bd74b

Please sign in to comment.