diff --git a/api/db/init_data.py b/api/db/init_data.py index d0f4d86fe49..42ddb15f48d 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -123,7 +123,12 @@ def init_superuser(): "name": "Youdao", "logo": "", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", + "status": "1", +},{ + "name": "DeepSeek", + "logo": "", + "tags": "LLM", + "status": "1", }, # { # "name": "文心一言", @@ -331,6 +336,21 @@ def init_llm_factory(): "max_tokens": 512, "model_type": LLMType.EMBEDDING.value }, + # ------------------------ DeepSeek ----------------------- + { + "fid": factory_infos[8]["name"], + "llm_name": "deepseek-chat", + "tags": "LLM,CHAT,", + "max_tokens": 32768, + "model_type": LLMType.CHAT.value + }, + { + "fid": factory_infos[8]["name"], + "llm_name": "deepseek-coder", + "tags": "LLM,CHAT,", + "max_tokens": 16385, + "model_type": LLMType.CHAT.value + }, ] for info in factory_infos: try: diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 14d789a5a9a..0a3386a7090 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -45,6 +45,7 @@ "Tongyi-Qianwen": QWenChat, "Ollama": OllamaChat, "Xinference": XinferenceChat, - "Moonshot": MoonshotChat + "Moonshot": MoonshotChat, + "DeepSeek": DeepSeekChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 0a3ecca944f..797d3fea168 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -24,16 +24,7 @@ class Base(ABC): - def __init__(self, key, model_name): - pass - - def chat(self, system, history, gen_conf): - raise NotImplementedError("Please implement encode method!") - - -class GptTurbo(Base): - def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): - if not base_url: base_url="https://api.openai.com/v1" + def __init__(self, key, model_name, base_url): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name @@ -54,28 +45,28 @@ def chat(self, system, history, gen_conf): return "**ERROR**: " + str(e), 0 -class MoonshotChat(GptTurbo): +class GptTurbo(Base): + def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): + if not base_url: base_url="https://api.openai.com/v1" + super().__init__(key, model_name, base_url) + + +class MoonshotChat(Base): def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): if not base_url: base_url="https://api.moonshot.cn/v1" - self.client = OpenAI( - api_key=key, base_url=base_url) - self.model_name = model_name + super().__init__(key, model_name, base_url) - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - **gen_conf) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.total_tokens - except openai.APIError as e: - return "**ERROR**: " + str(e), 0 + +class XinferenceChat(Base): + def __init__(self, key=None, model_name="", base_url=""): + key = "xxx" + super().__init__(key, model_name, base_url) + + +class DeepSeekChat(Base): + def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): + if not base_url: base_url="https://api.deepseek.com/v1" + super().__init__(key, model_name, base_url) class QWenChat(Base): @@ -157,25 +148,3 @@ def chat(self, system, history, gen_conf): except Exception as e: return "**ERROR**: " + str(e), 0 - -class XinferenceChat(Base): - def __init__(self, key=None, model_name="", base_url=""): - self.client = OpenAI(api_key="xxx", base_url=base_url) - self.model_name = model_name - - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - **gen_conf) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.total_tokens - except openai.APIError as e: - return "**ERROR**: " + str(e), 0 -