Skip to content

Commit

Permalink
Chat Use CVmodel (infiniflow#1607)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#1230 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
guoyuhao2330 authored Jul 19, 2024
1 parent 5ec7b37 commit 3bae5d0
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 6 deletions.
27 changes: 24 additions & 3 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import json
import re
from copy import deepcopy

Expand All @@ -26,6 +28,7 @@
from rag.nlp import keyword_extraction
from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory


class DialogService(CommonService):
Expand Down Expand Up @@ -73,6 +76,15 @@ def count():
return max_length, msg


def llm_id2llm_type(llm_id):
fnm = os.path.join(get_project_base_directory(), "conf")
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
for llm_factory in llm_factories["factory_llm_infos"]:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]


def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
llm = LLMService.query(llm_name=dialog.llm_id)
Expand All @@ -91,7 +103,10 @@ def chat(dialog, messages, stream=True, **kwargs):

questions = [m["content"] for m in messages if m["role"] == "user"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)

prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
Expand Down Expand Up @@ -328,7 +343,10 @@ def get_table():


def relevant(tenant_id, llm_id, question, contents: list):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
if llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are a grader assessing relevance of a retrieved document to a user question.
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
Expand All @@ -347,7 +365,10 @@ def relevant(tenant_id, llm_id, question, contents: list):


def rewrite(tenant_id, llm_id, question):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
if llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are an expert at query expansion to generate a paraphrasing of a question.
I can't retrieval relevant information from the knowledge base by using user's question directly.
Expand Down
2 changes: 1 addition & 1 deletion api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def model_instance(cls, tenant_id, llm_type,
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id
mdlnm = tenant.img2txt_id if not llm_name else llm_name
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
elif llm_type == LLMType.RERANK:
Expand Down
Loading

0 comments on commit 3bae5d0

Please sign in to comment.