Skip to content

Commit

Permalink
add support for cohere (#1849)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
  • Loading branch information
hangters and aopstudio authored Aug 7, 2024
1 parent 60428c4 commit e34817c
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 6 deletions.
110 changes: 110 additions & 0 deletions conf/llm_factories.json
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,116 @@
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
"status": "1",
"llm": []
},
{
"name": "cohere",
"logo": "",
"tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
"status": "1",
"llm": [
{
"llm_name": "command-r-plus",
"tags": "LLM,CHAT,128k",
"max_tokens": 131072,
"model_type": "chat"
},
{
"llm_name": "command-r",
"tags": "LLM,CHAT,128k",
"max_tokens": 131072,
"model_type": "chat"
},
{
"llm_name": "command",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": "chat"
},
{
"llm_name": "command-nightly",
"tags": "LLM,CHAT,128k",
"max_tokens": 131072,
"model_type": "chat"
},
{
"llm_name": "command-light",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": "chat"
},
{
"llm_name": "command-light-nightly",
"tags": "LLM,CHAT,4k",
"max_tokens": 4096,
"model_type": "chat"
},
{
"llm_name": "embed-english-v3.0",
"tags": "TEXT EMBEDDING",
"max_tokens": 512,
"model_type": "embedding"
},
{
"llm_name": "embed-english-light-v3.0",
"tags": "TEXT EMBEDDING",
"max_tokens": 512,
"model_type": "embedding"
},
{
"llm_name": "embed-multilingual-v3.0",
"tags": "TEXT EMBEDDING",
"max_tokens": 512,
"model_type": "embedding"
},
{
"llm_name": "embed-multilingual-light-v3.0",
"tags": "TEXT EMBEDDING",
"max_tokens": 512,
"model_type": "embedding"
},
{
"llm_name": "embed-english-v2.0",
"tags": "TEXT EMBEDDING",
"max_tokens": 512,
"model_type": "embedding"
},
{
"llm_name": "embed-english-light-v2.0",
"tags": "TEXT EMBEDDING",
"max_tokens": 512,
"model_type": "embedding"
},
{
"llm_name": "embed-multilingual-v2.0",
"tags": "TEXT EMBEDDING",
"max_tokens": 256,
"model_type": "embedding"
},
{
"llm_name": "rerank-english-v3.0",
"tags": "RE-RANK,4k",
"max_tokens": 4096,
"model_type": "rerank"
},
{
"llm_name": "rerank-multilingual-v3.0",
"tags": "RE-RANK,4k",
"max_tokens": 4096,
"model_type": "rerank"
},
{
"llm_name": "rerank-english-v2.0",
"tags": "RE-RANK,512",
"max_tokens": 8196,
"model_type": "rerank"
},
{
"llm_name": "rerank-multilingual-v2.0",
"tags": "RE-RANK,512",
"max_tokens": 512,
"model_type": "rerank"
}
]
}
]
}
9 changes: 6 additions & 3 deletions rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
"Gemini": GeminiEmbed,
"NVIDIA": NvidiaEmbed,
"LM-Studio": LmStudioEmbed,
"OpenAI-API-Compatible": OpenAI_APIEmbed
"OpenAI-API-Compatible": OpenAI_APIEmbed,
"cohere": CoHereEmbed
}


Expand Down Expand Up @@ -81,7 +82,8 @@
"StepFun": StepFunChat,
"NVIDIA": NvidiaChat,
"LM-Studio": LmStudioChat,
"OpenAI-API-Compatible": OpenAI_APIChat
"OpenAI-API-Compatible": OpenAI_APIChat,
"cohere": CoHereChat
}


Expand All @@ -92,7 +94,8 @@
"Xinference": XInferenceRerank,
"NVIDIA": NvidiaRerank,
"LM-Studio": LmStudioRerank,
"OpenAI-API-Compatible": OpenAI_APIRerank
"OpenAI-API-Compatible": OpenAI_APIRerank,
"cohere": CoHereRerank
}


Expand Down
81 changes: 81 additions & 0 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,3 +900,84 @@ def __init__(self, key, model_name, base_url):
base_url = os.path.join(base_url, "v1")
model_name = model_name.split("___")[0]
super().__init__(key, model_name, base_url)


class CoHereChat(Base):
def __init__(self, key, model_name, base_url=""):
from cohere import Client

self.client = Client(api_key=key)
self.model_name = model_name

def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "top_p" in gen_conf:
gen_conf["p"] = gen_conf.pop("top_p")
if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
gen_conf.pop("presence_penalty")
for item in history:
if "role" in item and item["role"] == "user":
item["role"] = "USER"
if "role" in item and item["role"] == "assistant":
item["role"] = "CHATBOT"
if "content" in item:
item["message"] = item.pop("content")
mes = history.pop()["message"]
ans = ""
try:
response = self.client.chat(
model=self.model_name, chat_history=history, message=mes, **gen_conf
)
ans = response.text
if response.finish_reason == "MAX_TOKENS":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
return (
ans,
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
)
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0

def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "top_p" in gen_conf:
gen_conf["p"] = gen_conf.pop("top_p")
if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
gen_conf.pop("presence_penalty")
for item in history:
if "role" in item and item["role"] == "user":
item["role"] = "USER"
if "role" in item and item["role"] == "assistant":
item["role"] = "CHATBOT"
if "content" in item:
item["message"] = item.pop("content")
mes = history.pop()["message"]
ans = ""
total_tokens = 0
try:
response = self.client.chat_stream(
model=self.model_name, chat_history=history, message=mes, **gen_conf
)
for resp in response:
if resp.event_type == "text-generation":
ans += resp.text
total_tokens += num_tokens_from_string(resp.text)
elif resp.event_type == "stream-end":
if resp.finish_reason == "MAX_TOKENS":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
yield ans

except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

yield total_tokens
32 changes: 31 additions & 1 deletion rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,4 +522,34 @@ def __init__(self, key, model_name, base_url):
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
self.model_name = model_name.split("___")[0]


class CoHereEmbed(Base):
def __init__(self, key, model_name, base_url=None):
from cohere import Client

self.client = Client(api_key=key)
self.model_name = model_name

def encode(self, texts: list, batch_size=32):
res = self.client.embed(
texts=texts,
model=self.model_name,
input_type="search_query",
embedding_types=["float"],
)
return np.array([d for d in res.embeddings.float]), int(
res.meta.billed_units.input_tokens
)

def encode_queries(self, text):
res = self.client.embed(
texts=[text],
model=self.model_name,
input_type="search_query",
embedding_types=["float"],
)
return np.array([d for d in res.embeddings.float]), int(
res.meta.billed_units.input_tokens
)
27 changes: 26 additions & 1 deletion rag/llm/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def similarity(self, query: str, texts: list):
"top_n": len(texts),
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
return (np.array([d["logit"] for d in res["rankings"]]), token_count)
rank = np.array([d["logit"] for d in res["rankings"]])
indexs = [d["index"] for d in res["rankings"]]
return rank[indexs], token_count


class LmStudioRerank(Base):
Expand All @@ -220,3 +222,26 @@ def __init__(self, key, model_name, base_url):

def similarity(self, query: str, texts: list):
raise NotImplementedError("The api has not been implement")


class CoHereRerank(Base):
def __init__(self, key, model_name, base_url=None):
from cohere import Client

self.client = Client(api_key=key)
self.model_name = model_name

def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum(
[num_tokens_from_string(t) for t in texts]
)
res = self.client.rerank(
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
return_documents=False,
)
rank = np.array([d.relevance_score for d in res.results])
indexs = [d.index for d in res.results]
return rank[indexs], token_count
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ botocore==1.34.140
cachetools==5.3.3
chardet==5.2.0
cn2an==0.5.22
cohere==5.6.2
dashscope==1.14.1
datrie==0.8.2
demjson3==3.0.6
Expand Down
1 change: 1 addition & 0 deletions requirements_arm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ certifi==2024.7.4
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
cohere==5.6.2
coloredlogs==15.0.1
cryptography==42.0.5
dashscope==1.14.1
Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ certifi==2024.7.4
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
cohere==5.6.2
coloredlogs==15.0.1
cryptography==42.0.5
dashscope==1.14.1
Expand Down
1 change: 1 addition & 0 deletions web/src/assets/svg/llm/cohere.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion web/src/pages/user-setting/setting-model/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ export const IconMap = {
StepFun: 'stepfun',
NVIDIA:'nvidia',
'LM-Studio':'lm-studio',
'OpenAI-API-Compatible':'openai-api'
'OpenAI-API-Compatible':'openai-api',
'cohere':'cohere'
};

export const BedrockRegionList = [
Expand Down

0 comments on commit e34817c

Please sign in to comment.