Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix duplicated llm name betweeen different suppliers #2477

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from rag.utils import rmSpace
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService
Expand Down Expand Up @@ -141,8 +141,7 @@ def set():
return get_data_error_result(retmsg="Tenant not found!")

embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)

e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
Expand Down Expand Up @@ -235,8 +234,7 @@ def create():
return get_data_error_result(retmsg="Tenant not found!")

embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)

v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1]
Expand Down Expand Up @@ -281,16 +279,14 @@ def retrieval_test():
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")

embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)

rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])

if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)

retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
Expand Down
11 changes: 9 additions & 2 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def count():


def llm_id2llm_type(llm_id):
llm_id = llm_id.split("@")[0]
fnm = os.path.join(get_project_base_directory(), "conf")
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
for llm_factory in llm_factories["factory_llm_infos"]:
Expand All @@ -89,9 +90,15 @@ def llm_id2llm_type(llm_id):
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
st = timer()
llm = LLMService.query(llm_name=dialog.llm_id)
tmp = dialog.llm_id.split("@")
fid = None
llm_id = tmp[0]
if len(tmp)>1: fid = tmp[1]

llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
if not llm:
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 8192
Expand Down
17 changes: 12 additions & 5 deletions api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from api.settings import database_logger
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
from api.db import LLMType
from api.db.db_models import DB, UserTenant
from api.db.db_models import DB
from api.db.db_models import LLMFactories, LLM, TenantLLM
from api.db.services.common_service import CommonService

Expand All @@ -36,7 +36,11 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
arr = model_name.split("@")
if len(arr) < 2:
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
else:
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
if not objs:
return
return objs[0]
Expand Down Expand Up @@ -81,14 +85,17 @@ def model_instance(cls, tenant_id, llm_type,
assert False, "LLM type error"

model_config = cls.get_api_key(tenant_id, mdlnm)
tmp = mdlnm.split("@")
fid = None if len(tmp) < 2 else tmp[1]
mdlnm = tmp[0]
if model_config: model_config = model_config.to_dict()
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=llm_name if llm_name else mdlnm)
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name if llm_name else mdlnm, "api_base": ""}
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if llm_name == "flag-embedding":
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
"llm_name": llm_name, "api_base": ""}
else:
Expand Down
2 changes: 1 addition & 1 deletion rag/app/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __call__(self, filename, binary=None, from_page=0, to_page=100000):
if last_image:
image_list.insert(0, last_image)
last_image = None
lines.append((self.__clean(p.text), image_list, p.style.name))
lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
else:
if current_image := self.get_picture(self.doc, p):
if lines:
Expand Down