Skip to content

Commit

Permalink
add support for deepseek (infiniflow#668)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#666 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh authored May 8, 2024
1 parent dd2cfda commit e0b9b11
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 53 deletions.
22 changes: 21 additions & 1 deletion api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ def init_superuser():
"name": "Youdao",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
"status": "1",
},{
"name": "DeepSeek",
"logo": "",
"tags": "LLM",
"status": "1",
},
# {
# "name": "文心一言",
Expand Down Expand Up @@ -331,6 +336,21 @@ def init_llm_factory():
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ DeepSeek -----------------------
{
"fid": factory_infos[8]["name"],
"llm_name": "deepseek-chat",
"tags": "LLM,CHAT,",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},
{
"fid": factory_infos[8]["name"],
"llm_name": "deepseek-coder",
"tags": "LLM,CHAT,",
"max_tokens": 16385,
"model_type": LLMType.CHAT.value
},
]
for info in factory_infos:
try:
Expand Down
3 changes: 2 additions & 1 deletion rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"Tongyi-Qianwen": QWenChat,
"Ollama": OllamaChat,
"Xinference": XinferenceChat,
"Moonshot": MoonshotChat
"Moonshot": MoonshotChat,
"DeepSeek": DeepSeekChat
}

71 changes: 20 additions & 51 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,7 @@


class Base(ABC):
def __init__(self, key, model_name):
pass

def chat(self, system, history, gen_conf):
raise NotImplementedError("Please implement encode method!")


class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
def __init__(self, key, model_name, base_url):
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name

Expand All @@ -54,28 +45,28 @@ def chat(self, system, history, gen_conf):
return "**ERROR**: " + str(e), 0


class MoonshotChat(GptTurbo):
class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
super().__init__(key, model_name, base_url)


class MoonshotChat(Base):
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url="https://api.moonshot.cn/v1"
self.client = OpenAI(
api_key=key, base_url=base_url)
self.model_name = model_name
super().__init__(key, model_name, base_url)

def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.total_tokens
except openai.APIError as e:
return "**ERROR**: " + str(e), 0

class XinferenceChat(Base):
def __init__(self, key=None, model_name="", base_url=""):
key = "xxx"
super().__init__(key, model_name, base_url)


class DeepSeekChat(Base):
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
if not base_url: base_url="https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)


class QWenChat(Base):
Expand Down Expand Up @@ -157,25 +148,3 @@ def chat(self, system, history, gen_conf):
except Exception as e:
return "**ERROR**: " + str(e), 0


class XinferenceChat(Base):
def __init__(self, key=None, model_name="", base_url=""):
self.client = OpenAI(api_key="xxx", base_url=base_url)
self.model_name = model_name

def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.total_tokens
except openai.APIError as e:
return "**ERROR**: " + str(e), 0

0 comments on commit e0b9b11

Please sign in to comment.