From 79426fc41f0461b128698e6249506793db13585a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Mon, 19 Aug 2024 10:36:57 +0800 Subject: [PATCH] add support for Replicate (#1980) ### What problem does this PR solve? #1853 add support for Replicate ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen --- api/apps/llm_app.py | 4 +- conf/llm_factories.json | 7 +++ rag/llm/__init__.py | 6 +- rag/llm/chat_model.py | 55 ++++++++++++++++++- rag/llm/embedding_model.py | 22 +++++++- requirements.txt | 3 +- requirements_arm.txt | 3 +- web/src/assets/svg/llm/replicate.svg | 1 + web/src/pages/user-setting/constants.tsx | 2 +- .../user-setting/setting-model/constant.ts | 3 +- 10 files changed, 94 insertions(+), 12 deletions(-) create mode 100644 web/src/assets/svg/llm/replicate.svg diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 467ea878d3c..f3aebb1749f 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -149,7 +149,7 @@ def add_llm(): msg = "" if llm["model_type"] == LLMType.EMBEDDING.value: mdl = EmbeddingModel[factory]( - key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None, + key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None, model_name=llm["llm_name"], base_url=llm["api_base"]) try: @@ -160,7 +160,7 @@ def add_llm(): msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e) elif llm["model_type"] == LLMType.CHAT.value: mdl = ChatModel[factory]( - key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None, + key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None, model_name=llm["llm_name"], base_url=llm["api_base"] ) diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 0947dbe3c87..be1065c8e1c 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -3113,6 +3113,13 @@ "model_type": "image2text" } ] + }, + { + "name": "Replicate", + "logo": "", + "tags": "LLM,TEXT EMBEDDING", + "status": "1", + "llm": [] } ] } diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 4f059291826..142fc60de32 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -42,7 +42,8 @@ "TogetherAI": TogetherAIEmbed, "PerfXCloud": PerfXCloudEmbed, "Upstage": UpstageEmbed, - "SILICONFLOW": SILICONFLOWEmbed + "SILICONFLOW": SILICONFLOWEmbed, + "Replicate": ReplicateEmbed } @@ -96,7 +97,8 @@ "Upstage":UpstageChat, "novita.ai": NovitaAIChat, "SILICONFLOW": SILICONFLOWChat, - "01.AI": YiChat + "01.AI": YiChat, + "Replicate": ReplicateChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 1c52b32ee30..5e338338cf4 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1003,7 +1003,7 @@ def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): base_url = "https://api.together.xyz/v1" super().__init__(key, model_name, base_url) - + class PerfXCloudChat(Base): def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): if not base_url: @@ -1036,4 +1036,55 @@ class YiChat(Base): def __init__(self, key, model_name, base_url="https://api.01.ai/v1"): if not base_url: base_url = "https://api.01.ai/v1" - super().__init__(key, model_name, base_url) \ No newline at end of file + super().__init__(key, model_name, base_url) + + +class ReplicateChat(Base): + def __init__(self, key, model_name, base_url=None): + from replicate.client import Client + + self.model_name = model_name + self.client = Client(api_token=key) + self.system = "" + + def chat(self, system, history, gen_conf): + if "max_tokens" in gen_conf: + gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens") + if system: + self.system = system + prompt = "\n".join( + [item["role"] + ":" + item["content"] for item in history[-5:]] + ) + ans = "" + try: + response = self.client.run( + self.model_name, + input={"system_prompt": self.system, "prompt": prompt, **gen_conf}, + ) + ans = "".join(response) + return ans, num_tokens_from_string(ans) + except Exception as e: + return ans + "\n**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if "max_tokens" in gen_conf: + gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens") + if system: + self.system = system + prompt = "\n".join( + [item["role"] + ":" + item["content"] for item in history[-5:]] + ) + ans = "" + try: + response = self.client.run( + self.model_name, + input={"system_prompt": self.system, "prompt": prompt, **gen_conf}, + ) + for resp in response: + ans += resp + yield ans + + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield num_tokens_from_string(ans) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index cda6b2429c8..2045e41690f 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -561,7 +561,7 @@ def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): base_url = "https://api.together.xyz/v1" super().__init__(key, model_name, base_url) - + class PerfXCloudEmbed(OpenAIEmbed): def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): if not base_url: @@ -580,4 +580,22 @@ class SILICONFLOWEmbed(OpenAIEmbed): def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"): if not base_url: base_url = "https://api.siliconflow.cn/v1" - super().__init__(key, model_name, base_url) \ No newline at end of file + super().__init__(key, model_name, base_url) + + +class ReplicateEmbed(Base): + def __init__(self, key, model_name, base_url=None): + from replicate.client import Client + + self.model_name = model_name + self.client = Client(api_token=key) + + def encode(self, texts: list, batch_size=32): + from json import dumps + + res = self.client.run(self.model_name, input={"texts": dumps(texts)}) + return np.array(res), sum([num_tokens_from_string(text) for text in texts]) + + def encode_queries(self, text): + res = self.client.embed(self.model_name, input={"texts": [text]}) + return np.array(res), num_tokens_from_string(text) diff --git a/requirements.txt b/requirements.txt index c8921d496db..f7de9a36734 100644 --- a/requirements.txt +++ b/requirements.txt @@ -65,6 +65,7 @@ python_pptx==0.6.23 readability_lxml==0.8.1 redis==5.0.3 Requests==2.32.2 +replicate==0.31.0 roman_numbers==1.0.2 ruamel.base==1.0.0 scholarly==1.7.11 @@ -87,4 +88,4 @@ wikipedia==1.4.0 word2number==1.1 xgboost==2.1.0 xpinyin==0.7.6 -zhipuai==2.0.1 +zhipuai==2.0.1 \ No newline at end of file diff --git a/requirements_arm.txt b/requirements_arm.txt index 5e4dfc7e86e..9b684a8a2c9 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -102,6 +102,7 @@ python-pptx==0.6.23 PyYAML==6.0.1 redis==5.0.3 regex==2023.12.25 +replicate==0.31.0 requests==2.32.2 ruamel.yaml==0.18.6 ruamel.yaml.clib==0.2.8 @@ -161,4 +162,4 @@ markdown_to_json==2.1.1 scholarly==1.7.11 deepl==1.18.0 psycopg2-binary==2.9.9 -tabulate-0.9.0 +tabulate-0.9.0 \ No newline at end of file diff --git a/web/src/assets/svg/llm/replicate.svg b/web/src/assets/svg/llm/replicate.svg new file mode 100644 index 00000000000..31241923ed3 --- /dev/null +++ b/web/src/assets/svg/llm/replicate.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/pages/user-setting/constants.tsx b/web/src/pages/user-setting/constants.tsx index 98b6e42bed6..01e2544b484 100644 --- a/web/src/pages/user-setting/constants.tsx +++ b/web/src/pages/user-setting/constants.tsx @@ -17,4 +17,4 @@ export const UserSettingIconMap = { export * from '@/constants/setting'; -export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI']; +export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI','Replicate']; diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts index 714159dcd8c..3b59364e6ed 100644 --- a/web/src/pages/user-setting/setting-model/constant.ts +++ b/web/src/pages/user-setting/setting-model/constant.ts @@ -30,7 +30,8 @@ export const IconMap = { Upstage: 'upstage', 'novita.ai': 'novita-ai', SILICONFLOW: 'siliconflow', - "01.AI": 'yi' + "01.AI": 'yi', + "Replicate": 'replicate' }; export const BedrockRegionList = [