From 3bae5d0bc938c10509460c434350547b9eb22fb6 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 18:36:34 +0800 Subject: [PATCH] Chat Use CVmodel (#1607) ### What problem does this PR solve? #1230 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/db/services/dialog_service.py | 27 +- api/db/services/llm_service.py | 2 +- rag/llm/cv_model.py | 300 +++++++++++++++++- .../components/llm-setting-items/index.tsx | 2 +- 4 files changed, 325 insertions(+), 6 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index ab90ee1c047..20cb27cfdc4 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import json import re from copy import deepcopy @@ -26,6 +28,7 @@ from rag.nlp import keyword_extraction from rag.nlp.search import index_name from rag.utils import rmSpace, num_tokens_from_string, encoder +from api.utils.file_utils import get_project_base_directory class DialogService(CommonService): @@ -73,6 +76,15 @@ def count(): return max_length, msg +def llm_id2llm_type(llm_id): + fnm = os.path.join(get_project_base_directory(), "conf") + llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r")) + for llm_factory in llm_factories["factory_llm_infos"]: + for llm in llm_factory["llm"]: + if llm_id == llm["llm_name"]: + return llm["model_type"].strip(",")[-1] + + def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." llm = LLMService.query(llm_name=dialog.llm_id) @@ -91,7 +103,10 @@ def chat(dialog, messages, stream=True, **kwargs): questions = [m["content"] for m in messages if m["role"] == "user"] embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) - chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + if llm_id2llm_type(dialog.llm_id) == "image2text": + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) + else: + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) @@ -328,7 +343,10 @@ def get_table(): def relevant(tenant_id, llm_id, question, contents: list): - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) + if llm_id2llm_type(llm_id) == "image2text": + chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) + else: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) prompt = """ You are a grader assessing relevance of a retrieved document to a user question. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. @@ -347,7 +365,10 @@ def relevant(tenant_id, llm_id, question, contents: list): def rewrite(tenant_id, llm_id, question): - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) + if llm_id2llm_type(llm_id) == "image2text": + chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) + else: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) prompt = """ You are an expert at query expansion to generate a paraphrasing of a question. I can't retrieval relevant information from the knowledge base by using user's question directly. diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 5906927b7bc..a994d6103eb 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -70,7 +70,7 @@ def model_instance(cls, tenant_id, llm_type, elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id elif llm_type == LLMType.IMAGE2TEXT.value: - mdlnm = tenant.img2txt_id + mdlnm = tenant.img2txt_id if not llm_name else llm_name elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id if not llm_name else llm_name elif llm_type == LLMType.RERANK: diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index e63d6a0a39f..5867ce2e85b 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -26,6 +26,7 @@ import json import requests +from rag.nlp import is_english from api.utils import get_uuid from api.utils.file_utils import get_project_base_directory @@ -36,7 +37,60 @@ def __init__(self, key, model_name): def describe(self, image, max_tokens=300): raise NotImplementedError("Please implement encode method!") + + def chat(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + try: + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7) + ) + return response.choices[0].message.content.strip(), response.usage.total_tokens + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + ans = "" + tk_count = 0 + try: + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7), + stream=True + ) + for resp in response: + if not resp.choices[0].delta.content: continue + delta = resp.choices[0].delta.content + ans += delta + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + tk_count = resp.usage.total_tokens + if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield tk_count + def image2base64(self, image): if isinstance(image, bytes): return base64.b64encode(image).decode("utf-8") @@ -68,6 +122,21 @@ def prompt(self, b64): } ] + def chat_prompt(self, text, b64): + return [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{b64}", + }, + }, + { + "type": "text", + "text": text + }, + ] + + class GptV4(Base): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): @@ -140,6 +209,12 @@ def prompt(self, binary): } ] + def chat_prompt(self, text, b64): + return [ + {"image": f"{b64}"}, + {"text": text}, + ] + def describe(self, image, max_tokens=300): from http import HTTPStatus from dashscope import MultiModalConversation @@ -149,6 +224,66 @@ def describe(self, image, max_tokens=300): return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens return response.message, 0 + def chat(self, system, history, gen_conf, image=""): + from http import HTTPStatus + from dashscope import MultiModalConversation + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + response = MultiModalConversation.call(model=self.model_name, messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7)) + + ans = "" + tk_count = 0 + if response.status_code == HTTPStatus.OK: + ans += response.output.choices[0]['message']['content'] + tk_count += response.usage.total_tokens + if response.output.choices[0].get("finish_reason", "") == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ans, tk_count + + return "**ERROR**: " + response.message, tk_count + + def chat_streamly(self, system, history, gen_conf, image=""): + from http import HTTPStatus + from dashscope import MultiModalConversation + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + + ans = "" + tk_count = 0 + try: + response = MultiModalConversation.call(model=self.model_name, messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7), + stream=True) + for resp in response: + if resp.status_code == HTTPStatus.OK: + ans = resp.output.choices[0]['message']['content'] + tk_count = resp.usage.total_tokens + if resp.output.choices[0].get("finish_reason", "") == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + yield ans + else: + yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find( + "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**" + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield tk_count + class Zhipu4V(Base): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): @@ -166,6 +301,59 @@ def describe(self, image, max_tokens=1024): ) return res.choices[0].message.content.strip(), res.usage.total_tokens + def chat(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + try: + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7) + ) + return response.choices[0].message.content.strip(), response.usage.total_tokens + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + ans = "" + tk_count = 0 + try: + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7), + stream=True + ) + for resp in response: + if not resp.choices[0].delta.content: continue + delta = resp.choices[0].delta.content + ans += delta + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + tk_count = resp.usage.total_tokens + if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield tk_count + class OllamaCV(Base): def __init__(self, key, model_name, lang="Chinese", **kwargs): @@ -188,6 +376,63 @@ def describe(self, image, max_tokens=1024): except Exception as e: return "**ERROR**: " + str(e), 0 + def chat(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + try: + for his in history: + if his["role"] == "user": + his["images"] = [image] + options = {} + if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] + if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] + if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] + if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] + response = self.client.chat( + model=self.model_name, + messages=history, + options=options, + keep_alive=-1 + ) + + ans = response["message"]["content"].strip() + return ans, response["eval_count"] + response.get("prompt_eval_count", 0) + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + for his in history: + if his["role"] == "user": + his["images"] = [image] + options = {} + if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] + if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] + if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] + if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] + ans = "" + try: + response = self.client.chat( + model=self.model_name, + messages=history, + stream=True, + options=options, + keep_alive=-1 + ) + for resp in response: + if resp["done"]: + yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) + ans += resp["message"]["content"] + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + yield 0 + class LocalAICV(Base): def __init__(self, key, model_name, base_url, lang="Chinese"): @@ -236,7 +481,7 @@ def describe(self, image, max_tokens=300): class GeminiCV(Base): def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): - from google.generativeai import client,GenerativeModel + from google.generativeai import client, GenerativeModel, GenerationConfig client.configure(api_key=key) _client = client.get_default_generative_client() self.model_name = model_name @@ -258,6 +503,59 @@ def describe(self, image, max_tokens=2048): ) return res.text,res.usage_metadata.total_token_count + def chat(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + try: + for his in history: + if his["role"] == "assistant": + his["role"] = "model" + his["parts"] = [his["content"]] + his.pop("content") + if his["role"] == "user": + his["parts"] = [his["content"]] + his.pop("content") + history[-1]["parts"].append(f"data:image/jpeg;base64," + image) + + response = self.model.generate_content(history, generation_config=GenerationConfig( + max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7))) + + ans = response.text + return ans, response.usage_metadata.total_token_count + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + ans = "" + tk_count = 0 + try: + for his in history: + if his["role"] == "assistant": + his["role"] = "model" + his["parts"] = [his["content"]] + his.pop("content") + if his["role"] == "user": + his["parts"] = [his["content"]] + his.pop("content") + history[-1]["parts"].append(f"data:image/jpeg;base64," + image) + + response = self.model.generate_content(history, generation_config=GenerationConfig( + max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7)), stream=True) + + for resp in response: + if not resp.text: continue + ans += resp.text + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield response._chunks[-1].usage_metadata.total_token_count + class OpenRouterCV(Base): def __init__( diff --git a/web/src/components/llm-setting-items/index.tsx b/web/src/components/llm-setting-items/index.tsx index a8365d616e8..c31d7aabeea 100644 --- a/web/src/components/llm-setting-items/index.tsx +++ b/web/src/components/llm-setting-items/index.tsx @@ -46,7 +46,7 @@ const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => { {...formItemLayout} rules={[{ required: true, message: t('modelMessage') }]} > -