From e34817c2a949ba8de100d5f3b21dea2ca8151be9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Wed, 7 Aug 2024 18:40:51 +0800 Subject: [PATCH] add support for cohere (#1849) ### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen --- conf/llm_factories.json | 110 ++++++++++++++++++ rag/llm/__init__.py | 9 +- rag/llm/chat_model.py | 81 +++++++++++++ rag/llm/embedding_model.py | 32 ++++- rag/llm/rerank_model.py | 27 ++++- requirements.txt | 1 + requirements_arm.txt | 1 + requirements_dev.txt | 1 + web/src/assets/svg/llm/cohere.svg | 1 + .../user-setting/setting-model/constant.ts | 3 +- 10 files changed, 260 insertions(+), 6 deletions(-) create mode 100644 web/src/assets/svg/llm/cohere.svg diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 3eb23c17e95..57b26ffa769 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -2216,6 +2216,116 @@ "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT", "status": "1", "llm": [] + }, + { + "name": "cohere", + "logo": "", + "tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK", + "status": "1", + "llm": [ + { + "llm_name": "command-r-plus", + "tags": "LLM,CHAT,128k", + "max_tokens": 131072, + "model_type": "chat" + }, + { + "llm_name": "command-r", + "tags": "LLM,CHAT,128k", + "max_tokens": 131072, + "model_type": "chat" + }, + { + "llm_name": "command", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": "chat" + }, + { + "llm_name": "command-nightly", + "tags": "LLM,CHAT,128k", + "max_tokens": 131072, + "model_type": "chat" + }, + { + "llm_name": "command-light", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": "chat" + }, + { + "llm_name": "command-light-nightly", + "tags": "LLM,CHAT,4k", + "max_tokens": 4096, + "model_type": "chat" + }, + { + "llm_name": "embed-english-v3.0", + "tags": "TEXT EMBEDDING", + "max_tokens": 512, + "model_type": "embedding" + }, + { + "llm_name": "embed-english-light-v3.0", + "tags": "TEXT EMBEDDING", + "max_tokens": 512, + "model_type": "embedding" + }, + { + "llm_name": "embed-multilingual-v3.0", + "tags": "TEXT EMBEDDING", + "max_tokens": 512, + "model_type": "embedding" + }, + { + "llm_name": "embed-multilingual-light-v3.0", + "tags": "TEXT EMBEDDING", + "max_tokens": 512, + "model_type": "embedding" + }, + { + "llm_name": "embed-english-v2.0", + "tags": "TEXT EMBEDDING", + "max_tokens": 512, + "model_type": "embedding" + }, + { + "llm_name": "embed-english-light-v2.0", + "tags": "TEXT EMBEDDING", + "max_tokens": 512, + "model_type": "embedding" + }, + { + "llm_name": "embed-multilingual-v2.0", + "tags": "TEXT EMBEDDING", + "max_tokens": 256, + "model_type": "embedding" + }, + { + "llm_name": "rerank-english-v3.0", + "tags": "RE-RANK,4k", + "max_tokens": 4096, + "model_type": "rerank" + }, + { + "llm_name": "rerank-multilingual-v3.0", + "tags": "RE-RANK,4k", + "max_tokens": 4096, + "model_type": "rerank" + }, + { + "llm_name": "rerank-english-v2.0", + "tags": "RE-RANK,512", + "max_tokens": 8196, + "model_type": "rerank" + }, + { + "llm_name": "rerank-multilingual-v2.0", + "tags": "RE-RANK,512", + "max_tokens": 512, + "model_type": "rerank" + } + ] } ] } \ No newline at end of file diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 4c3182cae95..f652b63364e 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -37,7 +37,8 @@ "Gemini": GeminiEmbed, "NVIDIA": NvidiaEmbed, "LM-Studio": LmStudioEmbed, - "OpenAI-API-Compatible": OpenAI_APIEmbed + "OpenAI-API-Compatible": OpenAI_APIEmbed, + "cohere": CoHereEmbed } @@ -81,7 +82,8 @@ "StepFun": StepFunChat, "NVIDIA": NvidiaChat, "LM-Studio": LmStudioChat, - "OpenAI-API-Compatible": OpenAI_APIChat + "OpenAI-API-Compatible": OpenAI_APIChat, + "cohere": CoHereChat } @@ -92,7 +94,8 @@ "Xinference": XInferenceRerank, "NVIDIA": NvidiaRerank, "LM-Studio": LmStudioRerank, - "OpenAI-API-Compatible": OpenAI_APIRerank + "OpenAI-API-Compatible": OpenAI_APIRerank, + "cohere": CoHereRerank } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 2b5ab68dde7..ee3acf1d6fb 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -900,3 +900,84 @@ def __init__(self, key, model_name, base_url): base_url = os.path.join(base_url, "v1") model_name = model_name.split("___")[0] super().__init__(key, model_name, base_url) + + +class CoHereChat(Base): + def __init__(self, key, model_name, base_url=""): + from cohere import Client + + self.client = Client(api_key=key) + self.model_name = model_name + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + if "top_p" in gen_conf: + gen_conf["p"] = gen_conf.pop("top_p") + if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf: + gen_conf.pop("presence_penalty") + for item in history: + if "role" in item and item["role"] == "user": + item["role"] = "USER" + if "role" in item and item["role"] == "assistant": + item["role"] = "CHATBOT" + if "content" in item: + item["message"] = item.pop("content") + mes = history.pop()["message"] + ans = "" + try: + response = self.client.chat( + model=self.model_name, chat_history=history, message=mes, **gen_conf + ) + ans = response.text + if response.finish_reason == "MAX_TOKENS": + ans += ( + "...\nFor the content length reason, it stopped, continue?" + if is_english([ans]) + else "······\n由于长度的原因,回答被截断了,要继续吗?" + ) + return ( + ans, + response.meta.tokens.input_tokens + response.meta.tokens.output_tokens, + ) + except Exception as e: + return ans + "\n**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + if "top_p" in gen_conf: + gen_conf["p"] = gen_conf.pop("top_p") + if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf: + gen_conf.pop("presence_penalty") + for item in history: + if "role" in item and item["role"] == "user": + item["role"] = "USER" + if "role" in item and item["role"] == "assistant": + item["role"] = "CHATBOT" + if "content" in item: + item["message"] = item.pop("content") + mes = history.pop()["message"] + ans = "" + total_tokens = 0 + try: + response = self.client.chat_stream( + model=self.model_name, chat_history=history, message=mes, **gen_conf + ) + for resp in response: + if resp.event_type == "text-generation": + ans += resp.text + total_tokens += num_tokens_from_string(resp.text) + elif resp.event_type == "stream-end": + if resp.finish_reason == "MAX_TOKENS": + ans += ( + "...\nFor the content length reason, it stopped, continue?" + if is_english([ans]) + else "······\n由于长度的原因,回答被截断了,要继续吗?" + ) + yield ans + + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 3b0ef5f715e..f69ae848425 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -522,4 +522,34 @@ def __init__(self, key, model_name, base_url): if base_url.split("/")[-1] != "v1": base_url = os.path.join(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) - self.model_name = model_name.split("___")[0] \ No newline at end of file + self.model_name = model_name.split("___")[0] + + +class CoHereEmbed(Base): + def __init__(self, key, model_name, base_url=None): + from cohere import Client + + self.client = Client(api_key=key) + self.model_name = model_name + + def encode(self, texts: list, batch_size=32): + res = self.client.embed( + texts=texts, + model=self.model_name, + input_type="search_query", + embedding_types=["float"], + ) + return np.array([d for d in res.embeddings.float]), int( + res.meta.billed_units.input_tokens + ) + + def encode_queries(self, text): + res = self.client.embed( + texts=[text], + model=self.model_name, + input_type="search_query", + embedding_types=["float"], + ) + return np.array([d for d in res.embeddings.float]), int( + res.meta.billed_units.input_tokens + ) diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index f5e89437f13..2f142ef0ecc 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -203,7 +203,9 @@ def similarity(self, query: str, texts: list): "top_n": len(texts), } res = requests.post(self.base_url, headers=self.headers, json=data).json() - return (np.array([d["logit"] for d in res["rankings"]]), token_count) + rank = np.array([d["logit"] for d in res["rankings"]]) + indexs = [d["index"] for d in res["rankings"]] + return rank[indexs], token_count class LmStudioRerank(Base): @@ -220,3 +222,26 @@ def __init__(self, key, model_name, base_url): def similarity(self, query: str, texts: list): raise NotImplementedError("The api has not been implement") + + +class CoHereRerank(Base): + def __init__(self, key, model_name, base_url=None): + from cohere import Client + + self.client = Client(api_key=key) + self.model_name = model_name + + def similarity(self, query: str, texts: list): + token_count = num_tokens_from_string(query) + sum( + [num_tokens_from_string(t) for t in texts] + ) + res = self.client.rerank( + model=self.model_name, + query=query, + documents=texts, + top_n=len(texts), + return_documents=False, + ) + rank = np.array([d.relevance_score for d in res.results]) + indexs = [d.index for d in res.results] + return rank[indexs], token_count diff --git a/requirements.txt b/requirements.txt index a4065d175b3..d49e1c7e8bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ botocore==1.34.140 cachetools==5.3.3 chardet==5.2.0 cn2an==0.5.22 +cohere==5.6.2 dashscope==1.14.1 datrie==0.8.2 demjson3==3.0.6 diff --git a/requirements_arm.txt b/requirements_arm.txt index 1c94316626a..34b3db21f49 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -14,6 +14,7 @@ certifi==2024.7.4 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 +cohere==5.6.2 coloredlogs==15.0.1 cryptography==42.0.5 dashscope==1.14.1 diff --git a/requirements_dev.txt b/requirements_dev.txt index 92a015c9831..41143d67087 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -14,6 +14,7 @@ certifi==2024.7.4 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 +cohere==5.6.2 coloredlogs==15.0.1 cryptography==42.0.5 dashscope==1.14.1 diff --git a/web/src/assets/svg/llm/cohere.svg b/web/src/assets/svg/llm/cohere.svg new file mode 100644 index 00000000000..cb1b2a5919e --- /dev/null +++ b/web/src/assets/svg/llm/cohere.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts index 865eb29bf34..0ac73d6de77 100644 --- a/web/src/pages/user-setting/setting-model/constant.ts +++ b/web/src/pages/user-setting/setting-model/constant.ts @@ -22,7 +22,8 @@ export const IconMap = { StepFun: 'stepfun', NVIDIA:'nvidia', 'LM-Studio':'lm-studio', - 'OpenAI-API-Compatible':'openai-api' + 'OpenAI-API-Compatible':'openai-api', + 'cohere':'cohere' }; export const BedrockRegionList = [