Skip to content

Commit

Permalink
add lighten control (#2567)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

#2295

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
  • Loading branch information
KevinHuSh authored Sep 24, 2024
1 parent 9251fb3 commit 7bb28ca
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 78 deletions.
4 changes: 3 additions & 1 deletion api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flask import request
from flask_login import login_required, current_user
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from api.settings import LIGHTEN
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM
Expand Down Expand Up @@ -319,13 +320,14 @@ def my_llms():
@login_required
def list_app():
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN else []
model_type = request.args.get("model_type")
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
llms = LLMService.get_all()
llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value]
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied

Expand Down
134 changes: 67 additions & 67 deletions api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
SERVER_MODULE = "rag_flow_server.py"
TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp")
RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
LIGHTEN = os.environ.get('LIGHTEN')

SUBPROCESS_STD_LOG_NAME = "std.log"

Expand All @@ -57,77 +58,76 @@

USE_REGISTRY = get_base_config("use_registry")

default_llm = {
"Tongyi-Qianwen": {
"chat_model": "qwen-plus",
"embedding_model": "text-embedding-v2",
"image2text_model": "qwen-vl-max",
"asr_model": "paraformer-realtime-8k-v1",
},
"OpenAI": {
"chat_model": "gpt-3.5-turbo",
"embedding_model": "text-embedding-ada-002",
"image2text_model": "gpt-4-vision-preview",
"asr_model": "whisper-1",
},
"Azure-OpenAI": {
"chat_model": "azure-gpt-35-turbo",
"embedding_model": "azure-text-embedding-ada-002",
"image2text_model": "azure-gpt-4-vision-preview",
"asr_model": "azure-whisper-1",
},
"ZHIPU-AI": {
"chat_model": "glm-3-turbo",
"embedding_model": "embedding-2",
"image2text_model": "glm-4v",
"asr_model": "",
},
"Ollama": {
"chat_model": "qwen-14B-chat",
"embedding_model": "flag-embedding",
"image2text_model": "",
"asr_model": "",
},
"Moonshot": {
"chat_model": "moonshot-v1-8k",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"DeepSeek": {
"chat_model": "deepseek-chat",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"VolcEngine": {
"chat_model": "",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"BAAI": {
"chat_model": "",
"embedding_model": "BAAI/bge-large-zh-v1.5",
"image2text_model": "",
"asr_model": "",
"rerank_model": "BAAI/bge-reranker-v2-m3",
}
}
LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
LLM_BASE_URL = LLM.get("base_url")

if LLM_FACTORY not in default_llm:
print(
"\33[91m【ERROR】\33[0m:",
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
LLM_FACTORY = "Tongyi-Qianwen"
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
RERANK_MDL = default_llm["BAAI"]["rerank_model"]
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
if not LIGHTEN:
default_llm = {
"Tongyi-Qianwen": {
"chat_model": "qwen-plus",
"embedding_model": "text-embedding-v2",
"image2text_model": "qwen-vl-max",
"asr_model": "paraformer-realtime-8k-v1",
},
"OpenAI": {
"chat_model": "gpt-3.5-turbo",
"embedding_model": "text-embedding-ada-002",
"image2text_model": "gpt-4-vision-preview",
"asr_model": "whisper-1",
},
"Azure-OpenAI": {
"chat_model": "azure-gpt-35-turbo",
"embedding_model": "azure-text-embedding-ada-002",
"image2text_model": "azure-gpt-4-vision-preview",
"asr_model": "azure-whisper-1",
},
"ZHIPU-AI": {
"chat_model": "glm-3-turbo",
"embedding_model": "embedding-2",
"image2text_model": "glm-4v",
"asr_model": "",
},
"Ollama": {
"chat_model": "qwen-14B-chat",
"embedding_model": "flag-embedding",
"image2text_model": "",
"asr_model": "",
},
"Moonshot": {
"chat_model": "moonshot-v1-8k",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"DeepSeek": {
"chat_model": "deepseek-chat",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"VolcEngine": {
"chat_model": "",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
},
"BAAI": {
"chat_model": "",
"embedding_model": "BAAI/bge-large-zh-v1.5",
"image2text_model": "",
"asr_model": "",
"rerank_model": "BAAI/bge-reranker-v2-m3",
}
}

CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
RERANK_MDL = default_llm["BAAI"]["rerank_model"] if not LIGHTEN else ""
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
else:
CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = ""

API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get(
Expand Down
8 changes: 5 additions & 3 deletions deepdoc/parser/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import xgboost as xgb
from io import BytesIO
import torch
import re
import pdfplumber
import logging
Expand All @@ -25,6 +24,7 @@
from timeit import default_timer as timer
from pypdf import PdfReader as pdf2_read

from api.settings import LIGHTEN
from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
from rag.nlp import rag_tokenizer
Expand All @@ -44,8 +44,10 @@ def __init__(self):
self.tbl_det = TableStructureRecognizer()

self.updown_cnt_mdl = xgb.Booster()
if torch.cuda.is_available():
self.updown_cnt_mdl.set_param({"device": "cuda"})
if not LIGHTEN:
import torch
if torch.cuda.is_available():
self.updown_cnt_mdl.set_param({"device": "cuda"})
try:
model_dir = os.path.join(
get_project_base_directory(),
Expand Down
8 changes: 5 additions & 3 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from ollama import Client
import dashscope
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import numpy as np
import asyncio

from api.settings import LIGHTEN
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai
Expand Down Expand Up @@ -60,8 +60,10 @@ def __init__(self, key, model_name, **kwargs):
^_-
"""
if not DefaultEmbedding._model:
if not LIGHTEN and not DefaultEmbedding._model:
with DefaultEmbedding._model_lock:
from FlagEmbedding import FlagModel
import torch
if not DefaultEmbedding._model:
try:
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
Expand Down
10 changes: 6 additions & 4 deletions rag/llm/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# limitations under the License.
#
import re
import threading
import threading
import requests
import torch
from FlagEmbedding import FlagReranker
from huggingface_hub import snapshot_download
import os
from abc import ABC
import numpy as np

from api.settings import LIGHTEN
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
import json
Expand Down Expand Up @@ -53,7 +53,9 @@ def __init__(self, key, model_name, **kwargs):
^_-
"""
if not DefaultRerank._model:
if not LIGHTEN and not DefaultRerank._model:
import torch
from FlagEmbedding import FlagReranker
with DefaultRerank._model_lock:
if not DefaultRerank._model:
try:
Expand Down

0 comments on commit 7bb28ca

Please sign in to comment.