From 69a1b56e8950f9fccc98c79626f1ca16848f8079 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Mon, 15 Jul 2024 17:38:41 +0800 Subject: [PATCH 01/14] Update wikipedia.py --- graph/component/wikipedia.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/graph/component/wikipedia.py b/graph/component/wikipedia.py index 9e67875961e..48abefaf6e7 100644 --- a/graph/component/wikipedia.py +++ b/graph/component/wikipedia.py @@ -30,9 +30,16 @@ class WikipediaParam(ComponentParamBase): def __init__(self): super().__init__() self.top_n = 10 + self.lang = 'en' def check(self): self.check_positive_integer(self.top_n, "Top N") + self.check_valid_value(self.lang, "Wikipedia languages", + ['af', 'pl', 'ar', 'ast', 'az', 'bg', 'nan', 'bn', 'be', 'ca', 'cs', 'cy', 'da', 'de', + 'et', 'el', 'en', 'es', 'eo', 'eu', 'fa', 'fr', 'gl', 'ko', 'hy', 'hi', 'hr', 'id', + 'it', 'he', 'ka', 'lld', 'la', 'lv', 'lt', 'hu', 'mk', 'arz', 'ms', 'min', 'my', 'nl', + 'ja', 'nb', 'nn', 'ce', 'uz', 'pt', 'kk', 'ro', 'ru', 'ceb', 'sk', 'sl', 'sr', 'sh', + 'fi', 'sv', 'ta', 'tt', 'th', 'tg', 'azb', 'tr', 'uk', 'ur', 'vi', 'war', 'zh', 'yue']) class Wikipedia(ComponentBase, ABC): @@ -45,9 +52,11 @@ def _run(self, history, **kwargs): return Wikipedia.be_output(self._param.no) wiki_res = [] - for wiki_key in wikipedia.search(ans, results=self._param.top_n): + wikipedia.set_lang(self._param.lang) + wiki_engine = wikipedia + for wiki_key in wiki_engine.search(ans, results=self._param.top_n): try: - page = wikipedia.page(title=wiki_key, auto_suggest=False) + page = wiki_engine.page(title=wiki_key, auto_suggest=False) wiki_res.append({"content": '' + page.title + ' ' + page.summary}) except Exception as e: print(e) From 268c0c7eb7fb7df47600341af51ca37a3241411e Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Wed, 17 Jul 2024 15:37:28 +0800 Subject: [PATCH 02/14] Update requirements.txt --- requirements.txt | 176 ++++++++++++++--------------------------------- 1 file changed, 51 insertions(+), 125 deletions(-) diff --git a/requirements.txt b/requirements.txt index 43eddf5b18f..9d02e96486e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,153 +1,79 @@ -accelerate==0.27.2 -aiohttp==3.9.5 -aiosignal==1.3.1 -annotated-types==0.6.0 -anyio==4.3.0 -argon2-cffi==23.1.0 -argon2-cffi-bindings==21.2.0 Aspose.Slides==24.2.0 -attrs==23.2.0 -blinker==1.7.0 -cachelib==0.12.0 +BCEmbedding==0.1.3 +boto3==1.34.140 +botocore==1.34.140 cachetools==5.3.3 -certifi==2024.2.2 -cffi==1.16.0 -charset-normalizer==3.3.2 -click==8.1.7 -coloredlogs==15.0.1 -cryptography==42.0.5 +chardet==5.2.0 +cn2an==0.5.22 dashscope==1.14.1 -datasets==2.17.1 datrie==0.8.2 demjson3==3.0.6 -dill==0.3.8 -distro==1.9.0 -elastic-transport==8.12.0 +discord.py==2.3.2 +duckduckgo_search==6.1.9 +elastic_transport==8.12.0 elasticsearch==8.12.1 -elasticsearch-dsl==8.12.0 -et-xmlfile==1.1.0 -filelock==3.13.1 +elasticsearch_dsl==8.12.0 fastembed==0.2.6 -FlagEmbedding==1.2.5 -Flask==3.0.2 -Flask-Cors==4.0.0 -Flask-Login==0.6.3 -Flask-Session==0.6.0 -flatbuffers==23.5.26 -frozenlist==1.4.1 -fsspec==2023.10.0 -h11==0.14.0 +fasttext==0.9.3 +filelock==3.15.4 +FlagEmbedding==1.2.10 +Flask==3.0.3 +Flask_Cors==4.0.0 +Flask_Login==0.6.3 +flask_session==0.8.0 +groq==0.9.0 hanziconv==0.3.2 -httpcore==1.0.4 +html_text==0.6.2 httpx==0.27.0 -huggingface-hub==0.20.3 -humanfriendly==10.0 -idna==3.6 -install==1.3.5 +huggingface_hub==0.20.3 +infinity_emb==0.0.51 itsdangerous==2.1.2 -Jinja2==3.1.3 -joblib==1.3.2 -lxml==5.1.0 -MarkupSafe==2.1.5 +Markdown==3.6 minio==7.2.4 -mpmath==1.3.0 -multidict==6.0.5 -multiprocess==0.70.16 -networkx==3.2.1 +mistralai==0.4.2 nltk==3.8.1 numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==8.9.2.26 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.19.3 -nvidia-nvjitlink-cu12==12.3.101 -nvidia-nvtx-cu12==12.1.105 -ollama==0.1.9 -onnxruntime-gpu==1.17.1 -openai==1.12.0 -opencv-python==4.9.0.80 +ollama==0.2.1 +onnxruntime==1.17.3 +onnxruntime_gpu==1.17.1 +openai==1.35.14 +opencv_python==4.9.0.80 +opencv_python_headless==4.9.0.80 openpyxl==3.1.2 -packaging==23.2 -pandas==2.2.1 -pdfminer.six==20221105 +pandas==2.2.2 pdfplumber==0.10.4 peewee==3.17.1 -pillow==10.3.0 -protobuf==4.25.3 -psutil==5.9.8 -pyarrow==15.0.0 -pyarrow-hotfix==0.6 +Pillow==10.4.0 +pipreqs==0.5.0 +protobuf==5.27.2 pyclipper==1.3.0.post5 -pycparser==2.21 -pycryptodome -pycryptodome-test-vectors -pycryptodomex -pydantic==2.6.2 -pydantic_core==2.16.3 -PyJWT==2.8.0 -PyMySQL==1.1.1 +pycryptodomex==3.20.0 PyPDF2==3.0.1 -pypdfium2==4.27.0 -python-dateutil==2.8.2 -python-docx==1.1.0 +pytest==8.2.2 python-dotenv==1.0.1 -python-pptx==0.6.23 -PyYAML==6.0.1 +python_dateutil==2.8.2 +python_pptx==0.6.23 +readability_lxml==0.8.1 redis==5.0.3 -regex==2023.12.25 -requests==2.31.0 -ruamel.yaml==0.18.6 -ruamel.yaml.clib==0.2.8 -safetensors==0.4.2 -scikit-learn==1.4.1.post1 -scipy==1.12.0 -sentence-transformers==2.4.0 -shapely==2.0.3 +Requests==2.32.3 +roman_numbers==1.0.2 +ruamel.base==1.0.0 +scikit_learn==1.4.1.post1 +selenium==4.22.0 +setuptools==69.5.1 +Shapely==2.0.5 six==1.16.0 -sniffio==1.3.1 StrEnum==0.4.15 -sympy==1.12 -threadpoolctl==3.3.0 tika==2.6.0 tiktoken==0.6.0 -tokenizers==0.15.2 -torch==2.2.1 -tqdm==4.66.2 +torch==2.3.0 transformers==4.38.1 -triton==2.2.0 -typing_extensions==4.10.0 -tzdata==2024.1 -urllib3==2.2.1 +umap==0.1.1 +volcengine==1.0.146 +webdriver_manager==4.0.1 Werkzeug==3.0.3 -xgboost==2.0.3 -XlsxWriter==3.2.0 +wikipedia==1.4.0 +word2number==1.1 +xgboost==2.1.0 xpinyin==0.7.6 -xxhash==3.4.1 -yarl==1.9.4 zhipuai==2.0.1 -BCEmbedding -loguru==0.7.2 -umap-learn -fasttext==0.9.2 -pybind11==2.13.1 -volcengine==1.0.141 -readability-lxml==0.8.1 -html_text==0.6.2 -selenium==4.21.0 -webdriver-manager==4.0.1 -cn2an==0.5.22 -roman-numbers==1.0.2 -word2number==1.1 -markdown==3.6 -mistralai==0.4.2 -boto3==1.34.140 -duckduckgo_search==6.1.9 -google-generativeai==0.7.2 -groq==0.9.0 -wikipedia==1.4.0 From c1e9da13178bc405d4e80fd475536b19495d1a88 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:22:15 +0800 Subject: [PATCH 03/14] Update dialog_service.py --- api/db/services/dialog_service.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index ab90ee1c047..1680e0e9320 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import json import re from copy import deepcopy @@ -26,6 +28,7 @@ from rag.nlp import keyword_extraction from rag.nlp.search import index_name from rag.utils import rmSpace, num_tokens_from_string, encoder +from api.utils.file_utils import get_project_base_directory class DialogService(CommonService): @@ -73,6 +76,15 @@ def count(): return max_length, msg +def llm_id2llm_type(llm_id): + fnm = os.path.join(get_project_base_directory(), "conf") + llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r")) + for llm_factory in llm_factories["factory_llm_infos"]: + for llm in llm_factory["llm"]: + if llm_id == llm["llm_name"]: + return llm["model_type"] + + def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." llm = LLMService.query(llm_name=dialog.llm_id) @@ -91,7 +103,10 @@ def chat(dialog, messages, stream=True, **kwargs): questions = [m["content"] for m in messages if m["role"] == "user"] embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) - chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + if llm_id2llm_type(dialog.llm_id) == "image2text": + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) + else: + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) @@ -328,7 +343,10 @@ def get_table(): def relevant(tenant_id, llm_id, question, contents: list): - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) + if llm_id2llm_type(llm_id) == "image2text": + chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) + else: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) prompt = """ You are a grader assessing relevance of a retrieved document to a user question. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. @@ -347,7 +365,10 @@ def relevant(tenant_id, llm_id, question, contents: list): def rewrite(tenant_id, llm_id, question): - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) + if llm_id2llm_type(llm_id) == "image2text": + chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) + else: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) prompt = """ You are an expert at query expansion to generate a paraphrasing of a question. I can't retrieval relevant information from the knowledge base by using user's question directly. From 9fd37a71f6a5e049eac167660b6b3410085aea98 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:33:37 +0800 Subject: [PATCH 04/14] Update cv_model.py --- rag/llm/cv_model.py | 75 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 09b7347f4f8..b2eb1dfcd75 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -26,6 +26,7 @@ import json import requests +from rag.nlp import is_english from api.utils import get_uuid from api.utils.file_utils import get_project_base_directory @@ -36,7 +37,13 @@ def __init__(self, key, model_name): def describe(self, image, max_tokens=300): raise NotImplementedError("Please implement encode method!") + + def chat(self, system, history, gen_conf): + raise NotImplementedError("Please implement encode method!") + def chat_streamly(self, system, history, gen_conf): + raise NotImplementedError("Please implement encode method!") + def image2base64(self, image): if isinstance(image, bytes): return base64.b64encode(image).decode("utf-8") @@ -68,6 +75,21 @@ def prompt(self, b64): } ] + def chat_prompt(self, text, b64): + return [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{b64}" + }, + }, + { + "type": "text", + "text": text + }, + ] + + class GptV4(Base): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): @@ -166,6 +188,59 @@ def describe(self, image, max_tokens=1024): ) return res.choices[0].message.content.strip(), res.usage.total_tokens + def chat(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + try: + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7) + ) + return response.choices[0].message.content.strip(), response.usage.total_tokens + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + ans = "" + tk_count = 0 + try: + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7), + stream=True + ) + for resp in response: + if not resp.choices[0].delta.content: continue + delta = resp.choices[0].delta.content + ans += delta + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + tk_count = resp.usage.total_tokens + if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield tk_count + class OllamaCV(Base): def __init__(self, key, model_name, lang="Chinese", **kwargs): From a0d14ea05ae890a8cbff3fa7e7502d999781ea80 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:59:03 +0800 Subject: [PATCH 05/14] Update dialog_service.py --- api/db/services/dialog_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 1680e0e9320..20cb27cfdc4 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -82,7 +82,7 @@ def llm_id2llm_type(llm_id): for llm_factory in llm_factories["factory_llm_infos"]: for llm in llm_factory["llm"]: if llm_id == llm["llm_name"]: - return llm["model_type"] + return llm["model_type"].strip(",")[-1] def chat(dialog, messages, stream=True, **kwargs): From be822ee3e7c6e4b13b27f06351735ddfb4756f61 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:27:20 +0800 Subject: [PATCH 06/14] Update cv_model.py --- rag/llm/cv_model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index b2eb1dfcd75..6b86680a925 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -188,6 +188,20 @@ def describe(self, image, max_tokens=1024): ) return res.choices[0].message.content.strip(), res.usage.total_tokens + def chat_prompt(self, text, b64): + return [ + { + "type": "image_url", + "image_url": { + "url": f"{b64}" + }, + }, + { + "type": "text", + "text": text + }, + ] + def chat(self, system, history, gen_conf, image=""): if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] From edb9ceb8c608c82fc2f43c7f73be5f1210e872c7 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:29:32 +0800 Subject: [PATCH 07/14] Update cv_model.py --- rag/llm/cv_model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 6b86680a925..f131ef95386 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -79,9 +79,7 @@ def chat_prompt(self, text, b64): return [ { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{b64}" - }, + "image_url": f"{b64}" }, { "type": "text", From c48d8c080027b2115e1d66c3589b038554c57b2a Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:38:52 +0800 Subject: [PATCH 08/14] Update cv_model.py --- rag/llm/cv_model.py | 55 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index f131ef95386..1d3306b5cca 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -38,11 +38,58 @@ def __init__(self, key, model_name): def describe(self, image, max_tokens=300): raise NotImplementedError("Please implement encode method!") - def chat(self, system, history, gen_conf): - raise NotImplementedError("Please implement encode method!") + def chat(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + try: + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) - def chat_streamly(self, system, history, gen_conf): - raise NotImplementedError("Please implement encode method!") + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7) + ) + return response.choices[0].message.content.strip(), response.usage.total_tokens + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + ans = "" + tk_count = 0 + try: + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7), + stream=True + ) + for resp in response: + if not resp.choices[0].delta.content: continue + delta = resp.choices[0].delta.content + ans += delta + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + tk_count = resp.usage.total_tokens + if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield tk_count def image2base64(self, image): if isinstance(image, bytes): From ec585f002ec725295aa90c08b1269aec711b3d08 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 16:08:10 +0800 Subject: [PATCH 09/14] Update cv_model.py --- rag/llm/cv_model.py | 66 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 1d3306b5cca..7fa9833beb8 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -207,6 +207,12 @@ def prompt(self, binary): } ] + def chat_prompt(self, text, b64): + return [ + {"image": f"{b64}"}, + {"text": text}, + ] + def describe(self, image, max_tokens=300): from http import HTTPStatus from dashscope import MultiModalConversation @@ -216,6 +222,66 @@ def describe(self, image, max_tokens=300): return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens return response.message, 0 + def chat(self, system, history, gen_conf, image=""): + from http import HTTPStatus + from dashscope import MultiModalConversation + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + response = MultiModalConversation.call(model=self.model_name, messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7)) + + ans = "" + tk_count = 0 + if response.status_code == HTTPStatus.OK: + ans += response.output.choices[0]['message']['content'] + tk_count += response.usage.total_tokens + if response.output.choices[0].get("finish_reason", "") == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ans, tk_count + + return "**ERROR**: " + response.message, tk_count + + def chat_streamly(self, system, history, gen_conf, image=""): + from http import HTTPStatus + from dashscope import MultiModalConversation + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + for his in history: + if his["role"] == "user": + his["content"] = self.chat_prompt(his["content"], image) + + ans = "" + tk_count = 0 + try: + response = MultiModalConversation.call(model=self.model_name, messages=history, + max_tokens=gen_conf.get("max_tokens", 1000), + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7), + stream=True) + for resp in response: + if resp.status_code == HTTPStatus.OK: + ans = resp.output.choices[0]['message']['content'] + tk_count = resp.usage.total_tokens + if resp.output.choices[0].get("finish_reason", "") == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + yield ans + else: + yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find( + "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**" + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield tk_count + class Zhipu4V(Base): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): From aede60b4b19103e0d0fe60f3f623c7226fbfce75 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 16:26:58 +0800 Subject: [PATCH 10/14] Update cv_model.py --- rag/llm/cv_model.py | 57 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 7fa9833beb8..d0c5bcb706d 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -388,6 +388,63 @@ def describe(self, image, max_tokens=1024): except Exception as e: return "**ERROR**: " + str(e), 0 + def chat(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + try: + for his in history: + if his["role"] == "user": + his["images"] = [image] + options = {} + if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] + if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] + if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] + if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] + response = self.client.chat( + model=self.model_name, + messages=history, + options=options, + keep_alive=-1 + ) + + ans = response["message"]["content"].strip() + return ans, response["eval_count"] + response.get("prompt_eval_count", 0) + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + for his in history: + if his["role"] == "user": + his["images"] = [image] + options = {} + if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] + if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] + if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] + if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] + ans = "" + try: + response = self.client.chat( + model=self.model_name, + messages=history, + stream=True, + options=options, + keep_alive=-1 + ) + for resp in response: + if resp["done"]: + yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) + ans += resp["message"]["content"] + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + yield 0 + class XinferenceCV(Base): def __init__(self, key, model_name="", lang="Chinese", base_url=""): From d1e8b8aeeb2242a98d2785826d81a714b662fe8f Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:02:24 +0800 Subject: [PATCH 11/14] Update cv_model.py --- rag/llm/cv_model.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index d0c5bcb706d..5443dec6652 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -126,7 +126,9 @@ def chat_prompt(self, text, b64): return [ { "type": "image_url", - "image_url": f"{b64}" + "image_url": { + "url": f"data:image/jpeg;base64,{b64}", + }, }, { "type": "text", @@ -299,20 +301,6 @@ def describe(self, image, max_tokens=1024): ) return res.choices[0].message.content.strip(), res.usage.total_tokens - def chat_prompt(self, text, b64): - return [ - { - "type": "image_url", - "image_url": { - "url": f"{b64}" - }, - }, - { - "type": "text", - "text": text - }, - ] - def chat(self, system, history, gen_conf, image=""): if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] From 420deb34ad66d6678205af483423d43375f656d8 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 18:04:32 +0800 Subject: [PATCH 12/14] Update cv_model.py --- rag/llm/cv_model.py | 55 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 5443dec6652..9af84c83c2d 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -452,7 +452,7 @@ def describe(self, image, max_tokens=300): class GeminiCV(Base): def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): - from google.generativeai import client,GenerativeModel + from google.generativeai import client, GenerativeModel, GenerationConfig client.configure(api_key=key) _client = client.get_default_generative_client() self.model_name = model_name @@ -474,6 +474,59 @@ def describe(self, image, max_tokens=2048): ) return res.text,res.usage_metadata.total_token_count + def chat(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + try: + for his in history: + if his["role"] == "assistant": + his["role"] = "model" + his["parts"] = [his["content"]] + his.pop("content") + if his["role"] == "user": + his["parts"] = [his["content"]] + his.pop("content") + history[-1]["parts"].append(f"data:image/jpeg;base64," + image) + + response = self.model.generate_content(history, generation_config=GenerationConfig( + max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7))) + + ans = response.text + return ans, response.usage_metadata.total_token_count + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf, image=""): + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + + ans = "" + tk_count = 0 + try: + for his in history: + if his["role"] == "assistant": + his["role"] = "model" + his["parts"] = [his["content"]] + his.pop("content") + if his["role"] == "user": + his["parts"] = [his["content"]] + his.pop("content") + history[-1]["parts"].append(f"data:image/jpeg;base64," + image) + + response = self.model.generate_content(history, generation_config=GenerationConfig( + max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7)), stream=True) + + for resp in response: + if not resp.text: continue + ans += resp.text + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield response._chunks[-1].usage_metadata.total_token_count + class OpenRouterCV(Base): def __init__( From 799f60fcbe17435dda833de5f3aef3a337fbc017 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 18:26:28 +0800 Subject: [PATCH 13/14] Update llm_service.py --- api/db/services/llm_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 5906927b7bc..a994d6103eb 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -70,7 +70,7 @@ def model_instance(cls, tenant_id, llm_type, elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id elif llm_type == LLMType.IMAGE2TEXT.value: - mdlnm = tenant.img2txt_id + mdlnm = tenant.img2txt_id if not llm_name else llm_name elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id if not llm_name else llm_name elif llm_type == LLMType.RERANK: From 9e5696fc985d99c47d815f97b544b125601594c1 Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 19 Jul 2024 18:27:01 +0800 Subject: [PATCH 14/14] Update index.tsx --- web/src/components/llm-setting-items/index.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/llm-setting-items/index.tsx b/web/src/components/llm-setting-items/index.tsx index a8365d616e8..c31d7aabeea 100644 --- a/web/src/components/llm-setting-items/index.tsx +++ b/web/src/components/llm-setting-items/index.tsx @@ -46,7 +46,7 @@ const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => { {...formItemLayout} rules={[{ required: true, message: t('modelMessage') }]} > -